update 03-22 09:28

This commit is contained in:
2026-03-22 09:28:14 +09:00
commit 7f45211276
43 changed files with 9373 additions and 0 deletions

0
src/__init__.py Normal file
View File

262
src/config.py Normal file
View File

@@ -0,0 +1,262 @@
"""Configuration loader for the Polymarket Temporal Arbitrage Bot.
Loads environment variables from .env and typed settings from config.toml,
exposing them through nested dataclasses with validated, typed access.
"""
from __future__ import annotations
import logging
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import structlog
import toml
from dotenv import load_dotenv
# ---------------------------------------------------------------------------
# Nested config dataclasses
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class GeneralConfig:
mode: str = "paper"
assets: list[str] = field(default_factory=lambda: ["BTC", "ETH", "SOL"])
timeframes: list[str] = field(default_factory=lambda: ["5M", "15M"])
log_level: str = "INFO"
starting_balance: float = 500.0
@dataclass(frozen=True)
class TemporalArbConfig:
enabled: bool = True
min_price_move_pct: float = 0.15
max_poly_entry_price: float = 0.65
min_edge: float = 0.20
exit_before_resolution_sec: int = 5
@dataclass(frozen=True)
class SumToOneConfig:
enabled: bool = True
min_spread_after_fee: float = 0.02
@dataclass(frozen=True)
class SpreadCaptureConfig:
enabled: bool = False
spread_target: float = 0.04
@dataclass(frozen=True)
class RiskConfig:
max_position_per_market_usd: float = 5000.0
max_total_exposure_usd: float = 20000.0
max_daily_loss_usd: float = 2000.0
kelly_fraction_cap: float = 0.25
max_concurrent_positions: int = 6
@dataclass(frozen=True)
class FeesConfig:
taker_fee_5m: float = 0.0156
taker_fee_15m: float = 0.03
def fee_for_timeframe(self, timeframe: str) -> float:
"""Return the taker fee for a given timeframe string (e.g. '5M')."""
mapping = {
"5M": self.taker_fee_5m,
"15M": self.taker_fee_15m,
}
if timeframe not in mapping:
raise ValueError(f"Unknown timeframe '{timeframe}'; expected one of {list(mapping)}")
return mapping[timeframe]
@dataclass(frozen=True)
class BinanceConfig:
ws_url: str = "wss://stream.binance.com:9443/stream"
symbols: list[str] = field(default_factory=lambda: ["btcusdt", "ethusdt", "solusdt"])
@dataclass(frozen=True)
class PolymarketConfig:
clob_url: str = "https://clob.polymarket.com"
gamma_url: str = "https://gamma-api.polymarket.com"
data_url: str = "https://data-api.polymarket.com"
ws_url: str = "wss://ws-subscriptions-clob.polymarket.com/ws/"
chain_id: int = 137
signature_type: int = 2
# Credentials sourced from environment variables
private_key: str = ""
proxy_wallet: str = ""
api_key: str = ""
api_secret: str = ""
api_passphrase: str = ""
@dataclass(frozen=True)
class NotificationConfig:
telegram_enabled: bool = True
telegram_token: str = ""
telegram_chat_id: str = ""
notify_on_trade: bool = True
notify_on_daily_summary: bool = True
notify_on_error: bool = True
# ---------------------------------------------------------------------------
# Top-level Config
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Config:
general: GeneralConfig = field(default_factory=GeneralConfig)
temporal_arb: TemporalArbConfig = field(default_factory=TemporalArbConfig)
sum_to_one: SumToOneConfig = field(default_factory=SumToOneConfig)
spread_capture: SpreadCaptureConfig = field(default_factory=SpreadCaptureConfig)
risk: RiskConfig = field(default_factory=RiskConfig)
fees: FeesConfig = field(default_factory=FeesConfig)
binance: BinanceConfig = field(default_factory=BinanceConfig)
polymarket: PolymarketConfig = field(default_factory=PolymarketConfig)
notifications: NotificationConfig = field(default_factory=NotificationConfig)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _build_dataclass(cls: type, raw: dict[str, Any]) -> Any:
"""Instantiate a frozen dataclass, silently ignoring unknown keys."""
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
filtered = {k: v for k, v in raw.items() if k in valid_fields}
return cls(**filtered)
def _setup_logging(level_name: str) -> None:
"""Configure structlog with human-readable console output."""
level = getattr(logging, level_name.upper(), logging.INFO)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(level),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=True,
)
# Also align stdlib logging so third-party libraries respect the level
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=level,
)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def load_config(config_path: str = "config.toml") -> Config:
"""Load configuration from *config_path* and environment variables.
1. Reads ``.env`` (if present) into the process environment.
2. Parses ``config.toml`` for all trading parameters.
3. Overlays sensitive credentials from environment variables.
4. Configures structured logging via *structlog*.
Returns a fully-populated, immutable :class:`Config` instance.
"""
# --- .env ---------------------------------------------------------------
env_path = Path(config_path).parent / ".env"
load_dotenv(dotenv_path=env_path if env_path.exists() else None)
# --- config.toml --------------------------------------------------------
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Config file not found: {config_file.resolve()}")
raw: dict[str, Any] = toml.load(config_file)
# --- Build nested configs -----------------------------------------------
general = _build_dataclass(GeneralConfig, raw.get("general", {}))
strategy = raw.get("strategy", {})
temporal_arb = _build_dataclass(TemporalArbConfig, strategy.get("temporal_arb", {}))
sum_to_one = _build_dataclass(SumToOneConfig, strategy.get("sum_to_one", {}))
spread_capture = _build_dataclass(SpreadCaptureConfig, strategy.get("spread_capture", {}))
risk = _build_dataclass(RiskConfig, raw.get("risk", {}))
fees = _build_dataclass(FeesConfig, raw.get("fees", {}))
exchanges = raw.get("exchange", {})
binance = _build_dataclass(BinanceConfig, exchanges.get("binance", {}))
# Polymarket: merge file config with env-var credentials
poly_raw = dict(exchanges.get("polymarket", {}))
poly_raw.update({
"private_key": os.getenv("POLYMARKET_PRIVATE_KEY", ""),
"proxy_wallet": os.getenv("POLYMARKET_PROXY_WALLET", ""),
"api_key": os.getenv("POLYMARKET_API_KEY", ""),
"api_secret": os.getenv("POLYMARKET_API_SECRET", ""),
"api_passphrase": os.getenv("POLYMARKET_API_PASSPHRASE", ""),
})
polymarket = _build_dataclass(PolymarketConfig, poly_raw)
# Notifications: merge file config with env-var tokens
notif_raw = dict(raw.get("notifications", {}))
notif_raw.setdefault("telegram_token", os.getenv("TELEGRAM_BOT_TOKEN", ""))
notif_raw.setdefault("telegram_chat_id", os.getenv("TELEGRAM_CHAT_ID", ""))
# Env vars override empty strings from the config file
if not notif_raw.get("telegram_token"):
notif_raw["telegram_token"] = os.getenv("TELEGRAM_BOT_TOKEN", "")
if not notif_raw.get("telegram_chat_id"):
notif_raw["telegram_chat_id"] = os.getenv("TELEGRAM_CHAT_ID", "")
notifications = _build_dataclass(NotificationConfig, notif_raw)
# --- Assemble top-level config ------------------------------------------
config = Config(
general=general,
temporal_arb=temporal_arb,
sum_to_one=sum_to_one,
spread_capture=spread_capture,
risk=risk,
fees=fees,
binance=binance,
polymarket=polymarket,
notifications=notifications,
)
# --- Logging ------------------------------------------------------------
_setup_logging(config.general.log_level)
log = structlog.get_logger()
log.info(
"config_loaded",
mode=config.general.mode,
assets=config.general.assets,
timeframes=config.general.timeframes,
strategies_enabled=[
name
for name, enabled in [
("temporal_arb", config.temporal_arb.enabled),
("sum_to_one", config.sum_to_one.enabled),
("spread_capture", config.spread_capture.enabled),
]
if enabled
],
)
return config

5
src/data/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Data layer for the Polymarket Arbitrage Bot."""
from .db import TradeDB
__all__ = ["TradeDB"]

323
src/data/db.py Normal file
View File

@@ -0,0 +1,323 @@
"""SQLite trade-log database backed by sqlite-utils.
Provides persistent storage for trades, window snapshots, and daily
performance summaries with lightweight query helpers.
"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from typing import Any, Optional
import structlog
from sqlite_utils import Database
from src.data.models import Trade, WindowState
logger = structlog.get_logger(__name__)
class TradeDB:
"""Thin wrapper around a SQLite database for trade logging and analytics.
Parameters
----------
db_path:
Path to the SQLite database file. Use ``":memory:"`` for tests.
"""
def __init__(self, db_path: str = "trades.db") -> None:
self._db_path = db_path
self._db = Database(db_path)
self._log = logger.bind(component="TradeDB", db=db_path)
self._ensure_tables()
self._log.info("database_ready")
# ------------------------------------------------------------------
# Schema
# ------------------------------------------------------------------
def _ensure_tables(self) -> None:
"""Create tables if they do not already exist."""
if "trades" not in self._db.table_names():
self._db["trades"].create(
{
"id": str,
"asset": str,
"timeframe": str,
"direction": str,
"token_id": str,
"entry_price": float,
"fill_price": float,
"size": int,
"fee": float,
"pnl": float,
"status": str,
"signal_edge": float,
"signal_prob": float,
"created_at": float,
"updated_at": float,
},
pk="id",
)
self._log.debug("table_created", table="trades")
if "window_snapshots" not in self._db.table_names():
self._db["window_snapshots"].create(
{
"id": int,
"asset": str,
"timeframe": str,
"start_price": float,
"end_price": float,
"price_change_pct": float,
"window_start": float,
"window_end": float,
"market_condition_id": str,
"created_at": float,
},
pk="id",
)
self._log.debug("table_created", table="window_snapshots")
if "daily_summary" not in self._db.table_names():
self._db["daily_summary"].create(
{
"date": str,
"total_trades": int,
"wins": int,
"losses": int,
"total_pnl": float,
"total_fees": float,
"total_volume": float,
"best_trade_pnl": float,
"worst_trade_pnl": float,
},
pk="date",
)
self._log.debug("table_created", table="daily_summary")
if "balance_history" not in self._db.table_names():
self._db["balance_history"].create(
{
"id": int,
"timestamp": float,
"balance": float,
"pnl": float,
"event": str,
},
pk="id",
)
self._log.debug("table_created", table="balance_history")
if "oracle_snapshots" not in self._db.table_names():
self._db["oracle_snapshots"].create(
{
"id": int,
"timestamp": float,
"asset": str,
"oracle_price": float,
"cex_price": float,
"deviation_pct": float,
"oracle_lag_sec": float,
"oracle_round_id": str,
},
pk="id",
)
self._log.debug("table_created", table="oracle_snapshots")
# ------------------------------------------------------------------
# Trade CRUD
# ------------------------------------------------------------------
def log_trade(self, trade: Trade) -> None:
"""Insert a new trade record derived from a :class:`Trade` dataclass."""
row = {
"id": trade.id,
"asset": trade.signal.asset.value,
"timeframe": trade.signal.timeframe.value,
"direction": trade.signal.direction.value,
"token_id": trade.signal.token_id,
"entry_price": trade.signal.price,
"fill_price": trade.fill_price,
"size": trade.signal.size,
"fee": trade.fee,
"pnl": trade.pnl,
"status": trade.status.value,
"signal_edge": trade.signal.edge,
"signal_prob": trade.signal.estimated_prob,
"created_at": trade.created_at,
"updated_at": trade.updated_at,
}
self._db["trades"].insert(row)
self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"])
def update_trade(self, trade_id: str, **fields: Any) -> None:
"""Update arbitrary fields on an existing trade row."""
fields["updated_at"] = time.time()
self._db["trades"].update(trade_id, fields)
self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys()))
def get_trades(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
asset: Optional[str] = None,
) -> list[dict[str, Any]]:
"""Return trades matching optional time-range and asset filters."""
clauses: list[str] = []
params: list[Any] = []
if start_time is not None:
clauses.append("created_at >= ?")
params.append(start_time)
if end_time is not None:
clauses.append("created_at <= ?")
params.append(end_time)
if asset is not None:
clauses.append("asset = ?")
params.append(asset)
where = " AND ".join(clauses) if clauses else "1=1"
sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC"
return list(self._db.execute(sql, params).fetchall())
# ------------------------------------------------------------------
# Window snapshots
# ------------------------------------------------------------------
def log_window(self, window: WindowState) -> None:
"""Snapshot a completed window into the database."""
row = {
"asset": window.asset.value,
"timeframe": window.timeframe.value,
"start_price": window.start_price,
"end_price": window.current_price,
"price_change_pct": window.price_change_pct,
"window_start": window.window_start_time,
"window_end": window.window_end_time,
"market_condition_id": (
window.market.condition_id if window.market else None
),
"created_at": time.time(),
}
self._db["window_snapshots"].insert(row)
self._log.info(
"window_logged",
asset=row["asset"],
timeframe=row["timeframe"],
change_pct=row["price_change_pct"],
)
# ------------------------------------------------------------------
# Daily summary helpers
# ------------------------------------------------------------------
def get_daily_summary(self, date: str) -> dict[str, Any]:
"""Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``)."""
sql = """
SELECT
COUNT(*) AS total_trades,
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses,
COALESCE(SUM(pnl), 0.0) AS total_pnl,
COALESCE(SUM(fee), 0.0) AS total_fees,
COALESCE(SUM(size * fill_price), 0.0) AS total_volume,
COALESCE(MAX(pnl), 0.0) AS best_trade_pnl,
COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl
FROM trades
WHERE date(created_at, 'unixepoch') = ?
"""
row = self._db.execute(sql, [date]).fetchone()
return {
"date": date,
"total_trades": row[0],
"wins": row[1],
"losses": row[2],
"total_pnl": row[3],
"total_fees": row[4],
"total_volume": row[5],
"best_trade_pnl": row[6],
"worst_trade_pnl": row[7],
}
def update_daily_summary(self, date: str) -> None:
"""Recalculate and upsert the daily summary row for *date*."""
summary = self.get_daily_summary(date)
self._db["daily_summary"].upsert(summary, pk="date")
self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"])
# ------------------------------------------------------------------
# Quick-access aggregates
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Balance history
# ------------------------------------------------------------------
def log_balance(self, balance: float, pnl: float, event: str = "update") -> None:
"""Record a balance snapshot."""
self._db["balance_history"].insert({
"timestamp": time.time(),
"balance": balance,
"pnl": pnl,
"event": event,
})
def log_oracle(
self,
asset: str,
oracle_price: float,
cex_price: float,
deviation_pct: float,
oracle_lag_sec: float,
oracle_round_id: int = 0,
) -> None:
"""Record an oracle price snapshot."""
self._db["oracle_snapshots"].insert({
"timestamp": time.time(),
"asset": asset,
"oracle_price": oracle_price,
"cex_price": cex_price,
"deviation_pct": round(deviation_pct, 4),
"oracle_lag_sec": round(oracle_lag_sec, 1),
"oracle_round_id": str(oracle_round_id),
})
def get_latest_balance(self) -> Optional[float]:
"""Return the most recent balance, or None."""
row = self._db.execute(
"SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1"
).fetchone()
return float(row[0]) if row else None
# ------------------------------------------------------------------
# Quick-access aggregates
# ------------------------------------------------------------------
def get_total_pnl(self) -> float:
"""Return the all-time cumulative PnL."""
row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone()
return float(row[0])
def get_today_pnl(self) -> float:
"""Return today's cumulative PnL (UTC)."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
row = self._db.execute(
"SELECT COALESCE(SUM(pnl), 0.0) FROM trades "
"WHERE date(created_at, 'unixepoch') = ?",
[today],
).fetchone()
return float(row[0])
def get_today_trade_count(self) -> int:
"""Return the number of trades placed today (UTC)."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
row = self._db.execute(
"SELECT COUNT(*) FROM trades "
"WHERE date(created_at, 'unixepoch') = ?",
[today],
).fetchone()
return int(row[0])

136
src/data/models.py Normal file
View File

@@ -0,0 +1,136 @@
"""Core data models used throughout the Polymarket Arbitrage Bot."""
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
import time
class Asset(str, Enum):
BTC = "BTC"
ETH = "ETH"
SOL = "SOL"
class Timeframe(str, Enum):
FIVE_MIN = "5M"
FIFTEEN_MIN = "15M"
class Direction(str, Enum):
UP = "UP"
DOWN = "DOWN"
class OrderSide(str, Enum):
BUY = "BUY"
SELL = "SELL"
class TradeStatus(str, Enum):
PENDING = "PENDING"
FILLED = "FILLED"
PARTIALLY_FILLED = "PARTIALLY_FILLED"
CANCELLED = "CANCELLED"
FAILED = "FAILED"
@dataclass
class ActiveMarket:
condition_id: str
up_token_id: str
down_token_id: str
asset: Asset
timeframe: Timeframe
end_date: str # ISO format resolution time
question: str
description: str = ""
@dataclass
class WindowState:
asset: Asset
timeframe: Timeframe
window_start_time: float # Unix timestamp
window_end_time: float # Unix timestamp
start_price: Optional[float] = None
current_price: Optional[float] = None
market: Optional[ActiveMarket] = None
@property
def time_remaining(self) -> float:
return max(0, self.window_end_time - time.time())
@property
def price_change_pct(self) -> Optional[float]:
if self.start_price and self.current_price and self.start_price > 0:
return (self.current_price - self.start_price) / self.start_price * 100
return None
@dataclass
class Signal:
direction: Direction
asset: Asset
timeframe: Timeframe
token_id: str
price: float # Polymarket entry price
size: int # Number of shares
edge: float # Estimated edge after fees
estimated_prob: float # Estimated true probability
timestamp: float = field(default_factory=time.time)
@dataclass
class Trade:
id: str
signal: Signal
order_id: str = ""
status: TradeStatus = TradeStatus.PENDING
fill_price: float = 0.0
fill_size: int = 0
fee: float = 0.0
pnl: float = 0.0
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
@dataclass
class Position:
market_id: str
asset: Asset
timeframe: Timeframe
direction: Direction
token_id: str
size: int
avg_price: float
current_value: float = 0.0
unrealized_pnl: float = 0.0
@dataclass
class OrderBookLevel:
price: float
size: float
@dataclass
class OrderBookSnapshot:
token_id: str
bids: list[OrderBookLevel] = field(default_factory=list)
asks: list[OrderBookLevel] = field(default_factory=list)
timestamp: float = field(default_factory=time.time)
@property
def best_bid(self) -> Optional[float]:
return self.bids[0].price if self.bids else None
@property
def best_ask(self) -> Optional[float]:
return self.asks[0].price if self.asks else None
@property
def spread(self) -> Optional[float]:
if self.best_bid is not None and self.best_ask is not None:
return self.best_ask - self.best_bid
return None

View File

@@ -0,0 +1,7 @@
"""Execution engine modules."""
from src.execution.clob_client import ClobClientWrapper
from src.execution.order_manager import OrderManager
from src.execution.position_tracker import PositionTracker
__all__ = ["ClobClientWrapper", "OrderManager", "PositionTracker"]

View File

@@ -0,0 +1,235 @@
"""Polymarket CLOB API wrapper with EIP-712 authentication.
Handles order creation, cancellation, and position queries via
the py-clob-client SDK.
"""
from __future__ import annotations
import asyncio
from typing import Any, Optional
import structlog
from src.config import PolymarketConfig
log = structlog.get_logger()
class ClobClientWrapper:
"""Async-friendly wrapper around py-clob-client.
The underlying SDK is synchronous, so heavy calls are delegated
to a thread-pool executor to avoid blocking the event loop.
"""
def __init__(self, config: PolymarketConfig) -> None:
self.config = config
self._client: Any = None
self._initialized = False
async def initialize(self) -> None:
"""Initialize the CLOB client with credentials.
Must be called before any trading operations.
"""
if self._initialized:
return
if not self.config.private_key:
log.warning("clob_client_no_key", msg="No private key — running in read-only mode")
return
loop = asyncio.get_event_loop()
try:
self._client = await loop.run_in_executor(None, self._create_client)
self._initialized = True
log.info("clob_client_initialized")
except Exception:
log.exception("clob_client_init_failed")
def _create_client(self) -> Any:
"""Create the synchronous CLOB client (runs in executor)."""
from py_clob_client.client import ClobClient
client = ClobClient(
host=self.config.clob_url,
key=self.config.private_key,
chain_id=self.config.chain_id,
signature_type=self.config.signature_type,
funder=self.config.proxy_wallet if self.config.proxy_wallet else None,
)
# Set API credentials if available
if self.config.api_key:
client.set_api_creds(client.create_or_derive_api_creds())
else:
# Generate new API key
api_creds = client.create_api_key()
client.set_api_creds(api_creds)
log.info("clob_api_key_created")
return client
@property
def is_ready(self) -> bool:
return self._initialized and self._client is not None
# ------------------------------------------------------------------
# Order operations
# ------------------------------------------------------------------
async def place_limit_order(
self,
token_id: str,
price: float,
size: int,
side: str = "BUY",
) -> Optional[dict]:
"""Place a limit order on the CLOB.
Args:
token_id: The CLOB token ID (Up or Down outcome).
price: Limit price (0.01 to 0.99).
size: Number of shares.
side: "BUY" or "SELL".
Returns:
Order response dict or None on failure.
"""
if not self.is_ready:
log.error("clob_not_ready", msg="Client not initialized")
return None
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(
None, self._place_order_sync, token_id, price, size, side
)
log.info(
"order_placed",
token_id=token_id[:16],
price=price,
size=size,
side=side,
order_id=result.get("orderID", ""),
)
return result
except Exception:
log.exception("order_place_failed", token_id=token_id[:16], price=price, size=size)
return None
def _place_order_sync(self, token_id: str, price: float, size: int, side: str) -> dict:
from py_clob_client.clob_types import OrderArgs, OrderType
order_args = OrderArgs(
token_id=token_id,
price=price,
size=size,
side=side,
)
signed_order = self._client.create_order(order_args)
return self._client.post_order(signed_order, orderType=OrderType.GTC)
async def place_market_order(
self,
token_id: str,
size: int,
side: str = "BUY",
) -> Optional[dict]:
"""Place a market order (FOK at worst available price)."""
if not self.is_ready:
return None
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(
None, self._place_market_order_sync, token_id, size, side
)
log.info("market_order_placed", token_id=token_id[:16], size=size, side=side)
return result
except Exception:
log.exception("market_order_failed")
return None
def _place_market_order_sync(self, token_id: str, size: int, side: str) -> dict:
from py_clob_client.clob_types import OrderArgs, OrderType
# Use price 0.99 for BUY (worst case) or 0.01 for SELL
price = 0.99 if side == "BUY" else 0.01
order_args = OrderArgs(
token_id=token_id,
price=price,
size=size,
side=side,
)
signed_order = self._client.create_order(order_args)
return self._client.post_order(signed_order, orderType=OrderType.FOK)
async def cancel_order(self, order_id: str) -> bool:
"""Cancel a specific order by ID."""
if not self.is_ready:
return False
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self._client.cancel, order_id)
log.info("order_cancelled", order_id=order_id)
return True
except Exception:
log.exception("order_cancel_failed", order_id=order_id)
return False
async def cancel_all_orders(self) -> bool:
"""Cancel all open orders."""
if not self.is_ready:
return False
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self._client.cancel_all)
log.info("all_orders_cancelled")
return True
except Exception:
log.exception("cancel_all_failed")
return False
# ------------------------------------------------------------------
# Query operations
# ------------------------------------------------------------------
async def get_order(self, order_id: str) -> Optional[dict]:
"""Fetch order status by ID."""
if not self.is_ready:
return None
loop = asyncio.get_event_loop()
try:
return await loop.run_in_executor(None, self._client.get_order, order_id)
except Exception:
log.exception("get_order_failed", order_id=order_id)
return None
async def get_open_orders(self) -> list[dict]:
"""Fetch all open orders."""
if not self.is_ready:
return []
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(None, self._client.get_orders)
return result if isinstance(result, list) else []
except Exception:
log.exception("get_open_orders_failed")
return []
async def get_orderbook(self, token_id: str) -> Optional[dict]:
"""Fetch current orderbook for a token."""
loop = asyncio.get_event_loop()
try:
return await loop.run_in_executor(
None, self._client.get_order_book, token_id
)
except Exception:
log.exception("get_orderbook_failed", token_id=token_id[:16])
return None

View File

@@ -0,0 +1,243 @@
"""Order Manager — handles order lifecycle from signal to fill.
Manages order creation, tracking, modification, and cancellation
with position limit enforcement.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from typing import Callable, Optional
import structlog
from src.config import Config
from src.data.models import Direction, Signal, Trade, TradeStatus
from src.execution.clob_client import ClobClientWrapper
log = structlog.get_logger()
class OrderManager:
"""Manages the full lifecycle of orders from signal to fill.
Enforces position limits, tracks pending/active orders,
and handles cancellation near resolution time.
"""
def __init__(self, config: Config, clob_client: ClobClientWrapper) -> None:
self.config = config
self.clob = clob_client
self._pending_orders: dict[str, Trade] = {}
self._active_trades: dict[str, Trade] = {}
self._on_fill_callbacks: list[Callable[[Trade], None]] = []
self._running = False
def on_fill(self, callback: Callable[[Trade], None]) -> None:
"""Register callback for when an order is filled."""
self._on_fill_callbacks.append(callback)
# ------------------------------------------------------------------
# Order submission
# ------------------------------------------------------------------
async def submit_signal(self, signal: Signal) -> Optional[Trade]:
"""Submit a signal for execution.
Validates against risk limits, creates the order, and tracks it.
Returns a Trade object if the order was submitted.
"""
# Check concurrent position limit
active_count = len(self._active_trades) + len(self._pending_orders)
if active_count >= self.config.risk.max_concurrent_positions:
log.warning(
"position_limit_reached",
active=active_count,
limit=self.config.risk.max_concurrent_positions,
)
return None
# Check per-market exposure
market_exposure = self._get_market_exposure(signal.token_id)
new_exposure = signal.price * signal.size
if market_exposure + new_exposure > self.config.risk.max_position_per_market_usd:
log.warning(
"market_exposure_limit",
token_id=signal.token_id[:16],
current=round(market_exposure, 2),
new=round(new_exposure, 2),
)
return None
# Create trade record
trade = Trade(
id=str(uuid.uuid4())[:8],
signal=signal,
status=TradeStatus.PENDING,
)
# Submit order to CLOB
result = await self.clob.place_limit_order(
token_id=signal.token_id,
price=signal.price,
size=signal.size,
side="BUY",
)
if result is None:
trade.status = TradeStatus.FAILED
log.error("order_submission_failed", trade_id=trade.id)
return trade
trade.order_id = result.get("orderID", "")
self._pending_orders[trade.id] = trade
log.info(
"order_submitted",
trade_id=trade.id,
order_id=trade.order_id,
asset=signal.asset.value,
direction=signal.direction.value,
price=signal.price,
size=signal.size,
)
return trade
# ------------------------------------------------------------------
# Order monitoring
# ------------------------------------------------------------------
async def check_order_status(self, trade: Trade) -> Trade:
"""Check the status of a pending order and update trade accordingly."""
if not trade.order_id:
return trade
order_info = await self.clob.get_order(trade.order_id)
if order_info is None:
return trade
status = order_info.get("status", "").upper()
if status in ("MATCHED", "FILLED"):
trade.status = TradeStatus.FILLED
trade.fill_price = float(order_info.get("price", trade.signal.price))
trade.fill_size = int(order_info.get("size_matched", trade.signal.size))
trade.fee = self._calculate_fee(trade)
trade.updated_at = time.time()
# Move from pending to active
self._pending_orders.pop(trade.id, None)
self._active_trades[trade.id] = trade
for cb in self._on_fill_callbacks:
try:
cb(trade)
except Exception:
log.exception("fill_callback_error")
log.info(
"order_filled",
trade_id=trade.id,
fill_price=trade.fill_price,
fill_size=trade.fill_size,
fee=round(trade.fee, 4),
)
elif status == "CANCELLED":
trade.status = TradeStatus.CANCELLED
trade.updated_at = time.time()
self._pending_orders.pop(trade.id, None)
return trade
async def monitor_loop(self, interval: float = 2.0) -> None:
"""Continuously monitor pending orders for fills."""
self._running = True
while self._running:
for trade_id in list(self._pending_orders.keys()):
trade = self._pending_orders.get(trade_id)
if trade:
await self.check_order_status(trade)
await asyncio.sleep(interval)
def stop(self) -> None:
self._running = False
# ------------------------------------------------------------------
# Cancellation
# ------------------------------------------------------------------
async def cancel_pending(self, trade_id: str) -> bool:
"""Cancel a specific pending order."""
trade = self._pending_orders.get(trade_id)
if not trade or not trade.order_id:
return False
success = await self.clob.cancel_order(trade.order_id)
if success:
trade.status = TradeStatus.CANCELLED
trade.updated_at = time.time()
self._pending_orders.pop(trade_id, None)
return success
async def cancel_all_pending(self) -> int:
"""Cancel all pending orders. Returns count of cancelled orders."""
cancelled = 0
for trade_id in list(self._pending_orders.keys()):
if await self.cancel_pending(trade_id):
cancelled += 1
return cancelled
async def cancel_expiring_orders(self, seconds_before_resolution: float = 5.0) -> int:
"""Cancel orders that are too close to market resolution."""
cancelled = 0
now = time.time()
for trade_id, trade in list(self._pending_orders.items()):
# Check if the signal's window is about to resolve
# (Signal doesn't store window_end_time, but we can infer from context)
if trade.created_at + 300 - now < seconds_before_resolution: # Rough heuristic
if await self.cancel_pending(trade_id):
cancelled += 1
log.info("expiring_order_cancelled", trade_id=trade_id)
return cancelled
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def get_pending_trades(self) -> list[Trade]:
return list(self._pending_orders.values())
def get_active_trades(self) -> list[Trade]:
return list(self._active_trades.values())
def get_all_trades(self) -> list[Trade]:
return self.get_pending_trades() + self.get_active_trades()
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_market_exposure(self, token_id: str) -> float:
"""Calculate current dollar exposure for a given token."""
exposure = 0.0
for trade in list(self._pending_orders.values()) + list(self._active_trades.values()):
if trade.signal.token_id == token_id:
exposure += trade.signal.price * trade.signal.size
return exposure
def _calculate_fee(self, trade: Trade) -> float:
"""Calculate the taker fee for a filled trade."""
tf = trade.signal.timeframe.value
fee_rate = self.config.fees.fee_for_timeframe(tf)
# Fee is on the potential payout (winning amount)
payout = trade.fill_size * 1.0 # $1 per share if winning
cost = trade.fill_price * trade.fill_size
profit = payout - cost
return max(0, profit * fee_rate)

View File

@@ -0,0 +1,211 @@
"""Position Tracker — real-time position & PnL tracking.
Monitors active positions, calculates unrealized PnL based on
current Polymarket prices, and enforces exposure limits.
"""
from __future__ import annotations
import time
from typing import Optional
import structlog
from src.config import Config
from src.data.models import (
Asset,
Direction,
OrderBookSnapshot,
Position,
Timeframe,
Trade,
TradeStatus,
)
log = structlog.get_logger()
class PositionTracker:
"""Tracks all open positions and computes real-time PnL.
Positions are created from filled trades and closed when
the underlying market resolves or the position is sold.
"""
def __init__(self, config: Config) -> None:
self.config = config
self._positions: dict[str, Position] = {} # keyed by token_id
self._realized_pnl: float = 0.0
self._total_fees: float = 0.0
self._trade_count: int = 0
self._win_count: int = 0
# ------------------------------------------------------------------
# Position management
# ------------------------------------------------------------------
def open_position(self, trade: Trade) -> Position:
"""Open or add to a position from a filled trade."""
token_id = trade.signal.token_id
signal = trade.signal
if token_id in self._positions:
# Add to existing position
pos = self._positions[token_id]
total_cost = pos.avg_price * pos.size + trade.fill_price * trade.fill_size
new_size = pos.size + trade.fill_size
pos.avg_price = total_cost / new_size if new_size > 0 else 0
pos.size = new_size
else:
pos = Position(
market_id=token_id,
asset=signal.asset,
timeframe=signal.timeframe,
direction=signal.direction,
token_id=token_id,
size=trade.fill_size,
avg_price=trade.fill_price,
)
self._positions[token_id] = pos
self._trade_count += 1
self._total_fees += trade.fee
log.info(
"position_opened",
token_id=token_id[:16],
asset=signal.asset.value,
direction=signal.direction.value,
size=pos.size,
avg_price=round(pos.avg_price, 4),
)
return pos
def close_position(self, token_id: str, resolution_price: float) -> Optional[float]:
"""Close a position at resolution.
Args:
token_id: The token ID of the position.
resolution_price: 1.0 if the outcome won, 0.0 if lost.
Returns:
Realized PnL for this position, or None if no position.
"""
pos = self._positions.pop(token_id, None)
if pos is None:
return None
payout = resolution_price * pos.size
cost = pos.avg_price * pos.size
pnl = payout - cost
self._realized_pnl += pnl
if pnl > 0:
self._win_count += 1
log.info(
"position_closed",
token_id=token_id[:16],
asset=pos.asset.value,
direction=pos.direction.value,
size=pos.size,
avg_price=round(pos.avg_price, 4),
payout=round(payout, 2),
pnl=round(pnl, 2),
result="WIN" if pnl > 0 else "LOSS",
)
return pnl
# ------------------------------------------------------------------
# Mark-to-market
# ------------------------------------------------------------------
def update_mark(self, token_id: str, current_price: float) -> None:
"""Update mark-to-market for a position."""
pos = self._positions.get(token_id)
if pos is None:
return
pos.current_value = current_price * pos.size
pos.unrealized_pnl = pos.current_value - (pos.avg_price * pos.size)
def update_from_orderbooks(self, orderbooks: dict[str, OrderBookSnapshot]) -> None:
"""Update all positions from latest orderbook data."""
for token_id, pos in self._positions.items():
book = orderbooks.get(token_id)
if book and book.best_bid is not None:
self.update_mark(token_id, book.best_bid)
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def get_position(self, token_id: str) -> Optional[Position]:
return self._positions.get(token_id)
def get_all_positions(self) -> list[Position]:
return list(self._positions.values())
def get_positions_by_asset(self, asset: str) -> list[Position]:
a = Asset(asset)
return [p for p in self._positions.values() if p.asset == a]
@property
def total_unrealized_pnl(self) -> float:
return sum(p.unrealized_pnl for p in self._positions.values())
@property
def total_realized_pnl(self) -> float:
return self._realized_pnl
@property
def total_pnl(self) -> float:
return self._realized_pnl + self.total_unrealized_pnl
@property
def total_fees(self) -> float:
return self._total_fees
@property
def total_exposure(self) -> float:
return sum(p.avg_price * p.size for p in self._positions.values())
@property
def position_count(self) -> int:
return len(self._positions)
@property
def trade_count(self) -> int:
return self._trade_count
@property
def win_rate(self) -> float:
closed = self._trade_count - len(self._positions)
return self._win_count / closed if closed > 0 else 0.0
# ------------------------------------------------------------------
# Risk checks
# ------------------------------------------------------------------
def is_exposure_ok(self, additional_usd: float = 0) -> bool:
"""Check if total exposure is within limits."""
return (self.total_exposure + additional_usd) <= self.config.risk.max_total_exposure_usd
def is_daily_loss_ok(self, daily_pnl: float) -> bool:
"""Check if daily loss is within limits."""
return daily_pnl > -self.config.risk.max_daily_loss_usd
def get_summary(self) -> dict:
"""Return a summary dict for logging/display."""
return {
"positions": self.position_count,
"total_exposure": round(self.total_exposure, 2),
"unrealized_pnl": round(self.total_unrealized_pnl, 2),
"realized_pnl": round(self.total_realized_pnl, 2),
"total_pnl": round(self.total_pnl, 2),
"total_fees": round(self.total_fees, 2),
"trades": self.trade_count,
"win_rate": round(self.win_rate * 100, 1),
}

6
src/feeds/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""Price feed integrations for the Polymarket Arbitrage Bot."""
from .binance_ws import BinanceFeed, PriceCallback
from .polymarket_ws import PolymarketFeed
__all__ = ["BinanceFeed", "PolymarketFeed", "PriceCallback"]

250
src/feeds/binance_ws.py Normal file
View File

@@ -0,0 +1,250 @@
"""Binance WebSocket price feed for real-time trade data.
Connects to Binance combined streams for BTC, ETH, and SOL trade data,
providing low-latency price updates via an async callback pattern.
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Callable, Optional
import structlog
import websockets
from websockets.exceptions import ConnectionClosed, InvalidStatusCode
logger = structlog.get_logger(__name__)
# Stream symbol mapping: Binance stream name -> canonical symbol
_STREAM_TO_SYMBOL: dict[str, str] = {
"btcusdt@trade": "BTC",
"ethusdt@trade": "ETH",
"solusdt@trade": "SOL",
}
_SYMBOL_TO_STREAM: dict[str, str] = {v: k for k, v in _STREAM_TO_SYMBOL.items()}
PriceCallback = Callable[[str, float, float, float], None]
_DEFAULT_URL = (
"wss://stream.binance.com:9443/stream"
"?streams=btcusdt@trade/ethusdt@trade/solusdt@trade"
)
_PING_INTERVAL_S = 30
_RECONNECT_BASE_S = 1.0
_RECONNECT_MAX_S = 30.0
class BinanceFeed:
"""Async Binance WebSocket price feed with auto-reconnect.
Usage::
feed = BinanceFeed()
feed.subscribe(my_callback)
await feed.start() # runs until stop() is called
"""
def __init__(self, url: str = _DEFAULT_URL) -> None:
self._url = url
self._callbacks: list[PriceCallback] = []
self._latest_prices: dict[str, float] = {}
self._last_recv_ts: dict[str, float] = {}
self._connected = False
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._stop_event: asyncio.Event = asyncio.Event()
self._tasks: list[asyncio.Task[None]] = []
self._reconnect_delay = _RECONNECT_BASE_S
self._log = logger.bind(component="BinanceFeed")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@property
def is_connected(self) -> bool:
"""Return ``True`` when the WebSocket connection is open."""
return self._connected
def subscribe(self, callback: PriceCallback) -> None:
"""Register a price-update callback.
The callback signature is::
callback(symbol: str, price: float, timestamp: float, volume: float)
"""
if callback not in self._callbacks:
self._callbacks.append(callback)
self._log.info("callback_subscribed", total=len(self._callbacks))
def get_latest_price(self, symbol: str) -> Optional[float]:
"""Return the most recent trade price for *symbol*, or ``None``."""
return self._latest_prices.get(symbol.upper())
async def start(self) -> None:
"""Open the WebSocket and begin consuming messages.
Automatically reconnects on disconnection with exponential backoff.
Blocks until :meth:`stop` is called.
"""
self._stop_event.clear()
self._log.info("feed_starting", url=self._url)
while not self._stop_event.is_set():
try:
await self._connect_and_listen()
except (ConnectionClosed, ConnectionError, InvalidStatusCode, OSError) as exc:
self._set_disconnected()
self._log.warning(
"connection_lost",
error=str(exc),
reconnect_in=self._reconnect_delay,
)
except asyncio.CancelledError:
self._log.info("feed_cancelled")
break
except Exception:
self._set_disconnected()
self._log.exception(
"unexpected_error",
reconnect_in=self._reconnect_delay,
)
if self._stop_event.is_set():
break
# Wait with exponential backoff, but allow early exit via stop().
try:
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self._reconnect_delay,
)
break # stop_event was set during the wait
except asyncio.TimeoutError:
pass
self._reconnect_delay = min(
self._reconnect_delay * 2, _RECONNECT_MAX_S
)
await self._cleanup()
self._log.info("feed_stopped")
async def stop(self) -> None:
"""Signal the feed to shut down gracefully."""
self._log.info("feed_stop_requested")
self._stop_event.set()
for task in self._tasks:
task.cancel()
if self._ws is not None:
await self._ws.close()
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
async def _connect_and_listen(self) -> None:
"""Establish a WebSocket connection and consume messages."""
async with websockets.connect(
self._url,
ping_interval=None, # we manage our own heartbeat
close_timeout=5,
) as ws:
self._ws = ws
self._connected = True
self._reconnect_delay = _RECONNECT_BASE_S
self._log.info("connected")
ping_task = asyncio.create_task(self._heartbeat(ws))
self._tasks.append(ping_task)
try:
async for raw_msg in ws:
if self._stop_event.is_set():
break
self._handle_message(raw_msg)
finally:
ping_task.cancel()
try:
await ping_task
except asyncio.CancelledError:
pass
self._tasks = [t for t in self._tasks if t is not ping_task]
async def _heartbeat(self, ws: websockets.WebSocketClientProtocol) -> None:
"""Send a ping frame every ``_PING_INTERVAL_S`` seconds."""
try:
while True:
await asyncio.sleep(_PING_INTERVAL_S)
pong = await ws.ping()
await asyncio.wait_for(pong, timeout=10)
self._log.debug("heartbeat_ok")
except asyncio.CancelledError:
pass
except Exception as exc:
self._log.warning("heartbeat_failed", error=str(exc))
def _handle_message(self, raw: str | bytes) -> None:
"""Parse a combined-stream message and dispatch callbacks."""
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError) as exc:
self._log.warning("json_parse_error", error=str(exc), raw=raw[:200])
return
stream: str | None = msg.get("stream")
data: dict | None = msg.get("data")
if stream is None or data is None:
self._log.debug("ignored_message", keys=list(msg.keys()))
return
symbol = _STREAM_TO_SYMBOL.get(stream)
if symbol is None:
self._log.debug("unknown_stream", stream=stream)
return
try:
price = float(data["p"])
volume = float(data["q"])
# Binance trade timestamp is in milliseconds
timestamp = float(data["T"]) / 1000.0
except (KeyError, ValueError, TypeError) as exc:
self._log.warning(
"trade_parse_error",
stream=stream,
error=str(exc),
data=data,
)
return
self._latest_prices[symbol] = price
self._last_recv_ts[symbol] = time.time()
for cb in self._callbacks:
try:
cb(symbol, price, timestamp, volume)
except Exception:
self._log.exception("callback_error", symbol=symbol)
def _set_disconnected(self) -> None:
self._connected = False
self._ws = None
async def _cleanup(self) -> None:
"""Cancel lingering tasks and close the socket."""
for task in self._tasks:
task.cancel()
self._tasks.clear()
if self._ws is not None:
try:
await self._ws.close()
except Exception:
pass
self._set_disconnected()

303
src/feeds/polymarket_ws.py Normal file
View File

@@ -0,0 +1,303 @@
"""Polymarket CLOB WebSocket feed for real-time orderbook data.
Connects to the Polymarket WebSocket subscriptions endpoint, subscribes
to token-level orderbook channels, and maintains the latest
:class:`OrderBookSnapshot` per token with auto-reconnect.
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Callable, Optional
import structlog
import websockets
from websockets.exceptions import ConnectionClosed, InvalidStatusCode
from src.data.models import OrderBookLevel, OrderBookSnapshot
logger = structlog.get_logger(__name__)
OrderBookCallback = Callable[[str, OrderBookSnapshot], None]
_WS_URL = "wss://ws-subscriptions-clob.polymarket.com/ws/"
_RECONNECT_BASE_S = 1.0
_RECONNECT_MAX_S = 60.0
_PING_INTERVAL_S = 30
class PolymarketFeed:
"""Async Polymarket CLOB WebSocket feed with auto-reconnect.
Usage::
feed = PolymarketFeed()
feed.on_orderbook_update(my_callback)
feed.subscribe_market("0xabc...")
await feed.start()
"""
def __init__(self, url: str = _WS_URL) -> None:
self._url = url
self._callbacks: list[OrderBookCallback] = []
self._subscriptions: set[str] = set()
self._orderbooks: dict[str, OrderBookSnapshot] = {}
self._connected = False
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._stop_event: asyncio.Event = asyncio.Event()
self._tasks: list[asyncio.Task[None]] = []
self._reconnect_delay = _RECONNECT_BASE_S
self._log = logger.bind(component="PolymarketFeed")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@property
def is_connected(self) -> bool:
"""Return ``True`` when the WebSocket connection is open."""
return self._connected
def subscribe_market(self, token_id: str) -> None:
"""Subscribe to orderbook updates for *token_id*.
If the WebSocket is already connected the subscription message is
sent immediately; otherwise it will be sent on the next connect.
"""
self._subscriptions.add(token_id)
self._log.info("market_subscribed", token_id=token_id)
if self._ws is not None and self._connected:
asyncio.ensure_future(self._send_subscribe(token_id))
def unsubscribe_market(self, token_id: str) -> None:
"""Unsubscribe from orderbook updates for *token_id*."""
self._subscriptions.discard(token_id)
self._orderbooks.pop(token_id, None)
self._log.info("market_unsubscribed", token_id=token_id)
if self._ws is not None and self._connected:
asyncio.ensure_future(self._send_unsubscribe(token_id))
def on_orderbook_update(self, callback: OrderBookCallback) -> None:
"""Register a callback invoked on every orderbook update.
Signature::
callback(token_id: str, snapshot: OrderBookSnapshot)
"""
if callback not in self._callbacks:
self._callbacks.append(callback)
self._log.info("callback_registered", total=len(self._callbacks))
def get_orderbook(self, token_id: str) -> Optional[OrderBookSnapshot]:
"""Return the latest cached orderbook for *token_id*, or ``None``."""
return self._orderbooks.get(token_id)
async def start(self) -> None:
"""Connect and begin receiving orderbook data.
Blocks until :meth:`stop` is called. Automatically reconnects
with exponential backoff on connection failures.
"""
self._stop_event.clear()
self._log.info("feed_starting", url=self._url)
while not self._stop_event.is_set():
try:
await self._connect_and_listen()
except (ConnectionClosed, ConnectionError, InvalidStatusCode, OSError) as exc:
self._set_disconnected()
self._log.warning(
"connection_lost",
error=str(exc),
reconnect_in=self._reconnect_delay,
)
except asyncio.CancelledError:
self._log.info("feed_cancelled")
break
except Exception:
self._set_disconnected()
self._log.exception(
"unexpected_error",
reconnect_in=self._reconnect_delay,
)
if self._stop_event.is_set():
break
# Exponential backoff, interruptible by stop().
try:
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self._reconnect_delay,
)
break
except asyncio.TimeoutError:
pass
self._reconnect_delay = min(
self._reconnect_delay * 2, _RECONNECT_MAX_S
)
await self._cleanup()
self._log.info("feed_stopped")
async def stop(self) -> None:
"""Signal the feed to shut down gracefully."""
self._log.info("feed_stop_requested")
self._stop_event.set()
for task in self._tasks:
task.cancel()
if self._ws is not None:
await self._ws.close()
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
async def _connect_and_listen(self) -> None:
"""Establish WebSocket and consume messages."""
async with websockets.connect(
self._url,
ping_interval=None,
close_timeout=5,
) as ws:
self._ws = ws
self._connected = True
self._reconnect_delay = _RECONNECT_BASE_S
self._log.info("connected")
# Re-subscribe to all tracked markets on (re)connect.
for token_id in self._subscriptions:
await self._send_subscribe(token_id)
ping_task = asyncio.create_task(self._heartbeat(ws))
self._tasks.append(ping_task)
try:
async for raw_msg in ws:
if self._stop_event.is_set():
break
self._handle_message(raw_msg)
finally:
ping_task.cancel()
try:
await ping_task
except asyncio.CancelledError:
pass
self._tasks = [t for t in self._tasks if t is not ping_task]
async def _heartbeat(self, ws: websockets.WebSocketClientProtocol) -> None:
"""Send periodic pings to keep the connection alive."""
try:
while True:
await asyncio.sleep(_PING_INTERVAL_S)
pong = await ws.ping()
await asyncio.wait_for(pong, timeout=10)
self._log.debug("heartbeat_ok")
except asyncio.CancelledError:
pass
except Exception as exc:
self._log.warning("heartbeat_failed", error=str(exc))
async def _send_subscribe(self, token_id: str) -> None:
"""Send a subscribe message for a token's orderbook channel."""
if self._ws is None:
return
msg = {
"type": "subscribe",
"channel": "book",
"assets_ids": [token_id],
}
await self._ws.send(json.dumps(msg))
self._log.debug("subscribe_sent", token_id=token_id)
async def _send_unsubscribe(self, token_id: str) -> None:
"""Send an unsubscribe message for a token's orderbook channel."""
if self._ws is None:
return
msg = {
"type": "unsubscribe",
"channel": "book",
"assets_ids": [token_id],
}
await self._ws.send(json.dumps(msg))
self._log.debug("unsubscribe_sent", token_id=token_id)
def _handle_message(self, raw: str | bytes) -> None:
"""Parse an incoming message and update the local orderbook cache."""
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError) as exc:
self._log.warning("json_parse_error", error=str(exc))
return
msg_type = msg.get("type") or msg.get("event_type")
# Handle book snapshot or delta messages
if msg_type in ("book", "book_snapshot", "book_delta"):
self._process_book_message(msg)
elif msg_type == "error":
self._log.error("server_error", message=msg.get("message"))
else:
self._log.debug("ignored_message", msg_type=msg_type)
def _process_book_message(self, msg: dict) -> None:
"""Extract bids/asks from a book message and update state."""
# Polymarket sends asset_id at the top level of book messages.
token_id: str | None = msg.get("asset_id") or msg.get("market")
if token_id is None:
self._log.debug("book_message_missing_token", keys=list(msg.keys()))
return
bids = [
OrderBookLevel(price=float(b["price"]), size=float(b["size"]))
for b in msg.get("bids", [])
]
asks = [
OrderBookLevel(price=float(a["price"]), size=float(a["size"]))
for a in msg.get("asks", [])
]
# Sort: bids descending, asks ascending
bids.sort(key=lambda lvl: lvl.price, reverse=True)
asks.sort(key=lambda lvl: lvl.price)
snapshot = OrderBookSnapshot(
token_id=token_id,
bids=bids,
asks=asks,
timestamp=time.time(),
)
self._orderbooks[token_id] = snapshot
# Dispatch to registered callbacks
for cb in self._callbacks:
try:
cb(token_id, snapshot)
except Exception:
self._log.exception("callback_error", token_id=token_id)
def _set_disconnected(self) -> None:
self._connected = False
self._ws = None
async def _cleanup(self) -> None:
"""Cancel lingering tasks and close the socket."""
for task in self._tasks:
task.cancel()
self._tasks.clear()
if self._ws is not None:
try:
await self._ws.close()
except Exception:
pass
self._set_disconnected()

762
src/main.py Normal file
View File

@@ -0,0 +1,762 @@
"""Main entrypoint for the Polymarket Temporal Arbitrage Bot.
Ties together all components: market discovery, window tracking, Binance and
Polymarket price feeds, strategy evaluation, order execution, position tracking,
risk management, and notifications in a single async event loop.
"""
from __future__ import annotations
import asyncio
import signal
import sys
import time
from datetime import datetime, timezone
from typing import Optional
import structlog
from src.config import Config, load_config
from src.data import TradeDB
from src.data.models import (
ActiveMarket,
Asset,
Direction,
OrderBookSnapshot,
Signal,
Timeframe,
Trade,
TradeStatus,
WindowState,
)
from src.execution.clob_client import ClobClientWrapper
from src.execution.order_manager import OrderManager
from src.execution.position_tracker import PositionTracker
from src.feeds import BinanceFeed, PolymarketFeed
from src.market import MarketDiscovery, WindowTracker
from src.risk.risk_manager import RiskManager
from src.strategy import SignalAggregator
from src.utils import get_logger, setup_logging
from src.utils.telegram import TelegramNotifier
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_STATUS_INTERVAL_S: float = 30.0
_DISCOVERY_INTERVAL_S: float = 30.0
_ORDER_MONITOR_INTERVAL_S: float = 2.0
_BALANCE_SNAPSHOT_INTERVAL_S: float = 60.0
_DAILY_SUMMARY_HOUR_UTC: int = 0 # midnight UTC
_BANNER = r"""
____ _ _ _ _ _
| _ \ ___ | |_ _ _ __ ___ __ _ _ __| | _______| |_ / \ _ __| |__
| |_) / _ \| | | | | '_ ` _ \ / _` | '__| |/ / _ \ __/ / _ \ | '__| '_ \
| __/ (_) | | |_| | | | | | | (_| | | | < __/ |_ / ___ \| | | |_) |
|_| \___/|_|\__, |_| |_| |_|\__,_|_| |_|\_\___|\__| /_/ \_\_| |_.__/
|___/
"""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _print_startup_banner(cfg: Config) -> None:
"""Print a startup banner summarising the active configuration."""
print(_BANNER)
strategies = []
if cfg.temporal_arb.enabled:
strategies.append("temporal_arb")
if cfg.sum_to_one.enabled:
strategies.append("sum_to_one")
if cfg.spread_capture.enabled:
strategies.append("spread_capture")
print(f" Mode: {cfg.general.mode}")
print(f" Assets: {', '.join(cfg.general.assets)}")
print(f" Timeframes: {', '.join(cfg.general.timeframes)}")
print(f" Strategies: {', '.join(strategies) or 'none'}")
print(f" Log level: {cfg.general.log_level}")
print(f" Binance WS: {cfg.binance.ws_url}")
print(f" Polymarket: {cfg.polymarket.ws_url}")
print()
def _link_markets_to_tracker(
markets: list[ActiveMarket],
tracker: WindowTracker,
log: structlog.stdlib.BoundLogger,
) -> None:
"""Link each discovered market to the corresponding window in the tracker."""
for mkt in markets:
tracker.link_market(
asset=mkt.asset.value,
timeframe=mkt.timeframe.value,
market=mkt,
)
log.info(
"market_linked_to_window",
asset=mkt.asset.value,
timeframe=mkt.timeframe.value,
condition_id=mkt.condition_id,
)
def _subscribe_market_tokens(
markets: list[ActiveMarket],
poly_feed: PolymarketFeed,
log: structlog.stdlib.BoundLogger,
) -> None:
"""Subscribe the Polymarket feed to token IDs from discovered markets."""
for mkt in markets:
poly_feed.subscribe_market(mkt.up_token_id)
poly_feed.subscribe_market(mkt.down_token_id)
log.debug(
"poly_tokens_subscribed",
condition_id=mkt.condition_id,
up_token=mkt.up_token_id[:12],
down_token=mkt.down_token_id[:12],
)
# ---------------------------------------------------------------------------
# Core application
# ---------------------------------------------------------------------------
class ArbBot:
"""Full-featured arbitrage bot with strategy execution and risk management."""
def __init__(self, config: Config) -> None:
self._cfg = config
self._log = get_logger("arb_bot")
# Components
self._db: Optional[TradeDB] = None
self._discovery: Optional[MarketDiscovery] = None
self._tracker: Optional[WindowTracker] = None
self._binance_feed: Optional[BinanceFeed] = None
self._poly_feed: Optional[PolymarketFeed] = None
self._clob: Optional[ClobClientWrapper] = None
self._order_manager: Optional[OrderManager] = None
self._position_tracker: Optional[PositionTracker] = None
self._risk_manager: Optional[RiskManager] = None
self._signal_aggregator: Optional[SignalAggregator] = None
self._telegram: Optional[TelegramNotifier] = None
# Discovered markets cache
self._active_markets: list[ActiveMarket] = []
# Orderbook state
self._orderbooks: dict[str, OrderBookSnapshot] = {}
# Balance tracking
self._balance: float = config.general.starting_balance
# Shutdown coordination
self._shutdown_event = asyncio.Event()
# Daily summary tracking
self._last_daily_summary_date: str = ""
# ------------------------------------------------------------------
# Callbacks
# ------------------------------------------------------------------
def _on_binance_price(
self, symbol: str, price: float, timestamp: float, volume: float
) -> None:
"""Callback invoked for each Binance trade tick."""
if self._tracker is not None:
self._tracker.update_price(symbol, price, timestamp)
# Strategy evaluation on each tick
if self._signal_aggregator and self._tracker:
window = self._tracker.get_window(symbol, "5M")
if window:
asyncio.get_event_loop().call_soon(
lambda s=symbol, p=price, w=window: asyncio.ensure_future(
self._evaluate_strategies(s, p, w)
)
)
# Also evaluate 15M windows
window_15m = self._tracker.get_window(symbol, "15M")
if window_15m:
asyncio.get_event_loop().call_soon(
lambda s=symbol, p=price, w=window_15m: asyncio.ensure_future(
self._evaluate_strategies(s, p, w)
)
)
async def _evaluate_strategies(
self, symbol: str, cex_price: float, window: WindowState
) -> None:
"""Evaluate strategies for a given symbol and window."""
if not self._signal_aggregator:
return
try:
await self._signal_aggregator.on_price_tick(
symbol=symbol,
cex_price=cex_price,
window=window,
orderbooks=self._orderbooks,
)
except Exception:
self._log.exception("strategy_eval_error", symbol=symbol)
def _on_orderbook_update(
self, token_id: str, snapshot: OrderBookSnapshot
) -> None:
"""Callback invoked for each Polymarket orderbook update."""
self._orderbooks[token_id] = snapshot
# Update position mark-to-market
if self._position_tracker:
self._position_tracker.update_from_orderbooks(self._orderbooks)
def _on_window_change(self, window: WindowState) -> None:
"""Callback fired when a price window transitions."""
self._log.info(
"window_changed",
asset=window.asset.value,
timeframe=window.timeframe.value,
start_price=window.start_price,
window_start=window.window_start_time,
window_end=window.window_end_time,
)
# Snapshot the previous window to DB
if self._db is not None and window.start_price is not None:
try:
self._db.log_window(window)
except Exception:
self._log.exception("window_snapshot_failed")
# Resolve positions for this window (window ended = resolution)
if self._position_tracker and window.current_price and window.start_price:
self._resolve_window_positions(window)
def _resolve_window_positions(self, window: WindowState) -> None:
"""Close positions whose window just resolved."""
if not window.market or not self._position_tracker:
return
actual_direction = (
Direction.UP
if window.current_price and window.start_price
and window.current_price > window.start_price
else Direction.DOWN
)
market = window.market
for token_id, is_up in [
(market.up_token_id, True),
(market.down_token_id, False),
]:
pos = self._position_tracker.get_position(token_id)
if pos is None:
continue
won = (is_up and actual_direction == Direction.UP) or (
not is_up and actual_direction == Direction.DOWN
)
resolution_price = 1.0 if won else 0.0
pnl = self._position_tracker.close_position(token_id, resolution_price)
if pnl is not None:
self._balance += pnl
if self._db:
self._db.update_trade(
pos.market_id,
pnl=pnl,
status=TradeStatus.FILLED.value,
)
self._db.log_balance(self._balance, self._position_tracker.total_pnl)
# Notify
if self._telegram:
result = "WIN" if pnl > 0 else "LOSS"
asyncio.ensure_future(
self._telegram.send(
f"{'' if pnl > 0 else ''} <b>Position Resolved</b>\n"
f"Asset: {pos.asset.value} | {pos.direction.value}\n"
f"Result: {result} | PnL: ${pnl:+.2f}\n"
f"Balance: ${self._balance:.2f}"
)
)
self._log.info(
"position_resolved",
asset=pos.asset.value,
direction=pos.direction.value,
won=won,
pnl=round(pnl, 2),
balance=round(self._balance, 2),
)
# Update strategy balance
if self._signal_aggregator:
self._signal_aggregator.update_balance(self._balance)
def _on_new_markets(self, markets: list[ActiveMarket]) -> None:
"""Callback invoked when MarketDiscovery finds new markets."""
self._active_markets.extend(markets)
if self._tracker is not None:
_link_markets_to_tracker(markets, self._tracker, self._log)
if self._poly_feed is not None:
_subscribe_market_tokens(markets, self._poly_feed, self._log)
def _on_signal(self, signal: Signal) -> None:
"""Handle a trading signal from the strategy aggregator."""
if not self._risk_manager or not self._order_manager:
return
# Risk check
additional_usd = signal.price * signal.size
if not self._risk_manager.can_open_position(additional_usd):
self._log.warning(
"signal_rejected_risk",
asset=signal.asset.value,
direction=signal.direction.value,
reason=self._risk_manager.halt_reason or "risk_limit",
)
if self._risk_manager.is_halted and self._telegram:
asyncio.ensure_future(
self._telegram.notify_halt(self._risk_manager.halt_reason)
)
return
# Submit order
asyncio.ensure_future(self._execute_signal(signal))
async def _execute_signal(self, signal: Signal) -> None:
"""Execute a signal through the order manager."""
if not self._order_manager or not self._position_tracker:
return
trade = await self._order_manager.submit_signal(signal)
if trade is None:
return
if trade.status == TradeStatus.FAILED:
self._log.error("trade_failed", asset=signal.asset.value)
return
# Log trade to DB
if self._db:
self._db.log_trade(trade)
# Notify via telegram
if self._telegram:
await self._telegram.notify_trade(
asset=signal.asset.value,
direction=signal.direction.value,
timeframe=signal.timeframe.value,
price=signal.price,
size=signal.size,
edge=signal.edge,
)
self._log.info(
"trade_submitted",
trade_id=trade.id,
asset=signal.asset.value,
direction=signal.direction.value,
price=signal.price,
size=signal.size,
edge=round(signal.edge, 4),
)
def _on_fill(self, trade: Trade) -> None:
"""Handle a filled order."""
if not self._position_tracker:
return
# Open position
self._position_tracker.open_position(trade)
# Update DB
if self._db:
self._db.update_trade(
trade.id,
fill_price=trade.fill_price,
fill_size=trade.fill_size,
fee=trade.fee,
status=TradeStatus.FILLED.value,
)
# Notify
if self._telegram:
asyncio.ensure_future(
self._telegram.notify_fill(
asset=trade.signal.asset.value,
direction=trade.signal.direction.value,
fill_price=trade.fill_price,
fill_size=trade.fill_size,
trade_id=trade.id,
)
)
# ------------------------------------------------------------------
# Background loops
# ------------------------------------------------------------------
async def _status_loop(self) -> None:
"""Log a status summary every ``_STATUS_INTERVAL_S`` seconds."""
while not self._shutdown_event.is_set():
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=_STATUS_INTERVAL_S,
)
break
except asyncio.TimeoutError:
pass
windows = (
self._tracker.get_all_active_windows()
if self._tracker
else []
)
latest_prices: dict[str, Optional[float]] = {}
if self._binance_feed is not None:
for sym in self._cfg.general.assets:
latest_prices[sym] = self._binance_feed.get_latest_price(sym)
position_summary = (
self._position_tracker.get_summary()
if self._position_tracker
else {}
)
risk_summary = (
self._risk_manager.get_risk_summary()
if self._risk_manager
else {}
)
strategy_stats = (
self._signal_aggregator.get_stats()
if self._signal_aggregator
else {}
)
self._log.info(
"status",
binance_connected=(
self._binance_feed.is_connected
if self._binance_feed
else False
),
polymarket_connected=(
self._poly_feed.is_connected
if self._poly_feed
else False
),
active_windows=len(windows),
active_markets=len(self._active_markets),
latest_prices=latest_prices,
balance=round(self._balance, 2),
positions=position_summary,
risk=risk_summary,
strategies=strategy_stats,
today_trades=(
self._db.get_today_trade_count() if self._db else 0
),
today_pnl=(
self._db.get_today_pnl() if self._db else 0.0
),
)
async def _balance_snapshot_loop(self) -> None:
"""Periodically snapshot balance to DB."""
while not self._shutdown_event.is_set():
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=_BALANCE_SNAPSHOT_INTERVAL_S,
)
break
except asyncio.TimeoutError:
pass
if self._db and self._position_tracker:
self._db.log_balance(
self._balance,
self._position_tracker.total_pnl,
)
async def _daily_summary_loop(self) -> None:
"""Send daily summary at midnight UTC."""
while not self._shutdown_event.is_set():
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=60.0,
)
break
except asyncio.TimeoutError:
pass
now_utc = datetime.now(timezone.utc)
today = now_utc.strftime("%Y-%m-%d")
if (
now_utc.hour == _DAILY_SUMMARY_HOUR_UTC
and today != self._last_daily_summary_date
):
self._last_daily_summary_date = today
# Summarize yesterday
yesterday = (
datetime.now(timezone.utc).replace(hour=0, minute=0, second=0)
)
yesterday_str = yesterday.strftime("%Y-%m-%d")
if self._db:
self._db.update_daily_summary(yesterday_str)
summary = self._db.get_daily_summary(yesterday_str)
if self._telegram and summary["total_trades"] > 0:
await self._telegram.notify_daily_summary(
date=yesterday_str,
total_trades=summary["total_trades"],
wins=summary["wins"],
losses=summary["losses"],
pnl=summary["total_pnl"],
fees=summary["total_fees"],
volume=summary["total_volume"],
)
async def _order_expiry_loop(self) -> None:
"""Cancel orders close to resolution."""
while not self._shutdown_event.is_set():
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=5.0,
)
break
except asyncio.TimeoutError:
pass
if self._order_manager:
try:
cancelled = await self._order_manager.cancel_expiring_orders(
seconds_before_resolution=self._cfg.temporal_arb.exit_before_resolution_sec,
)
if cancelled > 0:
self._log.info("expiring_orders_cancelled", count=cancelled)
except Exception:
self._log.exception("expiry_loop_error")
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
async def start(self) -> None:
"""Initialise all components, then run feeds concurrently."""
self._log.info("bot_initialising")
# 1. Database
db_path = "trades.db"
if self._cfg.general.mode == "paper":
db_path = "paper_trades.db"
self._db = TradeDB(db_path=db_path)
# Record initial balance
self._db.log_balance(self._balance, 0.0, event="start")
# 2. Market discovery
self._discovery = MarketDiscovery(
on_new_markets=self._on_new_markets,
)
self._log.info("running_initial_discovery")
initial_markets = await self._discovery.discover()
self._active_markets = initial_markets
self._log.info(
"initial_discovery_complete",
count=len(initial_markets),
)
# 3. Window tracker
assets = [Asset(a) for a in self._cfg.general.assets]
timeframes = [Timeframe(t) for t in self._cfg.general.timeframes]
self._tracker = WindowTracker(assets=assets, timeframes=timeframes)
self._tracker.on_window_change(self._on_window_change)
# Link initial markets to windows
_link_markets_to_tracker(initial_markets, self._tracker, self._log)
# 4. CLOB client (for live mode)
self._clob = ClobClientWrapper(self._cfg.polymarket)
if self._cfg.general.mode == "live":
await self._clob.initialize()
self._log.info("clob_client_ready", is_ready=self._clob.is_ready)
# 5. Execution components
self._position_tracker = PositionTracker(self._cfg)
self._order_manager = OrderManager(self._cfg, self._clob)
self._order_manager.on_fill(self._on_fill)
self._risk_manager = RiskManager(
self._cfg.risk, self._position_tracker, self._db
)
# 6. Strategy aggregator
self._signal_aggregator = SignalAggregator(
self._cfg, balance=self._balance
)
self._signal_aggregator.on_signal(self._on_signal)
# 7. Telegram
self._telegram = TelegramNotifier(self._cfg.notifications)
await self._telegram.send(
f"🚀 <b>Polymarket Arb Bot Started</b>\n"
f"Mode: {self._cfg.general.mode}\n"
f"Balance: ${self._balance:.2f}\n"
f"Assets: {', '.join(self._cfg.general.assets)}"
)
# 8. Binance feed
ws_url = self._cfg.binance.ws_url
streams = "/".join(
f"{s}@trade" for s in self._cfg.binance.symbols
)
full_url = f"{ws_url}?streams={streams}"
self._binance_feed = BinanceFeed(url=full_url)
self._binance_feed.subscribe(self._on_binance_price)
# 9. Polymarket feed
self._poly_feed = PolymarketFeed(url=self._cfg.polymarket.ws_url)
self._poly_feed.on_orderbook_update(self._on_orderbook_update)
# Subscribe to discovered market tokens
_subscribe_market_tokens(
initial_markets, self._poly_feed, self._log
)
self._log.info("bot_starting_feeds")
# 10. Run all tasks concurrently
tasks = [
self._binance_feed.start(),
self._poly_feed.start(),
self._discovery.discover_loop(
interval_sec=_DISCOVERY_INTERVAL_S,
),
self._status_loop(),
self._balance_snapshot_loop(),
self._daily_summary_loop(),
self._order_expiry_loop(),
]
# Add order monitor for live mode
if self._cfg.general.mode == "live" and self._clob.is_ready:
tasks.append(
self._order_manager.monitor_loop(
interval=_ORDER_MONITOR_INTERVAL_S
)
)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
self._log.info("gather_cancelled")
async def shutdown(self) -> None:
"""Gracefully stop all components."""
self._log.info("bot_shutting_down")
self._shutdown_event.set()
# Cancel all pending orders
if self._order_manager:
cancelled = await self._order_manager.cancel_all_pending()
self._log.info("shutdown_orders_cancelled", count=cancelled)
self._order_manager.stop()
if self._binance_feed is not None:
await self._binance_feed.stop()
if self._poly_feed is not None:
await self._poly_feed.stop()
# Final balance snapshot
if self._db and self._position_tracker:
self._db.log_balance(
self._balance,
self._position_tracker.total_pnl,
event="shutdown",
)
# Update daily summary
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
self._db.update_daily_summary(today)
# Notify shutdown
if self._telegram:
position_summary = (
self._position_tracker.get_summary()
if self._position_tracker
else {}
)
await self._telegram.send(
f"🛑 <b>Bot Stopped</b>\n"
f"Balance: ${self._balance:.2f}\n"
f"Total PnL: ${position_summary.get('total_pnl', 0):.2f}\n"
f"Trades: {position_summary.get('trades', 0)}"
)
await self._telegram.close()
self._log.info("bot_stopped")
# ---------------------------------------------------------------------------
# Signal handling and entry point
# ---------------------------------------------------------------------------
def _install_signal_handlers(
loop: asyncio.AbstractEventLoop, bot: ArbBot
) -> None:
"""Register SIGINT / SIGTERM handlers that trigger graceful shutdown."""
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(
sig,
lambda: asyncio.ensure_future(bot.shutdown()),
)
except NotImplementedError:
pass
async def async_main(config: Optional[Config] = None) -> None:
"""Async entry point — load config, build the bot, and run."""
if config is None:
config = load_config()
setup_logging(config.general.log_level)
_print_startup_banner(config)
bot = ArbBot(config)
loop = asyncio.get_running_loop()
_install_signal_handlers(loop, bot)
try:
await bot.start()
except KeyboardInterrupt:
pass
finally:
await bot.shutdown()
def main() -> None:
"""Synchronous entry point for ``python src/main.py``."""
try:
asyncio.run(async_main())
except KeyboardInterrupt:
print("\nShutdown complete.")
sys.exit(0)
if __name__ == "__main__":
main()

6
src/market/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""Market discovery and window tracking."""
from src.market.discovery import MarketDiscovery
from src.market.window_tracker import WindowTracker
__all__ = ["MarketDiscovery", "WindowTracker"]

275
src/market/discovery.py Normal file
View File

@@ -0,0 +1,275 @@
"""Active 5min/15min Up/Down crypto market discovery via Gamma API."""
from __future__ import annotations
import asyncio
import re
from typing import Callable, Optional
import aiohttp
import structlog
from src.data.models import ActiveMarket, Asset, Timeframe
logger = structlog.get_logger(__name__)
GAMMA_API_URL = "https://gamma-api.polymarket.com/events"
GAMMA_QUERY_PARAMS = {
"tag": "crypto",
"active": "true",
"closed": "false",
"limit": "100",
}
SUPPORTED_ASSETS: dict[str, Asset] = {
"BTC": Asset.BTC,
"ETH": Asset.ETH,
"SOL": Asset.SOL,
}
TIMEFRAME_PATTERNS: dict[str, Timeframe] = {
"5 Min": Timeframe.FIVE_MIN,
"5 min": Timeframe.FIVE_MIN,
"5Min": Timeframe.FIVE_MIN,
"15 Min": Timeframe.FIFTEEN_MIN,
"15 min": Timeframe.FIFTEEN_MIN,
"15Min": Timeframe.FIFTEEN_MIN,
}
def _extract_asset(title: str) -> Optional[Asset]:
"""Extract the crypto asset from an event/market title."""
for keyword, asset in SUPPORTED_ASSETS.items():
if keyword in title:
return asset
return None
def _extract_timeframe(title: str) -> Optional[Timeframe]:
"""Extract the timeframe from an event/market title."""
for pattern, timeframe in TIMEFRAME_PATTERNS.items():
if pattern in title:
return timeframe
return None
def _extract_token_ids(market: dict) -> tuple[str, str] | None:
"""Extract (up_token_id, down_token_id) from a market dict.
The Gamma API returns token IDs as a JSON-encoded list or a
``clobTokenIds`` field. The first token is conventionally the
"Up" outcome, the second is "Down". We also inspect ``outcomes``
to verify ordering when available.
"""
tokens: list[str] = []
outcomes: list[str] = []
# Token IDs
raw_tokens = market.get("clobTokenIds")
if isinstance(raw_tokens, list):
tokens = [str(t) for t in raw_tokens]
elif isinstance(raw_tokens, str):
# Sometimes returned as JSON-encoded string "[\"0xabc\",\"0xdef\"]"
try:
import json
tokens = [str(t) for t in json.loads(raw_tokens)]
except (json.JSONDecodeError, TypeError):
return None
# Outcomes
raw_outcomes = market.get("outcomes")
if isinstance(raw_outcomes, list):
outcomes = [str(o).upper().strip() for o in raw_outcomes]
elif isinstance(raw_outcomes, str):
try:
import json
outcomes = [str(o).upper().strip() for o in json.loads(raw_outcomes)]
except (json.JSONDecodeError, TypeError):
pass
if len(tokens) < 2:
return None
# Determine which token is Up and which is Down
up_token: str = tokens[0]
down_token: str = tokens[1]
if len(outcomes) >= 2:
if outcomes[0] == "DOWN" and outcomes[1] == "UP":
up_token, down_token = tokens[1], tokens[0]
return up_token, down_token
class MarketDiscovery:
"""Discovers active 5-min and 15-min Up/Down crypto markets on Polymarket.
Uses the Gamma API to fetch events tagged as crypto, then filters for
short-duration binary Up/Down markets on BTC, ETH, or SOL.
"""
def __init__(
self,
session: Optional[aiohttp.ClientSession] = None,
on_new_markets: Optional[Callable[[list[ActiveMarket]], None]] = None,
) -> None:
self._external_session = session
self._session: Optional[aiohttp.ClientSession] = session
self._on_new_markets = on_new_markets
self._seen_condition_ids: set[str] = set()
self._log = logger.bind(component="MarketDiscovery")
async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def _close_session_if_owned(self) -> None:
"""Close the HTTP session only if we created it ourselves."""
if self._external_session is None and self._session is not None and not self._session.closed:
await self._session.close()
# ------------------------------------------------------------------
# Core discovery
# ------------------------------------------------------------------
async def discover(self) -> list[ActiveMarket]:
"""Fetch events from the Gamma API and return matching ActiveMarket instances."""
session = await self._ensure_session()
self._log.info("gamma_api_fetch_start")
try:
async with session.get(GAMMA_API_URL, params=GAMMA_QUERY_PARAMS, timeout=aiohttp.ClientTimeout(total=15)) as resp:
if resp.status == 429:
retry_after = float(resp.headers.get("Retry-After", "5"))
self._log.warning("gamma_api_rate_limited", retry_after=retry_after)
await asyncio.sleep(retry_after)
return []
if resp.status != 200:
self._log.error("gamma_api_http_error", status=resp.status, reason=resp.reason)
return []
try:
data = await resp.json()
except (aiohttp.ContentTypeError, ValueError) as exc:
self._log.error("gamma_api_json_parse_error", error=str(exc))
return []
except asyncio.TimeoutError:
self._log.error("gamma_api_timeout")
return []
except aiohttp.ClientError as exc:
self._log.error("gamma_api_client_error", error=str(exc))
return []
events: list[dict] = data if isinstance(data, list) else []
markets = self._filter_and_parse(events)
self._log.info("gamma_api_fetch_done", total_events=len(events), matched_markets=len(markets))
return markets
def _filter_and_parse(self, events: list[dict]) -> list[ActiveMarket]:
"""Filter events and their sub-markets, returning ActiveMarket instances."""
results: list[ActiveMarket] = []
for event in events:
title: str = event.get("title", "")
# Must be an Up or Down style market
if not re.search(r"[Uu]p\s+or\s+[Dd]own", title):
continue
# Must match a supported timeframe
timeframe = _extract_timeframe(title)
if timeframe is None:
continue
# Must be for a supported asset
asset = _extract_asset(title)
if asset is None:
continue
# Process each sub-market within the event
sub_markets: list[dict] = event.get("markets", [])
if not sub_markets:
# Some API shapes embed market data at the event level
sub_markets = [event]
for mkt in sub_markets:
if not mkt.get("active", False):
continue
if not mkt.get("enableOrderBook", False):
continue
condition_id: str = mkt.get("conditionId", "") or mkt.get("condition_id", "")
if not condition_id:
continue
token_pair = _extract_token_ids(mkt)
if token_pair is None:
self._log.debug("skip_market_missing_tokens", condition_id=condition_id)
continue
up_token, down_token = token_pair
end_date: str = mkt.get("endDate", "") or mkt.get("end_date_iso", "") or ""
results.append(
ActiveMarket(
condition_id=condition_id,
up_token_id=up_token,
down_token_id=down_token,
asset=asset,
timeframe=timeframe,
end_date=end_date,
question=mkt.get("question", title),
description=mkt.get("description", ""),
)
)
return results
# ------------------------------------------------------------------
# Continuous discovery loop
# ------------------------------------------------------------------
async def discover_loop(self, interval_sec: float = 30) -> None:
"""Continuously discover markets and invoke the callback with new ones.
Runs indefinitely. Each iteration sleeps for *interval_sec* seconds
after completing a fetch cycle. Previously seen ``condition_id`` values
are cached so the callback only receives genuinely new markets.
"""
self._log.info("discover_loop_start", interval_sec=interval_sec)
try:
while True:
try:
all_markets = await self.discover()
new_markets = [
m for m in all_markets
if m.condition_id not in self._seen_condition_ids
]
for m in new_markets:
self._seen_condition_ids.add(m.condition_id)
if new_markets:
self._log.info("new_markets_found", count=len(new_markets))
if self._on_new_markets is not None:
self._on_new_markets(new_markets)
except Exception:
self._log.exception("discover_loop_iteration_error")
await asyncio.sleep(interval_sec)
finally:
await self._close_session_if_owned()
# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
def reset_cache(self) -> None:
"""Clear the set of previously seen condition IDs."""
self._seen_condition_ids.clear()
self._log.info("cache_cleared")

186
src/market/oracle.py Normal file
View File

@@ -0,0 +1,186 @@
"""Chainlink Oracle monitor — tracks oracle update latency via web3.
Monitors the delay between real CEX prices and Chainlink oracle updates
on Polygon, which is the core edge in temporal arbitrage.
"""
from __future__ import annotations
import asyncio
import time
from typing import Optional
import structlog
log = structlog.get_logger()
# Chainlink AggregatorV3Interface ABI (latestRoundData only)
AGGREGATOR_ABI = [
{
"inputs": [],
"name": "latestRoundData",
"outputs": [
{"name": "roundId", "type": "uint80"},
{"name": "answer", "type": "int256"},
{"name": "startedAt", "type": "uint256"},
{"name": "updatedAt", "type": "uint256"},
{"name": "answeredInRound", "type": "uint80"},
],
"stateMutability": "view",
"type": "function",
},
{
"inputs": [],
"name": "decimals",
"outputs": [{"name": "", "type": "uint8"}],
"stateMutability": "view",
"type": "function",
},
]
class OracleMonitor:
"""Monitors Chainlink oracle update frequency and latency.
The oracle typically updates every ~10-30 seconds or on 0.5% deviation.
Tracking this helps calibrate the temporal arbitrage opportunity window.
"""
# Chainlink price feed addresses on Polygon
FEEDS = {
"BTC": "0xc907E116054Ad103354f2D350FD2514433D57F6f",
"ETH": "0xF9680D99D6C9589e2a93a78A04A279e509205945",
"SOL": "0x10C8264C0935b3B9870013e057f330Ff3e9C56dC",
}
def __init__(self, rpc_url: Optional[str] = None) -> None:
self._rpc_url = rpc_url or "https://polygon.drpc.org"
self._w3 = None
self._contracts: dict[str, object] = {}
self._decimals: dict[str, int] = {}
self._last_oracle_prices: dict[str, float] = {}
self._last_oracle_timestamps: dict[str, float] = {}
self._last_oracle_round_ids: dict[str, int] = {}
self._update_intervals: dict[str, list[float]] = {
"BTC": [], "ETH": [], "SOL": []
}
self._initialized = False
async def initialize(self) -> bool:
"""Initialize web3 connection and contract instances."""
try:
from web3 import Web3
self._w3 = Web3(Web3.HTTPProvider(self._rpc_url))
if not self._w3.is_connected():
log.error("oracle_web3_not_connected", rpc_url=self._rpc_url)
return False
for asset, address in self.FEEDS.items():
checksum = self._w3.to_checksum_address(address)
contract = self._w3.eth.contract(
address=checksum, abi=AGGREGATOR_ABI
)
self._contracts[asset] = contract
self._decimals[asset] = contract.functions.decimals().call()
self._initialized = True
log.info(
"oracle_initialized",
rpc_url=self._rpc_url,
assets=list(self._contracts.keys()),
)
return True
except ImportError:
log.warning("oracle_web3_not_installed", msg="pip install web3")
return False
except Exception:
log.exception("oracle_init_failed")
return False
async def get_oracle_price(self, asset: str) -> Optional[float]:
"""Fetch the latest oracle price for an asset from Chainlink."""
if not self._initialized or asset not in self._contracts:
return self._last_oracle_prices.get(asset)
try:
contract = self._contracts[asset]
result = contract.functions.latestRoundData().call()
round_id, answer, started_at, updated_at, answered_in_round = result
decimals = self._decimals.get(asset, 8)
price = answer / (10 ** decimals)
# Record the update
self.record_oracle_update(asset, price, updated_at)
self._last_oracle_round_ids[asset] = round_id
return price
except Exception:
log.exception("oracle_fetch_failed", asset=asset)
return self._last_oracle_prices.get(asset)
def record_oracle_update(self, asset: str, price: float, timestamp: float) -> None:
"""Record an observed oracle price update."""
prev_ts = self._last_oracle_timestamps.get(asset)
if prev_ts is not None and timestamp > prev_ts:
interval = timestamp - prev_ts
intervals = self._update_intervals[asset]
intervals.append(interval)
# Keep last 100 intervals
if len(intervals) > 100:
intervals.pop(0)
self._last_oracle_prices[asset] = price
self._last_oracle_timestamps[asset] = timestamp
def get_avg_update_interval(self, asset: str) -> Optional[float]:
"""Get average oracle update interval in seconds."""
intervals = self._update_intervals.get(asset, [])
if not intervals:
return None
return sum(intervals) / len(intervals)
def get_estimated_lag(self, asset: str) -> Optional[float]:
"""Estimate current oracle lag (time since last known update)."""
last_ts = self._last_oracle_timestamps.get(asset)
if last_ts is None:
return None
return time.time() - last_ts
def get_oracle_vs_cex_deviation(
self, asset: str, cex_price: float
) -> Optional[float]:
"""Calculate percentage deviation between oracle and CEX price."""
oracle_price = self._last_oracle_prices.get(asset)
if oracle_price is None or oracle_price <= 0:
return None
return (cex_price - oracle_price) / oracle_price * 100
async def poll_loop(self, interval: float = 5.0) -> None:
"""Continuously poll oracle prices."""
if not self._initialized:
log.warning("oracle_poll_not_initialized")
return
while True:
for asset in self.FEEDS:
try:
await self.get_oracle_price(asset)
except Exception:
log.exception("oracle_poll_error", asset=asset)
await asyncio.sleep(interval)
def get_stats(self) -> dict:
return {
asset: {
"last_price": self._last_oracle_prices.get(asset),
"last_round_id": self._last_oracle_round_ids.get(asset),
"avg_interval": round(avg, 2) if (avg := self.get_avg_update_interval(asset)) else None,
"estimated_lag": round(lag, 2) if (lag := self.get_estimated_lag(asset)) else None,
"initialized": self._initialized,
}
for asset in self.FEEDS
}

View File

@@ -0,0 +1,193 @@
"""Track 5-minute and 15-minute price windows for BTC, ETH, SOL.
Maintains six simultaneous windows (3 assets x 2 timeframes), capturing the
start price from the first CEX tick after each window opens and tracking the
current price throughout.
"""
from __future__ import annotations
import math
from typing import Callable, Optional
import structlog
from src.data.models import ActiveMarket, Asset, Timeframe, WindowState
log = structlog.get_logger(__name__)
# Timeframe durations in seconds.
_TIMEFRAME_SECONDS: dict[Timeframe, int] = {
Timeframe.FIVE_MIN: 5 * 60,
Timeframe.FIFTEEN_MIN: 15 * 60,
}
def _window_bounds(timestamp: float, timeframe: Timeframe) -> tuple[float, float]:
"""Return (start, end) UTC-aligned window boundaries for *timestamp*."""
interval = _TIMEFRAME_SECONDS[timeframe]
start = math.floor(timestamp / interval) * interval
return float(start), float(start + interval)
class WindowTracker:
"""Track clock-aligned price windows for multiple assets and timeframes.
Parameters
----------
assets:
Assets to track. Defaults to BTC, ETH, SOL.
timeframes:
Timeframes to track. Defaults to 5M and 15M.
"""
def __init__(
self,
assets: Optional[list[Asset]] = None,
timeframes: Optional[list[Timeframe]] = None,
) -> None:
self._assets: list[Asset] = assets or list(Asset)
self._timeframes: list[Timeframe] = timeframes or list(Timeframe)
# Keyed by (asset, timeframe).
self._windows: dict[tuple[Asset, Timeframe], WindowState] = {}
self._callbacks: list[Callable[[WindowState], None]] = []
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def update_price(self, asset: str, price: float, timestamp: float) -> None:
"""Process an incoming CEX price tick.
If *timestamp* falls within a new window that we haven't initialised
yet (or a different window from the one we're tracking), a fresh
``WindowState`` is created and any registered callbacks are fired.
The first tick inside a window sets ``start_price``. Every tick
updates ``current_price``.
"""
asset_enum = Asset(asset)
for tf in self._timeframes:
key = (asset_enum, tf)
win_start, win_end = _window_bounds(timestamp, tf)
existing = self._windows.get(key)
# Detect window transition (or first-ever window).
if existing is None or existing.window_start_time != win_start:
# Carry over any linked market from the previous window.
linked_market = existing.market if existing else None
new_window = WindowState(
asset=asset_enum,
timeframe=tf,
window_start_time=win_start,
window_end_time=win_end,
start_price=price,
current_price=price,
market=linked_market,
)
self._windows[key] = new_window
log.info(
"window_transition",
asset=asset_enum.value,
timeframe=tf.value,
window_start=win_start,
window_end=win_end,
start_price=price,
)
# Fire callbacks with the COMPLETED window so that
# subscribers can read its final price_change_pct.
# On the very first window there is no completed window
# to report, so fire with the new one instead.
completed = existing if existing is not None else new_window
self._fire_callbacks(completed)
else:
# Same window — update current price. If start_price was
# somehow not captured (shouldn't happen with this logic,
# but defensive), fill it with the first available price.
if existing.start_price is None:
existing.start_price = price
log.info(
"late_start_price",
asset=asset_enum.value,
timeframe=tf.value,
price=price,
)
existing.current_price = price
def get_window(self, asset: str, timeframe: str) -> Optional[WindowState]:
"""Return the current window state for an asset/timeframe pair."""
key = (Asset(asset), Timeframe(timeframe))
return self._windows.get(key)
def get_all_active_windows(self) -> list[WindowState]:
"""Return every currently tracked window."""
return list(self._windows.values())
def is_window_expired(self, asset: str, timeframe: str) -> bool:
"""Check whether the tracked window has already ended.
Returns ``True`` when the window end time is in the past **or**
when no window has been initialised for the given pair.
"""
import time
key = (Asset(asset), Timeframe(timeframe))
window = self._windows.get(key)
if window is None:
return True
return time.time() >= window.window_end_time
def link_market(
self, asset: str, timeframe: str, market: ActiveMarket
) -> None:
"""Associate a Polymarket market with the current window.
If no window exists yet for the pair, one is **not** created — the
market will be linked once the first price tick arrives and triggers
a window transition.
"""
key = (Asset(asset), Timeframe(timeframe))
window = self._windows.get(key)
if window is not None:
window.market = market
log.info(
"market_linked",
asset=asset,
timeframe=timeframe,
condition_id=market.condition_id,
)
else:
log.warning(
"link_market_no_window",
asset=asset,
timeframe=timeframe,
condition_id=market.condition_id,
)
def on_window_change(self, callback: Callable[[WindowState], None]) -> None:
"""Register a callback invoked whenever a new window starts.
The callback receives the newly-created ``WindowState``.
"""
self._callbacks.append(callback)
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _fire_callbacks(self, window: WindowState) -> None:
for cb in self._callbacks:
try:
cb(window)
except Exception:
log.exception(
"window_change_callback_error",
asset=window.asset.value,
timeframe=window.timeframe.value,
)

7
src/risk/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""Risk management modules."""
from src.risk.fee_calculator import FeeCalculator
from src.risk.position_sizer import PositionSizer
from src.risk.risk_manager import RiskManager
__all__ = ["PositionSizer", "RiskManager", "FeeCalculator"]

View File

@@ -0,0 +1,90 @@
"""Fee Calculator — computes trading fees for Polymarket CLOB.
Fee structure:
- 5-minute markets: max 1.56% taker fee
- 15-minute markets: max 3% taker fee
- Fee is applied to potential profit, not to cost basis
- Maker rebate program available for orders resting >= 3.5 seconds
"""
from __future__ import annotations
from src.config import FeesConfig
import structlog
log = structlog.get_logger()
class FeeCalculator:
"""Calculates expected fees for Polymarket trades."""
def __init__(self, fees_config: FeesConfig) -> None:
self.fees = fees_config
def taker_fee(self, timeframe: str, entry_price: float, size: int) -> float:
"""Calculate the taker fee for a trade.
The fee is applied to the potential profit (payout - cost),
not the cost itself.
Args:
timeframe: "5M" or "15M".
entry_price: Entry price per share (e.g., 0.50).
size: Number of shares.
Returns:
Fee in USD.
"""
fee_rate = self.fees.fee_for_timeframe(timeframe)
payout = size * 1.0 # $1 per share if winning
cost = entry_price * size
profit = payout - cost
return max(0, profit * fee_rate)
def breakeven_price(self, timeframe: str, entry_price: float) -> float:
"""Calculate the breakeven probability needed to cover fees.
If you buy at `entry_price`, you need the true probability
to exceed this value to be profitable after fees.
"""
fee_rate = self.fees.fee_for_timeframe(timeframe)
# Solving: prob * (1 - fee_rate) * (1 - entry_price) > entry_price * (1 - prob)
# prob > entry_price / (1 - fee_rate * (1 - entry_price))
denominator = 1.0 - fee_rate * (1.0 - entry_price)
if denominator <= 0:
return 1.0
return entry_price / denominator
def net_payout(self, timeframe: str, entry_price: float, size: int, won: bool) -> float:
"""Calculate net payout after fees.
Args:
timeframe: "5M" or "15M".
entry_price: Entry price per share.
size: Number of shares.
won: Whether the outcome was correct.
Returns:
Net payout in USD (can be negative for losses).
"""
cost = entry_price * size
if won:
gross_payout = size * 1.0
fee = self.taker_fee(timeframe, entry_price, size)
return gross_payout - cost - fee
else:
return -cost # Total loss, no fee on losing side
def expected_value(
self, timeframe: str, entry_price: float, estimated_prob: float, size: int
) -> float:
"""Calculate expected value of a trade.
EV = prob * net_win_payout + (1-prob) * net_loss
"""
win_payout = self.net_payout(timeframe, entry_price, size, won=True)
loss_payout = self.net_payout(timeframe, entry_price, size, won=False)
return estimated_prob * win_payout + (1 - estimated_prob) * loss_payout

View File

@@ -0,0 +1,90 @@
"""Kelly Criterion position sizer with risk adjustments."""
from __future__ import annotations
import math
from typing import Optional
import structlog
from src.config import FeesConfig, RiskConfig
log = structlog.get_logger()
class PositionSizer:
"""Determines optimal position size using Kelly Criterion
with multiple safety caps and adjustments.
"""
def __init__(self, risk_config: RiskConfig, fees_config: FeesConfig) -> None:
self.risk = risk_config
self.fees = fees_config
def calculate(
self,
estimated_prob: float,
poly_price: float,
timeframe: str,
balance: float,
current_exposure: float = 0.0,
) -> int:
"""Calculate position size in shares.
Args:
estimated_prob: Our estimated true probability (0-1).
poly_price: Polymarket entry price (0-1).
timeframe: "5M" or "15M" for fee lookup.
balance: Available balance in USD.
current_exposure: Current total exposure in USD.
Returns:
Number of shares to buy (0 if no trade).
"""
if poly_price <= 0 or poly_price >= 1.0 or estimated_prob <= 0:
return 0
fee_rate = self.fees.fee_for_timeframe(timeframe)
# Kelly Criterion: f* = (b*p - q) / b
b = (1.0 / poly_price) - 1.0 # Payout odds
p = estimated_prob
q = 1.0 - p
if b <= 0:
return 0
kelly_raw = (b * p - q) / b
if kelly_raw <= 0:
return 0
# Cap 1: Kelly fraction cap (default 25%)
kelly_adj = min(kelly_raw, self.risk.kelly_fraction_cap)
# Cap 2: Use half-Kelly for additional safety
kelly_adj *= 0.5
# Cap 3: Dollar size limits
dollar_size = balance * kelly_adj
dollar_size = min(dollar_size, self.risk.max_position_per_market_usd)
# Cap 4: Total exposure limit
remaining_capacity = self.risk.max_total_exposure_usd - current_exposure
if remaining_capacity <= 0:
return 0
dollar_size = min(dollar_size, remaining_capacity)
# Convert to shares
shares = int(dollar_size / poly_price)
log.debug(
"position_sized",
kelly_raw=round(kelly_raw, 4),
kelly_adj=round(kelly_adj, 4),
dollar_size=round(dollar_size, 2),
shares=shares,
prob=round(estimated_prob, 4),
price=poly_price,
)
return max(0, shares)

161
src/risk/risk_manager.py Normal file
View File

@@ -0,0 +1,161 @@
"""Risk Manager — enforces trading limits and circuit breakers.
Monitors exposure, daily PnL, and position counts to prevent
catastrophic losses.
"""
from __future__ import annotations
import time
from datetime import datetime, timezone
import structlog
from src.config import RiskConfig
from src.data.db import TradeDB
from src.execution.position_tracker import PositionTracker
log = structlog.get_logger()
class RiskManager:
"""Central risk management.
Tracks daily PnL, total exposure, and position count.
Can halt trading if limits are breached.
"""
def __init__(
self,
risk_config: RiskConfig,
position_tracker: PositionTracker,
trade_db: TradeDB,
) -> None:
self.risk = risk_config
self.tracker = position_tracker
self.db = trade_db
self._halted = False
self._halt_reason: str = ""
self._daily_pnl_cache: float = 0.0
self._last_pnl_check: float = 0.0
@property
def is_halted(self) -> bool:
return self._halted
@property
def halt_reason(self) -> str:
return self._halt_reason
def check_all(self) -> bool:
"""Run all risk checks. Returns True if trading is allowed."""
if self._halted:
return False
if not self._check_daily_loss():
return False
if not self._check_total_exposure():
return False
if not self._check_position_count():
return False
return True
def can_open_position(self, additional_usd: float) -> bool:
"""Check if a new position of given size can be opened."""
if self._halted:
return False
if not self.tracker.is_exposure_ok(additional_usd):
log.warning(
"risk_exposure_limit",
current=round(self.tracker.total_exposure, 2),
additional=round(additional_usd, 2),
limit=self.risk.max_total_exposure_usd,
)
return False
if self.tracker.position_count >= self.risk.max_concurrent_positions:
log.warning(
"risk_position_limit",
current=self.tracker.position_count,
limit=self.risk.max_concurrent_positions,
)
return False
return self.check_all()
# ------------------------------------------------------------------
# Individual checks
# ------------------------------------------------------------------
def _check_daily_loss(self) -> bool:
"""Check if daily loss limit has been breached."""
now = time.time()
# Cache daily PnL check (recompute every 10 seconds)
if now - self._last_pnl_check > 10:
self._daily_pnl_cache = self.db.get_today_pnl()
self._last_pnl_check = now
daily_pnl = self._daily_pnl_cache + self.tracker.total_unrealized_pnl
if daily_pnl < -self.risk.max_daily_loss_usd:
self._halt("daily_loss_limit_breached",
f"Daily PnL ${daily_pnl:.2f} < -${self.risk.max_daily_loss_usd}")
return False
return True
def _check_total_exposure(self) -> bool:
"""Check if total exposure is within limits."""
exposure = self.tracker.total_exposure
if exposure > self.risk.max_total_exposure_usd:
self._halt("exposure_limit_breached",
f"Exposure ${exposure:.2f} > ${self.risk.max_total_exposure_usd}")
return False
return True
def _check_position_count(self) -> bool:
"""Check if position count is within limits."""
count = self.tracker.position_count
if count > self.risk.max_concurrent_positions:
log.warning("position_count_exceeded", count=count, limit=self.risk.max_concurrent_positions)
return False
return True
# ------------------------------------------------------------------
# Halt/resume
# ------------------------------------------------------------------
def _halt(self, halt_event: str, reason: str) -> None:
"""Halt all trading."""
self._halted = True
self._halt_reason = reason
log.critical("trading_halted", halt_event=halt_event, reason=reason)
def resume(self) -> None:
"""Manually resume trading (use with caution)."""
log.warning("trading_resumed", previous_halt_reason=self._halt_reason)
self._halted = False
self._halt_reason = ""
# ------------------------------------------------------------------
# Reporting
# ------------------------------------------------------------------
def get_risk_summary(self) -> dict:
"""Return current risk state for dashboard/logging."""
return {
"halted": self._halted,
"halt_reason": self._halt_reason,
"daily_pnl": round(self._daily_pnl_cache, 2),
"total_exposure": round(self.tracker.total_exposure, 2),
"exposure_pct": round(
self.tracker.total_exposure / self.risk.max_total_exposure_usd * 100, 1
),
"positions": self.tracker.position_count,
"max_positions": self.risk.max_concurrent_positions,
"daily_loss_limit": self.risk.max_daily_loss_usd,
"exposure_limit": self.risk.max_total_exposure_usd,
}

13
src/strategy/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Strategy modules for the Polymarket Arbitrage Bot."""
from src.strategy.signal import SignalAggregator
from src.strategy.spread_capture import SpreadCaptureStrategy
from src.strategy.sum_to_one import SumToOneStrategy
from src.strategy.temporal_arb import TemporalArbStrategy
__all__ = [
"TemporalArbStrategy",
"SumToOneStrategy",
"SpreadCaptureStrategy",
"SignalAggregator",
]

196
src/strategy/signal.py Normal file
View File

@@ -0,0 +1,196 @@
"""Signal Aggregator — coordinates all strategies and manages signal flow.
Receives market data from feeds and window tracker, evaluates all active
strategies, and emits consolidated signals for the execution engine.
"""
from __future__ import annotations
import asyncio
import time
from typing import Callable, Optional
import structlog
from src.config import Config
from src.data.models import (
ActiveMarket,
Direction,
OrderBookSnapshot,
Signal,
WindowState,
)
from src.strategy.spread_capture import SpreadCaptureStrategy
from src.strategy.sum_to_one import SumToOneStrategy
from src.strategy.temporal_arb import TemporalArbStrategy
log = structlog.get_logger()
SignalCallback = Callable[[Signal], None]
class SignalAggregator:
"""Coordinates all strategies and emits trading signals.
Connects to the WindowTracker and PolymarketFeed to receive
real-time data, runs strategy evaluations, and dispatches
signals to registered handlers.
"""
def __init__(self, config: Config, balance: float = 10000.0) -> None:
self.config = config
self._callbacks: list[SignalCallback] = []
# Initialize strategies
self.temporal_arb = TemporalArbStrategy(
arb_config=config.temporal_arb,
risk_config=config.risk,
fees_config=config.fees,
balance=balance,
)
self.sum_to_one = SumToOneStrategy(
sto_config=config.sum_to_one,
risk_config=config.risk,
fees_config=config.fees,
balance=balance,
)
self.spread_capture = SpreadCaptureStrategy(
spread_config=config.spread_capture,
risk_config=config.risk,
fees_config=config.fees,
balance=balance,
)
# Rate limiting: don't evaluate the same market more than once per second
self._last_eval_time: dict[str, float] = {}
self._min_eval_interval = 0.5 # seconds
# Track recent signals to avoid duplicates
self._recent_signals: dict[str, float] = {}
self._signal_cooldown = 30.0 # seconds between signals for same market
def on_signal(self, callback: SignalCallback) -> None:
"""Register a callback for emitted signals."""
self._callbacks.append(callback)
def _emit(self, signal: Signal) -> None:
"""Dispatch a signal to all registered callbacks."""
# Dedup: don't emit same direction/asset/timeframe within cooldown
key = f"{signal.asset.value}:{signal.timeframe.value}:{signal.direction.value}"
now = time.time()
if key in self._recent_signals:
if now - self._recent_signals[key] < self._signal_cooldown:
log.debug("signal_deduplicated", key=key)
return
self._recent_signals[key] = now
for cb in self._callbacks:
try:
cb(signal)
except Exception:
log.exception("signal_callback_error")
async def on_price_tick(
self,
symbol: str,
cex_price: float,
window: Optional[WindowState],
orderbooks: dict[str, OrderBookSnapshot],
) -> None:
"""Called on each CEX price tick with current window and orderbook state.
Evaluates temporal arb and sum-to-one strategies.
"""
if window is None or window.market is None:
return
if window.start_price is None:
return
market = window.market
timeframe = window.timeframe.value
# Rate limiting
eval_key = f"{symbol}:{timeframe}"
now = time.time()
if eval_key in self._last_eval_time:
if now - self._last_eval_time[eval_key] < self._min_eval_interval:
return
self._last_eval_time[eval_key] = now
# Get Polymarket prices from orderbooks
up_book = orderbooks.get(market.up_token_id)
down_book = orderbooks.get(market.down_token_id)
poly_up_ask = up_book.best_ask if up_book else None
poly_down_ask = down_book.best_ask if down_book else None
# --- Temporal Arbitrage ---
if self.config.temporal_arb.enabled:
signal = await self.temporal_arb.evaluate(
symbol=symbol,
cex_price=cex_price,
window_start_price=window.start_price,
window_end_time=window.window_end_time,
poly_up_ask=poly_up_ask,
poly_down_ask=poly_down_ask,
up_token_id=market.up_token_id,
down_token_id=market.down_token_id,
timeframe=timeframe,
)
if signal:
self._emit(signal)
# --- Sum-to-One Arbitrage ---
if self.config.sum_to_one.enabled:
opp = await self.sum_to_one.evaluate(
asset=symbol,
timeframe=timeframe,
poly_up_ask=poly_up_ask,
poly_down_ask=poly_down_ask,
up_token_id=market.up_token_id,
down_token_id=market.down_token_id,
)
if opp:
up_sig, down_sig = self.sum_to_one.generate_signals(opp)
self._emit(up_sig)
self._emit(down_sig)
# --- Spread Capture ---
if self.config.spread_capture.enabled:
for token_id, book in [(market.up_token_id, up_book), (market.down_token_id, down_book)]:
if book:
quote = await self.spread_capture.evaluate(
asset=symbol,
timeframe=timeframe,
token_id=token_id,
orderbook=book,
)
# Spread quotes handled differently — not emitted as signals
def update_balance(self, balance: float) -> None:
"""Update balance across all strategies."""
self.temporal_arb.update_balance(balance)
self.sum_to_one.update_balance(balance)
self.spread_capture.update_balance(balance)
def get_stats(self) -> dict:
"""Return combined statistics from all strategies."""
return {
"temporal_arb": {
"evaluations": self.temporal_arb.stats.total_evaluations,
"signals": self.temporal_arb.stats.signals_generated,
"avg_edge": round(self.temporal_arb.stats.avg_edge, 4),
"by_asset": dict(self.temporal_arb.stats.signals_by_asset),
},
"sum_to_one": {
"evaluations": self.sum_to_one.stats.total_evaluations,
"opportunities": self.sum_to_one.stats.opportunities_found,
"total_net_profit": round(self.sum_to_one.stats.total_net_profit, 4),
},
"spread_capture": {
"evaluations": self.spread_capture.stats.total_evaluations,
"quotes": self.spread_capture.stats.quotes_generated,
"avg_spread": round(self.spread_capture.stats.avg_spread, 4),
},
}

View File

@@ -0,0 +1,164 @@
"""Spread Capture / Market Making Strategy.
Places limit orders on both sides of the bid-ask spread to capture
the spread as profit. Requires careful inventory management.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Optional
import structlog
from src.config import FeesConfig, RiskConfig, SpreadCaptureConfig
from src.data.models import (
Asset,
Direction,
OrderBookSnapshot,
Signal,
Timeframe,
)
log = structlog.get_logger()
@dataclass
class SpreadQuote:
"""A pair of limit orders to capture the spread."""
asset: Asset
timeframe: Timeframe
token_id: str
bid_price: float
ask_price: float
size: int
spread: float
timestamp: float = field(default_factory=time.time)
@dataclass
class SpreadStats:
total_evaluations: int = 0
quotes_generated: int = 0
avg_spread: float = 0.0
_total_spread: float = 0.0
class SpreadCaptureStrategy:
"""Market making strategy that places limit orders on both sides
of the spread. Profits from the bid-ask spread minus fees.
Note: Requires orders to rest for >=3.5 seconds for maker rebate eligibility.
"""
def __init__(
self,
spread_config: SpreadCaptureConfig,
risk_config: RiskConfig,
fees_config: FeesConfig,
balance: float = 10000.0,
) -> None:
self.spread = spread_config
self.risk = risk_config
self.fees = fees_config
self.balance = balance
self.stats = SpreadStats()
# Active quotes per token to avoid over-quoting
self._active_quotes: dict[str, SpreadQuote] = {}
async def evaluate(
self,
asset: str,
timeframe: str,
token_id: str,
orderbook: OrderBookSnapshot,
) -> Optional[SpreadQuote]:
"""Evaluate whether to place spread-capture quotes on this token.
Returns a SpreadQuote if the spread is wide enough to be profitable.
"""
self.stats.total_evaluations += 1
if not self.spread.enabled:
return None
if orderbook.best_bid is None or orderbook.best_ask is None:
return None
current_spread = orderbook.best_ask - orderbook.best_bid
if current_spread < self.spread.spread_target:
return None # Spread too tight
# Place our orders inside the current spread
# Improve by 1 cent on each side
our_bid = orderbook.best_bid + 0.01
our_ask = orderbook.best_ask - 0.01
our_spread = our_ask - our_bid
if our_spread <= 0:
return None
# Check profitability after fees (maker gets rebate, but conservative estimate)
taker_fee = self.fees.fee_for_timeframe(timeframe)
# As maker, fee is lower or zero (rebate), but we budget for taker as worst case
estimated_profit_per_share = our_spread - (taker_fee * 2)
if estimated_profit_per_share <= 0:
return None
# Size: conservative for market making
max_dollar = min(
self.balance * 0.10, # 10% of balance per quote
self.risk.max_position_per_market_usd * 0.5,
)
size = int(max_dollar / our_ask)
if size <= 0:
return None
# Check if we already have an active quote
if token_id in self._active_quotes:
old = self._active_quotes[token_id]
# Only update if spread changed significantly
if abs(old.bid_price - our_bid) < 0.005 and abs(old.ask_price - our_ask) < 0.005:
return None
quote = SpreadQuote(
asset=Asset(asset),
timeframe=Timeframe(timeframe),
token_id=token_id,
bid_price=our_bid,
ask_price=our_ask,
size=size,
spread=our_spread,
)
self._active_quotes[token_id] = quote
self.stats.quotes_generated += 1
self.stats._total_spread += our_spread
self.stats.avg_spread = self.stats._total_spread / self.stats.quotes_generated
log.info(
"spread_quote",
asset=asset,
timeframe=timeframe,
bid=our_bid,
ask=our_ask,
spread=round(our_spread, 4),
size=size,
)
return quote
def remove_quote(self, token_id: str) -> None:
"""Remove an active quote (e.g., after fill or cancel)."""
self._active_quotes.pop(token_id, None)
def get_active_quotes(self) -> dict[str, SpreadQuote]:
return dict(self._active_quotes)
def update_balance(self, new_balance: float) -> None:
self.balance = new_balance

178
src/strategy/sum_to_one.py Normal file
View File

@@ -0,0 +1,178 @@
"""Sum-to-One Arbitrage Strategy.
When YES + NO best ask prices sum to less than $1.00 (minus fees),
buy both sides for a guaranteed risk-free profit regardless of outcome.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Optional
import structlog
from src.config import FeesConfig, RiskConfig, SumToOneConfig
from src.data.models import Asset, Direction, Signal, Timeframe
log = structlog.get_logger()
@dataclass
class SumToOneOpportunity:
"""Represents a sum-to-one arbitrage opportunity."""
asset: Asset
timeframe: Timeframe
up_ask: float
down_ask: float
total_cost: float # up_ask + down_ask
gross_profit: float # 1.0 - total_cost
fee: float # Total fees for both sides
net_profit: float # gross_profit - fee
up_token_id: str
down_token_id: str
timestamp: float = field(default_factory=time.time)
@dataclass
class SumToOneStats:
total_evaluations: int = 0
opportunities_found: int = 0
total_net_profit: float = 0.0
class SumToOneStrategy:
"""Sum-to-One arbitrage: buy both YES and NO when their combined
ask price is less than $1.00 after fees."""
def __init__(
self,
sto_config: SumToOneConfig,
risk_config: RiskConfig,
fees_config: FeesConfig,
balance: float = 10000.0,
) -> None:
self.sto = sto_config
self.risk = risk_config
self.fees = fees_config
self.balance = balance
self.stats = SumToOneStats()
async def evaluate(
self,
asset: str,
timeframe: str,
poly_up_ask: Optional[float],
poly_down_ask: Optional[float],
up_token_id: str,
down_token_id: str,
) -> Optional[SumToOneOpportunity]:
"""Check if a sum-to-one arb opportunity exists.
Returns an opportunity if YES ask + NO ask + fees < $1.00.
"""
self.stats.total_evaluations += 1
if poly_up_ask is None or poly_down_ask is None:
return None
if poly_up_ask <= 0 or poly_down_ask <= 0:
return None
if poly_up_ask >= 1.0 or poly_down_ask >= 1.0:
return None
total_cost = poly_up_ask + poly_down_ask
# Already sums to >= $1.00 — no arb
if total_cost >= 1.0:
return None
gross_profit = 1.0 - total_cost
# Fee: taker fee applies to BOTH purchases
taker_fee = self.fees.fee_for_timeframe(timeframe)
# Fee is calculated on the payout, not the cost
# For each side, fee = taker_fee * (payout - cost) when it wins
# Simplified: total fee ≈ taker_fee * 1.0 (since one side pays out $1)
total_fee = taker_fee * 1.0
net_profit = gross_profit - total_fee
if net_profit < self.sto.min_spread_after_fee:
return None
opp = SumToOneOpportunity(
asset=Asset(asset),
timeframe=Timeframe(timeframe),
up_ask=poly_up_ask,
down_ask=poly_down_ask,
total_cost=total_cost,
gross_profit=gross_profit,
fee=total_fee,
net_profit=net_profit,
up_token_id=up_token_id,
down_token_id=down_token_id,
)
self.stats.opportunities_found += 1
self.stats.total_net_profit += net_profit
log.info(
"sum_to_one_opportunity",
asset=asset,
timeframe=timeframe,
up_ask=poly_up_ask,
down_ask=poly_down_ask,
total_cost=round(total_cost, 4),
net_profit=round(net_profit, 4),
)
return opp
def calculate_size(self, opportunity: SumToOneOpportunity) -> int:
"""Calculate position size for a sum-to-one trade.
Since this is risk-free, we can size more aggressively,
limited only by available liquidity and max position.
"""
# Max dollar exposure per side
max_per_side = min(
self.balance * 0.5, # Don't commit more than 50% of balance
self.risk.max_position_per_market_usd,
)
# Shares are constrained by the more expensive side
max_price = max(opportunity.up_ask, opportunity.down_ask)
shares = int(max_per_side / max_price)
return max(0, shares)
def generate_signals(self, opportunity: SumToOneOpportunity) -> tuple[Signal, Signal]:
"""Generate a pair of BUY signals (one for UP, one for DOWN)."""
size = self.calculate_size(opportunity)
up_signal = Signal(
direction=Direction.UP,
asset=opportunity.asset,
timeframe=opportunity.timeframe,
token_id=opportunity.up_token_id,
price=opportunity.up_ask,
size=size,
edge=opportunity.net_profit,
estimated_prob=0.5, # Direction-agnostic
)
down_signal = Signal(
direction=Direction.DOWN,
asset=opportunity.asset,
timeframe=opportunity.timeframe,
token_id=opportunity.down_token_id,
price=opportunity.down_ask,
size=size,
edge=opportunity.net_profit,
estimated_prob=0.5,
)
return up_signal, down_signal
def update_balance(self, new_balance: float) -> None:
self.balance = new_balance

View File

@@ -0,0 +1,297 @@
"""Temporal Arbitrage Strategy — core strategy module.
Exploits the delay between CEX price movements and Polymarket oracle updates.
When Binance price confirms a direction but Polymarket odds haven't caught up,
buy the confirmed direction at a discount.
"""
from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from typing import Optional
import structlog
from src.config import FeesConfig, RiskConfig, TemporalArbConfig
from src.data.models import Asset, Direction, Signal, Timeframe
log = structlog.get_logger()
@dataclass
class StrategyStats:
"""Running statistics for the strategy."""
total_evaluations: int = 0
signals_generated: int = 0
signals_by_asset: dict[str, int] = field(default_factory=lambda: {"BTC": 0, "ETH": 0, "SOL": 0})
total_edge: float = 0.0
@property
def avg_edge(self) -> float:
return self.total_edge / self.signals_generated if self.signals_generated > 0 else 0.0
class TemporalArbStrategy:
"""Core temporal arbitrage strategy.
Monitors CEX price vs Polymarket odds and generates buy signals when
the CEX price has confirmed a direction but Polymarket hasn't adjusted.
"""
def __init__(
self,
arb_config: TemporalArbConfig,
risk_config: RiskConfig,
fees_config: FeesConfig,
balance: float = 10000.0,
) -> None:
self.arb = arb_config
self.risk = risk_config
self.fees = fees_config
self.balance = balance
self.stats = StrategyStats()
# ------------------------------------------------------------------
# Core evaluation
# ------------------------------------------------------------------
async def evaluate(
self,
symbol: str,
cex_price: float,
window_start_price: float,
window_end_time: float,
poly_up_ask: Optional[float],
poly_down_ask: Optional[float],
up_token_id: str,
down_token_id: str,
timeframe: str,
) -> Optional[Signal]:
"""Evaluate whether to enter a temporal arb position.
Called on every CEX price tick. Returns a Signal if conditions are met,
otherwise None.
"""
self.stats.total_evaluations += 1
if window_start_price <= 0:
return None
# 1. Price direction & magnitude
price_change_pct = (cex_price - window_start_price) / window_start_price * 100
if abs(price_change_pct) < self.arb.min_price_move_pct:
return None # Not enough movement to confirm direction
# 2. Direction
direction = Direction.UP if price_change_pct > 0 else Direction.DOWN
# 3. Time remaining
now = time.time()
time_remaining = window_end_time - now
if time_remaining < self.arb.exit_before_resolution_sec:
return None # Too close to resolution
# 4. Polymarket price
poly_price = poly_up_ask if direction == Direction.UP else poly_down_ask
if poly_price is None or poly_price <= 0 or poly_price >= 1.0:
return None
if poly_price > self.arb.max_poly_entry_price:
return None # Risk/reward insufficient
# 5. Probability estimation
total_window = self._total_window_seconds(timeframe)
estimated_prob = self.estimate_probability(
price_change_pct, time_remaining, total_window
)
# 6. Edge calculation (after fees)
taker_fee = self.fees.fee_for_timeframe(timeframe)
edge = estimated_prob - poly_price - taker_fee
if edge < self.arb.min_edge:
return None # Insufficient edge
# 7. Position sizing
asset = Asset(symbol)
size = self.calculate_kelly_size(
edge=edge,
price=poly_price,
balance=self.balance,
max_size=self.risk.max_position_per_market_usd,
)
if size <= 0:
return None
# 8. Build signal
token_id = up_token_id if direction == Direction.UP else down_token_id
tf = Timeframe(timeframe)
signal = Signal(
direction=direction,
asset=asset,
timeframe=tf,
token_id=token_id,
price=poly_price,
size=size,
edge=edge,
estimated_prob=estimated_prob,
)
# Update stats
self.stats.signals_generated += 1
self.stats.signals_by_asset[symbol] = self.stats.signals_by_asset.get(symbol, 0) + 1
self.stats.total_edge += edge
log.info(
"signal_generated",
asset=symbol,
direction=direction.value,
timeframe=timeframe,
price_change_pct=round(price_change_pct, 4),
poly_price=poly_price,
estimated_prob=round(estimated_prob, 4),
edge=round(edge, 4),
size=size,
time_remaining=round(time_remaining, 1),
)
return signal
# ------------------------------------------------------------------
# Probability model
# ------------------------------------------------------------------
def estimate_probability(
self,
price_change_pct: float,
time_remaining: float,
total_window_sec: float,
) -> float:
"""Estimate the true probability that the current direction holds at resolution.
Multi-factor model:
1. Base probability from price magnitude
2. Time decay: as window nears end, momentum is more confirmed
3. Volatility: larger moves are harder to reverse
"""
abs_change = abs(price_change_pct)
# Factor 1: Base probability from price magnitude
# 0.15% → ~70%, 0.3% → ~82%, 0.5% → ~90%, 1.0%+ → ~95%
base_prob = 0.55 + abs_change * 1.0 # 1.0 scaling factor
base_prob = min(base_prob, 0.95)
# Factor 2: Time decay — more time elapsed = more confirmation
# If 80% of window has passed and price is still in this direction,
# the probability of reversal is lower
elapsed_fraction = max(0, 1.0 - (time_remaining / total_window_sec))
# Sigmoid-like boost: ramps up in the last 40% of the window
time_factor = 1.0 + 0.08 * max(0, elapsed_fraction - 0.6) / 0.4
time_factor = min(time_factor, 1.08)
# Factor 3: Volatility / momentum strength
# Very large moves (>0.5%) get extra confidence
if abs_change > 0.5:
vol_boost = min(0.05, (abs_change - 0.5) * 0.1)
else:
vol_boost = 0.0
final_prob = base_prob * time_factor + vol_boost
return min(0.95, max(0.50, final_prob))
# ------------------------------------------------------------------
# Position sizing
# ------------------------------------------------------------------
def calculate_kelly_size(
self,
edge: float,
price: float,
balance: float,
max_size: float,
) -> int:
"""Kelly Criterion position sizing.
f* = (b*p - q) / b
where b = (1/price - 1), p = estimated_prob, q = 1-p
"""
if price <= 0 or price >= 1.0:
return 0
b = (1.0 / price) - 1.0 # Payout odds
p = edge + price # estimated_prob (edge = prob - price - fee)
q = 1.0 - p
if b <= 0:
return 0
kelly_fraction = (b * p - q) / b
kelly_fraction = max(0.0, min(kelly_fraction, self.risk.kelly_fraction_cap))
dollar_size = min(balance * kelly_fraction, max_size)
shares = int(dollar_size / price)
return max(0, shares)
# ------------------------------------------------------------------
# Exit logic
# ------------------------------------------------------------------
def should_exit_early(
self,
entry_direction: Direction,
entry_price: float,
current_poly_price: float,
cex_price: float,
window_start_price: float,
time_remaining: float,
) -> bool:
"""Determine if we should exit a position early.
Exit if:
1. Price has reversed and edge has disappeared
2. Polymarket price has risen enough to take profit
3. Very close to resolution with negative PnL trajectory
"""
if window_start_price <= 0:
return False
current_change_pct = (cex_price - window_start_price) / window_start_price * 100
# Direction reversed significantly
if entry_direction == Direction.UP and current_change_pct < -0.05:
log.warning("exit_signal_reversal", direction="UP", change_pct=round(current_change_pct, 4))
return True
if entry_direction == Direction.DOWN and current_change_pct > 0.05:
log.warning("exit_signal_reversal", direction="DOWN", change_pct=round(current_change_pct, 4))
return True
# Take profit: price has moved significantly in our favor
if current_poly_price > 0.90 and current_poly_price > entry_price * 1.3:
log.info("exit_signal_take_profit", current_price=current_poly_price, entry_price=entry_price)
return True
# Close to resolution with thin margin
if time_remaining < 10 and abs(current_change_pct) < 0.05:
log.warning("exit_signal_thin_margin", time_remaining=round(time_remaining, 1))
return True
return False
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _total_window_seconds(timeframe: str) -> float:
"""Return total seconds for a given timeframe."""
return {"5M": 300.0, "15M": 900.0}.get(timeframe, 300.0)
def update_balance(self, new_balance: float) -> None:
"""Update available balance for position sizing."""
self.balance = new_balance

14
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
"""Utility modules for the Polymarket Arbitrage Bot."""
from .logger import get_logger, log_timing, log_trade, setup_logging
from .metrics import MetricsCollector
from .telegram import TelegramNotifier
__all__ = [
"get_logger",
"log_timing",
"log_trade",
"setup_logging",
"MetricsCollector",
"TelegramNotifier",
]

153
src/utils/logger.py Normal file
View File

@@ -0,0 +1,153 @@
"""Structured logging setup using structlog for the Polymarket Arbitrage Bot.
Provides consistent, structured logging across all modules with support
for dev-friendly console output and production JSON rendering.
"""
from __future__ import annotations
import logging
import sys
import time
from contextlib import contextmanager
from typing import Any, Generator
import structlog
def setup_logging(log_level: str = "INFO") -> None:
"""Configure structlog with processors, renderers, and stdlib integration.
Parameters
----------
log_level:
Root log level (e.g. ``"DEBUG"``, ``"INFO"``, ``"WARNING"``).
Also accepts lowercase variants.
"""
level = getattr(logging, log_level.upper(), logging.INFO)
# Shared processors applied to every log entry
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.CallsiteParameterAdder(
[
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.LINENO,
]
),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.UnicodeDecoder(),
]
# Choose renderer based on whether we're attached to a terminal (dev) or
# running headless / in production.
if sys.stderr.isatty():
renderer: structlog.types.Processor = structlog.dev.ConsoleRenderer(
colors=True,
)
else:
renderer = structlog.processors.JSONRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure stdlib root logger so third-party libraries also go through
# structlog's formatting pipeline.
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(level)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""Return a bound structlog logger for *name*.
Parameters
----------
name:
Typically ``__name__`` of the calling module.
"""
return structlog.get_logger(name)
def log_trade(logger: structlog.stdlib.BoundLogger, trade_data: dict[str, Any]) -> None:
"""Emit a structured trade event.
Extracts key fields from *trade_data* and logs them as a single
structured event for easy querying and dashboarding.
Parameters
----------
logger:
A bound structlog logger.
trade_data:
Dictionary with trade details (e.g. ``id``, ``asset``, ``direction``,
``entry_price``, ``fill_price``, ``size``, ``pnl``, ``status``).
"""
logger.info(
"trade_event",
trade_id=trade_data.get("id"),
asset=trade_data.get("asset"),
direction=trade_data.get("direction"),
token_id=trade_data.get("token_id"),
entry_price=trade_data.get("entry_price"),
fill_price=trade_data.get("fill_price"),
size=trade_data.get("size"),
fee=trade_data.get("fee"),
pnl=trade_data.get("pnl"),
status=trade_data.get("status"),
edge=trade_data.get("edge"),
)
@contextmanager
def log_timing(
logger: structlog.stdlib.BoundLogger,
operation: str,
) -> Generator[None, None, None]:
"""Context manager that logs the elapsed wall-clock time of *operation*.
Usage::
with log_timing(logger, "fetch_orderbook"):
await fetch_orderbook()
Parameters
----------
logger:
A bound structlog logger.
operation:
Human-readable label for the timed block.
"""
start = time.perf_counter()
logger.debug("timing_start", operation=operation)
try:
yield
finally:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.info(
"timing_end",
operation=operation,
elapsed_ms=round(elapsed_ms, 2),
)

116
src/utils/metrics.py Normal file
View File

@@ -0,0 +1,116 @@
"""Metrics collector for the performance dashboard."""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TradeMetric:
"""Single trade metric for time-series tracking."""
timestamp: float
asset: str
direction: str
timeframe: str
entry_price: float
size: int
edge: float
pnl: Optional[float] = None
won: Optional[bool] = None
class MetricsCollector:
"""Collects and aggregates trading metrics for the dashboard.
Maintains rolling windows and aggregated stats for real-time display.
"""
def __init__(self, max_history: int = 10000) -> None:
self._trades: deque[TradeMetric] = deque(maxlen=max_history)
self._pnl_series: deque[tuple[float, float]] = deque(maxlen=max_history)
self._start_time = time.time()
# Running aggregates
self.total_volume: float = 0.0
self.total_fees: float = 0.0
def record_trade(self, metric: TradeMetric) -> None:
self._trades.append(metric)
self.total_volume += metric.entry_price * metric.size
def record_pnl(self, pnl: float) -> None:
self._pnl_series.append((time.time(), pnl))
def record_fee(self, fee: float) -> None:
self.total_fees += fee
# ------------------------------------------------------------------
# Aggregations
# ------------------------------------------------------------------
def get_recent_trades(self, n: int = 50) -> list[TradeMetric]:
return list(self._trades)[-n:]
def get_pnl_series(self) -> list[tuple[float, float]]:
return list(self._pnl_series)
def get_hourly_stats(self) -> dict:
"""Aggregate stats for the last hour."""
cutoff = time.time() - 3600
recent = [t for t in self._trades if t.timestamp > cutoff]
wins = sum(1 for t in recent if t.won is True)
losses = sum(1 for t in recent if t.won is False)
total_pnl = sum(t.pnl for t in recent if t.pnl is not None)
return {
"trades": len(recent),
"wins": wins,
"losses": losses,
"win_rate": wins / (wins + losses) * 100 if (wins + losses) > 0 else 0,
"pnl": round(total_pnl, 2),
}
def get_asset_breakdown(self) -> dict[str, dict]:
"""PnL and trade count per asset."""
breakdown: dict[str, dict] = {}
for t in self._trades:
if t.asset not in breakdown:
breakdown[t.asset] = {"trades": 0, "wins": 0, "pnl": 0.0}
breakdown[t.asset]["trades"] += 1
if t.won is True:
breakdown[t.asset]["wins"] += 1
if t.pnl is not None:
breakdown[t.asset]["pnl"] += t.pnl
for asset in breakdown:
total = breakdown[asset]["trades"]
wins = breakdown[asset]["wins"]
breakdown[asset]["win_rate"] = round(wins / total * 100, 1) if total > 0 else 0
breakdown[asset]["pnl"] = round(breakdown[asset]["pnl"], 2)
return breakdown
def get_uptime(self) -> float:
"""Return uptime in seconds."""
return time.time() - self._start_time
def get_summary(self) -> dict:
total_trades = len(self._trades)
wins = sum(1 for t in self._trades if t.won is True)
losses = sum(1 for t in self._trades if t.won is False)
total_pnl = sum(t.pnl for t in self._trades if t.pnl is not None)
return {
"total_trades": total_trades,
"wins": wins,
"losses": losses,
"win_rate": round(wins / (wins + losses) * 100, 1) if (wins + losses) > 0 else 0,
"total_pnl": round(total_pnl, 2),
"total_volume": round(self.total_volume, 2),
"total_fees": round(self.total_fees, 2),
"uptime_hours": round(self.get_uptime() / 3600, 2),
}

152
src/utils/telegram.py Normal file
View File

@@ -0,0 +1,152 @@
"""Telegram notification bot for trade alerts and daily summaries."""
from __future__ import annotations
import asyncio
from typing import Optional
import structlog
from src.config import NotificationConfig
log = structlog.get_logger()
class TelegramNotifier:
"""Sends trading notifications via Telegram Bot API.
Uses aiohttp directly (no python-telegram-bot dependency for async)
for lightweight non-blocking notification delivery.
"""
BASE_URL = "https://api.telegram.org/bot{token}/sendMessage"
def __init__(self, config: NotificationConfig) -> None:
self.config = config
self._enabled = config.telegram_enabled and bool(config.telegram_token) and bool(config.telegram_chat_id)
self._session = None
if not self._enabled:
log.info("telegram_disabled", reason="missing token or chat_id")
async def _get_session(self):
if self._session is None:
import aiohttp
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
if self._session:
await self._session.close()
self._session = None
async def send(self, message: str, parse_mode: str = "HTML") -> bool:
"""Send a message to the configured Telegram chat."""
if not self._enabled:
return False
url = self.BASE_URL.format(token=self.config.telegram_token)
payload = {
"chat_id": self.config.telegram_chat_id,
"text": message,
"parse_mode": parse_mode,
}
try:
session = await self._get_session()
async with session.post(url, json=payload, timeout=10) as resp:
if resp.status == 200:
return True
else:
body = await resp.text()
log.warning("telegram_send_failed", status=resp.status, body=body[:200])
return False
except Exception:
log.exception("telegram_send_error")
return False
# ------------------------------------------------------------------
# Convenience methods
# ------------------------------------------------------------------
async def notify_trade(
self,
asset: str,
direction: str,
timeframe: str,
price: float,
size: int,
edge: float,
) -> None:
"""Send a trade notification."""
if not self.config.notify_on_trade:
return
msg = (
f"🔔 <b>Trade Signal</b>\n"
f"Asset: <b>{asset}</b> | {direction}\n"
f"Timeframe: {timeframe}\n"
f"Price: {price:.2f} | Size: {size}\n"
f"Edge: {edge:.2%}"
)
await self.send(msg)
async def notify_fill(
self,
asset: str,
direction: str,
fill_price: float,
fill_size: int,
trade_id: str,
) -> None:
"""Send a fill notification."""
if not self.config.notify_on_trade:
return
msg = (
f"✅ <b>Order Filled</b>\n"
f"Asset: {asset} | {direction}\n"
f"Fill: {fill_price:.2f} × {fill_size}\n"
f"Trade ID: <code>{trade_id}</code>"
)
await self.send(msg)
async def notify_daily_summary(
self,
date: str,
total_trades: int,
wins: int,
losses: int,
pnl: float,
fees: float,
volume: float,
) -> None:
"""Send daily summary."""
if not self.config.notify_on_daily_summary:
return
win_rate = wins / total_trades * 100 if total_trades > 0 else 0
emoji = "📈" if pnl >= 0 else "📉"
msg = (
f"{emoji} <b>Daily Summary — {date}</b>\n"
f"Trades: {total_trades} (W:{wins} / L:{losses})\n"
f"Win Rate: {win_rate:.1f}%\n"
f"PnL: <b>${pnl:+.2f}</b>\n"
f"Fees: ${fees:.2f}\n"
f"Volume: ${volume:,.0f}"
)
await self.send(msg)
async def notify_error(self, error_msg: str) -> None:
"""Send an error alert."""
if not self.config.notify_on_error:
return
msg = f"🚨 <b>Error Alert</b>\n<code>{error_msg[:500]}</code>"
await self.send(msg)
async def notify_halt(self, reason: str) -> None:
"""Send a trading halt alert."""
msg = f"🛑 <b>TRADING HALTED</b>\nReason: {reason}"
await self.send(msg)