From e5758426640b4a54061dd46698deb48cb085bcd7 Mon Sep 17 00:00:00 2001 From: "Bobby (aider)" Date: Tue, 11 Feb 2025 19:00:24 -0800 Subject: [PATCH] feat: Add trade linking and performance metrics to trading plans --- src/streamlit_app.py | 61 +++++++++++++++++++++++++++++++++++ src/trading/journal.py | 3 +- src/trading/trading_plan.py | 64 +++++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/streamlit_app.py b/src/streamlit_app.py index 202cb72..209a688 100644 --- a/src/streamlit_app.py +++ b/src/streamlit_app.py @@ -1143,6 +1143,67 @@ def trading_plan_page(): st.query_params.update(rerun=True) except Exception as e: st.error(f"Error deleting plan: {str(e)}") + + # Add Trade Management section + st.subheader("Trade Management") + + # Get current trades for this plan + plan_trades = get_plan_trades(plan.id) + + # Display current trades + if plan_trades: + st.write("Current Trades:") + for trade in plan_trades: + with st.expander(f"{trade['ticker']} - {trade['entry_date']}"): + col1, col2 = st.columns(2) + with col1: + st.write(f"Entry: ${trade['entry_price']:.2f}") + st.write(f"Shares: {trade['shares']}") + with col2: + if trade['exit_price']: + pl = (trade['exit_price'] - trade['entry_price']) * trade['shares'] + st.write(f"Exit: ${trade['exit_price']:.2f}") + st.write(f"P/L: ${pl:.2f}") + + # Get available trades + with create_client() as client: + query = """ + SELECT id, ticker, entry_date, entry_price, shares, exit_price, exit_date + FROM stock_db.trades + WHERE plan_id IS NULL + ORDER BY entry_date DESC + """ + result = client.query(query) + available_trades = [dict(zip( + ['id', 'ticker', 'entry_date', 'entry_price', 'shares', 'exit_price', 'exit_date'], + row + )) for row in result.result_rows] + + if available_trades: + st.write("Link Existing Trades:") + selected_trades = st.multiselect( + "Select trades to link to this plan", + options=[t['id'] for t in available_trades], + format_func=lambda x: next( + f"{t['ticker']} - {t['entry_date']} - ${t['entry_price']:.2f}" + for t in available_trades if t['id'] == x + ) + ) + + if selected_trades and st.button("Link Selected Trades"): + if link_trades_to_plan(plan.id, selected_trades): + st.success("Trades linked successfully!") + + # Calculate and update metrics + metrics = calculate_plan_metrics(plan.id) + plan.win_rate = metrics['win_rate'] + plan.average_return_per_trade = metrics['average_return'] + plan.profit_factor = metrics['profit_factor'] + update_trading_plan(plan) + + st.query_params.update(rerun=True) + else: + st.error("Error linking trades") else: st.info("No plans available to edit") diff --git a/src/trading/journal.py b/src/trading/journal.py index 5791aa1..04a52aa 100644 --- a/src/trading/journal.py +++ b/src/trading/journal.py @@ -278,7 +278,8 @@ def create_trades_table(): exit_date Nullable(DateTime), exit_reason Nullable(String), notes Nullable(String), - created_at DateTime DEFAULT now() + created_at DateTime DEFAULT now(), + plan_id Nullable(UInt32) ) ENGINE = MergeTree() ORDER BY (position_id, id, entry_date) """ diff --git a/src/trading/trading_plan.py b/src/trading/trading_plan.py index 98fa388..dbff08b 100644 --- a/src/trading/trading_plan.py +++ b/src/trading/trading_plan.py @@ -323,6 +323,70 @@ def delete_trading_plan(plan_id: int) -> bool: print(f"Error deleting trading plan: {e}") return False +def link_trades_to_plan(plan_id: int, trade_ids: List[int]) -> bool: + """Link existing trades to a trading plan""" + with create_client() as client: + try: + # Update trades to link them to the plan + trade_ids_str = ", ".join(map(str, trade_ids)) + query = f""" + ALTER TABLE stock_db.trades + UPDATE plan_id = {plan_id} + WHERE id IN ({trade_ids_str}) + """ + client.command(query) + return True + except Exception as e: + print(f"Error linking trades to plan: {e}") + return False + +def get_plan_trades(plan_id: int) -> List[dict]: + """Get all trades associated with a trading plan""" + with create_client() as client: + query = f""" + SELECT * + FROM stock_db.trades + WHERE plan_id = {plan_id} + ORDER BY entry_date DESC + """ + result = client.query(query) + return [dict(zip( + ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price', + 'target_price', 'stop_loss', 'strategy', 'order_type', 'direction', + 'followed_rules', 'entry_reason', 'exit_price', 'exit_date', + 'exit_reason', 'notes', 'created_at', 'plan_id'], + row + )) for row in result.result_rows] + +def calculate_plan_metrics(plan_id: int) -> dict: + """Calculate performance metrics for a trading plan""" + trades = get_plan_trades(plan_id) + if not trades: + return {} + + total_trades = len(trades) + winning_trades = sum(1 for t in trades if t['exit_price'] and + (t['exit_price'] - t['entry_price']) * t['shares'] > 0) + total_profit = sum((t['exit_price'] - t['entry_price']) * t['shares'] + for t in trades if t['exit_price']) + + gross_profits = sum((t['exit_price'] - t['entry_price']) * t['shares'] + for t in trades if t['exit_price'] and + (t['exit_price'] - t['entry_price']) * t['shares'] > 0) + + gross_losses = abs(sum((t['exit_price'] - t['entry_price']) * t['shares'] + for t in trades if t['exit_price'] and + (t['exit_price'] - t['entry_price']) * t['shares'] < 0)) + + return { + 'total_trades': total_trades, + 'winning_trades': winning_trades, + 'win_rate': (winning_trades / total_trades * 100) if total_trades > 0 else 0, + 'total_profit': total_profit, + 'average_return': total_profit / total_trades if total_trades > 0 else 0, + 'profit_factor': gross_profits / gross_losses if gross_losses > 0 else float('inf') + } + def update_trading_plan(plan: TradingPlan) -> bool: """Update an existing trading plan""" if not plan.id: