324 lines
12 KiB
Python
324 lines
12 KiB
Python
"""SQLite trade-log database backed by sqlite-utils.
|
|
|
|
Provides persistent storage for trades, window snapshots, and daily
|
|
performance summaries with lightweight query helpers.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional
|
|
|
|
import structlog
|
|
from sqlite_utils import Database
|
|
|
|
from src.data.models import Trade, WindowState
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class TradeDB:
|
|
"""Thin wrapper around a SQLite database for trade logging and analytics.
|
|
|
|
Parameters
|
|
----------
|
|
db_path:
|
|
Path to the SQLite database file. Use ``":memory:"`` for tests.
|
|
"""
|
|
|
|
def __init__(self, db_path: str = "trades.db") -> None:
|
|
self._db_path = db_path
|
|
self._db = Database(db_path)
|
|
self._log = logger.bind(component="TradeDB", db=db_path)
|
|
self._ensure_tables()
|
|
self._log.info("database_ready")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Schema
|
|
# ------------------------------------------------------------------
|
|
|
|
def _ensure_tables(self) -> None:
|
|
"""Create tables if they do not already exist."""
|
|
|
|
if "trades" not in self._db.table_names():
|
|
self._db["trades"].create(
|
|
{
|
|
"id": str,
|
|
"asset": str,
|
|
"timeframe": str,
|
|
"direction": str,
|
|
"token_id": str,
|
|
"entry_price": float,
|
|
"fill_price": float,
|
|
"size": int,
|
|
"fee": float,
|
|
"pnl": float,
|
|
"status": str,
|
|
"signal_edge": float,
|
|
"signal_prob": float,
|
|
"created_at": float,
|
|
"updated_at": float,
|
|
},
|
|
pk="id",
|
|
)
|
|
self._log.debug("table_created", table="trades")
|
|
|
|
if "window_snapshots" not in self._db.table_names():
|
|
self._db["window_snapshots"].create(
|
|
{
|
|
"id": int,
|
|
"asset": str,
|
|
"timeframe": str,
|
|
"start_price": float,
|
|
"end_price": float,
|
|
"price_change_pct": float,
|
|
"window_start": float,
|
|
"window_end": float,
|
|
"market_condition_id": str,
|
|
"created_at": float,
|
|
},
|
|
pk="id",
|
|
)
|
|
self._log.debug("table_created", table="window_snapshots")
|
|
|
|
if "daily_summary" not in self._db.table_names():
|
|
self._db["daily_summary"].create(
|
|
{
|
|
"date": str,
|
|
"total_trades": int,
|
|
"wins": int,
|
|
"losses": int,
|
|
"total_pnl": float,
|
|
"total_fees": float,
|
|
"total_volume": float,
|
|
"best_trade_pnl": float,
|
|
"worst_trade_pnl": float,
|
|
},
|
|
pk="date",
|
|
)
|
|
self._log.debug("table_created", table="daily_summary")
|
|
|
|
if "balance_history" not in self._db.table_names():
|
|
self._db["balance_history"].create(
|
|
{
|
|
"id": int,
|
|
"timestamp": float,
|
|
"balance": float,
|
|
"pnl": float,
|
|
"event": str,
|
|
},
|
|
pk="id",
|
|
)
|
|
self._log.debug("table_created", table="balance_history")
|
|
|
|
if "oracle_snapshots" not in self._db.table_names():
|
|
self._db["oracle_snapshots"].create(
|
|
{
|
|
"id": int,
|
|
"timestamp": float,
|
|
"asset": str,
|
|
"oracle_price": float,
|
|
"cex_price": float,
|
|
"deviation_pct": float,
|
|
"oracle_lag_sec": float,
|
|
"oracle_round_id": str,
|
|
},
|
|
pk="id",
|
|
)
|
|
self._log.debug("table_created", table="oracle_snapshots")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Trade CRUD
|
|
# ------------------------------------------------------------------
|
|
|
|
def log_trade(self, trade: Trade) -> None:
|
|
"""Insert a new trade record derived from a :class:`Trade` dataclass."""
|
|
row = {
|
|
"id": trade.id,
|
|
"asset": trade.signal.asset.value,
|
|
"timeframe": trade.signal.timeframe.value,
|
|
"direction": trade.signal.direction.value,
|
|
"token_id": trade.signal.token_id,
|
|
"entry_price": trade.signal.price,
|
|
"fill_price": trade.fill_price,
|
|
"size": trade.signal.size,
|
|
"fee": trade.fee,
|
|
"pnl": trade.pnl,
|
|
"status": trade.status.value,
|
|
"signal_edge": trade.signal.edge,
|
|
"signal_prob": trade.signal.estimated_prob,
|
|
"created_at": trade.created_at,
|
|
"updated_at": trade.updated_at,
|
|
}
|
|
self._db["trades"].insert(row)
|
|
self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"])
|
|
|
|
def update_trade(self, trade_id: str, **fields: Any) -> None:
|
|
"""Update arbitrary fields on an existing trade row."""
|
|
fields["updated_at"] = time.time()
|
|
self._db["trades"].update(trade_id, fields)
|
|
self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys()))
|
|
|
|
def get_trades(
|
|
self,
|
|
start_time: Optional[float] = None,
|
|
end_time: Optional[float] = None,
|
|
asset: Optional[str] = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Return trades matching optional time-range and asset filters."""
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
|
|
if start_time is not None:
|
|
clauses.append("created_at >= ?")
|
|
params.append(start_time)
|
|
if end_time is not None:
|
|
clauses.append("created_at <= ?")
|
|
params.append(end_time)
|
|
if asset is not None:
|
|
clauses.append("asset = ?")
|
|
params.append(asset)
|
|
|
|
where = " AND ".join(clauses) if clauses else "1=1"
|
|
sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC"
|
|
return list(self._db.execute(sql, params).fetchall())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Window snapshots
|
|
# ------------------------------------------------------------------
|
|
|
|
def log_window(self, window: WindowState) -> None:
|
|
"""Snapshot a completed window into the database."""
|
|
row = {
|
|
"asset": window.asset.value,
|
|
"timeframe": window.timeframe.value,
|
|
"start_price": window.start_price,
|
|
"end_price": window.current_price,
|
|
"price_change_pct": window.price_change_pct,
|
|
"window_start": window.window_start_time,
|
|
"window_end": window.window_end_time,
|
|
"market_condition_id": (
|
|
window.market.condition_id if window.market else None
|
|
),
|
|
"created_at": time.time(),
|
|
}
|
|
self._db["window_snapshots"].insert(row)
|
|
self._log.info(
|
|
"window_logged",
|
|
asset=row["asset"],
|
|
timeframe=row["timeframe"],
|
|
change_pct=row["price_change_pct"],
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Daily summary helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_daily_summary(self, date: str) -> dict[str, Any]:
|
|
"""Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``)."""
|
|
sql = """
|
|
SELECT
|
|
COUNT(*) AS total_trades,
|
|
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
|
SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses,
|
|
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
|
COALESCE(SUM(fee), 0.0) AS total_fees,
|
|
COALESCE(SUM(size * fill_price), 0.0) AS total_volume,
|
|
COALESCE(MAX(pnl), 0.0) AS best_trade_pnl,
|
|
COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl
|
|
FROM trades
|
|
WHERE date(created_at, 'unixepoch') = ?
|
|
"""
|
|
row = self._db.execute(sql, [date]).fetchone()
|
|
return {
|
|
"date": date,
|
|
"total_trades": row[0],
|
|
"wins": row[1],
|
|
"losses": row[2],
|
|
"total_pnl": row[3],
|
|
"total_fees": row[4],
|
|
"total_volume": row[5],
|
|
"best_trade_pnl": row[6],
|
|
"worst_trade_pnl": row[7],
|
|
}
|
|
|
|
def update_daily_summary(self, date: str) -> None:
|
|
"""Recalculate and upsert the daily summary row for *date*."""
|
|
summary = self.get_daily_summary(date)
|
|
self._db["daily_summary"].upsert(summary, pk="date")
|
|
self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Quick-access aggregates
|
|
# ------------------------------------------------------------------
|
|
|
|
# ------------------------------------------------------------------
|
|
# Balance history
|
|
# ------------------------------------------------------------------
|
|
|
|
def log_balance(self, balance: float, pnl: float, event: str = "update") -> None:
|
|
"""Record a balance snapshot."""
|
|
self._db["balance_history"].insert({
|
|
"timestamp": time.time(),
|
|
"balance": balance,
|
|
"pnl": pnl,
|
|
"event": event,
|
|
})
|
|
|
|
def log_oracle(
|
|
self,
|
|
asset: str,
|
|
oracle_price: float,
|
|
cex_price: float,
|
|
deviation_pct: float,
|
|
oracle_lag_sec: float,
|
|
oracle_round_id: int = 0,
|
|
) -> None:
|
|
"""Record an oracle price snapshot."""
|
|
self._db["oracle_snapshots"].insert({
|
|
"timestamp": time.time(),
|
|
"asset": asset,
|
|
"oracle_price": oracle_price,
|
|
"cex_price": cex_price,
|
|
"deviation_pct": round(deviation_pct, 4),
|
|
"oracle_lag_sec": round(oracle_lag_sec, 1),
|
|
"oracle_round_id": str(oracle_round_id),
|
|
})
|
|
|
|
def get_latest_balance(self) -> Optional[float]:
|
|
"""Return the most recent balance, or None."""
|
|
row = self._db.execute(
|
|
"SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1"
|
|
).fetchone()
|
|
return float(row[0]) if row else None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Quick-access aggregates
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_total_pnl(self) -> float:
|
|
"""Return the all-time cumulative PnL."""
|
|
row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone()
|
|
return float(row[0])
|
|
|
|
def get_today_pnl(self) -> float:
|
|
"""Return today's cumulative PnL (UTC)."""
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
row = self._db.execute(
|
|
"SELECT COALESCE(SUM(pnl), 0.0) FROM trades "
|
|
"WHERE date(created_at, 'unixepoch') = ?",
|
|
[today],
|
|
).fetchone()
|
|
return float(row[0])
|
|
|
|
def get_today_trade_count(self) -> int:
|
|
"""Return the number of trades placed today (UTC)."""
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
row = self._db.execute(
|
|
"SELECT COUNT(*) FROM trades "
|
|
"WHERE date(created_at, 'unixepoch') = ?",
|
|
[today],
|
|
).fetchone()
|
|
return int(row[0])
|