refactor: use Fetcher and add reference tickers

This commit is contained in:
Giulio De Pasquale 2025-10-18 10:55:36 +01:00
parent 600d6b76f6
commit c11c2bccf4
3 changed files with 152 additions and 98 deletions

View File

@ -1,68 +1,74 @@
from datetime import datetime, timedelta
from .taapi import TaapiClient
from typing import List
from .models import TickerOHLCV
import yfinance as yf
import pandas as pd
class Client:
def __init__(self, taapi_key: str):
self._taapi = TaapiClient(taapi_key)
class Fetcher:
@staticmethod
def ticker_data_for(ticker: str, date: datetime) -> TickerOHLCV | None:
# Set end date to next day to ensure we get the target date
start_date = date.strftime("%Y-%m-%d")
end_date = (date + timedelta(days=1)).strftime("%Y-%m-%d")
def ticker_data_for(
ticker: str, date: datetime, end_date: datetime | None = None
) -> List[TickerOHLCV]:
"""
Fetch OHLCV data for a ticker for a date range.
Args:
ticker: Stock ticker symbol
date: Start date (inclusive)
end_date: End date (inclusive). If None, fetches only the start date.
Returns:
List of TickerOHLCV records, one per trading day in the range.
Returns empty list if no data available.
"""
start_date_str = date.strftime("%Y-%m-%d")
if not end_date:
end_date_str = (date + timedelta(days=1)).strftime("%Y-%m-%d")
else:
end_date_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d")
try:
data = yf.download(
ticker,
start=start_date,
end=end_date,
start=start_date_str,
end=end_date_str,
auto_adjust=True,
progress=False,
)
if data.empty:
return None
return []
row = data.iloc[0]
results: List[TickerOHLCV] = []
open_price = (
float(row["Open"].iloc[0])
if isinstance(row["Open"], pd.Series)
else float(row["Open"])
)
high = (
float(row["High"].iloc[0])
if isinstance(row["High"], pd.Series)
else float(row["High"])
)
low = (
float(row["Low"].iloc[0])
if isinstance(row["Low"], pd.Series)
else float(row["Low"])
)
close = (
float(row["Close"].iloc[0])
if isinstance(row["Close"], pd.Series)
else float(row["Close"])
)
volume = (
int(row["Volume"].iloc[0])
if isinstance(row["Volume"], pd.Series)
else int(row["Volume"])
if "Volume" in row
else 0
)
for idx, row in data.iterrows():
# Extract datetime from index
if hasattr(idx, "to_pydatetime"):
row_date = idx.to_pydatetime()
elif isinstance(idx, datetime):
row_date = idx
else:
row_date = pd.to_datetime(idx).to_pydatetime()
# Extract values (handle both Series and scalar)
def safe_extract(value):
if isinstance(value, pd.Series):
return float(value.iloc[0])
return float(value)
open_price = safe_extract(row["Open"])
high = safe_extract(row["High"])
low = safe_extract(row["Low"])
close = safe_extract(row["Close"])
volume = int(safe_extract(row.get("Volume", 0)))
# Calculate average price
avg = (high + low) / 2.0
return TickerOHLCV(
ohlcv = TickerOHLCV(
ticker=ticker,
date=date,
date=row_date,
open=round(open_price, 2),
high=round(high, 2),
low=round(low, 2),
@ -70,8 +76,45 @@ class Client:
avg=round(avg, 2),
volume=volume,
)
results.append(ohlcv)
return results
except Exception as e:
print(f"Error fetching data for {ticker} on {start_date}: {str(e)}")
print(
f"Error fetching data for {ticker} from {start_date_str} to {end_date_str}: {str(e)}"
)
return []
return None
@staticmethod
def ticker_data_for_single_day(ticker: str, date: datetime) -> TickerOHLCV | None:
"""
Fetch OHLCV data for a single day (backward compatibility).
Args:
ticker: Stock ticker symbol
date: The specific date
Returns:
TickerOHLCV record for that day, or None if not available
"""
results = Fetcher.ticker_data_for(ticker, date, end_date=None)
return results[0] if results else None
@staticmethod
def ticker_data_for_range(
ticker: str, start_date: datetime, end_date: datetime
) -> List[TickerOHLCV]:
"""
Fetch OHLCV data for a date range (explicit range method).
Args:
ticker: Stock ticker symbol
start_date: Start date (inclusive)
end_date: End date (inclusive)
Returns:
List of TickerOHLCV records for the range
"""
return Fetcher.ticker_data_for(ticker, start_date, end_date)

View File

@ -50,6 +50,7 @@ class IndicatorService:
Returns:
IndicatorsData instance ready to save to database, or None if insufficient data
or if any calculated values are NaN
"""
# Fetch historical OHLCV data from database
start_date: datetime = target_date - timedelta(
@ -59,30 +60,40 @@ class IndicatorService:
ticker=ticker, start_date=start_date, end_date=target_date
)
# Verify we have enough data
if len(ohlcv_records) < 30:
# Verify we have enough data (minimum 40 trading days for MACD + ADX)
if len(ohlcv_records) < 40:
return None
# Convert to numpy arrays for TA-Lib
arrays: OHLCVArrays = self._convert_to_arrays(ohlcv_records)
# Calculate all indicator categories
momentum: MomentumIndicators = self._calculate_momentum(arrays)
volatility: VolatilityIndicators = self._calculate_volatility(arrays)
trend: TrendIndicators = self._calculate_trend(arrays)
volume: VolumeIndicators = self._calculate_volume(arrays)
momentum: MomentumIndicators | None = self._calculate_momentum(arrays)
volatility: VolatilityIndicators | None = self._calculate_volatility(arrays)
trend: TrendIndicators | None = self._calculate_trend(arrays)
volume: VolumeIndicators | None = self._calculate_volume(arrays)
support_resistance: SupportResistanceIndicators = (
self._calculate_support_resistance(arrays)
)
market_regime: MarketRegimeIndicators = self._calculate_market_regime(arrays)
market_regime: MarketRegimeIndicators | None = self._calculate_market_regime(
arrays
)
# Skip if any indicator group returned None (contains NaN values)
if (
momentum is None
or volatility is None
or trend is None
or volume is None
or market_regime is None
):
return None
# Build and return strongly-typed IndicatorsData
return IndicatorsData(
ticker=ticker,
date=target_date,
# Momentum (7 indicators)
rsi_14=momentum.rsi_14,
rsi_20=momentum.rsi_20,
macd_line=momentum.macd_line,
@ -90,40 +101,30 @@ class IndicatorService:
macd_histogram=momentum.macd_histogram,
stoch_k=momentum.stoch_k,
stoch_d=momentum.stoch_d,
# Volatility (6 indicators)
bb_upper=volatility.bb_upper,
bb_middle=volatility.bb_middle,
bb_lower=volatility.bb_lower,
bb_width=volatility.bb_width,
bb_percent=volatility.bb_percent,
atr_14=volatility.atr_14,
# Trend (4 indicators)
adx_14=trend.adx_14,
di_plus=trend.di_plus,
di_minus=trend.di_minus,
sar=trend.sar,
# Volume (3 indicators)
obv=volume.obv,
obv_sma_20=volume.obv_sma_20,
volume_roc_5=volume.volume_roc_5,
# Support/Resistance (6 indicators)
fib_236=support_resistance.fib_236,
fib_382=support_resistance.fib_382,
fib_618=support_resistance.fib_618,
pivot_point=support_resistance.pivot_point,
resistance_1=support_resistance.resistance_1,
support_1=support_resistance.support_1,
# Market Regime (2 indicators)
cci_20=market_regime.cci_20,
williams_r_14=market_regime.williams_r_14,
)
@ -180,13 +181,14 @@ class IndicatorService:
) -> int:
"""
Calculate and save indicators for multiple dates efficiently.
Uses upsert logic to handle existing records.
Args:
ticker: Stock ticker symbol
dates: List of dates to calculate indicators for
Returns:
Number of indicator records successfully saved
Number of indicator records successfully saved/updated
"""
indicators_list: List[IndicatorsData] = self.bulk_calculate_indicators(
ticker=ticker, dates=dates
@ -195,7 +197,14 @@ class IndicatorService:
if not indicators_list:
return 0
return self._crud.create_indicators_bulk(indicators_list)
# Use upsert for each indicator to avoid UNIQUE constraint violations
saved_count = 0
for indicators in indicators_list:
result = self._crud.upsert_indicators(indicators)
if result:
saved_count += 1
return saved_count
# ========================================================================
# PRIVATE: Data Conversion

View File

@ -7,8 +7,7 @@ from paperone.utils import (
get_last_n_trading_days,
)
from paperone.database import TradingDataCRUD
from os import environ
from paperone.client import Client
from paperone.client import Fetcher
from rich.progress import track
from datetime import datetime
@ -16,46 +15,49 @@ load_dotenv()
DB_FILE = "trading_data.db"
def main() -> NoReturn:
api_key = environ.get("API_KEY")
if not api_key:
print("API_KEY not set")
exit(0)
client = Client(api_key)
fetcher = Fetcher()
crud = TradingDataCRUD(f"sqlite:///{DB_FILE}")
date = datetime.now()
days_range = 360
days_range = 360 * 10
tickers = [
"AAPL",
"ARM",
"^VIX",
"AMD",
"GOOG",
"META",
"MNDY",
"MSFT",
"NVDA",
"VOO",
"VT",
]
# Loop through tickers and fetch data
for ticker in tickers:
reference_tickers = [
"VOO", # US
"VT", # global
"GLD", # gold
"^VIX", # volatility
"VNQ", # real estate
"DBC", # commodities
"VGK", # EU
"VPL", # Asia
"VWO", # Emerging
"BND", # Global bonds
]
for ticker in tickers + reference_tickers:
trading_days = get_last_n_trading_days(date, days_range)
for day in track(trading_days, description=f"{ticker}"):
existing = crud.get_ohlcv(ticker, day)
if existing:
continue
start_date = trading_days[0]
end_date = trading_days[-1]
ticker_data = client.ticker_data_for(ticker, day)
if not ticker_data:
continue
all_ohlcv_data = fetcher.ticker_data_for(ticker, start_date, end_date)
for ohlcv in track(all_ohlcv_data, description=f"{ticker}"):
existing = crud.get_ohlcv(ticker, ohlcv.date)
crud.create_ohlcv(ticker_data)
if not existing:
crud.create_ohlcv(ohlcv)
exit(0)