Files

223 lines
7.9 KiB
Python
Raw Permalink Normal View History

2026-03-20 07:49:42 +09:00
"""Data access layer for the trading database.
Provides CRUD operations for positions, trades, daily performance,
and bot state using synchronous sqlite3.
"""
from __future__ import annotations
import json
import sqlite3
from datetime import date, datetime
from typing import Dict, List, Optional
from loguru import logger
from config import settings
from database.models import (
DailyPerformance,
PositionRecord,
TradeRecord,
init_db,
)
class TradingRepository:
"""Synchronous repository for all trading data."""
def __init__(self, db_path: str | None = None):
self._db_path = db_path or settings.DB_PATH
self._conn: Optional[sqlite3.Connection] = None
def connect(self) -> None:
"""Open and initialise the database connection."""
self._conn = init_db(self._db_path)
self._conn.row_factory = sqlite3.Row
def close(self) -> None:
if self._conn:
self._conn.close()
@property
def conn(self) -> sqlite3.Connection:
if self._conn is None:
self.connect()
return self._conn # type: ignore
# ------------------------------------------------------------------
# Positions
# ------------------------------------------------------------------
def save_position(self, pos: PositionRecord) -> None:
self.conn.execute(
"""INSERT OR REPLACE INTO positions
(id, symbol, direction, entry_price, amount, stop_loss,
take_profit, trailing_stop, realized_pnl, status,
opened_at, closed_at, close_reason, confluence_score, entry_reasons)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(
pos.id, pos.symbol, pos.direction, pos.entry_price,
pos.amount, pos.stop_loss, pos.take_profit, pos.trailing_stop,
pos.realized_pnl, pos.status, pos.opened_at, pos.closed_at,
pos.close_reason, pos.confluence_score, pos.entry_reasons,
),
)
self.conn.commit()
def get_position(self, position_id: str) -> Optional[PositionRecord]:
row = self.conn.execute(
"SELECT * FROM positions WHERE id = ?", (position_id,)
).fetchone()
return self._row_to_position(row) if row else None
def get_open_positions(self) -> List[PositionRecord]:
rows = self.conn.execute(
"SELECT * FROM positions WHERE status = 'OPEN'"
).fetchall()
return [self._row_to_position(r) for r in rows]
def get_closed_positions(self, limit: int = 100) -> List[PositionRecord]:
rows = self.conn.execute(
"SELECT * FROM positions WHERE status = 'CLOSED' ORDER BY closed_at DESC LIMIT ?",
(limit,),
).fetchall()
return [self._row_to_position(r) for r in rows]
def _row_to_position(self, row: sqlite3.Row) -> PositionRecord:
return PositionRecord(
id=row["id"],
symbol=row["symbol"],
direction=row["direction"],
entry_price=row["entry_price"],
amount=row["amount"],
stop_loss=row["stop_loss"],
take_profit=row["take_profit"],
trailing_stop=row["trailing_stop"],
realized_pnl=row["realized_pnl"],
status=row["status"],
opened_at=row["opened_at"] or "",
closed_at=row["closed_at"],
close_reason=row["close_reason"],
confluence_score=row["confluence_score"] or 0,
entry_reasons=row["entry_reasons"] or "[]",
)
# ------------------------------------------------------------------
# Trade Records
# ------------------------------------------------------------------
def save_trade(self, trade: TradeRecord) -> None:
self.conn.execute(
"""INSERT OR REPLACE INTO trade_records
(id, position_id, symbol, side, order_type, price, amount, fee, timestamp)
VALUES (?,?,?,?,?,?,?,?,?)""",
(
trade.id, trade.position_id, trade.symbol, trade.side,
trade.order_type, trade.price, trade.amount, trade.fee,
trade.timestamp,
),
)
self.conn.commit()
def get_trades_for_position(self, position_id: str) -> List[TradeRecord]:
rows = self.conn.execute(
"SELECT * FROM trade_records WHERE position_id = ? ORDER BY timestamp",
(position_id,),
).fetchall()
return [
TradeRecord(
id=r["id"],
position_id=r["position_id"],
symbol=r["symbol"],
side=r["side"],
order_type=r["order_type"],
price=r["price"],
amount=r["amount"],
fee=r["fee"],
timestamp=r["timestamp"] or "",
)
for r in rows
]
# ------------------------------------------------------------------
# Daily Performance
# ------------------------------------------------------------------
def update_daily_performance(
self, pnl: float, is_win: bool, max_dd: float = 0.0
) -> None:
today = date.today().isoformat()
existing = self.conn.execute(
"SELECT * FROM daily_performance WHERE date = ?", (today,)
).fetchone()
if existing:
self.conn.execute(
"""UPDATE daily_performance SET
total_trades = total_trades + 1,
winning_trades = winning_trades + ?,
losing_trades = losing_trades + ?,
total_pnl = total_pnl + ?,
max_drawdown = MAX(max_drawdown, ?)
WHERE date = ?""",
(1 if is_win else 0, 0 if is_win else 1, pnl, max_dd, today),
)
else:
self.conn.execute(
"""INSERT INTO daily_performance
(date, total_trades, winning_trades, losing_trades, total_pnl, max_drawdown)
VALUES (?,1,?,?,?,?)""",
(today, 1 if is_win else 0, 0 if is_win else 1, pnl, max_dd),
)
self.conn.commit()
def get_daily_performance(self, day: str | None = None) -> Optional[DailyPerformance]:
day = day or date.today().isoformat()
row = self.conn.execute(
"SELECT * FROM daily_performance WHERE date = ?", (day,)
).fetchone()
if not row:
return None
return DailyPerformance(
date=row["date"],
total_trades=row["total_trades"],
winning_trades=row["winning_trades"],
losing_trades=row["losing_trades"],
total_pnl=row["total_pnl"],
max_drawdown=row["max_drawdown"],
)
def get_performance_history(self, days: int = 30) -> List[DailyPerformance]:
rows = self.conn.execute(
"SELECT * FROM daily_performance ORDER BY date DESC LIMIT ?", (days,)
).fetchall()
return [
DailyPerformance(
date=r["date"],
total_trades=r["total_trades"],
winning_trades=r["winning_trades"],
losing_trades=r["losing_trades"],
total_pnl=r["total_pnl"],
max_drawdown=r["max_drawdown"],
)
for r in rows
]
# ------------------------------------------------------------------
# Bot State
# ------------------------------------------------------------------
def set_state(self, key: str, value: str) -> None:
self.conn.execute(
"""INSERT OR REPLACE INTO bot_state (key, value, updated_at)
VALUES (?, ?, ?)""",
(key, value, datetime.utcnow().isoformat()),
)
self.conn.commit()
def get_state(self, key: str) -> Optional[str]:
row = self.conn.execute(
"SELECT value FROM bot_state WHERE key = ?", (key,)
).fetchone()
return row["value"] if row else None