115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
|
|
"""Historical data loader for backtesting.
|
||
|
|
|
||
|
|
Fetches OHLCV data from exchanges or loads from CSV files.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
from execution.exchange_client import ExchangeClient
|
||
|
|
|
||
|
|
|
||
|
|
class DataLoader:
|
||
|
|
"""Load historical OHLCV data for backtesting."""
|
||
|
|
|
||
|
|
def __init__(self, exchange_client: ExchangeClient | None = None):
|
||
|
|
self._client = exchange_client
|
||
|
|
|
||
|
|
async def fetch_from_exchange(
|
||
|
|
self,
|
||
|
|
symbol: str,
|
||
|
|
timeframe: str = "1h",
|
||
|
|
since: str | None = None,
|
||
|
|
limit: int = 1000,
|
||
|
|
) -> pd.DataFrame:
|
||
|
|
"""Fetch historical data from an exchange.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
symbol: Trading pair (e.g., "BTC/USDT").
|
||
|
|
timeframe: Candle timeframe.
|
||
|
|
since: ISO date string for start (e.g., "2025-01-01").
|
||
|
|
limit: Max candles to fetch.
|
||
|
|
"""
|
||
|
|
if self._client is None:
|
||
|
|
raise RuntimeError("ExchangeClient required for live data fetch")
|
||
|
|
|
||
|
|
if not await self._client.is_connected():
|
||
|
|
await self._client.connect()
|
||
|
|
|
||
|
|
since_ts = None
|
||
|
|
if since:
|
||
|
|
since_ts = int(pd.Timestamp(since).timestamp() * 1000)
|
||
|
|
|
||
|
|
# Fetch in chunks if needed
|
||
|
|
all_data: list = []
|
||
|
|
remaining = limit
|
||
|
|
current_since = since_ts
|
||
|
|
|
||
|
|
while remaining > 0:
|
||
|
|
batch_limit = min(remaining, 500)
|
||
|
|
df = await self._client.fetch_ohlcv(
|
||
|
|
symbol, timeframe, since=current_since, limit=batch_limit
|
||
|
|
)
|
||
|
|
if df.empty:
|
||
|
|
break
|
||
|
|
all_data.append(df)
|
||
|
|
remaining -= len(df)
|
||
|
|
# Move since to after the last candle
|
||
|
|
current_since = int(df.index[-1].timestamp() * 1000) + 1
|
||
|
|
|
||
|
|
if len(df) < batch_limit:
|
||
|
|
break
|
||
|
|
|
||
|
|
if not all_data:
|
||
|
|
return pd.DataFrame()
|
||
|
|
|
||
|
|
result = pd.concat(all_data)
|
||
|
|
result = result[~result.index.duplicated(keep="last")]
|
||
|
|
logger.info("Loaded {} candles for {} {}", len(result), symbol, timeframe)
|
||
|
|
return result.sort_index()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def load_from_csv(file_path: str) -> pd.DataFrame:
|
||
|
|
"""Load OHLCV data from a CSV file.
|
||
|
|
|
||
|
|
Expected columns: timestamp (or date), open, high, low, close, volume
|
||
|
|
"""
|
||
|
|
path = Path(file_path)
|
||
|
|
if not path.exists():
|
||
|
|
raise FileNotFoundError(f"CSV file not found: {file_path}")
|
||
|
|
|
||
|
|
df = pd.read_csv(file_path)
|
||
|
|
|
||
|
|
# Detect timestamp column
|
||
|
|
ts_col = None
|
||
|
|
for col in ["timestamp", "date", "datetime", "time"]:
|
||
|
|
if col in df.columns:
|
||
|
|
ts_col = col
|
||
|
|
break
|
||
|
|
|
||
|
|
if ts_col:
|
||
|
|
df[ts_col] = pd.to_datetime(df[ts_col])
|
||
|
|
df.set_index(ts_col, inplace=True)
|
||
|
|
|
||
|
|
required = {"open", "high", "low", "close", "volume"}
|
||
|
|
missing = required - set(df.columns)
|
||
|
|
if missing:
|
||
|
|
raise ValueError(f"CSV missing required columns: {missing}")
|
||
|
|
|
||
|
|
logger.info("Loaded {} candles from {}", len(df), file_path)
|
||
|
|
return df
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def save_to_csv(df: pd.DataFrame, file_path: str) -> None:
|
||
|
|
"""Save OHLCV DataFrame to CSV."""
|
||
|
|
path = Path(file_path)
|
||
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
df.to_csv(file_path)
|
||
|
|
logger.info("Saved {} candles to {}", len(df), file_path)
|