import pandas as pd
from backtrader import TimeFrame, date2num
from sqlalchemy import create_engine, inspect
from tqdm import tqdm
from koapy.backtrader.SQLiteData import SQLiteData
from koapy.utils.data.KrxHistoricalDailyPriceDataForBacktestLoader import (
    KrxHistoricalDailyPriceDataForBacktestLoader,
)
[docs]class KrxHistoricalDailyPriceDataFromSQLite(SQLiteData):
    # pylint: disable=no-member
[docs]    params = (
        ("engine", None),
        ("symbol", None),
        ("name", None),
        ("fromdate", None),
        ("todate", None),
        ("compression", 1),
        ("timeframe", TimeFrame.Days),
        ("calendar", None),
        ("timestampcolumn", 0),
        ("timestampcolumntimezone", None),
        ("lazy", False),
    ) 
[docs]    lines = (
        "amount",
        "marketcap",
        "shares",
    ) 
    def __init__(self):
        assert self.p.timeframe == TimeFrame.Days
        assert self.p.compression == 1
        self.p.tablename = self.p.tablename or self.p.symbol or None
        self.p.name = self.p.name or self.p.symbol or self.p.tablename or ""
        super().__init__()
    def _load(self):
        if self._cursor is None:
            return False
        try:
            date, open_, high, low, close, volume, amount, marcap, shares = next(
                self._cursor
            )
        except StopIteration:
            return False
        else:
            dt = pd.Timestamp(date)
            self.lines.datetime[0] = date2num(dt)
            self.lines.open[0] = open_
            self.lines.high[0] = high
            self.lines.low[0] = low
            self.lines.close[0] = close
            self.lines.volume[0] = volume
            self.lines.openinterest[0] = 0.0
            self.lines.amount[0] = amount
            self.lines.marketcap[0] = marcap
            self.lines.shares[0] = shares
            return True
    @classmethod
[docs]    def dump_from_store(
        cls,
        source_filename,
        dest_filename,
        symbols=None,
        fromdate=None,
        todate=None,
        progress_bar=True,
    ):
        loader = KrxHistoricalDailyPriceDataForBacktestLoader(source_filename)
        if symbols is None:
            symbols = loader.get_symbols()
        engine = create_engine("sqlite:///" + dest_filename)
        progress = tqdm(symbols, disable=not progress_bar)
        for symbol in progress:
            progress.set_description("Dumping Symbol [%s]" % symbol)
            data = loader.load(symbol, start_time=fromdate, end_time=todate)
            data.to_sql(symbol, engine, if_exists="replace") 
    @classmethod
[docs]    def adddata_fromfile(
        cls,
        cerebro,
        filename,
        symbols=None,
        fromdate=None,
        todate=None,
        progress_bar=True,
    ):
        engine = create_engine("sqlite:///" + filename)
        inspector = inspect(engine)
        if symbols is None:
            symbols = inspector.get_table_names()
        progress = tqdm(symbols, disable=not progress_bar)
        for symbol in progress:
            progress.set_description("Adding Symbol [%s]" % symbol)
            # pylint: disable=unexpected-keyword-arg
            data = cls(
                engine=engine,
                tablename=symbol,
                fromdate=fromdate,
                todate=todate,
                symbol=symbol,
                name=symbol,
            )
            cerebro.adddata(data, name=data.p.name)