"""SQLite trade-log database backed by sqlite-utils. Provides persistent storage for trades, window snapshots, and daily performance summaries with lightweight query helpers. """ from __future__ import annotations import time from datetime import datetime, timezone from typing import Any, Optional import structlog from sqlite_utils import Database from src.data.models import Trade, WindowState logger = structlog.get_logger(__name__) class TradeDB: """Thin wrapper around a SQLite database for trade logging and analytics. Parameters ---------- db_path: Path to the SQLite database file. Use ``":memory:"`` for tests. """ def __init__(self, db_path: str = "trades.db") -> None: self._db_path = db_path self._db = Database(db_path) self._log = logger.bind(component="TradeDB", db=db_path) self._ensure_tables() self._log.info("database_ready") # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ def close(self) -> None: """Flush WAL and close the underlying SQLite connection.""" try: if self._db and hasattr(self._db, "conn") and self._db.conn: self._db.conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") self._db.conn.close() self._log.info("database_closed") except Exception: self._log.exception("database_close_error") def __enter__(self): return self def __exit__(self, *exc): self.close() # ------------------------------------------------------------------ # Schema # ------------------------------------------------------------------ def _ensure_tables(self) -> None: """Create tables if they do not already exist.""" if "trades" not in self._db.table_names(): self._db["trades"].create( { "id": str, "asset": str, "timeframe": str, "direction": str, "token_id": str, "entry_price": float, "fill_price": float, "size": int, "fee": float, "pnl": float, "status": str, "signal_edge": float, "signal_prob": float, "created_at": float, "updated_at": float, }, pk="id", ) self._log.debug("table_created", table="trades") if "window_snapshots" not in self._db.table_names(): self._db["window_snapshots"].create( { "id": int, "asset": str, "timeframe": str, "start_price": float, "end_price": float, "price_change_pct": float, "window_start": float, "window_end": float, "market_condition_id": str, "created_at": float, }, pk="id", ) self._log.debug("table_created", table="window_snapshots") if "daily_summary" not in self._db.table_names(): self._db["daily_summary"].create( { "date": str, "total_trades": int, "wins": int, "losses": int, "total_pnl": float, "total_fees": float, "total_volume": float, "best_trade_pnl": float, "worst_trade_pnl": float, }, pk="date", ) self._log.debug("table_created", table="daily_summary") if "balance_history" not in self._db.table_names(): self._db["balance_history"].create( { "id": int, "timestamp": float, "balance": float, "pnl": float, "event": str, }, pk="id", ) self._log.debug("table_created", table="balance_history") if "oracle_snapshots" not in self._db.table_names(): self._db["oracle_snapshots"].create( { "id": int, "timestamp": float, "asset": str, "oracle_price": float, "cex_price": float, "deviation_pct": float, "oracle_lag_sec": float, "oracle_round_id": str, }, pk="id", ) self._log.debug("table_created", table="oracle_snapshots") # ------------------------------------------------------------------ # Trade CRUD # ------------------------------------------------------------------ def log_trade(self, trade: Trade) -> None: """Insert a new trade record derived from a :class:`Trade` dataclass.""" row = { "id": trade.id, "asset": trade.signal.asset.value, "timeframe": trade.signal.timeframe.value, "direction": trade.signal.direction.value, "token_id": trade.signal.token_id, "entry_price": trade.signal.price, "fill_price": trade.fill_price, "size": trade.signal.size, "fee": trade.fee, "pnl": trade.pnl, "status": trade.status.value, "signal_edge": trade.signal.edge, "signal_prob": trade.signal.estimated_prob, "created_at": trade.created_at, "updated_at": trade.updated_at, } self._db["trades"].insert(row) self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"]) def update_trade(self, trade_id: str, **fields: Any) -> None: """Update arbitrary fields on an existing trade row.""" fields["updated_at"] = time.time() self._db["trades"].update(trade_id, fields) self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys())) def get_trades( self, start_time: Optional[float] = None, end_time: Optional[float] = None, asset: Optional[str] = None, ) -> list[dict[str, Any]]: """Return trades matching optional time-range and asset filters.""" clauses: list[str] = [] params: list[Any] = [] if start_time is not None: clauses.append("created_at >= ?") params.append(start_time) if end_time is not None: clauses.append("created_at <= ?") params.append(end_time) if asset is not None: clauses.append("asset = ?") params.append(asset) where = " AND ".join(clauses) if clauses else "1=1" sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC" return list(self._db.execute(sql, params).fetchall()) # ------------------------------------------------------------------ # Window snapshots # ------------------------------------------------------------------ def log_window(self, window: WindowState) -> None: """Snapshot a completed window into the database.""" row = { "asset": window.asset.value, "timeframe": window.timeframe.value, "start_price": window.start_price, "end_price": window.current_price, "price_change_pct": window.price_change_pct, "window_start": window.window_start_time, "window_end": window.window_end_time, "market_condition_id": ( window.market.condition_id if window.market else None ), "created_at": time.time(), } self._db["window_snapshots"].insert(row) self._log.info( "window_logged", asset=row["asset"], timeframe=row["timeframe"], change_pct=row["price_change_pct"], ) # ------------------------------------------------------------------ # Daily summary helpers # ------------------------------------------------------------------ def get_daily_summary(self, date: str) -> dict[str, Any]: """Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``).""" sql = """ SELECT COUNT(*) AS total_trades, SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins, SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses, COALESCE(SUM(pnl), 0.0) AS total_pnl, COALESCE(SUM(fee), 0.0) AS total_fees, COALESCE(SUM(size * fill_price), 0.0) AS total_volume, COALESCE(MAX(pnl), 0.0) AS best_trade_pnl, COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl FROM trades WHERE date(created_at, 'unixepoch') = ? """ row = self._db.execute(sql, [date]).fetchone() return { "date": date, "total_trades": row[0], "wins": row[1], "losses": row[2], "total_pnl": row[3], "total_fees": row[4], "total_volume": row[5], "best_trade_pnl": row[6], "worst_trade_pnl": row[7], } def update_daily_summary(self, date: str) -> None: """Recalculate and upsert the daily summary row for *date*.""" summary = self.get_daily_summary(date) self._db["daily_summary"].upsert(summary, pk="date") self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"]) # ------------------------------------------------------------------ # Quick-access aggregates # ------------------------------------------------------------------ # ------------------------------------------------------------------ # Balance history # ------------------------------------------------------------------ def log_balance(self, balance: float, pnl: float, event: str = "update") -> None: """Record a balance snapshot.""" self._db["balance_history"].insert({ "timestamp": time.time(), "balance": balance, "pnl": pnl, "event": event, }) def log_oracle( self, asset: str, oracle_price: float, cex_price: float, deviation_pct: float, oracle_lag_sec: float, oracle_round_id: int = 0, ) -> None: """Record an oracle price snapshot.""" self._db["oracle_snapshots"].insert({ "timestamp": time.time(), "asset": asset, "oracle_price": oracle_price, "cex_price": cex_price, "deviation_pct": round(deviation_pct, 4), "oracle_lag_sec": round(oracle_lag_sec, 1), "oracle_round_id": str(oracle_round_id), }) def get_latest_balance(self) -> Optional[float]: """Return the most recent balance, or None.""" row = self._db.execute( "SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1" ).fetchone() return float(row[0]) if row else None # ------------------------------------------------------------------ # Quick-access aggregates # ------------------------------------------------------------------ def get_total_pnl(self) -> float: """Return the all-time cumulative PnL.""" row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone() return float(row[0]) def get_today_pnl(self) -> float: """Return today's cumulative PnL (UTC).""" today = datetime.now(timezone.utc).strftime("%Y-%m-%d") row = self._db.execute( "SELECT COALESCE(SUM(pnl), 0.0) FROM trades " "WHERE date(created_at, 'unixepoch') = ?", [today], ).fetchone() return float(row[0]) def get_today_trade_count(self) -> int: """Return the number of trades placed today (UTC).""" today = datetime.now(timezone.utc).strftime("%Y-%m-%d") row = self._db.execute( "SELECT COUNT(*) FROM trades " "WHERE date(created_at, 'unixepoch') = ?", [today], ).fetchone() return int(row[0]) # ------------------------------------------------------------------ # Maintenance # ------------------------------------------------------------------ def periodic_maintenance(self, retention_days: int = 7) -> None: """Prune old time-series data and checkpoint WAL.""" cutoff = time.time() - retention_days * 86400 pruned = {} for table in ("oracle_snapshots", "balance_history"): if table in self._db.table_names(): before = self._db.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] self._db.execute(f"DELETE FROM {table} WHERE timestamp < ?", [cutoff]) after = self._db.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] pruned[table] = before - after try: self._db.conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") except Exception: pass self._log.info("db_maintenance", pruned=pruned, retention_days=retention_days)