deploy: 2026-03-20 07:49
This commit is contained in:
114
backtest/data_loader.py
Normal file
114
backtest/data_loader.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Historical data loader for backtesting.
|
||||
|
||||
Fetches OHLCV data from exchanges or loads from CSV files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
from execution.exchange_client import ExchangeClient
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Load historical OHLCV data for backtesting."""
|
||||
|
||||
def __init__(self, exchange_client: ExchangeClient | None = None):
|
||||
self._client = exchange_client
|
||||
|
||||
async def fetch_from_exchange(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframe: str = "1h",
|
||||
since: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch historical data from an exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair (e.g., "BTC/USDT").
|
||||
timeframe: Candle timeframe.
|
||||
since: ISO date string for start (e.g., "2025-01-01").
|
||||
limit: Max candles to fetch.
|
||||
"""
|
||||
if self._client is None:
|
||||
raise RuntimeError("ExchangeClient required for live data fetch")
|
||||
|
||||
if not await self._client.is_connected():
|
||||
await self._client.connect()
|
||||
|
||||
since_ts = None
|
||||
if since:
|
||||
since_ts = int(pd.Timestamp(since).timestamp() * 1000)
|
||||
|
||||
# Fetch in chunks if needed
|
||||
all_data: list = []
|
||||
remaining = limit
|
||||
current_since = since_ts
|
||||
|
||||
while remaining > 0:
|
||||
batch_limit = min(remaining, 500)
|
||||
df = await self._client.fetch_ohlcv(
|
||||
symbol, timeframe, since=current_since, limit=batch_limit
|
||||
)
|
||||
if df.empty:
|
||||
break
|
||||
all_data.append(df)
|
||||
remaining -= len(df)
|
||||
# Move since to after the last candle
|
||||
current_since = int(df.index[-1].timestamp() * 1000) + 1
|
||||
|
||||
if len(df) < batch_limit:
|
||||
break
|
||||
|
||||
if not all_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
result = pd.concat(all_data)
|
||||
result = result[~result.index.duplicated(keep="last")]
|
||||
logger.info("Loaded {} candles for {} {}", len(result), symbol, timeframe)
|
||||
return result.sort_index()
|
||||
|
||||
@staticmethod
|
||||
def load_from_csv(file_path: str) -> pd.DataFrame:
|
||||
"""Load OHLCV data from a CSV file.
|
||||
|
||||
Expected columns: timestamp (or date), open, high, low, close, volume
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"CSV file not found: {file_path}")
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Detect timestamp column
|
||||
ts_col = None
|
||||
for col in ["timestamp", "date", "datetime", "time"]:
|
||||
if col in df.columns:
|
||||
ts_col = col
|
||||
break
|
||||
|
||||
if ts_col:
|
||||
df[ts_col] = pd.to_datetime(df[ts_col])
|
||||
df.set_index(ts_col, inplace=True)
|
||||
|
||||
required = {"open", "high", "low", "close", "volume"}
|
||||
missing = required - set(df.columns)
|
||||
if missing:
|
||||
raise ValueError(f"CSV missing required columns: {missing}")
|
||||
|
||||
logger.info("Loaded {} candles from {}", len(df), file_path)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def save_to_csv(df: pd.DataFrame, file_path: str) -> None:
|
||||
"""Save OHLCV DataFrame to CSV."""
|
||||
path = Path(file_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_csv(file_path)
|
||||
logger.info("Saved {} candles to {}", len(df), file_path)
|
||||
Reference in New Issue
Block a user