Files
crypto_news/backtest/data_loader.py

115 lines
3.4 KiB
Python
Raw Permalink Normal View History

2026-03-20 07:49:42 +09:00
"""Historical data loader for backtesting.
Fetches OHLCV data from exchanges or loads from CSV files.
"""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Optional
import pandas as pd
from loguru import logger
from execution.exchange_client import ExchangeClient
class DataLoader:
"""Load historical OHLCV data for backtesting."""
def __init__(self, exchange_client: ExchangeClient | None = None):
self._client = exchange_client
async def fetch_from_exchange(
self,
symbol: str,
timeframe: str = "1h",
since: str | None = None,
limit: int = 1000,
) -> pd.DataFrame:
"""Fetch historical data from an exchange.
Args:
symbol: Trading pair (e.g., "BTC/USDT").
timeframe: Candle timeframe.
since: ISO date string for start (e.g., "2025-01-01").
limit: Max candles to fetch.
"""
if self._client is None:
raise RuntimeError("ExchangeClient required for live data fetch")
if not await self._client.is_connected():
await self._client.connect()
since_ts = None
if since:
since_ts = int(pd.Timestamp(since).timestamp() * 1000)
# Fetch in chunks if needed
all_data: list = []
remaining = limit
current_since = since_ts
while remaining > 0:
batch_limit = min(remaining, 500)
df = await self._client.fetch_ohlcv(
symbol, timeframe, since=current_since, limit=batch_limit
)
if df.empty:
break
all_data.append(df)
remaining -= len(df)
# Move since to after the last candle
current_since = int(df.index[-1].timestamp() * 1000) + 1
if len(df) < batch_limit:
break
if not all_data:
return pd.DataFrame()
result = pd.concat(all_data)
result = result[~result.index.duplicated(keep="last")]
logger.info("Loaded {} candles for {} {}", len(result), symbol, timeframe)
return result.sort_index()
@staticmethod
def load_from_csv(file_path: str) -> pd.DataFrame:
"""Load OHLCV data from a CSV file.
Expected columns: timestamp (or date), open, high, low, close, volume
"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"CSV file not found: {file_path}")
df = pd.read_csv(file_path)
# Detect timestamp column
ts_col = None
for col in ["timestamp", "date", "datetime", "time"]:
if col in df.columns:
ts_col = col
break
if ts_col:
df[ts_col] = pd.to_datetime(df[ts_col])
df.set_index(ts_col, inplace=True)
required = {"open", "high", "low", "close", "volume"}
missing = required - set(df.columns)
if missing:
raise ValueError(f"CSV missing required columns: {missing}")
logger.info("Loaded {} candles from {}", len(df), file_path)
return df
@staticmethod
def save_to_csv(df: pd.DataFrame, file_path: str) -> None:
"""Save OHLCV DataFrame to CSV."""
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(file_path)
logger.info("Saved {} candles to {}", len(df), file_path)