From 012fa0e071f326616f48b0857d876d5a81c86621 Mon Sep 17 00:00:00 2001 From: "Bobby (aider)" Date: Sat, 8 Feb 2025 12:13:59 -0800 Subject: [PATCH] refactor: Move `get_stock_data` to utils module for shared usage --- src/screener/t_atr_ema_v2.py | 2 +- src/screener/t_sunnyband.py | 107 ------------------------------- src/utils/__init__.py | 3 + src/utils/data_utils.py | 120 +++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 108 deletions(-) create mode 100644 src/utils/__init__.py create mode 100644 src/utils/data_utils.py diff --git a/src/screener/t_atr_ema_v2.py b/src/screener/t_atr_ema_v2.py index 9bbdf5a..0ee2a50 100644 --- a/src/screener/t_atr_ema_v2.py +++ b/src/screener/t_atr_ema_v2.py @@ -3,7 +3,7 @@ import pandas as pd import os from db.db_connection import create_client from trading.position_calculator import PositionCalculator -from screener.t_sunnyband import get_stock_data +from utils.data_utils import get_stock_data from screener.user_input import get_interval_choice from indicators.three_atr_ema import ThreeATREMAIndicator diff --git a/src/screener/t_sunnyband.py b/src/screener/t_sunnyband.py index ddb70f0..0c2ce7a 100644 --- a/src/screener/t_sunnyband.py +++ b/src/screener/t_sunnyband.py @@ -8,113 +8,6 @@ from trading.position_calculator import PositionCalculator from screener.user_input import get_interval_choice -def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interval: str) -> pd.DataFrame: - """Fetch stock data from the database with enhanced fallback logic""" - try: - client = create_client() - - # Expand window to 90 days for more data robustness - start_date = start_date - timedelta(days=90) - - # First try primary data source - if interval == "daily": - table = "stock_prices_daily" - else: - table = "stock_prices" - - # Unified query format - query = f""" - SELECT - toDateTime(window_start/1000000000) as date, - open, - high, - low, - close, - volume - FROM stock_db.stock_prices - WHERE ticker = '{ticker}' - AND window_start BETWEEN - {int(start_date.timestamp() * 1e9)} AND - {int(end_date.timestamp() * 1e9)} - AND toYear(toDateTime(window_start/1000000000)) <= toYear(now()) - AND toYear(toDateTime(window_start/1000000000)) >= (toYear(now()) - 1) - ORDER BY date ASC - """ - - result = client.query(query) - - # Fallback to intraday data if needed - if not result.result_rows and interval == "daily": - # Try building daily bars from intraday data - print(f"⚠️ No daily data for {ticker}, resampling from intraday data") - intraday_query = f""" - SELECT - toDateTime(window_start/1000000000) as date, - first_value(open) AS open, - max(high) AS high, - min(low) AS low, - last_value(close) AS close, - sum(volume) AS volume - FROM stock_db.stock_prices - WHERE ticker = '{ticker}' - AND window_start BETWEEN - {int(start_date.timestamp() * 1e9)} AND - {int(end_date.timestamp() * 1e9)} - AND toYear(toDateTime(window_start/1000000000)) <= toYear(now()) - AND toYear(toDateTime(window_start/1000000000)) >= (toYear(now()) - 1) - GROUP BY date - ORDER BY date ASC - """ - result = client.query(intraday_query) - - # Fallback to different intervals if still empty - if not result.result_rows: - # Try alternate data sources - print(f"⚠️ No {interval} data for {ticker}, trying weekly") - weekly_query = f""" - SELECT - toStartOfWeek(window_start) AS date, - first_value(open) AS open, - max(high) AS high, - min(low) AS low, - last_value(close) AS close, - sum(volume) AS volume - FROM stock_db.stock_prices - WHERE ticker = '{ticker}' - GROUP BY date - ORDER BY date ASC - """ - result = client.query(weekly_query) - - if not result.result_rows: - return pd.DataFrame() - - df = pd.DataFrame( - result.result_rows, - columns=['date', 'open', 'high', 'low', 'close', 'volume'] - ) - - # Convert numeric columns - numeric_columns = ['open', 'high', 'low', 'close', 'volume'] - for col in numeric_columns: - df[col] = pd.to_numeric(df[col], errors='coerce') - - # Handle null values - if df['close'].isnull().any(): - print(f"Warning: Found null values in close prices") - df = df.dropna(subset=['close']) - - if df.empty or 'close' not in df.columns: - return pd.DataFrame() - - if df['date'].dtype == object: - df['date'] = pd.to_datetime(df['date']) - - return df - - except Exception as e: - print(f"Error fetching {ticker} data: {str(e)}") - return pd.DataFrame() def get_valid_tickers(min_price: float, max_price: float, min_volume: int, interval: str) -> list: """Get tickers that meet the price and volume criteria""" diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..04477bd --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,3 @@ +from .data_utils import get_stock_data + +__all__ = ['get_stock_data'] diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py new file mode 100644 index 0000000..3174d80 --- /dev/null +++ b/src/utils/data_utils.py @@ -0,0 +1,120 @@ +import pandas as pd +from datetime import datetime, timedelta +from db.db_connection import create_client + +def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interval: str) -> pd.DataFrame: + """ + Fetch stock data from the database with enhanced fallback logic + + Args: + ticker (str): Stock ticker symbol + start_date (datetime): Start date for data fetch + end_date (datetime): End date for data fetch + interval (str): Time interval for data ('daily', '5min', etc.) + + Returns: + pd.DataFrame: DataFrame with OHLCV data + """ + try: + client = create_client() + + # Expand window to 90 days for more data robustness + start_date = start_date - timedelta(days=90) + + # First try primary data source + if interval == "daily": + table = "stock_prices_daily" + else: + table = "stock_prices" + + # Unified query format + query = f""" + SELECT + toDateTime(window_start/1000000000) as date, + open, + high, + low, + close, + volume + FROM stock_db.stock_prices + WHERE ticker = '{ticker}' + AND window_start BETWEEN + {int(start_date.timestamp() * 1e9)} AND + {int(end_date.timestamp() * 1e9)} + AND toYear(toDateTime(window_start/1000000000)) <= toYear(now()) + AND toYear(toDateTime(window_start/1000000000)) >= (toYear(now()) - 1) + ORDER BY date ASC + """ + + result = client.query(query) + + # Fallback to intraday data if needed + if not result.result_rows and interval == "daily": + print(f"⚠️ No daily data for {ticker}, resampling from intraday data") + intraday_query = f""" + SELECT + toDateTime(window_start/1000000000) as date, + first_value(open) AS open, + max(high) AS high, + min(low) AS low, + last_value(close) AS close, + sum(volume) AS volume + FROM stock_db.stock_prices + WHERE ticker = '{ticker}' + AND window_start BETWEEN + {int(start_date.timestamp() * 1e9)} AND + {int(end_date.timestamp() * 1e9)} + AND toYear(toDateTime(window_start/1000000000)) <= toYear(now()) + AND toYear(toDateTime(window_start/1000000000)) >= (toYear(now()) - 1) + GROUP BY date + ORDER BY date ASC + """ + result = client.query(intraday_query) + + # Fallback to different intervals if still empty + if not result.result_rows: + print(f"⚠️ No {interval} data for {ticker}, trying weekly") + weekly_query = f""" + SELECT + toStartOfWeek(window_start) AS date, + first_value(open) AS open, + max(high) AS high, + min(low) AS low, + last_value(close) AS close, + sum(volume) AS volume + FROM stock_db.stock_prices + WHERE ticker = '{ticker}' + GROUP BY date + ORDER BY date ASC + """ + result = client.query(weekly_query) + + if not result.result_rows: + return pd.DataFrame() + + df = pd.DataFrame( + result.result_rows, + columns=['date', 'open', 'high', 'low', 'close', 'volume'] + ) + + # Convert numeric columns + numeric_columns = ['open', 'high', 'low', 'close', 'volume'] + for col in numeric_columns: + df[col] = pd.to_numeric(df[col], errors='coerce') + + # Handle null values + if df['close'].isnull().any(): + print(f"Warning: Found null values in close prices") + df = df.dropna(subset=['close']) + + if df.empty or 'close' not in df.columns: + return pd.DataFrame() + + if df['date'].dtype == object: + df['date'] = pd.to_datetime(df['date']) + + return df + + except Exception as e: + print(f"Error fetching {ticker} data: {str(e)}") + return pd.DataFrame()