"""Data access layer for the trading database. Provides CRUD operations for positions, trades, daily performance, and bot state using synchronous sqlite3. """ from __future__ import annotations import json import sqlite3 from datetime import date, datetime from typing import Dict, List, Optional from loguru import logger from config import settings from database.models import ( DailyPerformance, PositionRecord, TradeRecord, init_db, ) class TradingRepository: """Synchronous repository for all trading data.""" def __init__(self, db_path: str | None = None): self._db_path = db_path or settings.DB_PATH self._conn: Optional[sqlite3.Connection] = None def connect(self) -> None: """Open and initialise the database connection.""" self._conn = init_db(self._db_path) self._conn.row_factory = sqlite3.Row def close(self) -> None: if self._conn: self._conn.close() @property def conn(self) -> sqlite3.Connection: if self._conn is None: self.connect() return self._conn # type: ignore # ------------------------------------------------------------------ # Positions # ------------------------------------------------------------------ def save_position(self, pos: PositionRecord) -> None: self.conn.execute( """INSERT OR REPLACE INTO positions (id, symbol, direction, entry_price, amount, stop_loss, take_profit, trailing_stop, realized_pnl, status, opened_at, closed_at, close_reason, confluence_score, entry_reasons) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", ( pos.id, pos.symbol, pos.direction, pos.entry_price, pos.amount, pos.stop_loss, pos.take_profit, pos.trailing_stop, pos.realized_pnl, pos.status, pos.opened_at, pos.closed_at, pos.close_reason, pos.confluence_score, pos.entry_reasons, ), ) self.conn.commit() def get_position(self, position_id: str) -> Optional[PositionRecord]: row = self.conn.execute( "SELECT * FROM positions WHERE id = ?", (position_id,) ).fetchone() return self._row_to_position(row) if row else None def get_open_positions(self) -> List[PositionRecord]: rows = self.conn.execute( "SELECT * FROM positions WHERE status = 'OPEN'" ).fetchall() return [self._row_to_position(r) for r in rows] def get_closed_positions(self, limit: int = 100) -> List[PositionRecord]: rows = self.conn.execute( "SELECT * FROM positions WHERE status = 'CLOSED' ORDER BY closed_at DESC LIMIT ?", (limit,), ).fetchall() return [self._row_to_position(r) for r in rows] def _row_to_position(self, row: sqlite3.Row) -> PositionRecord: return PositionRecord( id=row["id"], symbol=row["symbol"], direction=row["direction"], entry_price=row["entry_price"], amount=row["amount"], stop_loss=row["stop_loss"], take_profit=row["take_profit"], trailing_stop=row["trailing_stop"], realized_pnl=row["realized_pnl"], status=row["status"], opened_at=row["opened_at"] or "", closed_at=row["closed_at"], close_reason=row["close_reason"], confluence_score=row["confluence_score"] or 0, entry_reasons=row["entry_reasons"] or "[]", ) # ------------------------------------------------------------------ # Trade Records # ------------------------------------------------------------------ def save_trade(self, trade: TradeRecord) -> None: self.conn.execute( """INSERT OR REPLACE INTO trade_records (id, position_id, symbol, side, order_type, price, amount, fee, timestamp) VALUES (?,?,?,?,?,?,?,?,?)""", ( trade.id, trade.position_id, trade.symbol, trade.side, trade.order_type, trade.price, trade.amount, trade.fee, trade.timestamp, ), ) self.conn.commit() def get_trades_for_position(self, position_id: str) -> List[TradeRecord]: rows = self.conn.execute( "SELECT * FROM trade_records WHERE position_id = ? ORDER BY timestamp", (position_id,), ).fetchall() return [ TradeRecord( id=r["id"], position_id=r["position_id"], symbol=r["symbol"], side=r["side"], order_type=r["order_type"], price=r["price"], amount=r["amount"], fee=r["fee"], timestamp=r["timestamp"] or "", ) for r in rows ] # ------------------------------------------------------------------ # Daily Performance # ------------------------------------------------------------------ def update_daily_performance( self, pnl: float, is_win: bool, max_dd: float = 0.0 ) -> None: today = date.today().isoformat() existing = self.conn.execute( "SELECT * FROM daily_performance WHERE date = ?", (today,) ).fetchone() if existing: self.conn.execute( """UPDATE daily_performance SET total_trades = total_trades + 1, winning_trades = winning_trades + ?, losing_trades = losing_trades + ?, total_pnl = total_pnl + ?, max_drawdown = MAX(max_drawdown, ?) WHERE date = ?""", (1 if is_win else 0, 0 if is_win else 1, pnl, max_dd, today), ) else: self.conn.execute( """INSERT INTO daily_performance (date, total_trades, winning_trades, losing_trades, total_pnl, max_drawdown) VALUES (?,1,?,?,?,?)""", (today, 1 if is_win else 0, 0 if is_win else 1, pnl, max_dd), ) self.conn.commit() def get_daily_performance(self, day: str | None = None) -> Optional[DailyPerformance]: day = day or date.today().isoformat() row = self.conn.execute( "SELECT * FROM daily_performance WHERE date = ?", (day,) ).fetchone() if not row: return None return DailyPerformance( date=row["date"], total_trades=row["total_trades"], winning_trades=row["winning_trades"], losing_trades=row["losing_trades"], total_pnl=row["total_pnl"], max_drawdown=row["max_drawdown"], ) def get_performance_history(self, days: int = 30) -> List[DailyPerformance]: rows = self.conn.execute( "SELECT * FROM daily_performance ORDER BY date DESC LIMIT ?", (days,) ).fetchall() return [ DailyPerformance( date=r["date"], total_trades=r["total_trades"], winning_trades=r["winning_trades"], losing_trades=r["losing_trades"], total_pnl=r["total_pnl"], max_drawdown=r["max_drawdown"], ) for r in rows ] # ------------------------------------------------------------------ # Bot State # ------------------------------------------------------------------ def set_state(self, key: str, value: str) -> None: self.conn.execute( """INSERT OR REPLACE INTO bot_state (key, value, updated_at) VALUES (?, ?, ?)""", (key, value, datetime.utcnow().isoformat()), ) self.conn.commit() def get_state(self, key: str) -> Optional[str]: row = self.conn.execute( "SELECT value FROM bot_state WHERE key = ?", (key,) ).fetchone() return row["value"] if row else None