feat(database): add trading data CRUD operations
This commit is contained in:
parent
b4cc94a444
commit
f5aa0e5848
476
paperone/database.py
Normal file
476
paperone/database.py
Normal file
@ -0,0 +1,476 @@
|
|||||||
|
from sqlmodel import SQLModel, Session, create_engine, select
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from .models import TickerOHLCV, IndicatorsData
|
||||||
|
from .entities import TimeSeriesTickerData
|
||||||
|
|
||||||
|
|
||||||
|
class TradingDataCRUD:
|
||||||
|
"""
|
||||||
|
CRUD operations for trading data with SQLModel/SQLite.
|
||||||
|
Handles OHLCV data and technical indicators with proper session management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, database_url: str = "sqlite:///trading_data.db"):
|
||||||
|
"""
|
||||||
|
Initialize the CRUD manager with database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: SQLite database URL (default: local file)
|
||||||
|
"""
|
||||||
|
self.engine = create_engine(
|
||||||
|
database_url,
|
||||||
|
echo=False, # Set to True for SQL query debugging
|
||||||
|
connect_args={"check_same_thread": False}, # Needed for SQLite
|
||||||
|
)
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
def _create_tables(self):
|
||||||
|
"""Create all tables if they don't exist."""
|
||||||
|
SQLModel.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session(self):
|
||||||
|
"""Context manager for database sessions with automatic cleanup."""
|
||||||
|
session = Session(self.engine)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# CREATE Operations
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
def create_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
|
||||||
|
"""
|
||||||
|
Insert a single OHLCV record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ohlcv: TickerOHLCV instance to insert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The inserted TickerOHLCV record
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
session.add(ohlcv)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(ohlcv)
|
||||||
|
return ohlcv
|
||||||
|
|
||||||
|
def create_ohlcv_bulk(self, ohlcv_list: List[TickerOHLCV]) -> int:
|
||||||
|
"""
|
||||||
|
Bulk insert OHLCV records (more efficient for large datasets).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ohlcv_list: List of TickerOHLCV instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of records inserted
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
session.add_all(ohlcv_list)
|
||||||
|
session.commit()
|
||||||
|
return len(ohlcv_list)
|
||||||
|
|
||||||
|
def create_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
|
||||||
|
"""
|
||||||
|
Insert a single indicators record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicators: IndicatorsData instance to insert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The inserted IndicatorsData record
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
session.add(indicators)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(indicators)
|
||||||
|
return indicators
|
||||||
|
|
||||||
|
def create_indicators_bulk(self, indicators_list: List[IndicatorsData]) -> int:
|
||||||
|
"""
|
||||||
|
Bulk insert indicators records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicators_list: List of IndicatorsData instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of records inserted
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
session.add_all(indicators_list)
|
||||||
|
session.commit()
|
||||||
|
return len(indicators_list)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# READ Operations
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
def get_ohlcv(self, ticker: str, date: datetime) -> Optional[TickerOHLCV]:
|
||||||
|
"""
|
||||||
|
Get a single OHLCV record by ticker and date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TickerOHLCV record or None if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(TickerOHLCV).where(
|
||||||
|
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
def get_ohlcv_range(
|
||||||
|
self, ticker: str, start_date: datetime, end_date: datetime
|
||||||
|
) -> List[TickerOHLCV]:
|
||||||
|
"""
|
||||||
|
Get OHLCV records for a ticker within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
start_date: Start of date range (inclusive)
|
||||||
|
end_date: End of date range (inclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TickerOHLCV records, sorted by date
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(TickerOHLCV)
|
||||||
|
.where(
|
||||||
|
TickerOHLCV.ticker == ticker,
|
||||||
|
TickerOHLCV.date >= start_date,
|
||||||
|
TickerOHLCV.date <= end_date,
|
||||||
|
)
|
||||||
|
.order_by(TickerOHLCV.date)
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
def get_ohlcv_latest(self, ticker: str, limit: int = 1) -> List[TickerOHLCV]:
|
||||||
|
"""
|
||||||
|
Get the most recent OHLCV record(s) for a ticker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
limit: Number of recent records to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of most recent TickerOHLCV records, sorted by date descending
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(TickerOHLCV)
|
||||||
|
.where(TickerOHLCV.ticker == ticker)
|
||||||
|
.order_by(TickerOHLCV.date.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
def get_indicators(self, ticker: str, date: datetime) -> Optional[IndicatorsData]:
|
||||||
|
"""
|
||||||
|
Get indicators for a specific ticker and date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndicatorsData record or None if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(IndicatorsData).where(
|
||||||
|
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||||
|
)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
def get_indicators_range(
|
||||||
|
self, ticker: str, start_date: datetime, end_date: datetime
|
||||||
|
) -> List[IndicatorsData]:
|
||||||
|
"""
|
||||||
|
Get indicators for a ticker within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
start_date: Start of date range (inclusive)
|
||||||
|
end_date: End of date range (inclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of IndicatorsData records, sorted by date
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(IndicatorsData)
|
||||||
|
.where(
|
||||||
|
IndicatorsData.ticker == ticker,
|
||||||
|
IndicatorsData.date >= start_date,
|
||||||
|
IndicatorsData.date <= end_date,
|
||||||
|
)
|
||||||
|
.order_by(IndicatorsData.date)
|
||||||
|
)
|
||||||
|
return list(session.exec(statement).all())
|
||||||
|
|
||||||
|
def get_ohlcv_with_indicators(
|
||||||
|
self, ticker: str, date: datetime
|
||||||
|
) -> Optional[Tuple[TickerOHLCV, IndicatorsData]]:
|
||||||
|
"""
|
||||||
|
Get OHLCV and indicators together for a specific ticker and date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (TickerOHLCV, IndicatorsData) or None if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
ohlcv_statement = select(TickerOHLCV).where(
|
||||||
|
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||||
|
)
|
||||||
|
ohlcv = session.exec(ohlcv_statement).first()
|
||||||
|
|
||||||
|
if not ohlcv:
|
||||||
|
return None
|
||||||
|
|
||||||
|
indicators_statement = select(IndicatorsData).where(
|
||||||
|
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||||
|
)
|
||||||
|
indicators = session.exec(indicators_statement).first()
|
||||||
|
|
||||||
|
if not indicators:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (ohlcv, indicators)
|
||||||
|
|
||||||
|
def get_time_series(
|
||||||
|
self, ticker: str, start_date: datetime, end_date: datetime
|
||||||
|
) -> TimeSeriesTickerData:
|
||||||
|
"""
|
||||||
|
Get time series data for building TimeSeriesTickerData.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
start_date: Start of date range
|
||||||
|
end_date: End of date range
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TimeSeriesTickerData instance with OHLCV data
|
||||||
|
"""
|
||||||
|
ohlcv_list = self.get_ohlcv_range(ticker, start_date, end_date)
|
||||||
|
return TimeSeriesTickerData.build_time_series_ticker_data(ticker, ohlcv_list)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# UPDATE Operations
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
def update_ohlcv(
|
||||||
|
self, ticker: str, date: datetime, **kwargs
|
||||||
|
) -> Optional[TickerOHLCV]:
|
||||||
|
"""
|
||||||
|
Update OHLCV fields for a specific record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
**kwargs: Fields to update (e.g., close=150.0, volume=1000000)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated TickerOHLCV record or None if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(TickerOHLCV).where(
|
||||||
|
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||||
|
)
|
||||||
|
ohlcv = session.exec(statement).first()
|
||||||
|
|
||||||
|
if not ohlcv:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(ohlcv, key):
|
||||||
|
setattr(ohlcv, key, value)
|
||||||
|
|
||||||
|
session.add(ohlcv)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(ohlcv)
|
||||||
|
return ohlcv
|
||||||
|
|
||||||
|
def update_indicators(
|
||||||
|
self, ticker: str, date: datetime, **kwargs
|
||||||
|
) -> Optional[IndicatorsData]:
|
||||||
|
"""
|
||||||
|
Update indicator fields for a specific record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
**kwargs: Fields to update (e.g., rsi_14=55.2, macd_line=1.8)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated IndicatorsData record or None if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(IndicatorsData).where(
|
||||||
|
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||||
|
)
|
||||||
|
indicators = session.exec(statement).first()
|
||||||
|
|
||||||
|
if not indicators:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(indicators, key):
|
||||||
|
setattr(indicators, key, value)
|
||||||
|
|
||||||
|
session.add(indicators)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(indicators)
|
||||||
|
return indicators
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# UPSERT Operations (Insert or Update)
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
def upsert_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
|
||||||
|
"""
|
||||||
|
Insert or update OHLCV record if it already exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ohlcv: TickerOHLCV instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The upserted TickerOHLCV record
|
||||||
|
"""
|
||||||
|
existing = self.get_ohlcv(ohlcv.ticker, ohlcv.date)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return self.update_ohlcv(
|
||||||
|
ohlcv.ticker,
|
||||||
|
ohlcv.date,
|
||||||
|
open=ohlcv.open,
|
||||||
|
close=ohlcv.close,
|
||||||
|
low=ohlcv.low,
|
||||||
|
high=ohlcv.high,
|
||||||
|
avg=ohlcv.avg,
|
||||||
|
volume=ohlcv.volume,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.create_ohlcv(ohlcv)
|
||||||
|
|
||||||
|
def upsert_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
|
||||||
|
"""
|
||||||
|
Insert or update indicators record if it already exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indicators: IndicatorsData instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The upserted IndicatorsData record
|
||||||
|
"""
|
||||||
|
existing = self.get_indicators(indicators.ticker, indicators.date)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Get all indicator fields dynamically
|
||||||
|
indicator_fields = {
|
||||||
|
k: v
|
||||||
|
for k, v in indicators.__dict__.items()
|
||||||
|
if k not in ["ticker", "date", "ohlcv", "_sa_instance_state"]
|
||||||
|
}
|
||||||
|
return self.update_indicators(
|
||||||
|
indicators.ticker, indicators.date, **indicator_fields
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.create_indicators(indicators)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# DELETE Operations
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
def delete_ohlcv(self, ticker: str, date: datetime) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an OHLCV record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(TickerOHLCV).where(
|
||||||
|
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||||
|
)
|
||||||
|
ohlcv = session.exec(statement).first()
|
||||||
|
|
||||||
|
if not ohlcv:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.delete(ohlcv)
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_indicators(self, ticker: str, date: datetime) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an indicators record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
date: Trading date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
statement = select(IndicatorsData).where(
|
||||||
|
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||||
|
)
|
||||||
|
indicators = session.exec(statement).first()
|
||||||
|
|
||||||
|
if not indicators:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.delete(indicators)
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_ticker_data(self, ticker: str) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Delete all data (OHLCV and indicators) for a ticker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ohlcv_deleted_count, indicators_deleted_count)
|
||||||
|
"""
|
||||||
|
with self.get_session() as session:
|
||||||
|
# Delete indicators first (due to foreign key constraint)
|
||||||
|
indicators_statement = select(IndicatorsData).where(
|
||||||
|
IndicatorsData.ticker == ticker
|
||||||
|
)
|
||||||
|
indicators_list = session.exec(indicators_statement).all()
|
||||||
|
indicators_count = len(list(indicators_list))
|
||||||
|
|
||||||
|
for ind in indicators_list:
|
||||||
|
session.delete(ind)
|
||||||
|
|
||||||
|
# Delete OHLCV
|
||||||
|
ohlcv_statement = select(TickerOHLCV).where(TickerOHLCV.ticker == ticker)
|
||||||
|
ohlcv_list = session.exec(ohlcv_statement).all()
|
||||||
|
ohlcv_count = len(list(ohlcv_list))
|
||||||
|
|
||||||
|
for ohlcv in ohlcv_list:
|
||||||
|
session.delete(ohlcv)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return (ohlcv_count, indicators_count)
|
Loading…
x
Reference in New Issue
Block a user