refactor: use Fetcher and add reference tickers
This commit is contained in:
parent
600d6b76f6
commit
c11c2bccf4
@ -1,77 +1,120 @@
|
||||
from datetime import datetime, timedelta
|
||||
from .taapi import TaapiClient
|
||||
from typing import List
|
||||
from .models import TickerOHLCV
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, taapi_key: str):
|
||||
self._taapi = TaapiClient(taapi_key)
|
||||
|
||||
class Fetcher:
|
||||
@staticmethod
|
||||
def ticker_data_for(ticker: str, date: datetime) -> TickerOHLCV | None:
|
||||
# Set end date to next day to ensure we get the target date
|
||||
start_date = date.strftime("%Y-%m-%d")
|
||||
end_date = (date + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
def ticker_data_for(
|
||||
ticker: str, date: datetime, end_date: datetime | None = None
|
||||
) -> List[TickerOHLCV]:
|
||||
"""
|
||||
Fetch OHLCV data for a ticker for a date range.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Start date (inclusive)
|
||||
end_date: End date (inclusive). If None, fetches only the start date.
|
||||
|
||||
Returns:
|
||||
List of TickerOHLCV records, one per trading day in the range.
|
||||
Returns empty list if no data available.
|
||||
"""
|
||||
start_date_str = date.strftime("%Y-%m-%d")
|
||||
|
||||
if not end_date:
|
||||
end_date_str = (date + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
else:
|
||||
end_date_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
data = yf.download(
|
||||
ticker,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
start=start_date_str,
|
||||
end=end_date_str,
|
||||
auto_adjust=True,
|
||||
progress=False,
|
||||
)
|
||||
|
||||
if data.empty:
|
||||
return None
|
||||
return []
|
||||
|
||||
row = data.iloc[0]
|
||||
results: List[TickerOHLCV] = []
|
||||
|
||||
open_price = (
|
||||
float(row["Open"].iloc[0])
|
||||
if isinstance(row["Open"], pd.Series)
|
||||
else float(row["Open"])
|
||||
)
|
||||
high = (
|
||||
float(row["High"].iloc[0])
|
||||
if isinstance(row["High"], pd.Series)
|
||||
else float(row["High"])
|
||||
)
|
||||
low = (
|
||||
float(row["Low"].iloc[0])
|
||||
if isinstance(row["Low"], pd.Series)
|
||||
else float(row["Low"])
|
||||
)
|
||||
close = (
|
||||
float(row["Close"].iloc[0])
|
||||
if isinstance(row["Close"], pd.Series)
|
||||
else float(row["Close"])
|
||||
)
|
||||
volume = (
|
||||
int(row["Volume"].iloc[0])
|
||||
if isinstance(row["Volume"], pd.Series)
|
||||
else int(row["Volume"])
|
||||
if "Volume" in row
|
||||
else 0
|
||||
)
|
||||
for idx, row in data.iterrows():
|
||||
# Extract datetime from index
|
||||
if hasattr(idx, "to_pydatetime"):
|
||||
row_date = idx.to_pydatetime()
|
||||
elif isinstance(idx, datetime):
|
||||
row_date = idx
|
||||
else:
|
||||
row_date = pd.to_datetime(idx).to_pydatetime()
|
||||
|
||||
# Calculate average price
|
||||
avg = (high + low) / 2.0
|
||||
# Extract values (handle both Series and scalar)
|
||||
def safe_extract(value):
|
||||
if isinstance(value, pd.Series):
|
||||
return float(value.iloc[0])
|
||||
return float(value)
|
||||
|
||||
return TickerOHLCV(
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
open=round(open_price, 2),
|
||||
high=round(high, 2),
|
||||
low=round(low, 2),
|
||||
close=round(close, 2),
|
||||
avg=round(avg, 2),
|
||||
volume=volume,
|
||||
)
|
||||
open_price = safe_extract(row["Open"])
|
||||
high = safe_extract(row["High"])
|
||||
low = safe_extract(row["Low"])
|
||||
close = safe_extract(row["Close"])
|
||||
volume = int(safe_extract(row.get("Volume", 0)))
|
||||
|
||||
avg = (high + low) / 2.0
|
||||
|
||||
ohlcv = TickerOHLCV(
|
||||
ticker=ticker,
|
||||
date=row_date,
|
||||
open=round(open_price, 2),
|
||||
high=round(high, 2),
|
||||
low=round(low, 2),
|
||||
close=round(close, 2),
|
||||
avg=round(avg, 2),
|
||||
volume=volume,
|
||||
)
|
||||
results.append(ohlcv)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching data for {ticker} on {start_date}: {str(e)}")
|
||||
print(
|
||||
f"Error fetching data for {ticker} from {start_date_str} to {end_date_str}: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
return None
|
||||
@staticmethod
|
||||
def ticker_data_for_single_day(ticker: str, date: datetime) -> TickerOHLCV | None:
|
||||
"""
|
||||
Fetch OHLCV data for a single day (backward compatibility).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: The specific date
|
||||
|
||||
Returns:
|
||||
TickerOHLCV record for that day, or None if not available
|
||||
"""
|
||||
results = Fetcher.ticker_data_for(ticker, date, end_date=None)
|
||||
|
||||
return results[0] if results else None
|
||||
|
||||
@staticmethod
|
||||
def ticker_data_for_range(
|
||||
ticker: str, start_date: datetime, end_date: datetime
|
||||
) -> List[TickerOHLCV]:
|
||||
"""
|
||||
Fetch OHLCV data for a date range (explicit range method).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start date (inclusive)
|
||||
end_date: End date (inclusive)
|
||||
|
||||
Returns:
|
||||
List of TickerOHLCV records for the range
|
||||
"""
|
||||
return Fetcher.ticker_data_for(ticker, start_date, end_date)
|
||||
|
@ -50,6 +50,7 @@ class IndicatorService:
|
||||
|
||||
Returns:
|
||||
IndicatorsData instance ready to save to database, or None if insufficient data
|
||||
or if any calculated values are NaN
|
||||
"""
|
||||
# Fetch historical OHLCV data from database
|
||||
start_date: datetime = target_date - timedelta(
|
||||
@ -59,30 +60,40 @@ class IndicatorService:
|
||||
ticker=ticker, start_date=start_date, end_date=target_date
|
||||
)
|
||||
|
||||
# Verify we have enough data
|
||||
if len(ohlcv_records) < 30:
|
||||
# Verify we have enough data (minimum 40 trading days for MACD + ADX)
|
||||
if len(ohlcv_records) < 40:
|
||||
return None
|
||||
|
||||
# Convert to numpy arrays for TA-Lib
|
||||
arrays: OHLCVArrays = self._convert_to_arrays(ohlcv_records)
|
||||
|
||||
# Calculate all indicator categories
|
||||
momentum: MomentumIndicators = self._calculate_momentum(arrays)
|
||||
volatility: VolatilityIndicators = self._calculate_volatility(arrays)
|
||||
trend: TrendIndicators = self._calculate_trend(arrays)
|
||||
volume: VolumeIndicators = self._calculate_volume(arrays)
|
||||
momentum: MomentumIndicators | None = self._calculate_momentum(arrays)
|
||||
volatility: VolatilityIndicators | None = self._calculate_volatility(arrays)
|
||||
trend: TrendIndicators | None = self._calculate_trend(arrays)
|
||||
volume: VolumeIndicators | None = self._calculate_volume(arrays)
|
||||
support_resistance: SupportResistanceIndicators = (
|
||||
self._calculate_support_resistance(arrays)
|
||||
)
|
||||
market_regime: MarketRegimeIndicators = self._calculate_market_regime(arrays)
|
||||
market_regime: MarketRegimeIndicators | None = self._calculate_market_regime(
|
||||
arrays
|
||||
)
|
||||
|
||||
# Skip if any indicator group returned None (contains NaN values)
|
||||
if (
|
||||
momentum is None
|
||||
or volatility is None
|
||||
or trend is None
|
||||
or volume is None
|
||||
or market_regime is None
|
||||
):
|
||||
return None
|
||||
|
||||
# Build and return strongly-typed IndicatorsData
|
||||
return IndicatorsData(
|
||||
ticker=ticker,
|
||||
date=target_date,
|
||||
|
||||
# Momentum (7 indicators)
|
||||
|
||||
rsi_14=momentum.rsi_14,
|
||||
rsi_20=momentum.rsi_20,
|
||||
macd_line=momentum.macd_line,
|
||||
@ -90,40 +101,30 @@ class IndicatorService:
|
||||
macd_histogram=momentum.macd_histogram,
|
||||
stoch_k=momentum.stoch_k,
|
||||
stoch_d=momentum.stoch_d,
|
||||
|
||||
# Volatility (6 indicators)
|
||||
|
||||
bb_upper=volatility.bb_upper,
|
||||
bb_middle=volatility.bb_middle,
|
||||
bb_lower=volatility.bb_lower,
|
||||
bb_width=volatility.bb_width,
|
||||
bb_percent=volatility.bb_percent,
|
||||
atr_14=volatility.atr_14,
|
||||
|
||||
# Trend (4 indicators)
|
||||
|
||||
adx_14=trend.adx_14,
|
||||
di_plus=trend.di_plus,
|
||||
di_minus=trend.di_minus,
|
||||
sar=trend.sar,
|
||||
|
||||
# Volume (3 indicators)
|
||||
|
||||
obv=volume.obv,
|
||||
obv_sma_20=volume.obv_sma_20,
|
||||
volume_roc_5=volume.volume_roc_5,
|
||||
|
||||
# Support/Resistance (6 indicators)
|
||||
|
||||
fib_236=support_resistance.fib_236,
|
||||
fib_382=support_resistance.fib_382,
|
||||
fib_618=support_resistance.fib_618,
|
||||
pivot_point=support_resistance.pivot_point,
|
||||
resistance_1=support_resistance.resistance_1,
|
||||
support_1=support_resistance.support_1,
|
||||
|
||||
# Market Regime (2 indicators)
|
||||
|
||||
cci_20=market_regime.cci_20,
|
||||
williams_r_14=market_regime.williams_r_14,
|
||||
)
|
||||
@ -180,13 +181,14 @@ class IndicatorService:
|
||||
) -> int:
|
||||
"""
|
||||
Calculate and save indicators for multiple dates efficiently.
|
||||
Uses upsert logic to handle existing records.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
dates: List of dates to calculate indicators for
|
||||
|
||||
Returns:
|
||||
Number of indicator records successfully saved
|
||||
Number of indicator records successfully saved/updated
|
||||
"""
|
||||
indicators_list: List[IndicatorsData] = self.bulk_calculate_indicators(
|
||||
ticker=ticker, dates=dates
|
||||
@ -195,7 +197,14 @@ class IndicatorService:
|
||||
if not indicators_list:
|
||||
return 0
|
||||
|
||||
return self._crud.create_indicators_bulk(indicators_list)
|
||||
# Use upsert for each indicator to avoid UNIQUE constraint violations
|
||||
saved_count = 0
|
||||
for indicators in indicators_list:
|
||||
result = self._crud.upsert_indicators(indicators)
|
||||
if result:
|
||||
saved_count += 1
|
||||
|
||||
return saved_count
|
||||
|
||||
# ========================================================================
|
||||
# PRIVATE: Data Conversion
|
||||
|
48
populate.py
48
populate.py
@ -7,8 +7,7 @@ from paperone.utils import (
|
||||
get_last_n_trading_days,
|
||||
)
|
||||
from paperone.database import TradingDataCRUD
|
||||
from os import environ
|
||||
from paperone.client import Client
|
||||
from paperone.client import Fetcher
|
||||
from rich.progress import track
|
||||
from datetime import datetime
|
||||
|
||||
@ -16,46 +15,49 @@ load_dotenv()
|
||||
|
||||
DB_FILE = "trading_data.db"
|
||||
|
||||
|
||||
def main() -> NoReturn:
|
||||
api_key = environ.get("API_KEY")
|
||||
|
||||
if not api_key:
|
||||
print("API_KEY not set")
|
||||
exit(0)
|
||||
|
||||
client = Client(api_key)
|
||||
fetcher = Fetcher()
|
||||
crud = TradingDataCRUD(f"sqlite:///{DB_FILE}")
|
||||
date = datetime.now()
|
||||
days_range = 360
|
||||
days_range = 360 * 10
|
||||
|
||||
tickers = [
|
||||
"AAPL",
|
||||
"ARM",
|
||||
"^VIX",
|
||||
"AMD",
|
||||
"GOOG",
|
||||
"META",
|
||||
"MNDY",
|
||||
"MSFT",
|
||||
"NVDA",
|
||||
"VOO",
|
||||
"VT",
|
||||
]
|
||||
|
||||
# Loop through tickers and fetch data
|
||||
for ticker in tickers:
|
||||
reference_tickers = [
|
||||
"VOO", # US
|
||||
"VT", # global
|
||||
"GLD", # gold
|
||||
"^VIX", # volatility
|
||||
"VNQ", # real estate
|
||||
"DBC", # commodities
|
||||
"VGK", # EU
|
||||
"VPL", # Asia
|
||||
"VWO", # Emerging
|
||||
"BND", # Global bonds
|
||||
]
|
||||
|
||||
for ticker in tickers + reference_tickers:
|
||||
trading_days = get_last_n_trading_days(date, days_range)
|
||||
|
||||
for day in track(trading_days, description=f"→ {ticker}"):
|
||||
existing = crud.get_ohlcv(ticker, day)
|
||||
if existing:
|
||||
continue
|
||||
start_date = trading_days[0]
|
||||
end_date = trading_days[-1]
|
||||
|
||||
ticker_data = client.ticker_data_for(ticker, day)
|
||||
if not ticker_data:
|
||||
continue
|
||||
all_ohlcv_data = fetcher.ticker_data_for(ticker, start_date, end_date)
|
||||
for ohlcv in track(all_ohlcv_data, description=f"→ {ticker}"):
|
||||
existing = crud.get_ohlcv(ticker, ohlcv.date)
|
||||
|
||||
crud.create_ohlcv(ticker_data)
|
||||
if not existing:
|
||||
crud.create_ohlcv(ohlcv)
|
||||
|
||||
exit(0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user