update 03-22 09:28
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
262
src/config.py
Normal file
262
src/config.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Configuration loader for the Polymarket Temporal Arbitrage Bot.
|
||||
|
||||
Loads environment variables from .env and typed settings from config.toml,
|
||||
exposing them through nested dataclasses with validated, typed access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
import toml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Nested config dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
mode: str = "paper"
|
||||
assets: list[str] = field(default_factory=lambda: ["BTC", "ETH", "SOL"])
|
||||
timeframes: list[str] = field(default_factory=lambda: ["5M", "15M"])
|
||||
log_level: str = "INFO"
|
||||
starting_balance: float = 500.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemporalArbConfig:
|
||||
enabled: bool = True
|
||||
min_price_move_pct: float = 0.15
|
||||
max_poly_entry_price: float = 0.65
|
||||
min_edge: float = 0.20
|
||||
exit_before_resolution_sec: int = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SumToOneConfig:
|
||||
enabled: bool = True
|
||||
min_spread_after_fee: float = 0.02
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpreadCaptureConfig:
|
||||
enabled: bool = False
|
||||
spread_target: float = 0.04
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RiskConfig:
|
||||
max_position_per_market_usd: float = 5000.0
|
||||
max_total_exposure_usd: float = 20000.0
|
||||
max_daily_loss_usd: float = 2000.0
|
||||
kelly_fraction_cap: float = 0.25
|
||||
max_concurrent_positions: int = 6
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FeesConfig:
|
||||
taker_fee_5m: float = 0.0156
|
||||
taker_fee_15m: float = 0.03
|
||||
|
||||
def fee_for_timeframe(self, timeframe: str) -> float:
|
||||
"""Return the taker fee for a given timeframe string (e.g. '5M')."""
|
||||
mapping = {
|
||||
"5M": self.taker_fee_5m,
|
||||
"15M": self.taker_fee_15m,
|
||||
}
|
||||
if timeframe not in mapping:
|
||||
raise ValueError(f"Unknown timeframe '{timeframe}'; expected one of {list(mapping)}")
|
||||
return mapping[timeframe]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BinanceConfig:
|
||||
ws_url: str = "wss://stream.binance.com:9443/stream"
|
||||
symbols: list[str] = field(default_factory=lambda: ["btcusdt", "ethusdt", "solusdt"])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PolymarketConfig:
|
||||
clob_url: str = "https://clob.polymarket.com"
|
||||
gamma_url: str = "https://gamma-api.polymarket.com"
|
||||
data_url: str = "https://data-api.polymarket.com"
|
||||
ws_url: str = "wss://ws-subscriptions-clob.polymarket.com/ws/"
|
||||
chain_id: int = 137
|
||||
signature_type: int = 2
|
||||
|
||||
# Credentials sourced from environment variables
|
||||
private_key: str = ""
|
||||
proxy_wallet: str = ""
|
||||
api_key: str = ""
|
||||
api_secret: str = ""
|
||||
api_passphrase: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NotificationConfig:
|
||||
telegram_enabled: bool = True
|
||||
telegram_token: str = ""
|
||||
telegram_chat_id: str = ""
|
||||
notify_on_trade: bool = True
|
||||
notify_on_daily_summary: bool = True
|
||||
notify_on_error: bool = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
general: GeneralConfig = field(default_factory=GeneralConfig)
|
||||
temporal_arb: TemporalArbConfig = field(default_factory=TemporalArbConfig)
|
||||
sum_to_one: SumToOneConfig = field(default_factory=SumToOneConfig)
|
||||
spread_capture: SpreadCaptureConfig = field(default_factory=SpreadCaptureConfig)
|
||||
risk: RiskConfig = field(default_factory=RiskConfig)
|
||||
fees: FeesConfig = field(default_factory=FeesConfig)
|
||||
binance: BinanceConfig = field(default_factory=BinanceConfig)
|
||||
polymarket: PolymarketConfig = field(default_factory=PolymarketConfig)
|
||||
notifications: NotificationConfig = field(default_factory=NotificationConfig)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_dataclass(cls: type, raw: dict[str, Any]) -> Any:
|
||||
"""Instantiate a frozen dataclass, silently ignoring unknown keys."""
|
||||
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
||||
filtered = {k: v for k, v in raw.items() if k in valid_fields}
|
||||
return cls(**filtered)
|
||||
|
||||
|
||||
def _setup_logging(level_name: str) -> None:
|
||||
"""Configure structlog with human-readable console output."""
|
||||
level = getattr(logging, level_name.upper(), logging.INFO)
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(level),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Also align stdlib logging so third-party libraries respect the level
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=level,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_config(config_path: str = "config.toml") -> Config:
|
||||
"""Load configuration from *config_path* and environment variables.
|
||||
|
||||
1. Reads ``.env`` (if present) into the process environment.
|
||||
2. Parses ``config.toml`` for all trading parameters.
|
||||
3. Overlays sensitive credentials from environment variables.
|
||||
4. Configures structured logging via *structlog*.
|
||||
|
||||
Returns a fully-populated, immutable :class:`Config` instance.
|
||||
"""
|
||||
# --- .env ---------------------------------------------------------------
|
||||
env_path = Path(config_path).parent / ".env"
|
||||
load_dotenv(dotenv_path=env_path if env_path.exists() else None)
|
||||
|
||||
# --- config.toml --------------------------------------------------------
|
||||
config_file = Path(config_path)
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_file.resolve()}")
|
||||
|
||||
raw: dict[str, Any] = toml.load(config_file)
|
||||
|
||||
# --- Build nested configs -----------------------------------------------
|
||||
general = _build_dataclass(GeneralConfig, raw.get("general", {}))
|
||||
|
||||
strategy = raw.get("strategy", {})
|
||||
temporal_arb = _build_dataclass(TemporalArbConfig, strategy.get("temporal_arb", {}))
|
||||
sum_to_one = _build_dataclass(SumToOneConfig, strategy.get("sum_to_one", {}))
|
||||
spread_capture = _build_dataclass(SpreadCaptureConfig, strategy.get("spread_capture", {}))
|
||||
|
||||
risk = _build_dataclass(RiskConfig, raw.get("risk", {}))
|
||||
fees = _build_dataclass(FeesConfig, raw.get("fees", {}))
|
||||
|
||||
exchanges = raw.get("exchange", {})
|
||||
binance = _build_dataclass(BinanceConfig, exchanges.get("binance", {}))
|
||||
|
||||
# Polymarket: merge file config with env-var credentials
|
||||
poly_raw = dict(exchanges.get("polymarket", {}))
|
||||
poly_raw.update({
|
||||
"private_key": os.getenv("POLYMARKET_PRIVATE_KEY", ""),
|
||||
"proxy_wallet": os.getenv("POLYMARKET_PROXY_WALLET", ""),
|
||||
"api_key": os.getenv("POLYMARKET_API_KEY", ""),
|
||||
"api_secret": os.getenv("POLYMARKET_API_SECRET", ""),
|
||||
"api_passphrase": os.getenv("POLYMARKET_API_PASSPHRASE", ""),
|
||||
})
|
||||
polymarket = _build_dataclass(PolymarketConfig, poly_raw)
|
||||
|
||||
# Notifications: merge file config with env-var tokens
|
||||
notif_raw = dict(raw.get("notifications", {}))
|
||||
notif_raw.setdefault("telegram_token", os.getenv("TELEGRAM_BOT_TOKEN", ""))
|
||||
notif_raw.setdefault("telegram_chat_id", os.getenv("TELEGRAM_CHAT_ID", ""))
|
||||
# Env vars override empty strings from the config file
|
||||
if not notif_raw.get("telegram_token"):
|
||||
notif_raw["telegram_token"] = os.getenv("TELEGRAM_BOT_TOKEN", "")
|
||||
if not notif_raw.get("telegram_chat_id"):
|
||||
notif_raw["telegram_chat_id"] = os.getenv("TELEGRAM_CHAT_ID", "")
|
||||
notifications = _build_dataclass(NotificationConfig, notif_raw)
|
||||
|
||||
# --- Assemble top-level config ------------------------------------------
|
||||
config = Config(
|
||||
general=general,
|
||||
temporal_arb=temporal_arb,
|
||||
sum_to_one=sum_to_one,
|
||||
spread_capture=spread_capture,
|
||||
risk=risk,
|
||||
fees=fees,
|
||||
binance=binance,
|
||||
polymarket=polymarket,
|
||||
notifications=notifications,
|
||||
)
|
||||
|
||||
# --- Logging ------------------------------------------------------------
|
||||
_setup_logging(config.general.log_level)
|
||||
|
||||
log = structlog.get_logger()
|
||||
log.info(
|
||||
"config_loaded",
|
||||
mode=config.general.mode,
|
||||
assets=config.general.assets,
|
||||
timeframes=config.general.timeframes,
|
||||
strategies_enabled=[
|
||||
name
|
||||
for name, enabled in [
|
||||
("temporal_arb", config.temporal_arb.enabled),
|
||||
("sum_to_one", config.sum_to_one.enabled),
|
||||
("spread_capture", config.spread_capture.enabled),
|
||||
]
|
||||
if enabled
|
||||
],
|
||||
)
|
||||
|
||||
return config
|
||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Data layer for the Polymarket Arbitrage Bot."""
|
||||
|
||||
from .db import TradeDB
|
||||
|
||||
__all__ = ["TradeDB"]
|
||||
323
src/data/db.py
Normal file
323
src/data/db.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""SQLite trade-log database backed by sqlite-utils.
|
||||
|
||||
Provides persistent storage for trades, window snapshots, and daily
|
||||
performance summaries with lightweight query helpers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import structlog
|
||||
from sqlite_utils import Database
|
||||
|
||||
from src.data.models import Trade, WindowState
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class TradeDB:
|
||||
"""Thin wrapper around a SQLite database for trade logging and analytics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db_path:
|
||||
Path to the SQLite database file. Use ``":memory:"`` for tests.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "trades.db") -> None:
|
||||
self._db_path = db_path
|
||||
self._db = Database(db_path)
|
||||
self._log = logger.bind(component="TradeDB", db=db_path)
|
||||
self._ensure_tables()
|
||||
self._log.info("database_ready")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_tables(self) -> None:
|
||||
"""Create tables if they do not already exist."""
|
||||
|
||||
if "trades" not in self._db.table_names():
|
||||
self._db["trades"].create(
|
||||
{
|
||||
"id": str,
|
||||
"asset": str,
|
||||
"timeframe": str,
|
||||
"direction": str,
|
||||
"token_id": str,
|
||||
"entry_price": float,
|
||||
"fill_price": float,
|
||||
"size": int,
|
||||
"fee": float,
|
||||
"pnl": float,
|
||||
"status": str,
|
||||
"signal_edge": float,
|
||||
"signal_prob": float,
|
||||
"created_at": float,
|
||||
"updated_at": float,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="trades")
|
||||
|
||||
if "window_snapshots" not in self._db.table_names():
|
||||
self._db["window_snapshots"].create(
|
||||
{
|
||||
"id": int,
|
||||
"asset": str,
|
||||
"timeframe": str,
|
||||
"start_price": float,
|
||||
"end_price": float,
|
||||
"price_change_pct": float,
|
||||
"window_start": float,
|
||||
"window_end": float,
|
||||
"market_condition_id": str,
|
||||
"created_at": float,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="window_snapshots")
|
||||
|
||||
if "daily_summary" not in self._db.table_names():
|
||||
self._db["daily_summary"].create(
|
||||
{
|
||||
"date": str,
|
||||
"total_trades": int,
|
||||
"wins": int,
|
||||
"losses": int,
|
||||
"total_pnl": float,
|
||||
"total_fees": float,
|
||||
"total_volume": float,
|
||||
"best_trade_pnl": float,
|
||||
"worst_trade_pnl": float,
|
||||
},
|
||||
pk="date",
|
||||
)
|
||||
self._log.debug("table_created", table="daily_summary")
|
||||
|
||||
if "balance_history" not in self._db.table_names():
|
||||
self._db["balance_history"].create(
|
||||
{
|
||||
"id": int,
|
||||
"timestamp": float,
|
||||
"balance": float,
|
||||
"pnl": float,
|
||||
"event": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="balance_history")
|
||||
|
||||
if "oracle_snapshots" not in self._db.table_names():
|
||||
self._db["oracle_snapshots"].create(
|
||||
{
|
||||
"id": int,
|
||||
"timestamp": float,
|
||||
"asset": str,
|
||||
"oracle_price": float,
|
||||
"cex_price": float,
|
||||
"deviation_pct": float,
|
||||
"oracle_lag_sec": float,
|
||||
"oracle_round_id": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
self._log.debug("table_created", table="oracle_snapshots")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trade CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_trade(self, trade: Trade) -> None:
|
||||
"""Insert a new trade record derived from a :class:`Trade` dataclass."""
|
||||
row = {
|
||||
"id": trade.id,
|
||||
"asset": trade.signal.asset.value,
|
||||
"timeframe": trade.signal.timeframe.value,
|
||||
"direction": trade.signal.direction.value,
|
||||
"token_id": trade.signal.token_id,
|
||||
"entry_price": trade.signal.price,
|
||||
"fill_price": trade.fill_price,
|
||||
"size": trade.signal.size,
|
||||
"fee": trade.fee,
|
||||
"pnl": trade.pnl,
|
||||
"status": trade.status.value,
|
||||
"signal_edge": trade.signal.edge,
|
||||
"signal_prob": trade.signal.estimated_prob,
|
||||
"created_at": trade.created_at,
|
||||
"updated_at": trade.updated_at,
|
||||
}
|
||||
self._db["trades"].insert(row)
|
||||
self._log.info("trade_logged", trade_id=trade.id, asset=row["asset"])
|
||||
|
||||
def update_trade(self, trade_id: str, **fields: Any) -> None:
|
||||
"""Update arbitrary fields on an existing trade row."""
|
||||
fields["updated_at"] = time.time()
|
||||
self._db["trades"].update(trade_id, fields)
|
||||
self._log.info("trade_updated", trade_id=trade_id, fields=list(fields.keys()))
|
||||
|
||||
def get_trades(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
asset: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return trades matching optional time-range and asset filters."""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if start_time is not None:
|
||||
clauses.append("created_at >= ?")
|
||||
params.append(start_time)
|
||||
if end_time is not None:
|
||||
clauses.append("created_at <= ?")
|
||||
params.append(end_time)
|
||||
if asset is not None:
|
||||
clauses.append("asset = ?")
|
||||
params.append(asset)
|
||||
|
||||
where = " AND ".join(clauses) if clauses else "1=1"
|
||||
sql = f"SELECT * FROM trades WHERE {where} ORDER BY created_at DESC"
|
||||
return list(self._db.execute(sql, params).fetchall())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Window snapshots
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_window(self, window: WindowState) -> None:
|
||||
"""Snapshot a completed window into the database."""
|
||||
row = {
|
||||
"asset": window.asset.value,
|
||||
"timeframe": window.timeframe.value,
|
||||
"start_price": window.start_price,
|
||||
"end_price": window.current_price,
|
||||
"price_change_pct": window.price_change_pct,
|
||||
"window_start": window.window_start_time,
|
||||
"window_end": window.window_end_time,
|
||||
"market_condition_id": (
|
||||
window.market.condition_id if window.market else None
|
||||
),
|
||||
"created_at": time.time(),
|
||||
}
|
||||
self._db["window_snapshots"].insert(row)
|
||||
self._log.info(
|
||||
"window_logged",
|
||||
asset=row["asset"],
|
||||
timeframe=row["timeframe"],
|
||||
change_pct=row["price_change_pct"],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daily summary helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_daily_summary(self, date: str) -> dict[str, Any]:
|
||||
"""Compute (but do not store) daily stats for *date* (``YYYY-MM-DD``)."""
|
||||
sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl <= 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(SUM(fee), 0.0) AS total_fees,
|
||||
COALESCE(SUM(size * fill_price), 0.0) AS total_volume,
|
||||
COALESCE(MAX(pnl), 0.0) AS best_trade_pnl,
|
||||
COALESCE(MIN(pnl), 0.0) AS worst_trade_pnl
|
||||
FROM trades
|
||||
WHERE date(created_at, 'unixepoch') = ?
|
||||
"""
|
||||
row = self._db.execute(sql, [date]).fetchone()
|
||||
return {
|
||||
"date": date,
|
||||
"total_trades": row[0],
|
||||
"wins": row[1],
|
||||
"losses": row[2],
|
||||
"total_pnl": row[3],
|
||||
"total_fees": row[4],
|
||||
"total_volume": row[5],
|
||||
"best_trade_pnl": row[6],
|
||||
"worst_trade_pnl": row[7],
|
||||
}
|
||||
|
||||
def update_daily_summary(self, date: str) -> None:
|
||||
"""Recalculate and upsert the daily summary row for *date*."""
|
||||
summary = self.get_daily_summary(date)
|
||||
self._db["daily_summary"].upsert(summary, pk="date")
|
||||
self._log.info("daily_summary_updated", date=date, pnl=summary["total_pnl"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick-access aggregates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Balance history
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def log_balance(self, balance: float, pnl: float, event: str = "update") -> None:
|
||||
"""Record a balance snapshot."""
|
||||
self._db["balance_history"].insert({
|
||||
"timestamp": time.time(),
|
||||
"balance": balance,
|
||||
"pnl": pnl,
|
||||
"event": event,
|
||||
})
|
||||
|
||||
def log_oracle(
|
||||
self,
|
||||
asset: str,
|
||||
oracle_price: float,
|
||||
cex_price: float,
|
||||
deviation_pct: float,
|
||||
oracle_lag_sec: float,
|
||||
oracle_round_id: int = 0,
|
||||
) -> None:
|
||||
"""Record an oracle price snapshot."""
|
||||
self._db["oracle_snapshots"].insert({
|
||||
"timestamp": time.time(),
|
||||
"asset": asset,
|
||||
"oracle_price": oracle_price,
|
||||
"cex_price": cex_price,
|
||||
"deviation_pct": round(deviation_pct, 4),
|
||||
"oracle_lag_sec": round(oracle_lag_sec, 1),
|
||||
"oracle_round_id": str(oracle_round_id),
|
||||
})
|
||||
|
||||
def get_latest_balance(self) -> Optional[float]:
|
||||
"""Return the most recent balance, or None."""
|
||||
row = self._db.execute(
|
||||
"SELECT balance FROM balance_history ORDER BY timestamp DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return float(row[0]) if row else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick-access aggregates
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_total_pnl(self) -> float:
|
||||
"""Return the all-time cumulative PnL."""
|
||||
row = self._db.execute("SELECT COALESCE(SUM(pnl), 0.0) FROM trades").fetchone()
|
||||
return float(row[0])
|
||||
|
||||
def get_today_pnl(self) -> float:
|
||||
"""Return today's cumulative PnL (UTC)."""
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
row = self._db.execute(
|
||||
"SELECT COALESCE(SUM(pnl), 0.0) FROM trades "
|
||||
"WHERE date(created_at, 'unixepoch') = ?",
|
||||
[today],
|
||||
).fetchone()
|
||||
return float(row[0])
|
||||
|
||||
def get_today_trade_count(self) -> int:
|
||||
"""Return the number of trades placed today (UTC)."""
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
row = self._db.execute(
|
||||
"SELECT COUNT(*) FROM trades "
|
||||
"WHERE date(created_at, 'unixepoch') = ?",
|
||||
[today],
|
||||
).fetchone()
|
||||
return int(row[0])
|
||||
136
src/data/models.py
Normal file
136
src/data/models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Core data models used throughout the Polymarket Arbitrage Bot."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
import time
|
||||
|
||||
|
||||
class Asset(str, Enum):
|
||||
BTC = "BTC"
|
||||
ETH = "ETH"
|
||||
SOL = "SOL"
|
||||
|
||||
|
||||
class Timeframe(str, Enum):
|
||||
FIVE_MIN = "5M"
|
||||
FIFTEEN_MIN = "15M"
|
||||
|
||||
|
||||
class Direction(str, Enum):
|
||||
UP = "UP"
|
||||
DOWN = "DOWN"
|
||||
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class TradeStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
FILLED = "FILLED"
|
||||
PARTIALLY_FILLED = "PARTIALLY_FILLED"
|
||||
CANCELLED = "CANCELLED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveMarket:
|
||||
condition_id: str
|
||||
up_token_id: str
|
||||
down_token_id: str
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
end_date: str # ISO format resolution time
|
||||
question: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowState:
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
window_start_time: float # Unix timestamp
|
||||
window_end_time: float # Unix timestamp
|
||||
start_price: Optional[float] = None
|
||||
current_price: Optional[float] = None
|
||||
market: Optional[ActiveMarket] = None
|
||||
|
||||
@property
|
||||
def time_remaining(self) -> float:
|
||||
return max(0, self.window_end_time - time.time())
|
||||
|
||||
@property
|
||||
def price_change_pct(self) -> Optional[float]:
|
||||
if self.start_price and self.current_price and self.start_price > 0:
|
||||
return (self.current_price - self.start_price) / self.start_price * 100
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
direction: Direction
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
token_id: str
|
||||
price: float # Polymarket entry price
|
||||
size: int # Number of shares
|
||||
edge: float # Estimated edge after fees
|
||||
estimated_prob: float # Estimated true probability
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
id: str
|
||||
signal: Signal
|
||||
order_id: str = ""
|
||||
status: TradeStatus = TradeStatus.PENDING
|
||||
fill_price: float = 0.0
|
||||
fill_size: int = 0
|
||||
fee: float = 0.0
|
||||
pnl: float = 0.0
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
market_id: str
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
direction: Direction
|
||||
token_id: str
|
||||
size: int
|
||||
avg_price: float
|
||||
current_value: float = 0.0
|
||||
unrealized_pnl: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderBookLevel:
|
||||
price: float
|
||||
size: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderBookSnapshot:
|
||||
token_id: str
|
||||
bids: list[OrderBookLevel] = field(default_factory=list)
|
||||
asks: list[OrderBookLevel] = field(default_factory=list)
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def best_bid(self) -> Optional[float]:
|
||||
return self.bids[0].price if self.bids else None
|
||||
|
||||
@property
|
||||
def best_ask(self) -> Optional[float]:
|
||||
return self.asks[0].price if self.asks else None
|
||||
|
||||
@property
|
||||
def spread(self) -> Optional[float]:
|
||||
if self.best_bid is not None and self.best_ask is not None:
|
||||
return self.best_ask - self.best_bid
|
||||
return None
|
||||
7
src/execution/__init__.py
Normal file
7
src/execution/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Execution engine modules."""
|
||||
|
||||
from src.execution.clob_client import ClobClientWrapper
|
||||
from src.execution.order_manager import OrderManager
|
||||
from src.execution.position_tracker import PositionTracker
|
||||
|
||||
__all__ = ["ClobClientWrapper", "OrderManager", "PositionTracker"]
|
||||
235
src/execution/clob_client.py
Normal file
235
src/execution/clob_client.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Polymarket CLOB API wrapper with EIP-712 authentication.
|
||||
|
||||
Handles order creation, cancellation, and position queries via
|
||||
the py-clob-client SDK.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import PolymarketConfig
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class ClobClientWrapper:
|
||||
"""Async-friendly wrapper around py-clob-client.
|
||||
|
||||
The underlying SDK is synchronous, so heavy calls are delegated
|
||||
to a thread-pool executor to avoid blocking the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PolymarketConfig) -> None:
|
||||
self.config = config
|
||||
self._client: Any = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the CLOB client with credentials.
|
||||
|
||||
Must be called before any trading operations.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
if not self.config.private_key:
|
||||
log.warning("clob_client_no_key", msg="No private key — running in read-only mode")
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
self._client = await loop.run_in_executor(None, self._create_client)
|
||||
self._initialized = True
|
||||
log.info("clob_client_initialized")
|
||||
except Exception:
|
||||
log.exception("clob_client_init_failed")
|
||||
|
||||
def _create_client(self) -> Any:
|
||||
"""Create the synchronous CLOB client (runs in executor)."""
|
||||
from py_clob_client.client import ClobClient
|
||||
|
||||
client = ClobClient(
|
||||
host=self.config.clob_url,
|
||||
key=self.config.private_key,
|
||||
chain_id=self.config.chain_id,
|
||||
signature_type=self.config.signature_type,
|
||||
funder=self.config.proxy_wallet if self.config.proxy_wallet else None,
|
||||
)
|
||||
|
||||
# Set API credentials if available
|
||||
if self.config.api_key:
|
||||
client.set_api_creds(client.create_or_derive_api_creds())
|
||||
else:
|
||||
# Generate new API key
|
||||
api_creds = client.create_api_key()
|
||||
client.set_api_creds(api_creds)
|
||||
log.info("clob_api_key_created")
|
||||
|
||||
return client
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
return self._initialized and self._client is not None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Order operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def place_limit_order(
|
||||
self,
|
||||
token_id: str,
|
||||
price: float,
|
||||
size: int,
|
||||
side: str = "BUY",
|
||||
) -> Optional[dict]:
|
||||
"""Place a limit order on the CLOB.
|
||||
|
||||
Args:
|
||||
token_id: The CLOB token ID (Up or Down outcome).
|
||||
price: Limit price (0.01 to 0.99).
|
||||
size: Number of shares.
|
||||
side: "BUY" or "SELL".
|
||||
|
||||
Returns:
|
||||
Order response dict or None on failure.
|
||||
"""
|
||||
if not self.is_ready:
|
||||
log.error("clob_not_ready", msg="Client not initialized")
|
||||
return None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result = await loop.run_in_executor(
|
||||
None, self._place_order_sync, token_id, price, size, side
|
||||
)
|
||||
log.info(
|
||||
"order_placed",
|
||||
token_id=token_id[:16],
|
||||
price=price,
|
||||
size=size,
|
||||
side=side,
|
||||
order_id=result.get("orderID", ""),
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
log.exception("order_place_failed", token_id=token_id[:16], price=price, size=size)
|
||||
return None
|
||||
|
||||
def _place_order_sync(self, token_id: str, price: float, size: int, side: str) -> dict:
|
||||
from py_clob_client.clob_types import OrderArgs, OrderType
|
||||
|
||||
order_args = OrderArgs(
|
||||
token_id=token_id,
|
||||
price=price,
|
||||
size=size,
|
||||
side=side,
|
||||
)
|
||||
signed_order = self._client.create_order(order_args)
|
||||
return self._client.post_order(signed_order, orderType=OrderType.GTC)
|
||||
|
||||
async def place_market_order(
|
||||
self,
|
||||
token_id: str,
|
||||
size: int,
|
||||
side: str = "BUY",
|
||||
) -> Optional[dict]:
|
||||
"""Place a market order (FOK at worst available price)."""
|
||||
if not self.is_ready:
|
||||
return None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result = await loop.run_in_executor(
|
||||
None, self._place_market_order_sync, token_id, size, side
|
||||
)
|
||||
log.info("market_order_placed", token_id=token_id[:16], size=size, side=side)
|
||||
return result
|
||||
except Exception:
|
||||
log.exception("market_order_failed")
|
||||
return None
|
||||
|
||||
def _place_market_order_sync(self, token_id: str, size: int, side: str) -> dict:
|
||||
from py_clob_client.clob_types import OrderArgs, OrderType
|
||||
|
||||
# Use price 0.99 for BUY (worst case) or 0.01 for SELL
|
||||
price = 0.99 if side == "BUY" else 0.01
|
||||
order_args = OrderArgs(
|
||||
token_id=token_id,
|
||||
price=price,
|
||||
size=size,
|
||||
side=side,
|
||||
)
|
||||
signed_order = self._client.create_order(order_args)
|
||||
return self._client.post_order(signed_order, orderType=OrderType.FOK)
|
||||
|
||||
async def cancel_order(self, order_id: str) -> bool:
|
||||
"""Cancel a specific order by ID."""
|
||||
if not self.is_ready:
|
||||
return False
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await loop.run_in_executor(None, self._client.cancel, order_id)
|
||||
log.info("order_cancelled", order_id=order_id)
|
||||
return True
|
||||
except Exception:
|
||||
log.exception("order_cancel_failed", order_id=order_id)
|
||||
return False
|
||||
|
||||
async def cancel_all_orders(self) -> bool:
|
||||
"""Cancel all open orders."""
|
||||
if not self.is_ready:
|
||||
return False
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await loop.run_in_executor(None, self._client.cancel_all)
|
||||
log.info("all_orders_cancelled")
|
||||
return True
|
||||
except Exception:
|
||||
log.exception("cancel_all_failed")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_order(self, order_id: str) -> Optional[dict]:
|
||||
"""Fetch order status by ID."""
|
||||
if not self.is_ready:
|
||||
return None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(None, self._client.get_order, order_id)
|
||||
except Exception:
|
||||
log.exception("get_order_failed", order_id=order_id)
|
||||
return None
|
||||
|
||||
async def get_open_orders(self) -> list[dict]:
|
||||
"""Fetch all open orders."""
|
||||
if not self.is_ready:
|
||||
return []
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
result = await loop.run_in_executor(None, self._client.get_orders)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception:
|
||||
log.exception("get_open_orders_failed")
|
||||
return []
|
||||
|
||||
async def get_orderbook(self, token_id: str) -> Optional[dict]:
|
||||
"""Fetch current orderbook for a token."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(
|
||||
None, self._client.get_order_book, token_id
|
||||
)
|
||||
except Exception:
|
||||
log.exception("get_orderbook_failed", token_id=token_id[:16])
|
||||
return None
|
||||
243
src/execution/order_manager.py
Normal file
243
src/execution/order_manager.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Order Manager — handles order lifecycle from signal to fill.
|
||||
|
||||
Manages order creation, tracking, modification, and cancellation
|
||||
with position limit enforcement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import Config
|
||||
from src.data.models import Direction, Signal, Trade, TradeStatus
|
||||
from src.execution.clob_client import ClobClientWrapper
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class OrderManager:
|
||||
"""Manages the full lifecycle of orders from signal to fill.
|
||||
|
||||
Enforces position limits, tracks pending/active orders,
|
||||
and handles cancellation near resolution time.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, clob_client: ClobClientWrapper) -> None:
|
||||
self.config = config
|
||||
self.clob = clob_client
|
||||
self._pending_orders: dict[str, Trade] = {}
|
||||
self._active_trades: dict[str, Trade] = {}
|
||||
self._on_fill_callbacks: list[Callable[[Trade], None]] = []
|
||||
self._running = False
|
||||
|
||||
def on_fill(self, callback: Callable[[Trade], None]) -> None:
|
||||
"""Register callback for when an order is filled."""
|
||||
self._on_fill_callbacks.append(callback)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Order submission
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def submit_signal(self, signal: Signal) -> Optional[Trade]:
|
||||
"""Submit a signal for execution.
|
||||
|
||||
Validates against risk limits, creates the order, and tracks it.
|
||||
Returns a Trade object if the order was submitted.
|
||||
"""
|
||||
# Check concurrent position limit
|
||||
active_count = len(self._active_trades) + len(self._pending_orders)
|
||||
if active_count >= self.config.risk.max_concurrent_positions:
|
||||
log.warning(
|
||||
"position_limit_reached",
|
||||
active=active_count,
|
||||
limit=self.config.risk.max_concurrent_positions,
|
||||
)
|
||||
return None
|
||||
|
||||
# Check per-market exposure
|
||||
market_exposure = self._get_market_exposure(signal.token_id)
|
||||
new_exposure = signal.price * signal.size
|
||||
if market_exposure + new_exposure > self.config.risk.max_position_per_market_usd:
|
||||
log.warning(
|
||||
"market_exposure_limit",
|
||||
token_id=signal.token_id[:16],
|
||||
current=round(market_exposure, 2),
|
||||
new=round(new_exposure, 2),
|
||||
)
|
||||
return None
|
||||
|
||||
# Create trade record
|
||||
trade = Trade(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
signal=signal,
|
||||
status=TradeStatus.PENDING,
|
||||
)
|
||||
|
||||
# Submit order to CLOB
|
||||
result = await self.clob.place_limit_order(
|
||||
token_id=signal.token_id,
|
||||
price=signal.price,
|
||||
size=signal.size,
|
||||
side="BUY",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
trade.status = TradeStatus.FAILED
|
||||
log.error("order_submission_failed", trade_id=trade.id)
|
||||
return trade
|
||||
|
||||
trade.order_id = result.get("orderID", "")
|
||||
self._pending_orders[trade.id] = trade
|
||||
|
||||
log.info(
|
||||
"order_submitted",
|
||||
trade_id=trade.id,
|
||||
order_id=trade.order_id,
|
||||
asset=signal.asset.value,
|
||||
direction=signal.direction.value,
|
||||
price=signal.price,
|
||||
size=signal.size,
|
||||
)
|
||||
|
||||
return trade
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Order monitoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def check_order_status(self, trade: Trade) -> Trade:
|
||||
"""Check the status of a pending order and update trade accordingly."""
|
||||
if not trade.order_id:
|
||||
return trade
|
||||
|
||||
order_info = await self.clob.get_order(trade.order_id)
|
||||
if order_info is None:
|
||||
return trade
|
||||
|
||||
status = order_info.get("status", "").upper()
|
||||
|
||||
if status in ("MATCHED", "FILLED"):
|
||||
trade.status = TradeStatus.FILLED
|
||||
trade.fill_price = float(order_info.get("price", trade.signal.price))
|
||||
trade.fill_size = int(order_info.get("size_matched", trade.signal.size))
|
||||
trade.fee = self._calculate_fee(trade)
|
||||
trade.updated_at = time.time()
|
||||
|
||||
# Move from pending to active
|
||||
self._pending_orders.pop(trade.id, None)
|
||||
self._active_trades[trade.id] = trade
|
||||
|
||||
for cb in self._on_fill_callbacks:
|
||||
try:
|
||||
cb(trade)
|
||||
except Exception:
|
||||
log.exception("fill_callback_error")
|
||||
|
||||
log.info(
|
||||
"order_filled",
|
||||
trade_id=trade.id,
|
||||
fill_price=trade.fill_price,
|
||||
fill_size=trade.fill_size,
|
||||
fee=round(trade.fee, 4),
|
||||
)
|
||||
|
||||
elif status == "CANCELLED":
|
||||
trade.status = TradeStatus.CANCELLED
|
||||
trade.updated_at = time.time()
|
||||
self._pending_orders.pop(trade.id, None)
|
||||
|
||||
return trade
|
||||
|
||||
async def monitor_loop(self, interval: float = 2.0) -> None:
|
||||
"""Continuously monitor pending orders for fills."""
|
||||
self._running = True
|
||||
while self._running:
|
||||
for trade_id in list(self._pending_orders.keys()):
|
||||
trade = self._pending_orders.get(trade_id)
|
||||
if trade:
|
||||
await self.check_order_status(trade)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cancellation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel_pending(self, trade_id: str) -> bool:
|
||||
"""Cancel a specific pending order."""
|
||||
trade = self._pending_orders.get(trade_id)
|
||||
if not trade or not trade.order_id:
|
||||
return False
|
||||
|
||||
success = await self.clob.cancel_order(trade.order_id)
|
||||
if success:
|
||||
trade.status = TradeStatus.CANCELLED
|
||||
trade.updated_at = time.time()
|
||||
self._pending_orders.pop(trade_id, None)
|
||||
|
||||
return success
|
||||
|
||||
async def cancel_all_pending(self) -> int:
|
||||
"""Cancel all pending orders. Returns count of cancelled orders."""
|
||||
cancelled = 0
|
||||
for trade_id in list(self._pending_orders.keys()):
|
||||
if await self.cancel_pending(trade_id):
|
||||
cancelled += 1
|
||||
return cancelled
|
||||
|
||||
async def cancel_expiring_orders(self, seconds_before_resolution: float = 5.0) -> int:
|
||||
"""Cancel orders that are too close to market resolution."""
|
||||
cancelled = 0
|
||||
now = time.time()
|
||||
|
||||
for trade_id, trade in list(self._pending_orders.items()):
|
||||
# Check if the signal's window is about to resolve
|
||||
# (Signal doesn't store window_end_time, but we can infer from context)
|
||||
if trade.created_at + 300 - now < seconds_before_resolution: # Rough heuristic
|
||||
if await self.cancel_pending(trade_id):
|
||||
cancelled += 1
|
||||
log.info("expiring_order_cancelled", trade_id=trade_id)
|
||||
|
||||
return cancelled
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_pending_trades(self) -> list[Trade]:
|
||||
return list(self._pending_orders.values())
|
||||
|
||||
def get_active_trades(self) -> list[Trade]:
|
||||
return list(self._active_trades.values())
|
||||
|
||||
def get_all_trades(self) -> list[Trade]:
|
||||
return self.get_pending_trades() + self.get_active_trades()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_market_exposure(self, token_id: str) -> float:
|
||||
"""Calculate current dollar exposure for a given token."""
|
||||
exposure = 0.0
|
||||
for trade in list(self._pending_orders.values()) + list(self._active_trades.values()):
|
||||
if trade.signal.token_id == token_id:
|
||||
exposure += trade.signal.price * trade.signal.size
|
||||
return exposure
|
||||
|
||||
def _calculate_fee(self, trade: Trade) -> float:
|
||||
"""Calculate the taker fee for a filled trade."""
|
||||
tf = trade.signal.timeframe.value
|
||||
fee_rate = self.config.fees.fee_for_timeframe(tf)
|
||||
# Fee is on the potential payout (winning amount)
|
||||
payout = trade.fill_size * 1.0 # $1 per share if winning
|
||||
cost = trade.fill_price * trade.fill_size
|
||||
profit = payout - cost
|
||||
return max(0, profit * fee_rate)
|
||||
211
src/execution/position_tracker.py
Normal file
211
src/execution/position_tracker.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Position Tracker — real-time position & PnL tracking.
|
||||
|
||||
Monitors active positions, calculates unrealized PnL based on
|
||||
current Polymarket prices, and enforces exposure limits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import Config
|
||||
from src.data.models import (
|
||||
Asset,
|
||||
Direction,
|
||||
OrderBookSnapshot,
|
||||
Position,
|
||||
Timeframe,
|
||||
Trade,
|
||||
TradeStatus,
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class PositionTracker:
|
||||
"""Tracks all open positions and computes real-time PnL.
|
||||
|
||||
Positions are created from filled trades and closed when
|
||||
the underlying market resolves or the position is sold.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self._positions: dict[str, Position] = {} # keyed by token_id
|
||||
self._realized_pnl: float = 0.0
|
||||
self._total_fees: float = 0.0
|
||||
self._trade_count: int = 0
|
||||
self._win_count: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Position management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def open_position(self, trade: Trade) -> Position:
|
||||
"""Open or add to a position from a filled trade."""
|
||||
token_id = trade.signal.token_id
|
||||
signal = trade.signal
|
||||
|
||||
if token_id in self._positions:
|
||||
# Add to existing position
|
||||
pos = self._positions[token_id]
|
||||
total_cost = pos.avg_price * pos.size + trade.fill_price * trade.fill_size
|
||||
new_size = pos.size + trade.fill_size
|
||||
pos.avg_price = total_cost / new_size if new_size > 0 else 0
|
||||
pos.size = new_size
|
||||
else:
|
||||
pos = Position(
|
||||
market_id=token_id,
|
||||
asset=signal.asset,
|
||||
timeframe=signal.timeframe,
|
||||
direction=signal.direction,
|
||||
token_id=token_id,
|
||||
size=trade.fill_size,
|
||||
avg_price=trade.fill_price,
|
||||
)
|
||||
self._positions[token_id] = pos
|
||||
|
||||
self._trade_count += 1
|
||||
self._total_fees += trade.fee
|
||||
|
||||
log.info(
|
||||
"position_opened",
|
||||
token_id=token_id[:16],
|
||||
asset=signal.asset.value,
|
||||
direction=signal.direction.value,
|
||||
size=pos.size,
|
||||
avg_price=round(pos.avg_price, 4),
|
||||
)
|
||||
|
||||
return pos
|
||||
|
||||
def close_position(self, token_id: str, resolution_price: float) -> Optional[float]:
|
||||
"""Close a position at resolution.
|
||||
|
||||
Args:
|
||||
token_id: The token ID of the position.
|
||||
resolution_price: 1.0 if the outcome won, 0.0 if lost.
|
||||
|
||||
Returns:
|
||||
Realized PnL for this position, or None if no position.
|
||||
"""
|
||||
pos = self._positions.pop(token_id, None)
|
||||
if pos is None:
|
||||
return None
|
||||
|
||||
payout = resolution_price * pos.size
|
||||
cost = pos.avg_price * pos.size
|
||||
pnl = payout - cost
|
||||
|
||||
self._realized_pnl += pnl
|
||||
if pnl > 0:
|
||||
self._win_count += 1
|
||||
|
||||
log.info(
|
||||
"position_closed",
|
||||
token_id=token_id[:16],
|
||||
asset=pos.asset.value,
|
||||
direction=pos.direction.value,
|
||||
size=pos.size,
|
||||
avg_price=round(pos.avg_price, 4),
|
||||
payout=round(payout, 2),
|
||||
pnl=round(pnl, 2),
|
||||
result="WIN" if pnl > 0 else "LOSS",
|
||||
)
|
||||
|
||||
return pnl
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mark-to-market
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_mark(self, token_id: str, current_price: float) -> None:
|
||||
"""Update mark-to-market for a position."""
|
||||
pos = self._positions.get(token_id)
|
||||
if pos is None:
|
||||
return
|
||||
|
||||
pos.current_value = current_price * pos.size
|
||||
pos.unrealized_pnl = pos.current_value - (pos.avg_price * pos.size)
|
||||
|
||||
def update_from_orderbooks(self, orderbooks: dict[str, OrderBookSnapshot]) -> None:
|
||||
"""Update all positions from latest orderbook data."""
|
||||
for token_id, pos in self._positions.items():
|
||||
book = orderbooks.get(token_id)
|
||||
if book and book.best_bid is not None:
|
||||
self.update_mark(token_id, book.best_bid)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_position(self, token_id: str) -> Optional[Position]:
|
||||
return self._positions.get(token_id)
|
||||
|
||||
def get_all_positions(self) -> list[Position]:
|
||||
return list(self._positions.values())
|
||||
|
||||
def get_positions_by_asset(self, asset: str) -> list[Position]:
|
||||
a = Asset(asset)
|
||||
return [p for p in self._positions.values() if p.asset == a]
|
||||
|
||||
@property
|
||||
def total_unrealized_pnl(self) -> float:
|
||||
return sum(p.unrealized_pnl for p in self._positions.values())
|
||||
|
||||
@property
|
||||
def total_realized_pnl(self) -> float:
|
||||
return self._realized_pnl
|
||||
|
||||
@property
|
||||
def total_pnl(self) -> float:
|
||||
return self._realized_pnl + self.total_unrealized_pnl
|
||||
|
||||
@property
|
||||
def total_fees(self) -> float:
|
||||
return self._total_fees
|
||||
|
||||
@property
|
||||
def total_exposure(self) -> float:
|
||||
return sum(p.avg_price * p.size for p in self._positions.values())
|
||||
|
||||
@property
|
||||
def position_count(self) -> int:
|
||||
return len(self._positions)
|
||||
|
||||
@property
|
||||
def trade_count(self) -> int:
|
||||
return self._trade_count
|
||||
|
||||
@property
|
||||
def win_rate(self) -> float:
|
||||
closed = self._trade_count - len(self._positions)
|
||||
return self._win_count / closed if closed > 0 else 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Risk checks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_exposure_ok(self, additional_usd: float = 0) -> bool:
|
||||
"""Check if total exposure is within limits."""
|
||||
return (self.total_exposure + additional_usd) <= self.config.risk.max_total_exposure_usd
|
||||
|
||||
def is_daily_loss_ok(self, daily_pnl: float) -> bool:
|
||||
"""Check if daily loss is within limits."""
|
||||
return daily_pnl > -self.config.risk.max_daily_loss_usd
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""Return a summary dict for logging/display."""
|
||||
return {
|
||||
"positions": self.position_count,
|
||||
"total_exposure": round(self.total_exposure, 2),
|
||||
"unrealized_pnl": round(self.total_unrealized_pnl, 2),
|
||||
"realized_pnl": round(self.total_realized_pnl, 2),
|
||||
"total_pnl": round(self.total_pnl, 2),
|
||||
"total_fees": round(self.total_fees, 2),
|
||||
"trades": self.trade_count,
|
||||
"win_rate": round(self.win_rate * 100, 1),
|
||||
}
|
||||
6
src/feeds/__init__.py
Normal file
6
src/feeds/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Price feed integrations for the Polymarket Arbitrage Bot."""
|
||||
|
||||
from .binance_ws import BinanceFeed, PriceCallback
|
||||
from .polymarket_ws import PolymarketFeed
|
||||
|
||||
__all__ = ["BinanceFeed", "PolymarketFeed", "PriceCallback"]
|
||||
250
src/feeds/binance_ws.py
Normal file
250
src/feeds/binance_ws.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Binance WebSocket price feed for real-time trade data.
|
||||
|
||||
Connects to Binance combined streams for BTC, ETH, and SOL trade data,
|
||||
providing low-latency price updates via an async callback pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed, InvalidStatusCode
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Stream symbol mapping: Binance stream name -> canonical symbol
|
||||
_STREAM_TO_SYMBOL: dict[str, str] = {
|
||||
"btcusdt@trade": "BTC",
|
||||
"ethusdt@trade": "ETH",
|
||||
"solusdt@trade": "SOL",
|
||||
}
|
||||
|
||||
_SYMBOL_TO_STREAM: dict[str, str] = {v: k for k, v in _STREAM_TO_SYMBOL.items()}
|
||||
|
||||
PriceCallback = Callable[[str, float, float, float], None]
|
||||
|
||||
_DEFAULT_URL = (
|
||||
"wss://stream.binance.com:9443/stream"
|
||||
"?streams=btcusdt@trade/ethusdt@trade/solusdt@trade"
|
||||
)
|
||||
|
||||
_PING_INTERVAL_S = 30
|
||||
_RECONNECT_BASE_S = 1.0
|
||||
_RECONNECT_MAX_S = 30.0
|
||||
|
||||
|
||||
class BinanceFeed:
|
||||
"""Async Binance WebSocket price feed with auto-reconnect.
|
||||
|
||||
Usage::
|
||||
|
||||
feed = BinanceFeed()
|
||||
feed.subscribe(my_callback)
|
||||
await feed.start() # runs until stop() is called
|
||||
"""
|
||||
|
||||
def __init__(self, url: str = _DEFAULT_URL) -> None:
|
||||
self._url = url
|
||||
self._callbacks: list[PriceCallback] = []
|
||||
self._latest_prices: dict[str, float] = {}
|
||||
self._last_recv_ts: dict[str, float] = {}
|
||||
self._connected = False
|
||||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._stop_event: asyncio.Event = asyncio.Event()
|
||||
self._tasks: list[asyncio.Task[None]] = []
|
||||
self._reconnect_delay = _RECONNECT_BASE_S
|
||||
self._log = logger.bind(component="BinanceFeed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Return ``True`` when the WebSocket connection is open."""
|
||||
return self._connected
|
||||
|
||||
def subscribe(self, callback: PriceCallback) -> None:
|
||||
"""Register a price-update callback.
|
||||
|
||||
The callback signature is::
|
||||
|
||||
callback(symbol: str, price: float, timestamp: float, volume: float)
|
||||
"""
|
||||
if callback not in self._callbacks:
|
||||
self._callbacks.append(callback)
|
||||
self._log.info("callback_subscribed", total=len(self._callbacks))
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Return the most recent trade price for *symbol*, or ``None``."""
|
||||
return self._latest_prices.get(symbol.upper())
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Open the WebSocket and begin consuming messages.
|
||||
|
||||
Automatically reconnects on disconnection with exponential backoff.
|
||||
Blocks until :meth:`stop` is called.
|
||||
"""
|
||||
self._stop_event.clear()
|
||||
self._log.info("feed_starting", url=self._url)
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
await self._connect_and_listen()
|
||||
except (ConnectionClosed, ConnectionError, InvalidStatusCode, OSError) as exc:
|
||||
self._set_disconnected()
|
||||
self._log.warning(
|
||||
"connection_lost",
|
||||
error=str(exc),
|
||||
reconnect_in=self._reconnect_delay,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
self._log.info("feed_cancelled")
|
||||
break
|
||||
except Exception:
|
||||
self._set_disconnected()
|
||||
self._log.exception(
|
||||
"unexpected_error",
|
||||
reconnect_in=self._reconnect_delay,
|
||||
)
|
||||
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
# Wait with exponential backoff, but allow early exit via stop().
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._stop_event.wait(),
|
||||
timeout=self._reconnect_delay,
|
||||
)
|
||||
break # stop_event was set during the wait
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2, _RECONNECT_MAX_S
|
||||
)
|
||||
|
||||
await self._cleanup()
|
||||
self._log.info("feed_stopped")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Signal the feed to shut down gracefully."""
|
||||
self._log.info("feed_stop_requested")
|
||||
self._stop_event.set()
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _connect_and_listen(self) -> None:
|
||||
"""Establish a WebSocket connection and consume messages."""
|
||||
async with websockets.connect(
|
||||
self._url,
|
||||
ping_interval=None, # we manage our own heartbeat
|
||||
close_timeout=5,
|
||||
) as ws:
|
||||
self._ws = ws
|
||||
self._connected = True
|
||||
self._reconnect_delay = _RECONNECT_BASE_S
|
||||
self._log.info("connected")
|
||||
|
||||
ping_task = asyncio.create_task(self._heartbeat(ws))
|
||||
self._tasks.append(ping_task)
|
||||
|
||||
try:
|
||||
async for raw_msg in ws:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
self._handle_message(raw_msg)
|
||||
finally:
|
||||
ping_task.cancel()
|
||||
try:
|
||||
await ping_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._tasks = [t for t in self._tasks if t is not ping_task]
|
||||
|
||||
async def _heartbeat(self, ws: websockets.WebSocketClientProtocol) -> None:
|
||||
"""Send a ping frame every ``_PING_INTERVAL_S`` seconds."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(_PING_INTERVAL_S)
|
||||
pong = await ws.ping()
|
||||
await asyncio.wait_for(pong, timeout=10)
|
||||
self._log.debug("heartbeat_ok")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._log.warning("heartbeat_failed", error=str(exc))
|
||||
|
||||
def _handle_message(self, raw: str | bytes) -> None:
|
||||
"""Parse a combined-stream message and dispatch callbacks."""
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
self._log.warning("json_parse_error", error=str(exc), raw=raw[:200])
|
||||
return
|
||||
|
||||
stream: str | None = msg.get("stream")
|
||||
data: dict | None = msg.get("data")
|
||||
if stream is None or data is None:
|
||||
self._log.debug("ignored_message", keys=list(msg.keys()))
|
||||
return
|
||||
|
||||
symbol = _STREAM_TO_SYMBOL.get(stream)
|
||||
if symbol is None:
|
||||
self._log.debug("unknown_stream", stream=stream)
|
||||
return
|
||||
|
||||
try:
|
||||
price = float(data["p"])
|
||||
volume = float(data["q"])
|
||||
# Binance trade timestamp is in milliseconds
|
||||
timestamp = float(data["T"]) / 1000.0
|
||||
except (KeyError, ValueError, TypeError) as exc:
|
||||
self._log.warning(
|
||||
"trade_parse_error",
|
||||
stream=stream,
|
||||
error=str(exc),
|
||||
data=data,
|
||||
)
|
||||
return
|
||||
|
||||
self._latest_prices[symbol] = price
|
||||
self._last_recv_ts[symbol] = time.time()
|
||||
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
cb(symbol, price, timestamp, volume)
|
||||
except Exception:
|
||||
self._log.exception("callback_error", symbol=symbol)
|
||||
|
||||
def _set_disconnected(self) -> None:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Cancel lingering tasks and close the socket."""
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self._tasks.clear()
|
||||
|
||||
if self._ws is not None:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._set_disconnected()
|
||||
303
src/feeds/polymarket_ws.py
Normal file
303
src/feeds/polymarket_ws.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Polymarket CLOB WebSocket feed for real-time orderbook data.
|
||||
|
||||
Connects to the Polymarket WebSocket subscriptions endpoint, subscribes
|
||||
to token-level orderbook channels, and maintains the latest
|
||||
:class:`OrderBookSnapshot` per token with auto-reconnect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed, InvalidStatusCode
|
||||
|
||||
from src.data.models import OrderBookLevel, OrderBookSnapshot
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
OrderBookCallback = Callable[[str, OrderBookSnapshot], None]
|
||||
|
||||
_WS_URL = "wss://ws-subscriptions-clob.polymarket.com/ws/"
|
||||
_RECONNECT_BASE_S = 1.0
|
||||
_RECONNECT_MAX_S = 60.0
|
||||
_PING_INTERVAL_S = 30
|
||||
|
||||
|
||||
class PolymarketFeed:
|
||||
"""Async Polymarket CLOB WebSocket feed with auto-reconnect.
|
||||
|
||||
Usage::
|
||||
|
||||
feed = PolymarketFeed()
|
||||
feed.on_orderbook_update(my_callback)
|
||||
feed.subscribe_market("0xabc...")
|
||||
await feed.start()
|
||||
"""
|
||||
|
||||
def __init__(self, url: str = _WS_URL) -> None:
|
||||
self._url = url
|
||||
self._callbacks: list[OrderBookCallback] = []
|
||||
self._subscriptions: set[str] = set()
|
||||
self._orderbooks: dict[str, OrderBookSnapshot] = {}
|
||||
self._connected = False
|
||||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._stop_event: asyncio.Event = asyncio.Event()
|
||||
self._tasks: list[asyncio.Task[None]] = []
|
||||
self._reconnect_delay = _RECONNECT_BASE_S
|
||||
self._log = logger.bind(component="PolymarketFeed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Return ``True`` when the WebSocket connection is open."""
|
||||
return self._connected
|
||||
|
||||
def subscribe_market(self, token_id: str) -> None:
|
||||
"""Subscribe to orderbook updates for *token_id*.
|
||||
|
||||
If the WebSocket is already connected the subscription message is
|
||||
sent immediately; otherwise it will be sent on the next connect.
|
||||
"""
|
||||
self._subscriptions.add(token_id)
|
||||
self._log.info("market_subscribed", token_id=token_id)
|
||||
|
||||
if self._ws is not None and self._connected:
|
||||
asyncio.ensure_future(self._send_subscribe(token_id))
|
||||
|
||||
def unsubscribe_market(self, token_id: str) -> None:
|
||||
"""Unsubscribe from orderbook updates for *token_id*."""
|
||||
self._subscriptions.discard(token_id)
|
||||
self._orderbooks.pop(token_id, None)
|
||||
self._log.info("market_unsubscribed", token_id=token_id)
|
||||
|
||||
if self._ws is not None and self._connected:
|
||||
asyncio.ensure_future(self._send_unsubscribe(token_id))
|
||||
|
||||
def on_orderbook_update(self, callback: OrderBookCallback) -> None:
|
||||
"""Register a callback invoked on every orderbook update.
|
||||
|
||||
Signature::
|
||||
|
||||
callback(token_id: str, snapshot: OrderBookSnapshot)
|
||||
"""
|
||||
if callback not in self._callbacks:
|
||||
self._callbacks.append(callback)
|
||||
self._log.info("callback_registered", total=len(self._callbacks))
|
||||
|
||||
def get_orderbook(self, token_id: str) -> Optional[OrderBookSnapshot]:
|
||||
"""Return the latest cached orderbook for *token_id*, or ``None``."""
|
||||
return self._orderbooks.get(token_id)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect and begin receiving orderbook data.
|
||||
|
||||
Blocks until :meth:`stop` is called. Automatically reconnects
|
||||
with exponential backoff on connection failures.
|
||||
"""
|
||||
self._stop_event.clear()
|
||||
self._log.info("feed_starting", url=self._url)
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
await self._connect_and_listen()
|
||||
except (ConnectionClosed, ConnectionError, InvalidStatusCode, OSError) as exc:
|
||||
self._set_disconnected()
|
||||
self._log.warning(
|
||||
"connection_lost",
|
||||
error=str(exc),
|
||||
reconnect_in=self._reconnect_delay,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
self._log.info("feed_cancelled")
|
||||
break
|
||||
except Exception:
|
||||
self._set_disconnected()
|
||||
self._log.exception(
|
||||
"unexpected_error",
|
||||
reconnect_in=self._reconnect_delay,
|
||||
)
|
||||
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
# Exponential backoff, interruptible by stop().
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._stop_event.wait(),
|
||||
timeout=self._reconnect_delay,
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2, _RECONNECT_MAX_S
|
||||
)
|
||||
|
||||
await self._cleanup()
|
||||
self._log.info("feed_stopped")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Signal the feed to shut down gracefully."""
|
||||
self._log.info("feed_stop_requested")
|
||||
self._stop_event.set()
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _connect_and_listen(self) -> None:
|
||||
"""Establish WebSocket and consume messages."""
|
||||
async with websockets.connect(
|
||||
self._url,
|
||||
ping_interval=None,
|
||||
close_timeout=5,
|
||||
) as ws:
|
||||
self._ws = ws
|
||||
self._connected = True
|
||||
self._reconnect_delay = _RECONNECT_BASE_S
|
||||
self._log.info("connected")
|
||||
|
||||
# Re-subscribe to all tracked markets on (re)connect.
|
||||
for token_id in self._subscriptions:
|
||||
await self._send_subscribe(token_id)
|
||||
|
||||
ping_task = asyncio.create_task(self._heartbeat(ws))
|
||||
self._tasks.append(ping_task)
|
||||
|
||||
try:
|
||||
async for raw_msg in ws:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
self._handle_message(raw_msg)
|
||||
finally:
|
||||
ping_task.cancel()
|
||||
try:
|
||||
await ping_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._tasks = [t for t in self._tasks if t is not ping_task]
|
||||
|
||||
async def _heartbeat(self, ws: websockets.WebSocketClientProtocol) -> None:
|
||||
"""Send periodic pings to keep the connection alive."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(_PING_INTERVAL_S)
|
||||
pong = await ws.ping()
|
||||
await asyncio.wait_for(pong, timeout=10)
|
||||
self._log.debug("heartbeat_ok")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._log.warning("heartbeat_failed", error=str(exc))
|
||||
|
||||
async def _send_subscribe(self, token_id: str) -> None:
|
||||
"""Send a subscribe message for a token's orderbook channel."""
|
||||
if self._ws is None:
|
||||
return
|
||||
msg = {
|
||||
"type": "subscribe",
|
||||
"channel": "book",
|
||||
"assets_ids": [token_id],
|
||||
}
|
||||
await self._ws.send(json.dumps(msg))
|
||||
self._log.debug("subscribe_sent", token_id=token_id)
|
||||
|
||||
async def _send_unsubscribe(self, token_id: str) -> None:
|
||||
"""Send an unsubscribe message for a token's orderbook channel."""
|
||||
if self._ws is None:
|
||||
return
|
||||
msg = {
|
||||
"type": "unsubscribe",
|
||||
"channel": "book",
|
||||
"assets_ids": [token_id],
|
||||
}
|
||||
await self._ws.send(json.dumps(msg))
|
||||
self._log.debug("unsubscribe_sent", token_id=token_id)
|
||||
|
||||
def _handle_message(self, raw: str | bytes) -> None:
|
||||
"""Parse an incoming message and update the local orderbook cache."""
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
self._log.warning("json_parse_error", error=str(exc))
|
||||
return
|
||||
|
||||
msg_type = msg.get("type") or msg.get("event_type")
|
||||
|
||||
# Handle book snapshot or delta messages
|
||||
if msg_type in ("book", "book_snapshot", "book_delta"):
|
||||
self._process_book_message(msg)
|
||||
elif msg_type == "error":
|
||||
self._log.error("server_error", message=msg.get("message"))
|
||||
else:
|
||||
self._log.debug("ignored_message", msg_type=msg_type)
|
||||
|
||||
def _process_book_message(self, msg: dict) -> None:
|
||||
"""Extract bids/asks from a book message and update state."""
|
||||
# Polymarket sends asset_id at the top level of book messages.
|
||||
token_id: str | None = msg.get("asset_id") or msg.get("market")
|
||||
if token_id is None:
|
||||
self._log.debug("book_message_missing_token", keys=list(msg.keys()))
|
||||
return
|
||||
|
||||
bids = [
|
||||
OrderBookLevel(price=float(b["price"]), size=float(b["size"]))
|
||||
for b in msg.get("bids", [])
|
||||
]
|
||||
asks = [
|
||||
OrderBookLevel(price=float(a["price"]), size=float(a["size"]))
|
||||
for a in msg.get("asks", [])
|
||||
]
|
||||
|
||||
# Sort: bids descending, asks ascending
|
||||
bids.sort(key=lambda lvl: lvl.price, reverse=True)
|
||||
asks.sort(key=lambda lvl: lvl.price)
|
||||
|
||||
snapshot = OrderBookSnapshot(
|
||||
token_id=token_id,
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
|
||||
self._orderbooks[token_id] = snapshot
|
||||
|
||||
# Dispatch to registered callbacks
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
cb(token_id, snapshot)
|
||||
except Exception:
|
||||
self._log.exception("callback_error", token_id=token_id)
|
||||
|
||||
def _set_disconnected(self) -> None:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Cancel lingering tasks and close the socket."""
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self._tasks.clear()
|
||||
|
||||
if self._ws is not None:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._set_disconnected()
|
||||
762
src/main.py
Normal file
762
src/main.py
Normal file
@@ -0,0 +1,762 @@
|
||||
"""Main entrypoint for the Polymarket Temporal Arbitrage Bot.
|
||||
|
||||
Ties together all components: market discovery, window tracking, Binance and
|
||||
Polymarket price feeds, strategy evaluation, order execution, position tracking,
|
||||
risk management, and notifications in a single async event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import Config, load_config
|
||||
from src.data import TradeDB
|
||||
from src.data.models import (
|
||||
ActiveMarket,
|
||||
Asset,
|
||||
Direction,
|
||||
OrderBookSnapshot,
|
||||
Signal,
|
||||
Timeframe,
|
||||
Trade,
|
||||
TradeStatus,
|
||||
WindowState,
|
||||
)
|
||||
from src.execution.clob_client import ClobClientWrapper
|
||||
from src.execution.order_manager import OrderManager
|
||||
from src.execution.position_tracker import PositionTracker
|
||||
from src.feeds import BinanceFeed, PolymarketFeed
|
||||
from src.market import MarketDiscovery, WindowTracker
|
||||
from src.risk.risk_manager import RiskManager
|
||||
from src.strategy import SignalAggregator
|
||||
from src.utils import get_logger, setup_logging
|
||||
from src.utils.telegram import TelegramNotifier
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_STATUS_INTERVAL_S: float = 30.0
|
||||
_DISCOVERY_INTERVAL_S: float = 30.0
|
||||
_ORDER_MONITOR_INTERVAL_S: float = 2.0
|
||||
_BALANCE_SNAPSHOT_INTERVAL_S: float = 60.0
|
||||
_DAILY_SUMMARY_HOUR_UTC: int = 0 # midnight UTC
|
||||
|
||||
_BANNER = r"""
|
||||
____ _ _ _ _ _
|
||||
| _ \ ___ | |_ _ _ __ ___ __ _ _ __| | _______| |_ / \ _ __| |__
|
||||
| |_) / _ \| | | | | '_ ` _ \ / _` | '__| |/ / _ \ __/ / _ \ | '__| '_ \
|
||||
| __/ (_) | | |_| | | | | | | (_| | | | < __/ |_ / ___ \| | | |_) |
|
||||
|_| \___/|_|\__, |_| |_| |_|\__,_|_| |_|\_\___|\__| /_/ \_\_| |_.__/
|
||||
|___/
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _print_startup_banner(cfg: Config) -> None:
|
||||
"""Print a startup banner summarising the active configuration."""
|
||||
print(_BANNER)
|
||||
strategies = []
|
||||
if cfg.temporal_arb.enabled:
|
||||
strategies.append("temporal_arb")
|
||||
if cfg.sum_to_one.enabled:
|
||||
strategies.append("sum_to_one")
|
||||
if cfg.spread_capture.enabled:
|
||||
strategies.append("spread_capture")
|
||||
|
||||
print(f" Mode: {cfg.general.mode}")
|
||||
print(f" Assets: {', '.join(cfg.general.assets)}")
|
||||
print(f" Timeframes: {', '.join(cfg.general.timeframes)}")
|
||||
print(f" Strategies: {', '.join(strategies) or 'none'}")
|
||||
print(f" Log level: {cfg.general.log_level}")
|
||||
print(f" Binance WS: {cfg.binance.ws_url}")
|
||||
print(f" Polymarket: {cfg.polymarket.ws_url}")
|
||||
print()
|
||||
|
||||
|
||||
def _link_markets_to_tracker(
|
||||
markets: list[ActiveMarket],
|
||||
tracker: WindowTracker,
|
||||
log: structlog.stdlib.BoundLogger,
|
||||
) -> None:
|
||||
"""Link each discovered market to the corresponding window in the tracker."""
|
||||
for mkt in markets:
|
||||
tracker.link_market(
|
||||
asset=mkt.asset.value,
|
||||
timeframe=mkt.timeframe.value,
|
||||
market=mkt,
|
||||
)
|
||||
log.info(
|
||||
"market_linked_to_window",
|
||||
asset=mkt.asset.value,
|
||||
timeframe=mkt.timeframe.value,
|
||||
condition_id=mkt.condition_id,
|
||||
)
|
||||
|
||||
|
||||
def _subscribe_market_tokens(
|
||||
markets: list[ActiveMarket],
|
||||
poly_feed: PolymarketFeed,
|
||||
log: structlog.stdlib.BoundLogger,
|
||||
) -> None:
|
||||
"""Subscribe the Polymarket feed to token IDs from discovered markets."""
|
||||
for mkt in markets:
|
||||
poly_feed.subscribe_market(mkt.up_token_id)
|
||||
poly_feed.subscribe_market(mkt.down_token_id)
|
||||
log.debug(
|
||||
"poly_tokens_subscribed",
|
||||
condition_id=mkt.condition_id,
|
||||
up_token=mkt.up_token_id[:12],
|
||||
down_token=mkt.down_token_id[:12],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ArbBot:
|
||||
"""Full-featured arbitrage bot with strategy execution and risk management."""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self._cfg = config
|
||||
self._log = get_logger("arb_bot")
|
||||
|
||||
# Components
|
||||
self._db: Optional[TradeDB] = None
|
||||
self._discovery: Optional[MarketDiscovery] = None
|
||||
self._tracker: Optional[WindowTracker] = None
|
||||
self._binance_feed: Optional[BinanceFeed] = None
|
||||
self._poly_feed: Optional[PolymarketFeed] = None
|
||||
self._clob: Optional[ClobClientWrapper] = None
|
||||
self._order_manager: Optional[OrderManager] = None
|
||||
self._position_tracker: Optional[PositionTracker] = None
|
||||
self._risk_manager: Optional[RiskManager] = None
|
||||
self._signal_aggregator: Optional[SignalAggregator] = None
|
||||
self._telegram: Optional[TelegramNotifier] = None
|
||||
|
||||
# Discovered markets cache
|
||||
self._active_markets: list[ActiveMarket] = []
|
||||
|
||||
# Orderbook state
|
||||
self._orderbooks: dict[str, OrderBookSnapshot] = {}
|
||||
|
||||
# Balance tracking
|
||||
self._balance: float = config.general.starting_balance
|
||||
|
||||
# Shutdown coordination
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
# Daily summary tracking
|
||||
self._last_daily_summary_date: str = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_binance_price(
|
||||
self, symbol: str, price: float, timestamp: float, volume: float
|
||||
) -> None:
|
||||
"""Callback invoked for each Binance trade tick."""
|
||||
if self._tracker is not None:
|
||||
self._tracker.update_price(symbol, price, timestamp)
|
||||
|
||||
# Strategy evaluation on each tick
|
||||
if self._signal_aggregator and self._tracker:
|
||||
window = self._tracker.get_window(symbol, "5M")
|
||||
if window:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
lambda s=symbol, p=price, w=window: asyncio.ensure_future(
|
||||
self._evaluate_strategies(s, p, w)
|
||||
)
|
||||
)
|
||||
# Also evaluate 15M windows
|
||||
window_15m = self._tracker.get_window(symbol, "15M")
|
||||
if window_15m:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
lambda s=symbol, p=price, w=window_15m: asyncio.ensure_future(
|
||||
self._evaluate_strategies(s, p, w)
|
||||
)
|
||||
)
|
||||
|
||||
async def _evaluate_strategies(
|
||||
self, symbol: str, cex_price: float, window: WindowState
|
||||
) -> None:
|
||||
"""Evaluate strategies for a given symbol and window."""
|
||||
if not self._signal_aggregator:
|
||||
return
|
||||
try:
|
||||
await self._signal_aggregator.on_price_tick(
|
||||
symbol=symbol,
|
||||
cex_price=cex_price,
|
||||
window=window,
|
||||
orderbooks=self._orderbooks,
|
||||
)
|
||||
except Exception:
|
||||
self._log.exception("strategy_eval_error", symbol=symbol)
|
||||
|
||||
def _on_orderbook_update(
|
||||
self, token_id: str, snapshot: OrderBookSnapshot
|
||||
) -> None:
|
||||
"""Callback invoked for each Polymarket orderbook update."""
|
||||
self._orderbooks[token_id] = snapshot
|
||||
|
||||
# Update position mark-to-market
|
||||
if self._position_tracker:
|
||||
self._position_tracker.update_from_orderbooks(self._orderbooks)
|
||||
|
||||
def _on_window_change(self, window: WindowState) -> None:
|
||||
"""Callback fired when a price window transitions."""
|
||||
self._log.info(
|
||||
"window_changed",
|
||||
asset=window.asset.value,
|
||||
timeframe=window.timeframe.value,
|
||||
start_price=window.start_price,
|
||||
window_start=window.window_start_time,
|
||||
window_end=window.window_end_time,
|
||||
)
|
||||
# Snapshot the previous window to DB
|
||||
if self._db is not None and window.start_price is not None:
|
||||
try:
|
||||
self._db.log_window(window)
|
||||
except Exception:
|
||||
self._log.exception("window_snapshot_failed")
|
||||
|
||||
# Resolve positions for this window (window ended = resolution)
|
||||
if self._position_tracker and window.current_price and window.start_price:
|
||||
self._resolve_window_positions(window)
|
||||
|
||||
def _resolve_window_positions(self, window: WindowState) -> None:
|
||||
"""Close positions whose window just resolved."""
|
||||
if not window.market or not self._position_tracker:
|
||||
return
|
||||
|
||||
actual_direction = (
|
||||
Direction.UP
|
||||
if window.current_price and window.start_price
|
||||
and window.current_price > window.start_price
|
||||
else Direction.DOWN
|
||||
)
|
||||
|
||||
market = window.market
|
||||
for token_id, is_up in [
|
||||
(market.up_token_id, True),
|
||||
(market.down_token_id, False),
|
||||
]:
|
||||
pos = self._position_tracker.get_position(token_id)
|
||||
if pos is None:
|
||||
continue
|
||||
|
||||
won = (is_up and actual_direction == Direction.UP) or (
|
||||
not is_up and actual_direction == Direction.DOWN
|
||||
)
|
||||
resolution_price = 1.0 if won else 0.0
|
||||
pnl = self._position_tracker.close_position(token_id, resolution_price)
|
||||
|
||||
if pnl is not None:
|
||||
self._balance += pnl
|
||||
if self._db:
|
||||
self._db.update_trade(
|
||||
pos.market_id,
|
||||
pnl=pnl,
|
||||
status=TradeStatus.FILLED.value,
|
||||
)
|
||||
self._db.log_balance(self._balance, self._position_tracker.total_pnl)
|
||||
|
||||
# Notify
|
||||
if self._telegram:
|
||||
result = "WIN" if pnl > 0 else "LOSS"
|
||||
asyncio.ensure_future(
|
||||
self._telegram.send(
|
||||
f"{'✅' if pnl > 0 else '❌'} <b>Position Resolved</b>\n"
|
||||
f"Asset: {pos.asset.value} | {pos.direction.value}\n"
|
||||
f"Result: {result} | PnL: ${pnl:+.2f}\n"
|
||||
f"Balance: ${self._balance:.2f}"
|
||||
)
|
||||
)
|
||||
|
||||
self._log.info(
|
||||
"position_resolved",
|
||||
asset=pos.asset.value,
|
||||
direction=pos.direction.value,
|
||||
won=won,
|
||||
pnl=round(pnl, 2),
|
||||
balance=round(self._balance, 2),
|
||||
)
|
||||
|
||||
# Update strategy balance
|
||||
if self._signal_aggregator:
|
||||
self._signal_aggregator.update_balance(self._balance)
|
||||
|
||||
def _on_new_markets(self, markets: list[ActiveMarket]) -> None:
|
||||
"""Callback invoked when MarketDiscovery finds new markets."""
|
||||
self._active_markets.extend(markets)
|
||||
if self._tracker is not None:
|
||||
_link_markets_to_tracker(markets, self._tracker, self._log)
|
||||
if self._poly_feed is not None:
|
||||
_subscribe_market_tokens(markets, self._poly_feed, self._log)
|
||||
|
||||
def _on_signal(self, signal: Signal) -> None:
|
||||
"""Handle a trading signal from the strategy aggregator."""
|
||||
if not self._risk_manager or not self._order_manager:
|
||||
return
|
||||
|
||||
# Risk check
|
||||
additional_usd = signal.price * signal.size
|
||||
if not self._risk_manager.can_open_position(additional_usd):
|
||||
self._log.warning(
|
||||
"signal_rejected_risk",
|
||||
asset=signal.asset.value,
|
||||
direction=signal.direction.value,
|
||||
reason=self._risk_manager.halt_reason or "risk_limit",
|
||||
)
|
||||
if self._risk_manager.is_halted and self._telegram:
|
||||
asyncio.ensure_future(
|
||||
self._telegram.notify_halt(self._risk_manager.halt_reason)
|
||||
)
|
||||
return
|
||||
|
||||
# Submit order
|
||||
asyncio.ensure_future(self._execute_signal(signal))
|
||||
|
||||
async def _execute_signal(self, signal: Signal) -> None:
|
||||
"""Execute a signal through the order manager."""
|
||||
if not self._order_manager or not self._position_tracker:
|
||||
return
|
||||
|
||||
trade = await self._order_manager.submit_signal(signal)
|
||||
if trade is None:
|
||||
return
|
||||
|
||||
if trade.status == TradeStatus.FAILED:
|
||||
self._log.error("trade_failed", asset=signal.asset.value)
|
||||
return
|
||||
|
||||
# Log trade to DB
|
||||
if self._db:
|
||||
self._db.log_trade(trade)
|
||||
|
||||
# Notify via telegram
|
||||
if self._telegram:
|
||||
await self._telegram.notify_trade(
|
||||
asset=signal.asset.value,
|
||||
direction=signal.direction.value,
|
||||
timeframe=signal.timeframe.value,
|
||||
price=signal.price,
|
||||
size=signal.size,
|
||||
edge=signal.edge,
|
||||
)
|
||||
|
||||
self._log.info(
|
||||
"trade_submitted",
|
||||
trade_id=trade.id,
|
||||
asset=signal.asset.value,
|
||||
direction=signal.direction.value,
|
||||
price=signal.price,
|
||||
size=signal.size,
|
||||
edge=round(signal.edge, 4),
|
||||
)
|
||||
|
||||
def _on_fill(self, trade: Trade) -> None:
|
||||
"""Handle a filled order."""
|
||||
if not self._position_tracker:
|
||||
return
|
||||
|
||||
# Open position
|
||||
self._position_tracker.open_position(trade)
|
||||
|
||||
# Update DB
|
||||
if self._db:
|
||||
self._db.update_trade(
|
||||
trade.id,
|
||||
fill_price=trade.fill_price,
|
||||
fill_size=trade.fill_size,
|
||||
fee=trade.fee,
|
||||
status=TradeStatus.FILLED.value,
|
||||
)
|
||||
|
||||
# Notify
|
||||
if self._telegram:
|
||||
asyncio.ensure_future(
|
||||
self._telegram.notify_fill(
|
||||
asset=trade.signal.asset.value,
|
||||
direction=trade.signal.direction.value,
|
||||
fill_price=trade.fill_price,
|
||||
fill_size=trade.fill_size,
|
||||
trade_id=trade.id,
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background loops
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _status_loop(self) -> None:
|
||||
"""Log a status summary every ``_STATUS_INTERVAL_S`` seconds."""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_event.wait(),
|
||||
timeout=_STATUS_INTERVAL_S,
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
windows = (
|
||||
self._tracker.get_all_active_windows()
|
||||
if self._tracker
|
||||
else []
|
||||
)
|
||||
|
||||
latest_prices: dict[str, Optional[float]] = {}
|
||||
if self._binance_feed is not None:
|
||||
for sym in self._cfg.general.assets:
|
||||
latest_prices[sym] = self._binance_feed.get_latest_price(sym)
|
||||
|
||||
position_summary = (
|
||||
self._position_tracker.get_summary()
|
||||
if self._position_tracker
|
||||
else {}
|
||||
)
|
||||
risk_summary = (
|
||||
self._risk_manager.get_risk_summary()
|
||||
if self._risk_manager
|
||||
else {}
|
||||
)
|
||||
strategy_stats = (
|
||||
self._signal_aggregator.get_stats()
|
||||
if self._signal_aggregator
|
||||
else {}
|
||||
)
|
||||
|
||||
self._log.info(
|
||||
"status",
|
||||
binance_connected=(
|
||||
self._binance_feed.is_connected
|
||||
if self._binance_feed
|
||||
else False
|
||||
),
|
||||
polymarket_connected=(
|
||||
self._poly_feed.is_connected
|
||||
if self._poly_feed
|
||||
else False
|
||||
),
|
||||
active_windows=len(windows),
|
||||
active_markets=len(self._active_markets),
|
||||
latest_prices=latest_prices,
|
||||
balance=round(self._balance, 2),
|
||||
positions=position_summary,
|
||||
risk=risk_summary,
|
||||
strategies=strategy_stats,
|
||||
today_trades=(
|
||||
self._db.get_today_trade_count() if self._db else 0
|
||||
),
|
||||
today_pnl=(
|
||||
self._db.get_today_pnl() if self._db else 0.0
|
||||
),
|
||||
)
|
||||
|
||||
async def _balance_snapshot_loop(self) -> None:
|
||||
"""Periodically snapshot balance to DB."""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_event.wait(),
|
||||
timeout=_BALANCE_SNAPSHOT_INTERVAL_S,
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
if self._db and self._position_tracker:
|
||||
self._db.log_balance(
|
||||
self._balance,
|
||||
self._position_tracker.total_pnl,
|
||||
)
|
||||
|
||||
async def _daily_summary_loop(self) -> None:
|
||||
"""Send daily summary at midnight UTC."""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_event.wait(),
|
||||
timeout=60.0,
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
today = now_utc.strftime("%Y-%m-%d")
|
||||
|
||||
if (
|
||||
now_utc.hour == _DAILY_SUMMARY_HOUR_UTC
|
||||
and today != self._last_daily_summary_date
|
||||
):
|
||||
self._last_daily_summary_date = today
|
||||
# Summarize yesterday
|
||||
yesterday = (
|
||||
datetime.now(timezone.utc).replace(hour=0, minute=0, second=0)
|
||||
)
|
||||
yesterday_str = yesterday.strftime("%Y-%m-%d")
|
||||
|
||||
if self._db:
|
||||
self._db.update_daily_summary(yesterday_str)
|
||||
summary = self._db.get_daily_summary(yesterday_str)
|
||||
|
||||
if self._telegram and summary["total_trades"] > 0:
|
||||
await self._telegram.notify_daily_summary(
|
||||
date=yesterday_str,
|
||||
total_trades=summary["total_trades"],
|
||||
wins=summary["wins"],
|
||||
losses=summary["losses"],
|
||||
pnl=summary["total_pnl"],
|
||||
fees=summary["total_fees"],
|
||||
volume=summary["total_volume"],
|
||||
)
|
||||
|
||||
async def _order_expiry_loop(self) -> None:
|
||||
"""Cancel orders close to resolution."""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_event.wait(),
|
||||
timeout=5.0,
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
if self._order_manager:
|
||||
try:
|
||||
cancelled = await self._order_manager.cancel_expiring_orders(
|
||||
seconds_before_resolution=self._cfg.temporal_arb.exit_before_resolution_sec,
|
||||
)
|
||||
if cancelled > 0:
|
||||
self._log.info("expiring_orders_cancelled", count=cancelled)
|
||||
except Exception:
|
||||
self._log.exception("expiry_loop_error")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialise all components, then run feeds concurrently."""
|
||||
self._log.info("bot_initialising")
|
||||
|
||||
# 1. Database
|
||||
db_path = "trades.db"
|
||||
if self._cfg.general.mode == "paper":
|
||||
db_path = "paper_trades.db"
|
||||
self._db = TradeDB(db_path=db_path)
|
||||
|
||||
# Record initial balance
|
||||
self._db.log_balance(self._balance, 0.0, event="start")
|
||||
|
||||
# 2. Market discovery
|
||||
self._discovery = MarketDiscovery(
|
||||
on_new_markets=self._on_new_markets,
|
||||
)
|
||||
self._log.info("running_initial_discovery")
|
||||
initial_markets = await self._discovery.discover()
|
||||
self._active_markets = initial_markets
|
||||
self._log.info(
|
||||
"initial_discovery_complete",
|
||||
count=len(initial_markets),
|
||||
)
|
||||
|
||||
# 3. Window tracker
|
||||
assets = [Asset(a) for a in self._cfg.general.assets]
|
||||
timeframes = [Timeframe(t) for t in self._cfg.general.timeframes]
|
||||
self._tracker = WindowTracker(assets=assets, timeframes=timeframes)
|
||||
self._tracker.on_window_change(self._on_window_change)
|
||||
|
||||
# Link initial markets to windows
|
||||
_link_markets_to_tracker(initial_markets, self._tracker, self._log)
|
||||
|
||||
# 4. CLOB client (for live mode)
|
||||
self._clob = ClobClientWrapper(self._cfg.polymarket)
|
||||
if self._cfg.general.mode == "live":
|
||||
await self._clob.initialize()
|
||||
self._log.info("clob_client_ready", is_ready=self._clob.is_ready)
|
||||
|
||||
# 5. Execution components
|
||||
self._position_tracker = PositionTracker(self._cfg)
|
||||
self._order_manager = OrderManager(self._cfg, self._clob)
|
||||
self._order_manager.on_fill(self._on_fill)
|
||||
self._risk_manager = RiskManager(
|
||||
self._cfg.risk, self._position_tracker, self._db
|
||||
)
|
||||
|
||||
# 6. Strategy aggregator
|
||||
self._signal_aggregator = SignalAggregator(
|
||||
self._cfg, balance=self._balance
|
||||
)
|
||||
self._signal_aggregator.on_signal(self._on_signal)
|
||||
|
||||
# 7. Telegram
|
||||
self._telegram = TelegramNotifier(self._cfg.notifications)
|
||||
await self._telegram.send(
|
||||
f"🚀 <b>Polymarket Arb Bot Started</b>\n"
|
||||
f"Mode: {self._cfg.general.mode}\n"
|
||||
f"Balance: ${self._balance:.2f}\n"
|
||||
f"Assets: {', '.join(self._cfg.general.assets)}"
|
||||
)
|
||||
|
||||
# 8. Binance feed
|
||||
ws_url = self._cfg.binance.ws_url
|
||||
streams = "/".join(
|
||||
f"{s}@trade" for s in self._cfg.binance.symbols
|
||||
)
|
||||
full_url = f"{ws_url}?streams={streams}"
|
||||
self._binance_feed = BinanceFeed(url=full_url)
|
||||
self._binance_feed.subscribe(self._on_binance_price)
|
||||
|
||||
# 9. Polymarket feed
|
||||
self._poly_feed = PolymarketFeed(url=self._cfg.polymarket.ws_url)
|
||||
self._poly_feed.on_orderbook_update(self._on_orderbook_update)
|
||||
|
||||
# Subscribe to discovered market tokens
|
||||
_subscribe_market_tokens(
|
||||
initial_markets, self._poly_feed, self._log
|
||||
)
|
||||
|
||||
self._log.info("bot_starting_feeds")
|
||||
|
||||
# 10. Run all tasks concurrently
|
||||
tasks = [
|
||||
self._binance_feed.start(),
|
||||
self._poly_feed.start(),
|
||||
self._discovery.discover_loop(
|
||||
interval_sec=_DISCOVERY_INTERVAL_S,
|
||||
),
|
||||
self._status_loop(),
|
||||
self._balance_snapshot_loop(),
|
||||
self._daily_summary_loop(),
|
||||
self._order_expiry_loop(),
|
||||
]
|
||||
|
||||
# Add order monitor for live mode
|
||||
if self._cfg.general.mode == "live" and self._clob.is_ready:
|
||||
tasks.append(
|
||||
self._order_manager.monitor_loop(
|
||||
interval=_ORDER_MONITOR_INTERVAL_S
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except asyncio.CancelledError:
|
||||
self._log.info("gather_cancelled")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Gracefully stop all components."""
|
||||
self._log.info("bot_shutting_down")
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Cancel all pending orders
|
||||
if self._order_manager:
|
||||
cancelled = await self._order_manager.cancel_all_pending()
|
||||
self._log.info("shutdown_orders_cancelled", count=cancelled)
|
||||
self._order_manager.stop()
|
||||
|
||||
if self._binance_feed is not None:
|
||||
await self._binance_feed.stop()
|
||||
|
||||
if self._poly_feed is not None:
|
||||
await self._poly_feed.stop()
|
||||
|
||||
# Final balance snapshot
|
||||
if self._db and self._position_tracker:
|
||||
self._db.log_balance(
|
||||
self._balance,
|
||||
self._position_tracker.total_pnl,
|
||||
event="shutdown",
|
||||
)
|
||||
# Update daily summary
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
self._db.update_daily_summary(today)
|
||||
|
||||
# Notify shutdown
|
||||
if self._telegram:
|
||||
position_summary = (
|
||||
self._position_tracker.get_summary()
|
||||
if self._position_tracker
|
||||
else {}
|
||||
)
|
||||
await self._telegram.send(
|
||||
f"🛑 <b>Bot Stopped</b>\n"
|
||||
f"Balance: ${self._balance:.2f}\n"
|
||||
f"Total PnL: ${position_summary.get('total_pnl', 0):.2f}\n"
|
||||
f"Trades: {position_summary.get('trades', 0)}"
|
||||
)
|
||||
await self._telegram.close()
|
||||
|
||||
self._log.info("bot_stopped")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal handling and entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _install_signal_handlers(
|
||||
loop: asyncio.AbstractEventLoop, bot: ArbBot
|
||||
) -> None:
|
||||
"""Register SIGINT / SIGTERM handlers that trigger graceful shutdown."""
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(
|
||||
sig,
|
||||
lambda: asyncio.ensure_future(bot.shutdown()),
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
async def async_main(config: Optional[Config] = None) -> None:
|
||||
"""Async entry point — load config, build the bot, and run."""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
setup_logging(config.general.log_level)
|
||||
|
||||
_print_startup_banner(config)
|
||||
|
||||
bot = ArbBot(config)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_install_signal_handlers(loop, bot)
|
||||
|
||||
try:
|
||||
await bot.start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
await bot.shutdown()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Synchronous entry point for ``python src/main.py``."""
|
||||
try:
|
||||
asyncio.run(async_main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown complete.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
src/market/__init__.py
Normal file
6
src/market/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Market discovery and window tracking."""
|
||||
|
||||
from src.market.discovery import MarketDiscovery
|
||||
from src.market.window_tracker import WindowTracker
|
||||
|
||||
__all__ = ["MarketDiscovery", "WindowTracker"]
|
||||
275
src/market/discovery.py
Normal file
275
src/market/discovery.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Active 5min/15min Up/Down crypto market discovery via Gamma API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from src.data.models import ActiveMarket, Asset, Timeframe
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
GAMMA_API_URL = "https://gamma-api.polymarket.com/events"
|
||||
GAMMA_QUERY_PARAMS = {
|
||||
"tag": "crypto",
|
||||
"active": "true",
|
||||
"closed": "false",
|
||||
"limit": "100",
|
||||
}
|
||||
|
||||
SUPPORTED_ASSETS: dict[str, Asset] = {
|
||||
"BTC": Asset.BTC,
|
||||
"ETH": Asset.ETH,
|
||||
"SOL": Asset.SOL,
|
||||
}
|
||||
|
||||
TIMEFRAME_PATTERNS: dict[str, Timeframe] = {
|
||||
"5 Min": Timeframe.FIVE_MIN,
|
||||
"5 min": Timeframe.FIVE_MIN,
|
||||
"5Min": Timeframe.FIVE_MIN,
|
||||
"15 Min": Timeframe.FIFTEEN_MIN,
|
||||
"15 min": Timeframe.FIFTEEN_MIN,
|
||||
"15Min": Timeframe.FIFTEEN_MIN,
|
||||
}
|
||||
|
||||
|
||||
def _extract_asset(title: str) -> Optional[Asset]:
|
||||
"""Extract the crypto asset from an event/market title."""
|
||||
for keyword, asset in SUPPORTED_ASSETS.items():
|
||||
if keyword in title:
|
||||
return asset
|
||||
return None
|
||||
|
||||
|
||||
def _extract_timeframe(title: str) -> Optional[Timeframe]:
|
||||
"""Extract the timeframe from an event/market title."""
|
||||
for pattern, timeframe in TIMEFRAME_PATTERNS.items():
|
||||
if pattern in title:
|
||||
return timeframe
|
||||
return None
|
||||
|
||||
|
||||
def _extract_token_ids(market: dict) -> tuple[str, str] | None:
|
||||
"""Extract (up_token_id, down_token_id) from a market dict.
|
||||
|
||||
The Gamma API returns token IDs as a JSON-encoded list or a
|
||||
``clobTokenIds`` field. The first token is conventionally the
|
||||
"Up" outcome, the second is "Down". We also inspect ``outcomes``
|
||||
to verify ordering when available.
|
||||
"""
|
||||
tokens: list[str] = []
|
||||
outcomes: list[str] = []
|
||||
|
||||
# Token IDs
|
||||
raw_tokens = market.get("clobTokenIds")
|
||||
if isinstance(raw_tokens, list):
|
||||
tokens = [str(t) for t in raw_tokens]
|
||||
elif isinstance(raw_tokens, str):
|
||||
# Sometimes returned as JSON-encoded string "[\"0xabc\",\"0xdef\"]"
|
||||
try:
|
||||
import json
|
||||
tokens = [str(t) for t in json.loads(raw_tokens)]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
# Outcomes
|
||||
raw_outcomes = market.get("outcomes")
|
||||
if isinstance(raw_outcomes, list):
|
||||
outcomes = [str(o).upper().strip() for o in raw_outcomes]
|
||||
elif isinstance(raw_outcomes, str):
|
||||
try:
|
||||
import json
|
||||
outcomes = [str(o).upper().strip() for o in json.loads(raw_outcomes)]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if len(tokens) < 2:
|
||||
return None
|
||||
|
||||
# Determine which token is Up and which is Down
|
||||
up_token: str = tokens[0]
|
||||
down_token: str = tokens[1]
|
||||
|
||||
if len(outcomes) >= 2:
|
||||
if outcomes[0] == "DOWN" and outcomes[1] == "UP":
|
||||
up_token, down_token = tokens[1], tokens[0]
|
||||
|
||||
return up_token, down_token
|
||||
|
||||
|
||||
class MarketDiscovery:
|
||||
"""Discovers active 5-min and 15-min Up/Down crypto markets on Polymarket.
|
||||
|
||||
Uses the Gamma API to fetch events tagged as crypto, then filters for
|
||||
short-duration binary Up/Down markets on BTC, ETH, or SOL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
on_new_markets: Optional[Callable[[list[ActiveMarket]], None]] = None,
|
||||
) -> None:
|
||||
self._external_session = session
|
||||
self._session: Optional[aiohttp.ClientSession] = session
|
||||
self._on_new_markets = on_new_markets
|
||||
self._seen_condition_ids: set[str] = set()
|
||||
self._log = logger.bind(component="MarketDiscovery")
|
||||
|
||||
async def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def _close_session_if_owned(self) -> None:
|
||||
"""Close the HTTP session only if we created it ourselves."""
|
||||
if self._external_session is None and self._session is not None and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core discovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def discover(self) -> list[ActiveMarket]:
|
||||
"""Fetch events from the Gamma API and return matching ActiveMarket instances."""
|
||||
session = await self._ensure_session()
|
||||
self._log.info("gamma_api_fetch_start")
|
||||
|
||||
try:
|
||||
async with session.get(GAMMA_API_URL, params=GAMMA_QUERY_PARAMS, timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
if resp.status == 429:
|
||||
retry_after = float(resp.headers.get("Retry-After", "5"))
|
||||
self._log.warning("gamma_api_rate_limited", retry_after=retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
return []
|
||||
|
||||
if resp.status != 200:
|
||||
self._log.error("gamma_api_http_error", status=resp.status, reason=resp.reason)
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError) as exc:
|
||||
self._log.error("gamma_api_json_parse_error", error=str(exc))
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._log.error("gamma_api_timeout")
|
||||
return []
|
||||
except aiohttp.ClientError as exc:
|
||||
self._log.error("gamma_api_client_error", error=str(exc))
|
||||
return []
|
||||
|
||||
events: list[dict] = data if isinstance(data, list) else []
|
||||
markets = self._filter_and_parse(events)
|
||||
self._log.info("gamma_api_fetch_done", total_events=len(events), matched_markets=len(markets))
|
||||
return markets
|
||||
|
||||
def _filter_and_parse(self, events: list[dict]) -> list[ActiveMarket]:
|
||||
"""Filter events and their sub-markets, returning ActiveMarket instances."""
|
||||
results: list[ActiveMarket] = []
|
||||
|
||||
for event in events:
|
||||
title: str = event.get("title", "")
|
||||
|
||||
# Must be an Up or Down style market
|
||||
if not re.search(r"[Uu]p\s+or\s+[Dd]own", title):
|
||||
continue
|
||||
|
||||
# Must match a supported timeframe
|
||||
timeframe = _extract_timeframe(title)
|
||||
if timeframe is None:
|
||||
continue
|
||||
|
||||
# Must be for a supported asset
|
||||
asset = _extract_asset(title)
|
||||
if asset is None:
|
||||
continue
|
||||
|
||||
# Process each sub-market within the event
|
||||
sub_markets: list[dict] = event.get("markets", [])
|
||||
if not sub_markets:
|
||||
# Some API shapes embed market data at the event level
|
||||
sub_markets = [event]
|
||||
|
||||
for mkt in sub_markets:
|
||||
if not mkt.get("active", False):
|
||||
continue
|
||||
if not mkt.get("enableOrderBook", False):
|
||||
continue
|
||||
|
||||
condition_id: str = mkt.get("conditionId", "") or mkt.get("condition_id", "")
|
||||
if not condition_id:
|
||||
continue
|
||||
|
||||
token_pair = _extract_token_ids(mkt)
|
||||
if token_pair is None:
|
||||
self._log.debug("skip_market_missing_tokens", condition_id=condition_id)
|
||||
continue
|
||||
|
||||
up_token, down_token = token_pair
|
||||
end_date: str = mkt.get("endDate", "") or mkt.get("end_date_iso", "") or ""
|
||||
|
||||
results.append(
|
||||
ActiveMarket(
|
||||
condition_id=condition_id,
|
||||
up_token_id=up_token,
|
||||
down_token_id=down_token,
|
||||
asset=asset,
|
||||
timeframe=timeframe,
|
||||
end_date=end_date,
|
||||
question=mkt.get("question", title),
|
||||
description=mkt.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous discovery loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def discover_loop(self, interval_sec: float = 30) -> None:
|
||||
"""Continuously discover markets and invoke the callback with new ones.
|
||||
|
||||
Runs indefinitely. Each iteration sleeps for *interval_sec* seconds
|
||||
after completing a fetch cycle. Previously seen ``condition_id`` values
|
||||
are cached so the callback only receives genuinely new markets.
|
||||
"""
|
||||
self._log.info("discover_loop_start", interval_sec=interval_sec)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
all_markets = await self.discover()
|
||||
new_markets = [
|
||||
m for m in all_markets
|
||||
if m.condition_id not in self._seen_condition_ids
|
||||
]
|
||||
|
||||
for m in new_markets:
|
||||
self._seen_condition_ids.add(m.condition_id)
|
||||
|
||||
if new_markets:
|
||||
self._log.info("new_markets_found", count=len(new_markets))
|
||||
if self._on_new_markets is not None:
|
||||
self._on_new_markets(new_markets)
|
||||
|
||||
except Exception:
|
||||
self._log.exception("discover_loop_iteration_error")
|
||||
|
||||
await asyncio.sleep(interval_sec)
|
||||
finally:
|
||||
await self._close_session_if_owned()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""Clear the set of previously seen condition IDs."""
|
||||
self._seen_condition_ids.clear()
|
||||
self._log.info("cache_cleared")
|
||||
186
src/market/oracle.py
Normal file
186
src/market/oracle.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Chainlink Oracle monitor — tracks oracle update latency via web3.
|
||||
|
||||
Monitors the delay between real CEX prices and Chainlink oracle updates
|
||||
on Polygon, which is the core edge in temporal arbitrage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
# Chainlink AggregatorV3Interface ABI (latestRoundData only)
|
||||
AGGREGATOR_ABI = [
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "latestRoundData",
|
||||
"outputs": [
|
||||
{"name": "roundId", "type": "uint80"},
|
||||
{"name": "answer", "type": "int256"},
|
||||
{"name": "startedAt", "type": "uint256"},
|
||||
{"name": "updatedAt", "type": "uint256"},
|
||||
{"name": "answeredInRound", "type": "uint80"},
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "decimals",
|
||||
"outputs": [{"name": "", "type": "uint8"}],
|
||||
"stateMutability": "view",
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class OracleMonitor:
|
||||
"""Monitors Chainlink oracle update frequency and latency.
|
||||
|
||||
The oracle typically updates every ~10-30 seconds or on 0.5% deviation.
|
||||
Tracking this helps calibrate the temporal arbitrage opportunity window.
|
||||
"""
|
||||
|
||||
# Chainlink price feed addresses on Polygon
|
||||
FEEDS = {
|
||||
"BTC": "0xc907E116054Ad103354f2D350FD2514433D57F6f",
|
||||
"ETH": "0xF9680D99D6C9589e2a93a78A04A279e509205945",
|
||||
"SOL": "0x10C8264C0935b3B9870013e057f330Ff3e9C56dC",
|
||||
}
|
||||
|
||||
def __init__(self, rpc_url: Optional[str] = None) -> None:
|
||||
self._rpc_url = rpc_url or "https://polygon.drpc.org"
|
||||
self._w3 = None
|
||||
self._contracts: dict[str, object] = {}
|
||||
self._decimals: dict[str, int] = {}
|
||||
self._last_oracle_prices: dict[str, float] = {}
|
||||
self._last_oracle_timestamps: dict[str, float] = {}
|
||||
self._last_oracle_round_ids: dict[str, int] = {}
|
||||
self._update_intervals: dict[str, list[float]] = {
|
||||
"BTC": [], "ETH": [], "SOL": []
|
||||
}
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize web3 connection and contract instances."""
|
||||
try:
|
||||
from web3 import Web3
|
||||
self._w3 = Web3(Web3.HTTPProvider(self._rpc_url))
|
||||
|
||||
if not self._w3.is_connected():
|
||||
log.error("oracle_web3_not_connected", rpc_url=self._rpc_url)
|
||||
return False
|
||||
|
||||
for asset, address in self.FEEDS.items():
|
||||
checksum = self._w3.to_checksum_address(address)
|
||||
contract = self._w3.eth.contract(
|
||||
address=checksum, abi=AGGREGATOR_ABI
|
||||
)
|
||||
self._contracts[asset] = contract
|
||||
self._decimals[asset] = contract.functions.decimals().call()
|
||||
|
||||
self._initialized = True
|
||||
log.info(
|
||||
"oracle_initialized",
|
||||
rpc_url=self._rpc_url,
|
||||
assets=list(self._contracts.keys()),
|
||||
)
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
log.warning("oracle_web3_not_installed", msg="pip install web3")
|
||||
return False
|
||||
except Exception:
|
||||
log.exception("oracle_init_failed")
|
||||
return False
|
||||
|
||||
async def get_oracle_price(self, asset: str) -> Optional[float]:
|
||||
"""Fetch the latest oracle price for an asset from Chainlink."""
|
||||
if not self._initialized or asset not in self._contracts:
|
||||
return self._last_oracle_prices.get(asset)
|
||||
|
||||
try:
|
||||
contract = self._contracts[asset]
|
||||
result = contract.functions.latestRoundData().call()
|
||||
|
||||
round_id, answer, started_at, updated_at, answered_in_round = result
|
||||
decimals = self._decimals.get(asset, 8)
|
||||
price = answer / (10 ** decimals)
|
||||
|
||||
# Record the update
|
||||
self.record_oracle_update(asset, price, updated_at)
|
||||
self._last_oracle_round_ids[asset] = round_id
|
||||
|
||||
return price
|
||||
|
||||
except Exception:
|
||||
log.exception("oracle_fetch_failed", asset=asset)
|
||||
return self._last_oracle_prices.get(asset)
|
||||
|
||||
def record_oracle_update(self, asset: str, price: float, timestamp: float) -> None:
|
||||
"""Record an observed oracle price update."""
|
||||
prev_ts = self._last_oracle_timestamps.get(asset)
|
||||
if prev_ts is not None and timestamp > prev_ts:
|
||||
interval = timestamp - prev_ts
|
||||
intervals = self._update_intervals[asset]
|
||||
intervals.append(interval)
|
||||
# Keep last 100 intervals
|
||||
if len(intervals) > 100:
|
||||
intervals.pop(0)
|
||||
|
||||
self._last_oracle_prices[asset] = price
|
||||
self._last_oracle_timestamps[asset] = timestamp
|
||||
|
||||
def get_avg_update_interval(self, asset: str) -> Optional[float]:
|
||||
"""Get average oracle update interval in seconds."""
|
||||
intervals = self._update_intervals.get(asset, [])
|
||||
if not intervals:
|
||||
return None
|
||||
return sum(intervals) / len(intervals)
|
||||
|
||||
def get_estimated_lag(self, asset: str) -> Optional[float]:
|
||||
"""Estimate current oracle lag (time since last known update)."""
|
||||
last_ts = self._last_oracle_timestamps.get(asset)
|
||||
if last_ts is None:
|
||||
return None
|
||||
return time.time() - last_ts
|
||||
|
||||
def get_oracle_vs_cex_deviation(
|
||||
self, asset: str, cex_price: float
|
||||
) -> Optional[float]:
|
||||
"""Calculate percentage deviation between oracle and CEX price."""
|
||||
oracle_price = self._last_oracle_prices.get(asset)
|
||||
if oracle_price is None or oracle_price <= 0:
|
||||
return None
|
||||
return (cex_price - oracle_price) / oracle_price * 100
|
||||
|
||||
async def poll_loop(self, interval: float = 5.0) -> None:
|
||||
"""Continuously poll oracle prices."""
|
||||
if not self._initialized:
|
||||
log.warning("oracle_poll_not_initialized")
|
||||
return
|
||||
|
||||
while True:
|
||||
for asset in self.FEEDS:
|
||||
try:
|
||||
await self.get_oracle_price(asset)
|
||||
except Exception:
|
||||
log.exception("oracle_poll_error", asset=asset)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return {
|
||||
asset: {
|
||||
"last_price": self._last_oracle_prices.get(asset),
|
||||
"last_round_id": self._last_oracle_round_ids.get(asset),
|
||||
"avg_interval": round(avg, 2) if (avg := self.get_avg_update_interval(asset)) else None,
|
||||
"estimated_lag": round(lag, 2) if (lag := self.get_estimated_lag(asset)) else None,
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
for asset in self.FEEDS
|
||||
}
|
||||
193
src/market/window_tracker.py
Normal file
193
src/market/window_tracker.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Track 5-minute and 15-minute price windows for BTC, ETH, SOL.
|
||||
|
||||
Maintains six simultaneous windows (3 assets x 2 timeframes), capturing the
|
||||
start price from the first CEX tick after each window opens and tracking the
|
||||
current price throughout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.data.models import ActiveMarket, Asset, Timeframe, WindowState
|
||||
|
||||
log = structlog.get_logger(__name__)
|
||||
|
||||
# Timeframe durations in seconds.
|
||||
_TIMEFRAME_SECONDS: dict[Timeframe, int] = {
|
||||
Timeframe.FIVE_MIN: 5 * 60,
|
||||
Timeframe.FIFTEEN_MIN: 15 * 60,
|
||||
}
|
||||
|
||||
|
||||
def _window_bounds(timestamp: float, timeframe: Timeframe) -> tuple[float, float]:
|
||||
"""Return (start, end) UTC-aligned window boundaries for *timestamp*."""
|
||||
interval = _TIMEFRAME_SECONDS[timeframe]
|
||||
start = math.floor(timestamp / interval) * interval
|
||||
return float(start), float(start + interval)
|
||||
|
||||
|
||||
class WindowTracker:
|
||||
"""Track clock-aligned price windows for multiple assets and timeframes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
assets:
|
||||
Assets to track. Defaults to BTC, ETH, SOL.
|
||||
timeframes:
|
||||
Timeframes to track. Defaults to 5M and 15M.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
assets: Optional[list[Asset]] = None,
|
||||
timeframes: Optional[list[Timeframe]] = None,
|
||||
) -> None:
|
||||
self._assets: list[Asset] = assets or list(Asset)
|
||||
self._timeframes: list[Timeframe] = timeframes or list(Timeframe)
|
||||
|
||||
# Keyed by (asset, timeframe).
|
||||
self._windows: dict[tuple[Asset, Timeframe], WindowState] = {}
|
||||
self._callbacks: list[Callable[[WindowState], None]] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_price(self, asset: str, price: float, timestamp: float) -> None:
|
||||
"""Process an incoming CEX price tick.
|
||||
|
||||
If *timestamp* falls within a new window that we haven't initialised
|
||||
yet (or a different window from the one we're tracking), a fresh
|
||||
``WindowState`` is created and any registered callbacks are fired.
|
||||
|
||||
The first tick inside a window sets ``start_price``. Every tick
|
||||
updates ``current_price``.
|
||||
"""
|
||||
asset_enum = Asset(asset)
|
||||
|
||||
for tf in self._timeframes:
|
||||
key = (asset_enum, tf)
|
||||
win_start, win_end = _window_bounds(timestamp, tf)
|
||||
|
||||
existing = self._windows.get(key)
|
||||
|
||||
# Detect window transition (or first-ever window).
|
||||
if existing is None or existing.window_start_time != win_start:
|
||||
# Carry over any linked market from the previous window.
|
||||
linked_market = existing.market if existing else None
|
||||
|
||||
new_window = WindowState(
|
||||
asset=asset_enum,
|
||||
timeframe=tf,
|
||||
window_start_time=win_start,
|
||||
window_end_time=win_end,
|
||||
start_price=price,
|
||||
current_price=price,
|
||||
market=linked_market,
|
||||
)
|
||||
self._windows[key] = new_window
|
||||
|
||||
log.info(
|
||||
"window_transition",
|
||||
asset=asset_enum.value,
|
||||
timeframe=tf.value,
|
||||
window_start=win_start,
|
||||
window_end=win_end,
|
||||
start_price=price,
|
||||
)
|
||||
|
||||
# Fire callbacks with the COMPLETED window so that
|
||||
# subscribers can read its final price_change_pct.
|
||||
# On the very first window there is no completed window
|
||||
# to report, so fire with the new one instead.
|
||||
completed = existing if existing is not None else new_window
|
||||
self._fire_callbacks(completed)
|
||||
else:
|
||||
# Same window — update current price. If start_price was
|
||||
# somehow not captured (shouldn't happen with this logic,
|
||||
# but defensive), fill it with the first available price.
|
||||
if existing.start_price is None:
|
||||
existing.start_price = price
|
||||
log.info(
|
||||
"late_start_price",
|
||||
asset=asset_enum.value,
|
||||
timeframe=tf.value,
|
||||
price=price,
|
||||
)
|
||||
existing.current_price = price
|
||||
|
||||
def get_window(self, asset: str, timeframe: str) -> Optional[WindowState]:
|
||||
"""Return the current window state for an asset/timeframe pair."""
|
||||
key = (Asset(asset), Timeframe(timeframe))
|
||||
return self._windows.get(key)
|
||||
|
||||
def get_all_active_windows(self) -> list[WindowState]:
|
||||
"""Return every currently tracked window."""
|
||||
return list(self._windows.values())
|
||||
|
||||
def is_window_expired(self, asset: str, timeframe: str) -> bool:
|
||||
"""Check whether the tracked window has already ended.
|
||||
|
||||
Returns ``True`` when the window end time is in the past **or**
|
||||
when no window has been initialised for the given pair.
|
||||
"""
|
||||
import time
|
||||
|
||||
key = (Asset(asset), Timeframe(timeframe))
|
||||
window = self._windows.get(key)
|
||||
if window is None:
|
||||
return True
|
||||
return time.time() >= window.window_end_time
|
||||
|
||||
def link_market(
|
||||
self, asset: str, timeframe: str, market: ActiveMarket
|
||||
) -> None:
|
||||
"""Associate a Polymarket market with the current window.
|
||||
|
||||
If no window exists yet for the pair, one is **not** created — the
|
||||
market will be linked once the first price tick arrives and triggers
|
||||
a window transition.
|
||||
"""
|
||||
key = (Asset(asset), Timeframe(timeframe))
|
||||
window = self._windows.get(key)
|
||||
if window is not None:
|
||||
window.market = market
|
||||
log.info(
|
||||
"market_linked",
|
||||
asset=asset,
|
||||
timeframe=timeframe,
|
||||
condition_id=market.condition_id,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
"link_market_no_window",
|
||||
asset=asset,
|
||||
timeframe=timeframe,
|
||||
condition_id=market.condition_id,
|
||||
)
|
||||
|
||||
def on_window_change(self, callback: Callable[[WindowState], None]) -> None:
|
||||
"""Register a callback invoked whenever a new window starts.
|
||||
|
||||
The callback receives the newly-created ``WindowState``.
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fire_callbacks(self, window: WindowState) -> None:
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
cb(window)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"window_change_callback_error",
|
||||
asset=window.asset.value,
|
||||
timeframe=window.timeframe.value,
|
||||
)
|
||||
7
src/risk/__init__.py
Normal file
7
src/risk/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Risk management modules."""
|
||||
|
||||
from src.risk.fee_calculator import FeeCalculator
|
||||
from src.risk.position_sizer import PositionSizer
|
||||
from src.risk.risk_manager import RiskManager
|
||||
|
||||
__all__ = ["PositionSizer", "RiskManager", "FeeCalculator"]
|
||||
90
src/risk/fee_calculator.py
Normal file
90
src/risk/fee_calculator.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Fee Calculator — computes trading fees for Polymarket CLOB.
|
||||
|
||||
Fee structure:
|
||||
- 5-minute markets: max 1.56% taker fee
|
||||
- 15-minute markets: max 3% taker fee
|
||||
- Fee is applied to potential profit, not to cost basis
|
||||
- Maker rebate program available for orders resting >= 3.5 seconds
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from src.config import FeesConfig
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class FeeCalculator:
|
||||
"""Calculates expected fees for Polymarket trades."""
|
||||
|
||||
def __init__(self, fees_config: FeesConfig) -> None:
|
||||
self.fees = fees_config
|
||||
|
||||
def taker_fee(self, timeframe: str, entry_price: float, size: int) -> float:
|
||||
"""Calculate the taker fee for a trade.
|
||||
|
||||
The fee is applied to the potential profit (payout - cost),
|
||||
not the cost itself.
|
||||
|
||||
Args:
|
||||
timeframe: "5M" or "15M".
|
||||
entry_price: Entry price per share (e.g., 0.50).
|
||||
size: Number of shares.
|
||||
|
||||
Returns:
|
||||
Fee in USD.
|
||||
"""
|
||||
fee_rate = self.fees.fee_for_timeframe(timeframe)
|
||||
payout = size * 1.0 # $1 per share if winning
|
||||
cost = entry_price * size
|
||||
profit = payout - cost
|
||||
return max(0, profit * fee_rate)
|
||||
|
||||
def breakeven_price(self, timeframe: str, entry_price: float) -> float:
|
||||
"""Calculate the breakeven probability needed to cover fees.
|
||||
|
||||
If you buy at `entry_price`, you need the true probability
|
||||
to exceed this value to be profitable after fees.
|
||||
"""
|
||||
fee_rate = self.fees.fee_for_timeframe(timeframe)
|
||||
# Solving: prob * (1 - fee_rate) * (1 - entry_price) > entry_price * (1 - prob)
|
||||
# prob > entry_price / (1 - fee_rate * (1 - entry_price))
|
||||
denominator = 1.0 - fee_rate * (1.0 - entry_price)
|
||||
if denominator <= 0:
|
||||
return 1.0
|
||||
return entry_price / denominator
|
||||
|
||||
def net_payout(self, timeframe: str, entry_price: float, size: int, won: bool) -> float:
|
||||
"""Calculate net payout after fees.
|
||||
|
||||
Args:
|
||||
timeframe: "5M" or "15M".
|
||||
entry_price: Entry price per share.
|
||||
size: Number of shares.
|
||||
won: Whether the outcome was correct.
|
||||
|
||||
Returns:
|
||||
Net payout in USD (can be negative for losses).
|
||||
"""
|
||||
cost = entry_price * size
|
||||
|
||||
if won:
|
||||
gross_payout = size * 1.0
|
||||
fee = self.taker_fee(timeframe, entry_price, size)
|
||||
return gross_payout - cost - fee
|
||||
else:
|
||||
return -cost # Total loss, no fee on losing side
|
||||
|
||||
def expected_value(
|
||||
self, timeframe: str, entry_price: float, estimated_prob: float, size: int
|
||||
) -> float:
|
||||
"""Calculate expected value of a trade.
|
||||
|
||||
EV = prob * net_win_payout + (1-prob) * net_loss
|
||||
"""
|
||||
win_payout = self.net_payout(timeframe, entry_price, size, won=True)
|
||||
loss_payout = self.net_payout(timeframe, entry_price, size, won=False)
|
||||
|
||||
return estimated_prob * win_payout + (1 - estimated_prob) * loss_payout
|
||||
90
src/risk/position_sizer.py
Normal file
90
src/risk/position_sizer.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Kelly Criterion position sizer with risk adjustments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import FeesConfig, RiskConfig
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class PositionSizer:
|
||||
"""Determines optimal position size using Kelly Criterion
|
||||
with multiple safety caps and adjustments.
|
||||
"""
|
||||
|
||||
def __init__(self, risk_config: RiskConfig, fees_config: FeesConfig) -> None:
|
||||
self.risk = risk_config
|
||||
self.fees = fees_config
|
||||
|
||||
def calculate(
|
||||
self,
|
||||
estimated_prob: float,
|
||||
poly_price: float,
|
||||
timeframe: str,
|
||||
balance: float,
|
||||
current_exposure: float = 0.0,
|
||||
) -> int:
|
||||
"""Calculate position size in shares.
|
||||
|
||||
Args:
|
||||
estimated_prob: Our estimated true probability (0-1).
|
||||
poly_price: Polymarket entry price (0-1).
|
||||
timeframe: "5M" or "15M" for fee lookup.
|
||||
balance: Available balance in USD.
|
||||
current_exposure: Current total exposure in USD.
|
||||
|
||||
Returns:
|
||||
Number of shares to buy (0 if no trade).
|
||||
"""
|
||||
if poly_price <= 0 or poly_price >= 1.0 or estimated_prob <= 0:
|
||||
return 0
|
||||
|
||||
fee_rate = self.fees.fee_for_timeframe(timeframe)
|
||||
|
||||
# Kelly Criterion: f* = (b*p - q) / b
|
||||
b = (1.0 / poly_price) - 1.0 # Payout odds
|
||||
p = estimated_prob
|
||||
q = 1.0 - p
|
||||
|
||||
if b <= 0:
|
||||
return 0
|
||||
|
||||
kelly_raw = (b * p - q) / b
|
||||
if kelly_raw <= 0:
|
||||
return 0
|
||||
|
||||
# Cap 1: Kelly fraction cap (default 25%)
|
||||
kelly_adj = min(kelly_raw, self.risk.kelly_fraction_cap)
|
||||
|
||||
# Cap 2: Use half-Kelly for additional safety
|
||||
kelly_adj *= 0.5
|
||||
|
||||
# Cap 3: Dollar size limits
|
||||
dollar_size = balance * kelly_adj
|
||||
dollar_size = min(dollar_size, self.risk.max_position_per_market_usd)
|
||||
|
||||
# Cap 4: Total exposure limit
|
||||
remaining_capacity = self.risk.max_total_exposure_usd - current_exposure
|
||||
if remaining_capacity <= 0:
|
||||
return 0
|
||||
dollar_size = min(dollar_size, remaining_capacity)
|
||||
|
||||
# Convert to shares
|
||||
shares = int(dollar_size / poly_price)
|
||||
|
||||
log.debug(
|
||||
"position_sized",
|
||||
kelly_raw=round(kelly_raw, 4),
|
||||
kelly_adj=round(kelly_adj, 4),
|
||||
dollar_size=round(dollar_size, 2),
|
||||
shares=shares,
|
||||
prob=round(estimated_prob, 4),
|
||||
price=poly_price,
|
||||
)
|
||||
|
||||
return max(0, shares)
|
||||
161
src/risk/risk_manager.py
Normal file
161
src/risk/risk_manager.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Risk Manager — enforces trading limits and circuit breakers.
|
||||
|
||||
Monitors exposure, daily PnL, and position counts to prevent
|
||||
catastrophic losses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import RiskConfig
|
||||
from src.data.db import TradeDB
|
||||
from src.execution.position_tracker import PositionTracker
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Central risk management.
|
||||
|
||||
Tracks daily PnL, total exposure, and position count.
|
||||
Can halt trading if limits are breached.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
risk_config: RiskConfig,
|
||||
position_tracker: PositionTracker,
|
||||
trade_db: TradeDB,
|
||||
) -> None:
|
||||
self.risk = risk_config
|
||||
self.tracker = position_tracker
|
||||
self.db = trade_db
|
||||
|
||||
self._halted = False
|
||||
self._halt_reason: str = ""
|
||||
self._daily_pnl_cache: float = 0.0
|
||||
self._last_pnl_check: float = 0.0
|
||||
|
||||
@property
|
||||
def is_halted(self) -> bool:
|
||||
return self._halted
|
||||
|
||||
@property
|
||||
def halt_reason(self) -> str:
|
||||
return self._halt_reason
|
||||
|
||||
def check_all(self) -> bool:
|
||||
"""Run all risk checks. Returns True if trading is allowed."""
|
||||
if self._halted:
|
||||
return False
|
||||
|
||||
if not self._check_daily_loss():
|
||||
return False
|
||||
if not self._check_total_exposure():
|
||||
return False
|
||||
if not self._check_position_count():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def can_open_position(self, additional_usd: float) -> bool:
|
||||
"""Check if a new position of given size can be opened."""
|
||||
if self._halted:
|
||||
return False
|
||||
|
||||
if not self.tracker.is_exposure_ok(additional_usd):
|
||||
log.warning(
|
||||
"risk_exposure_limit",
|
||||
current=round(self.tracker.total_exposure, 2),
|
||||
additional=round(additional_usd, 2),
|
||||
limit=self.risk.max_total_exposure_usd,
|
||||
)
|
||||
return False
|
||||
|
||||
if self.tracker.position_count >= self.risk.max_concurrent_positions:
|
||||
log.warning(
|
||||
"risk_position_limit",
|
||||
current=self.tracker.position_count,
|
||||
limit=self.risk.max_concurrent_positions,
|
||||
)
|
||||
return False
|
||||
|
||||
return self.check_all()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Individual checks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _check_daily_loss(self) -> bool:
|
||||
"""Check if daily loss limit has been breached."""
|
||||
now = time.time()
|
||||
# Cache daily PnL check (recompute every 10 seconds)
|
||||
if now - self._last_pnl_check > 10:
|
||||
self._daily_pnl_cache = self.db.get_today_pnl()
|
||||
self._last_pnl_check = now
|
||||
|
||||
daily_pnl = self._daily_pnl_cache + self.tracker.total_unrealized_pnl
|
||||
|
||||
if daily_pnl < -self.risk.max_daily_loss_usd:
|
||||
self._halt("daily_loss_limit_breached",
|
||||
f"Daily PnL ${daily_pnl:.2f} < -${self.risk.max_daily_loss_usd}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_total_exposure(self) -> bool:
|
||||
"""Check if total exposure is within limits."""
|
||||
exposure = self.tracker.total_exposure
|
||||
if exposure > self.risk.max_total_exposure_usd:
|
||||
self._halt("exposure_limit_breached",
|
||||
f"Exposure ${exposure:.2f} > ${self.risk.max_total_exposure_usd}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_position_count(self) -> bool:
|
||||
"""Check if position count is within limits."""
|
||||
count = self.tracker.position_count
|
||||
if count > self.risk.max_concurrent_positions:
|
||||
log.warning("position_count_exceeded", count=count, limit=self.risk.max_concurrent_positions)
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Halt/resume
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _halt(self, halt_event: str, reason: str) -> None:
|
||||
"""Halt all trading."""
|
||||
self._halted = True
|
||||
self._halt_reason = reason
|
||||
log.critical("trading_halted", halt_event=halt_event, reason=reason)
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Manually resume trading (use with caution)."""
|
||||
log.warning("trading_resumed", previous_halt_reason=self._halt_reason)
|
||||
self._halted = False
|
||||
self._halt_reason = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reporting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_risk_summary(self) -> dict:
|
||||
"""Return current risk state for dashboard/logging."""
|
||||
return {
|
||||
"halted": self._halted,
|
||||
"halt_reason": self._halt_reason,
|
||||
"daily_pnl": round(self._daily_pnl_cache, 2),
|
||||
"total_exposure": round(self.tracker.total_exposure, 2),
|
||||
"exposure_pct": round(
|
||||
self.tracker.total_exposure / self.risk.max_total_exposure_usd * 100, 1
|
||||
),
|
||||
"positions": self.tracker.position_count,
|
||||
"max_positions": self.risk.max_concurrent_positions,
|
||||
"daily_loss_limit": self.risk.max_daily_loss_usd,
|
||||
"exposure_limit": self.risk.max_total_exposure_usd,
|
||||
}
|
||||
13
src/strategy/__init__.py
Normal file
13
src/strategy/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Strategy modules for the Polymarket Arbitrage Bot."""
|
||||
|
||||
from src.strategy.signal import SignalAggregator
|
||||
from src.strategy.spread_capture import SpreadCaptureStrategy
|
||||
from src.strategy.sum_to_one import SumToOneStrategy
|
||||
from src.strategy.temporal_arb import TemporalArbStrategy
|
||||
|
||||
__all__ = [
|
||||
"TemporalArbStrategy",
|
||||
"SumToOneStrategy",
|
||||
"SpreadCaptureStrategy",
|
||||
"SignalAggregator",
|
||||
]
|
||||
196
src/strategy/signal.py
Normal file
196
src/strategy/signal.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Signal Aggregator — coordinates all strategies and manages signal flow.
|
||||
|
||||
Receives market data from feeds and window tracker, evaluates all active
|
||||
strategies, and emits consolidated signals for the execution engine.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import Config
|
||||
from src.data.models import (
|
||||
ActiveMarket,
|
||||
Direction,
|
||||
OrderBookSnapshot,
|
||||
Signal,
|
||||
WindowState,
|
||||
)
|
||||
from src.strategy.spread_capture import SpreadCaptureStrategy
|
||||
from src.strategy.sum_to_one import SumToOneStrategy
|
||||
from src.strategy.temporal_arb import TemporalArbStrategy
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
SignalCallback = Callable[[Signal], None]
|
||||
|
||||
|
||||
class SignalAggregator:
|
||||
"""Coordinates all strategies and emits trading signals.
|
||||
|
||||
Connects to the WindowTracker and PolymarketFeed to receive
|
||||
real-time data, runs strategy evaluations, and dispatches
|
||||
signals to registered handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, balance: float = 10000.0) -> None:
|
||||
self.config = config
|
||||
self._callbacks: list[SignalCallback] = []
|
||||
|
||||
# Initialize strategies
|
||||
self.temporal_arb = TemporalArbStrategy(
|
||||
arb_config=config.temporal_arb,
|
||||
risk_config=config.risk,
|
||||
fees_config=config.fees,
|
||||
balance=balance,
|
||||
)
|
||||
self.sum_to_one = SumToOneStrategy(
|
||||
sto_config=config.sum_to_one,
|
||||
risk_config=config.risk,
|
||||
fees_config=config.fees,
|
||||
balance=balance,
|
||||
)
|
||||
self.spread_capture = SpreadCaptureStrategy(
|
||||
spread_config=config.spread_capture,
|
||||
risk_config=config.risk,
|
||||
fees_config=config.fees,
|
||||
balance=balance,
|
||||
)
|
||||
|
||||
# Rate limiting: don't evaluate the same market more than once per second
|
||||
self._last_eval_time: dict[str, float] = {}
|
||||
self._min_eval_interval = 0.5 # seconds
|
||||
|
||||
# Track recent signals to avoid duplicates
|
||||
self._recent_signals: dict[str, float] = {}
|
||||
self._signal_cooldown = 30.0 # seconds between signals for same market
|
||||
|
||||
def on_signal(self, callback: SignalCallback) -> None:
|
||||
"""Register a callback for emitted signals."""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def _emit(self, signal: Signal) -> None:
|
||||
"""Dispatch a signal to all registered callbacks."""
|
||||
# Dedup: don't emit same direction/asset/timeframe within cooldown
|
||||
key = f"{signal.asset.value}:{signal.timeframe.value}:{signal.direction.value}"
|
||||
now = time.time()
|
||||
if key in self._recent_signals:
|
||||
if now - self._recent_signals[key] < self._signal_cooldown:
|
||||
log.debug("signal_deduplicated", key=key)
|
||||
return
|
||||
|
||||
self._recent_signals[key] = now
|
||||
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
cb(signal)
|
||||
except Exception:
|
||||
log.exception("signal_callback_error")
|
||||
|
||||
async def on_price_tick(
|
||||
self,
|
||||
symbol: str,
|
||||
cex_price: float,
|
||||
window: Optional[WindowState],
|
||||
orderbooks: dict[str, OrderBookSnapshot],
|
||||
) -> None:
|
||||
"""Called on each CEX price tick with current window and orderbook state.
|
||||
|
||||
Evaluates temporal arb and sum-to-one strategies.
|
||||
"""
|
||||
if window is None or window.market is None:
|
||||
return
|
||||
if window.start_price is None:
|
||||
return
|
||||
|
||||
market = window.market
|
||||
timeframe = window.timeframe.value
|
||||
|
||||
# Rate limiting
|
||||
eval_key = f"{symbol}:{timeframe}"
|
||||
now = time.time()
|
||||
if eval_key in self._last_eval_time:
|
||||
if now - self._last_eval_time[eval_key] < self._min_eval_interval:
|
||||
return
|
||||
self._last_eval_time[eval_key] = now
|
||||
|
||||
# Get Polymarket prices from orderbooks
|
||||
up_book = orderbooks.get(market.up_token_id)
|
||||
down_book = orderbooks.get(market.down_token_id)
|
||||
|
||||
poly_up_ask = up_book.best_ask if up_book else None
|
||||
poly_down_ask = down_book.best_ask if down_book else None
|
||||
|
||||
# --- Temporal Arbitrage ---
|
||||
if self.config.temporal_arb.enabled:
|
||||
signal = await self.temporal_arb.evaluate(
|
||||
symbol=symbol,
|
||||
cex_price=cex_price,
|
||||
window_start_price=window.start_price,
|
||||
window_end_time=window.window_end_time,
|
||||
poly_up_ask=poly_up_ask,
|
||||
poly_down_ask=poly_down_ask,
|
||||
up_token_id=market.up_token_id,
|
||||
down_token_id=market.down_token_id,
|
||||
timeframe=timeframe,
|
||||
)
|
||||
if signal:
|
||||
self._emit(signal)
|
||||
|
||||
# --- Sum-to-One Arbitrage ---
|
||||
if self.config.sum_to_one.enabled:
|
||||
opp = await self.sum_to_one.evaluate(
|
||||
asset=symbol,
|
||||
timeframe=timeframe,
|
||||
poly_up_ask=poly_up_ask,
|
||||
poly_down_ask=poly_down_ask,
|
||||
up_token_id=market.up_token_id,
|
||||
down_token_id=market.down_token_id,
|
||||
)
|
||||
if opp:
|
||||
up_sig, down_sig = self.sum_to_one.generate_signals(opp)
|
||||
self._emit(up_sig)
|
||||
self._emit(down_sig)
|
||||
|
||||
# --- Spread Capture ---
|
||||
if self.config.spread_capture.enabled:
|
||||
for token_id, book in [(market.up_token_id, up_book), (market.down_token_id, down_book)]:
|
||||
if book:
|
||||
quote = await self.spread_capture.evaluate(
|
||||
asset=symbol,
|
||||
timeframe=timeframe,
|
||||
token_id=token_id,
|
||||
orderbook=book,
|
||||
)
|
||||
# Spread quotes handled differently — not emitted as signals
|
||||
|
||||
def update_balance(self, balance: float) -> None:
|
||||
"""Update balance across all strategies."""
|
||||
self.temporal_arb.update_balance(balance)
|
||||
self.sum_to_one.update_balance(balance)
|
||||
self.spread_capture.update_balance(balance)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return combined statistics from all strategies."""
|
||||
return {
|
||||
"temporal_arb": {
|
||||
"evaluations": self.temporal_arb.stats.total_evaluations,
|
||||
"signals": self.temporal_arb.stats.signals_generated,
|
||||
"avg_edge": round(self.temporal_arb.stats.avg_edge, 4),
|
||||
"by_asset": dict(self.temporal_arb.stats.signals_by_asset),
|
||||
},
|
||||
"sum_to_one": {
|
||||
"evaluations": self.sum_to_one.stats.total_evaluations,
|
||||
"opportunities": self.sum_to_one.stats.opportunities_found,
|
||||
"total_net_profit": round(self.sum_to_one.stats.total_net_profit, 4),
|
||||
},
|
||||
"spread_capture": {
|
||||
"evaluations": self.spread_capture.stats.total_evaluations,
|
||||
"quotes": self.spread_capture.stats.quotes_generated,
|
||||
"avg_spread": round(self.spread_capture.stats.avg_spread, 4),
|
||||
},
|
||||
}
|
||||
164
src/strategy/spread_capture.py
Normal file
164
src/strategy/spread_capture.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Spread Capture / Market Making Strategy.
|
||||
|
||||
Places limit orders on both sides of the bid-ask spread to capture
|
||||
the spread as profit. Requires careful inventory management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import FeesConfig, RiskConfig, SpreadCaptureConfig
|
||||
from src.data.models import (
|
||||
Asset,
|
||||
Direction,
|
||||
OrderBookSnapshot,
|
||||
Signal,
|
||||
Timeframe,
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpreadQuote:
|
||||
"""A pair of limit orders to capture the spread."""
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
token_id: str
|
||||
bid_price: float
|
||||
ask_price: float
|
||||
size: int
|
||||
spread: float
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpreadStats:
|
||||
total_evaluations: int = 0
|
||||
quotes_generated: int = 0
|
||||
avg_spread: float = 0.0
|
||||
_total_spread: float = 0.0
|
||||
|
||||
|
||||
class SpreadCaptureStrategy:
|
||||
"""Market making strategy that places limit orders on both sides
|
||||
of the spread. Profits from the bid-ask spread minus fees.
|
||||
|
||||
Note: Requires orders to rest for >=3.5 seconds for maker rebate eligibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spread_config: SpreadCaptureConfig,
|
||||
risk_config: RiskConfig,
|
||||
fees_config: FeesConfig,
|
||||
balance: float = 10000.0,
|
||||
) -> None:
|
||||
self.spread = spread_config
|
||||
self.risk = risk_config
|
||||
self.fees = fees_config
|
||||
self.balance = balance
|
||||
self.stats = SpreadStats()
|
||||
|
||||
# Active quotes per token to avoid over-quoting
|
||||
self._active_quotes: dict[str, SpreadQuote] = {}
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
asset: str,
|
||||
timeframe: str,
|
||||
token_id: str,
|
||||
orderbook: OrderBookSnapshot,
|
||||
) -> Optional[SpreadQuote]:
|
||||
"""Evaluate whether to place spread-capture quotes on this token.
|
||||
|
||||
Returns a SpreadQuote if the spread is wide enough to be profitable.
|
||||
"""
|
||||
self.stats.total_evaluations += 1
|
||||
|
||||
if not self.spread.enabled:
|
||||
return None
|
||||
|
||||
if orderbook.best_bid is None or orderbook.best_ask is None:
|
||||
return None
|
||||
|
||||
current_spread = orderbook.best_ask - orderbook.best_bid
|
||||
|
||||
if current_spread < self.spread.spread_target:
|
||||
return None # Spread too tight
|
||||
|
||||
# Place our orders inside the current spread
|
||||
# Improve by 1 cent on each side
|
||||
our_bid = orderbook.best_bid + 0.01
|
||||
our_ask = orderbook.best_ask - 0.01
|
||||
our_spread = our_ask - our_bid
|
||||
|
||||
if our_spread <= 0:
|
||||
return None
|
||||
|
||||
# Check profitability after fees (maker gets rebate, but conservative estimate)
|
||||
taker_fee = self.fees.fee_for_timeframe(timeframe)
|
||||
# As maker, fee is lower or zero (rebate), but we budget for taker as worst case
|
||||
estimated_profit_per_share = our_spread - (taker_fee * 2)
|
||||
|
||||
if estimated_profit_per_share <= 0:
|
||||
return None
|
||||
|
||||
# Size: conservative for market making
|
||||
max_dollar = min(
|
||||
self.balance * 0.10, # 10% of balance per quote
|
||||
self.risk.max_position_per_market_usd * 0.5,
|
||||
)
|
||||
size = int(max_dollar / our_ask)
|
||||
|
||||
if size <= 0:
|
||||
return None
|
||||
|
||||
# Check if we already have an active quote
|
||||
if token_id in self._active_quotes:
|
||||
old = self._active_quotes[token_id]
|
||||
# Only update if spread changed significantly
|
||||
if abs(old.bid_price - our_bid) < 0.005 and abs(old.ask_price - our_ask) < 0.005:
|
||||
return None
|
||||
|
||||
quote = SpreadQuote(
|
||||
asset=Asset(asset),
|
||||
timeframe=Timeframe(timeframe),
|
||||
token_id=token_id,
|
||||
bid_price=our_bid,
|
||||
ask_price=our_ask,
|
||||
size=size,
|
||||
spread=our_spread,
|
||||
)
|
||||
|
||||
self._active_quotes[token_id] = quote
|
||||
self.stats.quotes_generated += 1
|
||||
self.stats._total_spread += our_spread
|
||||
self.stats.avg_spread = self.stats._total_spread / self.stats.quotes_generated
|
||||
|
||||
log.info(
|
||||
"spread_quote",
|
||||
asset=asset,
|
||||
timeframe=timeframe,
|
||||
bid=our_bid,
|
||||
ask=our_ask,
|
||||
spread=round(our_spread, 4),
|
||||
size=size,
|
||||
)
|
||||
|
||||
return quote
|
||||
|
||||
def remove_quote(self, token_id: str) -> None:
|
||||
"""Remove an active quote (e.g., after fill or cancel)."""
|
||||
self._active_quotes.pop(token_id, None)
|
||||
|
||||
def get_active_quotes(self) -> dict[str, SpreadQuote]:
|
||||
return dict(self._active_quotes)
|
||||
|
||||
def update_balance(self, new_balance: float) -> None:
|
||||
self.balance = new_balance
|
||||
178
src/strategy/sum_to_one.py
Normal file
178
src/strategy/sum_to_one.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Sum-to-One Arbitrage Strategy.
|
||||
|
||||
When YES + NO best ask prices sum to less than $1.00 (minus fees),
|
||||
buy both sides for a guaranteed risk-free profit regardless of outcome.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import FeesConfig, RiskConfig, SumToOneConfig
|
||||
from src.data.models import Asset, Direction, Signal, Timeframe
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SumToOneOpportunity:
|
||||
"""Represents a sum-to-one arbitrage opportunity."""
|
||||
asset: Asset
|
||||
timeframe: Timeframe
|
||||
up_ask: float
|
||||
down_ask: float
|
||||
total_cost: float # up_ask + down_ask
|
||||
gross_profit: float # 1.0 - total_cost
|
||||
fee: float # Total fees for both sides
|
||||
net_profit: float # gross_profit - fee
|
||||
up_token_id: str
|
||||
down_token_id: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SumToOneStats:
|
||||
total_evaluations: int = 0
|
||||
opportunities_found: int = 0
|
||||
total_net_profit: float = 0.0
|
||||
|
||||
|
||||
class SumToOneStrategy:
|
||||
"""Sum-to-One arbitrage: buy both YES and NO when their combined
|
||||
ask price is less than $1.00 after fees."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sto_config: SumToOneConfig,
|
||||
risk_config: RiskConfig,
|
||||
fees_config: FeesConfig,
|
||||
balance: float = 10000.0,
|
||||
) -> None:
|
||||
self.sto = sto_config
|
||||
self.risk = risk_config
|
||||
self.fees = fees_config
|
||||
self.balance = balance
|
||||
self.stats = SumToOneStats()
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
asset: str,
|
||||
timeframe: str,
|
||||
poly_up_ask: Optional[float],
|
||||
poly_down_ask: Optional[float],
|
||||
up_token_id: str,
|
||||
down_token_id: str,
|
||||
) -> Optional[SumToOneOpportunity]:
|
||||
"""Check if a sum-to-one arb opportunity exists.
|
||||
|
||||
Returns an opportunity if YES ask + NO ask + fees < $1.00.
|
||||
"""
|
||||
self.stats.total_evaluations += 1
|
||||
|
||||
if poly_up_ask is None or poly_down_ask is None:
|
||||
return None
|
||||
if poly_up_ask <= 0 or poly_down_ask <= 0:
|
||||
return None
|
||||
if poly_up_ask >= 1.0 or poly_down_ask >= 1.0:
|
||||
return None
|
||||
|
||||
total_cost = poly_up_ask + poly_down_ask
|
||||
|
||||
# Already sums to >= $1.00 — no arb
|
||||
if total_cost >= 1.0:
|
||||
return None
|
||||
|
||||
gross_profit = 1.0 - total_cost
|
||||
|
||||
# Fee: taker fee applies to BOTH purchases
|
||||
taker_fee = self.fees.fee_for_timeframe(timeframe)
|
||||
# Fee is calculated on the payout, not the cost
|
||||
# For each side, fee = taker_fee * (payout - cost) when it wins
|
||||
# Simplified: total fee ≈ taker_fee * 1.0 (since one side pays out $1)
|
||||
total_fee = taker_fee * 1.0
|
||||
|
||||
net_profit = gross_profit - total_fee
|
||||
|
||||
if net_profit < self.sto.min_spread_after_fee:
|
||||
return None
|
||||
|
||||
opp = SumToOneOpportunity(
|
||||
asset=Asset(asset),
|
||||
timeframe=Timeframe(timeframe),
|
||||
up_ask=poly_up_ask,
|
||||
down_ask=poly_down_ask,
|
||||
total_cost=total_cost,
|
||||
gross_profit=gross_profit,
|
||||
fee=total_fee,
|
||||
net_profit=net_profit,
|
||||
up_token_id=up_token_id,
|
||||
down_token_id=down_token_id,
|
||||
)
|
||||
|
||||
self.stats.opportunities_found += 1
|
||||
self.stats.total_net_profit += net_profit
|
||||
|
||||
log.info(
|
||||
"sum_to_one_opportunity",
|
||||
asset=asset,
|
||||
timeframe=timeframe,
|
||||
up_ask=poly_up_ask,
|
||||
down_ask=poly_down_ask,
|
||||
total_cost=round(total_cost, 4),
|
||||
net_profit=round(net_profit, 4),
|
||||
)
|
||||
|
||||
return opp
|
||||
|
||||
def calculate_size(self, opportunity: SumToOneOpportunity) -> int:
|
||||
"""Calculate position size for a sum-to-one trade.
|
||||
|
||||
Since this is risk-free, we can size more aggressively,
|
||||
limited only by available liquidity and max position.
|
||||
"""
|
||||
# Max dollar exposure per side
|
||||
max_per_side = min(
|
||||
self.balance * 0.5, # Don't commit more than 50% of balance
|
||||
self.risk.max_position_per_market_usd,
|
||||
)
|
||||
|
||||
# Shares are constrained by the more expensive side
|
||||
max_price = max(opportunity.up_ask, opportunity.down_ask)
|
||||
shares = int(max_per_side / max_price)
|
||||
|
||||
return max(0, shares)
|
||||
|
||||
def generate_signals(self, opportunity: SumToOneOpportunity) -> tuple[Signal, Signal]:
|
||||
"""Generate a pair of BUY signals (one for UP, one for DOWN)."""
|
||||
size = self.calculate_size(opportunity)
|
||||
|
||||
up_signal = Signal(
|
||||
direction=Direction.UP,
|
||||
asset=opportunity.asset,
|
||||
timeframe=opportunity.timeframe,
|
||||
token_id=opportunity.up_token_id,
|
||||
price=opportunity.up_ask,
|
||||
size=size,
|
||||
edge=opportunity.net_profit,
|
||||
estimated_prob=0.5, # Direction-agnostic
|
||||
)
|
||||
|
||||
down_signal = Signal(
|
||||
direction=Direction.DOWN,
|
||||
asset=opportunity.asset,
|
||||
timeframe=opportunity.timeframe,
|
||||
token_id=opportunity.down_token_id,
|
||||
price=opportunity.down_ask,
|
||||
size=size,
|
||||
edge=opportunity.net_profit,
|
||||
estimated_prob=0.5,
|
||||
)
|
||||
|
||||
return up_signal, down_signal
|
||||
|
||||
def update_balance(self, new_balance: float) -> None:
|
||||
self.balance = new_balance
|
||||
297
src/strategy/temporal_arb.py
Normal file
297
src/strategy/temporal_arb.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""Temporal Arbitrage Strategy — core strategy module.
|
||||
|
||||
Exploits the delay between CEX price movements and Polymarket oracle updates.
|
||||
When Binance price confirms a direction but Polymarket odds haven't caught up,
|
||||
buy the confirmed direction at a discount.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import FeesConfig, RiskConfig, TemporalArbConfig
|
||||
from src.data.models import Asset, Direction, Signal, Timeframe
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyStats:
|
||||
"""Running statistics for the strategy."""
|
||||
total_evaluations: int = 0
|
||||
signals_generated: int = 0
|
||||
signals_by_asset: dict[str, int] = field(default_factory=lambda: {"BTC": 0, "ETH": 0, "SOL": 0})
|
||||
total_edge: float = 0.0
|
||||
|
||||
@property
|
||||
def avg_edge(self) -> float:
|
||||
return self.total_edge / self.signals_generated if self.signals_generated > 0 else 0.0
|
||||
|
||||
|
||||
class TemporalArbStrategy:
|
||||
"""Core temporal arbitrage strategy.
|
||||
|
||||
Monitors CEX price vs Polymarket odds and generates buy signals when
|
||||
the CEX price has confirmed a direction but Polymarket hasn't adjusted.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arb_config: TemporalArbConfig,
|
||||
risk_config: RiskConfig,
|
||||
fees_config: FeesConfig,
|
||||
balance: float = 10000.0,
|
||||
) -> None:
|
||||
self.arb = arb_config
|
||||
self.risk = risk_config
|
||||
self.fees = fees_config
|
||||
self.balance = balance
|
||||
self.stats = StrategyStats()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core evaluation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
symbol: str,
|
||||
cex_price: float,
|
||||
window_start_price: float,
|
||||
window_end_time: float,
|
||||
poly_up_ask: Optional[float],
|
||||
poly_down_ask: Optional[float],
|
||||
up_token_id: str,
|
||||
down_token_id: str,
|
||||
timeframe: str,
|
||||
) -> Optional[Signal]:
|
||||
"""Evaluate whether to enter a temporal arb position.
|
||||
|
||||
Called on every CEX price tick. Returns a Signal if conditions are met,
|
||||
otherwise None.
|
||||
"""
|
||||
self.stats.total_evaluations += 1
|
||||
|
||||
if window_start_price <= 0:
|
||||
return None
|
||||
|
||||
# 1. Price direction & magnitude
|
||||
price_change_pct = (cex_price - window_start_price) / window_start_price * 100
|
||||
|
||||
if abs(price_change_pct) < self.arb.min_price_move_pct:
|
||||
return None # Not enough movement to confirm direction
|
||||
|
||||
# 2. Direction
|
||||
direction = Direction.UP if price_change_pct > 0 else Direction.DOWN
|
||||
|
||||
# 3. Time remaining
|
||||
now = time.time()
|
||||
time_remaining = window_end_time - now
|
||||
if time_remaining < self.arb.exit_before_resolution_sec:
|
||||
return None # Too close to resolution
|
||||
|
||||
# 4. Polymarket price
|
||||
poly_price = poly_up_ask if direction == Direction.UP else poly_down_ask
|
||||
if poly_price is None or poly_price <= 0 or poly_price >= 1.0:
|
||||
return None
|
||||
|
||||
if poly_price > self.arb.max_poly_entry_price:
|
||||
return None # Risk/reward insufficient
|
||||
|
||||
# 5. Probability estimation
|
||||
total_window = self._total_window_seconds(timeframe)
|
||||
estimated_prob = self.estimate_probability(
|
||||
price_change_pct, time_remaining, total_window
|
||||
)
|
||||
|
||||
# 6. Edge calculation (after fees)
|
||||
taker_fee = self.fees.fee_for_timeframe(timeframe)
|
||||
edge = estimated_prob - poly_price - taker_fee
|
||||
|
||||
if edge < self.arb.min_edge:
|
||||
return None # Insufficient edge
|
||||
|
||||
# 7. Position sizing
|
||||
asset = Asset(symbol)
|
||||
size = self.calculate_kelly_size(
|
||||
edge=edge,
|
||||
price=poly_price,
|
||||
balance=self.balance,
|
||||
max_size=self.risk.max_position_per_market_usd,
|
||||
)
|
||||
|
||||
if size <= 0:
|
||||
return None
|
||||
|
||||
# 8. Build signal
|
||||
token_id = up_token_id if direction == Direction.UP else down_token_id
|
||||
tf = Timeframe(timeframe)
|
||||
|
||||
signal = Signal(
|
||||
direction=direction,
|
||||
asset=asset,
|
||||
timeframe=tf,
|
||||
token_id=token_id,
|
||||
price=poly_price,
|
||||
size=size,
|
||||
edge=edge,
|
||||
estimated_prob=estimated_prob,
|
||||
)
|
||||
|
||||
# Update stats
|
||||
self.stats.signals_generated += 1
|
||||
self.stats.signals_by_asset[symbol] = self.stats.signals_by_asset.get(symbol, 0) + 1
|
||||
self.stats.total_edge += edge
|
||||
|
||||
log.info(
|
||||
"signal_generated",
|
||||
asset=symbol,
|
||||
direction=direction.value,
|
||||
timeframe=timeframe,
|
||||
price_change_pct=round(price_change_pct, 4),
|
||||
poly_price=poly_price,
|
||||
estimated_prob=round(estimated_prob, 4),
|
||||
edge=round(edge, 4),
|
||||
size=size,
|
||||
time_remaining=round(time_remaining, 1),
|
||||
)
|
||||
|
||||
return signal
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Probability model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def estimate_probability(
|
||||
self,
|
||||
price_change_pct: float,
|
||||
time_remaining: float,
|
||||
total_window_sec: float,
|
||||
) -> float:
|
||||
"""Estimate the true probability that the current direction holds at resolution.
|
||||
|
||||
Multi-factor model:
|
||||
1. Base probability from price magnitude
|
||||
2. Time decay: as window nears end, momentum is more confirmed
|
||||
3. Volatility: larger moves are harder to reverse
|
||||
"""
|
||||
abs_change = abs(price_change_pct)
|
||||
|
||||
# Factor 1: Base probability from price magnitude
|
||||
# 0.15% → ~70%, 0.3% → ~82%, 0.5% → ~90%, 1.0%+ → ~95%
|
||||
base_prob = 0.55 + abs_change * 1.0 # 1.0 scaling factor
|
||||
base_prob = min(base_prob, 0.95)
|
||||
|
||||
# Factor 2: Time decay — more time elapsed = more confirmation
|
||||
# If 80% of window has passed and price is still in this direction,
|
||||
# the probability of reversal is lower
|
||||
elapsed_fraction = max(0, 1.0 - (time_remaining / total_window_sec))
|
||||
# Sigmoid-like boost: ramps up in the last 40% of the window
|
||||
time_factor = 1.0 + 0.08 * max(0, elapsed_fraction - 0.6) / 0.4
|
||||
time_factor = min(time_factor, 1.08)
|
||||
|
||||
# Factor 3: Volatility / momentum strength
|
||||
# Very large moves (>0.5%) get extra confidence
|
||||
if abs_change > 0.5:
|
||||
vol_boost = min(0.05, (abs_change - 0.5) * 0.1)
|
||||
else:
|
||||
vol_boost = 0.0
|
||||
|
||||
final_prob = base_prob * time_factor + vol_boost
|
||||
return min(0.95, max(0.50, final_prob))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Position sizing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def calculate_kelly_size(
|
||||
self,
|
||||
edge: float,
|
||||
price: float,
|
||||
balance: float,
|
||||
max_size: float,
|
||||
) -> int:
|
||||
"""Kelly Criterion position sizing.
|
||||
|
||||
f* = (b*p - q) / b
|
||||
where b = (1/price - 1), p = estimated_prob, q = 1-p
|
||||
"""
|
||||
if price <= 0 or price >= 1.0:
|
||||
return 0
|
||||
|
||||
b = (1.0 / price) - 1.0 # Payout odds
|
||||
p = edge + price # estimated_prob (edge = prob - price - fee)
|
||||
q = 1.0 - p
|
||||
|
||||
if b <= 0:
|
||||
return 0
|
||||
|
||||
kelly_fraction = (b * p - q) / b
|
||||
kelly_fraction = max(0.0, min(kelly_fraction, self.risk.kelly_fraction_cap))
|
||||
|
||||
dollar_size = min(balance * kelly_fraction, max_size)
|
||||
shares = int(dollar_size / price)
|
||||
|
||||
return max(0, shares)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Exit logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def should_exit_early(
|
||||
self,
|
||||
entry_direction: Direction,
|
||||
entry_price: float,
|
||||
current_poly_price: float,
|
||||
cex_price: float,
|
||||
window_start_price: float,
|
||||
time_remaining: float,
|
||||
) -> bool:
|
||||
"""Determine if we should exit a position early.
|
||||
|
||||
Exit if:
|
||||
1. Price has reversed and edge has disappeared
|
||||
2. Polymarket price has risen enough to take profit
|
||||
3. Very close to resolution with negative PnL trajectory
|
||||
"""
|
||||
if window_start_price <= 0:
|
||||
return False
|
||||
|
||||
current_change_pct = (cex_price - window_start_price) / window_start_price * 100
|
||||
|
||||
# Direction reversed significantly
|
||||
if entry_direction == Direction.UP and current_change_pct < -0.05:
|
||||
log.warning("exit_signal_reversal", direction="UP", change_pct=round(current_change_pct, 4))
|
||||
return True
|
||||
if entry_direction == Direction.DOWN and current_change_pct > 0.05:
|
||||
log.warning("exit_signal_reversal", direction="DOWN", change_pct=round(current_change_pct, 4))
|
||||
return True
|
||||
|
||||
# Take profit: price has moved significantly in our favor
|
||||
if current_poly_price > 0.90 and current_poly_price > entry_price * 1.3:
|
||||
log.info("exit_signal_take_profit", current_price=current_poly_price, entry_price=entry_price)
|
||||
return True
|
||||
|
||||
# Close to resolution with thin margin
|
||||
if time_remaining < 10 and abs(current_change_pct) < 0.05:
|
||||
log.warning("exit_signal_thin_margin", time_remaining=round(time_remaining, 1))
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _total_window_seconds(timeframe: str) -> float:
|
||||
"""Return total seconds for a given timeframe."""
|
||||
return {"5M": 300.0, "15M": 900.0}.get(timeframe, 300.0)
|
||||
|
||||
def update_balance(self, new_balance: float) -> None:
|
||||
"""Update available balance for position sizing."""
|
||||
self.balance = new_balance
|
||||
14
src/utils/__init__.py
Normal file
14
src/utils/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Utility modules for the Polymarket Arbitrage Bot."""
|
||||
|
||||
from .logger import get_logger, log_timing, log_trade, setup_logging
|
||||
from .metrics import MetricsCollector
|
||||
from .telegram import TelegramNotifier
|
||||
|
||||
__all__ = [
|
||||
"get_logger",
|
||||
"log_timing",
|
||||
"log_trade",
|
||||
"setup_logging",
|
||||
"MetricsCollector",
|
||||
"TelegramNotifier",
|
||||
]
|
||||
153
src/utils/logger.py
Normal file
153
src/utils/logger.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Structured logging setup using structlog for the Polymarket Arbitrage Bot.
|
||||
|
||||
Provides consistent, structured logging across all modules with support
|
||||
for dev-friendly console output and production JSON rendering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> None:
|
||||
"""Configure structlog with processors, renderers, and stdlib integration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_level:
|
||||
Root log level (e.g. ``"DEBUG"``, ``"INFO"``, ``"WARNING"``).
|
||||
Also accepts lowercase variants.
|
||||
"""
|
||||
level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
|
||||
# Shared processors applied to every log entry
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
[
|
||||
structlog.processors.CallsiteParameter.FILENAME,
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
]
|
||||
),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
|
||||
# Choose renderer based on whether we're attached to a terminal (dev) or
|
||||
# running headless / in production.
|
||||
if sys.stderr.isatty():
|
||||
renderer: structlog.types.Processor = structlog.dev.ConsoleRenderer(
|
||||
colors=True,
|
||||
)
|
||||
else:
|
||||
renderer = structlog.processors.JSONRenderer()
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
*shared_processors,
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Configure stdlib root logger so third-party libraries also go through
|
||||
# structlog's formatting pipeline.
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
renderer,
|
||||
],
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(level)
|
||||
|
||||
|
||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
"""Return a bound structlog logger for *name*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
Typically ``__name__`` of the calling module.
|
||||
"""
|
||||
return structlog.get_logger(name)
|
||||
|
||||
|
||||
def log_trade(logger: structlog.stdlib.BoundLogger, trade_data: dict[str, Any]) -> None:
|
||||
"""Emit a structured trade event.
|
||||
|
||||
Extracts key fields from *trade_data* and logs them as a single
|
||||
structured event for easy querying and dashboarding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logger:
|
||||
A bound structlog logger.
|
||||
trade_data:
|
||||
Dictionary with trade details (e.g. ``id``, ``asset``, ``direction``,
|
||||
``entry_price``, ``fill_price``, ``size``, ``pnl``, ``status``).
|
||||
"""
|
||||
logger.info(
|
||||
"trade_event",
|
||||
trade_id=trade_data.get("id"),
|
||||
asset=trade_data.get("asset"),
|
||||
direction=trade_data.get("direction"),
|
||||
token_id=trade_data.get("token_id"),
|
||||
entry_price=trade_data.get("entry_price"),
|
||||
fill_price=trade_data.get("fill_price"),
|
||||
size=trade_data.get("size"),
|
||||
fee=trade_data.get("fee"),
|
||||
pnl=trade_data.get("pnl"),
|
||||
status=trade_data.get("status"),
|
||||
edge=trade_data.get("edge"),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_timing(
|
||||
logger: structlog.stdlib.BoundLogger,
|
||||
operation: str,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that logs the elapsed wall-clock time of *operation*.
|
||||
|
||||
Usage::
|
||||
|
||||
with log_timing(logger, "fetch_orderbook"):
|
||||
await fetch_orderbook()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logger:
|
||||
A bound structlog logger.
|
||||
operation:
|
||||
Human-readable label for the timed block.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
logger.debug("timing_start", operation=operation)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
logger.info(
|
||||
"timing_end",
|
||||
operation=operation,
|
||||
elapsed_ms=round(elapsed_ms, 2),
|
||||
)
|
||||
116
src/utils/metrics.py
Normal file
116
src/utils/metrics.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Metrics collector for the performance dashboard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeMetric:
|
||||
"""Single trade metric for time-series tracking."""
|
||||
timestamp: float
|
||||
asset: str
|
||||
direction: str
|
||||
timeframe: str
|
||||
entry_price: float
|
||||
size: int
|
||||
edge: float
|
||||
pnl: Optional[float] = None
|
||||
won: Optional[bool] = None
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collects and aggregates trading metrics for the dashboard.
|
||||
|
||||
Maintains rolling windows and aggregated stats for real-time display.
|
||||
"""
|
||||
|
||||
def __init__(self, max_history: int = 10000) -> None:
|
||||
self._trades: deque[TradeMetric] = deque(maxlen=max_history)
|
||||
self._pnl_series: deque[tuple[float, float]] = deque(maxlen=max_history)
|
||||
self._start_time = time.time()
|
||||
|
||||
# Running aggregates
|
||||
self.total_volume: float = 0.0
|
||||
self.total_fees: float = 0.0
|
||||
|
||||
def record_trade(self, metric: TradeMetric) -> None:
|
||||
self._trades.append(metric)
|
||||
self.total_volume += metric.entry_price * metric.size
|
||||
|
||||
def record_pnl(self, pnl: float) -> None:
|
||||
self._pnl_series.append((time.time(), pnl))
|
||||
|
||||
def record_fee(self, fee: float) -> None:
|
||||
self.total_fees += fee
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Aggregations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_recent_trades(self, n: int = 50) -> list[TradeMetric]:
|
||||
return list(self._trades)[-n:]
|
||||
|
||||
def get_pnl_series(self) -> list[tuple[float, float]]:
|
||||
return list(self._pnl_series)
|
||||
|
||||
def get_hourly_stats(self) -> dict:
|
||||
"""Aggregate stats for the last hour."""
|
||||
cutoff = time.time() - 3600
|
||||
recent = [t for t in self._trades if t.timestamp > cutoff]
|
||||
|
||||
wins = sum(1 for t in recent if t.won is True)
|
||||
losses = sum(1 for t in recent if t.won is False)
|
||||
total_pnl = sum(t.pnl for t in recent if t.pnl is not None)
|
||||
|
||||
return {
|
||||
"trades": len(recent),
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": wins / (wins + losses) * 100 if (wins + losses) > 0 else 0,
|
||||
"pnl": round(total_pnl, 2),
|
||||
}
|
||||
|
||||
def get_asset_breakdown(self) -> dict[str, dict]:
|
||||
"""PnL and trade count per asset."""
|
||||
breakdown: dict[str, dict] = {}
|
||||
for t in self._trades:
|
||||
if t.asset not in breakdown:
|
||||
breakdown[t.asset] = {"trades": 0, "wins": 0, "pnl": 0.0}
|
||||
breakdown[t.asset]["trades"] += 1
|
||||
if t.won is True:
|
||||
breakdown[t.asset]["wins"] += 1
|
||||
if t.pnl is not None:
|
||||
breakdown[t.asset]["pnl"] += t.pnl
|
||||
|
||||
for asset in breakdown:
|
||||
total = breakdown[asset]["trades"]
|
||||
wins = breakdown[asset]["wins"]
|
||||
breakdown[asset]["win_rate"] = round(wins / total * 100, 1) if total > 0 else 0
|
||||
breakdown[asset]["pnl"] = round(breakdown[asset]["pnl"], 2)
|
||||
|
||||
return breakdown
|
||||
|
||||
def get_uptime(self) -> float:
|
||||
"""Return uptime in seconds."""
|
||||
return time.time() - self._start_time
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
total_trades = len(self._trades)
|
||||
wins = sum(1 for t in self._trades if t.won is True)
|
||||
losses = sum(1 for t in self._trades if t.won is False)
|
||||
total_pnl = sum(t.pnl for t in self._trades if t.pnl is not None)
|
||||
|
||||
return {
|
||||
"total_trades": total_trades,
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": round(wins / (wins + losses) * 100, 1) if (wins + losses) > 0 else 0,
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"total_volume": round(self.total_volume, 2),
|
||||
"total_fees": round(self.total_fees, 2),
|
||||
"uptime_hours": round(self.get_uptime() / 3600, 2),
|
||||
}
|
||||
152
src/utils/telegram.py
Normal file
152
src/utils/telegram.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Telegram notification bot for trade alerts and daily summaries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.config import NotificationConfig
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class TelegramNotifier:
|
||||
"""Sends trading notifications via Telegram Bot API.
|
||||
|
||||
Uses aiohttp directly (no python-telegram-bot dependency for async)
|
||||
for lightweight non-blocking notification delivery.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api.telegram.org/bot{token}/sendMessage"
|
||||
|
||||
def __init__(self, config: NotificationConfig) -> None:
|
||||
self.config = config
|
||||
self._enabled = config.telegram_enabled and bool(config.telegram_token) and bool(config.telegram_chat_id)
|
||||
self._session = None
|
||||
|
||||
if not self._enabled:
|
||||
log.info("telegram_disabled", reason="missing token or chat_id")
|
||||
|
||||
async def _get_session(self):
|
||||
if self._session is None:
|
||||
import aiohttp
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def send(self, message: str, parse_mode: str = "HTML") -> bool:
|
||||
"""Send a message to the configured Telegram chat."""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
url = self.BASE_URL.format(token=self.config.telegram_token)
|
||||
payload = {
|
||||
"chat_id": self.config.telegram_chat_id,
|
||||
"text": message,
|
||||
"parse_mode": parse_mode,
|
||||
}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, timeout=10) as resp:
|
||||
if resp.status == 200:
|
||||
return True
|
||||
else:
|
||||
body = await resp.text()
|
||||
log.warning("telegram_send_failed", status=resp.status, body=body[:200])
|
||||
return False
|
||||
except Exception:
|
||||
log.exception("telegram_send_error")
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def notify_trade(
|
||||
self,
|
||||
asset: str,
|
||||
direction: str,
|
||||
timeframe: str,
|
||||
price: float,
|
||||
size: int,
|
||||
edge: float,
|
||||
) -> None:
|
||||
"""Send a trade notification."""
|
||||
if not self.config.notify_on_trade:
|
||||
return
|
||||
|
||||
msg = (
|
||||
f"🔔 <b>Trade Signal</b>\n"
|
||||
f"Asset: <b>{asset}</b> | {direction}\n"
|
||||
f"Timeframe: {timeframe}\n"
|
||||
f"Price: {price:.2f} | Size: {size}\n"
|
||||
f"Edge: {edge:.2%}"
|
||||
)
|
||||
await self.send(msg)
|
||||
|
||||
async def notify_fill(
|
||||
self,
|
||||
asset: str,
|
||||
direction: str,
|
||||
fill_price: float,
|
||||
fill_size: int,
|
||||
trade_id: str,
|
||||
) -> None:
|
||||
"""Send a fill notification."""
|
||||
if not self.config.notify_on_trade:
|
||||
return
|
||||
|
||||
msg = (
|
||||
f"✅ <b>Order Filled</b>\n"
|
||||
f"Asset: {asset} | {direction}\n"
|
||||
f"Fill: {fill_price:.2f} × {fill_size}\n"
|
||||
f"Trade ID: <code>{trade_id}</code>"
|
||||
)
|
||||
await self.send(msg)
|
||||
|
||||
async def notify_daily_summary(
|
||||
self,
|
||||
date: str,
|
||||
total_trades: int,
|
||||
wins: int,
|
||||
losses: int,
|
||||
pnl: float,
|
||||
fees: float,
|
||||
volume: float,
|
||||
) -> None:
|
||||
"""Send daily summary."""
|
||||
if not self.config.notify_on_daily_summary:
|
||||
return
|
||||
|
||||
win_rate = wins / total_trades * 100 if total_trades > 0 else 0
|
||||
emoji = "📈" if pnl >= 0 else "📉"
|
||||
|
||||
msg = (
|
||||
f"{emoji} <b>Daily Summary — {date}</b>\n"
|
||||
f"Trades: {total_trades} (W:{wins} / L:{losses})\n"
|
||||
f"Win Rate: {win_rate:.1f}%\n"
|
||||
f"PnL: <b>${pnl:+.2f}</b>\n"
|
||||
f"Fees: ${fees:.2f}\n"
|
||||
f"Volume: ${volume:,.0f}"
|
||||
)
|
||||
await self.send(msg)
|
||||
|
||||
async def notify_error(self, error_msg: str) -> None:
|
||||
"""Send an error alert."""
|
||||
if not self.config.notify_on_error:
|
||||
return
|
||||
|
||||
msg = f"🚨 <b>Error Alert</b>\n<code>{error_msg[:500]}</code>"
|
||||
await self.send(msg)
|
||||
|
||||
async def notify_halt(self, reason: str) -> None:
|
||||
"""Send a trading halt alert."""
|
||||
msg = f"🛑 <b>TRADING HALTED</b>\nReason: {reason}"
|
||||
await self.send(msg)
|
||||
Reference in New Issue
Block a user