"""Risk management engine. Controls position sizing, daily loss limits, drawdown monitoring, and trade approval to protect capital. """ from __future__ import annotations from dataclasses import dataclass, field from datetime import date, datetime from typing import Dict, List, Optional from loguru import logger from config import settings @dataclass class RiskApproval: """Result of a trade approval check.""" approved: bool reason: str = "" position_size: float = 0.0 risk_amount: float = 0.0 class RiskManager: """Central risk management engine. Enforces: - Per-trade risk limit (default 2% of balance) - Daily max loss (default 5%) - Max concurrent positions (default 3) - Max leverage (default 3x) - Max drawdown (default 15%) -- triggers emergency stop """ def __init__( self, max_risk_per_trade: float | None = None, max_daily_loss: float | None = None, max_concurrent_positions: int | None = None, max_leverage: int | None = None, max_drawdown: float | None = None, ): self.max_risk_per_trade = max_risk_per_trade or settings.MAX_RISK_PER_TRADE self.max_daily_loss = max_daily_loss or settings.MAX_DAILY_LOSS self.max_concurrent_positions = ( max_concurrent_positions or settings.MAX_CONCURRENT_POSITIONS ) self.max_leverage = max_leverage or settings.MAX_LEVERAGE self.max_drawdown = max_drawdown or settings.MAX_DRAWDOWN # Internal state self._daily_pnl: Dict[str, float] = {} # date_str -> cumulative pnl self._open_positions: int = 0 self._peak_equity: float = 0.0 self._is_stopped: bool = False # ------------------------------------------------------------------ # Trade approval # ------------------------------------------------------------------ def approve_trade( self, entry_price: float, stop_loss: float, balance: float, current_open_positions: int | None = None, ) -> RiskApproval: """Decide whether a new trade is allowed. Checks: 1. Bot not in emergency-stop state 2. Daily loss limit not exceeded 3. Concurrent position limit not exceeded 4. Drawdown limit not exceeded 5. Position size within acceptable bounds """ if self._is_stopped: return RiskApproval(False, "Bot is in emergency stop mode") # Daily loss check today = date.today().isoformat() daily = self._daily_pnl.get(today, 0.0) if daily < 0 and abs(daily) >= balance * self.max_daily_loss: return RiskApproval(False, f"Daily loss limit reached: {daily:.2f}") # Concurrent positions open_pos = ( current_open_positions if current_open_positions is not None else self._open_positions ) if open_pos >= self.max_concurrent_positions: return RiskApproval( False, f"Max concurrent positions reached: {open_pos}/{self.max_concurrent_positions}", ) # Drawdown if self._peak_equity > 0: drawdown = (self._peak_equity - balance) / self._peak_equity if drawdown >= self.max_drawdown: return RiskApproval( False, f"Max drawdown reached: {drawdown:.2%} >= {self.max_drawdown:.2%}" ) # Position sizing risk_pct = min(self.max_risk_per_trade, 0.05) # hard cap at 5% position_size = self.calculate_position_size( balance, entry_price, stop_loss, risk_pct ) if position_size <= 0: return RiskApproval(False, "Calculated position size is zero or negative") risk_amount = balance * risk_pct logger.info( "Trade APPROVED: size={:.6f}, risk={:.2f} ({:.1%})", position_size, risk_amount, risk_pct, ) return RiskApproval( approved=True, position_size=position_size, risk_amount=risk_amount, ) # ------------------------------------------------------------------ # Position sizing # ------------------------------------------------------------------ def calculate_position_size( self, balance: float, entry_price: float, stop_loss: float, risk_pct: float | None = None, ) -> float: """Calculate position size based on risk percentage. Formula: size = (balance * risk_pct) / |entry - stop_loss| """ risk = risk_pct or self.max_risk_per_trade price_risk = abs(entry_price - stop_loss) if price_risk == 0: return 0.0 risk_amount = balance * risk size = risk_amount / price_risk # Cap so that total cost doesn't exceed available balance # Reserve room for concurrent positions max_per_position = balance / self.max_concurrent_positions max_size = max_per_position / entry_price if size * entry_price > max_per_position: size = max_size return round(size, 8) # ------------------------------------------------------------------ # PnL tracking # ------------------------------------------------------------------ def update_daily_pnl(self, pnl: float) -> None: """Record realised PnL for the current day.""" today = date.today().isoformat() self._daily_pnl[today] = self._daily_pnl.get(today, 0.0) + pnl logger.debug("Daily PnL updated: {} = {:.2f}", today, self._daily_pnl[today]) def update_equity(self, equity: float) -> None: """Track peak equity for drawdown calculation.""" if equity > self._peak_equity: self._peak_equity = equity def get_daily_pnl(self) -> float: """Return today's cumulative PnL.""" today = date.today().isoformat() return self._daily_pnl.get(today, 0.0) # ------------------------------------------------------------------ # Drawdown # ------------------------------------------------------------------ def check_drawdown(self, current_equity: float) -> bool: """Return True if max drawdown has been breached.""" if self._peak_equity <= 0: return False drawdown = (self._peak_equity - current_equity) / self._peak_equity return drawdown >= self.max_drawdown # ------------------------------------------------------------------ # Emergency stop # ------------------------------------------------------------------ def emergency_stop(self) -> None: """Activate emergency stop -- no new trades allowed.""" self._is_stopped = True logger.critical("EMERGENCY STOP activated") def reset_emergency(self) -> None: """Clear emergency stop state.""" self._is_stopped = False logger.warning("Emergency stop cleared") @property def is_stopped(self) -> bool: return self._is_stopped # ------------------------------------------------------------------ # Position tracking helpers # ------------------------------------------------------------------ def on_position_opened(self) -> None: self._open_positions += 1 def on_position_closed(self) -> None: self._open_positions = max(0, self._open_positions - 1) @property def open_position_count(self) -> int: return self._open_positions