refactor: Replace connection pool with context manager for ClickHouse connections

This commit is contained in:
Bobby (aider) 2025-02-08 20:52:36 -08:00
parent 68e665a0c2
commit 865438bbf8
2 changed files with 41 additions and 117 deletions

View File

@ -1,10 +1,9 @@
import logging
import os
import time
from queue import Queue
from threading import Lock
import clickhouse_connect
from dotenv import load_dotenv
from contextlib import contextmanager
# Load environment variables from .env file
load_dotenv()
@ -13,117 +12,43 @@ load_dotenv()
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
# Connection pool settings
MAX_POOL_SIZE = 5
connection_pool = Queue(maxsize=MAX_POOL_SIZE)
pool_lock = Lock()
def create_base_client():
"""
Create a base ClickHouse client without session management
"""
clickhouse_password = os.getenv("CLICKHOUSE_PASSWORD")
if not clickhouse_password:
raise ValueError("CLICKHOUSE_PASSWORD environment variable not set.")
max_retries = 5
retry_delay = 1
last_exception = None
for attempt in range(max_retries):
try:
client = clickhouse_connect.get_client(
host="clickhouse.abellana.work",
port=443,
username="default",
password=clickhouse_password,
secure=True,
connect_timeout=10,
send_receive_timeout=300,
query_limit=0,
compress=True,
settings={
'enable_http_compression': 1,
'database': 'stock_db',
'max_execution_time': 300,
'mutations_sync': 0
}
)
# Test the connection
try:
client.query('SELECT 1')
logger.debug(f"Successfully established connection")
return client
except Exception as e:
logger.error(f"Connection test failed: {str(e)}")
raise
except Exception as e:
last_exception = e
error_message = str(e)
if attempt < max_retries - 1:
wait_time = retry_delay * (2 ** attempt)
logger.warning(f"Connection attempt {attempt + 1} failed: {error_message}")
logger.warning(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
continue
logger.error(f"Failed to establish connection after {max_retries} attempts: {error_message}")
raise last_exception
def initialize_pool():
"""Initialize the connection pool"""
while not connection_pool.full():
try:
client = create_base_client()
connection_pool.put(client)
except Exception as e:
logger.error(f"Error initializing pool connection: {str(e)}")
break
@contextmanager
def create_client():
"""
Get a client from the pool or create a new one if needed.
Maintains the original function name for compatibility.
Context manager for creating and properly closing ClickHouse connections
"""
# Initialize pool if empty
with pool_lock:
if connection_pool.empty():
initialize_pool()
client = None
try:
# Try to get a connection from the pool
client = connection_pool.get_nowait()
clickhouse_password = os.getenv("CLICKHOUSE_PASSWORD")
if not clickhouse_password:
raise ValueError("CLICKHOUSE_PASSWORD environment variable not set.")
# Test the connection
try:
client.query('SELECT 1')
return client
except Exception:
# Connection is dead, create a new one
logger.warning("Replacing dead connection in pool")
client = create_base_client()
return client
client = clickhouse_connect.get_client(
host="clickhouse.abellana.work",
port=443,
username="default",
password=clickhouse_password,
secure=True,
connect_timeout=10,
send_receive_timeout=300,
query_limit=0,
compress=True,
settings={
'enable_http_compression': 1,
'database': 'stock_db',
'max_execution_time': 300,
'mutations_sync': 0
}
)
except Exception:
# Pool is empty or other error, create new connection
logger.warning("Creating new connection (pool empty or error)")
return create_base_client()
yield client
def return_client(client):
"""
Return a client to the pool if possible
"""
try:
if not connection_pool.full():
connection_pool.put(client)
else:
client.close()
except Exception as e:
logger.warning(f"Error returning client to pool: {str(e)}")
# Export the create_client function
__all__ = ['create_client']
logger.error(f"Connection error: {str(e)}")
raise
finally:
if client:
try:
client.close()
except Exception as e:
logger.warning(f"Error closing connection: {str(e)}")

View File

@ -73,13 +73,12 @@ def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interv
pd.DataFrame: Resampled DataFrame with OHLCV data
"""
try:
client = create_client()
with create_client() as client:
# Expand window to get enough data for calculations
start_date = start_date - timedelta(days=90)
# Expand window to get enough data for calculations
start_date = start_date - timedelta(days=90)
# Base query to get raw data at finest granularity
query = f"""
# Base query to get raw data at finest granularity
query = f"""
SELECT
toDateTime(window_start/1000000000) as date,
open,