update 03-22 09:28
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
251
tests/test_execution.py
Normal file
251
tests/test_execution.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for execution, position tracking, and risk management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Config, RiskConfig, FeesConfig
|
||||
from src.data.models import (
|
||||
Asset,
|
||||
Direction,
|
||||
Position,
|
||||
Signal,
|
||||
Timeframe,
|
||||
Trade,
|
||||
TradeStatus,
|
||||
)
|
||||
from src.data.db import TradeDB
|
||||
from src.execution.position_tracker import PositionTracker
|
||||
from src.risk.risk_manager import RiskManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return Config()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def position_tracker(config):
|
||||
return PositionTracker(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trade_db():
|
||||
return TradeDB(db_path=":memory:")
|
||||
|
||||
|
||||
def _make_signal(
|
||||
asset="BTC", direction=Direction.UP, price=0.50, size=100, edge=0.10
|
||||
) -> Signal:
|
||||
return Signal(
|
||||
direction=direction,
|
||||
asset=Asset(asset),
|
||||
timeframe=Timeframe.FIVE_MIN,
|
||||
token_id=f"token_{asset}_{direction.value}",
|
||||
price=price,
|
||||
size=size,
|
||||
edge=edge,
|
||||
estimated_prob=price + edge + 0.0156,
|
||||
)
|
||||
|
||||
|
||||
def _make_trade(
|
||||
signal=None, fill_price=None, fill_size=None, status=TradeStatus.FILLED
|
||||
) -> Trade:
|
||||
if signal is None:
|
||||
signal = _make_signal()
|
||||
return Trade(
|
||||
id="test_trade_1",
|
||||
signal=signal,
|
||||
order_id="order_1",
|
||||
status=status,
|
||||
fill_price=fill_price or signal.price,
|
||||
fill_size=fill_size or signal.size,
|
||||
fee=0.78,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PositionTracker tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPositionTracker:
|
||||
def test_open_position(self, position_tracker):
|
||||
trade = _make_trade()
|
||||
pos = position_tracker.open_position(trade)
|
||||
|
||||
assert pos.size == 100
|
||||
assert pos.avg_price == 0.50
|
||||
assert pos.asset == Asset.BTC
|
||||
assert position_tracker.position_count == 1
|
||||
|
||||
def test_close_position_win(self, position_tracker):
|
||||
trade = _make_trade()
|
||||
position_tracker.open_position(trade)
|
||||
|
||||
pnl = position_tracker.close_position(trade.signal.token_id, 1.0)
|
||||
|
||||
assert pnl is not None
|
||||
assert pnl == pytest.approx(50.0) # (1.0 - 0.50) * 100
|
||||
assert position_tracker.position_count == 0
|
||||
assert position_tracker.total_realized_pnl == pytest.approx(50.0)
|
||||
|
||||
def test_close_position_loss(self, position_tracker):
|
||||
trade = _make_trade()
|
||||
position_tracker.open_position(trade)
|
||||
|
||||
pnl = position_tracker.close_position(trade.signal.token_id, 0.0)
|
||||
|
||||
assert pnl is not None
|
||||
assert pnl == pytest.approx(-50.0) # (0 - 0.50) * 100
|
||||
assert position_tracker.total_realized_pnl == pytest.approx(-50.0)
|
||||
|
||||
def test_add_to_existing_position(self, position_tracker):
|
||||
trade1 = _make_trade(signal=_make_signal(price=0.50, size=100))
|
||||
trade2 = _make_trade(signal=_make_signal(price=0.60, size=50))
|
||||
trade2.fill_price = 0.60
|
||||
trade2.fill_size = 50
|
||||
|
||||
position_tracker.open_position(trade1)
|
||||
pos = position_tracker.open_position(trade2)
|
||||
|
||||
assert pos.size == 150
|
||||
expected_avg = (0.50 * 100 + 0.60 * 50) / 150
|
||||
assert pos.avg_price == pytest.approx(expected_avg, abs=0.001)
|
||||
|
||||
def test_total_exposure(self, position_tracker):
|
||||
trade = _make_trade()
|
||||
position_tracker.open_position(trade)
|
||||
|
||||
assert position_tracker.total_exposure == pytest.approx(50.0)
|
||||
|
||||
def test_mark_to_market(self, position_tracker):
|
||||
trade = _make_trade()
|
||||
position_tracker.open_position(trade)
|
||||
|
||||
position_tracker.update_mark(trade.signal.token_id, 0.70)
|
||||
|
||||
pos = position_tracker.get_position(trade.signal.token_id)
|
||||
assert pos.current_value == pytest.approx(70.0)
|
||||
assert pos.unrealized_pnl == pytest.approx(20.0)
|
||||
|
||||
def test_win_rate(self, position_tracker):
|
||||
# Open and close two positions, one win one loss
|
||||
trade1 = _make_trade(signal=_make_signal(asset="BTC"))
|
||||
trade1.id = "t1"
|
||||
trade2 = _make_trade(signal=_make_signal(asset="ETH"))
|
||||
trade2.id = "t2"
|
||||
trade2.signal = _make_signal(asset="ETH")
|
||||
|
||||
position_tracker.open_position(trade1)
|
||||
position_tracker.open_position(trade2)
|
||||
|
||||
position_tracker.close_position(trade1.signal.token_id, 1.0) # Win
|
||||
position_tracker.close_position(trade2.signal.token_id, 0.0) # Loss
|
||||
|
||||
assert position_tracker.win_rate == pytest.approx(0.5)
|
||||
|
||||
def test_get_summary(self, position_tracker):
|
||||
summary = position_tracker.get_summary()
|
||||
assert "positions" in summary
|
||||
assert "total_pnl" in summary
|
||||
assert "win_rate" in summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskManager tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRiskManager:
|
||||
def test_can_open_within_limits(self, config, position_tracker, trade_db):
|
||||
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
||||
assert risk_mgr.can_open_position(1000) is True
|
||||
|
||||
def test_rejects_over_exposure(self, position_tracker, trade_db):
|
||||
risk_config = RiskConfig(max_total_exposure_usd=100)
|
||||
risk_mgr = RiskManager(risk_config, position_tracker, trade_db)
|
||||
|
||||
# Fill up exposure
|
||||
trade = _make_trade(signal=_make_signal(price=0.50, size=250))
|
||||
position_tracker.open_position(trade)
|
||||
# Exposure = 125, limit = 100
|
||||
|
||||
assert risk_mgr.check_all() is False
|
||||
|
||||
def test_rejects_over_position_count(self, position_tracker, trade_db):
|
||||
risk_config = RiskConfig(max_concurrent_positions=1)
|
||||
risk_mgr = RiskManager(risk_config, position_tracker, trade_db)
|
||||
|
||||
trade = _make_trade()
|
||||
position_tracker.open_position(trade)
|
||||
|
||||
assert risk_mgr.can_open_position(50) is False
|
||||
|
||||
def test_halt_and_resume(self, config, position_tracker, trade_db):
|
||||
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
||||
|
||||
assert risk_mgr.is_halted is False
|
||||
risk_mgr._halt(halt_event="test", reason="manual test halt")
|
||||
assert risk_mgr.is_halted is True
|
||||
assert risk_mgr.check_all() is False
|
||||
|
||||
risk_mgr.resume()
|
||||
assert risk_mgr.is_halted is False
|
||||
|
||||
def test_risk_summary(self, config, position_tracker, trade_db):
|
||||
risk_mgr = RiskManager(config.risk, position_tracker, trade_db)
|
||||
summary = risk_mgr.get_risk_summary()
|
||||
assert "halted" in summary
|
||||
assert "daily_pnl" in summary
|
||||
assert "total_exposure" in summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TradeDB tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTradeDB:
|
||||
def test_log_and_query_trade(self, trade_db):
|
||||
trade = _make_trade()
|
||||
trade_db.log_trade(trade)
|
||||
|
||||
count = trade_db.get_today_trade_count()
|
||||
assert count == 1
|
||||
|
||||
def test_update_trade(self, trade_db):
|
||||
trade = _make_trade()
|
||||
trade_db.log_trade(trade)
|
||||
|
||||
trade_db.update_trade(trade.id, pnl=50.0, status="FILLED")
|
||||
# No exception = success
|
||||
|
||||
def test_log_balance(self, trade_db):
|
||||
trade_db.log_balance(10000, 0.0, event="start")
|
||||
trade_db.log_balance(10050, 50.0, event="update")
|
||||
|
||||
bal = trade_db.get_latest_balance()
|
||||
assert bal == pytest.approx(10050.0)
|
||||
|
||||
def test_today_pnl(self, trade_db):
|
||||
trade = _make_trade()
|
||||
trade.pnl = 25.0
|
||||
trade_db.log_trade(trade)
|
||||
|
||||
pnl = trade_db.get_today_pnl()
|
||||
assert pnl == pytest.approx(25.0)
|
||||
|
||||
def test_daily_summary(self, trade_db):
|
||||
trade = _make_trade()
|
||||
trade.pnl = 30.0
|
||||
trade_db.log_trade(trade)
|
||||
|
||||
today = time.strftime("%Y-%m-%d", time.gmtime())
|
||||
summary = trade_db.get_daily_summary(today)
|
||||
assert summary["total_trades"] == 1
|
||||
assert summary["total_pnl"] == pytest.approx(30.0)
|
||||
336
tests/test_strategy.py
Normal file
336
tests/test_strategy.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Tests for temporal arbitrage strategy and related components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import (
|
||||
FeesConfig,
|
||||
RiskConfig,
|
||||
TemporalArbConfig,
|
||||
SumToOneConfig,
|
||||
SpreadCaptureConfig,
|
||||
)
|
||||
from src.data.models import Asset, Direction, Timeframe, Signal, OrderBookLevel, OrderBookSnapshot
|
||||
from src.strategy.temporal_arb import TemporalArbStrategy
|
||||
from src.strategy.sum_to_one import SumToOneStrategy
|
||||
from src.risk.fee_calculator import FeeCalculator
|
||||
from src.risk.position_sizer import PositionSizer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def arb_config():
|
||||
return TemporalArbConfig(
|
||||
enabled=True,
|
||||
min_price_move_pct=0.03,
|
||||
max_poly_entry_price=0.65,
|
||||
min_edge=0.05,
|
||||
exit_before_resolution_sec=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def risk_config():
|
||||
return RiskConfig(
|
||||
max_position_per_market_usd=5000,
|
||||
max_total_exposure_usd=20000,
|
||||
max_daily_loss_usd=2000,
|
||||
kelly_fraction_cap=0.25,
|
||||
max_concurrent_positions=6,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fees_config():
|
||||
return FeesConfig(taker_fee_5m=0.0156, taker_fee_15m=0.03)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(arb_config, risk_config, fees_config):
|
||||
return TemporalArbStrategy(
|
||||
arb_config=arb_config,
|
||||
risk_config=risk_config,
|
||||
fees_config=fees_config,
|
||||
balance=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fee_calc(fees_config):
|
||||
return FeeCalculator(fees_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TemporalArbStrategy tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTemporalArbStrategy:
|
||||
def test_no_signal_below_min_move(self, strategy):
|
||||
"""No signal when price move is too small."""
|
||||
result = asyncio.run(strategy.evaluate(
|
||||
symbol="BTC",
|
||||
cex_price=84010,
|
||||
window_start_price=84000,
|
||||
window_end_time=time.time() + 200,
|
||||
poly_up_ask=0.50,
|
||||
poly_down_ask=0.50,
|
||||
up_token_id="up_1",
|
||||
down_token_id="down_1",
|
||||
timeframe="5M",
|
||||
))
|
||||
assert result is None
|
||||
|
||||
def test_signal_generated_on_sufficient_move(self, strategy):
|
||||
"""Signal generated when price move and edge are sufficient."""
|
||||
result = asyncio.run(strategy.evaluate(
|
||||
symbol="BTC",
|
||||
cex_price=84300, # +0.36% move
|
||||
window_start_price=84000,
|
||||
window_end_time=time.time() + 200,
|
||||
poly_up_ask=0.50,
|
||||
poly_down_ask=0.50,
|
||||
up_token_id="up_1",
|
||||
down_token_id="down_1",
|
||||
timeframe="5M",
|
||||
))
|
||||
assert result is not None
|
||||
assert result.direction == Direction.UP
|
||||
assert result.asset == Asset.BTC
|
||||
assert result.price == 0.50
|
||||
assert result.edge > 0
|
||||
assert result.size > 0
|
||||
|
||||
def test_down_signal(self, strategy):
|
||||
"""Signal generated for DOWN direction."""
|
||||
result = asyncio.run(strategy.evaluate(
|
||||
symbol="ETH",
|
||||
cex_price=2290, # -0.43% from 2300
|
||||
window_start_price=2300,
|
||||
window_end_time=time.time() + 200,
|
||||
poly_up_ask=0.50,
|
||||
poly_down_ask=0.48,
|
||||
up_token_id="up_1",
|
||||
down_token_id="down_1",
|
||||
timeframe="15M",
|
||||
))
|
||||
assert result is not None
|
||||
assert result.direction == Direction.DOWN
|
||||
assert result.asset == Asset.ETH
|
||||
|
||||
def test_no_signal_when_poly_price_too_high(self, strategy):
|
||||
"""No signal when Polymarket price exceeds max entry price."""
|
||||
result = asyncio.run(strategy.evaluate(
|
||||
symbol="BTC",
|
||||
cex_price=84500,
|
||||
window_start_price=84000,
|
||||
window_end_time=time.time() + 200,
|
||||
poly_up_ask=0.70, # Above max_poly_entry_price=0.65
|
||||
poly_down_ask=0.30,
|
||||
up_token_id="up_1",
|
||||
down_token_id="down_1",
|
||||
timeframe="5M",
|
||||
))
|
||||
assert result is None
|
||||
|
||||
def test_no_signal_too_close_to_resolution(self, strategy):
|
||||
"""No signal when window is about to expire."""
|
||||
result = asyncio.run(strategy.evaluate(
|
||||
symbol="BTC",
|
||||
cex_price=84500,
|
||||
window_start_price=84000,
|
||||
window_end_time=time.time() + 3, # Only 3 seconds left
|
||||
poly_up_ask=0.50,
|
||||
poly_down_ask=0.50,
|
||||
up_token_id="up_1",
|
||||
down_token_id="down_1",
|
||||
timeframe="5M",
|
||||
))
|
||||
assert result is None
|
||||
|
||||
def test_probability_estimation(self, strategy):
|
||||
"""Probability increases with price magnitude."""
|
||||
prob_small = strategy.estimate_probability(0.1, 200, 300)
|
||||
prob_medium = strategy.estimate_probability(0.3, 200, 300)
|
||||
prob_large = strategy.estimate_probability(0.5, 200, 300)
|
||||
|
||||
assert prob_small < prob_medium < prob_large
|
||||
assert prob_small >= 0.50
|
||||
assert prob_large <= 0.95
|
||||
|
||||
def test_kelly_sizing_positive_edge(self, strategy):
|
||||
"""Kelly sizing returns positive size for positive edge."""
|
||||
size = strategy.calculate_kelly_size(
|
||||
edge=0.10, price=0.50, balance=10000, max_size=5000
|
||||
)
|
||||
assert size > 0
|
||||
assert size * 0.50 <= 5000 # Within max size
|
||||
|
||||
def test_kelly_sizing_zero_edge(self, strategy):
|
||||
"""Kelly sizing returns 0 for zero or negative edge."""
|
||||
size = strategy.calculate_kelly_size(
|
||||
edge=-0.05, price=0.50, balance=10000, max_size=5000
|
||||
)
|
||||
assert size == 0
|
||||
|
||||
def test_should_exit_early_reversal(self, strategy):
|
||||
"""Exit signal on price reversal."""
|
||||
should_exit = strategy.should_exit_early(
|
||||
entry_direction=Direction.UP,
|
||||
entry_price=0.50,
|
||||
current_poly_price=0.45,
|
||||
cex_price=83800, # Price reversed down
|
||||
window_start_price=84000,
|
||||
time_remaining=100,
|
||||
)
|
||||
assert should_exit is True
|
||||
|
||||
def test_should_not_exit_when_direction_holds(self, strategy):
|
||||
"""No exit when direction still holds."""
|
||||
should_exit = strategy.should_exit_early(
|
||||
entry_direction=Direction.UP,
|
||||
entry_price=0.50,
|
||||
current_poly_price=0.60,
|
||||
cex_price=84200,
|
||||
window_start_price=84000,
|
||||
time_remaining=100,
|
||||
)
|
||||
assert should_exit is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FeeCalculator tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFeeCalculator:
|
||||
def test_taker_fee_5m(self, fee_calc):
|
||||
"""5M taker fee calculation."""
|
||||
fee = fee_calc.taker_fee("5M", 0.50, 100)
|
||||
# Profit = 100*1.0 - 100*0.50 = 50, fee = 50 * 0.0156 = 0.78
|
||||
assert abs(fee - 0.78) < 0.01
|
||||
|
||||
def test_taker_fee_15m(self, fee_calc):
|
||||
"""15M taker fee is higher."""
|
||||
fee_5m = fee_calc.taker_fee("5M", 0.50, 100)
|
||||
fee_15m = fee_calc.taker_fee("15M", 0.50, 100)
|
||||
assert fee_15m > fee_5m
|
||||
|
||||
def test_net_payout_win(self, fee_calc):
|
||||
"""Net payout on a win."""
|
||||
payout = fee_calc.net_payout("5M", 0.50, 100, won=True)
|
||||
assert payout > 0
|
||||
assert payout < 50 # Less than gross profit due to fees
|
||||
|
||||
def test_net_payout_loss(self, fee_calc):
|
||||
"""Net payout on a loss."""
|
||||
payout = fee_calc.net_payout("5M", 0.50, 100, won=False)
|
||||
assert payout == -50.0 # Total loss of cost basis
|
||||
|
||||
def test_breakeven_price(self, fee_calc):
|
||||
"""Breakeven probability is higher than entry price."""
|
||||
be = fee_calc.breakeven_price("5M", 0.50)
|
||||
assert be > 0.50
|
||||
assert be < 1.0
|
||||
|
||||
def test_expected_value_positive_edge(self, fee_calc):
|
||||
"""EV is positive when estimated prob exceeds breakeven."""
|
||||
ev = fee_calc.expected_value("5M", 0.50, 0.70, 100)
|
||||
assert ev > 0
|
||||
|
||||
def test_expected_value_negative_edge(self, fee_calc):
|
||||
"""EV is negative when estimated prob is below breakeven."""
|
||||
ev = fee_calc.expected_value("5M", 0.50, 0.50, 100)
|
||||
assert ev < 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data models tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModels:
|
||||
def test_orderbook_snapshot(self):
|
||||
book = OrderBookSnapshot(
|
||||
token_id="test",
|
||||
bids=[OrderBookLevel(0.48, 100), OrderBookLevel(0.47, 200)],
|
||||
asks=[OrderBookLevel(0.52, 100), OrderBookLevel(0.53, 200)],
|
||||
)
|
||||
assert book.best_bid == 0.48
|
||||
assert book.best_ask == 0.52
|
||||
assert book.spread == pytest.approx(0.04)
|
||||
|
||||
def test_empty_orderbook(self):
|
||||
book = OrderBookSnapshot(token_id="test")
|
||||
assert book.best_bid is None
|
||||
assert book.best_ask is None
|
||||
assert book.spread is None
|
||||
|
||||
def test_signal_timestamp(self):
|
||||
sig = Signal(
|
||||
direction=Direction.UP,
|
||||
asset=Asset.BTC,
|
||||
timeframe=Timeframe.FIVE_MIN,
|
||||
token_id="test",
|
||||
price=0.50,
|
||||
size=100,
|
||||
edge=0.10,
|
||||
estimated_prob=0.65,
|
||||
)
|
||||
assert sig.timestamp > 0
|
||||
assert sig.price == 0.50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WindowTracker tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWindowTracker:
|
||||
def test_window_creation_on_first_tick(self):
|
||||
from src.market.window_tracker import WindowTracker
|
||||
|
||||
tracker = WindowTracker(
|
||||
assets=[Asset.BTC],
|
||||
timeframes=[Timeframe.FIVE_MIN],
|
||||
)
|
||||
|
||||
changed = []
|
||||
tracker.on_window_change(lambda w: changed.append(w))
|
||||
|
||||
tracker.update_price("BTC", 84000.0, time.time())
|
||||
|
||||
assert len(changed) == 1
|
||||
assert changed[0].asset == Asset.BTC
|
||||
assert changed[0].start_price == 84000.0
|
||||
|
||||
def test_get_window(self):
|
||||
from src.market.window_tracker import WindowTracker
|
||||
|
||||
tracker = WindowTracker(
|
||||
assets=[Asset.BTC],
|
||||
timeframes=[Timeframe.FIVE_MIN],
|
||||
)
|
||||
tracker.update_price("BTC", 84000.0, time.time())
|
||||
|
||||
window = tracker.get_window("BTC", "5M")
|
||||
assert window is not None
|
||||
assert window.start_price == 84000.0
|
||||
|
||||
def test_price_update_within_window(self):
|
||||
from src.market.window_tracker import WindowTracker
|
||||
|
||||
tracker = WindowTracker(
|
||||
assets=[Asset.BTC],
|
||||
timeframes=[Timeframe.FIVE_MIN],
|
||||
)
|
||||
now = time.time()
|
||||
tracker.update_price("BTC", 84000.0, now)
|
||||
tracker.update_price("BTC", 84100.0, now + 1)
|
||||
|
||||
window = tracker.get_window("BTC", "5M")
|
||||
assert window.start_price == 84000.0
|
||||
assert window.current_price == 84100.0
|
||||
Reference in New Issue
Block a user