From f5aa0e58489e53088b01661d3519c951a1dc9a39 Mon Sep 17 00:00:00 2001 From: Giulio De Pasquale Date: Sat, 18 Oct 2025 09:22:23 +0100 Subject: [PATCH] feat(database): add trading data CRUD operations --- paperone/database.py | 476 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 476 insertions(+) create mode 100644 paperone/database.py diff --git a/paperone/database.py b/paperone/database.py new file mode 100644 index 0000000..7c76137 --- /dev/null +++ b/paperone/database.py @@ -0,0 +1,476 @@ +from sqlmodel import SQLModel, Session, create_engine, select +from typing import List, Optional, Tuple +from datetime import datetime +from contextlib import contextmanager +from .models import TickerOHLCV, IndicatorsData +from .entities import TimeSeriesTickerData + + +class TradingDataCRUD: + """ + CRUD operations for trading data with SQLModel/SQLite. + Handles OHLCV data and technical indicators with proper session management. + """ + + def __init__(self, database_url: str = "sqlite:///trading_data.db"): + """ + Initialize the CRUD manager with database connection. + + Args: + database_url: SQLite database URL (default: local file) + """ + self.engine = create_engine( + database_url, + echo=False, # Set to True for SQL query debugging + connect_args={"check_same_thread": False}, # Needed for SQLite + ) + self._create_tables() + + def _create_tables(self): + """Create all tables if they don't exist.""" + SQLModel.metadata.create_all(self.engine) + + @contextmanager + def get_session(self): + """Context manager for database sessions with automatic cleanup.""" + session = Session(self.engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + # ======================================================================== + # CREATE Operations + # ======================================================================== + + def create_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV: + """ + Insert a single OHLCV record. + + Args: + ohlcv: TickerOHLCV instance to insert + + Returns: + The inserted TickerOHLCV record + """ + with self.get_session() as session: + session.add(ohlcv) + session.commit() + session.refresh(ohlcv) + return ohlcv + + def create_ohlcv_bulk(self, ohlcv_list: List[TickerOHLCV]) -> int: + """ + Bulk insert OHLCV records (more efficient for large datasets). + + Args: + ohlcv_list: List of TickerOHLCV instances + + Returns: + Number of records inserted + """ + with self.get_session() as session: + session.add_all(ohlcv_list) + session.commit() + return len(ohlcv_list) + + def create_indicators(self, indicators: IndicatorsData) -> IndicatorsData: + """ + Insert a single indicators record. + + Args: + indicators: IndicatorsData instance to insert + + Returns: + The inserted IndicatorsData record + """ + with self.get_session() as session: + session.add(indicators) + session.commit() + session.refresh(indicators) + return indicators + + def create_indicators_bulk(self, indicators_list: List[IndicatorsData]) -> int: + """ + Bulk insert indicators records. + + Args: + indicators_list: List of IndicatorsData instances + + Returns: + Number of records inserted + """ + with self.get_session() as session: + session.add_all(indicators_list) + session.commit() + return len(indicators_list) + + # ======================================================================== + # READ Operations + # ======================================================================== + + def get_ohlcv(self, ticker: str, date: datetime) -> Optional[TickerOHLCV]: + """ + Get a single OHLCV record by ticker and date. + + Args: + ticker: Stock ticker symbol + date: Trading date + + Returns: + TickerOHLCV record or None if not found + """ + with self.get_session() as session: + statement = select(TickerOHLCV).where( + TickerOHLCV.ticker == ticker, TickerOHLCV.date == date + ) + return session.exec(statement).first() + + def get_ohlcv_range( + self, ticker: str, start_date: datetime, end_date: datetime + ) -> List[TickerOHLCV]: + """ + Get OHLCV records for a ticker within a date range. + + Args: + ticker: Stock ticker symbol + start_date: Start of date range (inclusive) + end_date: End of date range (inclusive) + + Returns: + List of TickerOHLCV records, sorted by date + """ + with self.get_session() as session: + statement = ( + select(TickerOHLCV) + .where( + TickerOHLCV.ticker == ticker, + TickerOHLCV.date >= start_date, + TickerOHLCV.date <= end_date, + ) + .order_by(TickerOHLCV.date) + ) + return list(session.exec(statement).all()) + + def get_ohlcv_latest(self, ticker: str, limit: int = 1) -> List[TickerOHLCV]: + """ + Get the most recent OHLCV record(s) for a ticker. + + Args: + ticker: Stock ticker symbol + limit: Number of recent records to retrieve + + Returns: + List of most recent TickerOHLCV records, sorted by date descending + """ + with self.get_session() as session: + statement = ( + select(TickerOHLCV) + .where(TickerOHLCV.ticker == ticker) + .order_by(TickerOHLCV.date.desc()) + .limit(limit) + ) + return list(session.exec(statement).all()) + + def get_indicators(self, ticker: str, date: datetime) -> Optional[IndicatorsData]: + """ + Get indicators for a specific ticker and date. + + Args: + ticker: Stock ticker symbol + date: Trading date + + Returns: + IndicatorsData record or None if not found + """ + with self.get_session() as session: + statement = select(IndicatorsData).where( + IndicatorsData.ticker == ticker, IndicatorsData.date == date + ) + return session.exec(statement).first() + + def get_indicators_range( + self, ticker: str, start_date: datetime, end_date: datetime + ) -> List[IndicatorsData]: + """ + Get indicators for a ticker within a date range. + + Args: + ticker: Stock ticker symbol + start_date: Start of date range (inclusive) + end_date: End of date range (inclusive) + + Returns: + List of IndicatorsData records, sorted by date + """ + with self.get_session() as session: + statement = ( + select(IndicatorsData) + .where( + IndicatorsData.ticker == ticker, + IndicatorsData.date >= start_date, + IndicatorsData.date <= end_date, + ) + .order_by(IndicatorsData.date) + ) + return list(session.exec(statement).all()) + + def get_ohlcv_with_indicators( + self, ticker: str, date: datetime + ) -> Optional[Tuple[TickerOHLCV, IndicatorsData]]: + """ + Get OHLCV and indicators together for a specific ticker and date. + + Args: + ticker: Stock ticker symbol + date: Trading date + + Returns: + Tuple of (TickerOHLCV, IndicatorsData) or None if not found + """ + with self.get_session() as session: + ohlcv_statement = select(TickerOHLCV).where( + TickerOHLCV.ticker == ticker, TickerOHLCV.date == date + ) + ohlcv = session.exec(ohlcv_statement).first() + + if not ohlcv: + return None + + indicators_statement = select(IndicatorsData).where( + IndicatorsData.ticker == ticker, IndicatorsData.date == date + ) + indicators = session.exec(indicators_statement).first() + + if not indicators: + return None + + return (ohlcv, indicators) + + def get_time_series( + self, ticker: str, start_date: datetime, end_date: datetime + ) -> TimeSeriesTickerData: + """ + Get time series data for building TimeSeriesTickerData. + + Args: + ticker: Stock ticker symbol + start_date: Start of date range + end_date: End of date range + + Returns: + TimeSeriesTickerData instance with OHLCV data + """ + ohlcv_list = self.get_ohlcv_range(ticker, start_date, end_date) + return TimeSeriesTickerData.build_time_series_ticker_data(ticker, ohlcv_list) + + # ======================================================================== + # UPDATE Operations + # ======================================================================== + + def update_ohlcv( + self, ticker: str, date: datetime, **kwargs + ) -> Optional[TickerOHLCV]: + """ + Update OHLCV fields for a specific record. + + Args: + ticker: Stock ticker symbol + date: Trading date + **kwargs: Fields to update (e.g., close=150.0, volume=1000000) + + Returns: + Updated TickerOHLCV record or None if not found + """ + with self.get_session() as session: + statement = select(TickerOHLCV).where( + TickerOHLCV.ticker == ticker, TickerOHLCV.date == date + ) + ohlcv = session.exec(statement).first() + + if not ohlcv: + return None + + for key, value in kwargs.items(): + if hasattr(ohlcv, key): + setattr(ohlcv, key, value) + + session.add(ohlcv) + session.commit() + session.refresh(ohlcv) + return ohlcv + + def update_indicators( + self, ticker: str, date: datetime, **kwargs + ) -> Optional[IndicatorsData]: + """ + Update indicator fields for a specific record. + + Args: + ticker: Stock ticker symbol + date: Trading date + **kwargs: Fields to update (e.g., rsi_14=55.2, macd_line=1.8) + + Returns: + Updated IndicatorsData record or None if not found + """ + with self.get_session() as session: + statement = select(IndicatorsData).where( + IndicatorsData.ticker == ticker, IndicatorsData.date == date + ) + indicators = session.exec(statement).first() + + if not indicators: + return None + + for key, value in kwargs.items(): + if hasattr(indicators, key): + setattr(indicators, key, value) + + session.add(indicators) + session.commit() + session.refresh(indicators) + return indicators + + # ======================================================================== + # UPSERT Operations (Insert or Update) + # ======================================================================== + + def upsert_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV: + """ + Insert or update OHLCV record if it already exists. + + Args: + ohlcv: TickerOHLCV instance + + Returns: + The upserted TickerOHLCV record + """ + existing = self.get_ohlcv(ohlcv.ticker, ohlcv.date) + + if existing: + return self.update_ohlcv( + ohlcv.ticker, + ohlcv.date, + open=ohlcv.open, + close=ohlcv.close, + low=ohlcv.low, + high=ohlcv.high, + avg=ohlcv.avg, + volume=ohlcv.volume, + ) + else: + return self.create_ohlcv(ohlcv) + + def upsert_indicators(self, indicators: IndicatorsData) -> IndicatorsData: + """ + Insert or update indicators record if it already exists. + + Args: + indicators: IndicatorsData instance + + Returns: + The upserted IndicatorsData record + """ + existing = self.get_indicators(indicators.ticker, indicators.date) + + if existing: + # Get all indicator fields dynamically + indicator_fields = { + k: v + for k, v in indicators.__dict__.items() + if k not in ["ticker", "date", "ohlcv", "_sa_instance_state"] + } + return self.update_indicators( + indicators.ticker, indicators.date, **indicator_fields + ) + else: + return self.create_indicators(indicators) + + # ======================================================================== + # DELETE Operations + # ======================================================================== + + def delete_ohlcv(self, ticker: str, date: datetime) -> bool: + """ + Delete an OHLCV record. + + Args: + ticker: Stock ticker symbol + date: Trading date + + Returns: + True if deleted, False if not found + """ + with self.get_session() as session: + statement = select(TickerOHLCV).where( + TickerOHLCV.ticker == ticker, TickerOHLCV.date == date + ) + ohlcv = session.exec(statement).first() + + if not ohlcv: + return False + + session.delete(ohlcv) + session.commit() + return True + + def delete_indicators(self, ticker: str, date: datetime) -> bool: + """ + Delete an indicators record. + + Args: + ticker: Stock ticker symbol + date: Trading date + + Returns: + True if deleted, False if not found + """ + with self.get_session() as session: + statement = select(IndicatorsData).where( + IndicatorsData.ticker == ticker, IndicatorsData.date == date + ) + indicators = session.exec(statement).first() + + if not indicators: + return False + + session.delete(indicators) + session.commit() + return True + + def delete_ticker_data(self, ticker: str) -> Tuple[int, int]: + """ + Delete all data (OHLCV and indicators) for a ticker. + + Args: + ticker: Stock ticker symbol + + Returns: + Tuple of (ohlcv_deleted_count, indicators_deleted_count) + """ + with self.get_session() as session: + # Delete indicators first (due to foreign key constraint) + indicators_statement = select(IndicatorsData).where( + IndicatorsData.ticker == ticker + ) + indicators_list = session.exec(indicators_statement).all() + indicators_count = len(list(indicators_list)) + + for ind in indicators_list: + session.delete(ind) + + # Delete OHLCV + ohlcv_statement = select(TickerOHLCV).where(TickerOHLCV.ticker == ticker) + ohlcv_list = session.exec(ohlcv_statement).all() + ohlcv_count = len(list(ohlcv_list)) + + for ohlcv in ohlcv_list: + session.delete(ohlcv) + + session.commit() + return (ohlcv_count, indicators_count)