Files
polymarket-arb-bot/src/data/db.py
2026-03-22 09:28:14 +09:00

324 lines
12 KiB
Python

"""SQLite trade-log database backed by sqlite-utils.
Provides persistent storage for trades, window snapshots, and daily
performance summaries with lightweight query helpers.
"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from typing import Any, Optional
import structlog
from sqlite_utils import Database
from src.data.models import Trade, WindowState
logger = structlog.get_logger(__name__)
class TradeDB:
"""Thin wrapper around a SQLite database for trade logging and analytics.
Parameters
----------
db_path:
Path to the SQLite database file. Use ``":memory:"`` for tests.
"""
def __init__(self, db_path: str = "trades.db") -> None:
self._db_path = db_path
self._db = Database(db_path)
self._log = logger.bind(component="TradeDB", db=db_path)
self._ensure_tables()
self._log.info("database_ready")
# ------------------------------------------------------------------
# Schema
# ------------------------------------------------------------------
def _ensure_tables(self) -> None:
"""Create tables if they do not already exist."""
if "trades" not in self._db.table_names():
self._db["trades"].create(
{
"id": str,
"asset": str,
"timeframe": str,
"direction": str,
"token_id": str,
"entry_price": float,
"fill_price": float,
"size": int,
"fee": float,
"pnl": float,
"status": str,
"signal_edge": float,
"signal_prob": float,
"created_at": float,
"updated_at": float,
},
pk="id",
)
self._log.debug("table_created", table="trades")
if "window_snapshots" not in self._db.table_names():
self._db["window_snapshots"].create(
{
"id": int,
"asset": str,
"timeframe": str,
"start_price": float,
"end_price": float,
"price_change_pct": float,
"window_start": float,
"window_end": float,
"market_condition_id": str,
"created_at": float,
},
pk="id",
)
self._log.debug("table_created", table="window_snapshots")
if "daily_summary" not in self._db.table_names():
self._db["daily_summary"].create(
{
"date": str,
"total_trades": int,
"wins": int,
"losses": int,
"total_pnl": float,
"total_fees": float,
"total_volume": float,
"best_trade_pnl": float,
"worst_trade_pnl": float,
},
pk="date",
)
self._log.debug("table_created", table="daily_summary")
if "balance_history" not in self._db.table_names():
self._db["balance_history"].create(
{
"id": int,
"timestamp": float,
"balance": float,
"pnl": float,
"event": str,
},
pk="id",
)
self._log.debug("table_created", table="balance_history")
if "oracle_snapshots" not in self._db.table_names():
self._db["oracle_snapshots"].create(
{
"id": int,
"timestamp": float,
"asset": str,
"oracle_price": float,
"cex_price": float,
"deviation_pct": float,
"oracle_lag_sec": float,
"oracle_round_id": str,
},
pk="id",
)
self._log.debug("table_created", table="oracle_snapshots")
# ------------------------------------------------------------------
# Trade CRUD
# ------------------------------------------------------------------
def log_trade(self, trade: Trade) -> None:
"""Insert a new trade record derived from a :class:`Trade` dataclass."""
row = {
"id": trade.id,
"asset": trade.signal.asset.value,
"timeframe": trade.signal.timeframe.value,
"direction": trade.signal.direction.value,
"token_id": trade.signal.token_id,
"entry_price": trade.signal.price,
"fill_price": trade.fill_price,
"size": trade.signal.size,
"fee": trade.fee,
"pnl": trade.pnl,
"status": trade.status.value,
"signal_edge": trade.signal.edge,
"signal_prob": trade.signal.estimated_prob,
"created_at": trade.created_at,
"updated_at": trade.updated_at,
}
self._db["trades"].insert(row)
self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"])
def update_trade(self, trade_id: str, **fields: Any) -> None:
"""Update arbitrary fields on an existing trade row."""
fields["updated_at"] = time.time()
self._db["trades"].update(trade_id, fields)
self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys()))
def get_trades(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
asset: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Return trades matching optional time-range and asset filters."""
clauses: list[str] = []
params: list[Any] = []
if start_time is not None:
clauses.append("created_at >= ?")
params.append(start_time)
if end_time is not None:
clauses.append("created_at <= ?")
params.append(end_time)
if asset is not None:
clauses.append("asset = ?")
params.append(asset)
where = " AND ".join(clauses) if clauses else "1=1"
sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC"
return list(self._db.execute(sql, params).fetchall())
# ------------------------------------------------------------------
# Window snapshots
# ------------------------------------------------------------------
def log_window(self, window: WindowState) -> None:
"""Snapshot a completed window into the database."""
row = {
"asset": window.asset.value,
"timeframe": window.timeframe.value,
"start_price": window.start_price,
"end_price": window.current_price,
"price_change_pct": window.price_change_pct,
"window_start": window.window_start_time,
"window_end": window.window_end_time,
"market_condition_id": (
window.market.condition_id if window.market else None
),
"created_at": time.time(),
}
self._db["window_snapshots"].insert(row)
self._log.info(
"window_logged",
asset=row["asset"],
timeframe=row["timeframe"],
change_pct=row["price_change_pct"],
)
# ------------------------------------------------------------------
# Daily summary helpers
# ------------------------------------------------------------------
def get_daily_summary(self, date: str) -> dict[str, Any]:
"""Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``)."""
sql = """
SELECT
COUNT(*) AS total_trades,
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses,
COALESCE(SUM(pnl), 0.0) AS total_pnl,
COALESCE(SUM(fee), 0.0) AS total_fees,
COALESCE(SUM(size * fill_price), 0.0) AS total_volume,
COALESCE(MAX(pnl), 0.0) AS best_trade_pnl,
COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl
FROM trades
WHERE date(created_at, 'unixepoch') = ?
"""
row = self._db.execute(sql, [date]).fetchone()
return {
"date": date,
"total_trades": row[0],
"wins": row[1],
"losses": row[2],
"total_pnl": row[3],
"total_fees": row[4],
"total_volume": row[5],
"best_trade_pnl": row[6],
"worst_trade_pnl": row[7],
}
def update_daily_summary(self, date: str) -> None:
"""Recalculate and upsert the daily summary row for *date*."""
summary = self.get_daily_summary(date)
self._db["daily_summary"].upsert(summary, pk="date")
self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"])
# ------------------------------------------------------------------
# Quick-access aggregates
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Balance history
# ------------------------------------------------------------------
def log_balance(self, balance: float, pnl: float, event: str = "update") -> None:
"""Record a balance snapshot."""
self._db["balance_history"].insert({
"timestamp": time.time(),
"balance": balance,
"pnl": pnl,
"event": event,
})
def log_oracle(
self,
asset: str,
oracle_price: float,
cex_price: float,
deviation_pct: float,
oracle_lag_sec: float,
oracle_round_id: int = 0,
) -> None:
"""Record an oracle price snapshot."""
self._db["oracle_snapshots"].insert({
"timestamp": time.time(),
"asset": asset,
"oracle_price": oracle_price,
"cex_price": cex_price,
"deviation_pct": round(deviation_pct, 4),
"oracle_lag_sec": round(oracle_lag_sec, 1),
"oracle_round_id": str(oracle_round_id),
})
def get_latest_balance(self) -> Optional[float]:
"""Return the most recent balance, or None."""
row = self._db.execute(
"SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1"
).fetchone()
return float(row[0]) if row else None
# ------------------------------------------------------------------
# Quick-access aggregates
# ------------------------------------------------------------------
def get_total_pnl(self) -> float:
"""Return the all-time cumulative PnL."""
row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone()
return float(row[0])
def get_today_pnl(self) -> float:
"""Return today's cumulative PnL (UTC)."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
row = self._db.execute(
"SELECT COALESCE(SUM(pnl), 0.0) FROM trades "
"WHERE date(created_at, 'unixepoch') = ?",
[today],
).fetchone()
return float(row[0])
def get_today_trade_count(self) -> int:
"""Return the number of trades placed today (UTC)."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
row = self._db.execute(
"SELECT COUNT(*) FROM trades "
"WHERE date(created_at, 'unixepoch') = ?",
[today],
).fetchone()
return int(row[0])