diff --git a/src/trading/watchlist.py b/src/trading/watchlist.py index 9687a28..fa5039d 100644 --- a/src/trading/watchlist.py +++ b/src/trading/watchlist.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) def ensure_tables_exist(): with create_client() as client: + # First ensure database exists + client.command("CREATE DATABASE IF NOT EXISTS stock_db") + # Create watchlists table if not exists client.command(""" CREATE TABLE IF NOT EXISTS stock_db.watchlists ( @@ -54,6 +57,9 @@ class WatchlistItem: def create_watchlist(name: str, strategy: Optional[str] = None) -> int: with create_client() as client: + # Ensure tables exist before creating new watchlist + ensure_tables_exist() + # Get the next available ID result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlists") next_id = result.first_row[0] if result.first_row[0] is not None else 1 @@ -124,6 +130,9 @@ def remove_from_watchlist(item_id: int) -> bool: def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]: with create_client() as client: try: + # Ensure tables exist before querying + ensure_tables_exist() + query = """ SELECT i.*, w.name as watchlist_name, w.strategy FROM stock_db.watchlist_items i @@ -147,6 +156,7 @@ def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistIte entry_price=data['entry_price'], target_price=data['target_price'], stop_loss=data['stop_loss'], + shares=data.get('shares', 0), # Add default value notes=data['notes'], created_at=data['created_at'] ))