140 lines
5.5 KiB
Python
140 lines
5.5 KiB
Python
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
from db.db_connection import create_client
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class WatchlistItem:
|
|
ticker: str
|
|
entry_price: float
|
|
target_price: float
|
|
stop_loss: float
|
|
notes: Optional[str] = None
|
|
watchlist_id: Optional[int] = None
|
|
id: Optional[int] = None
|
|
created_at: Optional[datetime] = None
|
|
|
|
def create_watchlist(name: str, strategy: Optional[str] = None) -> int:
|
|
with create_client() as client:
|
|
# Get the next available ID
|
|
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlists")
|
|
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
|
|
|
client.insert('stock_db.watchlists',
|
|
[(next_id, name, strategy or '', datetime.now())],
|
|
column_names=['id', 'name', 'strategy', 'created_at'])
|
|
return next_id
|
|
|
|
def get_watchlists() -> List[dict]:
|
|
with create_client() as client:
|
|
result = client.query("SELECT * FROM stock_db.watchlists ORDER BY name")
|
|
return [dict(zip(result.column_names, row)) for row in result.result_rows]
|
|
|
|
def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool:
|
|
with create_client() as client:
|
|
try:
|
|
# Get the next available ID
|
|
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlist_items")
|
|
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
|
|
|
logger.info(f"Adding item to watchlist: ID={next_id}, WatchlistID={watchlist_id}, Ticker={item.ticker}")
|
|
logger.info(f"Entry Price: {item.entry_price}, Target: {item.target_price}, Stop: {item.stop_loss}")
|
|
|
|
# Ensure all values are properly formatted
|
|
entry_price = float(item.entry_price) if item.entry_price is not None else 0.0
|
|
target_price = float(item.target_price) if item.target_price is not None else 0.0
|
|
stop_loss = float(item.stop_loss) if item.stop_loss is not None else 0.0
|
|
|
|
data = [(
|
|
int(next_id), # Ensure ID is integer
|
|
int(watchlist_id), # Ensure watchlist_id is integer
|
|
str(item.ticker),
|
|
entry_price,
|
|
target_price,
|
|
stop_loss,
|
|
str(item.notes or ''),
|
|
datetime.now()
|
|
)]
|
|
|
|
logger.info(f"Attempting to insert data: {data}")
|
|
|
|
# Add explicit column types
|
|
column_types = [
|
|
'UInt32', # id
|
|
'UInt32', # watchlist_id
|
|
'String', # ticker
|
|
'Float64', # entry_price
|
|
'Float64', # target_price
|
|
'Float64', # stop_loss
|
|
'String', # notes
|
|
'DateTime' # created_at
|
|
]
|
|
|
|
client.insert(
|
|
'stock_db.watchlist_items',
|
|
data,
|
|
column_names=['id', 'watchlist_id', 'ticker', 'entry_price',
|
|
'target_price', 'stop_loss', 'notes', 'created_at'],
|
|
column_types=column_types
|
|
)
|
|
|
|
# Verify the insert
|
|
verify_query = f"SELECT * FROM stock_db.watchlist_items WHERE id = {next_id}"
|
|
verify_result = client.query(verify_query)
|
|
logger.info(f"Verification query result: {verify_result.result_rows}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error adding to watchlist: {e}", exc_info=True)
|
|
return False
|
|
|
|
def remove_from_watchlist(item_id: int) -> bool:
|
|
with create_client() as client:
|
|
try:
|
|
client.command(f"""
|
|
ALTER TABLE stock_db.watchlist_items
|
|
DELETE WHERE id = {item_id}
|
|
""")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error removing from watchlist: {e}")
|
|
return False
|
|
|
|
def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]:
|
|
with create_client() as client:
|
|
try:
|
|
query = """
|
|
SELECT i.*, w.name as watchlist_name, w.strategy
|
|
FROM stock_db.watchlist_items i
|
|
JOIN stock_db.watchlists w ON w.id = i.watchlist_id
|
|
"""
|
|
if watchlist_id:
|
|
query += f" WHERE i.watchlist_id = {watchlist_id}"
|
|
query += " ORDER BY i.created_at DESC"
|
|
|
|
logger.info(f"Executing query: {query}")
|
|
|
|
result = client.query(query)
|
|
items = []
|
|
for row in result.result_rows:
|
|
data = dict(zip(result.column_names, row))
|
|
logger.info(f"Processing row: {data}")
|
|
items.append(WatchlistItem(
|
|
id=data['id'],
|
|
watchlist_id=data['watchlist_id'],
|
|
ticker=data['ticker'],
|
|
entry_price=data['entry_price'],
|
|
target_price=data['target_price'],
|
|
stop_loss=data['stop_loss'],
|
|
notes=data['notes'],
|
|
created_at=data['created_at']
|
|
))
|
|
return items
|
|
except Exception as e:
|
|
logger.error(f"Error getting watchlist items: {e}", exc_info=True)
|
|
return []
|