From ca843a370ed3f197d8fe5abcaafb7bbf464fbf00 Mon Sep 17 00:00:00 2001 From: "Bobby (aider)" Date: Mon, 17 Feb 2025 16:12:34 -0800 Subject: [PATCH] refactor: Migrate watchlist functionality to ClickHouse database --- src/migrations/add_watchlist_tables.py | 49 +++++++------- src/trading/watchlist.py | 90 +++++++++++++++----------- 2 files changed, 80 insertions(+), 59 deletions(-) diff --git a/src/migrations/add_watchlist_tables.py b/src/migrations/add_watchlist_tables.py index 60093a8..2d48902 100644 --- a/src/migrations/add_watchlist_tables.py +++ b/src/migrations/add_watchlist_tables.py @@ -1,36 +1,41 @@ +import sys +from pathlib import Path + +# Add the src directory to the Python path +src_path = str(Path(__file__).parent.parent) +sys.path.append(src_path) + from db.db_connection import create_client def create_watchlist_tables(): with create_client() as client: - cursor = client.cursor() - # Create watchlists table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS watchlists ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - strategy TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + client.command(""" + CREATE TABLE IF NOT EXISTS stock_db.watchlists ( + id UInt32, + name String, + strategy String, + created_at DateTime DEFAULT now() ) + ENGINE = MergeTree() + ORDER BY (id) """) # Create watchlist items table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS watchlist_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - watchlist_id INTEGER, - ticker TEXT NOT NULL, - entry_price REAL, - target_price REAL, - stop_loss REAL, - notes TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (watchlist_id) REFERENCES watchlists(id), - UNIQUE(watchlist_id, ticker) + client.command(""" + CREATE TABLE IF NOT EXISTS stock_db.watchlist_items ( + id UInt32, + watchlist_id UInt32, + ticker String, + entry_price Float64, + target_price Float64, + stop_loss Float64, + notes String, + created_at DateTime DEFAULT now() ) + ENGINE = MergeTree() + ORDER BY (id, watchlist_id) """) - - client.commit() if __name__ == "__main__": create_watchlist_tables() diff --git a/src/trading/watchlist.py b/src/trading/watchlist.py index 1e0c7ed..ec23fb4 100644 --- a/src/trading/watchlist.py +++ b/src/trading/watchlist.py @@ -16,56 +16,72 @@ class WatchlistItem: def create_watchlist(name: str, strategy: Optional[str] = None) -> int: with create_client() as client: - cursor = client.cursor() - cursor.execute( - "INSERT INTO watchlists (name, strategy) VALUES (?, ?)", - (name, strategy) - ) - return cursor.lastrowid + # Get the next available ID + result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlists") + next_id = result.first_row[0] if result.first_row[0] is not None else 1 + + client.insert('stock_db.watchlists', + [(next_id, name, strategy or '', datetime.now())], + column_names=['id', 'name', 'strategy', 'created_at']) + return next_id def get_watchlists() -> List[dict]: with create_client() as client: - cursor = client.cursor() - cursor.execute("SELECT * FROM watchlists ORDER BY name") - return [dict(row) for row in cursor.fetchall()] + result = client.query("SELECT * FROM stock_db.watchlists ORDER BY name") + return [dict(zip(result.column_names, row)) for row in result.result_rows] def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool: with create_client() as client: - cursor = client.cursor() try: - cursor.execute(""" - INSERT INTO watchlist_items - (watchlist_id, ticker, entry_price, target_price, stop_loss, notes) - VALUES (?, ?, ?, ?, ?, ?) - """, (watchlist_id, item.ticker, item.entry_price, item.target_price, - item.stop_loss, item.notes)) - client.commit() + # Get the next available ID + result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlist_items") + next_id = result.first_row[0] if result.first_row[0] is not None else 1 + + client.insert('stock_db.watchlist_items', + [(next_id, watchlist_id, item.ticker, item.entry_price, + item.target_price, item.stop_loss, item.notes or '', datetime.now())], + column_names=['id', 'watchlist_id', 'ticker', 'entry_price', + 'target_price', 'stop_loss', 'notes', 'created_at']) return True - except Exception: + except Exception as e: + print(f"Error adding to watchlist: {e}") return False def remove_from_watchlist(item_id: int) -> bool: with create_client() as client: - cursor = client.cursor() - cursor.execute("DELETE FROM watchlist_items WHERE id = ?", (item_id,)) - return cursor.rowcount > 0 + try: + client.command(f""" + ALTER TABLE stock_db.watchlist_items + DELETE WHERE id = {item_id} + """) + return True + except Exception as e: + print(f"Error removing from watchlist: {e}") + return False def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]: with create_client() as client: - cursor = client.cursor() + query = """ + SELECT i.*, w.name as watchlist_name, w.strategy + FROM stock_db.watchlist_items i + JOIN stock_db.watchlists w ON w.id = i.watchlist_id + """ if watchlist_id: - cursor.execute(""" - SELECT i.*, w.name as watchlist_name, w.strategy - FROM watchlist_items i - JOIN watchlists w ON w.id = i.watchlist_id - WHERE watchlist_id = ? - ORDER BY i.created_at DESC - """, (watchlist_id,)) - else: - cursor.execute(""" - SELECT i.*, w.name as watchlist_name, w.strategy - FROM watchlist_items i - JOIN watchlists w ON w.id = i.watchlist_id - ORDER BY i.created_at DESC - """) - return [WatchlistItem(**row) for row in cursor.fetchall()] + query += f" WHERE watchlist_id = {watchlist_id}" + query += " ORDER BY i.created_at DESC" + + result = client.query(query) + items = [] + for row in result.result_rows: + data = dict(zip(result.column_names, row)) + items.append(WatchlistItem( + id=data['id'], + watchlist_id=data['watchlist_id'], + ticker=data['ticker'], + entry_price=data['entry_price'], + target_price=data['target_price'], + stop_loss=data['stop_loss'], + notes=data['notes'], + created_at=data['created_at'] + )) + return items