diff --git a/src/db/db_connection.py b/src/db/db_connection.py index 4865c26..7e21222 100644 --- a/src/db/db_connection.py +++ b/src/db/db_connection.py @@ -1,10 +1,9 @@ import logging import os import time -from queue import Queue -from threading import Lock import clickhouse_connect from dotenv import load_dotenv +from contextlib import contextmanager # Load environment variables from .env file load_dotenv() @@ -13,117 +12,43 @@ load_dotenv() logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) -# Connection pool settings -MAX_POOL_SIZE = 5 -connection_pool = Queue(maxsize=MAX_POOL_SIZE) -pool_lock = Lock() - -def create_base_client(): - """ - Create a base ClickHouse client without session management - """ - clickhouse_password = os.getenv("CLICKHOUSE_PASSWORD") - if not clickhouse_password: - raise ValueError("CLICKHOUSE_PASSWORD environment variable not set.") - - max_retries = 5 - retry_delay = 1 - - last_exception = None - - for attempt in range(max_retries): - try: - client = clickhouse_connect.get_client( - host="clickhouse.abellana.work", - port=443, - username="default", - password=clickhouse_password, - secure=True, - connect_timeout=10, - send_receive_timeout=300, - query_limit=0, - compress=True, - settings={ - 'enable_http_compression': 1, - 'database': 'stock_db', - 'max_execution_time': 300, - 'mutations_sync': 0 - } - ) - - # Test the connection - try: - client.query('SELECT 1') - logger.debug(f"Successfully established connection") - return client - except Exception as e: - logger.error(f"Connection test failed: {str(e)}") - raise - - except Exception as e: - last_exception = e - error_message = str(e) - - if attempt < max_retries - 1: - wait_time = retry_delay * (2 ** attempt) - logger.warning(f"Connection attempt {attempt + 1} failed: {error_message}") - logger.warning(f"Retrying in {wait_time} seconds...") - time.sleep(wait_time) - continue - - logger.error(f"Failed to establish connection after {max_retries} attempts: {error_message}") - raise last_exception - -def initialize_pool(): - """Initialize the connection pool""" - while not connection_pool.full(): - try: - client = create_base_client() - connection_pool.put(client) - except Exception as e: - logger.error(f"Error initializing pool connection: {str(e)}") - break - +@contextmanager def create_client(): """ - Get a client from the pool or create a new one if needed. - Maintains the original function name for compatibility. + Context manager for creating and properly closing ClickHouse connections """ - # Initialize pool if empty - with pool_lock: - if connection_pool.empty(): - initialize_pool() - + client = None try: - # Try to get a connection from the pool - client = connection_pool.get_nowait() + clickhouse_password = os.getenv("CLICKHOUSE_PASSWORD") + if not clickhouse_password: + raise ValueError("CLICKHOUSE_PASSWORD environment variable not set.") + + client = clickhouse_connect.get_client( + host="clickhouse.abellana.work", + port=443, + username="default", + password=clickhouse_password, + secure=True, + connect_timeout=10, + send_receive_timeout=300, + query_limit=0, + compress=True, + settings={ + 'enable_http_compression': 1, + 'database': 'stock_db', + 'max_execution_time': 300, + 'mutations_sync': 0 + } + ) + + yield client - # Test the connection - try: - client.query('SELECT 1') - return client - except Exception: - # Connection is dead, create a new one - logger.warning("Replacing dead connection in pool") - client = create_base_client() - return client - - except Exception: - # Pool is empty or other error, create new connection - logger.warning("Creating new connection (pool empty or error)") - return create_base_client() - -def return_client(client): - """ - Return a client to the pool if possible - """ - try: - if not connection_pool.full(): - connection_pool.put(client) - else: - client.close() except Exception as e: - logger.warning(f"Error returning client to pool: {str(e)}") - -# Export the create_client function -__all__ = ['create_client'] + logger.error(f"Connection error: {str(e)}") + raise + finally: + if client: + try: + client.close() + except Exception as e: + logger.warning(f"Error closing connection: {str(e)}") diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 7c0c9cf..6bab95a 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -73,13 +73,12 @@ def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interv pd.DataFrame: Resampled DataFrame with OHLCV data """ try: - client = create_client() - - # Expand window to get enough data for calculations - start_date = start_date - timedelta(days=90) - - # Base query to get raw data at finest granularity - query = f""" + with create_client() as client: + # Expand window to get enough data for calculations + start_date = start_date - timedelta(days=90) + + # Base query to get raw data at finest granularity + query = f""" SELECT toDateTime(window_start/1000000000) as date, open,