Add database layer with SQLite schema, ORM-style Database class, and tests
Implements Task 2: creates db/schema.sql with five tables (signals, trades, positions, portfolio, settings) and indexes; data/db.py with a Database class covering all CRUD operations; tests/test_db.py with 6 passing pytest tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
99
data/db.py
Normal file
99
data/db.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||||
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self.conn.execute("PRAGMA busy_timeout=5000")
|
||||||
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
schema_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "db", "schema.sql")
|
||||||
|
with open(schema_path) as f:
|
||||||
|
self.conn.executescript(f.read())
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
def execute(self, sql, params=()):
|
||||||
|
return self.conn.execute(sql, params)
|
||||||
|
|
||||||
|
def insert_signal(self, coin, technical, news, social, ai, composite, signal):
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO signals (coin, technical_score, news_score, social_score, ai_score, composite_score, signal) VALUES (?,?,?,?,?,?,?)",
|
||||||
|
(coin, technical, news, social, ai, composite, signal),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_latest_signals(self, limit=50):
|
||||||
|
return self.conn.execute(
|
||||||
|
"SELECT * FROM signals WHERE id IN (SELECT MAX(id) FROM signals GROUP BY coin) ORDER BY composite_score DESC LIMIT ?",
|
||||||
|
(limit,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
def insert_trade(self, coin, side, price, quantity, amount_usd, reason):
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO trades (coin, side, price, quantity, amount_usd, reason) VALUES (?,?,?,?,?,?)",
|
||||||
|
(coin, side, price, quantity, amount_usd, reason),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_trades(self, limit=100):
|
||||||
|
return self.conn.execute(
|
||||||
|
"SELECT * FROM trades ORDER BY timestamp DESC LIMIT ?", (limit,)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
def open_position(self, coin, entry_price, quantity, invested_usd):
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO positions (coin, entry_price, quantity, invested_usd) VALUES (?,?,?,?)",
|
||||||
|
(coin, entry_price, quantity, invested_usd),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_open_positions(self):
|
||||||
|
return self.conn.execute(
|
||||||
|
"SELECT * FROM positions WHERE status='OPEN'"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
def close_position(self, position_id):
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE positions SET status='CLOSED', closed_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||||
|
(position_id,),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def update_position_quantity(self, position_id, new_quantity):
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE positions SET quantity=? WHERE id=?",
|
||||||
|
(new_quantity, position_id),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def insert_portfolio_snapshot(self, total_value, cash, pnl, pnl_pct):
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO portfolio (total_value, cash, pnl, pnl_pct) VALUES (?,?,?,?)",
|
||||||
|
(total_value, cash, pnl, pnl_pct),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_portfolio_history(self, limit=100):
|
||||||
|
return self.conn.execute(
|
||||||
|
"SELECT * FROM portfolio ORDER BY timestamp DESC LIMIT ?", (limit,)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
def save_setting(self, key, value):
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO settings (key, value) VALUES (?,?)",
|
||||||
|
(key, value),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def load_setting(self, key):
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT value FROM settings WHERE key=?", (key,)
|
||||||
|
).fetchone()
|
||||||
|
return row["value"] if row else None
|
||||||
51
db/schema.sql
Normal file
51
db/schema.sql
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS signals (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
coin TEXT NOT NULL,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
technical_score REAL DEFAULT 50,
|
||||||
|
news_score REAL DEFAULT 50,
|
||||||
|
social_score REAL DEFAULT 50,
|
||||||
|
ai_score REAL DEFAULT 50,
|
||||||
|
composite_score REAL DEFAULT 50,
|
||||||
|
signal TEXT DEFAULT 'HOLD'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS trades (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
coin TEXT NOT NULL,
|
||||||
|
side TEXT NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
quantity REAL NOT NULL,
|
||||||
|
amount_usd REAL NOT NULL,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
reason TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS positions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
coin TEXT NOT NULL,
|
||||||
|
entry_price REAL NOT NULL,
|
||||||
|
quantity REAL NOT NULL,
|
||||||
|
invested_usd REAL NOT NULL,
|
||||||
|
status TEXT DEFAULT 'OPEN',
|
||||||
|
opened_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
closed_at DATETIME
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS portfolio (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
total_value REAL NOT NULL,
|
||||||
|
cash REAL NOT NULL,
|
||||||
|
pnl REAL NOT NULL,
|
||||||
|
pnl_pct REAL NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS settings (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_signals_coin_ts ON signals(coin, timestamp);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_positions_status ON positions(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_trades_coin ON trades(coin);
|
||||||
51
tests/test_db.py
Normal file
51
tests/test_db.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from data.db import Database
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db(tmp_path):
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
database = Database(db_path)
|
||||||
|
database.init()
|
||||||
|
yield database
|
||||||
|
database.close()
|
||||||
|
|
||||||
|
def test_init_creates_tables(db):
|
||||||
|
tables = db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||||
|
names = {r[0] for r in tables}
|
||||||
|
assert "signals" in names
|
||||||
|
assert "trades" in names
|
||||||
|
assert "positions" in names
|
||||||
|
assert "portfolio" in names
|
||||||
|
assert "settings" in names
|
||||||
|
|
||||||
|
def test_insert_signal(db):
|
||||||
|
db.insert_signal("BTCUSDT", 75.0, 60.0, 55.0, 70.0, 68.5, "HOLD")
|
||||||
|
rows = db.get_latest_signals()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["coin"] == "BTCUSDT"
|
||||||
|
assert rows[0]["composite_score"] == 68.5
|
||||||
|
|
||||||
|
def test_insert_trade(db):
|
||||||
|
db.insert_trade("ETHUSDT", "BUY", 3500.0, 0.01, 35.0, "signal")
|
||||||
|
trades = db.get_trades()
|
||||||
|
assert len(trades) == 1
|
||||||
|
assert trades[0]["coin"] == "ETHUSDT"
|
||||||
|
|
||||||
|
def test_open_close_position(db):
|
||||||
|
db.open_position("SOLUSDT", 140.0, 0.5, 70.0)
|
||||||
|
positions = db.get_open_positions()
|
||||||
|
assert len(positions) == 1
|
||||||
|
db.close_position(positions[0]["id"])
|
||||||
|
assert len(db.get_open_positions()) == 0
|
||||||
|
|
||||||
|
def test_save_load_setting(db):
|
||||||
|
db.save_setting("weights", '{"technical": 0.7}')
|
||||||
|
val = db.load_setting("weights")
|
||||||
|
assert val == '{"technical": 0.7}'
|
||||||
|
|
||||||
|
def test_portfolio_snapshot(db):
|
||||||
|
db.insert_portfolio_snapshot(210.0, 50.0, 10.0, 5.0)
|
||||||
|
snaps = db.get_portfolio_history()
|
||||||
|
assert len(snaps) == 1
|
||||||
|
assert snaps[0]["total_value"] == 210.0
|
||||||
Reference in New Issue
Block a user