update 03-22 09:28
This commit is contained in:
323
src/data/db.py
Normal file
323
src/data/db.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""SQLite trade-log database backed by sqlite-utils.
|
||||
|
||||
Provides persistent storage for trades, window snapshots, and daily
|
||||
performance summaries with lightweight query helpers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import structlog
|
||||
from sqlite_utils import Database
|
||||
|
||||
from src.data.models import Trade, WindowState
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class TradeDB:
|
||||
"""Thin wrapper around a SQLite database for trade logging and analytics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db_path:
|
||||
Path to the SQLite database file. Use ``":memory:"`` for tests.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "trades.db") -> None:
|
||||
self._db_path = db_path
|
||||
self._db = Database(db_path)
|
||||
self._log = logger.bind(component="TradeDB", db=db_path)
|
||||
self._ensure_tables()
|
||||
self._log.info("database_ready")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_tables(self) -> None:
|
||||
"""Create tables if they do not already exist."""
|
||||
|
||||
if "trades" not in self._db.table_names():
|
||||
self._db["trades"].create(
|
||||
{
|
||||
"id": str,
|
||||
"asset": str,
|
||||
"timeframe": str,
|
||||
"direction": str,
|
||||
"token_id": str,
|
||||
"entry_price": float,
|
||||
"fill_price": float,
|
||||
"size": int,
|
||||
"fee": float,
|
||||
"pnl": float,
|
||||
"status": str,
|
||||
"signal_edge": float,
|
||||
"signal_prob": float,
|
||||
"created_at": float,
|
||||
"updated_at": float,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="trades")
|
||||
|
||||
if "window_snapshots" not in self._db.table_names():
|
||||
self._db["window_snapshots"].create(
|
||||
{
|
||||
"id": int,
|
||||
"asset": str,
|
||||
"timeframe": str,
|
||||
"start_price": float,
|
||||
"end_price": float,
|
||||
"price_change_pct": float,
|
||||
"window_start": float,
|
||||
"window_end": float,
|
||||
"market_condition_id": str,
|
||||
"created_at": float,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="window_snapshots")
|
||||
|
||||
if "daily_summary" not in self._db.table_names():
|
||||
self._db["daily_summary"].create(
|
||||
{
|
||||
"date": str,
|
||||
"total_trades": int,
|
||||
"wins": int,
|
||||
"losses": int,
|
||||
"total_pnl": float,
|
||||
"total_fees": float,
|
||||
"total_volume": float,
|
||||
"best_trade_pnl": float,
|
||||
"worst_trade_pnl": float,
|
||||
},
|
||||
pk="date",
|
||||
)
|
||||
self._log.debug("table_created", table="daily_summary")
|
||||
|
||||
if "balance_history" not in self._db.table_names():
|
||||
self._db["balance_history"].create(
|
||||
{
|
||||
"id": int,
|
||||
"timestamp": float,
|
||||
"balance": float,
|
||||
"pnl": float,
|
||||
"event": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="balance_history")
|
||||
|
||||
if "oracle_snapshots" not in self._db.table_names():
|
||||
self._db["oracle_snapshots"].create(
|
||||
{
|
||||
"id": int,
|
||||
"timestamp": float,
|
||||
"asset": str,
|
||||
"oracle_price": float,
|
||||
"cex_price": float,
|
||||
"deviation_pct": float,
|
||||
"oracle_lag_sec": float,
|
||||
"oracle_round_id": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="oracle_snapshots")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trade CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_trade(self, trade: Trade) -> None:
|
||||
"""Insert a new trade record derived from a :class:`Trade` dataclass."""
|
||||
row = {
|
||||
"id": trade.id,
|
||||
"asset": trade.signal.asset.value,
|
||||
"timeframe": trade.signal.timeframe.value,
|
||||
"direction": trade.signal.direction.value,
|
||||
"token_id": trade.signal.token_id,
|
||||
"entry_price": trade.signal.price,
|
||||
"fill_price": trade.fill_price,
|
||||
"size": trade.signal.size,
|
||||
"fee": trade.fee,
|
||||
"pnl": trade.pnl,
|
||||
"status": trade.status.value,
|
||||
"signal_edge": trade.signal.edge,
|
||||
"signal_prob": trade.signal.estimated_prob,
|
||||
"created_at": trade.created_at,
|
||||
"updated_at": trade.updated_at,
|
||||
}
|
||||
self._db["trades"].insert(row)
|
||||
self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"])
|
||||
|
||||
def update_trade(self, trade_id: str, **fields: Any) -> None:
|
||||
"""Update arbitrary fields on an existing trade row."""
|
||||
fields["updated_at"] = time.time()
|
||||
self._db["trades"].update(trade_id, fields)
|
||||
self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys()))
|
||||
|
||||
def get_trades(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
asset: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return trades matching optional time-range and asset filters."""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if start_time is not None:
|
||||
clauses.append("created_at >= ?")
|
||||
params.append(start_time)
|
||||
if end_time is not None:
|
||||
clauses.append("created_at <= ?")
|
||||
params.append(end_time)
|
||||
if asset is not None:
|
||||
clauses.append("asset = ?")
|
||||
params.append(asset)
|
||||
|
||||
where = " AND ".join(clauses) if clauses else "1=1"
|
||||
sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC"
|
||||
return list(self._db.execute(sql, params).fetchall())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Window snapshots
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_window(self, window: WindowState) -> None:
|
||||
"""Snapshot a completed window into the database."""
|
||||
row = {
|
||||
"asset": window.asset.value,
|
||||
"timeframe": window.timeframe.value,
|
||||
"start_price": window.start_price,
|
||||
"end_price": window.current_price,
|
||||
"price_change_pct": window.price_change_pct,
|
||||
"window_start": window.window_start_time,
|
||||
"window_end": window.window_end_time,
|
||||
"market_condition_id": (
|
||||
window.market.condition_id if window.market else None
|
||||
),
|
||||
"created_at": time.time(),
|
||||
}
|
||||
self._db["window_snapshots"].insert(row)
|
||||
self._log.info(
|
||||
"window_logged",
|
||||
asset=row["asset"],
|
||||
timeframe=row["timeframe"],
|
||||
change_pct=row["price_change_pct"],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daily summary helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_daily_summary(self, date: str) -> dict[str, Any]:
|
||||
"""Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``)."""
|
||||
sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(SUM(fee), 0.0) AS total_fees,
|
||||
COALESCE(SUM(size * fill_price), 0.0) AS total_volume,
|
||||
COALESCE(MAX(pnl), 0.0) AS best_trade_pnl,
|
||||
COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl
|
||||
FROM trades
|
||||
WHERE date(created_at, 'unixepoch') = ?
|
||||
"""
|
||||
row = self._db.execute(sql, [date]).fetchone()
|
||||
return {
|
||||
"date": date,
|
||||
"total_trades": row[0],
|
||||
"wins": row[1],
|
||||
"losses": row[2],
|
||||
"total_pnl": row[3],
|
||||
"total_fees": row[4],
|
||||
"total_volume": row[5],
|
||||
"best_trade_pnl": row[6],
|
||||
"worst_trade_pnl": row[7],
|
||||
}
|
||||
|
||||
def update_daily_summary(self, date: str) -> None:
|
||||
"""Recalculate and upsert the daily summary row for *date*."""
|
||||
summary = self.get_daily_summary(date)
|
||||
self._db["daily_summary"].upsert(summary, pk="date")
|
||||
self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick-access aggregates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Balance history
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_balance(self, balance: float, pnl: float, event: str = "update") -> None:
|
||||
"""Record a balance snapshot."""
|
||||
self._db["balance_history"].insert({
|
||||
"timestamp": time.time(),
|
||||
"balance": balance,
|
||||
"pnl": pnl,
|
||||
"event": event,
|
||||
})
|
||||
|
||||
def log_oracle(
|
||||
self,
|
||||
asset: str,
|
||||
oracle_price: float,
|
||||
cex_price: float,
|
||||
deviation_pct: float,
|
||||
oracle_lag_sec: float,
|
||||
oracle_round_id: int = 0,
|
||||
) -> None:
|
||||
"""Record an oracle price snapshot."""
|
||||
self._db["oracle_snapshots"].insert({
|
||||
"timestamp": time.time(),
|
||||
"asset": asset,
|
||||
"oracle_price": oracle_price,
|
||||
"cex_price": cex_price,
|
||||
"deviation_pct": round(deviation_pct, 4),
|
||||
"oracle_lag_sec": round(oracle_lag_sec, 1),
|
||||
"oracle_round_id": str(oracle_round_id),
|
||||
})
|
||||
|
||||
def get_latest_balance(self) -> Optional[float]:
|
||||
"""Return the most recent balance, or None."""
|
||||
row = self._db.execute(
|
||||
"SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return float(row[0]) if row else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick-access aggregates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_total_pnl(self) -> float:
|
||||
"""Return the all-time cumulative PnL."""
|
||||
row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone()
|
||||
return float(row[0])
|
||||
|
||||
def get_today_pnl(self) -> float:
|
||||
"""Return today's cumulative PnL (UTC)."""
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
row = self._db.execute(
|
||||
"SELECT COALESCE(SUM(pnl), 0.0) FROM trades "
|
||||
"WHERE date(created_at, 'unixepoch') = ?",
|
||||
[today],
|
||||
).fetchone()
|
||||
return float(row[0])
|
||||
|
||||
def get_today_trade_count(self) -> int:
|
||||
"""Return the number of trades placed today (UTC)."""
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
row = self._db.execute(
|
||||
"SELECT COUNT(*) FROM trades "
|
||||
"WHERE date(created_at, 'unixepoch') = ?",
|
||||
[today],
|
||||
).fetchone()
|
||||
return int(row[0])
|
||||
Reference in New Issue
Block a user