252 lines
8.0 KiB
Python
252 lines
8.0 KiB
Python
"""Tests for execution, position tracking, and risk management."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from src.config import Config, RiskConfig, FeesConfig
|
|
from src.data.models import (
|
|
Asset,
|
|
Direction,
|
|
Position,
|
|
Signal,
|
|
Timeframe,
|
|
Trade,
|
|
TradeStatus,
|
|
)
|
|
from src.data.db import TradeDB
|
|
from src.execution.position_tracker import PositionTracker
|
|
from src.risk.risk_manager import RiskManager
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture
|
|
def config():
|
|
return Config()
|
|
|
|
|
|
@pytest.fixture
|
|
def position_tracker(config):
|
|
return PositionTracker(config)
|
|
|
|
|
|
@pytest.fixture
|
|
def trade_db():
|
|
return TradeDB(db_path=":memory:")
|
|
|
|
|
|
def _make_signal(
|
|
asset="BTC", direction=Direction.UP, price=0.50, size=100, edge=0.10
|
|
) -> Signal:
|
|
return Signal(
|
|
direction=direction,
|
|
asset=Asset(asset),
|
|
timeframe=Timeframe.FIVE_MIN,
|
|
token_id=f"token_{asset}_{direction.value}",
|
|
price=price,
|
|
size=size,
|
|
edge=edge,
|
|
estimated_prob=price + edge + 0.0156,
|
|
)
|
|
|
|
|
|
def _make_trade(
|
|
signal=None, fill_price=None, fill_size=None, status=TradeStatus.FILLED
|
|
) -> Trade:
|
|
if signal is None:
|
|
signal = _make_signal()
|
|
return Trade(
|
|
id="test_trade_1",
|
|
signal=signal,
|
|
order_id="order_1",
|
|
status=status,
|
|
fill_price=fill_price or signal.price,
|
|
fill_size=fill_size or signal.size,
|
|
fee=0.78,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PositionTracker tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPositionTracker:
|
|
def test_open_position(self, position_tracker):
|
|
trade = _make_trade()
|
|
pos = position_tracker.open_position(trade)
|
|
|
|
assert pos.size == 100
|
|
assert pos.avg_price == 0.50
|
|
assert pos.asset == Asset.BTC
|
|
assert position_tracker.position_count == 1
|
|
|
|
def test_close_position_win(self, position_tracker):
|
|
trade = _make_trade()
|
|
position_tracker.open_position(trade)
|
|
|
|
pnl = position_tracker.close_position(trade.signal.token_id, 1.0)
|
|
|
|
assert pnl is not None
|
|
assert pnl == pytest.approx(50.0) # (1.0 - 0.50) * 100
|
|
assert position_tracker.position_count == 0
|
|
assert position_tracker.total_realized_pnl == pytest.approx(50.0)
|
|
|
|
def test_close_position_loss(self, position_tracker):
|
|
trade = _make_trade()
|
|
position_tracker.open_position(trade)
|
|
|
|
pnl = position_tracker.close_position(trade.signal.token_id, 0.0)
|
|
|
|
assert pnl is not None
|
|
assert pnl == pytest.approx(-50.0) # (0 - 0.50) * 100
|
|
assert position_tracker.total_realized_pnl == pytest.approx(-50.0)
|
|
|
|
def test_add_to_existing_position(self, position_tracker):
|
|
trade1 = _make_trade(signal=_make_signal(price=0.50, size=100))
|
|
trade2 = _make_trade(signal=_make_signal(price=0.60, size=50))
|
|
trade2.fill_price = 0.60
|
|
trade2.fill_size = 50
|
|
|
|
position_tracker.open_position(trade1)
|
|
pos = position_tracker.open_position(trade2)
|
|
|
|
assert pos.size == 150
|
|
expected_avg = (0.50 * 100 + 0.60 * 50) / 150
|
|
assert pos.avg_price == pytest.approx(expected_avg, abs=0.001)
|
|
|
|
def test_total_exposure(self, position_tracker):
|
|
trade = _make_trade()
|
|
position_tracker.open_position(trade)
|
|
|
|
assert position_tracker.total_exposure == pytest.approx(50.0)
|
|
|
|
def test_mark_to_market(self, position_tracker):
|
|
trade = _make_trade()
|
|
position_tracker.open_position(trade)
|
|
|
|
position_tracker.update_mark(trade.signal.token_id, 0.70)
|
|
|
|
pos = position_tracker.get_position(trade.signal.token_id)
|
|
assert pos.current_value == pytest.approx(70.0)
|
|
assert pos.unrealized_pnl == pytest.approx(20.0)
|
|
|
|
def test_win_rate(self, position_tracker):
|
|
# Open and close two positions, one win one loss
|
|
trade1 = _make_trade(signal=_make_signal(asset="BTC"))
|
|
trade1.id = "t1"
|
|
trade2 = _make_trade(signal=_make_signal(asset="ETH"))
|
|
trade2.id = "t2"
|
|
trade2.signal = _make_signal(asset="ETH")
|
|
|
|
position_tracker.open_position(trade1)
|
|
position_tracker.open_position(trade2)
|
|
|
|
position_tracker.close_position(trade1.signal.token_id, 1.0) # Win
|
|
position_tracker.close_position(trade2.signal.token_id, 0.0) # Loss
|
|
|
|
assert position_tracker.win_rate == pytest.approx(0.5)
|
|
|
|
def test_get_summary(self, position_tracker):
|
|
summary = position_tracker.get_summary()
|
|
assert "positions" in summary
|
|
assert "total_pnl" in summary
|
|
assert "win_rate" in summary
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RiskManager tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRiskManager:
|
|
def test_can_open_within_limits(self, config, position_tracker, trade_db):
|
|
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
|
assert risk_mgr.can_open_position(1000) is True
|
|
|
|
def test_rejects_over_exposure(self, position_tracker, trade_db):
|
|
risk_config = RiskConfig(max_total_exposure_usd=100)
|
|
risk_mgr = RiskManager(risk_config, position_tracker, trade_db)
|
|
|
|
# Fill up exposure
|
|
trade = _make_trade(signal=_make_signal(price=0.50, size=250))
|
|
position_tracker.open_position(trade)
|
|
# Exposure = 125, limit = 100
|
|
|
|
assert risk_mgr.check_all() is False
|
|
|
|
def test_rejects_over_position_count(self, position_tracker, trade_db):
|
|
risk_config = RiskConfig(max_concurrent_positions=1)
|
|
risk_mgr = RiskManager(risk_config, position_tracker, trade_db)
|
|
|
|
trade = _make_trade()
|
|
position_tracker.open_position(trade)
|
|
|
|
assert risk_mgr.can_open_position(50) is False
|
|
|
|
def test_halt_and_resume(self, config, position_tracker, trade_db):
|
|
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
|
|
|
assert risk_mgr.is_halted is False
|
|
risk_mgr._halt(halt_event="test", reason="manual test halt")
|
|
assert risk_mgr.is_halted is True
|
|
assert risk_mgr.check_all() is False
|
|
|
|
risk_mgr.resume()
|
|
assert risk_mgr.is_halted is False
|
|
|
|
def test_risk_summary(self, config, position_tracker, trade_db):
|
|
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
|
summary = risk_mgr.get_risk_summary()
|
|
assert "halted" in summary
|
|
assert "daily_pnl" in summary
|
|
assert "total_exposure" in summary
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TradeDB tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestTradeDB:
|
|
def test_log_and_query_trade(self, trade_db):
|
|
trade = _make_trade()
|
|
trade_db.log_trade(trade)
|
|
|
|
count = trade_db.get_today_trade_count()
|
|
assert count == 1
|
|
|
|
def test_update_trade(self, trade_db):
|
|
trade = _make_trade()
|
|
trade_db.log_trade(trade)
|
|
|
|
trade_db.update_trade(trade.id, pnl=50.0, status="FILLED")
|
|
# No exception = success
|
|
|
|
def test_log_balance(self, trade_db):
|
|
trade_db.log_balance(10000, 0.0, event="start")
|
|
trade_db.log_balance(10050, 50.0, event="update")
|
|
|
|
bal = trade_db.get_latest_balance()
|
|
assert bal == pytest.approx(10050.0)
|
|
|
|
def test_today_pnl(self, trade_db):
|
|
trade = _make_trade()
|
|
trade.pnl = 25.0
|
|
trade_db.log_trade(trade)
|
|
|
|
pnl = trade_db.get_today_pnl()
|
|
assert pnl == pytest.approx(25.0)
|
|
|
|
def test_daily_summary(self, trade_db):
|
|
trade = _make_trade()
|
|
trade.pnl = 30.0
|
|
trade_db.log_trade(trade)
|
|
|
|
today = time.strftime("%Y-%m-%d", time.gmtime())
|
|
summary = trade_db.get_daily_summary(today)
|
|
assert summary["total_trades"] == 1
|
|
assert summary["total_pnl"] == pytest.approx(30.0)
|