feat(database): add trading data CRUD operations
This commit is contained in:
parent
b4cc94a444
commit
f5aa0e5848
476
paperone/database.py
Normal file
476
paperone/database.py
Normal file
@ -0,0 +1,476 @@
|
||||
from sqlmodel import SQLModel, Session, create_engine, select
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from contextlib import contextmanager
|
||||
from .models import TickerOHLCV, IndicatorsData
|
||||
from .entities import TimeSeriesTickerData
|
||||
|
||||
|
||||
class TradingDataCRUD:
|
||||
"""
|
||||
CRUD operations for trading data with SQLModel/SQLite.
|
||||
Handles OHLCV data and technical indicators with proper session management.
|
||||
"""
|
||||
|
||||
def __init__(self, database_url: str = "sqlite:///trading_data.db"):
|
||||
"""
|
||||
Initialize the CRUD manager with database connection.
|
||||
|
||||
Args:
|
||||
database_url: SQLite database URL (default: local file)
|
||||
"""
|
||||
self.engine = create_engine(
|
||||
database_url,
|
||||
echo=False, # Set to True for SQL query debugging
|
||||
connect_args={"check_same_thread": False}, # Needed for SQLite
|
||||
)
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create all tables if they don't exist."""
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""Context manager for database sessions with automatic cleanup."""
|
||||
session = Session(self.engine)
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# ========================================================================
|
||||
# CREATE Operations
|
||||
# ========================================================================
|
||||
|
||||
def create_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
|
||||
"""
|
||||
Insert a single OHLCV record.
|
||||
|
||||
Args:
|
||||
ohlcv: TickerOHLCV instance to insert
|
||||
|
||||
Returns:
|
||||
The inserted TickerOHLCV record
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
session.add(ohlcv)
|
||||
session.commit()
|
||||
session.refresh(ohlcv)
|
||||
return ohlcv
|
||||
|
||||
def create_ohlcv_bulk(self, ohlcv_list: List[TickerOHLCV]) -> int:
|
||||
"""
|
||||
Bulk insert OHLCV records (more efficient for large datasets).
|
||||
|
||||
Args:
|
||||
ohlcv_list: List of TickerOHLCV instances
|
||||
|
||||
Returns:
|
||||
Number of records inserted
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
session.add_all(ohlcv_list)
|
||||
session.commit()
|
||||
return len(ohlcv_list)
|
||||
|
||||
def create_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
|
||||
"""
|
||||
Insert a single indicators record.
|
||||
|
||||
Args:
|
||||
indicators: IndicatorsData instance to insert
|
||||
|
||||
Returns:
|
||||
The inserted IndicatorsData record
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
session.add(indicators)
|
||||
session.commit()
|
||||
session.refresh(indicators)
|
||||
return indicators
|
||||
|
||||
def create_indicators_bulk(self, indicators_list: List[IndicatorsData]) -> int:
|
||||
"""
|
||||
Bulk insert indicators records.
|
||||
|
||||
Args:
|
||||
indicators_list: List of IndicatorsData instances
|
||||
|
||||
Returns:
|
||||
Number of records inserted
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
session.add_all(indicators_list)
|
||||
session.commit()
|
||||
return len(indicators_list)
|
||||
|
||||
# ========================================================================
|
||||
# READ Operations
|
||||
# ========================================================================
|
||||
|
||||
def get_ohlcv(self, ticker: str, date: datetime) -> Optional[TickerOHLCV]:
|
||||
"""
|
||||
Get a single OHLCV record by ticker and date.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
TickerOHLCV record or None if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(TickerOHLCV).where(
|
||||
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def get_ohlcv_range(
|
||||
self, ticker: str, start_date: datetime, end_date: datetime
|
||||
) -> List[TickerOHLCV]:
|
||||
"""
|
||||
Get OHLCV records for a ticker within a date range.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start of date range (inclusive)
|
||||
end_date: End of date range (inclusive)
|
||||
|
||||
Returns:
|
||||
List of TickerOHLCV records, sorted by date
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = (
|
||||
select(TickerOHLCV)
|
||||
.where(
|
||||
TickerOHLCV.ticker == ticker,
|
||||
TickerOHLCV.date >= start_date,
|
||||
TickerOHLCV.date <= end_date,
|
||||
)
|
||||
.order_by(TickerOHLCV.date)
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def get_ohlcv_latest(self, ticker: str, limit: int = 1) -> List[TickerOHLCV]:
|
||||
"""
|
||||
Get the most recent OHLCV record(s) for a ticker.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
limit: Number of recent records to retrieve
|
||||
|
||||
Returns:
|
||||
List of most recent TickerOHLCV records, sorted by date descending
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = (
|
||||
select(TickerOHLCV)
|
||||
.where(TickerOHLCV.ticker == ticker)
|
||||
.order_by(TickerOHLCV.date.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def get_indicators(self, ticker: str, date: datetime) -> Optional[IndicatorsData]:
|
||||
"""
|
||||
Get indicators for a specific ticker and date.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
IndicatorsData record or None if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(IndicatorsData).where(
|
||||
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||
)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def get_indicators_range(
|
||||
self, ticker: str, start_date: datetime, end_date: datetime
|
||||
) -> List[IndicatorsData]:
|
||||
"""
|
||||
Get indicators for a ticker within a date range.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start of date range (inclusive)
|
||||
end_date: End of date range (inclusive)
|
||||
|
||||
Returns:
|
||||
List of IndicatorsData records, sorted by date
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = (
|
||||
select(IndicatorsData)
|
||||
.where(
|
||||
IndicatorsData.ticker == ticker,
|
||||
IndicatorsData.date >= start_date,
|
||||
IndicatorsData.date <= end_date,
|
||||
)
|
||||
.order_by(IndicatorsData.date)
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def get_ohlcv_with_indicators(
|
||||
self, ticker: str, date: datetime
|
||||
) -> Optional[Tuple[TickerOHLCV, IndicatorsData]]:
|
||||
"""
|
||||
Get OHLCV and indicators together for a specific ticker and date.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
Tuple of (TickerOHLCV, IndicatorsData) or None if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
ohlcv_statement = select(TickerOHLCV).where(
|
||||
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||
)
|
||||
ohlcv = session.exec(ohlcv_statement).first()
|
||||
|
||||
if not ohlcv:
|
||||
return None
|
||||
|
||||
indicators_statement = select(IndicatorsData).where(
|
||||
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||
)
|
||||
indicators = session.exec(indicators_statement).first()
|
||||
|
||||
if not indicators:
|
||||
return None
|
||||
|
||||
return (ohlcv, indicators)
|
||||
|
||||
def get_time_series(
|
||||
self, ticker: str, start_date: datetime, end_date: datetime
|
||||
) -> TimeSeriesTickerData:
|
||||
"""
|
||||
Get time series data for building TimeSeriesTickerData.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
|
||||
Returns:
|
||||
TimeSeriesTickerData instance with OHLCV data
|
||||
"""
|
||||
ohlcv_list = self.get_ohlcv_range(ticker, start_date, end_date)
|
||||
return TimeSeriesTickerData.build_time_series_ticker_data(ticker, ohlcv_list)
|
||||
|
||||
# ========================================================================
|
||||
# UPDATE Operations
|
||||
# ========================================================================
|
||||
|
||||
def update_ohlcv(
|
||||
self, ticker: str, date: datetime, **kwargs
|
||||
) -> Optional[TickerOHLCV]:
|
||||
"""
|
||||
Update OHLCV fields for a specific record.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
**kwargs: Fields to update (e.g., close=150.0, volume=1000000)
|
||||
|
||||
Returns:
|
||||
Updated TickerOHLCV record or None if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(TickerOHLCV).where(
|
||||
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||
)
|
||||
ohlcv = session.exec(statement).first()
|
||||
|
||||
if not ohlcv:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(ohlcv, key):
|
||||
setattr(ohlcv, key, value)
|
||||
|
||||
session.add(ohlcv)
|
||||
session.commit()
|
||||
session.refresh(ohlcv)
|
||||
return ohlcv
|
||||
|
||||
def update_indicators(
|
||||
self, ticker: str, date: datetime, **kwargs
|
||||
) -> Optional[IndicatorsData]:
|
||||
"""
|
||||
Update indicator fields for a specific record.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
**kwargs: Fields to update (e.g., rsi_14=55.2, macd_line=1.8)
|
||||
|
||||
Returns:
|
||||
Updated IndicatorsData record or None if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(IndicatorsData).where(
|
||||
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||
)
|
||||
indicators = session.exec(statement).first()
|
||||
|
||||
if not indicators:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(indicators, key):
|
||||
setattr(indicators, key, value)
|
||||
|
||||
session.add(indicators)
|
||||
session.commit()
|
||||
session.refresh(indicators)
|
||||
return indicators
|
||||
|
||||
# ========================================================================
|
||||
# UPSERT Operations (Insert or Update)
|
||||
# ========================================================================
|
||||
|
||||
def upsert_ohlcv(self, ohlcv: TickerOHLCV) -> TickerOHLCV:
|
||||
"""
|
||||
Insert or update OHLCV record if it already exists.
|
||||
|
||||
Args:
|
||||
ohlcv: TickerOHLCV instance
|
||||
|
||||
Returns:
|
||||
The upserted TickerOHLCV record
|
||||
"""
|
||||
existing = self.get_ohlcv(ohlcv.ticker, ohlcv.date)
|
||||
|
||||
if existing:
|
||||
return self.update_ohlcv(
|
||||
ohlcv.ticker,
|
||||
ohlcv.date,
|
||||
open=ohlcv.open,
|
||||
close=ohlcv.close,
|
||||
low=ohlcv.low,
|
||||
high=ohlcv.high,
|
||||
avg=ohlcv.avg,
|
||||
volume=ohlcv.volume,
|
||||
)
|
||||
else:
|
||||
return self.create_ohlcv(ohlcv)
|
||||
|
||||
def upsert_indicators(self, indicators: IndicatorsData) -> IndicatorsData:
|
||||
"""
|
||||
Insert or update indicators record if it already exists.
|
||||
|
||||
Args:
|
||||
indicators: IndicatorsData instance
|
||||
|
||||
Returns:
|
||||
The upserted IndicatorsData record
|
||||
"""
|
||||
existing = self.get_indicators(indicators.ticker, indicators.date)
|
||||
|
||||
if existing:
|
||||
# Get all indicator fields dynamically
|
||||
indicator_fields = {
|
||||
k: v
|
||||
for k, v in indicators.__dict__.items()
|
||||
if k not in ["ticker", "date", "ohlcv", "_sa_instance_state"]
|
||||
}
|
||||
return self.update_indicators(
|
||||
indicators.ticker, indicators.date, **indicator_fields
|
||||
)
|
||||
else:
|
||||
return self.create_indicators(indicators)
|
||||
|
||||
# ========================================================================
|
||||
# DELETE Operations
|
||||
# ========================================================================
|
||||
|
||||
def delete_ohlcv(self, ticker: str, date: datetime) -> bool:
|
||||
"""
|
||||
Delete an OHLCV record.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(TickerOHLCV).where(
|
||||
TickerOHLCV.ticker == ticker, TickerOHLCV.date == date
|
||||
)
|
||||
ohlcv = session.exec(statement).first()
|
||||
|
||||
if not ohlcv:
|
||||
return False
|
||||
|
||||
session.delete(ohlcv)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def delete_indicators(self, ticker: str, date: datetime) -> bool:
|
||||
"""
|
||||
Delete an indicators record.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
statement = select(IndicatorsData).where(
|
||||
IndicatorsData.ticker == ticker, IndicatorsData.date == date
|
||||
)
|
||||
indicators = session.exec(statement).first()
|
||||
|
||||
if not indicators:
|
||||
return False
|
||||
|
||||
session.delete(indicators)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def delete_ticker_data(self, ticker: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Delete all data (OHLCV and indicators) for a ticker.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Tuple of (ohlcv_deleted_count, indicators_deleted_count)
|
||||
"""
|
||||
with self.get_session() as session:
|
||||
# Delete indicators first (due to foreign key constraint)
|
||||
indicators_statement = select(IndicatorsData).where(
|
||||
IndicatorsData.ticker == ticker
|
||||
)
|
||||
indicators_list = session.exec(indicators_statement).all()
|
||||
indicators_count = len(list(indicators_list))
|
||||
|
||||
for ind in indicators_list:
|
||||
session.delete(ind)
|
||||
|
||||
# Delete OHLCV
|
||||
ohlcv_statement = select(TickerOHLCV).where(TickerOHLCV.ticker == ticker)
|
||||
ohlcv_list = session.exec(ohlcv_statement).all()
|
||||
ohlcv_count = len(list(ohlcv_list))
|
||||
|
||||
for ohlcv in ohlcv_list:
|
||||
session.delete(ohlcv)
|
||||
|
||||
session.commit()
|
||||
return (ohlcv_count, indicators_count)
|
Loading…
x
Reference in New Issue
Block a user