diff --git a/src/trading/journal.py b/src/trading/journal.py index f0df50c..8110442 100644 --- a/src/trading/journal.py +++ b/src/trading/journal.py @@ -367,26 +367,71 @@ def update_trade(trade_id: int, updates: dict): updates (dict): Dictionary of fields and values to update """ with create_client() as client: - # Build update query dynamically from provided updates - update_statements = [] - for field, value in updates.items(): - if isinstance(value, str): - update_statements.append(f"{field} = '{value}'") - elif isinstance(value, datetime): - update_statements.append(f"{field} = '{value.strftime('%Y-%m-%d %H:%M:%S')}'") - elif value is None: - update_statements.append(f"{field} = NULL") - else: - update_statements.append(f"{field} = {value}") + # If trying to update entry_date, we need to delete and reinsert + if 'entry_date' in updates: + # First get the full trade data + query = f"SELECT * FROM stock_db.trades WHERE id = {trade_id}" + result = client.query(query).result_rows + if not result: + raise Exception("Trade not found") - update_clause = ", ".join(update_statements) - - query = f""" - ALTER TABLE stock_db.trades - UPDATE {update_clause} - WHERE id = {trade_id} - """ - client.command(query) + # Delete the existing trade + client.command(f"ALTER TABLE stock_db.trades DELETE WHERE id = {trade_id}") + + # Prepare the new trade data + columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price', + 'target_price', 'stop_loss', 'strategy', 'order_type', 'followed_rules', + 'entry_reason', 'exit_price', 'exit_date', 'exit_reason', 'notes', 'created_at'] + trade_data = dict(zip(columns, result[0])) + trade_data.update(updates) + + # Insert the updated trade + query = f""" + INSERT INTO stock_db.trades ( + id, position_id, ticker, entry_date, shares, entry_price, target_price, + stop_loss, strategy, order_type, followed_rules, entry_reason, exit_price, + exit_date, exit_reason, notes + ) VALUES ( + {trade_id}, + '{trade_data['position_id']}', + '{trade_data['ticker']}', + '{trade_data['entry_date'].strftime('%Y-%m-%d %H:%M:%S')}', + {trade_data['shares']}, + {trade_data['entry_price']}, + {trade_data['target_price']}, + {trade_data['stop_loss']}, + '{trade_data['strategy']}', + '{trade_data['order_type']}', + {1 if trade_data['followed_rules'] else 0}, + {f"'{trade_data['entry_reason']}'" if trade_data['entry_reason'] else 'NULL'}, + {trade_data['exit_price'] if trade_data['exit_price'] else 'NULL'}, + {f"'{trade_data['exit_date'].strftime('%Y-%m-%d %H:%M:%S')}'" if trade_data['exit_date'] else 'NULL'}, + {f"'{trade_data['exit_reason']}'" if trade_data['exit_reason'] else 'NULL'}, + {f"'{trade_data['notes']}'" if trade_data['notes'] else 'NULL'} + ) + """ + client.command(query) + else: + # For non-key columns, we can use regular UPDATE + update_statements = [] + for field, value in updates.items(): + if isinstance(value, str): + update_statements.append(f"{field} = '{value}'") + elif isinstance(value, datetime): + update_statements.append(f"{field} = '{value.strftime('%Y-%m-%d %H:%M:%S')}'") + elif value is None: + update_statements.append(f"{field} = NULL") + else: + update_statements.append(f"{field} = {value}") + + update_clause = ", ".join(update_statements) + + query = f""" + ALTER TABLE stock_db.trades + UPDATE {update_clause} + WHERE id = {trade_id} + """ + client.command(query) def get_open_trades_summary() -> dict: """Get summary of all open trades grouped by ticker"""