paperone/paperone/database.py
2025-10-18 09:22:23 +01:00

477 lines
15 KiB
Python

from sqlmodel import SQLModel, Session, create_engine, select
from typing import List, Optional, Tuple
from datetime import datetime
from contextlib import contextmanager
from .models import TickerOHLCV, IndicatorsData
from .entities import TimeSeriesTickerData
class TradingDataCRUD:
"""
CRUD operations for trading data with SQLModel/SQLite.
Handles OHLCV data and technical indicators with proper session management.
"""
def __init__(self, database_url: str = "sqlite:///trading_data.db"):
"""
Initialize the CRUD manager with database connection.
Args:
database_url: SQLite database URL (default: local file)
"""
self.engine = create_engine(
database_url,
echo=False, # Set to True for SQL query debugging
connect_args={"check_same_thread": False}, # Needed for SQLite
)
self._create_tables()
def _create_tables(self):
"""Create all tables if they don't exist."""
SQLModel.metadata.create_all(self.engine)
@contextmanager
def get_session(self):
"""Context manager for database sessions with automatic cleanup."""
session = Session(self.engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# ========================================================================
# CREATE Operations
# ========================================================================
def create_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
"""
Insert a single OHLCV record.
Args:
ohlcv: TickerOHLCV instance to insert
Returns:
The inserted TickerOHLCV record
"""
with self.get_session() as session:
session.add(ohlcv)
session.commit()
session.refresh(ohlcv)
return ohlcv
def create_ohlcv_bulk(self, ohlcv_list: List[TickerOHLCV]) -> int:
"""
Bulk insert OHLCV records (more efficient for large datasets).
Args:
ohlcv_list: List of TickerOHLCV instances
Returns:
Number of records inserted
"""
with self.get_session() as session:
session.add_all(ohlcv_list)
session.commit()
return len(ohlcv_list)
def create_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
"""
Insert a single indicators record.
Args:
indicators: IndicatorsData instance to insert
Returns:
The inserted IndicatorsData record
"""
with self.get_session() as session:
session.add(indicators)
session.commit()
session.refresh(indicators)
return indicators
def create_indicators_bulk(self, indicators_list: List[IndicatorsData]) -> int:
"""
Bulk insert indicators records.
Args:
indicators_list: List of IndicatorsData instances
Returns:
Number of records inserted
"""
with self.get_session() as session:
session.add_all(indicators_list)
session.commit()
return len(indicators_list)
# ========================================================================
# READ Operations
# ========================================================================
def get_ohlcv(self, ticker: str, date: datetime) -> Optional[TickerOHLCV]:
"""
Get a single OHLCV record by ticker and date.
Args:
ticker: Stock ticker symbol
date: Trading date
Returns:
TickerOHLCV record or None if not found
"""
with self.get_session() as session:
statement = select(TickerOHLCV).where(
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
)
return session.exec(statement).first()
def get_ohlcv_range(
self, ticker: str, start_date: datetime, end_date: datetime
) -> List[TickerOHLCV]:
"""
Get OHLCV records for a ticker within a date range.
Args:
ticker: Stock ticker symbol
start_date: Start of date range (inclusive)
end_date: End of date range (inclusive)
Returns:
List of TickerOHLCV records, sorted by date
"""
with self.get_session() as session:
statement = (
select(TickerOHLCV)
.where(
TickerOHLCV.ticker == ticker,
TickerOHLCV.date >= start_date,
TickerOHLCV.date <= end_date,
)
.order_by(TickerOHLCV.date)
)
return list(session.exec(statement).all())
def get_ohlcv_latest(self, ticker: str, limit: int = 1) -> List[TickerOHLCV]:
"""
Get the most recent OHLCV record(s) for a ticker.
Args:
ticker: Stock ticker symbol
limit: Number of recent records to retrieve
Returns:
List of most recent TickerOHLCV records, sorted by date descending
"""
with self.get_session() as session:
statement = (
select(TickerOHLCV)
.where(TickerOHLCV.ticker == ticker)
.order_by(TickerOHLCV.date.desc())
.limit(limit)
)
return list(session.exec(statement).all())
def get_indicators(self, ticker: str, date: datetime) -> Optional[IndicatorsData]:
"""
Get indicators for a specific ticker and date.
Args:
ticker: Stock ticker symbol
date: Trading date
Returns:
IndicatorsData record or None if not found
"""
with self.get_session() as session:
statement = select(IndicatorsData).where(
IndicatorsData.ticker == ticker, IndicatorsData.date == date
)
return session.exec(statement).first()
def get_indicators_range(
self, ticker: str, start_date: datetime, end_date: datetime
) -> List[IndicatorsData]:
"""
Get indicators for a ticker within a date range.
Args:
ticker: Stock ticker symbol
start_date: Start of date range (inclusive)
end_date: End of date range (inclusive)
Returns:
List of IndicatorsData records, sorted by date
"""
with self.get_session() as session:
statement = (
select(IndicatorsData)
.where(
IndicatorsData.ticker == ticker,
IndicatorsData.date >= start_date,
IndicatorsData.date <= end_date,
)
.order_by(IndicatorsData.date)
)
return list(session.exec(statement).all())
def get_ohlcv_with_indicators(
self, ticker: str, date: datetime
) -> Optional[Tuple[TickerOHLCV, IndicatorsData]]:
"""
Get OHLCV and indicators together for a specific ticker and date.
Args:
ticker: Stock ticker symbol
date: Trading date
Returns:
Tuple of (TickerOHLCV, IndicatorsData) or None if not found
"""
with self.get_session() as session:
ohlcv_statement = select(TickerOHLCV).where(
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
)
ohlcv = session.exec(ohlcv_statement).first()
if not ohlcv:
return None
indicators_statement = select(IndicatorsData).where(
IndicatorsData.ticker == ticker, IndicatorsData.date == date
)
indicators = session.exec(indicators_statement).first()
if not indicators:
return None
return (ohlcv, indicators)
def get_time_series(
self, ticker: str, start_date: datetime, end_date: datetime
) -> TimeSeriesTickerData:
"""
Get time series data for building TimeSeriesTickerData.
Args:
ticker: Stock ticker symbol
start_date: Start of date range
end_date: End of date range
Returns:
TimeSeriesTickerData instance with OHLCV data
"""
ohlcv_list = self.get_ohlcv_range(ticker, start_date, end_date)
return TimeSeriesTickerData.build_time_series_ticker_data(ticker, ohlcv_list)
# ========================================================================
# UPDATE Operations
# ========================================================================
def update_ohlcv(
self, ticker: str, date: datetime, **kwargs
) -> Optional[TickerOHLCV]:
"""
Update OHLCV fields for a specific record.
Args:
ticker: Stock ticker symbol
date: Trading date
**kwargs: Fields to update (e.g., close=150.0, volume=1000000)
Returns:
Updated TickerOHLCV record or None if not found
"""
with self.get_session() as session:
statement = select(TickerOHLCV).where(
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
)
ohlcv = session.exec(statement).first()
if not ohlcv:
return None
for key, value in kwargs.items():
if hasattr(ohlcv, key):
setattr(ohlcv, key, value)
session.add(ohlcv)
session.commit()
session.refresh(ohlcv)
return ohlcv
def update_indicators(
self, ticker: str, date: datetime, **kwargs
) -> Optional[IndicatorsData]:
"""
Update indicator fields for a specific record.
Args:
ticker: Stock ticker symbol
date: Trading date
**kwargs: Fields to update (e.g., rsi_14=55.2, macd_line=1.8)
Returns:
Updated IndicatorsData record or None if not found
"""
with self.get_session() as session:
statement = select(IndicatorsData).where(
IndicatorsData.ticker == ticker, IndicatorsData.date == date
)
indicators = session.exec(statement).first()
if not indicators:
return None
for key, value in kwargs.items():
if hasattr(indicators, key):
setattr(indicators, key, value)
session.add(indicators)
session.commit()
session.refresh(indicators)
return indicators
# ========================================================================
# UPSERT Operations (Insert or Update)
# ========================================================================
def upsert_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
"""
Insert or update OHLCV record if it already exists.
Args:
ohlcv: TickerOHLCV instance
Returns:
The upserted TickerOHLCV record
"""
existing = self.get_ohlcv(ohlcv.ticker, ohlcv.date)
if existing:
return self.update_ohlcv(
ohlcv.ticker,
ohlcv.date,
open=ohlcv.open,
close=ohlcv.close,
low=ohlcv.low,
high=ohlcv.high,
avg=ohlcv.avg,
volume=ohlcv.volume,
)
else:
return self.create_ohlcv(ohlcv)
def upsert_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
"""
Insert or update indicators record if it already exists.
Args:
indicators: IndicatorsData instance
Returns:
The upserted IndicatorsData record
"""
existing = self.get_indicators(indicators.ticker, indicators.date)
if existing:
# Get all indicator fields dynamically
indicator_fields = {
k: v
for k, v in indicators.__dict__.items()
if k not in ["ticker", "date", "ohlcv", "_sa_instance_state"]
}
return self.update_indicators(
indicators.ticker, indicators.date, **indicator_fields
)
else:
return self.create_indicators(indicators)
# ========================================================================
# DELETE Operations
# ========================================================================
def delete_ohlcv(self, ticker: str, date: datetime) -> bool:
"""
Delete an OHLCV record.
Args:
ticker: Stock ticker symbol
date: Trading date
Returns:
True if deleted, False if not found
"""
with self.get_session() as session:
statement = select(TickerOHLCV).where(
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
)
ohlcv = session.exec(statement).first()
if not ohlcv:
return False
session.delete(ohlcv)
session.commit()
return True
def delete_indicators(self, ticker: str, date: datetime) -> bool:
"""
Delete an indicators record.
Args:
ticker: Stock ticker symbol
date: Trading date
Returns:
True if deleted, False if not found
"""
with self.get_session() as session:
statement = select(IndicatorsData).where(
IndicatorsData.ticker == ticker, IndicatorsData.date == date
)
indicators = session.exec(statement).first()
if not indicators:
return False
session.delete(indicators)
session.commit()
return True
def delete_ticker_data(self, ticker: str) -> Tuple[int, int]:
"""
Delete all data (OHLCV and indicators) for a ticker.
Args:
ticker: Stock ticker symbol
Returns:
Tuple of (ohlcv_deleted_count, indicators_deleted_count)
"""
with self.get_session() as session:
# Delete indicators first (due to foreign key constraint)
indicators_statement = select(IndicatorsData).where(
IndicatorsData.ticker == ticker
)
indicators_list = session.exec(indicators_statement).all()
indicators_count = len(list(indicators_list))
for ind in indicators_list:
session.delete(ind)
# Delete OHLCV
ohlcv_statement = select(TickerOHLCV).where(TickerOHLCV.ticker == ticker)
ohlcv_list = session.exec(ohlcv_statement).all()
ohlcv_count = len(list(ohlcv_list))
for ohlcv in ohlcv_list:
session.delete(ohlcv)
session.commit()
return (ohlcv_count, indicators_count)