From 3e98ba4e9d404a0e22bba41ffd2a26ba04a7a678 Mon Sep 17 00:00:00 2001 From: "Bobby (aider)" Date: Wed, 12 Feb 2025 19:50:00 -0800 Subject: [PATCH] refactor: Resolve circular import by creating common_utils module --- src/utils/common_utils.py | 172 +++++++++++++++++++++++++++++++++++++ src/utils/data_utils.py | 4 +- src/utils/scanner_utils.py | 30 +------ 3 files changed, 174 insertions(+), 32 deletions(-) create mode 100644 src/utils/common_utils.py diff --git a/src/utils/common_utils.py b/src/utils/common_utils.py new file mode 100644 index 0000000..feede68 --- /dev/null +++ b/src/utils/common_utils.py @@ -0,0 +1,172 @@ +import os +import pandas as pd +from datetime import datetime, timedelta +from typing import Optional +from db.db_connection import create_client + +def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]: + """ + Get user input with escape option + """ + while True: + value = input(f"{prompt} (q to quit): ").strip() + + if value.lower() in ['q', 'quit', 'exit']: + return None + + if not value and allow_empty: + return None + + try: + if input_type == bool: + return value.lower() in ['y', 'yes', 'true', '1'] + return input_type(value) + except ValueError: + print(f"Please enter a valid {input_type.__name__}") + +def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interval: str) -> pd.DataFrame: + """ + Fetch and resample stock data based on the chosen interval + """ + try: + with create_client() as client: + # Expand window to get enough data for calculations + start_date = start_date - timedelta(days=90) + + query = f""" + SELECT + toDateTime(window_start/1000000000) as date, + open, + high, + low, + close, + volume + FROM stock_db.stock_prices + WHERE ticker = '{ticker}' + AND window_start BETWEEN + {int(start_date.timestamp() * 1e9)} AND + {int(end_date.timestamp() * 1e9)} + AND toDateTime(window_start/1000000000) <= now() + ORDER BY date ASC + """ + + result = client.query(query) + + if not result.result_rows: + return pd.DataFrame() + + df = pd.DataFrame( + result.result_rows, + columns=['date', 'open', 'high', 'low', 'close', 'volume'] + ) + + numeric_columns = ['open', 'high', 'low', 'close', 'volume'] + for col in numeric_columns: + df[col] = pd.to_numeric(df[col], errors='coerce') + + df['date'] = pd.to_datetime(df['date']) + df.set_index('date', inplace=True) + + if interval == 'daily': + rule = '1D' + elif interval == '5min': + rule = '5T' + elif interval == '15min': + rule = '15T' + elif interval == '30min': + rule = '30T' + elif interval == '1hour': + rule = '1H' + else: + rule = '1D' + + resampled = df.resample(rule).agg({ + 'open': 'first', + 'high': 'max', + 'low': 'min', + 'close': 'last', + 'volume': 'sum' + }).dropna() + + resampled.reset_index(inplace=True) + + mask = (resampled['date'] >= start_date + timedelta(days=89)) & (resampled['date'] <= end_date) + resampled = resampled.loc[mask] + + if resampled['close'].isnull().any(): + print(f"Warning: Found null values in close prices") + resampled = resampled.dropna(subset=['close']) + + if resampled.empty or 'close' not in resampled.columns: + return pd.DataFrame() + + return resampled + + except Exception as e: + print(f"Error fetching {ticker} data: {str(e)}") + return pd.DataFrame() + +def get_qualified_stocks(start_date: datetime, end_date: datetime, min_price: float, max_price: float, min_volume: int) -> list: + """ + Get qualified stocks based on price and volume criteria within date range + """ + try: + start_ts = int(start_date.timestamp() * 1000000000) + end_ts = int(end_date.timestamp() * 1000000000) + + with create_client() as client: + query = f""" + WITH filtered_data AS ( + SELECT + sp.ticker, + sp.window_start, + sp.close, + sp.volume, + t.type as stock_type, + toDateTime(toDateTime(sp.window_start/1000000000)) as trade_date + FROM stock_db.stock_prices sp + JOIN stock_db.stock_tickers t ON sp.ticker = t.ticker + WHERE window_start BETWEEN {start_ts} AND {end_ts} + AND toDateTime(window_start/1000000000) <= now() + AND close BETWEEN {min_price} AND {max_price} + AND volume >= {min_volume} + ), + daily_data AS ( + SELECT + ticker, + stock_type, + toDate(trade_date) as date, + argMax(close, window_start) as daily_close, + sum(volume) as daily_volume + FROM filtered_data + GROUP BY ticker, stock_type, toDate(trade_date) + ), + latest_data AS ( + SELECT + ticker, + any(stock_type) as stock_type, + argMax(daily_close, date) as last_close, + sum(daily_volume) as total_volume, + max(toUnixTimestamp(date)) as last_update + FROM daily_data + GROUP BY ticker + HAVING last_close BETWEEN {min_price} AND {max_price} + ) + SELECT + ticker, + last_close, + total_volume, + last_update, + stock_type + FROM latest_data + ORDER BY ticker + """ + + result = client.query(query) + qualified_stocks = [(row[0], row[1], row[2], row[3], row[4]) for row in result.result_rows] + + return qualified_stocks + + except Exception as e: + print(f"Error getting qualified stocks: {str(e)}") + return [] diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 2f2be0c..f93e2d7 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -2,10 +2,8 @@ import os import pandas as pd import yfinance as yf from datetime import datetime, timedelta -from db.db_connection import create_client -from screener.user_input import get_interval_choice, get_date_range from trading.position_calculator import PositionCalculator -from utils.scanner_utils import initialize_scanner, get_user_input +from utils.common_utils import get_user_input, get_stock_data, get_qualified_stocks from typing import Optional def get_float_input(prompt: str) -> Optional[float]: diff --git a/src/utils/scanner_utils.py b/src/utils/scanner_utils.py index 123a4f2..dedc251 100644 --- a/src/utils/scanner_utils.py +++ b/src/utils/scanner_utils.py @@ -1,37 +1,9 @@ from datetime import datetime, timedelta -from utils.data_utils import get_stock_data, get_qualified_stocks +from utils.common_utils import get_user_input, get_stock_data, get_qualified_stocks from screener.user_input import get_interval_choice, get_date_range from trading.position_calculator import PositionCalculator from typing import Optional -def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]: - """ - Get user input with escape option - - Args: - prompt (str): Input prompt to display - input_type (type): Expected input type (str, float, int) - allow_empty (bool): Whether to allow empty input - - Returns: - Optional[any]: Converted input value or None if user wants to exit - """ - while True: - value = input(f"{prompt} (q to quit): ").strip() - - if value.lower() in ['q', 'quit', 'exit']: - return None - - if not value and allow_empty: - return None - - try: - if input_type == bool: - return value.lower() in ['y', 'yes', 'true', '1'] - return input_type(value) - except ValueError: - print(f"Please enter a valid {input_type.__name__}") - def initialize_scanner(min_price: float, max_price: float, min_volume: int, portfolio_size: float = None, interval: str = "1d", start_date: datetime = None, end_date: datetime = None) -> tuple: