From ffada928f2d04358b27cef3ed830858a2d91599d Mon Sep 17 00:00:00 2001 From: choijaewook Date: Fri, 20 Mar 2026 17:47:34 +0900 Subject: [PATCH] Add database layer with SQLite schema, ORM-style Database class, and tests Implements Task 2: creates db/schema.sql with five tables (signals, trades, positions, portfolio, settings) and indexes; data/db.py with a Database class covering all CRUD operations; tests/test_db.py with 6 passing pytest tests. Co-Authored-By: Claude Sonnet 4.6 --- data/db.py | 99 ++++++++++++++++++++++++++++++++++++++++++++++++ db/schema.sql | 51 +++++++++++++++++++++++++ tests/test_db.py | 51 +++++++++++++++++++++++++ 3 files changed, 201 insertions(+) create mode 100644 data/db.py create mode 100644 db/schema.sql create mode 100644 tests/test_db.py diff --git a/data/db.py b/data/db.py new file mode 100644 index 0000000..4d37599 --- /dev/null +++ b/data/db.py @@ -0,0 +1,99 @@ +import sqlite3 +import os + +class Database: + def __init__(self, db_path: str): + self.db_path = db_path + self.conn = None + + def init(self): + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.conn.execute("PRAGMA journal_mode=WAL") + self.conn.execute("PRAGMA busy_timeout=5000") + self.conn.row_factory = sqlite3.Row + schema_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "db", "schema.sql") + with open(schema_path) as f: + self.conn.executescript(f.read()) + + def close(self): + if self.conn: + self.conn.close() + + def execute(self, sql, params=()): + return self.conn.execute(sql, params) + + def insert_signal(self, coin, technical, news, social, ai, composite, signal): + self.conn.execute( + "INSERT INTO signals (coin, technical_score, news_score, social_score, ai_score, composite_score, signal) VALUES (?,?,?,?,?,?,?)", + (coin, technical, news, social, ai, composite, signal), + ) + self.conn.commit() + + def get_latest_signals(self, limit=50): + return self.conn.execute( + "SELECT * FROM signals WHERE id IN (SELECT MAX(id) FROM signals GROUP BY coin) ORDER BY composite_score DESC LIMIT ?", + (limit,), + ).fetchall() + + def insert_trade(self, coin, side, price, quantity, amount_usd, reason): + self.conn.execute( + "INSERT INTO trades (coin, side, price, quantity, amount_usd, reason) VALUES (?,?,?,?,?,?)", + (coin, side, price, quantity, amount_usd, reason), + ) + self.conn.commit() + + def get_trades(self, limit=100): + return self.conn.execute( + "SELECT * FROM trades ORDER BY timestamp DESC LIMIT ?", (limit,) + ).fetchall() + + def open_position(self, coin, entry_price, quantity, invested_usd): + self.conn.execute( + "INSERT INTO positions (coin, entry_price, quantity, invested_usd) VALUES (?,?,?,?)", + (coin, entry_price, quantity, invested_usd), + ) + self.conn.commit() + + def get_open_positions(self): + return self.conn.execute( + "SELECT * FROM positions WHERE status='OPEN'" + ).fetchall() + + def close_position(self, position_id): + self.conn.execute( + "UPDATE positions SET status='CLOSED', closed_at=CURRENT_TIMESTAMP WHERE id=?", + (position_id,), + ) + self.conn.commit() + + def update_position_quantity(self, position_id, new_quantity): + self.conn.execute( + "UPDATE positions SET quantity=? WHERE id=?", + (new_quantity, position_id), + ) + self.conn.commit() + + def insert_portfolio_snapshot(self, total_value, cash, pnl, pnl_pct): + self.conn.execute( + "INSERT INTO portfolio (total_value, cash, pnl, pnl_pct) VALUES (?,?,?,?)", + (total_value, cash, pnl, pnl_pct), + ) + self.conn.commit() + + def get_portfolio_history(self, limit=100): + return self.conn.execute( + "SELECT * FROM portfolio ORDER BY timestamp DESC LIMIT ?", (limit,) + ).fetchall() + + def save_setting(self, key, value): + self.conn.execute( + "INSERT OR REPLACE INTO settings (key, value) VALUES (?,?)", + (key, value), + ) + self.conn.commit() + + def load_setting(self, key): + row = self.conn.execute( + "SELECT value FROM settings WHERE key=?", (key,) + ).fetchone() + return row["value"] if row else None diff --git a/db/schema.sql b/db/schema.sql new file mode 100644 index 0000000..0568469 --- /dev/null +++ b/db/schema.sql @@ -0,0 +1,51 @@ +CREATE TABLE IF NOT EXISTS signals ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + coin TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + technical_score REAL DEFAULT 50, + news_score REAL DEFAULT 50, + social_score REAL DEFAULT 50, + ai_score REAL DEFAULT 50, + composite_score REAL DEFAULT 50, + signal TEXT DEFAULT 'HOLD' +); + +CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + coin TEXT NOT NULL, + side TEXT NOT NULL, + price REAL NOT NULL, + quantity REAL NOT NULL, + amount_usd REAL NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + reason TEXT +); + +CREATE TABLE IF NOT EXISTS positions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + coin TEXT NOT NULL, + entry_price REAL NOT NULL, + quantity REAL NOT NULL, + invested_usd REAL NOT NULL, + status TEXT DEFAULT 'OPEN', + opened_at DATETIME DEFAULT CURRENT_TIMESTAMP, + closed_at DATETIME +); + +CREATE TABLE IF NOT EXISTS portfolio ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + total_value REAL NOT NULL, + cash REAL NOT NULL, + pnl REAL NOT NULL, + pnl_pct REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_signals_coin_ts ON signals(coin, timestamp); +CREATE INDEX IF NOT EXISTS idx_positions_status ON positions(status); +CREATE INDEX IF NOT EXISTS idx_trades_coin ON trades(coin); diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..7a62f44 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,51 @@ +import os +import pytest +from data.db import Database + +@pytest.fixture +def db(tmp_path): + db_path = str(tmp_path / "test.db") + database = Database(db_path) + database.init() + yield database + database.close() + +def test_init_creates_tables(db): + tables = db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + names = {r[0] for r in tables} + assert "signals" in names + assert "trades" in names + assert "positions" in names + assert "portfolio" in names + assert "settings" in names + +def test_insert_signal(db): + db.insert_signal("BTCUSDT", 75.0, 60.0, 55.0, 70.0, 68.5, "HOLD") + rows = db.get_latest_signals() + assert len(rows) == 1 + assert rows[0]["coin"] == "BTCUSDT" + assert rows[0]["composite_score"] == 68.5 + +def test_insert_trade(db): + db.insert_trade("ETHUSDT", "BUY", 3500.0, 0.01, 35.0, "signal") + trades = db.get_trades() + assert len(trades) == 1 + assert trades[0]["coin"] == "ETHUSDT" + +def test_open_close_position(db): + db.open_position("SOLUSDT", 140.0, 0.5, 70.0) + positions = db.get_open_positions() + assert len(positions) == 1 + db.close_position(positions[0]["id"]) + assert len(db.get_open_positions()) == 0 + +def test_save_load_setting(db): + db.save_setting("weights", '{"technical": 0.7}') + val = db.load_setting("weights") + assert val == '{"technical": 0.7}' + +def test_portfolio_snapshot(db): + db.insert_portfolio_snapshot(210.0, 50.0, 10.0, 5.0) + snaps = db.get_portfolio_history() + assert len(snaps) == 1 + assert snaps[0]["total_value"] == 210.0