diff --git a/src/migrations/add_watchlist_tables.py b/src/migrations/add_watchlist_tables.py index 2d48902..c55d629 100644 --- a/src/migrations/add_watchlist_tables.py +++ b/src/migrations/add_watchlist_tables.py @@ -1,41 +1,53 @@ import sys from pathlib import Path +import logging # Add the src directory to the Python path src_path = str(Path(__file__).parent.parent) sys.path.append(src_path) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + from db.db_connection import create_client def create_watchlist_tables(): with create_client() as client: - # Create watchlists table - client.command(""" - CREATE TABLE IF NOT EXISTS stock_db.watchlists ( - id UInt32, - name String, - strategy String, - created_at DateTime DEFAULT now() - ) - ENGINE = MergeTree() - ORDER BY (id) - """) - - # Create watchlist items table - client.command(""" - CREATE TABLE IF NOT EXISTS stock_db.watchlist_items ( - id UInt32, - watchlist_id UInt32, - ticker String, - entry_price Float64, - target_price Float64, - stop_loss Float64, - notes String, - created_at DateTime DEFAULT now() - ) - ENGINE = MergeTree() - ORDER BY (id, watchlist_id) - """) + try: + # Create watchlists table + logger.info("Creating watchlists table...") + client.command(""" + CREATE TABLE IF NOT EXISTS stock_db.watchlists ( + id UInt32, + name String, + strategy String, + created_at DateTime DEFAULT now() + ) + ENGINE = MergeTree() + ORDER BY (id) + """) + + # Create watchlist items table + logger.info("Creating watchlist_items table...") + client.command(""" + CREATE TABLE IF NOT EXISTS stock_db.watchlist_items ( + id UInt32, + watchlist_id UInt32, + ticker String, + entry_price Float64, + target_price Float64, + stop_loss Float64, + notes String, + created_at DateTime DEFAULT now() + ) + ENGINE = MergeTree() + ORDER BY (id, watchlist_id) + """) + + logger.info("Tables created successfully") + except Exception as e: + logger.error(f"Error creating tables: {e}", exc_info=True) + raise if __name__ == "__main__": create_watchlist_tables() diff --git a/src/pages/trading/trading_system_page.py b/src/pages/trading/trading_system_page.py index 5afeece..7887fe9 100644 --- a/src/pages/trading/trading_system_page.py +++ b/src/pages/trading/trading_system_page.py @@ -184,20 +184,23 @@ def trading_system_page(): st.metric("Confidence Level", f"{confidence_level}%") # Add to watchlist option - if st.button("Add to Watch List"): - watchlists = get_watchlists() - if not watchlists: - st.warning("No watch lists available. Create one in the Watch Lists tab.") - else: + st.divider() + st.subheader("Add to Watch List") + watchlists = get_watchlists() + if not watchlists: + st.warning("No watch lists available. Create one in the Watch Lists tab.") + else: + col1, col2 = st.columns([3, 1]) + with col1: selected_list = st.selectbox( "Select Watch List", options=[(w['id'], w['name']) for w in watchlists], format_func=lambda x: x[1] ) - notes = st.text_area("Notes") - - if st.button("Confirm Add to Watch List"): + + with col2: + if st.button("Add to Watch List", key="add_to_watchlist"): item = WatchlistItem( ticker=ticker, entry_price=entry_price, @@ -207,6 +210,7 @@ def trading_system_page(): ) if add_to_watchlist(selected_list[0], item): st.success(f"Added {ticker} to watch list!") + st.rerun() else: st.error("Failed to add to watch list") diff --git a/src/trading/watchlist.py b/src/trading/watchlist.py index ec23fb4..4732722 100644 --- a/src/trading/watchlist.py +++ b/src/trading/watchlist.py @@ -2,6 +2,10 @@ from dataclasses import dataclass from datetime import datetime from typing import List, Optional from db.db_connection import create_client +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) @dataclass class WatchlistItem: @@ -37,14 +41,30 @@ def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool: result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlist_items") next_id = result.first_row[0] if result.first_row[0] is not None else 1 - client.insert('stock_db.watchlist_items', - [(next_id, watchlist_id, item.ticker, item.entry_price, - item.target_price, item.stop_loss, item.notes or '', datetime.now())], - column_names=['id', 'watchlist_id', 'ticker', 'entry_price', - 'target_price', 'stop_loss', 'notes', 'created_at']) + logger.info(f"Adding item to watchlist: ID={next_id}, WatchlistID={watchlist_id}, Ticker={item.ticker}") + + data = [( + next_id, + watchlist_id, + item.ticker, + float(item.entry_price), + float(item.target_price), + float(item.stop_loss), + item.notes or '', + datetime.now() + )] + + logger.info(f"Insert {data}") + + client.insert( + 'stock_db.watchlist_items', + data, + column_names=['id', 'watchlist_id', 'ticker', 'entry_price', + 'target_price', 'stop_loss', 'notes', 'created_at'] + ) return True except Exception as e: - print(f"Error adding to watchlist: {e}") + logger.error(f"Error adding to watchlist: {e}", exc_info=True) return False def remove_from_watchlist(item_id: int) -> bool: @@ -61,27 +81,34 @@ def remove_from_watchlist(item_id: int) -> bool: def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]: with create_client() as client: - query = """ - SELECT i.*, w.name as watchlist_name, w.strategy - FROM stock_db.watchlist_items i - JOIN stock_db.watchlists w ON w.id = i.watchlist_id - """ - if watchlist_id: - query += f" WHERE watchlist_id = {watchlist_id}" - query += " ORDER BY i.created_at DESC" - - result = client.query(query) - items = [] - for row in result.result_rows: - data = dict(zip(result.column_names, row)) - items.append(WatchlistItem( - id=data['id'], - watchlist_id=data['watchlist_id'], - ticker=data['ticker'], - entry_price=data['entry_price'], - target_price=data['target_price'], - stop_loss=data['stop_loss'], - notes=data['notes'], - created_at=data['created_at'] - )) - return items + try: + query = """ + SELECT i.*, w.name as watchlist_name, w.strategy + FROM stock_db.watchlist_items i + JOIN stock_db.watchlists w ON w.id = i.watchlist_id + """ + if watchlist_id: + query += f" WHERE i.watchlist_id = {watchlist_id}" + query += " ORDER BY i.created_at DESC" + + logger.info(f"Executing query: {query}") + + result = client.query(query) + items = [] + for row in result.result_rows: + data = dict(zip(result.column_names, row)) + logger.info(f"Processing row: {data}") + items.append(WatchlistItem( + id=data['id'], + watchlist_id=data['watchlist_id'], + ticker=data['ticker'], + entry_price=data['entry_price'], + target_price=data['target_price'], + stop_loss=data['stop_loss'], + notes=data['notes'], + created_at=data['created_at'] + )) + return items + except Exception as e: + logger.error(f"Error getting watchlist items: {e}", exc_info=True) + return []