"""Tests for execution, position tracking, and risk management.""" from __future__ import annotations import time import pytest from src.config import Config, RiskConfig, FeesConfig from src.data.models import ( Asset, Direction, Position, Signal, Timeframe, Trade, TradeStatus, ) from src.data.db import TradeDB from src.execution.position_tracker import PositionTracker from src.risk.risk_manager import RiskManager # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def config(): return Config() @pytest.fixture def position_tracker(config): return PositionTracker(config) @pytest.fixture def trade_db(): return TradeDB(db_path=":memory:") def _make_signal( asset="BTC", direction=Direction.UP, price=0.50, size=100, edge=0.10 ) -> Signal: return Signal( direction=direction, asset=Asset(asset), timeframe=Timeframe.FIVE_MIN, token_id=f"token_{asset}_{direction.value}", price=price, size=size, edge=edge, estimated_prob=price + edge + 0.0156, ) def _make_trade( signal=None, fill_price=None, fill_size=None, status=TradeStatus.FILLED ) -> Trade: if signal is None: signal = _make_signal() return Trade( id="test_trade_1", signal=signal, order_id="order_1", status=status, fill_price=fill_price or signal.price, fill_size=fill_size or signal.size, fee=0.78, ) # --------------------------------------------------------------------------- # PositionTracker tests # --------------------------------------------------------------------------- class TestPositionTracker: def test_open_position(self, position_tracker): trade = _make_trade() pos = position_tracker.open_position(trade) assert pos.size == 100 assert pos.avg_price == 0.50 assert pos.asset == Asset.BTC assert position_tracker.position_count == 1 def test_close_position_win(self, position_tracker): trade = _make_trade() position_tracker.open_position(trade) pnl = position_tracker.close_position(trade.signal.token_id, 1.0) assert pnl is not None assert pnl == pytest.approx(50.0) # (1.0 - 0.50) * 100 assert position_tracker.position_count == 0 assert position_tracker.total_realized_pnl == pytest.approx(50.0) def test_close_position_loss(self, position_tracker): trade = _make_trade() position_tracker.open_position(trade) pnl = position_tracker.close_position(trade.signal.token_id, 0.0) assert pnl is not None assert pnl == pytest.approx(-50.0) # (0 - 0.50) * 100 assert position_tracker.total_realized_pnl == pytest.approx(-50.0) def test_add_to_existing_position(self, position_tracker): trade1 = _make_trade(signal=_make_signal(price=0.50, size=100)) trade2 = _make_trade(signal=_make_signal(price=0.60, size=50)) trade2.fill_price = 0.60 trade2.fill_size = 50 position_tracker.open_position(trade1) pos = position_tracker.open_position(trade2) assert pos.size == 150 expected_avg = (0.50 * 100 + 0.60 * 50) / 150 assert pos.avg_price == pytest.approx(expected_avg, abs=0.001) def test_total_exposure(self, position_tracker): trade = _make_trade() position_tracker.open_position(trade) assert position_tracker.total_exposure == pytest.approx(50.0) def test_mark_to_market(self, position_tracker): trade = _make_trade() position_tracker.open_position(trade) position_tracker.update_mark(trade.signal.token_id, 0.70) pos = position_tracker.get_position(trade.signal.token_id) assert pos.current_value == pytest.approx(70.0) assert pos.unrealized_pnl == pytest.approx(20.0) def test_win_rate(self, position_tracker): # Open and close two positions, one win one loss trade1 = _make_trade(signal=_make_signal(asset="BTC")) trade1.id = "t1" trade2 = _make_trade(signal=_make_signal(asset="ETH")) trade2.id = "t2" trade2.signal = _make_signal(asset="ETH") position_tracker.open_position(trade1) position_tracker.open_position(trade2) position_tracker.close_position(trade1.signal.token_id, 1.0) # Win position_tracker.close_position(trade2.signal.token_id, 0.0) # Loss assert position_tracker.win_rate == pytest.approx(0.5) def test_get_summary(self, position_tracker): summary = position_tracker.get_summary() assert "positions" in summary assert "total_pnl" in summary assert "win_rate" in summary # --------------------------------------------------------------------------- # RiskManager tests # --------------------------------------------------------------------------- class TestRiskManager: def test_can_open_within_limits(self, config, position_tracker, trade_db): risk_mgr = RiskManager(config.risk, position_tracker, trade_db) assert risk_mgr.can_open_position(1000) is True def test_rejects_over_exposure(self, position_tracker, trade_db): risk_config = RiskConfig(max_total_exposure_usd=100) risk_mgr = RiskManager(risk_config, position_tracker, trade_db) # Fill up exposure trade = _make_trade(signal=_make_signal(price=0.50, size=250)) position_tracker.open_position(trade) # Exposure = 125, limit = 100 assert risk_mgr.check_all() is False def test_rejects_over_position_count(self, position_tracker, trade_db): risk_config = RiskConfig(max_concurrent_positions=1) risk_mgr = RiskManager(risk_config, position_tracker, trade_db) trade = _make_trade() position_tracker.open_position(trade) assert risk_mgr.can_open_position(50) is False def test_halt_and_resume(self, config, position_tracker, trade_db): risk_mgr = RiskManager(config.risk, position_tracker, trade_db) assert risk_mgr.is_halted is False risk_mgr._halt(halt_event="test", reason="manual test halt") assert risk_mgr.is_halted is True assert risk_mgr.check_all() is False risk_mgr.resume() assert risk_mgr.is_halted is False def test_risk_summary(self, config, position_tracker, trade_db): risk_mgr = RiskManager(config.risk, position_tracker, trade_db) summary = risk_mgr.get_risk_summary() assert "halted" in summary assert "daily_pnl" in summary assert "total_exposure" in summary # --------------------------------------------------------------------------- # TradeDB tests # --------------------------------------------------------------------------- class TestTradeDB: def test_log_and_query_trade(self, trade_db): trade = _make_trade() trade_db.log_trade(trade) count = trade_db.get_today_trade_count() assert count == 1 def test_update_trade(self, trade_db): trade = _make_trade() trade_db.log_trade(trade) trade_db.update_trade(trade.id, pnl=50.0, status="FILLED") # No exception = success def test_log_balance(self, trade_db): trade_db.log_balance(10000, 0.0, event="start") trade_db.log_balance(10050, 50.0, event="update") bal = trade_db.get_latest_balance() assert bal == pytest.approx(10050.0) def test_today_pnl(self, trade_db): trade = _make_trade() trade.pnl = 25.0 trade_db.log_trade(trade) pnl = trade_db.get_today_pnl() assert pnl == pytest.approx(25.0) def test_daily_summary(self, trade_db): trade = _make_trade() trade.pnl = 30.0 trade_db.log_trade(trade) today = time.strftime("%Y-%m-%d", time.gmtime()) summary = trade_db.get_daily_summary(today) assert summary["total_trades"] == 1 assert summary["total_pnl"] == pytest.approx(30.0)