refactor: Migrate watchlist functionality to ClickHouse database
This commit is contained in:
parent
c00893360e
commit
ca843a370e
@ -1,36 +1,41 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the src directory to the Python path
|
||||||
|
src_path = str(Path(__file__).parent.parent)
|
||||||
|
sys.path.append(src_path)
|
||||||
|
|
||||||
from db.db_connection import create_client
|
from db.db_connection import create_client
|
||||||
|
|
||||||
def create_watchlist_tables():
|
def create_watchlist_tables():
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
|
||||||
|
|
||||||
# Create watchlists table
|
# Create watchlists table
|
||||||
cursor.execute("""
|
client.command("""
|
||||||
CREATE TABLE IF NOT EXISTS watchlists (
|
CREATE TABLE IF NOT EXISTS stock_db.watchlists (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id UInt32,
|
||||||
name TEXT NOT NULL UNIQUE,
|
name String,
|
||||||
strategy TEXT,
|
strategy String,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
created_at DateTime DEFAULT now()
|
||||||
)
|
)
|
||||||
|
ENGINE = MergeTree()
|
||||||
|
ORDER BY (id)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Create watchlist items table
|
# Create watchlist items table
|
||||||
cursor.execute("""
|
client.command("""
|
||||||
CREATE TABLE IF NOT EXISTS watchlist_items (
|
CREATE TABLE IF NOT EXISTS stock_db.watchlist_items (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id UInt32,
|
||||||
watchlist_id INTEGER,
|
watchlist_id UInt32,
|
||||||
ticker TEXT NOT NULL,
|
ticker String,
|
||||||
entry_price REAL,
|
entry_price Float64,
|
||||||
target_price REAL,
|
target_price Float64,
|
||||||
stop_loss REAL,
|
stop_loss Float64,
|
||||||
notes TEXT,
|
notes String,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at DateTime DEFAULT now()
|
||||||
FOREIGN KEY (watchlist_id) REFERENCES watchlists(id),
|
|
||||||
UNIQUE(watchlist_id, ticker)
|
|
||||||
)
|
)
|
||||||
|
ENGINE = MergeTree()
|
||||||
|
ORDER BY (id, watchlist_id)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
client.commit()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_watchlist_tables()
|
create_watchlist_tables()
|
||||||
|
|||||||
@ -16,56 +16,72 @@ class WatchlistItem:
|
|||||||
|
|
||||||
def create_watchlist(name: str, strategy: Optional[str] = None) -> int:
|
def create_watchlist(name: str, strategy: Optional[str] = None) -> int:
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
# Get the next available ID
|
||||||
cursor.execute(
|
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlists")
|
||||||
"INSERT INTO watchlists (name, strategy) VALUES (?, ?)",
|
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
||||||
(name, strategy)
|
|
||||||
)
|
client.insert('stock_db.watchlists',
|
||||||
return cursor.lastrowid
|
[(next_id, name, strategy or '', datetime.now())],
|
||||||
|
column_names=['id', 'name', 'strategy', 'created_at'])
|
||||||
|
return next_id
|
||||||
|
|
||||||
def get_watchlists() -> List[dict]:
|
def get_watchlists() -> List[dict]:
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
result = client.query("SELECT * FROM stock_db.watchlists ORDER BY name")
|
||||||
cursor.execute("SELECT * FROM watchlists ORDER BY name")
|
return [dict(zip(result.column_names, row)) for row in result.result_rows]
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool:
|
def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool:
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
|
||||||
try:
|
try:
|
||||||
cursor.execute("""
|
# Get the next available ID
|
||||||
INSERT INTO watchlist_items
|
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlist_items")
|
||||||
(watchlist_id, ticker, entry_price, target_price, stop_loss, notes)
|
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
""", (watchlist_id, item.ticker, item.entry_price, item.target_price,
|
client.insert('stock_db.watchlist_items',
|
||||||
item.stop_loss, item.notes))
|
[(next_id, watchlist_id, item.ticker, item.entry_price,
|
||||||
client.commit()
|
item.target_price, item.stop_loss, item.notes or '', datetime.now())],
|
||||||
|
column_names=['id', 'watchlist_id', 'ticker', 'entry_price',
|
||||||
|
'target_price', 'stop_loss', 'notes', 'created_at'])
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
print(f"Error adding to watchlist: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def remove_from_watchlist(item_id: int) -> bool:
|
def remove_from_watchlist(item_id: int) -> bool:
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
try:
|
||||||
cursor.execute("DELETE FROM watchlist_items WHERE id = ?", (item_id,))
|
client.command(f"""
|
||||||
return cursor.rowcount > 0
|
ALTER TABLE stock_db.watchlist_items
|
||||||
|
DELETE WHERE id = {item_id}
|
||||||
|
""")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error removing from watchlist: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]:
|
def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]:
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
cursor = client.cursor()
|
query = """
|
||||||
|
SELECT i.*, w.name as watchlist_name, w.strategy
|
||||||
|
FROM stock_db.watchlist_items i
|
||||||
|
JOIN stock_db.watchlists w ON w.id = i.watchlist_id
|
||||||
|
"""
|
||||||
if watchlist_id:
|
if watchlist_id:
|
||||||
cursor.execute("""
|
query += f" WHERE watchlist_id = {watchlist_id}"
|
||||||
SELECT i.*, w.name as watchlist_name, w.strategy
|
query += " ORDER BY i.created_at DESC"
|
||||||
FROM watchlist_items i
|
|
||||||
JOIN watchlists w ON w.id = i.watchlist_id
|
result = client.query(query)
|
||||||
WHERE watchlist_id = ?
|
items = []
|
||||||
ORDER BY i.created_at DESC
|
for row in result.result_rows:
|
||||||
""", (watchlist_id,))
|
data = dict(zip(result.column_names, row))
|
||||||
else:
|
items.append(WatchlistItem(
|
||||||
cursor.execute("""
|
id=data['id'],
|
||||||
SELECT i.*, w.name as watchlist_name, w.strategy
|
watchlist_id=data['watchlist_id'],
|
||||||
FROM watchlist_items i
|
ticker=data['ticker'],
|
||||||
JOIN watchlists w ON w.id = i.watchlist_id
|
entry_price=data['entry_price'],
|
||||||
ORDER BY i.created_at DESC
|
target_price=data['target_price'],
|
||||||
""")
|
stop_loss=data['stop_loss'],
|
||||||
return [WatchlistItem(**row) for row in cursor.fetchall()]
|
notes=data['notes'],
|
||||||
|
created_at=data['created_at']
|
||||||
|
))
|
||||||
|
return items
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user