Compare commits
No commits in common. "main" and "all_tradeable" have entirely different histories.
main
...
all_tradea
@ -1,15 +0,0 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
.env
|
||||
*.log
|
||||
.git
|
||||
.gitignore
|
||||
data/
|
||||
reports/
|
||||
.aider.model.settings.yml
|
||||
.aider.chat.history.md
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -18,7 +18,3 @@ reports/
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
.aider*
|
||||
|
||||
# Docker
|
||||
.docker/
|
||||
docker-compose.override.yml
|
||||
|
||||
15
Dockerfile
15
Dockerfile
@ -1,15 +0,0 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
libta-lib-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
EXPOSE 8501
|
||||
CMD ["streamlit", "run", "src/streamlit_app.py", "--server.headless=true", "--server.port=8501"]
|
||||
@ -1,12 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
web:
|
||||
build: .
|
||||
ports:
|
||||
- "8501:8501"
|
||||
volumes:
|
||||
- .:/app
|
||||
environment:
|
||||
- STREAMLIT_SERVER_PORT=8501
|
||||
- STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
||||
@ -13,28 +13,3 @@ pytest
|
||||
|
||||
# Time zone handling
|
||||
pytz
|
||||
tzdata
|
||||
|
||||
# Real-time data
|
||||
yfinance
|
||||
|
||||
# Web interface
|
||||
streamlit>=1.24.0
|
||||
plotly>=5.13.0
|
||||
streamlit-option-menu>=0.3.2
|
||||
|
||||
# Backtesting
|
||||
backtesting
|
||||
pandas-ta
|
||||
|
||||
# Technical Analysis
|
||||
ta
|
||||
ta-lib
|
||||
|
||||
# AI and Machine Learning
|
||||
tensorflow>=2.10.0
|
||||
scikit-learn>=1.0.0
|
||||
matplotlib>=3.5.0
|
||||
|
||||
# Add this to your requirements.txt
|
||||
pandas_market_calendars>=3.0.0
|
||||
|
||||
94
src/app.py
94
src/app.py
@ -1,94 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import pytz
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Import page modules
|
||||
from pages.dashboard.dashboard_page import dashboard_page
|
||||
from pages.screener.screener_page import screener_page
|
||||
from pages.screener.technical_scanner_page import technical_scanner_page
|
||||
from pages.journal.trading_journal_page import trading_journal_page
|
||||
from pages.trading.trading_system_page import trading_system_page
|
||||
from pages.analysis.monte_carlo_page import monte_carlo_page
|
||||
from pages.analysis.ai_forecast_page import ai_forecast_page
|
||||
from pages.backtesting.backtesting_page import backtesting_page
|
||||
from pages.settings.settings_page import settings_page
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set page config
|
||||
st.set_page_config(
|
||||
page_title="Trading Dashboard",
|
||||
page_icon="📈",
|
||||
layout="wide",
|
||||
initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
# Define page mapping
|
||||
pages = {
|
||||
"Dashboard": dashboard_page,
|
||||
"Stock Screener": {
|
||||
"Basic Screener": screener_page,
|
||||
"Technical Scanner": technical_scanner_page
|
||||
},
|
||||
"Trading Journal": trading_journal_page,
|
||||
"Trading System": trading_system_page,
|
||||
"Analysis": {
|
||||
"Monte Carlo Simulation": monte_carlo_page,
|
||||
"AI Stock Forecast": ai_forecast_page
|
||||
},
|
||||
"Backtesting": backtesting_page,
|
||||
"Settings": settings_page
|
||||
}
|
||||
|
||||
def main():
|
||||
# Sidebar navigation
|
||||
st.sidebar.title("Navigation")
|
||||
|
||||
# Handle nested pages
|
||||
selected_page = None
|
||||
selected_subpage = None
|
||||
|
||||
# First level navigation
|
||||
main_page = st.sidebar.radio(
|
||||
"Select Page",
|
||||
options=list(pages.keys())
|
||||
)
|
||||
|
||||
# Check if the selected page has subpages
|
||||
if isinstance(pages[main_page], dict):
|
||||
# Second level navigation
|
||||
selected_subpage = st.sidebar.radio(
|
||||
f"Select {main_page} Option",
|
||||
options=list(pages[main_page].keys())
|
||||
)
|
||||
selected_page = pages[main_page][selected_subpage]
|
||||
else:
|
||||
selected_page = pages[main_page]
|
||||
|
||||
# Add a separator
|
||||
st.sidebar.markdown("---")
|
||||
|
||||
# Add app info in sidebar
|
||||
st.sidebar.info(
|
||||
"""
|
||||
**Trading Dashboard App**
|
||||
Version 1.0
|
||||
|
||||
A comprehensive trading toolkit for analysis,
|
||||
journaling, and decision making.
|
||||
"""
|
||||
)
|
||||
|
||||
# Display the selected page
|
||||
selected_page()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -4,7 +4,6 @@ import time
|
||||
import clickhouse_connect
|
||||
from dotenv import load_dotenv
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
@ -53,75 +52,3 @@ def create_client():
|
||||
client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {str(e)}")
|
||||
|
||||
def get_current_prices(tickers):
|
||||
"""Fetch current prices for the given tickers from ClickHouse using window_start."""
|
||||
try:
|
||||
with create_client() as client:
|
||||
# Get the current time
|
||||
now = datetime.now()
|
||||
print(f"Current datetime: {now}")
|
||||
|
||||
# First check if there's any data for today
|
||||
today_check_query = """
|
||||
SELECT count() as count
|
||||
FROM stock_db.stock_prices
|
||||
WHERE window_start >= toUnixTimestamp(today()) * 1000000000
|
||||
"""
|
||||
today_result = client.query(today_check_query)
|
||||
today_count = today_result.result_rows[0][0]
|
||||
print(f"Number of records for today: {today_count}")
|
||||
|
||||
# If we have data for today, use it
|
||||
if today_count > 0:
|
||||
print("Using today's data")
|
||||
query = f"""
|
||||
SELECT ticker,
|
||||
argMax(close, window_start) as close,
|
||||
max(window_start) as last_update
|
||||
FROM stock_db.stock_prices
|
||||
WHERE ticker IN ({','.join([f"'{t}'" for t in tickers])})
|
||||
AND window_start >= toUnixTimestamp(today()) * 1000000000
|
||||
GROUP BY ticker
|
||||
"""
|
||||
else:
|
||||
# Otherwise get the most recent data
|
||||
print("No data for today, using most recent data")
|
||||
query = f"""
|
||||
SELECT ticker,
|
||||
argMax(close, window_start) as close,
|
||||
max(window_start) as last_update
|
||||
FROM stock_db.stock_prices
|
||||
WHERE ticker IN ({','.join([f"'{t}'" for t in tickers])})
|
||||
GROUP BY ticker
|
||||
"""
|
||||
|
||||
print(f"Executing query: {query}")
|
||||
result = client.query(query)
|
||||
|
||||
# Process results with timestamp information
|
||||
prices = {}
|
||||
timestamps = {}
|
||||
for row in result.result_rows:
|
||||
ticker = row[0]
|
||||
price = row[1]
|
||||
timestamp = row[2]
|
||||
prices[ticker] = price
|
||||
timestamps[ticker] = timestamp
|
||||
|
||||
# Convert timestamps to readable format
|
||||
readable_times = {}
|
||||
for ticker, ts in timestamps.items():
|
||||
try:
|
||||
dt = datetime.fromtimestamp(ts / 1e9) # Convert from nanoseconds
|
||||
readable_times[ticker] = dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except Exception as e:
|
||||
readable_times[ticker] = f"Error: {str(e)}"
|
||||
|
||||
print(f"Retrieved prices: {prices}")
|
||||
print(f"Last updated: {readable_times}")
|
||||
|
||||
return prices
|
||||
except Exception as e:
|
||||
print(f"Error fetching current prices: {str(e)}")
|
||||
return {}
|
||||
|
||||
143
src/main.py
143
src/main.py
@ -2,38 +2,125 @@ import warnings
|
||||
from urllib3.exceptions import NotOpenSSLWarning
|
||||
warnings.filterwarnings('ignore', category=NotOpenSSLWarning)
|
||||
|
||||
from trading.menu import print_main_menu, print_technical_scanner_menu
|
||||
from trading.journal import journal_menu
|
||||
from screener.scanner_controller import run_technical_scanner
|
||||
from screener.canslim_controller import run_canslim_screener
|
||||
from trading.main import main as trading_main
|
||||
import datetime
|
||||
from screener.data_fetcher import validate_date_range, fetch_financial_data, get_stocks_in_time_range
|
||||
from screener.t_sunnyband import run_sunny_scanner
|
||||
from screener.t_atr_ema import run_atr_ema_scanner
|
||||
from screener.c_canslim import check_quarterly_earnings, check_return_on_equity, check_sales_growth
|
||||
from screener.a_canslim import check_annual_eps_growth
|
||||
from screener.l_canslim import check_industry_leadership
|
||||
from screener.i_canslim import check_institutional_sponsorship
|
||||
from screener.csv_appender import append_scores_to_csv
|
||||
from screener.screeners import SCREENERS
|
||||
from screener.user_input import get_user_screener_selection, get_interval_choice
|
||||
from indicators.three_atr_ema import ThreeATREMAIndicator
|
||||
|
||||
def get_float_input(prompt: str) -> float:
|
||||
while True:
|
||||
try:
|
||||
return float(input(prompt))
|
||||
except ValueError:
|
||||
print("Please enter a valid number")
|
||||
|
||||
def get_scanner_parameters():
|
||||
"""Get user input for scanner parameters"""
|
||||
min_price = get_float_input("Enter minimum stock price ($): ")
|
||||
max_price = get_float_input("Enter maximum stock price ($): ")
|
||||
min_volume = int(input("Enter minimum volume: "))
|
||||
portfolio_size = get_float_input("Enter portfolio size ($) or 0 to skip position sizing: ")
|
||||
return min_price, max_price, min_volume, portfolio_size
|
||||
|
||||
def main():
|
||||
while True:
|
||||
print_main_menu()
|
||||
choice = input("\nSelect an option (1-5): ")
|
||||
print("\nStock Analysis System")
|
||||
print("1. Run CANSLIM Screener")
|
||||
print("2. Run Technical Scanners (SunnyBands/ATR-EMA)")
|
||||
print("3. Launch Trading System")
|
||||
print("4. Exit")
|
||||
|
||||
choice = input("\nSelect an option (1-4): ")
|
||||
|
||||
if choice == "1":
|
||||
# 1️⃣ Ask user for start and end date
|
||||
user_start_date = input("Enter start date (YYYY-MM-DD): ")
|
||||
user_end_date = input("Enter end date (YYYY-MM-DD): ")
|
||||
|
||||
# 2️⃣ Validate and adjust date range if needed
|
||||
start_date, end_date = validate_date_range(user_start_date, user_end_date, required_quarters=4)
|
||||
selected_screeners = get_user_screener_selection()
|
||||
symbol_list = get_stocks_in_time_range(start_date, end_date)
|
||||
|
||||
if choice == "1":
|
||||
run_canslim_screener()
|
||||
|
||||
elif choice == "2":
|
||||
print_technical_scanner_menu()
|
||||
scanner_choice = input("\nEnter your choice (1-3): ")
|
||||
run_technical_scanner(scanner_choice)
|
||||
|
||||
elif choice == "3":
|
||||
trading_main()
|
||||
|
||||
elif choice == "4":
|
||||
journal_menu()
|
||||
|
||||
elif choice == "5":
|
||||
print("Exiting...")
|
||||
break
|
||||
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
if not symbol_list:
|
||||
print("No stocks found within the given date range.")
|
||||
return
|
||||
|
||||
print(f"Processing {len(symbol_list)} stocks within the given date range...\n")
|
||||
|
||||
for symbol in symbol_list:
|
||||
data = fetch_financial_data(symbol, start_date, end_date)
|
||||
|
||||
if not data:
|
||||
print(f"⚠️ Warning: No data returned for {symbol}. Assigning default score.\n")
|
||||
scores = {screener: 0.25 for category in selected_screeners for screener in selected_screeners[category]}
|
||||
else:
|
||||
scores = {}
|
||||
|
||||
for category, screeners in selected_screeners.items():
|
||||
for screener, threshold in screeners.items():
|
||||
if screener == "EPS_Score":
|
||||
scores[screener] = check_quarterly_earnings(data.get("quarterly_eps", []))
|
||||
elif screener == "Annual_EPS_Score":
|
||||
scores[screener] = check_annual_eps_growth(data.get("annual_eps", []))
|
||||
elif screener == "Sales_Score":
|
||||
scores[screener] = check_sales_growth(data.get("sales_growth", []))
|
||||
elif screener == "ROE_Score":
|
||||
scores[screener] = check_return_on_equity(data.get("roe", []))
|
||||
elif screener == "L_Score":
|
||||
scores[screener] = check_industry_leadership(symbol)
|
||||
print(f"🟢 {symbol} - L_Score: {scores[screener]}")
|
||||
elif screener == "I_Score":
|
||||
scores[screener] = check_institutional_sponsorship(symbol)
|
||||
print(f"🏢 {symbol} - I_Score: {scores[screener]}")
|
||||
|
||||
if isinstance(threshold, (int, float)):
|
||||
scores[screener] = scores[screener] >= threshold
|
||||
|
||||
scores["Total_Score"] = sum(scores.values())
|
||||
append_scores_to_csv(symbol, scores)
|
||||
|
||||
print("✅ Scores saved in data/metrics/stock_scores.csv\n")
|
||||
|
||||
elif choice == "2":
|
||||
print("\nTechnical Scanner Options:")
|
||||
print("1. SunnyBands Scanner")
|
||||
print("2. Standard ATR-EMA Scanner")
|
||||
print("3. Enhanced ATR-EMA v2 Scanner") # NEW OPTION
|
||||
|
||||
scanner_choice = input("\nEnter your choice (1-3): ")
|
||||
|
||||
# Get parameters first for all scanners
|
||||
min_price, max_price, min_volume, portfolio_size = get_scanner_parameters()
|
||||
|
||||
if scanner_choice == "1":
|
||||
from screener.t_sunnyband import run_sunny_scanner
|
||||
run_sunny_scanner(min_price, max_price, min_volume, portfolio_size)
|
||||
elif scanner_choice == "2":
|
||||
from screener.t_atr_ema import run_atr_ema_scanner
|
||||
run_atr_ema_scanner(min_price, max_price, min_volume, portfolio_size)
|
||||
elif scanner_choice == "3": # NEW CASE
|
||||
from screener.t_atr_ema_v2 import run_atr_ema_scanner_v2
|
||||
run_atr_ema_scanner_v2(min_price, max_price, min_volume, portfolio_size)
|
||||
else:
|
||||
print("Invalid choice. Returning to main menu.")
|
||||
|
||||
elif choice == "3":
|
||||
from trading.main import main as trading_main
|
||||
trading_main()
|
||||
|
||||
elif choice == "4":
|
||||
print("Exiting...")
|
||||
return
|
||||
else:
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,85 +0,0 @@
|
||||
from db.db_connection import create_client
|
||||
from datetime import datetime
|
||||
|
||||
def migrate_trades():
|
||||
"""Add direction field to existing trades"""
|
||||
with create_client() as client:
|
||||
# First get all trades
|
||||
query = "SELECT * FROM stock_db.trades"
|
||||
trades = client.query(query).result_rows
|
||||
|
||||
# Get column names
|
||||
columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price',
|
||||
'target_price', 'stop_loss', 'strategy', 'order_type', 'followed_rules',
|
||||
'entry_reason', 'exit_price', 'exit_date', 'exit_reason', 'notes', 'created_at']
|
||||
|
||||
# Create new table with direction field
|
||||
client.command("""
|
||||
CREATE TABLE IF NOT EXISTS stock_db.trades_new (
|
||||
id UInt32,
|
||||
position_id String,
|
||||
ticker String,
|
||||
entry_date DateTime,
|
||||
shares UInt32,
|
||||
entry_price Float64,
|
||||
target_price Nullable(Float64),
|
||||
stop_loss Nullable(Float64),
|
||||
strategy Nullable(String),
|
||||
order_type String,
|
||||
direction String,
|
||||
followed_rules Nullable(UInt8),
|
||||
entry_reason Nullable(String),
|
||||
exit_price Nullable(Float64),
|
||||
exit_date Nullable(DateTime),
|
||||
exit_reason Nullable(String),
|
||||
notes Nullable(String),
|
||||
created_at DateTime DEFAULT now()
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (position_id, id, entry_date)
|
||||
""")
|
||||
|
||||
# Migrate data
|
||||
for trade in trades:
|
||||
trade_dict = dict(zip(columns, trade))
|
||||
|
||||
# Determine direction based on exit_price
|
||||
direction = 'sell' if trade_dict['exit_price'] else 'buy'
|
||||
|
||||
# Insert into new table
|
||||
query = f"""
|
||||
INSERT INTO stock_db.trades_new (
|
||||
id, position_id, ticker, entry_date, shares, entry_price, target_price, stop_loss,
|
||||
strategy, order_type, direction, followed_rules, entry_reason, exit_price, exit_date,
|
||||
exit_reason, notes, created_at
|
||||
) VALUES (
|
||||
{trade_dict['id']},
|
||||
'{trade_dict['position_id']}',
|
||||
'{trade_dict['ticker']}',
|
||||
'{trade_dict['entry_date'].strftime('%Y-%m-%d %H:%M:%S')}',
|
||||
{trade_dict['shares']},
|
||||
{trade_dict['entry_price']},
|
||||
{trade_dict['target_price'] if trade_dict['target_price'] else 'NULL'},
|
||||
{trade_dict['stop_loss'] if trade_dict['stop_loss'] else 'NULL'},
|
||||
{f"'{trade_dict['strategy']}'" if trade_dict['strategy'] else 'NULL'},
|
||||
'{trade_dict['order_type']}',
|
||||
'{direction}',
|
||||
{1 if trade_dict['followed_rules'] else 0},
|
||||
{f"'{trade_dict['entry_reason']}'" if trade_dict['entry_reason'] else 'NULL'},
|
||||
{trade_dict['exit_price'] if trade_dict['exit_price'] else 'NULL'},
|
||||
{f"'{trade_dict['exit_date'].strftime('%Y-%m-%d %H:%M:%S')}'" if trade_dict['exit_date'] else 'NULL'},
|
||||
{f"'{trade_dict['exit_reason']}'" if trade_dict['exit_reason'] else 'NULL'},
|
||||
{f"'{trade_dict['notes']}'" if trade_dict['notes'] else 'NULL'},
|
||||
'{trade_dict['created_at'].strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
)
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
# Rename tables
|
||||
client.command("RENAME TABLE stock_db.trades TO stock_db.trades_backup")
|
||||
client.command("RENAME TABLE stock_db.trades_new TO stock_db.trades")
|
||||
|
||||
print("Migration completed successfully!")
|
||||
print("Old table backed up as trades_backup")
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate_trades()
|
||||
@ -1,53 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Add the src directory to the Python path
|
||||
src_path = str(Path(__file__).parent.parent)
|
||||
sys.path.append(src_path)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from db.db_connection import create_client
|
||||
|
||||
def create_watchlist_tables():
|
||||
with create_client() as client:
|
||||
try:
|
||||
# Create watchlists table
|
||||
logger.info("Creating watchlists table...")
|
||||
client.command("""
|
||||
CREATE TABLE IF NOT EXISTS stock_db.watchlists (
|
||||
id UInt32,
|
||||
name String,
|
||||
strategy String,
|
||||
created_at DateTime DEFAULT now()
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY (id)
|
||||
""")
|
||||
|
||||
# Create watchlist items table
|
||||
logger.info("Creating watchlist_items table...")
|
||||
client.command("""
|
||||
CREATE TABLE IF NOT EXISTS stock_db.watchlist_items (
|
||||
id UInt32,
|
||||
watchlist_id UInt32,
|
||||
ticker String,
|
||||
entry_price Float64,
|
||||
target_price Float64,
|
||||
stop_loss Float64,
|
||||
notes String,
|
||||
created_at DateTime DEFAULT now()
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY (id, watchlist_id)
|
||||
""")
|
||||
|
||||
logger.info("Tables created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tables: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_watchlist_tables()
|
||||
@ -1,657 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
from datetime import datetime, timedelta
|
||||
from utils.common_utils import get_stock_data
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import LSTM, Dense, Dropout
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
from pandas.tseries.offsets import BDay
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
class AIForecaster:
|
||||
def __init__(self, data: pd.DataFrame, forecast_days: int, lookback_window: int = 60):
|
||||
"""
|
||||
Initialize AI Forecaster
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): Historical price data
|
||||
forecast_days (int): Number of days to forecast
|
||||
lookback_window (int): Number of past days to use for prediction
|
||||
"""
|
||||
# Make a copy and ensure column names are standardized
|
||||
self.data = data.copy()
|
||||
|
||||
# Standardize column names (convert to lowercase first, then capitalize)
|
||||
self.data.columns = [col.lower() for col in self.data.columns]
|
||||
|
||||
# Map common column names to standard format
|
||||
column_mapping = {
|
||||
'open': 'Open',
|
||||
'high': 'High',
|
||||
'low': 'Low',
|
||||
'close': 'Close',
|
||||
'volume': 'Volume',
|
||||
'adj close': 'Adj Close',
|
||||
'adj_close': 'Adj Close'
|
||||
}
|
||||
|
||||
# Rename columns
|
||||
self.data.rename(columns={k: v for k, v in column_mapping.items() if k in self.data.columns}, inplace=True)
|
||||
|
||||
# Ensure data is sorted by date
|
||||
if 'date' in self.data.columns:
|
||||
self.data.rename(columns={'date': 'Date'}, inplace=True)
|
||||
self.data = self.data.sort_values('Date')
|
||||
|
||||
self.forecast_days = forecast_days
|
||||
self.lookback_window = lookback_window
|
||||
|
||||
# Check if Close column exists
|
||||
if 'Close' not in self.data.columns:
|
||||
raise ValueError(f"Required column 'Close' not found. Available columns: {list(self.data.columns)}")
|
||||
|
||||
self.last_price = self.data['Close'].iloc[-1]
|
||||
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
# Features to use for prediction
|
||||
self.features = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||
self.available_features = [f for f in self.features if f in self.data.columns]
|
||||
|
||||
# Add technical indicators
|
||||
self.add_technical_indicators()
|
||||
|
||||
def add_technical_indicators(self):
|
||||
"""Add technical indicators to the dataset"""
|
||||
# Moving averages
|
||||
self.data['MA5'] = self.data['Close'].rolling(window=5).mean()
|
||||
self.data['MA20'] = self.data['Close'].rolling(window=20).mean()
|
||||
|
||||
# Relative Strength Index (RSI)
|
||||
delta = self.data['Close'].diff()
|
||||
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
|
||||
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
|
||||
rs = gain / loss
|
||||
self.data['RSI'] = 100 - (100 / (1 + rs))
|
||||
|
||||
# MACD
|
||||
self.data['EMA12'] = self.data['Close'].ewm(span=12, adjust=False).mean()
|
||||
self.data['EMA26'] = self.data['Close'].ewm(span=26, adjust=False).mean()
|
||||
self.data['MACD'] = self.data['EMA12'] - self.data['EMA26']
|
||||
self.data['Signal'] = self.data['MACD'].ewm(span=9, adjust=False).mean()
|
||||
|
||||
# Bollinger Bands
|
||||
self.data['BB_Middle'] = self.data['Close'].rolling(window=20).mean()
|
||||
self.data['BB_Std'] = self.data['Close'].rolling(window=20).std()
|
||||
self.data['BB_Upper'] = self.data['BB_Middle'] + 2 * self.data['BB_Std']
|
||||
self.data['BB_Lower'] = self.data['BB_Middle'] - 2 * self.data['BB_Std']
|
||||
|
||||
# Drop NaN values
|
||||
self.data = self.data.dropna()
|
||||
|
||||
# Add additional features to available features
|
||||
additional_features = ['MA5', 'MA20', 'RSI', 'MACD', 'Signal']
|
||||
self.available_features.extend(additional_features)
|
||||
|
||||
def prepare_data(self):
|
||||
"""Prepare data for LSTM model"""
|
||||
# Select features
|
||||
dataset = self.data[self.available_features].values
|
||||
|
||||
# Scale the data
|
||||
scaled_data = self.scaler.fit_transform(dataset)
|
||||
|
||||
# Create sequences for LSTM
|
||||
X, y = [], []
|
||||
for i in range(self.lookback_window, len(scaled_data)):
|
||||
X.append(scaled_data[i-self.lookback_window:i])
|
||||
y.append(scaled_data[i, self.available_features.index('Close')])
|
||||
|
||||
X, y = np.array(X), np.array(y)
|
||||
|
||||
# Split data into train and test sets (80% train, 20% test)
|
||||
train_size = int(len(X) * 0.8)
|
||||
X_train, X_test = X[:train_size], X[train_size:]
|
||||
y_train, y_test = y[:train_size], y[train_size:]
|
||||
|
||||
return X_train, y_train, X_test, y_test, scaled_data
|
||||
|
||||
def build_model(self, X_train):
|
||||
"""Build and compile LSTM model"""
|
||||
model = Sequential()
|
||||
|
||||
# Input layer
|
||||
model.add(LSTM(units=50, return_sequences=True,
|
||||
input_shape=(X_train.shape[1], X_train.shape[2])))
|
||||
model.add(Dropout(0.2))
|
||||
|
||||
# Hidden layers
|
||||
model.add(LSTM(units=50, return_sequences=True))
|
||||
model.add(Dropout(0.2))
|
||||
|
||||
model.add(LSTM(units=50))
|
||||
model.add(Dropout(0.2))
|
||||
|
||||
# Output layer
|
||||
model.add(Dense(units=1))
|
||||
|
||||
# Compile model
|
||||
model.compile(optimizer='adam', loss='mean_squared_error')
|
||||
|
||||
return model
|
||||
|
||||
def train_model(self, verbose=1):
|
||||
"""Train the LSTM model and make predictions"""
|
||||
try:
|
||||
# Prepare data
|
||||
X_train, y_train, X_test, y_test, scaled_data = self.prepare_data()
|
||||
|
||||
# Build model
|
||||
model = self.build_model(X_train)
|
||||
|
||||
# Train model with a callback to prevent training from hanging
|
||||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
)
|
||||
|
||||
history = model.fit(
|
||||
X_train, y_train,
|
||||
epochs=50,
|
||||
batch_size=32,
|
||||
validation_data=(X_test, y_test),
|
||||
verbose=verbose,
|
||||
callbacks=[early_stopping]
|
||||
)
|
||||
|
||||
# Make predictions on test data
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
# Prepare for inverse scaling
|
||||
pred_copy = np.repeat(y_pred, len(self.available_features)).reshape(-1, len(self.available_features))
|
||||
|
||||
# Only the Close price column needs to be replaced
|
||||
close_idx = self.available_features.index('Close')
|
||||
|
||||
# Create a dummy array with the same shape as scaled_data[-len(y_test):]
|
||||
dummy = scaled_data[-len(y_test):].copy()
|
||||
|
||||
# Replace only the Close column with predictions
|
||||
dummy[:, close_idx] = y_pred.flatten()
|
||||
|
||||
# Inverse transform
|
||||
y_pred_actual = self.scaler.inverse_transform(dummy)[:, close_idx]
|
||||
y_test_actual = self.scaler.inverse_transform(scaled_data[-len(y_test):])[:, close_idx]
|
||||
|
||||
# Calculate metrics
|
||||
mse = mean_squared_error(y_test_actual, y_pred_actual)
|
||||
mae = mean_absolute_error(y_test_actual, y_pred_actual)
|
||||
rmse = np.sqrt(mse)
|
||||
r2 = r2_score(y_test_actual, y_pred_actual)
|
||||
|
||||
metrics = {
|
||||
'mse': mse,
|
||||
'mae': mae,
|
||||
'rmse': rmse,
|
||||
'r2': r2
|
||||
}
|
||||
|
||||
# Forecast future prices
|
||||
last_sequence = scaled_data[-self.lookback_window:]
|
||||
future_predictions = []
|
||||
|
||||
# Create a copy for forecasting
|
||||
current_sequence = last_sequence.copy()
|
||||
|
||||
def get_trading_days(start_date, num_days):
|
||||
"""Get future trading days using NYSE calendar"""
|
||||
nyse = mcal.get_calendar('NYSE')
|
||||
end_date = start_date + timedelta(days=num_days * 2) # Look ahead enough to find required trading days
|
||||
schedule = nyse.schedule(start_date=start_date, end_date=end_date)
|
||||
trading_days = mcal.date_range(schedule, frequency='1D')
|
||||
return trading_days[:num_days]
|
||||
|
||||
# Get the last date from the data
|
||||
last_date = self.data.index[-1] if isinstance(self.data.index, pd.DatetimeIndex) else pd.to_datetime(self.data['Date'].iloc[-1])
|
||||
|
||||
# Generate future trading days
|
||||
forecast_dates = get_trading_days(last_date + timedelta(days=1), self.forecast_days)
|
||||
|
||||
for _ in range(self.forecast_days):
|
||||
# Reshape for prediction
|
||||
current_batch = current_sequence.reshape(1, self.lookback_window, len(self.available_features))
|
||||
|
||||
# Predict next value
|
||||
next_pred = model.predict(current_batch)[0][0]
|
||||
|
||||
# Create a dummy row with the last known values
|
||||
dummy_row = current_sequence[-1].copy()
|
||||
|
||||
# Update the Close price with our prediction
|
||||
dummy_row[close_idx] = next_pred
|
||||
|
||||
# Add to predictions
|
||||
future_predictions.append(dummy_row)
|
||||
|
||||
# Update sequence by removing first row and adding the new prediction
|
||||
current_sequence = np.vstack([current_sequence[1:], dummy_row])
|
||||
|
||||
# Convert predictions to actual values
|
||||
future_predictions = np.array(future_predictions)
|
||||
future_prices = self.scaler.inverse_transform(future_predictions)[:, close_idx]
|
||||
|
||||
# Create forecast DataFrame
|
||||
forecast_df = pd.DataFrame({
|
||||
'Date': forecast_dates,
|
||||
'Predicted_Close': future_prices
|
||||
})
|
||||
|
||||
# Create historical predictions for plotting
|
||||
historical_predictions = model.predict(np.array(X_test))
|
||||
|
||||
# Prepare for inverse scaling (same as before)
|
||||
hist_dummy = scaled_data[-len(y_test):].copy()
|
||||
hist_dummy[:, close_idx] = historical_predictions.flatten()
|
||||
historical_actual = self.scaler.inverse_transform(hist_dummy)[:, close_idx]
|
||||
|
||||
# Get dates for historical predictions
|
||||
if isinstance(self.data.index, pd.DatetimeIndex):
|
||||
historical_dates = self.data.index[-len(y_test):]
|
||||
else:
|
||||
historical_dates = pd.to_datetime(self.data['Date'].iloc[-len(y_test):])
|
||||
|
||||
# Create historical predictions DataFrame
|
||||
historical_df = pd.DataFrame({
|
||||
'Date': historical_dates,
|
||||
'Actual_Close': y_test_actual,
|
||||
'Predicted_Close': historical_actual
|
||||
})
|
||||
|
||||
return model, forecast_df, historical_df, metrics, history
|
||||
except Exception as e:
|
||||
raise Exception(f"Error during model training: {str(e)}")
|
||||
|
||||
def plot_forecast(self, forecast_df, historical_df, metrics):
|
||||
"""Create an interactive plot of forecast results"""
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1,
|
||||
subplot_titles=('Price Forecast', 'Prediction Error'),
|
||||
vertical_spacing=0.15,
|
||||
row_heights=[0.7, 0.3]
|
||||
)
|
||||
|
||||
# Plot historical actual prices
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=historical_df['Date'],
|
||||
y=historical_df['Actual_Close'],
|
||||
name='Actual Price',
|
||||
line=dict(color='blue', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Plot historical predicted prices
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=historical_df['Date'],
|
||||
y=historical_df['Predicted_Close'],
|
||||
name='Model Fit',
|
||||
line=dict(color='green', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Plot future predictions
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=forecast_df['Date'],
|
||||
y=forecast_df['Predicted_Close'],
|
||||
name='Price Forecast',
|
||||
line=dict(color='red', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add confidence intervals (simple approach: ±RMSE)
|
||||
rmse = metrics['rmse']
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=forecast_df['Date'],
|
||||
y=forecast_df['Predicted_Close'] + rmse,
|
||||
name='Upper Bound',
|
||||
line=dict(color='rgba(255,0,0,0.3)', dash='dash'),
|
||||
showlegend=False
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=forecast_df['Date'],
|
||||
y=forecast_df['Predicted_Close'] - rmse,
|
||||
name='Lower Bound',
|
||||
line=dict(color='rgba(255,0,0,0.3)', dash='dash'),
|
||||
fill='tonexty',
|
||||
showlegend=False
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Plot prediction error
|
||||
error = historical_df['Actual_Close'] - historical_df['Predicted_Close']
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=historical_df['Date'],
|
||||
y=error,
|
||||
name='Prediction Error',
|
||||
marker_color='orange'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Add zero line for error
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[historical_df['Date'].min(), historical_df['Date'].max()],
|
||||
y=[0, 0],
|
||||
mode='lines',
|
||||
line=dict(color='white', dash='dash'),
|
||||
showlegend=False
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title='AI Stock Price Forecast',
|
||||
showlegend=True,
|
||||
height=800,
|
||||
template='plotly_dark',
|
||||
plot_bgcolor='rgba(0,0,0,0)',
|
||||
paper_bgcolor='rgba(0,0,0,0)'
|
||||
)
|
||||
|
||||
# Update axes
|
||||
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
|
||||
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
|
||||
|
||||
return fig
|
||||
|
||||
def plot_training_history(self, history):
|
||||
"""Plot training and validation loss"""
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(history.history['loss'], label='Training Loss')
|
||||
plt.plot(history.history['val_loss'], label='Validation Loss')
|
||||
plt.title('Model Loss')
|
||||
plt.ylabel('Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(True)
|
||||
|
||||
# Convert plot to image
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png', bbox_inches='tight')
|
||||
buf.seek(0)
|
||||
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
||||
plt.close()
|
||||
|
||||
return img_str
|
||||
|
||||
def ai_forecast_page():
|
||||
st.title("AI Stock Price Forecasting")
|
||||
|
||||
# Input parameters
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
ticker = st.text_input("Enter Ticker Symbol", value="AAPL").upper()
|
||||
start_date = st.date_input(
|
||||
"Start Date (for historical data)",
|
||||
value=datetime.now() - timedelta(days=365)
|
||||
)
|
||||
end_date = st.date_input("End Date", value=datetime.now())
|
||||
|
||||
with col2:
|
||||
forecast_days = st.number_input(
|
||||
"Forecast Horizon (Days)",
|
||||
min_value=5,
|
||||
max_value=90,
|
||||
value=30
|
||||
)
|
||||
lookback_window = st.number_input(
|
||||
"Lookback Window (Days)",
|
||||
min_value=10,
|
||||
max_value=120,
|
||||
value=60,
|
||||
help="Number of past days used to predict the next day"
|
||||
)
|
||||
|
||||
# Advanced options
|
||||
with st.expander("Advanced Options"):
|
||||
model_type = st.selectbox(
|
||||
"Model Type",
|
||||
options=["LSTM", "GRU", "Bidirectional LSTM"],
|
||||
index=0,
|
||||
help="Type of neural network to use"
|
||||
)
|
||||
|
||||
training_epochs = st.slider(
|
||||
"Training Epochs",
|
||||
min_value=10,
|
||||
max_value=200,
|
||||
value=50,
|
||||
help="Number of training iterations"
|
||||
)
|
||||
|
||||
batch_size = st.selectbox(
|
||||
"Batch Size",
|
||||
options=[8, 16, 32, 64, 128],
|
||||
index=2,
|
||||
help="Number of samples processed before model update"
|
||||
)
|
||||
|
||||
if st.button("Generate Forecast"):
|
||||
# Create a progress bar
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
with st.spinner('Training AI model and generating forecast...'):
|
||||
try:
|
||||
# Update status
|
||||
status_text.text("Loading historical data...")
|
||||
progress_bar.progress(10)
|
||||
|
||||
# Get historical data
|
||||
df = get_stock_data(
|
||||
ticker,
|
||||
datetime.combine(start_date, datetime.min.time()),
|
||||
datetime.combine(end_date, datetime.min.time()),
|
||||
'1d' # Daily data
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
st.error("No data available for the selected period")
|
||||
return
|
||||
|
||||
# Filter out non-trading days (weekends and holidays)
|
||||
nyse = mcal.get_calendar('NYSE')
|
||||
schedule = nyse.schedule(start_date=start_date, end_date=end_date)
|
||||
trading_days = mcal.date_range(schedule, frequency='1D')
|
||||
trading_days_dates = [d.date() for d in trading_days]
|
||||
|
||||
# Ensure the dataframe has a datetime index
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
elif 'Date' in df.columns:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df = df.set_index('Date')
|
||||
|
||||
# Keep only trading days
|
||||
df = df[df.index.map(lambda x: x.date() in trading_days_dates)]
|
||||
|
||||
# Update status
|
||||
status_text.text("Preparing data and adding technical indicators...")
|
||||
progress_bar.progress(30)
|
||||
|
||||
# Add debug information
|
||||
with st.expander("Debug Information"):
|
||||
st.write("DataFrame columns:", df.columns.tolist())
|
||||
st.write("DataFrame head:", df.head())
|
||||
st.write("DataFrame shape:", df.shape)
|
||||
|
||||
# Initialize forecaster
|
||||
forecaster = AIForecaster(df, forecast_days, lookback_window)
|
||||
|
||||
# Update status
|
||||
status_text.text("Training neural network model (this may take a few minutes)...")
|
||||
progress_bar.progress(50)
|
||||
|
||||
# Train model and get predictions
|
||||
model, forecast_df, historical_df, metrics, history = forecaster.train_model()
|
||||
|
||||
# Update status
|
||||
status_text.text("Generating forecast visualization...")
|
||||
progress_bar.progress(80)
|
||||
|
||||
# Display results
|
||||
col1, col2 = st.columns([3, 1])
|
||||
|
||||
with col1:
|
||||
# Plot forecast
|
||||
fig = forecaster.plot_forecast(forecast_df, historical_df, metrics)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.subheader("Forecast Metrics")
|
||||
|
||||
# Price metrics
|
||||
st.write("##### Current & Forecast")
|
||||
# Handle different column name formats
|
||||
close_col = 'close' if 'close' in df.columns else 'Close'
|
||||
current_price = df[close_col].iloc[-1]
|
||||
forecast_price = forecast_df['Predicted_Close'].iloc[-1]
|
||||
price_change = ((forecast_price - current_price) / current_price) * 100
|
||||
|
||||
st.metric("Current Price", f"${current_price:.2f}")
|
||||
st.metric(
|
||||
f"Forecast ({forecast_days} days)",
|
||||
f"${forecast_price:.2f}",
|
||||
f"{price_change:.1f}%",
|
||||
delta_color="normal" if price_change >= 0 else "inverse"
|
||||
)
|
||||
|
||||
# Model accuracy metrics
|
||||
st.write("##### Model Accuracy")
|
||||
st.metric("RMSE", f"${metrics['rmse']:.2f}")
|
||||
st.metric("MAE", f"${metrics['mae']:.2f}")
|
||||
st.metric("R² Score", f"{metrics['r2']:.3f}")
|
||||
|
||||
# Show training history
|
||||
st.subheader("Model Training History")
|
||||
history_img = forecaster.plot_training_history(history)
|
||||
st.image(f"data:image/png;base64,{history_img}", use_container_width=True)
|
||||
|
||||
# Show forecast table
|
||||
st.subheader("Detailed Forecast")
|
||||
forecast_df['Date'] = forecast_df['Date'].dt.strftime('%Y-%m-%d')
|
||||
forecast_df['Predicted_Close'] = forecast_df['Predicted_Close'].round(2)
|
||||
forecast_df.columns = ['Date', 'Predicted Price']
|
||||
|
||||
# Calculate daily changes
|
||||
forecast_df['Daily Change %'] = [0] + [
|
||||
((forecast_df['Predicted Price'].iloc[i] - forecast_df['Predicted Price'].iloc[i-1]) /
|
||||
forecast_df['Predicted Price'].iloc[i-1] * 100).round(2)
|
||||
for i in range(1, len(forecast_df))
|
||||
]
|
||||
|
||||
# Calculate cumulative change from current price
|
||||
forecast_df['Cumulative Change %'] = [
|
||||
((price - current_price) / current_price * 100).round(2)
|
||||
for price in forecast_df['Predicted Price']
|
||||
]
|
||||
|
||||
st.dataframe(forecast_df, use_container_width=True)
|
||||
|
||||
# Add download buttons
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
csv = forecast_df.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="Download Forecast Data",
|
||||
data=csv,
|
||||
file_name=f'ai_forecast_{ticker}_{datetime.now().strftime("%Y%m%d")}.csv',
|
||||
mime='text/csv'
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Combine historical and forecast for full dataset
|
||||
historical_df['Date'] = historical_df['Date'].dt.strftime('%Y-%m-%d')
|
||||
full_data = pd.concat([
|
||||
historical_df[['Date', 'Actual_Close', 'Predicted_Close']],
|
||||
forecast_df[['Date', 'Predicted Price']].rename(
|
||||
columns={'Predicted Price': 'Predicted_Close'}
|
||||
)
|
||||
])
|
||||
full_data['Actual_Close'] = full_data['Actual_Close'].round(2)
|
||||
full_data['Predicted_Close'] = full_data['Predicted_Close'].round(2)
|
||||
|
||||
csv_full = full_data.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="Download Complete Dataset",
|
||||
data=csv_full,
|
||||
file_name=f'ai_forecast_full_{ticker}_{datetime.now().strftime("%Y%m%d")}.csv',
|
||||
mime='text/csv'
|
||||
)
|
||||
|
||||
# After displaying all the results, add a button to save the model
|
||||
if 'model' in locals():
|
||||
try:
|
||||
# Create a temporary file to save the model
|
||||
with tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmp:
|
||||
temp_path = tmp.name
|
||||
|
||||
# Save the model to the temporary file
|
||||
model.save(temp_path)
|
||||
|
||||
# Read the saved model file
|
||||
with open(temp_path, 'rb') as f:
|
||||
model_data = f.read()
|
||||
|
||||
# Clean up the temporary file
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
# Provide download button
|
||||
st.download_button(
|
||||
label="Download Trained Model",
|
||||
data=model_data,
|
||||
file_name=f'ai_model_{ticker}_{datetime.now().strftime("%Y%m%d")}.h5',
|
||||
mime='application/octet-stream'
|
||||
)
|
||||
except Exception as e:
|
||||
st.warning(f"Could not save model: {str(e)}")
|
||||
st.info("You can still use the forecast results even though the model couldn't be saved.")
|
||||
|
||||
# Complete progress
|
||||
progress_bar.progress(100)
|
||||
status_text.text("Forecast complete!")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error during forecast: {str(e)}")
|
||||
import traceback
|
||||
st.code(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
ai_forecast_page()
|
||||
@ -1,329 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yfinance as yf
|
||||
from datetime import datetime, timedelta
|
||||
from utils.common_utils import get_stock_data
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
from typing import Tuple, List, Dict
|
||||
from scipy import stats
|
||||
|
||||
class MonteCarloSimulator:
|
||||
def __init__(self, data: pd.DataFrame, num_simulations: int, time_horizon: int):
|
||||
"""
|
||||
Initialize Monte Carlo simulator
|
||||
|
||||
Args:
|
||||
data (pd.DataFrame): Historical price data
|
||||
num_simulations (int): Number of simulation paths
|
||||
time_horizon (int): Number of days to simulate
|
||||
"""
|
||||
# Make a copy and standardize column names
|
||||
self.data = data.copy()
|
||||
self.data.columns = [col.capitalize() for col in self.data.columns]
|
||||
|
||||
self.num_simulations = num_simulations
|
||||
self.time_horizon = time_horizon
|
||||
self.returns = np.log(self.data['Close'] / self.data['Close'].shift(1)).dropna()
|
||||
self.last_price = self.data['Close'].iloc[-1]
|
||||
self.drift = self.returns.mean()
|
||||
self.volatility = self.returns.std()
|
||||
|
||||
def run_simulation(self) -> np.ndarray:
|
||||
"""Run Monte Carlo simulation and return paths"""
|
||||
# Generate random walks
|
||||
daily_returns = np.random.normal(
|
||||
(self.drift + (self.volatility ** 2) / 2),
|
||||
self.volatility,
|
||||
size=(self.time_horizon, self.num_simulations)
|
||||
)
|
||||
|
||||
# Calculate price paths
|
||||
price_paths = np.zeros_like(daily_returns)
|
||||
price_paths[0] = self.last_price
|
||||
for t in range(1, self.time_horizon):
|
||||
price_paths[t] = price_paths[t-1] * np.exp(daily_returns[t])
|
||||
|
||||
return price_paths
|
||||
|
||||
def calculate_target_price(self, confidence_level: float = 95) -> float:
|
||||
"""
|
||||
Calculate target price based on Monte Carlo simulation and confidence level
|
||||
|
||||
Args:
|
||||
confidence_level (float): Confidence level for target price (default: 95)
|
||||
|
||||
Returns:
|
||||
float: Recommended target price
|
||||
"""
|
||||
# Run simulation
|
||||
paths = self.run_simulation()
|
||||
|
||||
# Get the prices at specified timeframe
|
||||
final_prices = paths[-1]
|
||||
|
||||
# Calculate the price that represents the upside potential
|
||||
target_price = np.percentile(final_prices, confidence_level)
|
||||
|
||||
return target_price
|
||||
|
||||
def calculate_stop_loss(self, risk_percentage: float) -> float:
|
||||
"""
|
||||
Calculate stop loss price based on Monte Carlo simulation and desired risk percentage
|
||||
|
||||
Args:
|
||||
risk_percentage (float): Maximum risk percentage willing to take
|
||||
|
||||
Returns:
|
||||
float: Recommended stop loss price
|
||||
"""
|
||||
# Run a quick simulation
|
||||
paths = self.run_simulation()
|
||||
|
||||
# Get the first day's simulated prices
|
||||
first_day_prices = paths[1] # Using day 1 instead of day 0 to see immediate movement
|
||||
|
||||
# Calculate the price that represents the risk percentage loss
|
||||
potential_losses = (first_day_prices - self.last_price) / self.last_price * 100
|
||||
|
||||
# Find the price level that corresponds to our risk percentage
|
||||
stop_loss_percentile = np.percentile(potential_losses, risk_percentage)
|
||||
stop_loss_price = self.last_price * (1 + stop_loss_percentile / 100)
|
||||
|
||||
return stop_loss_price
|
||||
|
||||
def calculate_metrics(self, paths: np.ndarray) -> Dict:
|
||||
"""Calculate key metrics from simulation results"""
|
||||
final_prices = paths[-1]
|
||||
returns = (final_prices - self.last_price) / self.last_price
|
||||
|
||||
metrics = {
|
||||
'Expected Price': np.mean(final_prices),
|
||||
'Median Price': np.median(final_prices),
|
||||
'Std Dev': np.std(final_prices),
|
||||
'Skewness': stats.skew(final_prices),
|
||||
'Kurtosis': stats.kurtosis(final_prices),
|
||||
'95% CI Lower': np.percentile(final_prices, 2.5),
|
||||
'95% CI Upper': np.percentile(final_prices, 97.5),
|
||||
'Probability Above Current': np.mean(final_prices > self.last_price) * 100,
|
||||
'Expected Return': np.mean(returns) * 100,
|
||||
'VaR (95%)': np.percentile(returns, 5) * 100,
|
||||
'CVaR (95%)': np.mean(returns[returns <= np.percentile(returns, 5)]) * 100
|
||||
}
|
||||
return metrics
|
||||
|
||||
def create_simulation_plot(paths: np.ndarray, dates: pd.DatetimeIndex,
|
||||
ticker: str, last_price: float) -> go.Figure:
|
||||
"""Create an interactive plot of simulation results"""
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1,
|
||||
subplot_titles=('Price Paths', 'Price Distribution at End Date'),
|
||||
vertical_spacing=0.15,
|
||||
row_heights=[0.7, 0.3]
|
||||
)
|
||||
|
||||
# Plot all simulation paths with low opacity
|
||||
for i in range(paths.shape[1]):
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=dates,
|
||||
y=paths[:, i],
|
||||
name=f'Simulation {i+1}',
|
||||
line=dict(color='orange', width=0.1),
|
||||
opacity=0.1,
|
||||
showlegend=False
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Plot confidence intervals
|
||||
percentiles = np.percentile(paths, [5, 25, 50, 75, 95], axis=1)
|
||||
|
||||
# Add upper bound (blue)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=dates,
|
||||
y=percentiles[4],
|
||||
name='Upper 95% Confidence',
|
||||
line=dict(color='blue', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add lower bound (red)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=dates,
|
||||
y=percentiles[0],
|
||||
name='Lower 95% Confidence',
|
||||
line=dict(color='red', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add median path
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=dates,
|
||||
y=percentiles[2],
|
||||
name='Median Path',
|
||||
line=dict(color='white', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add starting price line
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=dates,
|
||||
y=[last_price] * len(dates),
|
||||
name='Current Price',
|
||||
line=dict(color='green', dash='dash', width=2)
|
||||
),
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# Add histogram of final prices
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
x=paths[-1],
|
||||
name='Final Price Distribution',
|
||||
nbinsx=50,
|
||||
marker_color='orange'
|
||||
),
|
||||
row=2, col=1
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title=f'Monte Carlo Simulation - {ticker}',
|
||||
showlegend=True,
|
||||
height=800,
|
||||
template='plotly_dark', # Dark theme for better visibility
|
||||
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
|
||||
paper_bgcolor='rgba(0,0,0,0)'
|
||||
)
|
||||
|
||||
# Update axes
|
||||
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
|
||||
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
|
||||
|
||||
return fig
|
||||
|
||||
def monte_carlo_page():
|
||||
st.title("Monte Carlo Price Simulation")
|
||||
|
||||
# Input parameters
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
ticker = st.text_input("Enter Ticker Symbol", value="AAPL").upper()
|
||||
start_date = st.date_input(
|
||||
"Start Date (for historical data)",
|
||||
value=datetime.now() - timedelta(days=365)
|
||||
)
|
||||
end_date = st.date_input("End Date", value=datetime.now())
|
||||
|
||||
with col2:
|
||||
num_simulations = st.number_input(
|
||||
"Number of Simulations",
|
||||
min_value=100,
|
||||
max_value=10000,
|
||||
value=1000,
|
||||
step=100
|
||||
)
|
||||
time_horizon = st.number_input(
|
||||
"Time Horizon (Days)",
|
||||
min_value=5,
|
||||
max_value=365,
|
||||
value=30
|
||||
)
|
||||
confidence_level = st.slider(
|
||||
"Confidence Level (%)",
|
||||
min_value=80,
|
||||
max_value=99,
|
||||
value=95
|
||||
)
|
||||
|
||||
if st.button("Run Simulation"):
|
||||
with st.spinner('Running Monte Carlo simulation...'):
|
||||
try:
|
||||
# Get minute-level historical data
|
||||
df = get_stock_data(
|
||||
ticker,
|
||||
datetime.combine(start_date, datetime.min.time()),
|
||||
datetime.combine(end_date, datetime.min.time()),
|
||||
'1m' # Get minute data
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
st.error("No data available for the selected period")
|
||||
return
|
||||
|
||||
# Initialize simulator
|
||||
simulator = MonteCarloSimulator(df, num_simulations, time_horizon)
|
||||
|
||||
# Run simulation
|
||||
paths = simulator.run_simulation()
|
||||
|
||||
# Calculate metrics
|
||||
metrics = simulator.calculate_metrics(paths)
|
||||
|
||||
# Generate future dates for plotting
|
||||
future_dates = pd.date_range(
|
||||
start=end_date,
|
||||
periods=time_horizon,
|
||||
freq='B' # Business days
|
||||
)
|
||||
|
||||
# Display results
|
||||
col1, col2 = st.columns([3, 1])
|
||||
|
||||
with col1:
|
||||
# Plot results
|
||||
fig = create_simulation_plot(
|
||||
paths, future_dates, ticker, simulator.last_price
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.subheader("Simulation Metrics")
|
||||
|
||||
# Price metrics
|
||||
st.write("##### Price Projections")
|
||||
st.metric("Expected Price", f"${metrics['Expected Price']:.2f}")
|
||||
st.metric("95% CI", f"${metrics['95% CI Lower']:.2f} - ${metrics['95% CI Upper']:.2f}")
|
||||
|
||||
# Risk metrics
|
||||
st.write("##### Risk Metrics")
|
||||
st.metric("Expected Return", f"{metrics['Expected Return']:.1f}%")
|
||||
st.metric("Value at Risk (95%)", f"{abs(metrics['VaR (95%)']):.1f}%")
|
||||
st.metric("Conditional VaR", f"{abs(metrics['CVaR (95%)']):.1f}%")
|
||||
|
||||
# Distribution metrics
|
||||
st.write("##### Distribution Metrics")
|
||||
st.metric("Standard Deviation", f"${metrics['Std Dev']:.2f}")
|
||||
st.metric("Skewness", f"{metrics['Skewness']:.2f}")
|
||||
st.metric("Kurtosis", f"{metrics['Kurtosis']:.2f}")
|
||||
|
||||
# Add download button for simulation results
|
||||
final_prices_df = pd.DataFrame({
|
||||
'Simulation': range(1, num_simulations + 1),
|
||||
'Final_Price': paths[-1],
|
||||
'Return': (paths[-1] - simulator.last_price) / simulator.last_price * 100
|
||||
})
|
||||
|
||||
csv = final_prices_df.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="Download Simulation Results",
|
||||
data=csv,
|
||||
file_name=f'monte_carlo_{ticker}_{datetime.now().strftime("%Y%m%d")}.csv',
|
||||
mime='text/csv'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error during simulation: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
monte_carlo_page()
|
||||
@ -1 +0,0 @@
|
||||
# Initialize backtesting package
|
||||
@ -1,891 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas_ta as ta
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from utils.common_utils import get_qualified_stocks
|
||||
from backtesting import Backtest, Strategy
|
||||
from typing import Dict, List, Union
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from utils.common_utils import get_stock_data
|
||||
|
||||
class DynamicStrategy(Strategy):
|
||||
"""Dynamic strategy class that can be configured through the UI"""
|
||||
|
||||
def init(self):
|
||||
# Will be populated with indicator calculations
|
||||
self.indicators = {}
|
||||
|
||||
# Initialize all selected indicators
|
||||
for ind_name, ind_config in self.indicator_configs.items():
|
||||
if ind_config['type'] == 'SMA':
|
||||
def sma_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.sma(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(sma_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'EMA':
|
||||
def ema_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.ema(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(ema_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'RSI':
|
||||
def rsi_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.rsi(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(rsi_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'MACD':
|
||||
fast = int(ind_config['params']['fast'])
|
||||
slow = int(ind_config['params']['slow'])
|
||||
signal = int(ind_config['params']['signal'])
|
||||
|
||||
def macd_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.macd(series, fast=fast, slow=slow, signal=signal)
|
||||
if result is None or result.empty:
|
||||
return np.zeros(len(x)), np.zeros(len(x))
|
||||
|
||||
macd_col = result.columns[0]
|
||||
signal_col = result.columns[2]
|
||||
|
||||
macd_vals = result[macd_col].fillna(method='ffill').fillna(0).values
|
||||
signal_vals = result[signal_col].fillna(method='ffill').fillna(0).values
|
||||
|
||||
return macd_vals, signal_vals
|
||||
|
||||
macd_vals = self.I(macd_calc, self.data.Close)
|
||||
self.indicators[f"{ind_name}_macd"] = macd_vals[0]
|
||||
self.indicators[f"{ind_name}_signal"] = macd_vals[1]
|
||||
|
||||
elif ind_config['type'] == 'BB':
|
||||
length = int(ind_config['params']['length'])
|
||||
std = float(ind_config['params']['std'])
|
||||
|
||||
def bb_calc(x):
|
||||
# Ensure input is a pandas Series with proper index
|
||||
series = pd.Series(x)
|
||||
|
||||
# Calculate BB using pandas-ta
|
||||
bb_result = ta.bbands(series, length=length, std=std)
|
||||
|
||||
if bb_result is None or bb_result.empty:
|
||||
return np.zeros(len(x)), np.zeros(len(x)), np.zeros(len(x))
|
||||
|
||||
# Get the column names from bb_result
|
||||
upper_col = bb_result.columns[2] # BBU
|
||||
middle_col = bb_result.columns[1] # BBM
|
||||
lower_col = bb_result.columns[0] # BBL
|
||||
|
||||
# Extract and process the values using ffill() instead of fillna(method='ffill')
|
||||
upper = bb_result[upper_col].ffill().fillna(0).values
|
||||
middle = bb_result[middle_col].ffill().fillna(0).values
|
||||
lower = bb_result[lower_col].ffill().fillna(0).values
|
||||
|
||||
# Add debug print to verify values
|
||||
if len(upper) > 0:
|
||||
print(f"Debug BB values - Upper: {upper[-1]:.2f}, Middle: {middle[-1]:.2f}, Lower: {lower[-1]:.2f}")
|
||||
|
||||
return upper, middle, lower
|
||||
|
||||
bb_vals = self.I(bb_calc, self.data.Close)
|
||||
self.indicators[f"{ind_name}_upper"] = bb_vals[0]
|
||||
self.indicators[f"{ind_name}_middle"] = bb_vals[1]
|
||||
self.indicators[f"{ind_name}_lower"] = bb_vals[2]
|
||||
|
||||
elif ind_config['type'] == 'WMA':
|
||||
def wma_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.wma(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(wma_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'DEMA':
|
||||
def dema_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.dema(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(dema_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'TEMA':
|
||||
def tema_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.tema(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(tema_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'HMA':
|
||||
def hma_calc(x):
|
||||
series = pd.Series(x)
|
||||
result = ta.hma(series, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(hma_calc, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'VWAP':
|
||||
def vwap_calc(high, low, close, volume):
|
||||
result = ta.vwap(high, low, close, volume, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(vwap_calc, self.data.High, self.data.Low, self.data.Close, self.data.Volume)
|
||||
|
||||
elif ind_config['type'] == 'Stochastic':
|
||||
def stoch_calc(high, low, close):
|
||||
result = ta.stoch(high, low, close,
|
||||
k=int(ind_config['params']['k']),
|
||||
d=int(ind_config['params']['d']),
|
||||
smooth_k=int(ind_config['params']['smooth_k']))
|
||||
return (result['STOCHk_14_3_3'].fillna(method='ffill').fillna(0).values,
|
||||
result['STOCHd_14_3_3'].fillna(method='ffill').fillna(0).values)
|
||||
stoch_vals = self.I(stoch_calc, self.data.High, self.data.Low, self.data.Close)
|
||||
self.indicators[f"{ind_name}_k"] = stoch_vals[0]
|
||||
self.indicators[f"{ind_name}_d"] = stoch_vals[1]
|
||||
|
||||
elif ind_config['type'] == 'ADX':
|
||||
def adx_calc(high, low, close):
|
||||
result = ta.adx(high=pd.Series(high),
|
||||
low=pd.Series(low),
|
||||
close=pd.Series(close),
|
||||
length=int(ind_config['params']['length']),
|
||||
lensig=14) # Adding lensig parameter
|
||||
if isinstance(result, pd.DataFrame):
|
||||
# ADX is typically the third column in the result
|
||||
return result.iloc[:, 2].fillna(method='ffill').fillna(0).values
|
||||
else:
|
||||
# If result is a Series, return it directly
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(adx_calc, self.data.High, self.data.Low, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'CCI':
|
||||
def cci_calc(high, low, close):
|
||||
result = ta.cci(high, low, close, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(cci_calc, self.data.High, self.data.Low, self.data.Close)
|
||||
|
||||
elif ind_config['type'] == 'MFI':
|
||||
def mfi_calc(high, low, close, volume):
|
||||
result = ta.mfi(high, low, close, volume, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(mfi_calc, self.data.High, self.data.Low, self.data.Close, self.data.Volume)
|
||||
|
||||
elif ind_config['type'] == 'Williams%R':
|
||||
def willr_calc(high, low, close):
|
||||
result = ta.willr(high, low, close, length=int(ind_config['params']['length']))
|
||||
return result.fillna(method='ffill').fillna(0).values
|
||||
self.indicators[ind_name] = self.I(willr_calc, self.data.High, self.data.Low, self.data.Close)
|
||||
|
||||
def next(self):
|
||||
price = self.data.Close[-1]
|
||||
|
||||
# Example trading logic using indicators
|
||||
for ind_name, ind_config in self.indicator_configs.items():
|
||||
if ind_config['type'] == 'SMA':
|
||||
current_sma = self.indicators[ind_name][-1]
|
||||
print(f"SMA - Price: {price:.2f}, SMA: {current_sma:.2f}")
|
||||
if price > current_sma and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above SMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_sma and self.position:
|
||||
print(f"SELL signal - Price below SMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'MACD':
|
||||
macd = self.indicators[f"{ind_name}_macd"][-1]
|
||||
signal = self.indicators[f"{ind_name}_signal"][-1]
|
||||
prev_macd = self.indicators[f"{ind_name}_macd"][-2]
|
||||
prev_signal = self.indicators[f"{ind_name}_signal"][-2]
|
||||
|
||||
print(f"MACD - MACD: {macd:.2f}, Signal: {signal:.2f}")
|
||||
if macd > signal and prev_macd <= prev_signal and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - MACD crossover with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif macd < signal and prev_macd >= prev_signal and self.position:
|
||||
print(f"SELL signal - MACD crossunder")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'BB':
|
||||
upper = self.indicators[f"{ind_name}_upper"][-1]
|
||||
lower = self.indicators[f"{ind_name}_lower"][-1]
|
||||
|
||||
print(f"BB - Price: {price:.2f}, Upper: {upper:.2f}, Lower: {lower:.2f}")
|
||||
if price <= lower and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price at lower band with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price >= upper and self.position:
|
||||
print(f"SELL signal - Price at upper band")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'RSI':
|
||||
rsi = self.indicators[ind_name][-1]
|
||||
print(f"RSI - Value: {rsi:.2f}")
|
||||
|
||||
if rsi < 30 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - RSI oversold with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif rsi > 70 and self.position:
|
||||
print(f"SELL signal - RSI overbought")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'EMA':
|
||||
current_ema = self.indicators[ind_name][-1]
|
||||
print(f"EMA - Price: {price:.2f}, EMA: {current_ema:.2f}")
|
||||
if price > current_ema and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above EMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_ema and self.position:
|
||||
print(f"SELL signal - Price below EMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'WMA':
|
||||
current_wma = self.indicators[ind_name][-1]
|
||||
print(f"WMA - Price: {price:.2f}, WMA: {current_wma:.2f}")
|
||||
if price > current_wma and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above WMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_wma and self.position:
|
||||
print(f"SELL signal - Price below WMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'DEMA':
|
||||
current_dema = self.indicators[ind_name][-1]
|
||||
print(f"DEMA - Price: {price:.2f}, DEMA: {current_dema:.2f}")
|
||||
if price > current_dema and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above DEMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_dema and self.position:
|
||||
print(f"SELL signal - Price below DEMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'TEMA':
|
||||
current_tema = self.indicators[ind_name][-1]
|
||||
print(f"TEMA - Price: {price:.2f}, TEMA: {current_tema:.2f}")
|
||||
if price > current_tema and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above TEMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_tema and self.position:
|
||||
print(f"SELL signal - Price below TEMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'HMA':
|
||||
current_hma = self.indicators[ind_name][-1]
|
||||
print(f"HMA - Price: {price:.2f}, HMA: {current_hma:.2f}")
|
||||
if price > current_hma and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above HMA with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_hma and self.position:
|
||||
print(f"SELL signal - Price below HMA")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'VWAP':
|
||||
current_vwap = self.indicators[ind_name][-1]
|
||||
print(f"VWAP - Price: {price:.2f}, VWAP: {current_vwap:.2f}")
|
||||
if price > current_vwap and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Price above VWAP with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif price < current_vwap and self.position:
|
||||
print(f"SELL signal - Price below VWAP")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'Stochastic':
|
||||
k = self.indicators[f"{ind_name}_k"][-1]
|
||||
d = self.indicators[f"{ind_name}_d"][-1]
|
||||
prev_k = self.indicators[f"{ind_name}_k"][-2]
|
||||
prev_d = self.indicators[f"{ind_name}_d"][-2]
|
||||
|
||||
print(f"Stochastic - K: {k:.2f}, D: {d:.2f}")
|
||||
if k > d and prev_k <= prev_d and k < 20 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Stochastic crossover in oversold territory with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif k < d and prev_k >= prev_d and k > 80 and self.position:
|
||||
print(f"SELL signal - Stochastic crossunder in overbought territory")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'ADX':
|
||||
adx = self.indicators[ind_name][-1]
|
||||
print(f"ADX - Value: {adx:.2f}")
|
||||
if adx > 25 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Strong trend detected with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif adx < 20 and self.position:
|
||||
print(f"SELL signal - Weak trend")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'CCI':
|
||||
cci = self.indicators[ind_name][-1]
|
||||
print(f"CCI - Value: {cci:.2f}")
|
||||
if cci < -100 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - CCI oversold with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif cci > 100 and self.position:
|
||||
print(f"SELL signal - CCI overbought")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'MFI':
|
||||
mfi = self.indicators[ind_name][-1]
|
||||
print(f"MFI - Value: {mfi:.2f}")
|
||||
if mfi < 20 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - MFI oversold with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif mfi > 80 and self.position:
|
||||
print(f"SELL signal - MFI overbought")
|
||||
self.position.close()
|
||||
|
||||
elif ind_config['type'] == 'Williams%R':
|
||||
willr = self.indicators[ind_name][-1]
|
||||
print(f"Williams%R - Value: {willr:.2f}")
|
||||
if willr < -80 and not self.position:
|
||||
stop_loss = price * 0.93
|
||||
print(f"BUY signal - Williams%R oversold with stop loss at {stop_loss:.2f}")
|
||||
self.buy(sl=stop_loss)
|
||||
elif willr > -20 and self.position:
|
||||
print(f"SELL signal - Williams%R overbought")
|
||||
self.position.close()
|
||||
|
||||
def get_available_indicators() -> Dict:
|
||||
"""Returns available indicators and their parameters"""
|
||||
return {
|
||||
'SMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'EMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'RSI': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 14},
|
||||
'ranges': {'length': (5, 30)}
|
||||
},
|
||||
'MACD': {
|
||||
'params': ['fast', 'slow', 'signal'],
|
||||
'defaults': {'fast': 12, 'slow': 26, 'signal': 9},
|
||||
'ranges': {'fast': (8, 20), 'slow': (20, 40), 'signal': (5, 15)}
|
||||
},
|
||||
'BB': {
|
||||
'params': ['length', 'std'],
|
||||
'defaults': {'length': 20, 'std': 2.0},
|
||||
'ranges': {'length': (10, 50), 'std': (1.5, 3.0)}
|
||||
},
|
||||
'WMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'DEMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'TEMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'HMA': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (5, 200)}
|
||||
},
|
||||
'VWAP': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 14},
|
||||
'ranges': {'length': (1, 30)}
|
||||
},
|
||||
'Stochastic': {
|
||||
'params': ['k', 'd', 'smooth_k'],
|
||||
'defaults': {'k': 14, 'd': 3, 'smooth_k': 3},
|
||||
'ranges': {'k': (5, 30), 'd': (2, 10), 'smooth_k': (2, 10)}
|
||||
},
|
||||
'ADX': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 14},
|
||||
'ranges': {'length': (5, 30)}
|
||||
},
|
||||
'CCI': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 20},
|
||||
'ranges': {'length': (10, 50)}
|
||||
},
|
||||
'MFI': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 14},
|
||||
'ranges': {'length': (5, 30)}
|
||||
},
|
||||
'Williams%R': {
|
||||
'params': ['length'],
|
||||
'defaults': {'length': 14},
|
||||
'ranges': {'length': (5, 30)}
|
||||
}
|
||||
}
|
||||
|
||||
def prepare_data_for_backtest(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare the dataframe for backtesting"""
|
||||
print("\nPreparing data for backtest...")
|
||||
print(f"Initial dataframe shape: {df.shape}")
|
||||
print(f"Initial columns: {df.columns.tolist()}")
|
||||
|
||||
# Ensure the dataframe has the required columns
|
||||
required_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||
|
||||
# Rename columns if needed
|
||||
df = df.copy()
|
||||
df.columns = [c.capitalize() for c in df.columns]
|
||||
print(f"Columns after capitalization: {df.columns.tolist()}")
|
||||
|
||||
# Set the date as index if it's not already
|
||||
if 'Date' in df.columns:
|
||||
df.set_index('Date', inplace=True)
|
||||
print("Set Date as index")
|
||||
|
||||
# Verify all required columns exist
|
||||
missing_cols = [col for col in required_columns if col not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"WARNING: Missing columns: {missing_cols}")
|
||||
raise ValueError(f"Missing required columns: {missing_cols}")
|
||||
|
||||
print(f"Final dataframe shape: {df.shape}")
|
||||
|
||||
return df
|
||||
|
||||
def backtesting_page():
|
||||
st.title("Strategy Backtesting")
|
||||
|
||||
# Create two columns for the main layout
|
||||
left_col, right_col = st.columns([2, 3])
|
||||
|
||||
with left_col:
|
||||
st.subheader("Backtest Settings")
|
||||
|
||||
# Date range selection
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
start_date = st.date_input("Start Date",
|
||||
value=datetime.now() - timedelta(days=365))
|
||||
with col2:
|
||||
end_date = st.date_input("End Date")
|
||||
|
||||
# Convert dates to datetime objects
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||||
end_datetime = datetime.combine(end_date, datetime.min.time())
|
||||
|
||||
# Add radio button for test mode
|
||||
test_mode = st.radio("Testing Mode", ["Single Ticker", "Multiple Tickers", "All Available Tickers"])
|
||||
|
||||
if test_mode == "Single Ticker":
|
||||
# Single ticker input
|
||||
ticker = st.text_input("Enter Ticker Symbol", value="AAPL").upper()
|
||||
tickers = [ticker]
|
||||
elif test_mode == "Multiple Tickers":
|
||||
# Multiple ticker input
|
||||
ticker_input = st.text_area(
|
||||
"Enter Ticker Symbols (one per line)",
|
||||
value="AAPL\nMSFT\nGOOG"
|
||||
)
|
||||
tickers = [t.strip().upper() for t in ticker_input.split('\n') if t.strip()]
|
||||
else: # All Available Tickers
|
||||
st.subheader("Filter Settings")
|
||||
min_price = st.number_input("Minimum Price", value=5.0)
|
||||
max_price = st.number_input("Maximum Price", value=1000.0)
|
||||
min_volume = st.number_input("Minimum Volume", value=100000)
|
||||
|
||||
# Get all qualified stocks based on filters
|
||||
try:
|
||||
qualified_stocks = get_qualified_stocks(
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume
|
||||
)
|
||||
st.info(f"Found {len(qualified_stocks)} qualified stocks for testing")
|
||||
tickers = qualified_stocks
|
||||
except Exception as e:
|
||||
st.error(f"Error getting qualified stocks: {str(e)}")
|
||||
tickers = []
|
||||
|
||||
# Add performance filters for multiple and all tickers modes
|
||||
if test_mode in ["Multiple Tickers", "All Available Tickers"]:
|
||||
st.subheader("Performance Filters")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
min_return = st.number_input("Minimum Return (%)", value=10.0)
|
||||
with col2:
|
||||
min_sharpe = st.number_input("Minimum Sharpe Ratio", value=1.0)
|
||||
with col3:
|
||||
max_drawdown = st.number_input("Maximum Drawdown (%)", value=-20.0)
|
||||
|
||||
# Add batch size control for processing
|
||||
batch_size = st.number_input("Batch Size (tickers per batch)", value=50, min_value=1)
|
||||
|
||||
# Add progress tracking
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
# Indicator selection
|
||||
available_indicators = get_available_indicators()
|
||||
selected_indicators = st.multiselect(
|
||||
"Select Technical Indicators",
|
||||
options=list(available_indicators.keys())
|
||||
)
|
||||
|
||||
# Parameter input/optimization section
|
||||
optimize = st.checkbox("Optimize Parameters")
|
||||
|
||||
indicator_settings = {}
|
||||
for ind_name in selected_indicators:
|
||||
st.subheader(f"{ind_name} Settings")
|
||||
ind_config = available_indicators[ind_name]
|
||||
|
||||
if optimize:
|
||||
params = {}
|
||||
for param in ind_config['params']:
|
||||
param_range = ind_config['ranges'][param]
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
min_val = st.number_input(f"{param} Min",
|
||||
value=float(param_range[0]))
|
||||
with col2:
|
||||
max_val = st.number_input(f"{param} Max",
|
||||
value=float(param_range[1]))
|
||||
with col3:
|
||||
step = st.number_input(f"{param} Step",
|
||||
value=1.0 if param == 'std' else 1)
|
||||
params[param] = {'min': min_val, 'max': max_val, 'step': step}
|
||||
else:
|
||||
params = {}
|
||||
for param in ind_config['params']:
|
||||
default_val = ind_config['defaults'][param]
|
||||
params[param] = st.number_input(f"{param}",
|
||||
value=float(default_val))
|
||||
|
||||
indicator_settings[ind_name] = {
|
||||
'type': ind_name,
|
||||
'params': params
|
||||
}
|
||||
|
||||
if st.button("Run Backtest"):
|
||||
with st.spinner('Running backtest...'):
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||||
end_datetime = datetime.combine(end_date, datetime.min.time())
|
||||
|
||||
if test_mode == "Single Ticker":
|
||||
# Single ticker logic
|
||||
df = get_stock_data(ticker, start_datetime, end_datetime, 'daily')
|
||||
if df.empty:
|
||||
st.error("No data available for the selected period")
|
||||
return
|
||||
|
||||
try:
|
||||
df = prepare_data_for_backtest(df)
|
||||
if optimize:
|
||||
results = run_optimization(df, indicator_settings)
|
||||
with right_col:
|
||||
display_optimization_results(results)
|
||||
else:
|
||||
results = run_single_backtest(df, indicator_settings)
|
||||
with right_col:
|
||||
display_backtest_results(results)
|
||||
except Exception as e:
|
||||
st.error(f"Error during backtest: {str(e)}")
|
||||
|
||||
else:
|
||||
# Multiple ticker or All Available Tickers logic
|
||||
try:
|
||||
total_tickers = len(tickers)
|
||||
processed_tickers = []
|
||||
all_results = []
|
||||
|
||||
for i in range(0, total_tickers, batch_size):
|
||||
batch = tickers[i:i+batch_size]
|
||||
status_text.text(f"Processing batch {i//batch_size + 1} of {(total_tickers + batch_size - 1)//batch_size}")
|
||||
|
||||
results_df = run_multi_ticker_backtest(
|
||||
batch, start_datetime, end_datetime, indicator_settings
|
||||
)
|
||||
|
||||
if not results_df.empty:
|
||||
all_results.append(results_df)
|
||||
processed_tickers.extend(batch)
|
||||
|
||||
# Update progress
|
||||
progress = min((i + batch_size) / total_tickers, 1.0)
|
||||
progress_bar.progress(progress)
|
||||
|
||||
if all_results:
|
||||
# Combine all results
|
||||
results_df = pd.concat(all_results, ignore_index=True)
|
||||
|
||||
# Apply performance filters
|
||||
filtered_df = results_df[
|
||||
(results_df['Return [%]'] >= min_return) &
|
||||
(results_df['Sharpe Ratio'] >= min_sharpe) &
|
||||
(results_df['Max Drawdown [%]'] >= max_drawdown)
|
||||
]
|
||||
|
||||
with right_col:
|
||||
st.subheader("Multi-Ticker Results")
|
||||
|
||||
# Display summary statistics
|
||||
st.write("### Summary Statistics")
|
||||
summary = pd.DataFrame({
|
||||
'Metric': [
|
||||
'Total Tickers Tested',
|
||||
'Successful Tests',
|
||||
'Average Return',
|
||||
'Average Sharpe',
|
||||
'Average Drawdown',
|
||||
'Success Rate'
|
||||
],
|
||||
'Value': [
|
||||
f"{len(processed_tickers)}",
|
||||
f"{len(results_df)}",
|
||||
f"{results_df['Return [%]'].mean():.2f}%",
|
||||
f"{results_df['Sharpe Ratio'].mean():.2f}",
|
||||
f"{results_df['Max Drawdown [%]'].mean():.2f}%",
|
||||
f"{(len(filtered_df) / len(results_df) * 100):.1f}%"
|
||||
]
|
||||
})
|
||||
st.table(summary)
|
||||
|
||||
# Display full results
|
||||
st.write("### All Results")
|
||||
st.dataframe(results_df.sort_values('Return [%]', ascending=False))
|
||||
|
||||
# Display filtered results
|
||||
st.write("### Filtered Results (Meeting Criteria)")
|
||||
st.dataframe(filtered_df.sort_values('Return [%]', ascending=False))
|
||||
|
||||
# Create a downloadable CSV
|
||||
csv = results_df.to_csv(index=False)
|
||||
st.download_button(
|
||||
"Download Results CSV",
|
||||
csv,
|
||||
"backtest_results.csv",
|
||||
"text/csv",
|
||||
key='download-csv'
|
||||
)
|
||||
else:
|
||||
st.error("No valid results were generated from any ticker")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error during multi-ticker backtest: {str(e)}")
|
||||
st.error("Full error details: " + str(e))
|
||||
|
||||
def run_optimization(df: pd.DataFrame, indicator_settings: Dict) -> List:
|
||||
"""Run optimization with different parameter combinations"""
|
||||
param_combinations = generate_param_combinations(indicator_settings)
|
||||
print(f"Generated {len(param_combinations)} parameter combinations")
|
||||
results = []
|
||||
|
||||
for i, params in enumerate(param_combinations):
|
||||
print(f"\nTesting combination {i+1}/{len(param_combinations)}")
|
||||
print(f"Parameters: {params}")
|
||||
|
||||
# Configure strategy with current parameters
|
||||
DynamicStrategy.indicator_configs = params
|
||||
|
||||
# Run backtest
|
||||
bt = Backtest(df, DynamicStrategy, cash=100000, commission=.002)
|
||||
stats = bt.run()
|
||||
|
||||
print(f"Results - Return: {stats['Return [%]']:.2f}%, "
|
||||
f"Sharpe: {stats['Sharpe Ratio']:.2f}, "
|
||||
f"Drawdown: {stats['Max. Drawdown [%]']:.2f}%")
|
||||
|
||||
results.append({
|
||||
'parameters': params,
|
||||
'Return [%]': stats['Return [%]'],
|
||||
'Sharpe Ratio': stats['Sharpe Ratio'],
|
||||
'Max Drawdown [%]': stats['Max. Drawdown [%]'], # Updated key
|
||||
'Win Rate [%]': stats['Win Rate [%]']
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def run_single_backtest(df: pd.DataFrame, indicator_settings: Dict) -> Dict:
|
||||
"""Run a single backtest with fixed parameters"""
|
||||
DynamicStrategy.indicator_configs = indicator_settings
|
||||
bt = Backtest(df, DynamicStrategy, cash=100000, commission=.002)
|
||||
return bt.run()
|
||||
|
||||
def generate_param_combinations(settings: Dict) -> List[Dict]:
|
||||
"""Generate all possible parameter combinations for optimization"""
|
||||
param_space = {}
|
||||
for ind_name, ind_config in settings.items():
|
||||
for param_name, param_range in ind_config['params'].items():
|
||||
if isinstance(param_range, dict): # Optimization mode
|
||||
values = np.arange(
|
||||
param_range['min'],
|
||||
param_range['max'] + param_range['step'],
|
||||
param_range['step']
|
||||
)
|
||||
param_space[f"{ind_name}_{param_name}"] = values
|
||||
|
||||
# Generate all combinations
|
||||
keys = list(param_space.keys())
|
||||
values = list(param_space.values())
|
||||
combinations = list(itertools.product(*values))
|
||||
|
||||
# Convert to list of parameter dictionaries
|
||||
result = []
|
||||
for combo in combinations:
|
||||
params = {}
|
||||
for key, value in zip(keys, combo):
|
||||
ind_name, param_name = key.rsplit('_', 1)
|
||||
if ind_name not in params:
|
||||
params[ind_name] = {'type': ind_name, 'params': {}}
|
||||
params[ind_name]['params'][param_name] = value
|
||||
result.append(params)
|
||||
|
||||
return result
|
||||
|
||||
def display_optimization_results(results: List):
|
||||
"""Display optimization results in a formatted table"""
|
||||
df_results = pd.DataFrame(results)
|
||||
|
||||
# Format parameters column for better display
|
||||
df_results['parameters'] = df_results['parameters'].apply(str)
|
||||
|
||||
st.subheader("Optimization Results")
|
||||
st.dataframe(df_results.sort_values('Return [%]', ascending=False))
|
||||
|
||||
def run_multi_ticker_backtest(tickers: list, start_date: datetime, end_date: datetime, indicator_settings: Dict) -> pd.DataFrame:
|
||||
"""Run backtest across multiple tickers and aggregate results"""
|
||||
all_results = []
|
||||
|
||||
if not indicator_settings:
|
||||
print("Error: No indicators selected")
|
||||
raise ValueError("Please select at least one indicator")
|
||||
|
||||
# Convert indicator settings to the correct format
|
||||
processed_settings = {}
|
||||
for ind_name, ind_config in indicator_settings.items():
|
||||
processed_settings[ind_name] = {
|
||||
'type': ind_config['type'],
|
||||
'params': {}
|
||||
}
|
||||
# Handle both optimization and non-optimization cases
|
||||
if isinstance(ind_config['params'], dict):
|
||||
for param_name, param_value in ind_config['params'].items():
|
||||
# Handle both direct values and range dictionaries
|
||||
if isinstance(param_value, dict):
|
||||
# Use the minimum value from the range for non-optimization run
|
||||
processed_settings[ind_name]['params'][param_name] = float(param_value['min'])
|
||||
else:
|
||||
processed_settings[ind_name]['params'][param_name] = float(param_value)
|
||||
|
||||
print(f"Processed indicator settings: {processed_settings}")
|
||||
|
||||
for ticker_data in tickers:
|
||||
try:
|
||||
# Extract ticker symbol from tuple if it's a tuple, otherwise use as is
|
||||
ticker = ticker_data[0] if isinstance(ticker_data, tuple) else ticker_data
|
||||
|
||||
print(f"\nTesting strategy on {ticker}")
|
||||
df = get_stock_data(ticker, start_date, end_date, 'daily')
|
||||
|
||||
if df is None or df.empty:
|
||||
print(f"No data available for {ticker}")
|
||||
continue
|
||||
|
||||
print(f"Data shape for {ticker}: {df.shape}")
|
||||
print(f"Date range: {df.index.min()} to {df.index.max()}")
|
||||
print(f"Columns: {df.columns.tolist()}")
|
||||
|
||||
try:
|
||||
df = prepare_data_for_backtest(df)
|
||||
except Exception as e:
|
||||
print(f"Error preparing data for {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
print(f"Prepared data shape: {df.shape}")
|
||||
print(f"Final columns: {df.columns.tolist()}")
|
||||
|
||||
# Run backtest
|
||||
try:
|
||||
# Set the indicator configs before creating the Backtest instance
|
||||
DynamicStrategy.indicator_configs = processed_settings
|
||||
print(f"Strategy configured with settings: {processed_settings}")
|
||||
|
||||
bt = Backtest(df, DynamicStrategy, cash=100000, commission=.002)
|
||||
stats = bt.run()
|
||||
|
||||
print(f"Backtest completed for {ticker}")
|
||||
print("Available stats keys:", stats.keys())
|
||||
|
||||
# Store results with error handling for each metric
|
||||
result = {
|
||||
'Ticker': ticker,
|
||||
'Return [%]': float(stats['Return [%]']), # Use direct key access
|
||||
'Sharpe Ratio': float(stats['Sharpe Ratio']),
|
||||
'Max Drawdown [%]': float(stats['Max. Drawdown [%]']),
|
||||
'Win Rate [%]': float(stats['Win Rate [%]']),
|
||||
'Number of Trades': int(stats['# Trades'])
|
||||
}
|
||||
|
||||
# Verify results are valid
|
||||
if any(pd.isna(val) for val in result.values()):
|
||||
print(f"Warning: NaN values in results for {ticker}")
|
||||
continue
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
print(f"Success - {ticker} results:")
|
||||
for key, value in result.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during backtest for {ticker}: {str(e)}")
|
||||
print(f"Current indicator settings: {processed_settings}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not all_results:
|
||||
print("\nNo valid results were generated. Debug information:")
|
||||
print(f"Number of tickers processed: {len(tickers)}")
|
||||
print(f"Indicator settings used: {processed_settings}")
|
||||
print(f"Date range: {start_date} to {end_date}")
|
||||
raise ValueError("No valid results were generated from any ticker")
|
||||
|
||||
return pd.DataFrame(all_results)
|
||||
|
||||
def display_backtest_results(results: Dict):
|
||||
"""Display single backtest results with metrics and plots"""
|
||||
st.subheader("Backtest Results")
|
||||
|
||||
# Display key metrics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Return", f"{results.get('Return [%]', 0):.2f}%")
|
||||
with col2:
|
||||
st.metric("Sharpe Ratio", f"{results.get('Sharpe Ratio', 0):.2f}")
|
||||
with col3:
|
||||
st.metric("Max Drawdown", f"{results.get('Max. Drawdown [%]', 0):.2f}%")
|
||||
with col4:
|
||||
st.metric("Win Rate", f"{results.get('Win Rate [%]', 0):.2f}%")
|
||||
|
||||
# Display full results in expandable section
|
||||
with st.expander("See detailed results"):
|
||||
st.write(results)
|
||||
@ -1 +0,0 @@
|
||||
# Trading journal package
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,112 +0,0 @@
|
||||
import streamlit as st
|
||||
|
||||
def strategy_guide_page():
|
||||
st.header("Trading Strategy Guide")
|
||||
|
||||
# Pre-Market Section
|
||||
with st.expander("1. Pre-Market Preparation (Night Before)", expanded=True):
|
||||
st.markdown("""
|
||||
* Run Screeners:
|
||||
- SunnyBands Screener
|
||||
- Heikin Ashi Screener
|
||||
* Review Each Chart For:
|
||||
- Volume patterns
|
||||
- Trend direction
|
||||
- SunnyBands indicator alignment
|
||||
* Add strong setups to watchlist
|
||||
* For each watchlist stock:
|
||||
- Use Monte Carlo Position Calculator
|
||||
- Pre-calculate optimal position sizes
|
||||
- Note potential entry zones
|
||||
* Set price alerts for:
|
||||
- Entry zones
|
||||
- Key SunnyBands levels
|
||||
""")
|
||||
|
||||
# Morning Trading Section
|
||||
with st.expander("2. Trading Strategy (6:30-7:30 AM PST)"):
|
||||
st.markdown("""
|
||||
**Objective:** Capture early momentum plays
|
||||
|
||||
**Best Setups:**
|
||||
1. SunnyBands 15-Minute Setups
|
||||
- Clear band crossovers
|
||||
- Volume confirmation
|
||||
- Trend alignment
|
||||
2. Heikin Ashi Confirmation
|
||||
- Strong candle patterns
|
||||
- Matches SunnyBands direction
|
||||
3. Volume-Based Entry
|
||||
- Above average volume
|
||||
- Clear direction
|
||||
|
||||
**Rules:**
|
||||
* Use Monte Carlo simulator for each entry:
|
||||
- Calculate optimal stop loss
|
||||
- Determine target price
|
||||
- Set position size
|
||||
* Set firm stop loss orders
|
||||
* Place alerts at target prices
|
||||
* When target alert hits:
|
||||
- Switch to trailing stop loss
|
||||
""")
|
||||
|
||||
# Midday Trading Section
|
||||
with st.expander("2. Trading Strategy (12:00-1:00 PM PST)"):
|
||||
st.markdown("""
|
||||
**Objective:** Identify clear trends for swing trades
|
||||
|
||||
**Best Setups:**
|
||||
1. SunnyBands Continuation
|
||||
- Holding within bands
|
||||
- Strong volume support
|
||||
2. Band Reversal Entry
|
||||
- Clear band crossover
|
||||
- Volume confirmation
|
||||
3. Momentum Continuation
|
||||
- Strong Heikin Ashi signals
|
||||
- Band support/resistance tests
|
||||
|
||||
**Rules:**
|
||||
* Recheck Monte Carlo projections
|
||||
* Adjust stops based on new calculations
|
||||
* Monitor volume for trend strength
|
||||
* Keep trailing stops on winning positions
|
||||
""")
|
||||
|
||||
# Execution Rules Section
|
||||
with st.expander("4. Trade Execution Rules"):
|
||||
st.markdown("""
|
||||
* Pre-Entry:
|
||||
- Run Monte Carlo simulation
|
||||
- Calculate position size
|
||||
- Determine stop loss and target
|
||||
|
||||
* Entry Rules:
|
||||
- Enter only at planned levels
|
||||
- Use limit orders when possible
|
||||
- Confirm volume and direction
|
||||
|
||||
* Position Management:
|
||||
- Set hard stop loss immediately
|
||||
- Place target price alerts
|
||||
- Convert to trailing stop when target hits
|
||||
|
||||
* Risk Management:
|
||||
- Follow Monte Carlo stop loss levels
|
||||
- Use calculated position sizes only
|
||||
- No overriding system calculations
|
||||
""")
|
||||
|
||||
# Review Section
|
||||
with st.expander("5. Post-Trading Review"):
|
||||
st.markdown("""
|
||||
* Journal every trade with:
|
||||
- Entry/exit points
|
||||
- Monte Carlo projections vs actual
|
||||
- SunnyBands signal accuracy
|
||||
* Review screener effectiveness
|
||||
* Analyze stop loss performance
|
||||
* Check target price hit rates
|
||||
* Evaluate trailing stop effectiveness
|
||||
""")
|
||||
@ -1,116 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from screener.canslim_controller import run_canslim_screener
|
||||
from db.db_connection import create_client
|
||||
from utils.report_utils import load_scanner_reports
|
||||
|
||||
|
||||
def canslim_screener_page():
|
||||
st.header("CANSLIM Screener")
|
||||
|
||||
# Create tabs for scanner and reports
|
||||
scanner_tab, reports_tab = st.tabs(["Run Scanner", "View Reports"])
|
||||
|
||||
with scanner_tab:
|
||||
# Date range selection
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
start_date = st.date_input("Start Date")
|
||||
with col2:
|
||||
end_date = st.date_input("End Date")
|
||||
|
||||
# CANSLIM criteria selection
|
||||
st.subheader("Select Screening Criteria")
|
||||
|
||||
c_criteria = st.expander("Current Quarterly Earnings (C)")
|
||||
with c_criteria:
|
||||
eps_threshold = st.slider("EPS Growth Threshold (%)", 0, 100, 25)
|
||||
sales_threshold = st.slider("Sales Growth Threshold (%)", 0, 100, 25)
|
||||
roe_threshold = st.slider("ROE Threshold (%)", 0, 50, 17)
|
||||
|
||||
a_criteria = st.expander("Annual Earnings Growth (A)")
|
||||
with a_criteria:
|
||||
annual_eps_threshold = st.slider("Annual EPS Growth Threshold (%)", 0, 100, 25)
|
||||
|
||||
l_criteria = st.expander("Industry Leadership (L)")
|
||||
with l_criteria:
|
||||
use_l = st.checkbox("Check Industry Leadership", value=True)
|
||||
l_threshold = st.slider("Industry Leadership Score Threshold", 0.0, 1.0, 0.7)
|
||||
|
||||
i_criteria = st.expander("Institutional Sponsorship (I)")
|
||||
with i_criteria:
|
||||
use_i = st.checkbox("Check Institutional Sponsorship", value=True)
|
||||
i_threshold = st.slider("Institutional Sponsorship Score Threshold", 0.0, 1.0, 0.7)
|
||||
|
||||
if st.button("Run CANSLIM Screener"):
|
||||
with st.spinner("Running CANSLIM screener..."):
|
||||
try:
|
||||
# Prepare selected screeners dictionary
|
||||
selected_screeners = {
|
||||
"C": {
|
||||
"EPS_Score": eps_threshold / 100,
|
||||
"Sales_Score": sales_threshold / 100,
|
||||
"ROE_Score": roe_threshold / 100
|
||||
},
|
||||
"A": {
|
||||
"Annual_EPS_Score": annual_eps_threshold / 100
|
||||
}
|
||||
}
|
||||
|
||||
if use_l:
|
||||
selected_screeners["L"] = {"L_Score": l_threshold}
|
||||
if use_i:
|
||||
selected_screeners["I"] = {"I_Score": i_threshold}
|
||||
|
||||
# Convert dates to strings for the screener
|
||||
start_str = start_date.strftime("%Y-%m-%d")
|
||||
end_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Run the screener
|
||||
st.session_state.screener_params = {
|
||||
"start_date": start_str,
|
||||
"end_date": end_str,
|
||||
"selected_screeners": selected_screeners
|
||||
}
|
||||
|
||||
# Modify run_canslim_screener to accept parameters
|
||||
run_canslim_screener()
|
||||
st.success("Screening complete! Check the Reports tab for results.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running CANSLIM screener: {str(e)}")
|
||||
|
||||
with reports_tab:
|
||||
st.subheader("CANSLIM Reports")
|
||||
|
||||
reports = load_scanner_reports(scanner_type="canslim")
|
||||
if reports:
|
||||
# Create a selectbox to choose the report
|
||||
selected_report = st.selectbox(
|
||||
"Select Report",
|
||||
options=reports,
|
||||
format_func=lambda x: f"{x['name']} ({x['created'].strftime('%Y-%m-%d %H:%M')})",
|
||||
key="canslim_scanner_report"
|
||||
)
|
||||
|
||||
if selected_report:
|
||||
try:
|
||||
# Load and display the CSV
|
||||
df = pd.read_csv(selected_report['path'])
|
||||
|
||||
# Add download button
|
||||
st.download_button(
|
||||
label="Download Report",
|
||||
data=df.to_csv(index=False),
|
||||
file_name=selected_report['name'],
|
||||
mime='text/csv'
|
||||
)
|
||||
|
||||
# Display the dataframe
|
||||
st.dataframe(df)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error loading report: {str(e)}")
|
||||
else:
|
||||
st.info("No CANSLIM reports found")
|
||||
@ -1,135 +0,0 @@
|
||||
import streamlit as st
|
||||
from screener.scanner_controller import run_technical_scanner
|
||||
from utils.report_utils import load_scanner_reports
|
||||
from screener.t_candlestick import CANDLESTICK_PATTERNS
|
||||
import pandas as pd
|
||||
|
||||
def technical_scanner_page():
|
||||
st.header("Technical Scanner")
|
||||
|
||||
# Create tabs for scanner and reports
|
||||
scanner_tab, reports_tab = st.tabs(["Run Scanner", "View Reports"])
|
||||
|
||||
with scanner_tab:
|
||||
scanner_type = st.selectbox(
|
||||
"Select Scanner",
|
||||
["SunnyBands", "ATR-EMA", "ATR-EMA v2", "Heikin-Ashi", "Candlestick", "Sunny-SMA"],
|
||||
key="tech_scanner_type"
|
||||
)
|
||||
|
||||
# Add candlestick pattern selection when Candlestick scanner is chosen
|
||||
selected_patterns = None
|
||||
if scanner_type == "Candlestick":
|
||||
selected_patterns = st.multiselect(
|
||||
"Select Candlestick Patterns",
|
||||
options=list(CANDLESTICK_PATTERNS.keys()),
|
||||
default=[],
|
||||
format_func=lambda x: CANDLESTICK_PATTERNS[x]['description'],
|
||||
help="Choose which candlestick patterns to scan for"
|
||||
)
|
||||
|
||||
# Add interval selection
|
||||
interval = st.selectbox(
|
||||
"Select Time Interval",
|
||||
["Daily", "5 minute", "15 minute", "30 minute", "1 hour"],
|
||||
key="interval_select"
|
||||
)
|
||||
|
||||
# Convert interval to format expected by scanner
|
||||
interval_map = {
|
||||
"Daily": "1d",
|
||||
"5 minute": "5m",
|
||||
"15 minute": "15m",
|
||||
"30 minute": "30m",
|
||||
"1 hour": "1h"
|
||||
}
|
||||
selected_interval = interval_map[interval]
|
||||
|
||||
# Date range selection
|
||||
date_col1, date_col2 = st.columns(2)
|
||||
with date_col1:
|
||||
start_date = st.date_input("Start Date")
|
||||
with date_col2:
|
||||
end_date = st.date_input("End Date")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
min_price = st.number_input("Minimum Price", value=5.0, step=0.1)
|
||||
max_price = st.number_input("Maximum Price", value=100.0, step=0.1)
|
||||
|
||||
with col2:
|
||||
min_volume = st.number_input("Minimum Volume", value=500000, step=100000)
|
||||
portfolio_size = st.number_input("Portfolio Size", value=100000.0, step=1000.0)
|
||||
|
||||
if st.button("Run Scanner"):
|
||||
with st.spinner("Running scanner..."):
|
||||
try:
|
||||
scanner_args = {
|
||||
"scanner_choice": scanner_type.lower().replace(" ", "_"),
|
||||
"start_date": start_date.strftime("%Y-%m-%d"),
|
||||
"end_date": end_date.strftime("%Y-%m-%d"),
|
||||
"min_price": min_price,
|
||||
"max_price": max_price,
|
||||
"min_volume": min_volume,
|
||||
"portfolio_size": portfolio_size,
|
||||
"interval": selected_interval
|
||||
}
|
||||
|
||||
# Add selected patterns if using candlestick scanner
|
||||
if scanner_type == "Candlestick":
|
||||
scanner_args["selected_patterns"] = selected_patterns
|
||||
|
||||
signals = run_technical_scanner(**scanner_args)
|
||||
if signals:
|
||||
st.success(f"Found {len(signals)} signals")
|
||||
# Create a summary table
|
||||
summary_df = pd.DataFrame(signals)[['ticker', 'entry_price', 'target_price', 'stop_loss']]
|
||||
summary_df = summary_df.round(2) # Round numeric columns to 2 decimal places
|
||||
st.dataframe(summary_df)
|
||||
|
||||
# Add a download button for the full results
|
||||
full_df = pd.DataFrame(signals)
|
||||
st.download_button(
|
||||
label="Download Full Results",
|
||||
data=full_df.to_csv(index=False),
|
||||
file_name="scanner_results.csv",
|
||||
mime='text/csv'
|
||||
)
|
||||
else:
|
||||
st.info("No signals found")
|
||||
except Exception as e:
|
||||
st.error(f"Error running scanner: {str(e)}")
|
||||
|
||||
with reports_tab:
|
||||
st.subheader("Scanner Reports")
|
||||
|
||||
reports = load_scanner_reports(scanner_type="technical")
|
||||
if reports:
|
||||
# Create a selectbox to choose the report
|
||||
selected_report = st.selectbox(
|
||||
"Select Report",
|
||||
options=reports,
|
||||
format_func=lambda x: f"{x['name']} ({x['created'].strftime('%Y-%m-%d %H:%M')})",
|
||||
key="tech_scanner_report"
|
||||
)
|
||||
|
||||
if selected_report:
|
||||
try:
|
||||
# Load and display the CSV
|
||||
df = pd.read_csv(selected_report['path'])
|
||||
|
||||
# Add download button
|
||||
st.download_button(
|
||||
label="Download Report",
|
||||
data=df.to_csv(index=False),
|
||||
file_name=selected_report['name'],
|
||||
mime='text/csv'
|
||||
)
|
||||
|
||||
# Display the dataframe
|
||||
st.dataframe(df)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error loading report: {str(e)}")
|
||||
else:
|
||||
st.info("No scanner reports found")
|
||||
@ -1 +0,0 @@
|
||||
# Trading pages package
|
||||
@ -1,483 +0,0 @@
|
||||
import streamlit as st
|
||||
from db.db_connection import create_client
|
||||
from trading.trading_plan import (
|
||||
delete_trading_plan,
|
||||
TradingPlan, PlanStatus, Timeframe, MarketFocus, TradeFrequency,
|
||||
save_trading_plan, get_trading_plan, get_all_trading_plans,
|
||||
update_trading_plan, get_plan_trades,
|
||||
link_trades_to_plan, calculate_plan_metrics, unlink_trades_from_plan
|
||||
)
|
||||
|
||||
def trading_plan_page():
|
||||
st.header("Trading Plans")
|
||||
|
||||
# Create tabs for different plan operations
|
||||
list_tab, add_tab, edit_tab = st.tabs(["View Plans", "Add Plan", "Edit Plan"])
|
||||
|
||||
with list_tab:
|
||||
st.subheader("Trading Plans")
|
||||
plans = get_all_trading_plans()
|
||||
|
||||
if plans:
|
||||
for plan in plans:
|
||||
with st.expander(f"{plan.plan_name} ({plan.status.value})"):
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("### Basic Information")
|
||||
st.write(f"**Plan Name:** {plan.plan_name}")
|
||||
st.write(f"**Status:** {plan.status.value}")
|
||||
st.write(f"**Author:** {plan.plan_author}")
|
||||
st.write(f"**Version:** {plan.strategy_version}")
|
||||
st.write(f"**Created:** {plan.created_at}")
|
||||
st.write(f"**Updated:** {plan.updated_at}")
|
||||
|
||||
st.markdown("### Market Details")
|
||||
st.write(f"**Timeframe:** {plan.timeframe.value}")
|
||||
st.write(f"**Market:** {plan.market_focus.value}")
|
||||
st.write(f"**Frequency:** {plan.trade_frequency.value}")
|
||||
if plan.sector_focus:
|
||||
st.write(f"**Sector Focus:** {plan.sector_focus}")
|
||||
|
||||
with col2:
|
||||
st.markdown("### Risk Parameters")
|
||||
st.write(f"**Stop Loss:** {plan.stop_loss}%")
|
||||
st.write(f"**Profit Target:** {plan.profit_target}%")
|
||||
st.write(f"**Risk/Reward Ratio:** {plan.risk_reward_ratio}")
|
||||
st.write(f"**Position Size:** {plan.position_sizing}%")
|
||||
st.write(f"**Risk per Trade:** {plan.total_risk_per_trade}%")
|
||||
st.write(f"**Max Portfolio Risk:** {plan.max_portfolio_risk}%")
|
||||
st.write(f"**Max Drawdown:** {plan.maximum_drawdown}%")
|
||||
st.write(f"**Max Trades/Day:** {plan.max_trades_per_day}")
|
||||
st.write(f"**Max Trades/Week:** {plan.max_trades_per_week}")
|
||||
|
||||
st.markdown("### Performance Metrics")
|
||||
if any([plan.win_rate, plan.average_return_per_trade, plan.profit_factor]):
|
||||
col3, col4 = st.columns(2)
|
||||
with col3:
|
||||
if plan.win_rate:
|
||||
st.write(f"**Win Rate:** {plan.win_rate}%")
|
||||
if plan.average_return_per_trade:
|
||||
st.write(f"**Avg Return/Trade:** {plan.average_return_per_trade}%")
|
||||
with col4:
|
||||
if plan.profit_factor:
|
||||
st.write(f"**Profit Factor:** {plan.profit_factor}")
|
||||
|
||||
st.markdown("### Linked Trades")
|
||||
plan_trades = get_plan_trades(plan.id)
|
||||
if plan_trades:
|
||||
total_pl = 0
|
||||
winning_trades = 0
|
||||
total_trades = len(plan_trades)
|
||||
|
||||
st.markdown("#### Trade Statistics")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.write(f"**Total Trades:** {total_trades}")
|
||||
st.write(f"**Winning Trades:** {winning_trades}")
|
||||
if total_trades > 0:
|
||||
st.write(f"**Win Rate:** {(winning_trades/total_trades)*100:.2f}%")
|
||||
|
||||
with col2:
|
||||
st.write(f"**Total P/L:** ${total_pl:.2f}")
|
||||
if total_trades > 0:
|
||||
st.write(f"**Average P/L per Trade:** ${total_pl/total_trades:.2f}")
|
||||
|
||||
st.markdown("#### Individual Trades")
|
||||
for trade in plan_trades:
|
||||
st.markdown("---")
|
||||
cols = st.columns(3)
|
||||
with cols[0]:
|
||||
st.write(f"**{trade['ticker']} - {trade['entry_date']}**")
|
||||
st.write(f"**Direction:** {trade['direction']}")
|
||||
if trade['strategy']:
|
||||
st.write(f"**Strategy:** {trade['strategy']}")
|
||||
|
||||
with cols[1]:
|
||||
st.write(f"**Entry:** ${trade['entry_price']:.2f}")
|
||||
st.write(f"**Shares:** {trade['shares']}")
|
||||
|
||||
with cols[2]:
|
||||
if trade['exit_price']:
|
||||
pl = (trade['exit_price'] - trade['entry_price']) * trade['shares']
|
||||
total_pl += pl
|
||||
if pl > 0:
|
||||
winning_trades += 1
|
||||
st.write(f"**Exit:** ${trade['exit_price']:.2f}")
|
||||
st.write(f"**P/L:** ${pl:.2f}")
|
||||
st.write(f"**Exit Date:** {trade['exit_date']}")
|
||||
else:
|
||||
st.write("**Status:** Open")
|
||||
else:
|
||||
st.info("No trades linked to this plan")
|
||||
|
||||
st.markdown("### Strategy Details")
|
||||
st.write("**Entry Criteria:**")
|
||||
st.write(plan.entry_criteria)
|
||||
|
||||
st.write("**Exit Criteria:**")
|
||||
st.write(plan.exit_criteria)
|
||||
|
||||
st.write("**Entry Confirmation:**")
|
||||
st.write(plan.entry_confirmation)
|
||||
|
||||
st.write("**Market Conditions:**")
|
||||
st.write(plan.market_conditions)
|
||||
|
||||
st.write("**Technical Indicators:**")
|
||||
st.write(plan.indicators_used)
|
||||
|
||||
st.markdown("### Risk Management")
|
||||
st.write("**Drawdown Adjustments:**")
|
||||
st.write(plan.adjustments_for_drawdown)
|
||||
|
||||
st.write("**Risk Controls:**")
|
||||
st.write(plan.risk_controls)
|
||||
|
||||
if plan.fundamental_criteria:
|
||||
st.markdown("### Fundamental Analysis")
|
||||
st.write(plan.fundamental_criteria)
|
||||
|
||||
if plan.options_strategy_details:
|
||||
st.markdown("### Options Strategy")
|
||||
st.write(plan.options_strategy_details)
|
||||
|
||||
if plan.improvements_needed:
|
||||
st.markdown("### Areas for Improvement")
|
||||
st.write(plan.improvements_needed)
|
||||
|
||||
if plan.trade_review_notes:
|
||||
st.markdown("### Trade Review Notes")
|
||||
st.write(plan.trade_review_notes)
|
||||
|
||||
if plan.future_testing_ideas:
|
||||
st.markdown("### Future Testing Ideas")
|
||||
st.write(plan.future_testing_ideas)
|
||||
|
||||
if plan.historical_backtest_results:
|
||||
st.markdown("### Historical Backtest Results")
|
||||
st.write(plan.historical_backtest_results)
|
||||
|
||||
if plan.real_trade_performance:
|
||||
st.markdown("### Real Trading Performance")
|
||||
st.write(plan.real_trade_performance)
|
||||
else:
|
||||
st.info("No trading plans found")
|
||||
|
||||
with add_tab:
|
||||
st.subheader("Create New Trading Plan")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
plan_name = st.text_input("Plan Name", key="add_plan_name")
|
||||
status = st.selectbox("Status", [s.value for s in PlanStatus], key="add_status")
|
||||
timeframe = st.selectbox("Timeframe", [t.value for t in Timeframe], key="add_timeframe")
|
||||
market_focus = st.selectbox("Market Focus", [m.value for m in MarketFocus], key="add_market_focus")
|
||||
|
||||
with col2:
|
||||
trade_frequency = st.selectbox("Trade Frequency", [f.value for f in TradeFrequency], key="add_trade_frequency")
|
||||
plan_author = st.text_input("Author", key="add_plan_author")
|
||||
strategy_version = st.number_input("Version", min_value=1, value=1, key="add_strategy_version")
|
||||
|
||||
st.subheader("Risk Parameters")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
stop_loss = st.number_input("Stop Loss %", min_value=0.1, value=7.0, key="add_stop_loss")
|
||||
profit_target = st.number_input("Profit Target %", min_value=0.1, value=21.0, key="add_profit_target")
|
||||
risk_reward_ratio = profit_target / stop_loss if stop_loss > 0 else 0
|
||||
st.write(f"Risk:Reward Ratio: {risk_reward_ratio:.2f}")
|
||||
|
||||
with col2:
|
||||
position_sizing = st.number_input("Position Size %", min_value=0.1, value=5.0, key="add_position_sizing")
|
||||
total_risk_per_trade = st.number_input("Risk per Trade %", min_value=0.1, value=1.0, key="add_total_risk_per_trade")
|
||||
max_portfolio_risk = st.number_input("Max Portfolio Risk %", min_value=0.1, value=5.0, key="add_max_portfolio_risk")
|
||||
|
||||
st.subheader("Trade Rules")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
max_trades_per_day = st.number_input("Max Trades per Day", min_value=1, value=3, key="add_max_trades_per_day")
|
||||
max_trades_per_week = st.number_input("Max Trades per Week", min_value=1, value=15, key="add_max_trades_per_week")
|
||||
maximum_drawdown = st.number_input("Maximum Drawdown %", min_value=0.1, value=20.0, key="add_maximum_drawdown")
|
||||
|
||||
st.subheader("Strategy Details")
|
||||
entry_criteria = st.text_area("Entry Criteria", key="add_entry_criteria")
|
||||
exit_criteria = st.text_area("Exit Criteria", key="add_exit_criteria")
|
||||
entry_confirmation = st.text_area("Entry Confirmation", key="add_entry_confirmation")
|
||||
market_conditions = st.text_area("Market Conditions", key="add_market_conditions")
|
||||
indicators_used = st.text_area("Technical Indicators", key="add_indicators_used")
|
||||
|
||||
st.subheader("Risk Management")
|
||||
adjustments_for_drawdown = st.text_area("Drawdown Adjustments", key="add_adjustments_for_drawdown")
|
||||
risk_controls = st.text_area("Risk Controls", key="add_risk_controls")
|
||||
|
||||
st.subheader("Additional Information")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
sector_focus = st.text_input("Sector Focus (optional)", key="add_sector_focus")
|
||||
fundamental_criteria = st.text_area("Fundamental Criteria (optional)", key="add_fundamental_criteria")
|
||||
|
||||
with col2:
|
||||
options_strategy_details = st.text_area("Options Strategy Details (optional)", key="add_options_strategy_details")
|
||||
improvements_needed = st.text_area("Improvements Needed (optional)", key="add_improvements_needed")
|
||||
|
||||
if st.button("Create Trading Plan", key="create_plan_button"):
|
||||
try:
|
||||
plan = TradingPlan(
|
||||
plan_name=plan_name,
|
||||
status=PlanStatus(status),
|
||||
timeframe=Timeframe(timeframe),
|
||||
market_focus=MarketFocus(market_focus),
|
||||
trade_frequency=TradeFrequency(trade_frequency),
|
||||
entry_criteria=entry_criteria,
|
||||
exit_criteria=exit_criteria,
|
||||
stop_loss=stop_loss,
|
||||
profit_target=profit_target,
|
||||
risk_reward_ratio=risk_reward_ratio,
|
||||
entry_confirmation=entry_confirmation,
|
||||
position_sizing=position_sizing,
|
||||
maximum_drawdown=maximum_drawdown,
|
||||
max_trades_per_day=max_trades_per_day,
|
||||
max_trades_per_week=max_trades_per_week,
|
||||
total_risk_per_trade=total_risk_per_trade,
|
||||
max_portfolio_risk=max_portfolio_risk,
|
||||
adjustments_for_drawdown=adjustments_for_drawdown,
|
||||
risk_controls=risk_controls,
|
||||
market_conditions=market_conditions,
|
||||
indicators_used=indicators_used,
|
||||
plan_author=plan_author,
|
||||
strategy_version=strategy_version,
|
||||
sector_focus=sector_focus,
|
||||
fundamental_criteria=fundamental_criteria,
|
||||
options_strategy_details=options_strategy_details,
|
||||
improvements_needed=improvements_needed
|
||||
)
|
||||
|
||||
save_trading_plan(plan)
|
||||
st.success("Trading plan created successfully!")
|
||||
st.query_params.update(rerun=True)
|
||||
except Exception as e:
|
||||
st.error(f"Error creating trading plan: {str(e)}")
|
||||
|
||||
with edit_tab:
|
||||
st.subheader("Edit Trading Plan")
|
||||
plans = get_all_trading_plans()
|
||||
|
||||
if plans:
|
||||
selected_plan_id = st.selectbox(
|
||||
"Select Plan to Edit",
|
||||
options=[plan.id for plan in plans],
|
||||
format_func=lambda x: next(p.plan_name for p in plans if p.id == x),
|
||||
key="edit_plan_select"
|
||||
)
|
||||
|
||||
if selected_plan_id:
|
||||
plan = get_trading_plan(selected_plan_id)
|
||||
if plan:
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
plan_name = st.text_input("Plan Name", value=plan.plan_name, key="edit_plan_name")
|
||||
status = st.selectbox("Status", [s.value for s in PlanStatus], index=[s.value for s in PlanStatus].index(plan.status.value), key="edit_status")
|
||||
timeframe = st.selectbox("Timeframe", [t.value for t in Timeframe], index=[t.value for t in Timeframe].index(plan.timeframe.value), key="edit_timeframe")
|
||||
market_focus = st.selectbox("Market Focus", [m.value for m in MarketFocus], index=[m.value for m in MarketFocus].index(plan.market_focus.value), key="edit_market_focus")
|
||||
|
||||
with col2:
|
||||
trade_frequency = st.selectbox("Trade Frequency", [f.value for f in TradeFrequency], index=[f.value for f in TradeFrequency].index(plan.trade_frequency.value), key="edit_trade_frequency")
|
||||
plan_author = st.text_input("Author", value=plan.plan_author, key="edit_plan_author")
|
||||
strategy_version = st.number_input("Version", min_value=1, value=plan.strategy_version, key="edit_strategy_version")
|
||||
|
||||
st.subheader("Risk Parameters")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
stop_loss = st.number_input("Stop Loss %", min_value=0.1, value=plan.stop_loss, key="edit_stop_loss")
|
||||
profit_target = st.number_input("Profit Target %", min_value=0.1, value=plan.profit_target, key="edit_profit_target")
|
||||
risk_reward_ratio = profit_target / stop_loss if stop_loss > 0 else 0
|
||||
st.write(f"Risk:Reward Ratio: {risk_reward_ratio:.2f}")
|
||||
|
||||
with col2:
|
||||
position_sizing = st.number_input("Position Size %", min_value=0.1, value=plan.position_sizing, key="edit_position_sizing")
|
||||
total_risk_per_trade = st.number_input("Risk per Trade %", min_value=0.1, value=plan.total_risk_per_trade, key="edit_total_risk_per_trade")
|
||||
max_portfolio_risk = st.number_input("Max Portfolio Risk %", min_value=0.1, value=plan.max_portfolio_risk, key="edit_max_portfolio_risk")
|
||||
|
||||
st.subheader("Trade Rules")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
max_trades_per_day = st.number_input("Max Trades per Day", min_value=1, value=plan.max_trades_per_day, key="edit_max_trades_per_day")
|
||||
max_trades_per_week = st.number_input("Max Trades per Week", min_value=1, value=plan.max_trades_per_week, key="edit_max_trades_per_week")
|
||||
maximum_drawdown = st.number_input("Maximum Drawdown %", min_value=0.1, value=plan.maximum_drawdown, key="edit_maximum_drawdown")
|
||||
|
||||
st.subheader("Strategy Details")
|
||||
entry_criteria = st.text_area("Entry Criteria", value=plan.entry_criteria, key="edit_entry_criteria")
|
||||
exit_criteria = st.text_area("Exit Criteria", value=plan.exit_criteria, key="edit_exit_criteria")
|
||||
entry_confirmation = st.text_area("Entry Confirmation", value=plan.entry_confirmation, key="edit_entry_confirmation")
|
||||
market_conditions = st.text_area("Market Conditions", value=plan.market_conditions, key="edit_market_conditions")
|
||||
indicators_used = st.text_area("Technical Indicators", value=plan.indicators_used, key="edit_indicators_used")
|
||||
|
||||
st.subheader("Risk Management")
|
||||
adjustments_for_drawdown = st.text_area("Drawdown Adjustments", value=plan.adjustments_for_drawdown, key="edit_adjustments_for_drawdown")
|
||||
risk_controls = st.text_area("Risk Controls", value=plan.risk_controls, key="edit_risk_controls")
|
||||
|
||||
st.subheader("Additional Information")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
sector_focus = st.text_input("Sector Focus (optional)", value=plan.sector_focus, key="edit_sector_focus")
|
||||
fundamental_criteria = st.text_area("Fundamental Criteria (optional)", value=plan.fundamental_criteria, key="edit_fundamental_criteria")
|
||||
|
||||
with col2:
|
||||
options_strategy_details = st.text_area("Options Strategy Details (optional)", value=plan.options_strategy_details, key="edit_options_strategy_details")
|
||||
improvements_needed = st.text_area("Improvements Needed (optional)", value=plan.improvements_needed, key="edit_improvements_needed")
|
||||
|
||||
if st.button("Update Plan", key="update_plan_button"):
|
||||
try:
|
||||
plan.plan_name = plan_name
|
||||
plan.status = PlanStatus(status)
|
||||
plan.timeframe = Timeframe(timeframe)
|
||||
plan.market_focus = MarketFocus(market_focus)
|
||||
plan.trade_frequency = TradeFrequency(trade_frequency)
|
||||
plan.plan_author = plan_author
|
||||
plan.strategy_version = strategy_version
|
||||
plan.stop_loss = stop_loss
|
||||
plan.profit_target = profit_target
|
||||
plan.position_sizing = position_sizing
|
||||
plan.total_risk_per_trade = total_risk_per_trade
|
||||
plan.max_portfolio_risk = max_portfolio_risk
|
||||
plan.max_trades_per_day = max_trades_per_day
|
||||
plan.max_trades_per_week = max_trades_per_week
|
||||
plan.maximum_drawdown = maximum_drawdown
|
||||
plan.entry_criteria = entry_criteria
|
||||
plan.exit_criteria = exit_criteria
|
||||
plan.entry_confirmation = entry_confirmation
|
||||
plan.market_conditions = market_conditions
|
||||
plan.indicators_used = indicators_used
|
||||
plan.adjustments_for_drawdown = adjustments_for_drawdown
|
||||
plan.risk_controls = risk_controls
|
||||
plan.sector_focus = sector_focus
|
||||
plan.fundamental_criteria = fundamental_criteria
|
||||
plan.options_strategy_details = options_strategy_details
|
||||
plan.improvements_needed = improvements_needed
|
||||
|
||||
update_trading_plan(plan)
|
||||
st.success("Plan updated successfully!")
|
||||
st.query_params.update(rerun=True)
|
||||
except Exception as e:
|
||||
st.error(f"Error updating plan: {str(e)}")
|
||||
|
||||
if st.button("Delete Plan", key="delete_plan_button"):
|
||||
try:
|
||||
delete_trading_plan(plan.id)
|
||||
st.success("Plan deleted successfully!")
|
||||
st.query_params.update(rerun=True)
|
||||
except Exception as e:
|
||||
st.error(f"Error deleting plan: {str(e)}")
|
||||
|
||||
st.subheader("Trade Management")
|
||||
|
||||
plan_trades = get_plan_trades(plan.id)
|
||||
|
||||
if plan_trades:
|
||||
st.write("Current Trades:")
|
||||
for trade in plan_trades:
|
||||
with st.expander(f"{trade['ticker']} - {trade['entry_date']}"):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.write(f"Entry: ${trade['entry_price']:.2f}")
|
||||
st.write(f"Shares: {trade['shares']}")
|
||||
with col2:
|
||||
if trade['exit_price']:
|
||||
pl = (trade['exit_price'] - trade['entry_price']) * trade['shares']
|
||||
st.write(f"Exit: ${trade['exit_price']:.2f}")
|
||||
st.write(f"P/L: ${pl:.2f}")
|
||||
|
||||
if st.button("Unlink Trade", key=f"unlink_trade_{trade['id']}"):
|
||||
try:
|
||||
query = """
|
||||
ALTER TABLE stock_db.trades
|
||||
UPDATE plan_id = NULL
|
||||
WHERE id = %(trade_id)s
|
||||
"""
|
||||
with create_client() as client:
|
||||
client.command(query, {'trade_id': trade['id']})
|
||||
|
||||
metrics = calculate_plan_metrics(plan.id)
|
||||
plan.win_rate = metrics['win_rate']
|
||||
plan.average_return_per_trade = metrics['average_return']
|
||||
plan.profit_factor = metrics['profit_factor']
|
||||
update_trading_plan(plan)
|
||||
|
||||
st.success(f"Trade unlinked successfully!")
|
||||
st.query_params.update(rerun=True)
|
||||
except Exception as e:
|
||||
st.error(f"Error unlinking trade: {str(e)}")
|
||||
|
||||
if st.button("Unlink All Trades", key=f"unlink_all_trades_{plan.id}"):
|
||||
try:
|
||||
if unlink_trades_from_plan(plan.id):
|
||||
plan.win_rate = None
|
||||
plan.average_return_per_trade = None
|
||||
plan.profit_factor = None
|
||||
update_trading_plan(plan)
|
||||
|
||||
st.success("All trades unlinked successfully!")
|
||||
st.query_params.update(rerun=True)
|
||||
else:
|
||||
st.error("Error unlinking trades")
|
||||
except Exception as e:
|
||||
st.error(f"Error unlinking trades: {str(e)}")
|
||||
|
||||
with create_client() as client:
|
||||
query = """
|
||||
SELECT
|
||||
id,
|
||||
ticker,
|
||||
entry_date,
|
||||
entry_price,
|
||||
shares,
|
||||
exit_price,
|
||||
exit_date,
|
||||
direction,
|
||||
strategy,
|
||||
CASE
|
||||
WHEN exit_price IS NOT NULL
|
||||
THEN (exit_price - entry_price) * shares
|
||||
ELSE NULL
|
||||
END as profit_loss
|
||||
FROM stock_db.trades
|
||||
WHERE plan_id IS NULL
|
||||
ORDER BY entry_date DESC
|
||||
"""
|
||||
result = client.query(query)
|
||||
available_trades = [dict(zip(
|
||||
['id', 'ticker', 'entry_date', 'entry_price', 'shares',
|
||||
'exit_price', 'exit_date', 'direction', 'strategy', 'profit_loss'],
|
||||
row
|
||||
)) for row in result.result_rows]
|
||||
|
||||
if available_trades:
|
||||
st.write("Link Existing Trades:")
|
||||
selected_trades = st.multiselect(
|
||||
"Select trades to link to this plan",
|
||||
options=[t['id'] for t in available_trades],
|
||||
format_func=lambda x: next(
|
||||
f"{t['ticker']} - {t['entry_date']} - ${t['entry_price']:.2f} "
|
||||
f"({t['direction']}) - {t['strategy']} "
|
||||
f"{'[Closed]' if t['exit_price'] else '[Open]'} "
|
||||
f"{'P/L: $' + format(t['profit_loss'], '.2f') if t['profit_loss'] is not None else ''}"
|
||||
for t in available_trades if t['id'] == x
|
||||
),
|
||||
key=f"link_trades_{plan.id}"
|
||||
)
|
||||
|
||||
if selected_trades and st.button("Link Selected Trades", key=f"link_trades_button_{plan.id}"):
|
||||
if link_trades_to_plan(plan.id, selected_trades):
|
||||
st.success("Trades linked successfully!")
|
||||
|
||||
metrics = calculate_plan_metrics(plan.id)
|
||||
plan.win_rate = metrics['win_rate']
|
||||
plan.average_return_per_trade = metrics['average_return']
|
||||
plan.profit_factor = metrics['profit_factor']
|
||||
update_trading_plan(plan)
|
||||
|
||||
st.query_params.update(rerun=True)
|
||||
else:
|
||||
st.error("Error linking trades")
|
||||
else:
|
||||
st.info("No plans available to edit")
|
||||
@ -1,766 +0,0 @@
|
||||
import logging
|
||||
import streamlit as st
|
||||
from trading.journal import get_latest_portfolio_value, get_open_trades_summary
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from utils.data_utils import get_current_prices
|
||||
from pages.analysis.monte_carlo_page import MonteCarloSimulator
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
from utils.common_utils import get_stock_data
|
||||
from trading.watchlist import (
|
||||
create_watchlist, get_watchlists, add_to_watchlist,
|
||||
remove_from_watchlist, get_watchlist_items, WatchlistItem,
|
||||
ensure_tables_exist
|
||||
)
|
||||
from db.db_connection import create_client
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def trading_system_page():
|
||||
# Initialize session state
|
||||
if 'prefill_watchlist' not in st.session_state:
|
||||
st.session_state.prefill_watchlist = None
|
||||
|
||||
st.header("Trading System")
|
||||
|
||||
# Create tabs
|
||||
tab1, tab2, tab3 = st.tabs(["Position Calculator", "Prop Firm Calculator", "Watch Lists"])
|
||||
|
||||
# Tab 1: Position Calculator
|
||||
with tab1:
|
||||
st.subheader("Position Calculator")
|
||||
|
||||
# Get latest portfolio value and open trades for total portfolio calculation
|
||||
portfolio_data = get_latest_portfolio_value()
|
||||
cash_balance = portfolio_data['cash_balance'] if portfolio_data else 0
|
||||
|
||||
# Calculate total portfolio value including open positions
|
||||
open_summary = get_open_trades_summary()
|
||||
total_position_value = 0
|
||||
total_paper_pl = 0
|
||||
|
||||
if open_summary:
|
||||
# Get current prices for all open positions
|
||||
unique_tickers = list(set(summary['ticker'] for summary in open_summary))
|
||||
current_prices = get_current_prices(unique_tickers)
|
||||
|
||||
# Calculate total invested value and paper P/L
|
||||
for summary in open_summary:
|
||||
ticker = summary['ticker']
|
||||
current_price = current_prices.get(ticker, 0)
|
||||
shares = summary['total_shares']
|
||||
avg_entry = summary['avg_entry_price']
|
||||
|
||||
position_value = current_price * shares if current_price else avg_entry * shares
|
||||
total_position_value += position_value
|
||||
total_paper_pl += (current_price - avg_entry) * shares if current_price else 0
|
||||
|
||||
total_portfolio_value = cash_balance + total_position_value
|
||||
|
||||
# Display portfolio information
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.info(f"Available Cash: ${cash_balance:,.2f}")
|
||||
with col2:
|
||||
st.info(f"Positions Value: ${total_position_value:,.2f}")
|
||||
with col3:
|
||||
st.info(f"Total Portfolio Value: ${total_portfolio_value:,.2f}")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
account_size = st.number_input("Account Size ($)",
|
||||
min_value=0.0,
|
||||
value=total_portfolio_value,
|
||||
step=1000.0,
|
||||
key="personal_account_size")
|
||||
risk_percentage = st.number_input("Risk Percentage (%)",
|
||||
min_value=0.1,
|
||||
max_value=100.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
key="personal_risk_percentage")
|
||||
use_monte_carlo = st.checkbox("Use Monte Carlo for Analysis", value=True, key="personal_monte_carlo")
|
||||
if use_monte_carlo:
|
||||
days_out = st.number_input("Days to Project",
|
||||
min_value=1,
|
||||
max_value=30,
|
||||
value=5,
|
||||
help="Number of days to project for target price",
|
||||
key="personal_days_out")
|
||||
confidence_level = st.slider("Confidence Level (%)",
|
||||
min_value=80,
|
||||
max_value=99,
|
||||
value=95,
|
||||
key="personal_confidence_level")
|
||||
|
||||
with col2:
|
||||
ticker = st.text_input("Ticker Symbol", value="", key="personal_ticker").upper()
|
||||
entry_price = st.number_input("Entry Price ($)", min_value=0.01, step=0.01, key="personal_entry_price")
|
||||
if not use_monte_carlo:
|
||||
target_price = st.number_input("Target Price ($)", min_value=0.01, step=0.01, key="personal_target_price")
|
||||
|
||||
if st.button("Calculate Position", key="personal_calculate"):
|
||||
try:
|
||||
if not ticker:
|
||||
st.error("Please enter a ticker symbol")
|
||||
return
|
||||
|
||||
# Get historical data for Monte Carlo simulation
|
||||
if use_monte_carlo:
|
||||
with st.spinner("Calculating optimal stop loss..."):
|
||||
df = get_stock_data(
|
||||
ticker,
|
||||
datetime.now() - timedelta(days=30), # Last 30 days of data
|
||||
datetime.now(),
|
||||
'1m' # Minute data for more accurate simulation
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
st.error("No data available for the selected ticker")
|
||||
return
|
||||
|
||||
# Initialize Monte Carlo simulator
|
||||
simulator = MonteCarloSimulator(df, num_simulations=1000, time_horizon=days_out)
|
||||
|
||||
# Calculate stop loss and target prices
|
||||
stop_loss_price = simulator.calculate_stop_loss(risk_percentage)
|
||||
target_price = simulator.calculate_target_price(confidence_level)
|
||||
|
||||
# Calculate stop loss percentage
|
||||
stop_loss_percentage = abs((stop_loss_price - entry_price) / entry_price * 100)
|
||||
else:
|
||||
stop_loss_percentage = 7.0 # Default value if not using Monte Carlo
|
||||
|
||||
calculator = PositionCalculator(
|
||||
account_size=account_size,
|
||||
risk_percentage=risk_percentage,
|
||||
stop_loss_percentage=stop_loss_percentage
|
||||
)
|
||||
|
||||
position = calculator.calculate_position_size(entry_price, target_price)
|
||||
|
||||
# Calculate maximum shares possible with available cash
|
||||
max_shares_by_cash = int(cash_balance / entry_price) if entry_price > 0 else 0
|
||||
|
||||
# Adjust shares based on available cash
|
||||
recommended_shares = min(position['shares'], max_shares_by_cash)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
if recommended_shares < position['shares']:
|
||||
st.warning(
|
||||
f"Position size limited by available cash.\n"
|
||||
f"Ideal shares: {position['shares']:,}\n"
|
||||
f"Maximum affordable shares: {recommended_shares:,}"
|
||||
)
|
||||
position_value = recommended_shares * entry_price
|
||||
risk_amount = position['risk_amount'] * (recommended_shares / position['shares'])
|
||||
|
||||
st.metric("Recommended Shares", f"{recommended_shares:,}")
|
||||
st.metric("Position Value", f"${position_value:,.2f}")
|
||||
st.metric("Risk Amount", f"${risk_amount:,.2f}")
|
||||
else:
|
||||
st.metric("Number of Shares", f"{position['shares']:,}")
|
||||
st.metric("Position Value", f"${position['position_value']:,.2f}")
|
||||
st.metric("Risk Amount", f"${position['risk_amount']:,.2f}")
|
||||
|
||||
with col2:
|
||||
st.metric("Stop Loss Price", f"${position['stop_loss']:.2f}")
|
||||
st.metric("Potential Loss", f"${position['potential_loss']:,.2f}")
|
||||
if 'potential_profit' in position:
|
||||
potential_profit = (target_price - entry_price) * recommended_shares
|
||||
risk_reward = abs(potential_profit / (position['stop_loss'] - entry_price) / recommended_shares) if recommended_shares > 0 else 0
|
||||
st.metric("Potential Profit", f"${potential_profit:,.2f}")
|
||||
st.metric("Risk/Reward Ratio", f"{risk_reward:.2f}")
|
||||
|
||||
# Show percentage of cash being used
|
||||
if recommended_shares > 0:
|
||||
cash_usage = (recommended_shares * entry_price / cash_balance) * 100
|
||||
portfolio_usage = (recommended_shares * entry_price / total_portfolio_value) * 100
|
||||
st.info(
|
||||
f"This position would use:\n"
|
||||
f"- {cash_usage:.1f}% of available cash\n"
|
||||
f"- {portfolio_usage:.1f}% of total portfolio"
|
||||
)
|
||||
|
||||
# Add Monte Carlo metrics if used
|
||||
if use_monte_carlo:
|
||||
st.subheader("Monte Carlo Analysis")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Stop Loss Price", f"${stop_loss_price:.2f}")
|
||||
st.metric("Stop Loss %", f"{stop_loss_percentage:.2f}%")
|
||||
with col2:
|
||||
st.metric("Target Price", f"${target_price:.2f}")
|
||||
st.metric("Target %", f"{((target_price - entry_price) / entry_price * 100):.2f}%")
|
||||
with col3:
|
||||
st.metric("Days Projected", f"{days_out}")
|
||||
st.metric("Confidence Level", f"{confidence_level}%")
|
||||
|
||||
# Add to watchlist option
|
||||
st.divider()
|
||||
st.subheader("Save to Watch List")
|
||||
if st.button("Prepare for Watch List", key="personal_prepare_watchlist"):
|
||||
st.session_state.prefill_watchlist = {
|
||||
'ticker': ticker,
|
||||
'entry_price': float(entry_price),
|
||||
'target_price': float(target_price),
|
||||
'stop_loss': float(position['stop_loss']),
|
||||
'shares': recommended_shares
|
||||
}
|
||||
st.success("Details saved! Switch to Watch Lists tab to complete adding to your watch list.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error calculating position: {str(e)}")
|
||||
|
||||
# Tab 2: Prop Firm Calculator
|
||||
with tab2:
|
||||
st.subheader("Prop Firm Calculator")
|
||||
|
||||
# Prop firm parameters
|
||||
st.markdown("### Prop Firm Parameters")
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
buying_power = st.number_input(
|
||||
"Buying Power ($)",
|
||||
min_value=1000.0,
|
||||
value=20000.0,
|
||||
step=1000.0,
|
||||
help="Total capital allocated by the prop firm"
|
||||
)
|
||||
|
||||
max_daily_loss = st.number_input(
|
||||
"Daily Loss Limit ($)",
|
||||
min_value=100.0,
|
||||
value=300.0,
|
||||
step=50.0,
|
||||
help="Maximum loss allowed in a single trading day"
|
||||
)
|
||||
|
||||
max_total_loss = st.number_input(
|
||||
"Max Total Loss ($)",
|
||||
min_value=100.0,
|
||||
value=900.0,
|
||||
step=50.0,
|
||||
help="Maximum total loss allowed during the evaluation period"
|
||||
)
|
||||
|
||||
with col2:
|
||||
evaluation_days = st.number_input(
|
||||
"Evaluation Period (Days)",
|
||||
min_value=5,
|
||||
value=45,
|
||||
step=1,
|
||||
help="Number of days in the evaluation period"
|
||||
)
|
||||
|
||||
risk_percentage = st.number_input(
|
||||
"Risk Percentage per Trade (%)",
|
||||
min_value=0.1,
|
||||
max_value=100.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Percentage of account to risk on each trade"
|
||||
)
|
||||
|
||||
max_position_size = st.number_input(
|
||||
"Max Position Size (%)",
|
||||
min_value=1.0,
|
||||
max_value=100.0,
|
||||
value=15.0,
|
||||
step=1.0,
|
||||
help="Maximum percentage of buying power for a single position"
|
||||
)
|
||||
|
||||
# In the prop firm calculator section, add this after the max_position_size input
|
||||
day_trading_mode = st.checkbox(
|
||||
"Day Trading Mode",
|
||||
value=True,
|
||||
help="Enable specific settings for day trading"
|
||||
)
|
||||
|
||||
if day_trading_mode:
|
||||
max_loss_per_trade = st.number_input(
|
||||
"Max Loss Per Trade ($)",
|
||||
min_value=10.0,
|
||||
max_value=max_daily_loss,
|
||||
value=50.0,
|
||||
step=10.0,
|
||||
help="Maximum dollar amount to risk on a single trade"
|
||||
)
|
||||
|
||||
# Calculate how many trades you can take at max loss
|
||||
max_trades_at_full_loss = int(max_daily_loss / max_loss_per_trade)
|
||||
st.info(f"You can take up to {max_trades_at_full_loss} trades at maximum loss before hitting daily limit")
|
||||
|
||||
# Add position scaling parameters to the top form
|
||||
use_scaling = st.checkbox(
|
||||
"Enable Position Scaling",
|
||||
value=True,
|
||||
help="Calculate how to add to winning positions"
|
||||
)
|
||||
|
||||
if use_scaling:
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
num_entries = st.number_input(
|
||||
"Number of Entry Points",
|
||||
min_value=2,
|
||||
max_value=5,
|
||||
value=3,
|
||||
help="How many times you want to add to your position"
|
||||
)
|
||||
|
||||
scaling_method = st.selectbox(
|
||||
"Scaling Method",
|
||||
options=["Equal Size", "Increasing Size", "Decreasing Size"],
|
||||
index=0,
|
||||
help="How to distribute position size across entries"
|
||||
)
|
||||
|
||||
with col2:
|
||||
price_increment = st.number_input(
|
||||
"Price Movement Between Entries (%)",
|
||||
min_value=0.1,
|
||||
max_value=10.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="How much the price should move before adding to position"
|
||||
)
|
||||
|
||||
max_scaling_factor = st.number_input(
|
||||
"Maximum Position Scaling Factor",
|
||||
min_value=1.0,
|
||||
max_value=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
help="Maximum multiple of initial position size"
|
||||
)
|
||||
|
||||
# Position calculator
|
||||
st.markdown("### Position Calculator")
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
ticker = st.text_input("Ticker Symbol", value="", key="prop_ticker").upper()
|
||||
entry_price = st.number_input("Entry Price ($)", min_value=0.01, step=0.01, key="prop_entry_price")
|
||||
|
||||
use_monte_carlo = st.checkbox("Use Monte Carlo for Analysis", value=True, key="prop_monte_carlo")
|
||||
if use_monte_carlo:
|
||||
days_out = st.number_input(
|
||||
"Days to Project",
|
||||
min_value=1,
|
||||
max_value=30,
|
||||
value=5,
|
||||
help="Number of days to project for target price",
|
||||
key="prop_days_out"
|
||||
)
|
||||
confidence_level = st.slider(
|
||||
"Confidence Level (%)",
|
||||
min_value=80,
|
||||
max_value=99,
|
||||
value=95,
|
||||
key="prop_confidence_level"
|
||||
)
|
||||
|
||||
with col2:
|
||||
if not use_monte_carlo:
|
||||
target_price = st.number_input("Target Price ($)", min_value=0.01, step=0.01, key="prop_target_price")
|
||||
stop_loss_price = st.number_input("Stop Loss Price ($)", min_value=0.01, step=0.01, key="prop_stop_loss")
|
||||
|
||||
# Calculate daily risk limit as a percentage
|
||||
daily_risk_pct = (max_daily_loss / buying_power) * 100
|
||||
st.info(f"Daily risk limit: {daily_risk_pct:.2f}% of account")
|
||||
|
||||
# Calculate total risk limit as a percentage
|
||||
total_risk_pct = (max_total_loss / buying_power) * 100
|
||||
st.info(f"Total risk limit: {total_risk_pct:.2f}% of account")
|
||||
|
||||
if st.button("Calculate Position", key="prop_calculate"):
|
||||
try:
|
||||
if not ticker:
|
||||
st.error("Please enter a ticker symbol")
|
||||
return
|
||||
|
||||
# Get historical data for Monte Carlo simulation
|
||||
if use_monte_carlo:
|
||||
with st.spinner("Calculating optimal stop loss..."):
|
||||
df = get_stock_data(
|
||||
ticker,
|
||||
datetime.now() - timedelta(days=30), # Last 30 days of data
|
||||
datetime.now(),
|
||||
'1m' # Minute data for more accurate simulation
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
st.error("No data available for the selected ticker")
|
||||
return
|
||||
|
||||
# Initialize Monte Carlo simulator
|
||||
simulator = MonteCarloSimulator(df, num_simulations=1000, time_horizon=days_out)
|
||||
|
||||
# Calculate stop loss and target prices
|
||||
stop_loss_price = simulator.calculate_stop_loss(risk_percentage)
|
||||
target_price = simulator.calculate_target_price(confidence_level)
|
||||
|
||||
# Calculate stop loss percentage
|
||||
stop_loss_percentage = abs((stop_loss_price - entry_price) / entry_price * 100)
|
||||
else:
|
||||
stop_loss_percentage = abs((stop_loss_price - entry_price) / entry_price * 100)
|
||||
|
||||
# Calculate position size based on risk percentage
|
||||
calculator = PositionCalculator(
|
||||
account_size=buying_power,
|
||||
risk_percentage=risk_percentage,
|
||||
stop_loss_percentage=stop_loss_percentage
|
||||
)
|
||||
|
||||
position = calculator.calculate_position_size(entry_price, target_price)
|
||||
|
||||
# Calculate maximum shares based on daily loss limit
|
||||
max_shares_by_daily_loss = int(max_daily_loss / abs(entry_price - stop_loss_price)) if entry_price != stop_loss_price else 0
|
||||
|
||||
# Calculate maximum shares based on position size limit
|
||||
max_position_value = buying_power * (max_position_size / 100)
|
||||
max_shares_by_position_limit = int(max_position_value / entry_price) if entry_price > 0 else 0
|
||||
|
||||
# Calculate shares based on fixed dollar risk
|
||||
max_shares_by_fixed_risk = int(max_loss_per_trade / abs(entry_price - stop_loss_price)) if entry_price != stop_loss_price else 0
|
||||
|
||||
# Update recommended shares calculation
|
||||
recommended_shares = min(
|
||||
position['shares'],
|
||||
max_shares_by_daily_loss,
|
||||
max_shares_by_position_limit,
|
||||
max_shares_by_fixed_risk
|
||||
)
|
||||
|
||||
# Calculate position metrics
|
||||
position_value = recommended_shares * entry_price
|
||||
max_loss = abs(entry_price - stop_loss_price) * recommended_shares
|
||||
potential_profit = abs(target_price - entry_price) * recommended_shares
|
||||
risk_reward = potential_profit / max_loss if max_loss > 0 else 0
|
||||
|
||||
# Display results
|
||||
st.markdown("### Position Results")
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Recommended Shares", f"{recommended_shares:,}")
|
||||
st.metric("Position Value", f"${position_value:,.2f}")
|
||||
st.metric("% of Buying Power", f"{(position_value/buying_power*100):.2f}%")
|
||||
|
||||
with col2:
|
||||
st.metric("Stop Loss Price", f"${stop_loss_price:.2f}")
|
||||
st.metric("Maximum Loss", f"${max_loss:.2f}")
|
||||
st.metric("% of Daily Limit", f"{(max_loss/max_daily_loss*100):.2f}%")
|
||||
|
||||
with col3:
|
||||
st.metric("Target Price", f"${target_price:.2f}")
|
||||
st.metric("Potential Profit", f"${potential_profit:.2f}")
|
||||
st.metric("Risk/Reward Ratio", f"{risk_reward:.2f}")
|
||||
|
||||
# Show constraint information
|
||||
st.subheader("Position Constraints")
|
||||
|
||||
constraints = {
|
||||
"Risk-based position size": position['shares'],
|
||||
"Daily loss limit": max_shares_by_daily_loss,
|
||||
"Maximum position size": max_shares_by_position_limit,
|
||||
"Fixed dollar risk per trade": max_shares_by_fixed_risk
|
||||
}
|
||||
|
||||
# Determine which constraint is active
|
||||
active_constraint = min(constraints, key=constraints.get)
|
||||
|
||||
for constraint, shares in constraints.items():
|
||||
if constraint == active_constraint:
|
||||
st.warning(f"**{constraint}**: {shares:,} shares (active constraint)")
|
||||
else:
|
||||
st.info(f"{constraint}: {shares:,} shares")
|
||||
|
||||
# Add Monte Carlo metrics if used
|
||||
if use_monte_carlo:
|
||||
st.subheader("Monte Carlo Analysis")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Stop Loss Price", f"${stop_loss_price:.2f}")
|
||||
st.metric("Stop Loss %", f"{stop_loss_percentage:.2f}%")
|
||||
with col2:
|
||||
st.metric("Target Price", f"${target_price:.2f}")
|
||||
st.metric("Target %", f"{((target_price - entry_price) / entry_price * 100):.2f}%")
|
||||
with col3:
|
||||
st.metric("Days Projected", f"{days_out}")
|
||||
st.metric("Confidence Level", f"{confidence_level}%")
|
||||
|
||||
# Add to watchlist option
|
||||
st.divider()
|
||||
st.subheader("Save to Watch List")
|
||||
if st.button("Prepare for Watch List", key="prop_prepare_watchlist"):
|
||||
st.session_state.prefill_watchlist = {
|
||||
'ticker': ticker,
|
||||
'entry_price': float(entry_price),
|
||||
'target_price': float(target_price),
|
||||
'stop_loss': float(stop_loss_price),
|
||||
'shares': recommended_shares
|
||||
}
|
||||
st.success("Details saved! Switch to Watch Lists tab to complete adding to your watch list.")
|
||||
|
||||
# Add this after the position results section in the prop firm calculator tab
|
||||
if use_scaling and recommended_shares > 0:
|
||||
st.divider()
|
||||
st.subheader("Position Scaling Strategy")
|
||||
|
||||
# Calculate scaling strategy
|
||||
initial_shares = recommended_shares
|
||||
max_total_shares = int(initial_shares * max_scaling_factor)
|
||||
|
||||
# Calculate shares for each entry point
|
||||
entry_points = []
|
||||
|
||||
# First entry is the initial position
|
||||
entry_points.append({
|
||||
"entry_num": 1,
|
||||
"price": entry_price,
|
||||
"shares": initial_shares,
|
||||
"value": initial_shares * entry_price,
|
||||
"cumulative_shares": initial_shares,
|
||||
"cumulative_value": initial_shares * entry_price
|
||||
})
|
||||
|
||||
# Calculate remaining entry points
|
||||
for i in range(2, num_entries + 1):
|
||||
# Calculate price for this entry
|
||||
price_movement = (price_increment / 100) * entry_price * (i - 1)
|
||||
entry_price_i = entry_price + price_movement
|
||||
|
||||
# Calculate shares for this entry based on scaling method
|
||||
if scaling_method == "Equal Size":
|
||||
shares_i = initial_shares
|
||||
elif scaling_method == "Increasing Size":
|
||||
shares_i = int(initial_shares * (1 + (i - 1) * 0.5))
|
||||
else: # Decreasing Size
|
||||
shares_i = int(initial_shares * (1 - (i - 1) * 0.25))
|
||||
shares_i = max(shares_i, int(initial_shares * 0.25)) # Ensure minimum size
|
||||
|
||||
# Ensure we don't exceed max total shares
|
||||
cumulative_shares = entry_points[-1]["cumulative_shares"] + shares_i
|
||||
if cumulative_shares > max_total_shares:
|
||||
shares_i = max_total_shares - entry_points[-1]["cumulative_shares"]
|
||||
if shares_i <= 0:
|
||||
break
|
||||
|
||||
# Add entry point
|
||||
entry_points.append({
|
||||
"entry_num": i,
|
||||
"price": entry_price_i,
|
||||
"shares": shares_i,
|
||||
"value": shares_i * entry_price_i,
|
||||
"cumulative_shares": cumulative_shares,
|
||||
"cumulative_value": entry_points[-1]["cumulative_value"] + (shares_i * entry_price_i)
|
||||
})
|
||||
|
||||
# Display scaling strategy
|
||||
st.markdown("#### Ladder Entry Strategy")
|
||||
|
||||
# Create a table for the entry points
|
||||
data = []
|
||||
for ep in entry_points:
|
||||
data.append([
|
||||
f"Entry #{ep['entry_num']}",
|
||||
f"${ep['price']:.2f}",
|
||||
f"{ep['shares']:,}",
|
||||
f"${ep['value']:.2f}",
|
||||
f"{ep['cumulative_shares']:,}",
|
||||
f"${ep['cumulative_value']:.2f}"
|
||||
])
|
||||
|
||||
st.table({
|
||||
"Entry Point": [row[0] for row in data],
|
||||
"Price": [row[1] for row in data],
|
||||
"Shares": [row[2] for row in data],
|
||||
"Position Value": [row[3] for row in data],
|
||||
"Cumulative Shares": [row[4] for row in data],
|
||||
"Cumulative Value": [row[5] for row in data]
|
||||
})
|
||||
|
||||
# Calculate average entry price after all entries
|
||||
if entry_points:
|
||||
final_entry = entry_points[-1]
|
||||
|
||||
# Calculate weighted average entry price correctly
|
||||
total_value = 0
|
||||
total_shares = 0
|
||||
|
||||
for entry in entry_points:
|
||||
total_value += entry["price"] * entry["shares"]
|
||||
total_shares += entry["shares"]
|
||||
|
||||
avg_entry = total_value / total_shares if total_shares > 0 else 0
|
||||
|
||||
max_position_value = final_entry["cumulative_value"]
|
||||
max_position_pct = (max_position_value / buying_power) * 100
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Final Position Size", f"{final_entry['cumulative_shares']:,} shares")
|
||||
st.metric("Average Entry Price", f"${avg_entry:.2f}")
|
||||
|
||||
with col2:
|
||||
st.metric("Total Position Value", f"${max_position_value:.2f}")
|
||||
st.metric("% of Buying Power", f"{max_position_pct:.2f}%")
|
||||
|
||||
# Risk analysis for the scaled position
|
||||
max_loss_scaled = (avg_entry - stop_loss_price) * final_entry["cumulative_shares"]
|
||||
max_profit_scaled = (target_price - avg_entry) * final_entry["cumulative_shares"]
|
||||
risk_reward_scaled = max_profit_scaled / max_loss_scaled if max_loss_scaled > 0 else 0
|
||||
|
||||
st.markdown("#### Risk Analysis for Scaled Position")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Maximum Loss", f"${max_loss_scaled:.2f}")
|
||||
st.metric("% of Daily Limit", f"{(max_loss_scaled/max_daily_loss*100):.2f}%")
|
||||
|
||||
with col2:
|
||||
st.metric("Potential Profit", f"${max_profit_scaled:.2f}")
|
||||
st.metric("Risk/Reward Ratio", f"{risk_reward_scaled:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error calculating position: {str(e)}")
|
||||
logger.exception("Error in prop firm calculator")
|
||||
|
||||
# Tab 3: Watch Lists
|
||||
with tab3:
|
||||
st.subheader("Watch Lists")
|
||||
|
||||
# Create new watch list
|
||||
with st.expander("Create New Watch List"):
|
||||
new_list_name = st.text_input("Watch List Name")
|
||||
strategy = st.selectbox(
|
||||
"Strategy",
|
||||
options=["SunnyBand", "Heikin Ashi", "Three ATR EMA", "Other"],
|
||||
index=3
|
||||
)
|
||||
if st.button("Create Watch List"):
|
||||
if new_list_name:
|
||||
create_watchlist(new_list_name, strategy)
|
||||
st.success(f"Created watch list: {new_list_name}")
|
||||
else:
|
||||
st.error("Please enter a watch list name")
|
||||
|
||||
# Add new item section
|
||||
with st.expander("Add New Item", expanded=bool(st.session_state.prefill_watchlist)):
|
||||
watchlists = get_watchlists()
|
||||
if watchlists:
|
||||
selected_list = st.selectbox(
|
||||
"Select Watch List",
|
||||
options=[(w['id'], w['name']) for w in watchlists],
|
||||
format_func=lambda x: x[1]
|
||||
)
|
||||
|
||||
# Use prefilled data if available
|
||||
prefill = st.session_state.prefill_watchlist or {}
|
||||
|
||||
# Debug output to verify prefill data
|
||||
if prefill:
|
||||
st.write("Debug - Prefill ", prefill)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
ticker = st.text_input("Ticker", value=str(prefill.get('ticker', '')), key="watchlist_ticker")
|
||||
entry_price = st.number_input("Entry Price",
|
||||
value=float(prefill.get('entry_price') or 0.0),
|
||||
min_value=0.0,
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
key="watchlist_entry_price")
|
||||
shares = st.number_input("Shares",
|
||||
value=int(prefill.get('shares') or 0),
|
||||
min_value=0,
|
||||
step=1,
|
||||
key="watchlist_shares")
|
||||
with col2:
|
||||
target_price = st.number_input("Target Price",
|
||||
value=float(prefill.get('target_price') or 0.0),
|
||||
min_value=0.0,
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
key="watchlist_target_price")
|
||||
stop_loss = st.number_input("Stop Loss",
|
||||
value=float(prefill.get('stop_loss') or 0.0),
|
||||
min_value=0.0,
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
key="watchlist_stop_loss")
|
||||
|
||||
notes = st.text_area("Notes")
|
||||
|
||||
if st.button("Add to Watch List"):
|
||||
try:
|
||||
ensure_tables_exist()
|
||||
item = WatchlistItem(
|
||||
ticker=ticker,
|
||||
entry_price=entry_price,
|
||||
target_price=target_price,
|
||||
stop_loss=stop_loss,
|
||||
shares=shares,
|
||||
notes=notes
|
||||
)
|
||||
|
||||
success = add_to_watchlist(selected_list[0], item)
|
||||
|
||||
if success:
|
||||
st.success(f"Added {ticker} to watchlist!")
|
||||
# Clear the prefill data
|
||||
st.session_state.prefill_watchlist = None
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Failed to add to watchlist")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error: {str(e)}")
|
||||
logger.exception("Error adding to watchlist")
|
||||
else:
|
||||
st.warning("Please create a watch list first")
|
||||
|
||||
# Display watch lists
|
||||
st.subheader("Current Watch Lists")
|
||||
watchlists = get_watchlists()
|
||||
if watchlists:
|
||||
selected_watchlist = st.selectbox(
|
||||
"Select Watch List to View",
|
||||
options=[(w['id'], w['name']) for w in watchlists],
|
||||
format_func=lambda x: x[1],
|
||||
key="view_watchlist" # Add a unique key to avoid conflicts
|
||||
)
|
||||
|
||||
items = get_watchlist_items(selected_watchlist[0])
|
||||
if items:
|
||||
for item in items:
|
||||
with st.container():
|
||||
col1, col2, col3, col4, col5, col6 = st.columns([2, 2, 2, 2, 1, 1])
|
||||
with col1:
|
||||
st.write(f"**{item.ticker}**")
|
||||
with col2:
|
||||
st.write(f"Entry: ${item.entry_price:.2f}")
|
||||
with col3:
|
||||
st.write(f"Target: ${item.target_price:.2f}")
|
||||
with col4:
|
||||
st.write(f"Stop: ${item.stop_loss:.2f}")
|
||||
with col5:
|
||||
st.write(f"Shares: {item.shares:,}")
|
||||
with col6:
|
||||
if st.button("Remove", key=f"remove_{item.id}"):
|
||||
if remove_from_watchlist(item.id):
|
||||
st.rerun()
|
||||
if item.notes:
|
||||
st.info(item.notes)
|
||||
st.divider()
|
||||
else:
|
||||
st.info("No items in this watch list")
|
||||
else:
|
||||
st.info("No watch lists created yet")
|
||||
@ -1,12 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the src directory to the Python path
|
||||
src_path = str(Path(__file__).parent)
|
||||
if src_path not in sys.path:
|
||||
sys.path.append(src_path)
|
||||
|
||||
from migrations.add_direction_field import migrate_trades
|
||||
|
||||
if __name__ == "__main__":
|
||||
migrate_trades()
|
||||
@ -1 +1,10 @@
|
||||
# Empty file
|
||||
# Add explicit imports for scanner modules
|
||||
from .t_atr_ema import run_atr_ema_scanner
|
||||
from .t_atr_ema_v2 import run_atr_ema_scanner_v2
|
||||
from .t_sunnyband import run_sunny_scanner
|
||||
|
||||
__all__ = [
|
||||
'run_atr_ema_scanner',
|
||||
'run_atr_ema_scanner_v2',
|
||||
'run_sunny_scanner'
|
||||
]
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
from screener.data_fetcher import validate_date_range, fetch_financial_data, get_stocks_in_time_range
|
||||
from screener.user_input import get_user_screener_selection
|
||||
from screener.c_canslim import check_quarterly_earnings, check_return_on_equity, check_sales_growth
|
||||
from screener.a_canslim import check_annual_eps_growth
|
||||
from screener.l_canslim import check_industry_leadership
|
||||
from screener.i_canslim import check_institutional_sponsorship
|
||||
from screener.csv_appender import append_scores_to_csv
|
||||
|
||||
def run_canslim_screener(start_date=None, end_date=None, selected_screeners=None):
|
||||
"""Run the CANSLIM screener"""
|
||||
if start_date is None or end_date is None:
|
||||
start_date = input("Enter start date (YYYY-MM-DD): ")
|
||||
end_date = input("Enter end date (YYYY-MM-DD): ")
|
||||
|
||||
if selected_screeners is None:
|
||||
selected_screeners = get_user_screener_selection()
|
||||
|
||||
start_date, end_date = validate_date_range(start_date, end_date, required_quarters=4)
|
||||
symbol_list = get_stocks_in_time_range(start_date, end_date)
|
||||
|
||||
if not symbol_list:
|
||||
print("No stocks found within the given date range.")
|
||||
return
|
||||
|
||||
print(f"Processing {len(symbol_list)} stocks within the given date range...\n")
|
||||
|
||||
for symbol in symbol_list:
|
||||
data = fetch_financial_data(symbol, start_date, end_date)
|
||||
process_symbol(symbol, data, selected_screeners)
|
||||
|
||||
print("✅ Scores saved in data/metrics/stock_scores.csv\n")
|
||||
|
||||
def process_symbol(symbol, data, selected_screeners):
|
||||
"""Process individual symbol for CANSLIM screening"""
|
||||
if not data: # Add the condition to check if data is empty
|
||||
print(f"⚠️ Warning: No data returned for {symbol}. Assigning default score.\n")
|
||||
scores = {screener: 0.25 for category in selected_screeners
|
||||
for screener in selected_screeners[category]}
|
||||
else:
|
||||
scores = calculate_scores(symbol, data, selected_screeners)
|
||||
|
||||
scores["Total_Score"] = sum(scores.values())
|
||||
append_scores_to_csv(symbol, scores)
|
||||
|
||||
def calculate_scores(symbol, data, selected_screeners):
|
||||
"""Calculate scores for each selected screener"""
|
||||
scores = {}
|
||||
for category, screeners in selected_screeners.items():
|
||||
for screener, threshold in screeners.items():
|
||||
scores[screener] = get_screener_score(
|
||||
screener, data, symbol, threshold
|
||||
)
|
||||
return scores
|
||||
|
||||
def get_screener_score(screener, data, symbol, threshold):
|
||||
"""Get score for specific screener"""
|
||||
if screener == "EPS_Score":
|
||||
score = check_quarterly_earnings(data.get("quarterly_eps", []))
|
||||
elif screener == "Annual_EPS_Score":
|
||||
score = check_annual_eps_growth(data.get("annual_eps", []))
|
||||
elif screener == "Sales_Score":
|
||||
score = check_sales_growth(data.get("sales_growth", []))
|
||||
elif screener == "ROE_Score":
|
||||
score = check_return_on_equity(data.get("roe", []))
|
||||
elif screener == "L_Score":
|
||||
score = check_industry_leadership(symbol)
|
||||
print(f"🟢 {symbol} - L_Score: {score}")
|
||||
elif screener == "I_Score":
|
||||
score = check_institutional_sponsorship(symbol)
|
||||
print(f"🏢 {symbol} - I_Score: {score}")
|
||||
else:
|
||||
score = 0
|
||||
|
||||
return score >= threshold if isinstance(threshold, (int, float)) else score
|
||||
@ -1,43 +0,0 @@
|
||||
from datetime import datetime
|
||||
from screener.t_sunnyband import run_sunny_scanner
|
||||
from screener.t_atr_ema import run_atr_ema_scanner
|
||||
from screener.t_atr_ema_v2 import run_atr_ema_scanner_v2
|
||||
from screener.t_heikinashi import run_heikin_ashi_scanner
|
||||
from screener.t_candlestick import run_candlestick_scanner
|
||||
from screener.t_sunny_sma import run_sunny_sma_scanner
|
||||
|
||||
def run_technical_scanner(scanner_choice: str, start_date: str, end_date: str,
|
||||
min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float, interval: str = "1d",
|
||||
selected_patterns: list = None):
|
||||
"""
|
||||
Run the selected technical scanner with provided parameters
|
||||
|
||||
Args:
|
||||
scanner_choice (str): Type of scanner to run
|
||||
start_date (str): Start date in YYYY-MM-DD format
|
||||
end_date (str): End date in YYYY-MM-DD format
|
||||
min_price (float): Minimum stock price
|
||||
max_price (float): Maximum stock price
|
||||
min_volume (int): Minimum volume
|
||||
portfolio_size (float): Portfolio size for position sizing
|
||||
interval (str): Time interval for data (default: "1d")
|
||||
"""
|
||||
# Convert string dates to datetime objects
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
scanner_map = {
|
||||
"sunnybands": lambda: run_sunny_scanner(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt),
|
||||
"atr-ema": lambda: run_atr_ema_scanner(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt),
|
||||
"atr-ema_v2": lambda: run_atr_ema_scanner_v2(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt),
|
||||
"heikin-ashi": lambda: run_heikin_ashi_scanner(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt),
|
||||
"candlestick": lambda: run_candlestick_scanner(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt, selected_patterns),
|
||||
"sunny-sma": lambda: run_sunny_sma_scanner(min_price, max_price, min_volume, portfolio_size, interval, start_dt, end_dt)
|
||||
}
|
||||
|
||||
scanner_func = scanner_map.get(scanner_choice)
|
||||
if scanner_func:
|
||||
return scanner_func()
|
||||
else:
|
||||
raise ValueError(f"Invalid scanner choice: {scanner_choice}")
|
||||
@ -1,7 +1,10 @@
|
||||
from datetime import datetime
|
||||
from screener.user_input import get_interval_choice, get_date_range
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from utils.data_utils import get_stock_data, validate_signal_date, print_signal, save_signals_to_csv
|
||||
from db.db_connection import create_client
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from utils.data_utils import get_stock_data, validate_signal_date, print_signal, save_signals_to_csv, get_qualified_stocks
|
||||
from indicators.three_atr_ema import ThreeATREMAIndicator
|
||||
|
||||
def check_entry_signal(df: pd.DataFrame) -> list:
|
||||
@ -57,50 +60,79 @@ def check_entry_signal(df: pd.DataFrame) -> list:
|
||||
|
||||
return signals
|
||||
|
||||
def run_atr_ema_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None) -> None:
|
||||
def run_atr_ema_scanner(min_price: float, max_price: float, min_volume: int, portfolio_size: float = None) -> None:
|
||||
print(f"\nScanning for stocks ${min_price:.2f}-${max_price:.2f} with min volume {min_volume:,}")
|
||||
|
||||
# Get time interval
|
||||
interval = get_interval_choice()
|
||||
|
||||
start_date, end_date = get_date_range()
|
||||
start_ts = int(start_date.timestamp() * 1000000000)
|
||||
end_ts = int(end_date.timestamp() * 1000000000)
|
||||
|
||||
try:
|
||||
# Initialize scanner components with all parameters
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume,
|
||||
portfolio_size=portfolio_size,
|
||||
interval=interval,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
qualified_stocks = get_qualified_stocks(start_date, end_date, min_price, max_price, min_volume)
|
||||
|
||||
if not qualified_stocks:
|
||||
print("No stocks found matching criteria.")
|
||||
return
|
||||
|
||||
print(f"\nFound {len(qualified_stocks)} stocks matching criteria")
|
||||
|
||||
# Initialize indicators
|
||||
indicator = ThreeATREMAIndicator()
|
||||
calculator = None
|
||||
if portfolio_size and portfolio_size > 0:
|
||||
calculator = PositionCalculator(
|
||||
account_size=portfolio_size,
|
||||
risk_percentage=1.0,
|
||||
stop_loss_percentage=7.0 # Explicitly set 7% stop
|
||||
)
|
||||
|
||||
bullish_signals = []
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
for ticker, current_price, current_volume, last_update in qualified_stocks:
|
||||
try:
|
||||
# Get historical data based on interval
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
if df.empty or len(df) < 21: # Need at least 21 bars for EMA
|
||||
if df.empty or len(df) < 50: # Need at least 50 bars for the indicator
|
||||
continue
|
||||
|
||||
results = indicator.calculate(df)
|
||||
|
||||
# Check for signals throughout the date range
|
||||
signals = check_entry_signal(df)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
signal_data['date'] = signal_date
|
||||
entry_data = process_signal_data(
|
||||
ticker, signal_data, current_volume,
|
||||
last_update, stock_type, calculator
|
||||
)
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'entry_price': signal_data['price'],
|
||||
'target_price': signal_data['ema'],
|
||||
'volume': current_volume,
|
||||
'signal_date': signal_date,
|
||||
'last_update': datetime.fromtimestamp(last_update/1000000000)
|
||||
}
|
||||
|
||||
if calculator:
|
||||
position = calculator.calculate_position_size(entry_data['entry_price'])
|
||||
potential_profit = (entry_data['target_price'] - entry_data['entry_price']) * position['shares']
|
||||
entry_data.update({
|
||||
'shares': position['shares'],
|
||||
'position_size': position['position_value'],
|
||||
'stop_loss': position['stop_loss'],
|
||||
'risk_amount': position['potential_loss'],
|
||||
'profit_amount': potential_profit,
|
||||
'risk_reward_ratio': abs(potential_profit / position['potential_loss']) if position['potential_loss'] != 0 else 0
|
||||
})
|
||||
|
||||
bullish_signals.append(entry_data)
|
||||
print_signal(entry_data, "🟢")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
save_signals_to_csv(bullish_signals, 'atr_ema')
|
||||
return bullish_signals
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
|
||||
@ -3,11 +3,7 @@ import pandas as pd
|
||||
import os
|
||||
from db.db_connection import create_client
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from utils.data_utils import (
|
||||
get_stock_data, validate_signal_date, print_signal,
|
||||
save_signals_to_csv, get_qualified_stocks
|
||||
)
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from utils.data_utils import get_stock_data, validate_signal_date, print_signal, save_signals_to_csv, get_qualified_stocks
|
||||
from screener.user_input import get_interval_choice, get_date_range
|
||||
from indicators.three_atr_ema import ThreeATREMAIndicator
|
||||
|
||||
@ -74,18 +70,35 @@ def run_atr_ema_scanner_v2(min_price: float, max_price: float, min_volume: int,
|
||||
min_volume (int): Minimum trading volume
|
||||
portfolio_size (float, optional): Portfolio size for position sizing
|
||||
"""
|
||||
print(f"\nScanning for stocks ${min_price:.2f}-${max_price:.2f} with min volume {min_volume:,}")
|
||||
|
||||
interval = get_interval_choice()
|
||||
|
||||
start_date, end_date = get_date_range()
|
||||
start_ts = int(start_date.timestamp() * 1000000000)
|
||||
end_ts = int(end_date.timestamp() * 1000000000)
|
||||
|
||||
try:
|
||||
# Initialize scanner components
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price, max_price, min_volume, portfolio_size
|
||||
)
|
||||
qualified_stocks = get_qualified_stocks(start_date, end_date, min_price, max_price, min_volume)
|
||||
|
||||
if not qualified_stocks:
|
||||
print("No stocks found matching criteria.")
|
||||
return
|
||||
|
||||
|
||||
print(f"\nFound {len(qualified_stocks)} stocks matching criteria")
|
||||
|
||||
# Initialize position calculator if portfolio size provided
|
||||
calculator = None
|
||||
if portfolio_size and portfolio_size > 0:
|
||||
calculator = PositionCalculator(
|
||||
account_size=portfolio_size,
|
||||
risk_percentage=1.0,
|
||||
stop_loss_percentage=7.0
|
||||
)
|
||||
|
||||
entry_signals = []
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
for ticker, current_price, current_volume, last_update in qualified_stocks:
|
||||
try:
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
@ -94,11 +107,27 @@ def run_atr_ema_scanner_v2(min_price: float, max_price: float, min_volume: int,
|
||||
|
||||
signals = check_entry_signal(df)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
signal_data['date'] = signal_date
|
||||
entry_data = process_signal_data(
|
||||
ticker, signal_data, current_volume,
|
||||
last_update, stock_type, calculator
|
||||
)
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'entry_price': signal_data['price'],
|
||||
'target_price': signal_data['ema'],
|
||||
'volume': current_volume,
|
||||
'signal_date': signal_date,
|
||||
'last_update': datetime.fromtimestamp(last_update/1000000000)
|
||||
}
|
||||
|
||||
if calculator:
|
||||
position = calculator.calculate_position_size(entry_data['entry_price'])
|
||||
potential_profit = (entry_data['target_price'] - entry_data['entry_price']) * position['shares']
|
||||
entry_data.update({
|
||||
'shares': position['shares'],
|
||||
'position_size': position['position_value'],
|
||||
'stop_loss': position['stop_loss'],
|
||||
'risk_amount': position['potential_loss'],
|
||||
'profit_amount': potential_profit,
|
||||
'risk_reward_ratio': abs(potential_profit / position['potential_loss']) if position['potential_loss'] != 0 else 0
|
||||
})
|
||||
|
||||
entry_signals.append(entry_data)
|
||||
print_signal(entry_data)
|
||||
|
||||
@ -107,8 +136,6 @@ def run_atr_ema_scanner_v2(min_price: float, max_price: float, min_volume: int,
|
||||
continue
|
||||
|
||||
save_signals_to_csv(entry_signals, 'atr_ema_v2')
|
||||
return entry_signals
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
|
||||
@ -1,183 +0,0 @@
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
import talib
|
||||
from db.db_connection import create_client
|
||||
from utils.data_utils import (
|
||||
get_stock_data, validate_signal_date, print_signal,
|
||||
save_signals_to_csv, get_qualified_stocks
|
||||
)
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from trading.position_calculator import PositionCalculator
|
||||
|
||||
# Dictionary mapping pattern names to their functions and descriptions
|
||||
CANDLESTICK_PATTERNS = {
|
||||
'BULLISH_ENGULFING': {
|
||||
'function': talib.CDLENGULFING,
|
||||
'description': 'Bullish Engulfing Pattern'
|
||||
},
|
||||
'HAMMER': {
|
||||
'function': talib.CDLHAMMER,
|
||||
'description': 'Hammer Pattern'
|
||||
},
|
||||
'MORNING_STAR': {
|
||||
'function': talib.CDLMORNINGSTAR,
|
||||
'description': 'Morning Star Pattern'
|
||||
},
|
||||
'PIERCING_LINE': {
|
||||
'function': talib.CDLPIERCING,
|
||||
'description': 'Piercing Line Pattern'
|
||||
},
|
||||
'THREE_WHITE_SOLDIERS': {
|
||||
'function': talib.CDL3WHITESOLDIERS,
|
||||
'description': 'Three White Soldiers Pattern'
|
||||
},
|
||||
'MORNING_DOJI_STAR': {
|
||||
'function': talib.CDLMORNINGDOJISTAR,
|
||||
'description': 'Morning Doji Star Pattern'
|
||||
},
|
||||
'DRAGONFLY_DOJI': {
|
||||
'function': talib.CDLDRAGONFLYDOJI,
|
||||
'description': 'Dragonfly Doji Pattern'
|
||||
},
|
||||
'HARAMI': {
|
||||
'function': talib.CDLHARAMI,
|
||||
'description': 'Bullish Harami Pattern'
|
||||
},
|
||||
'INVERTED_HAMMER': {
|
||||
'function': talib.CDLINVERTEDHAMMER,
|
||||
'description': 'Inverted Hammer Pattern'
|
||||
},
|
||||
'THREE_INSIDE_UP': {
|
||||
'function': talib.CDL3INSIDE,
|
||||
'description': 'Three Inside Up Pattern'
|
||||
},
|
||||
'THREE_OUTSIDE_UP': {
|
||||
'function': talib.CDL3OUTSIDE,
|
||||
'description': 'Three Outside Up Pattern'
|
||||
},
|
||||
'BELT_HOLD': {
|
||||
'function': talib.CDLBELTHOLD,
|
||||
'description': 'Bullish Belt Hold Pattern'
|
||||
},
|
||||
'LADDER_BOTTOM': {
|
||||
'function': talib.CDLLADDERBOTTOM,
|
||||
'description': 'Ladder Bottom Pattern'
|
||||
},
|
||||
'MATCHING_LOW': {
|
||||
'function': talib.CDLMATCHINGLOW,
|
||||
'description': 'Matching Low Pattern'
|
||||
},
|
||||
'UNIQUE_THREE_RIVER': {
|
||||
'function': talib.CDLUNIQUE3RIVER,
|
||||
'description': 'Unique Three River Bottom Pattern'
|
||||
}
|
||||
}
|
||||
|
||||
def check_entry_signal(df: pd.DataFrame, selected_patterns: list = None) -> list:
|
||||
"""
|
||||
Check for bullish candlestick patterns across the entire date range
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with OHLCV data
|
||||
selected_patterns (list): List of patterns to scan for
|
||||
|
||||
Returns:
|
||||
list: List of tuples (signal, date, signal_data) for each signal found
|
||||
"""
|
||||
if len(df) < 14: # Need minimum bars for pattern recognition
|
||||
return []
|
||||
|
||||
signals = []
|
||||
|
||||
# Use selected patterns or all patterns if none selected
|
||||
patterns_to_scan = {k: v for k, v in CANDLESTICK_PATTERNS.items()
|
||||
if selected_patterns is None or k in selected_patterns}
|
||||
|
||||
# Calculate patterns
|
||||
pattern_signals = {}
|
||||
for pattern_name, pattern_info in patterns_to_scan.items():
|
||||
pattern_signals[pattern_name] = pattern_info['function'](
|
||||
df['open'].values,
|
||||
df['high'].values,
|
||||
df['low'].values,
|
||||
df['close'].values
|
||||
)
|
||||
|
||||
# Look for signals across all candles
|
||||
for i in range(14, len(df)): # Start after lookback period
|
||||
found_patterns = []
|
||||
|
||||
for pattern_name, pattern_values in pattern_signals.items():
|
||||
# Check if we have a bullish signal (value > 0)
|
||||
if pattern_values[i] > 0:
|
||||
found_patterns.append(CANDLESTICK_PATTERNS[pattern_name]['description'])
|
||||
|
||||
if found_patterns:
|
||||
signal_data = {
|
||||
'price': df.iloc[i]['close'],
|
||||
'patterns': ', '.join(found_patterns),
|
||||
'pattern_count': len(found_patterns)
|
||||
}
|
||||
signals.append((True, df.iloc[i]['date'], signal_data))
|
||||
|
||||
return signals
|
||||
|
||||
def run_candlestick_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None,
|
||||
selected_patterns: list = None) -> None:
|
||||
"""
|
||||
Run candlestick pattern scanner to find bullish patterns
|
||||
"""
|
||||
try:
|
||||
# Initialize scanner components
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume,
|
||||
portfolio_size=portfolio_size,
|
||||
interval=interval,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if not qualified_stocks:
|
||||
return
|
||||
|
||||
bullish_signals = []
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
try:
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
if df.empty or len(df) < 14: # Need minimum bars
|
||||
continue
|
||||
|
||||
signals = check_entry_signal(df, selected_patterns)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'signal_date': signal_date,
|
||||
'entry_price': signal_data['price'],
|
||||
'patterns': signal_data['patterns'],
|
||||
'pattern_count': signal_data['pattern_count'],
|
||||
'volume': current_volume,
|
||||
'last_update': last_update,
|
||||
'stock_type': stock_type
|
||||
}
|
||||
bullish_signals.append(entry_data)
|
||||
|
||||
# Custom print for candlestick signals
|
||||
print(f"🕯️ {ticker}: {entry_data['patterns']} on {signal_date.strftime('%Y-%m-%d')} "
|
||||
f"at ${signal_data['price']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
save_signals_to_csv(bullish_signals, 'candlestick')
|
||||
return bullish_signals
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
@ -1,119 +0,0 @@
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from db.db_connection import create_client
|
||||
from utils.data_utils import (
|
||||
get_stock_data, validate_signal_date, print_signal,
|
||||
save_signals_to_csv, get_qualified_stocks
|
||||
)
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from trading.position_calculator import PositionCalculator
|
||||
|
||||
def calculate_heikin_ashi(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Calculate Heikin Ashi candles from regular OHLC data"""
|
||||
ha_close = (df['open'] + df['high'] + df['low'] + df['close']) / 4
|
||||
ha_open = pd.Series(index=df.index)
|
||||
ha_open.iloc[0] = df['open'].iloc[0]
|
||||
for i in range(1, len(df)):
|
||||
ha_open.iloc[i] = (ha_open.iloc[i-1] + ha_close.iloc[i-1]) / 2
|
||||
ha_high = df[['high', 'open', 'close']].max(axis=1)
|
||||
ha_low = df[['low', 'open', 'close']].min(axis=1)
|
||||
|
||||
return pd.DataFrame({
|
||||
'ha_open': ha_open,
|
||||
'ha_high': ha_high,
|
||||
'ha_low': ha_low,
|
||||
'ha_close': ha_close
|
||||
})
|
||||
|
||||
def check_entry_signal(df: pd.DataFrame) -> list:
|
||||
"""
|
||||
Check for bullish Heikin Ashi signals
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with price data
|
||||
|
||||
Returns:
|
||||
list: List of tuples (signal, date, signal_data) for each signal found
|
||||
"""
|
||||
if len(df) < 3: # Need at least 3 bars for comparison
|
||||
return []
|
||||
|
||||
# Calculate Heikin Ashi values
|
||||
ha_df = calculate_heikin_ashi(df)
|
||||
signals = []
|
||||
|
||||
# Start from index 2 to compare with previous candles
|
||||
for i in range(2, len(df)):
|
||||
current = ha_df.iloc[i]
|
||||
prev = ha_df.iloc[i-1]
|
||||
prev2 = ha_df.iloc[i-2]
|
||||
|
||||
# Bullish signal conditions:
|
||||
# 1. Current candle is bullish (close > open)
|
||||
# 2. Previous candle was bullish
|
||||
# 3. Previous to previous candle was bearish (transition point)
|
||||
if (current['ha_close'] > current['ha_open'] and
|
||||
prev['ha_close'] > prev['ha_open'] and
|
||||
prev2['ha_close'] < prev2['ha_open']):
|
||||
|
||||
signal_data = {
|
||||
'price': df.iloc[i]['close'],
|
||||
'ha_open': current['ha_open'],
|
||||
'ha_close': current['ha_close'],
|
||||
'ha_high': current['ha_high'],
|
||||
'ha_low': current['ha_low']
|
||||
}
|
||||
signals.append((True, df.iloc[i]['date'], signal_data))
|
||||
|
||||
return signals
|
||||
|
||||
def run_heikin_ashi_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None) -> None:
|
||||
"""
|
||||
Run Heikin Ashi scanner to find bullish reversal patterns
|
||||
"""
|
||||
try:
|
||||
# Initialize scanner components
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume,
|
||||
portfolio_size=portfolio_size,
|
||||
interval=interval,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if not qualified_stocks:
|
||||
return
|
||||
|
||||
bullish_signals = []
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
try:
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
if df.empty or len(df) < 3: # Need at least 3 bars
|
||||
continue
|
||||
|
||||
signals = check_entry_signal(df)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
signal_data['date'] = signal_date
|
||||
entry_data = process_signal_data(
|
||||
ticker, signal_data, current_volume,
|
||||
last_update, stock_type, calculator
|
||||
)
|
||||
bullish_signals.append(entry_data)
|
||||
print_signal(entry_data, "🟢")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
save_signals_to_csv(bullish_signals, 'heikin_ashi')
|
||||
return bullish_signals
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
@ -1,125 +0,0 @@
|
||||
import pandas as pd
|
||||
import talib
|
||||
from datetime import datetime
|
||||
from utils.data_utils import (
|
||||
get_stock_data, validate_signal_date, print_signal,
|
||||
save_signals_to_csv, get_qualified_stocks
|
||||
)
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from indicators.sunny_bands import SunnyBands
|
||||
|
||||
def check_entry_signal(df: pd.DataFrame) -> list:
|
||||
"""
|
||||
Check for entry signals based on combined Sunny Bands and SMA strategy
|
||||
|
||||
Conditions:
|
||||
1. 21 SMA below lowest Sunny Band
|
||||
2. Price crosses above 21 SMA
|
||||
3. Price still below Sunny Bands
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with price data
|
||||
|
||||
Returns:
|
||||
list: List of tuples (signal, date, signal_data) for each signal found
|
||||
"""
|
||||
if len(df) < 21: # Need at least 21 bars for SMA
|
||||
return []
|
||||
|
||||
# Calculate Sunny Bands
|
||||
sunny = SunnyBands()
|
||||
sunny_results = sunny.calculate(df)
|
||||
|
||||
# Calculate 21 day SMA
|
||||
sma21 = talib.SMA(df['close'].values, timeperiod=21)
|
||||
|
||||
signals = []
|
||||
|
||||
# Start from index 21 to ensure we have enough data for SMA
|
||||
for i in range(21, len(df)):
|
||||
current = df.iloc[i]
|
||||
prev = df.iloc[i-1]
|
||||
current_bands = sunny_results.iloc[i]
|
||||
current_sma = sma21[i]
|
||||
|
||||
# Check conditions:
|
||||
# 1. SMA below lower band
|
||||
sma_below_band = current_sma < current_bands['lower_band']
|
||||
|
||||
# 2. Price crosses above SMA
|
||||
price_cross_sma = (current['close'] > current_sma) and (prev['close'] < sma21[i-1])
|
||||
|
||||
# 3. Price still below lower band
|
||||
price_below_band = current['close'] < current_bands['lower_band']
|
||||
|
||||
if sma_below_band and price_cross_sma and price_below_band:
|
||||
signal_data = {
|
||||
'price': current['close'],
|
||||
'sma21': current_sma,
|
||||
'upper_band': current_bands['upper_band'],
|
||||
'lower_band': current_bands['lower_band'],
|
||||
'dma': current_bands['dma']
|
||||
}
|
||||
signals.append((True, current['date'], signal_data))
|
||||
|
||||
return signals
|
||||
|
||||
def run_sunny_sma_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None) -> None:
|
||||
"""
|
||||
Run scanner combining Sunny Bands and 21 SMA strategy
|
||||
"""
|
||||
try:
|
||||
# Initialize scanner components
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume,
|
||||
portfolio_size=portfolio_size,
|
||||
interval=interval,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if not qualified_stocks:
|
||||
return
|
||||
|
||||
bullish_signals = []
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
try:
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
if df.empty or len(df) < 50: # Need at least 50 bars for the indicators
|
||||
continue
|
||||
|
||||
signals = check_entry_signal(df)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
# Custom print for Sunny-SMA signals
|
||||
print(f"🌞 {ticker}: SMA-21 Cross at ${signal_data['price']:.2f} on {signal_date.strftime('%Y-%m-%d')}")
|
||||
print(f" SMA: ${signal_data['sma21']:.2f}")
|
||||
print(f" Lower Band: ${signal_data['lower_band']:.2f}")
|
||||
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'signal_date': signal_date,
|
||||
'entry_price': signal_data['price'],
|
||||
'sma21': signal_data['sma21'],
|
||||
'lower_band': signal_data['lower_band'],
|
||||
'volume': current_volume,
|
||||
'last_update': last_update,
|
||||
'stock_type': stock_type
|
||||
}
|
||||
bullish_signals.append(entry_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
save_signals_to_csv(bullish_signals, 'sunny_sma')
|
||||
return bullish_signals
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
@ -1,11 +1,8 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
from db.db_connection import create_client
|
||||
from utils.data_utils import (
|
||||
get_stock_data, validate_signal_date, print_signal,
|
||||
save_signals_to_csv, get_qualified_stocks
|
||||
)
|
||||
from utils.scanner_utils import initialize_scanner, process_signal_data
|
||||
from indicators.sunny_bands import SunnyBands
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from screener.user_input import get_interval_choice, get_date_range
|
||||
@ -202,50 +199,86 @@ def view_stock_details(ticker: str, interval: str, start_date: datetime, end_dat
|
||||
except Exception as e:
|
||||
print(f"Error analyzing {ticker}: {str(e)}")
|
||||
|
||||
def run_sunny_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None) -> None:
|
||||
def run_sunny_scanner(min_price: float, max_price: float, min_volume: int, portfolio_size: float = None) -> None:
|
||||
print(f"\nScanning for stocks ${min_price:.2f}-${max_price:.2f} with min volume {min_volume:,}")
|
||||
|
||||
interval = get_interval_choice()
|
||||
|
||||
# Get date range from user input
|
||||
start_date, end_date = get_date_range()
|
||||
|
||||
# First get qualified stocks from database
|
||||
# Convert dates to Unix timestamp in nanoseconds
|
||||
end_ts = int(end_date.timestamp() * 1000000000)
|
||||
start_ts = int(start_date.timestamp() * 1000000000)
|
||||
|
||||
try:
|
||||
# Initialize scanner components with all parameters
|
||||
interval, start_date, end_date, qualified_stocks, calculator = initialize_scanner(
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_volume=min_volume,
|
||||
portfolio_size=portfolio_size,
|
||||
interval=interval,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
qualified_stocks = get_qualified_stocks(start_date, end_date, min_price, max_price, min_volume)
|
||||
|
||||
if not qualified_stocks:
|
||||
print("No stocks found matching criteria.")
|
||||
return
|
||||
|
||||
bullish_signals = []
|
||||
|
||||
print(f"\nFound {len(qualified_stocks)} stocks matching criteria")
|
||||
|
||||
for ticker, current_price, current_volume, last_update, stock_type in qualified_stocks:
|
||||
# Initialize indicators
|
||||
sunny = SunnyBands()
|
||||
calculator = None
|
||||
if portfolio_size and portfolio_size > 0:
|
||||
calculator = PositionCalculator(
|
||||
account_size=portfolio_size,
|
||||
risk_percentage=1.0,
|
||||
stop_loss_percentage=7.0 # Explicit 7% stop loss
|
||||
)
|
||||
|
||||
bullish_signals = []
|
||||
bearish_signals = []
|
||||
|
||||
# Process each qualified stock
|
||||
for ticker, current_price, current_volume, last_update in qualified_stocks:
|
||||
try:
|
||||
# Get historical data based on interval
|
||||
df = get_stock_data(ticker, start_date, end_date, interval)
|
||||
|
||||
if df.empty or len(df) < 50: # Need at least 50 bars for the indicator
|
||||
continue
|
||||
|
||||
# Check for signals throughout the date range
|
||||
signals = check_entry_signal(df)
|
||||
for signal, signal_date, signal_data in signals:
|
||||
signal_data['date'] = signal_date
|
||||
entry_data = process_signal_data(
|
||||
ticker, signal_data, current_volume,
|
||||
last_update, stock_type, calculator
|
||||
)
|
||||
bullish_signals.append(entry_data)
|
||||
print_signal(entry_data, "🟢")
|
||||
if calculator:
|
||||
try:
|
||||
position = calculator.calculate_position_size(
|
||||
entry_price=signal_data['price'],
|
||||
target_price=signal_data['upper_band']
|
||||
)
|
||||
if position['shares'] > 0:
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'entry_price': signal_data['price'],
|
||||
'target_price': signal_data['upper_band'],
|
||||
'signal_date': signal_date,
|
||||
'volume': current_volume,
|
||||
'last_update': datetime.fromtimestamp(last_update/1000000000),
|
||||
'shares': position['shares'],
|
||||
'position_size': position['position_value'],
|
||||
'stop_loss': signal_data['price'] * 0.93, # 7% stop loss
|
||||
'risk_amount': position['potential_loss'],
|
||||
'profit_amount': position['potential_profit'],
|
||||
'risk_reward_ratio': position['risk_reward_ratio']
|
||||
}
|
||||
bullish_signals.append(entry_data)
|
||||
print_signal(entry_data, "🟢")
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Skipping {ticker} position: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {ticker}: {str(e)}")
|
||||
continue
|
||||
|
||||
save_signals_to_csv(bullish_signals, 'sunny')
|
||||
return bullish_signals
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during scan: {str(e)}")
|
||||
return []
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from db.db_connection import create_client
|
||||
from pages.journal.trading_journal_page import trading_journal_page, format_datetime
|
||||
from pages.screener.technical_scanner_page import technical_scanner_page
|
||||
from pages.trading.trading_system_page import trading_system_page
|
||||
from pages.trading.trading_plan_page import trading_plan_page
|
||||
from pages.backtesting.backtesting_page import backtesting_page
|
||||
from pages.analysis.monte_carlo_page import monte_carlo_page
|
||||
from pages.analysis.ai_forecast_page import ai_forecast_page
|
||||
from trading.journal import (
|
||||
create_trades_table, get_open_trades, get_trade_history,
|
||||
get_latest_portfolio_value, update_portfolio_value
|
||||
)
|
||||
from trading.trading_plan import (
|
||||
create_trading_plan_table,
|
||||
)
|
||||
from pages.rules.strategy_guide_page import strategy_guide_page
|
||||
from pages.screener.canslim_screener_page import canslim_screener_page, load_scanner_reports
|
||||
|
||||
def init_session_state():
|
||||
"""Initialize session state variables"""
|
||||
if 'page' not in st.session_state:
|
||||
st.session_state.page = 'Trading Journal'
|
||||
|
||||
def main():
|
||||
st.set_page_config(page_title="Trading System", layout="wide")
|
||||
init_session_state()
|
||||
|
||||
# Sidebar navigation
|
||||
st.sidebar.title("Navigation")
|
||||
st.session_state.page = st.sidebar.radio(
|
||||
"Go to",
|
||||
["Strategy Guide", "Trading Journal", "Technical Scanner", "CANSLIM Screener",
|
||||
"Trading System", "Trading Plans", "Backtesting", "Monte Carlo Analysis",
|
||||
"AI Stock Forecast"]
|
||||
)
|
||||
|
||||
# Create necessary tables
|
||||
create_trades_table()
|
||||
create_trading_plan_table()
|
||||
|
||||
# Display selected page
|
||||
if st.session_state.page == "Strategy Guide":
|
||||
strategy_guide_page()
|
||||
elif st.session_state.page == "Trading Journal":
|
||||
trading_journal_page()
|
||||
elif st.session_state.page == "Technical Scanner":
|
||||
technical_scanner_page()
|
||||
elif st.session_state.page == "CANSLIM Screener":
|
||||
canslim_screener_page()
|
||||
elif st.session_state.page == "Trading System":
|
||||
trading_system_page()
|
||||
elif st.session_state.page == "Trading Plans":
|
||||
trading_plan_page()
|
||||
elif st.session_state.page == "Backtesting":
|
||||
backtesting_page()
|
||||
elif st.session_state.page == "Monte Carlo Analysis":
|
||||
monte_carlo_page()
|
||||
elif st.session_state.page == "AI Stock Forecast":
|
||||
ai_forecast_page()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,983 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import pytz
|
||||
from zoneinfo import ZoneInfo
|
||||
import yfinance as yf
|
||||
from db.db_connection import create_client
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from utils.data_utils import get_user_input, get_current_prices
|
||||
|
||||
def handle_sell_order(ticker: str, shares_to_sell: int, exit_price: float, exit_date: datetime,
|
||||
order_type: str, followed_rules: bool, exit_reason: str, notes: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Handle sell order using FIFO logic
|
||||
|
||||
Args:
|
||||
ticker (str): Stock ticker
|
||||
shares_to_sell (int): Number of shares to sell
|
||||
exit_price (float): Exit price per share
|
||||
exit_date (datetime): Exit date and time
|
||||
order_type (str): Order type (Market/Limit)
|
||||
followed_rules (bool): Whether trading rules were followed
|
||||
exit_reason (str): Reason for exit
|
||||
notes (Optional[str]): Additional notes
|
||||
|
||||
Returns:
|
||||
bool: True if sell order was processed successfully
|
||||
"""
|
||||
with create_client() as client:
|
||||
# Get open positions for this ticker ordered by entry date (FIFO)
|
||||
query = f"""
|
||||
SELECT id, shares, entry_price, position_id
|
||||
FROM stock_db.trades
|
||||
WHERE ticker = '{ticker}'
|
||||
AND exit_price IS NULL
|
||||
ORDER BY entry_date ASC
|
||||
"""
|
||||
result = client.query(query).result_rows
|
||||
|
||||
if not result:
|
||||
print(f"No open positions found for {ticker}")
|
||||
return False
|
||||
|
||||
remaining_shares = shares_to_sell
|
||||
positions = [dict(zip(['id', 'shares', 'entry_price', 'position_id'], row)) for row in result]
|
||||
total_available_shares = sum(pos['shares'] for pos in positions)
|
||||
|
||||
if shares_to_sell > total_available_shares:
|
||||
print(f"Error: Attempting to sell {shares_to_sell} shares but only {total_available_shares} available")
|
||||
return False
|
||||
|
||||
for position in positions:
|
||||
if remaining_shares <= 0:
|
||||
break
|
||||
|
||||
shares_from_position = min(remaining_shares, position['shares'])
|
||||
|
||||
if shares_from_position == position['shares']:
|
||||
# Close entire position
|
||||
update_trade(
|
||||
trade_id=position['id'],
|
||||
updates={
|
||||
'exit_price': exit_price,
|
||||
'exit_date': exit_date,
|
||||
'followed_rules': 1 if followed_rules else 0,
|
||||
'exit_reason': exit_reason,
|
||||
'notes': notes
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Split position: update original position with remaining shares
|
||||
# and create a new closed position for the sold shares
|
||||
new_position_shares = position['shares'] - shares_from_position
|
||||
|
||||
# Update original position with reduced shares
|
||||
client.command(f"""
|
||||
ALTER TABLE stock_db.trades
|
||||
UPDATE shares = {new_position_shares}
|
||||
WHERE id = {position['id']}
|
||||
""")
|
||||
|
||||
# Create new record for the sold portion
|
||||
trade = TradeEntry(
|
||||
ticker=ticker,
|
||||
entry_date=datetime.now(), # Use original entry date?
|
||||
shares=shares_from_position,
|
||||
entry_price=position['entry_price'],
|
||||
target_price=0, # Not relevant for closed portion
|
||||
stop_loss=0, # Not relevant for closed portion
|
||||
strategy="FIFO_SPLIT",
|
||||
order_type=order_type,
|
||||
position_id=position['position_id'],
|
||||
followed_rules=followed_rules,
|
||||
exit_price=exit_price,
|
||||
exit_date=exit_date,
|
||||
exit_reason=exit_reason,
|
||||
notes=notes
|
||||
)
|
||||
add_trade(trade)
|
||||
|
||||
remaining_shares -= shares_from_position
|
||||
|
||||
return True
|
||||
|
||||
def create_portfolio_table():
|
||||
with create_client() as client:
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS stock_db.portfolio_history (
|
||||
id UInt32,
|
||||
date DateTime,
|
||||
total_value Float64,
|
||||
cash_balance Float64,
|
||||
notes Nullable(String),
|
||||
created_at DateTime DEFAULT now()
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (date, id)
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
def update_portfolio_value(total_value: float, cash_balance: float, notes: Optional[str] = None):
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
INSERT INTO stock_db.portfolio_history (
|
||||
id, date, total_value, cash_balance, notes
|
||||
) VALUES (
|
||||
{generate_id()},
|
||||
'{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}',
|
||||
{total_value},
|
||||
{cash_balance},
|
||||
{f"'{notes}'" if notes else 'NULL'}
|
||||
)
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
def get_latest_portfolio_value() -> Optional[dict]:
|
||||
with create_client() as client:
|
||||
query = """
|
||||
SELECT total_value, cash_balance, date
|
||||
FROM stock_db.portfolio_history
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
result = client.query(query).result_rows
|
||||
if result:
|
||||
return {
|
||||
'total_value': result[0][0],
|
||||
'cash_balance': result[0][1],
|
||||
'date': result[0][2]
|
||||
}
|
||||
return None
|
||||
|
||||
@dataclass
|
||||
class TradeEntry:
|
||||
ticker: str
|
||||
entry_date: datetime
|
||||
shares: int
|
||||
entry_price: float
|
||||
target_price: Optional[float] # Made optional since sell orders don't need it
|
||||
stop_loss: Optional[float] # Made optional since sell orders don't need it
|
||||
strategy: Optional[str] # Made optional since sell orders might not need it
|
||||
order_type: str # Market/Limit
|
||||
direction: str # 'buy' or 'sell'
|
||||
followed_rules: Optional[bool] = None
|
||||
entry_reason: Optional[str] = None
|
||||
exit_price: Optional[float] = None
|
||||
exit_date: Optional[datetime] = None
|
||||
exit_reason: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
position_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# For sell orders, set exit_price and exit_date
|
||||
if self.direction.lower() == 'sell':
|
||||
self.exit_price = self.entry_price
|
||||
self.exit_date = self.entry_date
|
||||
self.entry_price = None # Set to NULL for sells
|
||||
self.entry_date = None # Set to NULL for sells
|
||||
|
||||
@property
|
||||
def expected_profit_loss(self) -> Optional[float]:
|
||||
if self.direction == 'buy' and self.target_price:
|
||||
return (self.target_price - self.entry_price) * self.shares
|
||||
return None
|
||||
|
||||
@property
|
||||
def max_loss(self) -> Optional[float]:
|
||||
if self.direction == 'buy' and self.stop_loss:
|
||||
return (self.stop_loss - self.entry_price) * self.shares
|
||||
return None
|
||||
|
||||
@property
|
||||
def actual_profit_loss(self) -> Optional[float]:
|
||||
if self.exit_price:
|
||||
return (self.exit_price - self.entry_price) * self.shares
|
||||
return None
|
||||
|
||||
def get_market_hours(date: datetime) -> tuple:
|
||||
"""Get market open/close times in Eastern for given date"""
|
||||
eastern = pytz.timezone('US/Eastern')
|
||||
date_eastern = date.astimezone(eastern)
|
||||
|
||||
market_open = eastern.localize(
|
||||
datetime.combine(date_eastern.date(), datetime.strptime("09:30", "%H:%M").time())
|
||||
)
|
||||
market_close = eastern.localize(
|
||||
datetime.combine(date_eastern.date(), datetime.strptime("16:00", "%H:%M").time())
|
||||
)
|
||||
return market_open, market_close
|
||||
|
||||
def validate_market_time(dt: datetime) -> tuple[datetime, bool]:
|
||||
"""
|
||||
Validate if time is during market hours, adjust if needed
|
||||
Returns: (adjusted_datetime, was_adjusted)
|
||||
"""
|
||||
pacific = pytz.timezone('US/Pacific')
|
||||
eastern = pytz.timezone('US/Eastern')
|
||||
|
||||
# Ensure datetime is timezone-aware
|
||||
if dt.tzinfo is None:
|
||||
dt = pacific.localize(dt)
|
||||
|
||||
dt_eastern = dt.astimezone(eastern)
|
||||
market_open, market_close = get_market_hours(dt_eastern)
|
||||
|
||||
if dt_eastern < market_open:
|
||||
return market_open.astimezone(pacific), True
|
||||
elif dt_eastern > market_close:
|
||||
return market_close.astimezone(pacific), True
|
||||
|
||||
return dt, False
|
||||
|
||||
def get_datetime_input(prompt: str, default: datetime = None) -> Optional[datetime]:
|
||||
"""Get date and time input in Pacific time"""
|
||||
pacific = pytz.timezone('US/Pacific')
|
||||
|
||||
while True:
|
||||
try:
|
||||
if default:
|
||||
print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})")
|
||||
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM, q to quit): ").strip()
|
||||
|
||||
if date_str.lower() in ['q', 'quit', 'exit']:
|
||||
return None
|
||||
|
||||
if not date_str and default:
|
||||
dt = default
|
||||
else:
|
||||
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
|
||||
|
||||
# Make datetime timezone-aware (Pacific)
|
||||
dt = pacific.localize(dt)
|
||||
|
||||
# Validate market hours
|
||||
adjusted_dt, was_adjusted = validate_market_time(dt)
|
||||
if was_adjusted:
|
||||
print(f"\nWarning: Time adjusted to market hours (Eastern)")
|
||||
print(f"Original (Pacific): {dt.strftime('%Y-%m-%d %H:%M %Z')}")
|
||||
print(f"Adjusted (Pacific): {adjusted_dt.strftime('%Y-%m-%d %H:%M %Z')}")
|
||||
print(f"Adjusted (Eastern): {adjusted_dt.astimezone(pytz.timezone('US/Eastern')).strftime('%Y-%m-%d %H:%M %Z')}")
|
||||
if input("Accept adjusted time? (y/n): ").lower() != 'y':
|
||||
continue
|
||||
|
||||
return adjusted_dt
|
||||
|
||||
except ValueError:
|
||||
print("Invalid format. Please use YYYY-MM-D HH:MM")
|
||||
|
||||
def create_trades_table():
|
||||
with create_client() as client:
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS stock_db.trades (
|
||||
id UInt32,
|
||||
position_id String,
|
||||
ticker String,
|
||||
entry_date DateTime,
|
||||
shares UInt32,
|
||||
entry_price Float64,
|
||||
target_price Nullable(Float64),
|
||||
stop_loss Nullable(Float64),
|
||||
strategy Nullable(String),
|
||||
order_type String,
|
||||
direction String,
|
||||
followed_rules Nullable(UInt8),
|
||||
entry_reason Nullable(String),
|
||||
exit_price Nullable(Float64),
|
||||
exit_date Nullable(DateTime),
|
||||
exit_reason Nullable(String),
|
||||
notes Nullable(String),
|
||||
created_at DateTime DEFAULT now(),
|
||||
plan_id Nullable(UInt32)
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (position_id, id, entry_date)
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
def generate_id() -> int:
|
||||
"""Generate a unique ID for the trade"""
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def generate_position_id(ticker: str, entry_date: datetime = None) -> str:
|
||||
"""
|
||||
Generate a unique position ID for grouping related trades
|
||||
|
||||
Args:
|
||||
ticker (str): Stock ticker symbol
|
||||
entry_date (datetime, optional): Entry date for the trade
|
||||
|
||||
Returns:
|
||||
str: Position ID in format TICKER_YYYYMMDDHHMMSS
|
||||
"""
|
||||
if entry_date is None:
|
||||
entry_date = datetime.now()
|
||||
timestamp = entry_date.strftime("%Y%m%d%H%M%S")
|
||||
return f"{ticker}_{timestamp}"
|
||||
|
||||
def get_position_summary(ticker: str) -> dict:
|
||||
"""Get summary of existing positions for a ticker"""
|
||||
print(f"\n=== Getting Position Summary for {ticker} ===") # Debug
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
SELECT
|
||||
position_id,
|
||||
sum(shares) as total_shares,
|
||||
avg(entry_price) as avg_entry_price,
|
||||
min(entry_date) as first_entry,
|
||||
max(entry_date) as last_entry,
|
||||
count() as num_orders,
|
||||
any(target_price) as target_price,
|
||||
any(stop_loss) as stop_loss,
|
||||
any(strategy) as strategy
|
||||
FROM stock_db.trades
|
||||
WHERE ticker = '{ticker}'
|
||||
AND exit_price IS NULL
|
||||
GROUP BY position_id
|
||||
ORDER BY first_entry DESC
|
||||
"""
|
||||
print(f"Executing query: {query}") # Debug
|
||||
result = client.query(query).result_rows
|
||||
print(f"Query returned {len(result)} rows") # Debug
|
||||
|
||||
columns = ['position_id', 'total_shares', 'avg_entry_price',
|
||||
'first_entry', 'last_entry', 'num_orders',
|
||||
'target_price', 'stop_loss', 'strategy']
|
||||
positions = [dict(zip(columns, row)) for row in result]
|
||||
print(f"Processed positions: {positions}") # Debug
|
||||
return positions
|
||||
|
||||
def get_order_type() -> Optional[str]:
|
||||
"""Get order type from user"""
|
||||
while True:
|
||||
print("\nOrder Type:")
|
||||
print("1. Market")
|
||||
print("2. Limit")
|
||||
print("q. Quit")
|
||||
choice = input("Select order type (1-2, q to quit): ")
|
||||
|
||||
if choice.lower() in ['q', 'quit', 'exit']:
|
||||
return None
|
||||
elif choice == "1":
|
||||
return "Market"
|
||||
elif choice == "2":
|
||||
return "Limit"
|
||||
print("Invalid choice. Please try again.")
|
||||
|
||||
def add_trade(trade: TradeEntry):
|
||||
"""Add a new trade to the database"""
|
||||
with create_client() as client:
|
||||
# For sell orders, use exit_price and exit_date instead of entry_price and entry_date
|
||||
query = f"""
|
||||
INSERT INTO stock_db.trades (
|
||||
id, position_id, ticker, entry_date, shares, entry_price, target_price,
|
||||
stop_loss, strategy, order_type, direction, followed_rules, entry_reason,
|
||||
exit_price, exit_date, exit_reason, notes
|
||||
) VALUES (
|
||||
{generate_id()},
|
||||
'{trade.position_id}',
|
||||
'{trade.ticker}',
|
||||
{f"'{trade.entry_date.strftime('%Y-%m-%d %H:%M:%S')}'" if trade.entry_date else 'NULL'},
|
||||
{trade.shares},
|
||||
{trade.entry_price if trade.entry_price else 'NULL'},
|
||||
{trade.target_price if trade.target_price else 'NULL'},
|
||||
{trade.stop_loss if trade.stop_loss else 'NULL'},
|
||||
{f"'{trade.strategy}'" if trade.strategy else 'NULL'},
|
||||
'{trade.order_type}',
|
||||
'{trade.direction.lower()}',
|
||||
{1 if trade.followed_rules else 0},
|
||||
{f"'{trade.entry_reason}'" if trade.entry_reason else 'NULL'},
|
||||
{trade.exit_price if trade.exit_price else 'NULL'},
|
||||
{f"'{trade.exit_date.strftime('%Y-%m-%d %H:%M:%S')}'" if trade.exit_date else 'NULL'},
|
||||
{f"'{trade.exit_reason}'" if trade.exit_reason else 'NULL'},
|
||||
{f"'{trade.notes}'" if trade.notes else 'NULL'}
|
||||
)
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
def update_trade(trade_id: int, updates: dict):
|
||||
"""
|
||||
Update trade details
|
||||
|
||||
Args:
|
||||
trade_id (int): ID of trade to update
|
||||
updates (dict): Dictionary of fields and values to update
|
||||
"""
|
||||
with create_client() as client:
|
||||
# If trying to update entry_date, we need to delete and reinsert
|
||||
if 'entry_date' in updates:
|
||||
# First get the full trade data
|
||||
query = f"SELECT * FROM stock_db.trades WHERE id = {trade_id}"
|
||||
result = client.query(query).result_rows
|
||||
if not result:
|
||||
raise Exception("Trade not found")
|
||||
|
||||
# Delete the existing trade
|
||||
client.command(f"ALTER TABLE stock_db.trades DELETE WHERE id = {trade_id}")
|
||||
|
||||
# Prepare the new trade data
|
||||
columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price',
|
||||
'target_price', 'stop_loss', 'strategy', 'order_type', 'followed_rules',
|
||||
'entry_reason', 'exit_price', 'exit_date', 'exit_reason', 'notes', 'created_at']
|
||||
trade_data = dict(zip(columns, result[0]))
|
||||
trade_data.update(updates)
|
||||
|
||||
# Insert the updated trade
|
||||
query = f"""
|
||||
INSERT INTO stock_db.trades (
|
||||
id, position_id, ticker, entry_date, shares, entry_price, target_price,
|
||||
stop_loss, strategy, order_type, followed_rules, entry_reason, exit_price,
|
||||
exit_date, exit_reason, notes
|
||||
) VALUES (
|
||||
{trade_id},
|
||||
'{trade_data['position_id']}',
|
||||
'{trade_data['ticker']}',
|
||||
'{trade_data['entry_date'].strftime('%Y-%m-%d %H:%M:%S')}',
|
||||
{trade_data['shares']},
|
||||
{trade_data['entry_price']},
|
||||
{trade_data['target_price']},
|
||||
{trade_data['stop_loss']},
|
||||
'{trade_data['strategy']}',
|
||||
'{trade_data['order_type']}',
|
||||
{1 if trade_data['followed_rules'] else 0},
|
||||
{f"'{trade_data['entry_reason']}'" if trade_data['entry_reason'] else 'NULL'},
|
||||
{trade_data['exit_price'] if trade_data['exit_price'] else 'NULL'},
|
||||
{f"'{trade_data['exit_date'].strftime('%Y-%m-%d %H:%M:%S')}'" if trade_data['exit_date'] else 'NULL'},
|
||||
{f"'{trade_data['exit_reason']}'" if trade_data['exit_reason'] else 'NULL'},
|
||||
{f"'{trade_data['notes']}'" if trade_data['notes'] else 'NULL'}
|
||||
)
|
||||
"""
|
||||
client.command(query)
|
||||
else:
|
||||
# For non-key columns, we can use regular UPDATE
|
||||
update_statements = []
|
||||
for field, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
update_statements.append(f"{field} = '{value}'")
|
||||
elif isinstance(value, datetime):
|
||||
update_statements.append(f"{field} = '{value.strftime('%Y-%m-%d %H:%M:%S')}'")
|
||||
elif value is None:
|
||||
update_statements.append(f"{field} = NULL")
|
||||
else:
|
||||
update_statements.append(f"{field} = {value}")
|
||||
|
||||
update_clause = ", ".join(update_statements)
|
||||
|
||||
query = f"""
|
||||
ALTER TABLE stock_db.trades
|
||||
UPDATE {update_clause}
|
||||
WHERE id = {trade_id}
|
||||
"""
|
||||
client.command(query)
|
||||
|
||||
def get_open_trades_summary() -> dict:
|
||||
"""Get summary of all open trades grouped by ticker"""
|
||||
print("\n=== Fetching Open Trades Summary ===") # Debug
|
||||
with create_client() as client:
|
||||
query = """
|
||||
WITH position_totals AS (
|
||||
SELECT
|
||||
ticker,
|
||||
position_id,
|
||||
sum(CASE
|
||||
WHEN direction = 'sell' THEN -shares
|
||||
ELSE shares
|
||||
END) as net_shares
|
||||
FROM stock_db.trades
|
||||
GROUP BY ticker, position_id
|
||||
HAVING net_shares > 0
|
||||
)
|
||||
SELECT
|
||||
t.ticker,
|
||||
sum(CASE
|
||||
WHEN t.direction = 'sell' THEN -t.shares
|
||||
ELSE t.shares
|
||||
END) as total_shares,
|
||||
avg(t.entry_price) as avg_entry_price,
|
||||
min(t.entry_date) as first_entry,
|
||||
max(t.entry_date) as last_entry,
|
||||
count() as num_orders,
|
||||
groupArray(t.position_id) as position_ids
|
||||
FROM stock_db.trades t
|
||||
INNER JOIN position_totals pt
|
||||
ON t.ticker = pt.ticker
|
||||
AND t.position_id = pt.position_id
|
||||
GROUP BY t.ticker
|
||||
HAVING total_shares > 0
|
||||
ORDER BY t.ticker ASC
|
||||
"""
|
||||
print(f"Executing summary query: {query}") # Debug
|
||||
try:
|
||||
result = client.query(query).result_rows
|
||||
print(f"Summary query returned {len(result)} rows") # Debug
|
||||
|
||||
columns = ['ticker', 'total_shares', 'avg_entry_price',
|
||||
'first_entry', 'last_entry', 'num_orders', 'position_ids']
|
||||
summaries = [dict(zip(columns, row)) for row in result]
|
||||
print(f"Processed summaries: {summaries}") # Debug
|
||||
return summaries
|
||||
except Exception as e:
|
||||
print(f"Error in get_open_trades_summary: {str(e)}") # Debug
|
||||
raise
|
||||
|
||||
def get_open_trades():
|
||||
print("\n=== Fetching Open Trades ===") # Debug
|
||||
with create_client() as client:
|
||||
query = """
|
||||
WITH position_totals AS (
|
||||
SELECT
|
||||
ticker,
|
||||
position_id,
|
||||
sum(CASE
|
||||
WHEN direction = 'sell' THEN -shares
|
||||
ELSE shares
|
||||
END) as net_shares
|
||||
FROM stock_db.trades
|
||||
GROUP BY ticker, position_id
|
||||
HAVING net_shares > 0
|
||||
)
|
||||
SELECT t.*
|
||||
FROM stock_db.trades t
|
||||
INNER JOIN position_totals pt
|
||||
ON t.ticker = pt.ticker
|
||||
AND t.position_id = pt.position_id
|
||||
ORDER BY t.entry_date DESC
|
||||
"""
|
||||
print(f"Executing query: {query}") # Debug
|
||||
try:
|
||||
result = client.query(query).result_rows
|
||||
print(f"Query returned {len(result)} rows") # Debug
|
||||
|
||||
columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price', 'target_price',
|
||||
'stop_loss', 'strategy', 'order_type', 'followed_rules', 'entry_reason', 'exit_price',
|
||||
'exit_date', 'exit_reason', 'notes', 'created_at']
|
||||
trades = [dict(zip(columns, row)) for row in result]
|
||||
print(f"Processed trades: {trades}") # Debug
|
||||
return trades
|
||||
except Exception as e:
|
||||
print(f"Error in get_open_trades: {str(e)}") # Debug
|
||||
raise
|
||||
|
||||
|
||||
def delete_trade(trade_id: int) -> bool:
|
||||
"""
|
||||
Delete a trade from the database
|
||||
|
||||
Args:
|
||||
trade_id (int): ID of trade to delete
|
||||
|
||||
Returns:
|
||||
bool: True if deletion was successful
|
||||
"""
|
||||
try:
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
ALTER TABLE stock_db.trades
|
||||
DELETE WHERE id = {trade_id}
|
||||
"""
|
||||
client.command(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting trade: {e}")
|
||||
return False
|
||||
|
||||
def get_trade_history(limit: int = 50):
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
SELECT
|
||||
id, position_id, ticker, entry_date, shares, entry_price,
|
||||
target_price, stop_loss, strategy, order_type, direction,
|
||||
followed_rules, entry_reason, exit_price, exit_date,
|
||||
exit_reason, notes, created_at,
|
||||
groupArray(ticker) OVER (PARTITION BY position_id) as related_tickers,
|
||||
groupArray(id) OVER (PARTITION BY position_id) as related_ids
|
||||
FROM stock_db.trades
|
||||
ORDER BY entry_date DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
result = client.query(query).result_rows
|
||||
columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares',
|
||||
'entry_price', 'target_price', 'stop_loss', 'strategy',
|
||||
'order_type', 'direction', 'followed_rules', 'entry_reason',
|
||||
'exit_price', 'exit_date', 'exit_reason', 'notes',
|
||||
'created_at', 'related_tickers', 'related_ids']
|
||||
return [dict(zip(columns, row)) for row in result]
|
||||
|
||||
def journal_menu():
|
||||
"""Trading journal menu interface"""
|
||||
create_trades_table() # Ensure table exists
|
||||
|
||||
while True:
|
||||
print("\nTrading Journal")
|
||||
print("1. Add New Trade")
|
||||
print("2. Update Existing Trade")
|
||||
print("3. View Open Trades")
|
||||
print("4. View Trade History")
|
||||
print("5. Delete Trade")
|
||||
print("6. Return to Main Menu")
|
||||
|
||||
choice = input("\nSelect an option (1-5): ")
|
||||
|
||||
if choice == "1":
|
||||
ticker = get_user_input("Enter ticker symbol:", str)
|
||||
if ticker is None:
|
||||
continue
|
||||
ticker = ticker.upper()
|
||||
|
||||
# Ask if this is a buy or sell order
|
||||
print("\nOrder Direction:")
|
||||
print("1. Buy")
|
||||
print("2. Sell")
|
||||
direction = get_user_input("Select direction (1-2):", str)
|
||||
if direction is None:
|
||||
continue
|
||||
|
||||
if direction not in ["1", "2"]:
|
||||
print("Invalid direction")
|
||||
continue
|
||||
|
||||
if direction == "2": # Sell order
|
||||
shares = get_user_input("Enter number of shares to sell:", int)
|
||||
if shares is None:
|
||||
continue
|
||||
|
||||
exit_price = get_user_input("Enter exit price:", float)
|
||||
if exit_price is None:
|
||||
continue
|
||||
|
||||
exit_date = get_datetime_input("Enter exit date and time", default=datetime.now())
|
||||
if exit_date is None:
|
||||
continue
|
||||
|
||||
order_type = get_order_type()
|
||||
if order_type is None:
|
||||
continue
|
||||
|
||||
followed_rules = get_user_input("Did you follow your rules? (y/n):", bool)
|
||||
if followed_rules is None:
|
||||
continue
|
||||
|
||||
exit_reason = input("Enter exit reason: ")
|
||||
notes = input("Additional notes (optional): ") or None
|
||||
|
||||
if handle_sell_order(
|
||||
ticker=ticker,
|
||||
shares_to_sell=shares,
|
||||
exit_price=exit_price,
|
||||
exit_date=exit_date,
|
||||
order_type=order_type,
|
||||
followed_rules=followed_rules,
|
||||
exit_reason=exit_reason,
|
||||
notes=notes
|
||||
):
|
||||
print("Sell order processed successfully!")
|
||||
continue
|
||||
|
||||
# Show existing positions for this ticker
|
||||
existing_positions = get_position_summary(ticker)
|
||||
if existing_positions:
|
||||
print(f"\nExisting {ticker} Positions:")
|
||||
for pos in existing_positions:
|
||||
print(f"\nPosition ID: {pos['position_id']}")
|
||||
print(f"Total Shares: {pos['total_shares']}")
|
||||
print(f"Average Entry: ${pos['avg_entry_price']:.2f}")
|
||||
print(f"First Entry: {pos['first_entry']}")
|
||||
print(f"Number of Orders: {pos['num_orders']}")
|
||||
|
||||
add_to_existing = get_user_input("Add to existing position? (y/n):", bool)
|
||||
if add_to_existing is None:
|
||||
continue
|
||||
|
||||
if add_to_existing:
|
||||
position_id = get_user_input("Enter Position ID:", str)
|
||||
if position_id is None:
|
||||
continue
|
||||
else:
|
||||
position_id = generate_position_id(ticker)
|
||||
else:
|
||||
position_id = generate_position_id(ticker)
|
||||
|
||||
# Get entry date/time with market hours validation
|
||||
entry_date = get_datetime_input("Enter entry date and time", default=datetime.now())
|
||||
if entry_date is None:
|
||||
continue
|
||||
|
||||
shares = get_user_input("Enter number of shares:", int)
|
||||
if shares is None:
|
||||
continue
|
||||
|
||||
entry_price = get_user_input("Enter entry price:", float)
|
||||
if entry_price is None:
|
||||
continue
|
||||
|
||||
order_type = get_order_type()
|
||||
if order_type is None:
|
||||
continue
|
||||
|
||||
# If adding to existing position, get target/stop from existing
|
||||
if existing_positions and add_to_existing:
|
||||
# Use existing target and stop loss
|
||||
target_price = float(input("Enter new target price (or press Enter to keep existing): ") or
|
||||
existing_positions[0]['target_price'])
|
||||
stop_loss = float(input("Enter new stop loss (or press Enter to keep existing): ") or
|
||||
existing_positions[0]['stop_loss'])
|
||||
strategy = input("Enter strategy name (or press Enter to keep existing): ") or existing_positions[0]['strategy']
|
||||
else:
|
||||
target_price = float(input("Enter target price: "))
|
||||
stop_loss = float(input("Enter stop loss: "))
|
||||
strategy = input("Enter strategy name: ")
|
||||
|
||||
followed_rules = input("Did you follow your rules? (y/n): ").lower() == 'y'
|
||||
entry_reason = input("Enter entry reason (optional): ") or None
|
||||
notes = input("Additional notes (optional): ") or None
|
||||
|
||||
trade = TradeEntry(
|
||||
ticker=ticker,
|
||||
entry_date=entry_date,
|
||||
shares=shares,
|
||||
entry_price=entry_price,
|
||||
target_price=target_price,
|
||||
stop_loss=stop_loss,
|
||||
strategy=strategy,
|
||||
order_type=order_type,
|
||||
position_id=position_id,
|
||||
followed_rules=followed_rules,
|
||||
entry_reason=entry_reason,
|
||||
notes=notes
|
||||
)
|
||||
|
||||
add_trade(trade)
|
||||
|
||||
# Show updated position summary
|
||||
updated_positions = get_position_summary(ticker)
|
||||
if updated_positions:
|
||||
pos = updated_positions[0] # Get the most recent position
|
||||
print(f"\nUpdated Position Summary for {ticker}:")
|
||||
print(f"Total Shares: {pos['total_shares']}")
|
||||
print(f"Average Entry: ${pos['avg_entry_price']:.2f}")
|
||||
print(f"Expected Profit: ${(target_price - pos['avg_entry_price']) * pos['total_shares']:.2f}")
|
||||
print(f"Maximum Loss: ${(stop_loss - pos['avg_entry_price']) * pos['total_shares']:.2f}")
|
||||
|
||||
print("Trade added successfully!")
|
||||
|
||||
elif choice == "2":
|
||||
open_trades = get_open_trades()
|
||||
if not open_trades:
|
||||
print("No open trades to update.")
|
||||
continue
|
||||
|
||||
print("\nOpen Trades:")
|
||||
for trade in open_trades:
|
||||
print(f"{trade['id']}: {trade['ticker']} - Entered at ${trade['entry_price']}")
|
||||
|
||||
print("\nOpen Trades:")
|
||||
for trade in open_trades:
|
||||
print(f"\nID: {trade['id']}")
|
||||
print(f"Ticker: {trade['ticker']}")
|
||||
print(f"Entry Date: {trade['entry_date']}")
|
||||
print(f"Shares: {trade['shares']}")
|
||||
print(f"Entry Price: ${trade['entry_price']}")
|
||||
print(f"Target: ${trade['target_price']}")
|
||||
print(f"Stop Loss: ${trade['stop_loss']}")
|
||||
print(f"Strategy: {trade['strategy']}")
|
||||
print("-" * 40)
|
||||
|
||||
trade_id = get_user_input("\nEnter trade ID to update:", int)
|
||||
if trade_id is None:
|
||||
continue
|
||||
|
||||
# Find the trade to update
|
||||
trade_to_update = next((t for t in open_trades if t['id'] == trade_id), None)
|
||||
if not trade_to_update:
|
||||
print("Trade not found.")
|
||||
continue
|
||||
|
||||
print("\nUpdate Trade Fields")
|
||||
print("Leave blank to keep existing value")
|
||||
|
||||
updates = {}
|
||||
|
||||
# Entry date
|
||||
new_date = get_datetime_input("Enter new entry date and time (blank to keep):", None)
|
||||
if new_date:
|
||||
updates['entry_date'] = new_date
|
||||
|
||||
# Shares
|
||||
new_shares = get_user_input("Enter new number of shares:", int, allow_empty=True)
|
||||
if new_shares is not None:
|
||||
updates['shares'] = new_shares
|
||||
|
||||
# Entry price
|
||||
new_entry = get_user_input("Enter new entry price:", float, allow_empty=True)
|
||||
if new_entry is not None:
|
||||
updates['entry_price'] = new_entry
|
||||
|
||||
# Target price
|
||||
new_target = get_user_input("Enter new target price:", float, allow_empty=True)
|
||||
if new_target is not None:
|
||||
updates['target_price'] = new_target
|
||||
|
||||
# Stop loss
|
||||
new_stop = get_user_input("Enter new stop loss:", float, allow_empty=True)
|
||||
if new_stop is not None:
|
||||
updates['stop_loss'] = new_stop
|
||||
|
||||
# Strategy
|
||||
new_strategy = input("Enter new strategy (blank to keep): ").strip()
|
||||
if new_strategy:
|
||||
updates['strategy'] = new_strategy
|
||||
|
||||
# Order type
|
||||
if input("Update order type? (y/n): ").lower() == 'y':
|
||||
new_order_type = get_order_type()
|
||||
if new_order_type:
|
||||
updates['order_type'] = new_order_type
|
||||
|
||||
# Notes
|
||||
new_notes = input("Enter new notes (blank to keep): ").strip()
|
||||
if new_notes:
|
||||
updates['notes'] = new_notes
|
||||
|
||||
if updates:
|
||||
try:
|
||||
update_trade(trade_id, updates)
|
||||
print("Trade updated successfully!")
|
||||
except Exception as e:
|
||||
print(f"Error updating trade: {e}")
|
||||
else:
|
||||
print("No updates provided.")
|
||||
|
||||
elif choice == "3":
|
||||
open_trades = get_open_trades()
|
||||
open_summary = get_open_trades_summary()
|
||||
|
||||
if not open_trades:
|
||||
print("No open trades found.")
|
||||
else:
|
||||
# Get current prices for all open positions
|
||||
unique_tickers = list(set(summary['ticker'] for summary in open_summary))
|
||||
current_prices = get_current_prices(unique_tickers)
|
||||
|
||||
print("\n=== Open Trades Summary ===")
|
||||
total_portfolio_value = 0
|
||||
total_paper_pl = 0
|
||||
|
||||
for summary in open_summary:
|
||||
ticker = summary['ticker']
|
||||
avg_entry = summary['avg_entry_price']
|
||||
current_price = current_prices.get(ticker)
|
||||
|
||||
stop_loss = avg_entry * 0.93 # 7% stop loss
|
||||
total_shares = summary['total_shares']
|
||||
position_value = avg_entry * total_shares
|
||||
max_loss = (avg_entry - stop_loss) * total_shares
|
||||
|
||||
print(f"\n{ticker} Summary:")
|
||||
print(f"Total Shares: {total_shares}")
|
||||
print(f"Average Entry: ${avg_entry:.2f}")
|
||||
print(f"Total Position Value: ${position_value:.2f}")
|
||||
print(f"Combined Stop Loss (7%): ${stop_loss:.2f}")
|
||||
print(f"Maximum Loss at Stop: ${max_loss:.2f}")
|
||||
|
||||
if current_price:
|
||||
current_value = current_price * total_shares
|
||||
paper_pl = (current_price - avg_entry) * total_shares
|
||||
pl_percentage = (paper_pl / position_value) * 100
|
||||
total_portfolio_value += current_value
|
||||
total_paper_pl += paper_pl
|
||||
|
||||
print(f"Current Price: ${current_price:.2f}")
|
||||
print(f"Current Value: ${current_value:.2f}")
|
||||
print(f"Paper P/L: ${paper_pl:.2f} ({pl_percentage:.2f}%)")
|
||||
|
||||
print(f"Number of Orders: {summary['num_orders']}")
|
||||
print(f"Position Duration: {summary['last_entry'] - summary['first_entry']}")
|
||||
print("-" * 50)
|
||||
|
||||
if total_portfolio_value > 0:
|
||||
print(f"\nTotal Portfolio Value: ${total_portfolio_value:.2f}")
|
||||
print(f"Total Paper P/L: ${total_paper_pl:.2f}")
|
||||
print(f"Overall P/L %: {(total_paper_pl / (total_portfolio_value - total_paper_pl)) * 100:.2f}%")
|
||||
|
||||
print("\n=== Individual Trades ===")
|
||||
for trade in open_trades:
|
||||
ticker = trade['ticker']
|
||||
current_price = current_prices.get(ticker)
|
||||
|
||||
print(f"\nTicker: {ticker}")
|
||||
print(f"Position ID: {trade['position_id']}")
|
||||
print(f"Entry Date: {trade['entry_date']}")
|
||||
print(f"Shares: {trade['shares']}")
|
||||
print(f"Entry Price: ${trade['entry_price']}")
|
||||
print(f"Target: ${trade['target_price']}")
|
||||
print(f"Stop Loss: ${trade['stop_loss']}")
|
||||
print(f"Strategy: {trade['strategy']}")
|
||||
print(f"Order Type: {trade['order_type']}")
|
||||
|
||||
if current_price:
|
||||
paper_pl = (current_price - trade['entry_price']) * trade['shares']
|
||||
pl_percentage = (paper_pl / (trade['entry_price'] * trade['shares'])) * 100
|
||||
print(f"Current Price: ${current_price:.2f}")
|
||||
print(f"Paper P/L: ${paper_pl:.2f} ({pl_percentage:.2f}%)")
|
||||
|
||||
if trade['entry_reason']:
|
||||
print(f"Entry Reason: {trade['entry_reason']}")
|
||||
if trade['notes']:
|
||||
print(f"Notes: {trade['notes']}")
|
||||
print("-" * 40)
|
||||
|
||||
elif choice == "4":
|
||||
history = get_trade_history()
|
||||
if not history:
|
||||
print("No trade history found.")
|
||||
else:
|
||||
print("\nTrade History:")
|
||||
for trade in history:
|
||||
profit_loss = (trade['exit_price'] - trade['entry_price']) * trade['shares'] if trade['exit_price'] else None
|
||||
print(f"\nTicker: {trade['ticker']}")
|
||||
print(f"Entry: ${trade['entry_price']} on {trade['entry_date']}")
|
||||
if trade['exit_price']:
|
||||
print(f"Exit: ${trade['exit_price']} on {trade['exit_date']}")
|
||||
print(f"P/L: ${profit_loss:.2f}")
|
||||
print(f"Strategy: {trade['strategy']}")
|
||||
if trade['notes']:
|
||||
print(f"Notes: {trade['notes']}")
|
||||
print("-" * 40)
|
||||
|
||||
elif choice == "5":
|
||||
# Show all trades (both open and closed)
|
||||
print("\nAll Trades:")
|
||||
with create_client() as client:
|
||||
query = "SELECT * FROM stock_db.trades ORDER BY entry_date DESC"
|
||||
result = client.query(query).result_rows
|
||||
columns = ['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price',
|
||||
'target_price', 'stop_loss', 'strategy', 'order_type', 'followed_rules',
|
||||
'entry_reason', 'exit_price', 'exit_date', 'exit_reason', 'notes', 'created_at']
|
||||
trades = [dict(zip(columns, row)) for row in result]
|
||||
|
||||
for trade in trades:
|
||||
print(f"\nID: {trade['id']}")
|
||||
print(f"Ticker: {trade['ticker']}")
|
||||
print(f"Entry Date: {trade['entry_date']}")
|
||||
print(f"Shares: {trade['shares']}")
|
||||
print(f"Entry Price: ${trade['entry_price']}")
|
||||
if trade['exit_price']:
|
||||
print(f"Exit Price: ${trade['exit_price']}")
|
||||
print(f"Exit Date: {trade['exit_date']}")
|
||||
pl = (trade['exit_price'] - trade['entry_price']) * trade['shares']
|
||||
print(f"P/L: ${pl:.2f}")
|
||||
print("-" * 40)
|
||||
|
||||
trade_id = get_user_input("\nEnter trade ID to delete:", int)
|
||||
if trade_id is None:
|
||||
continue
|
||||
|
||||
confirm = input(f"Are you sure you want to delete trade {trade_id}? (y/n): ").lower()
|
||||
if confirm == 'y':
|
||||
if delete_trade(trade_id):
|
||||
print("Trade deleted successfully!")
|
||||
else:
|
||||
print("Failed to delete trade.")
|
||||
else:
|
||||
print("Delete cancelled.")
|
||||
|
||||
elif choice == "6":
|
||||
break
|
||||
@ -1,80 +1,31 @@
|
||||
from datetime import datetime
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from trading.portfolio import Portfolio, Position
|
||||
from trading.journal import (TradeEntry, add_trade, get_open_trades,
|
||||
create_portfolio_table, update_portfolio_value,
|
||||
get_latest_portfolio_value, get_datetime_input,
|
||||
get_order_type)
|
||||
from position_calculator import PositionCalculator
|
||||
from portfolio import Portfolio, Position
|
||||
|
||||
def get_float_input(prompt: str) -> float:
|
||||
while True:
|
||||
try:
|
||||
value = input(prompt)
|
||||
if value.lower() in ['q', 'quit', 'exit']:
|
||||
return None
|
||||
return float(value)
|
||||
return float(input(prompt))
|
||||
except ValueError:
|
||||
print("Please enter a valid number")
|
||||
|
||||
def main():
|
||||
# Initialize tables
|
||||
create_portfolio_table()
|
||||
# Get initial portfolio value from user
|
||||
portfolio_value = get_float_input("Enter your portfolio value: $")
|
||||
|
||||
# Get latest portfolio value or ask for initial value
|
||||
portfolio_data = get_latest_portfolio_value()
|
||||
if portfolio_data:
|
||||
portfolio_value = portfolio_data['total_value']
|
||||
cash_balance = portfolio_data['cash_balance']
|
||||
print(f"\nCurrent Portfolio Value: ${portfolio_value:.2f}")
|
||||
print(f"Cash Balance: ${cash_balance:.2f}")
|
||||
print(f"Last Updated: {portfolio_data['date']}")
|
||||
|
||||
if input("\nUpdate portfolio value? (y/n): ").lower() == 'y':
|
||||
new_value = get_float_input("Enter new portfolio value: $")
|
||||
new_cash = get_float_input("Enter new cash balance: $")
|
||||
notes = input("Notes (optional): ")
|
||||
if new_value and new_cash:
|
||||
portfolio_value = new_value
|
||||
cash_balance = new_cash
|
||||
update_portfolio_value(portfolio_value, cash_balance, notes)
|
||||
else:
|
||||
portfolio_value = get_float_input("Enter your initial portfolio value: $")
|
||||
cash_balance = get_float_input("Enter your cash balance: $")
|
||||
if portfolio_value and cash_balance:
|
||||
update_portfolio_value(portfolio_value, cash_balance, "Initial portfolio setup")
|
||||
|
||||
if not portfolio_value or not cash_balance:
|
||||
return
|
||||
|
||||
# Initialize calculator with current portfolio value
|
||||
calculator = PositionCalculator(account_size=portfolio_value)
|
||||
|
||||
# Initialize portfolio
|
||||
# Initialize portfolio and position calculator
|
||||
portfolio = Portfolio()
|
||||
calculator = PositionCalculator(account_size=portfolio_value) # Use user's portfolio value
|
||||
|
||||
# Load existing open trades into portfolio
|
||||
open_trades = get_open_trades()
|
||||
for trade in open_trades:
|
||||
position = Position(
|
||||
symbol=trade['ticker'],
|
||||
entry_date=trade['entry_date'],
|
||||
entry_price=trade['entry_price'],
|
||||
shares=trade['shares'],
|
||||
stop_loss=trade['stop_loss'],
|
||||
target_price=trade['target_price']
|
||||
)
|
||||
portfolio.add_position(position)
|
||||
|
||||
while True:
|
||||
print("\nTrading Management System")
|
||||
print("1. Calculate Position Size")
|
||||
print("2. Add Position")
|
||||
print("3. View Portfolio")
|
||||
print("4. Remove Position")
|
||||
print("5. Update Portfolio Value")
|
||||
print("6. Exit")
|
||||
print("5. Exit")
|
||||
|
||||
choice = input("\nSelect an option (1-6): ")
|
||||
choice = input("\nSelect an option (1-5): ")
|
||||
|
||||
if choice == "1":
|
||||
try:
|
||||
@ -145,17 +96,6 @@ def main():
|
||||
print(f"\nRemoved position: {symbol}")
|
||||
|
||||
elif choice == "5":
|
||||
new_value = get_float_input("Enter new portfolio value: $")
|
||||
new_cash = get_float_input("Enter new cash balance: $")
|
||||
notes = input("Notes (optional): ")
|
||||
if new_value and new_cash:
|
||||
portfolio_value = new_value
|
||||
cash_balance = new_cash
|
||||
update_portfolio_value(portfolio_value, cash_balance, notes)
|
||||
print(f"\nPortfolio value updated: ${portfolio_value:.2f}")
|
||||
print(f"Cash balance updated: ${cash_balance:.2f}")
|
||||
|
||||
elif choice == "6":
|
||||
print("\nExiting Trading Management System")
|
||||
break
|
||||
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
def print_main_menu():
|
||||
print("\nStock Analysis System")
|
||||
print("1. Run CANSLIM Screener")
|
||||
print("2. Run Technical Scanners (SunnyBands/ATR-EMA)")
|
||||
print("3. Launch Trading System")
|
||||
print("4. Trading Journal")
|
||||
print("5. Exit")
|
||||
|
||||
def print_technical_scanner_menu():
|
||||
print("\nTechnical Scanner Options:")
|
||||
print("1. SunnyBands Scanner")
|
||||
print("2. Standard ATR-EMA Scanner")
|
||||
print("3. Enhanced ATR-EMA v2 Scanner")
|
||||
@ -14,30 +14,21 @@ class Position:
|
||||
@property
|
||||
def current_value(self) -> float:
|
||||
# TODO: Implement real-time price fetching
|
||||
if self.entry_price is None or self.shares is None:
|
||||
return 0.0
|
||||
return self.shares * self.entry_price
|
||||
|
||||
@property
|
||||
def potential_profit(self) -> float:
|
||||
"""Calculate potential profit at target price"""
|
||||
if self.target_price is None or self.entry_price is None or self.shares is None:
|
||||
return 0.0
|
||||
return (self.target_price - self.entry_price) * self.shares
|
||||
|
||||
@property
|
||||
def potential_loss(self) -> float:
|
||||
"""Calculate potential loss at stop loss"""
|
||||
if self.stop_loss is None or self.entry_price is None or self.shares is None:
|
||||
return 0.0
|
||||
return (self.stop_loss - self.entry_price) * self.shares
|
||||
|
||||
@property
|
||||
def risk_reward_ratio(self) -> float:
|
||||
"""Calculate risk/reward ratio"""
|
||||
if (self.target_price is None or self.entry_price is None or
|
||||
self.stop_loss is None):
|
||||
return 0.0
|
||||
potential_gain = self.target_price - self.entry_price
|
||||
potential_risk = self.entry_price - self.stop_loss
|
||||
return abs(potential_gain / potential_risk) if potential_risk != 0 else 0
|
||||
|
||||
@ -1,546 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
from db.db_connection import create_client
|
||||
|
||||
class PlanStatus(Enum):
|
||||
ACTIVE = 'active'
|
||||
ARCHIVED = 'archived'
|
||||
TESTING = 'testing'
|
||||
DEPRECATED = 'deprecated'
|
||||
|
||||
class Timeframe(Enum):
|
||||
DAILY = 'daily'
|
||||
WEEKLY = 'weekly'
|
||||
HOURLY = 'hourly'
|
||||
MIN_15 = '15-min'
|
||||
MIN_30 = '30-min'
|
||||
MIN_5 = '5-min'
|
||||
|
||||
class MarketFocus(Enum):
|
||||
STOCKS = 'stocks'
|
||||
CRYPTO = 'crypto'
|
||||
FOREX = 'forex'
|
||||
OPTIONS = 'options'
|
||||
FUTURES = 'futures'
|
||||
|
||||
class TradeFrequency(Enum):
|
||||
DAILY = 'daily'
|
||||
WEEKLY = 'weekly'
|
||||
MONTHLY = 'monthly'
|
||||
AS_NEEDED = 'as-needed'
|
||||
|
||||
@dataclass
|
||||
class TradingPlan:
|
||||
# General Info
|
||||
plan_name: str
|
||||
status: PlanStatus
|
||||
timeframe: Timeframe
|
||||
market_focus: MarketFocus
|
||||
entry_criteria: str
|
||||
exit_criteria: str
|
||||
stop_loss: float
|
||||
profit_target: float
|
||||
risk_reward_ratio: float
|
||||
trade_frequency: TradeFrequency
|
||||
market_conditions: str
|
||||
indicators_used: str
|
||||
entry_confirmation: str
|
||||
position_sizing: float
|
||||
maximum_drawdown: float
|
||||
max_trades_per_day: int
|
||||
max_trades_per_week: int
|
||||
total_risk_per_trade: float
|
||||
max_portfolio_risk: float
|
||||
adjustments_for_drawdown: str
|
||||
risk_controls: str
|
||||
plan_author: str
|
||||
|
||||
# Fields with default values must come after required fields
|
||||
created_at: datetime = None
|
||||
updated_at: datetime = None
|
||||
id: int = None
|
||||
strategy_version: int = 1
|
||||
win_rate: Optional[float] = None
|
||||
average_return_per_trade: Optional[float] = None
|
||||
profit_factor: Optional[float] = None
|
||||
historical_backtest_results: Optional[str] = None
|
||||
real_trade_performance: Optional[str] = None
|
||||
improvements_needed: Optional[str] = None
|
||||
trade_review_notes: Optional[str] = None
|
||||
future_testing_ideas: Optional[str] = None
|
||||
sector_focus: Optional[str] = None
|
||||
fundamental_criteria: Optional[str] = None
|
||||
options_strategy_details: Optional[str] = None
|
||||
|
||||
def create_trading_plan_table():
|
||||
"""Create the trading plans table if it doesn't exist"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
# First check if table exists
|
||||
check_query = """
|
||||
SELECT name
|
||||
FROM system.tables
|
||||
WHERE database = currentDatabase()
|
||||
AND name = 'trading_plans'
|
||||
"""
|
||||
result = client.query(check_query)
|
||||
|
||||
if result.result_rows:
|
||||
return # Table already exists, silently return
|
||||
|
||||
# Create new table with a structure that supports updates
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS trading_plans
|
||||
(
|
||||
id UInt32,
|
||||
plan_name String,
|
||||
status String,
|
||||
created_at DateTime,
|
||||
updated_at DateTime,
|
||||
timeframe String,
|
||||
market_focus String,
|
||||
entry_criteria String,
|
||||
exit_criteria String,
|
||||
stop_loss Float64,
|
||||
profit_target Float64,
|
||||
risk_reward_ratio Float64,
|
||||
trade_frequency String,
|
||||
market_conditions String,
|
||||
indicators_used String,
|
||||
entry_confirmation String,
|
||||
position_sizing Float64,
|
||||
maximum_drawdown Float64,
|
||||
max_trades_per_day UInt32,
|
||||
max_trades_per_week UInt32,
|
||||
total_risk_per_trade Float64,
|
||||
max_portfolio_risk Float64,
|
||||
adjustments_for_drawdown String,
|
||||
risk_controls String,
|
||||
win_rate Float64,
|
||||
average_return_per_trade Float64,
|
||||
profit_factor Float64,
|
||||
historical_backtest_results String,
|
||||
real_trade_performance String,
|
||||
improvements_needed String,
|
||||
strategy_version UInt32,
|
||||
plan_author String,
|
||||
trade_review_notes String,
|
||||
future_testing_ideas String,
|
||||
sector_focus String,
|
||||
fundamental_criteria String,
|
||||
options_strategy_details String
|
||||
)
|
||||
ENGINE = ReplacingMergeTree()
|
||||
ORDER BY id
|
||||
"""
|
||||
client.command(query)
|
||||
print("Table 'trading_plans' created successfully.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating table: {e}")
|
||||
|
||||
def save_trading_plan(plan: TradingPlan) -> int:
|
||||
"""Save a trading plan to the database"""
|
||||
with create_client() as client:
|
||||
if not plan.id:
|
||||
# Generate new ID for new plans
|
||||
result = client.query("SELECT max(id) FROM trading_plans")
|
||||
max_id = result.result_rows[0][0] if result.result_rows else 0
|
||||
plan.id = (max_id or 0) + 1
|
||||
plan.created_at = datetime.now()
|
||||
|
||||
plan.updated_at = datetime.now()
|
||||
|
||||
query = """
|
||||
INSERT INTO trading_plans VALUES (
|
||||
%(id)s, %(plan_name)s, %(status)s, %(created_at)s, %(updated_at)s,
|
||||
%(timeframe)s, %(market_focus)s, %(entry_criteria)s, %(exit_criteria)s,
|
||||
%(stop_loss)s, %(profit_target)s, %(risk_reward_ratio)s, %(trade_frequency)s,
|
||||
%(market_conditions)s, %(indicators_used)s, %(entry_confirmation)s,
|
||||
%(position_sizing)s, %(maximum_drawdown)s, %(max_trades_per_day)s,
|
||||
%(max_trades_per_week)s, %(total_risk_per_trade)s, %(max_portfolio_risk)s,
|
||||
%(adjustments_for_drawdown)s, %(risk_controls)s, %(win_rate)s,
|
||||
%(average_return_per_trade)s, %(profit_factor)s, %(historical_backtest_results)s,
|
||||
%(real_trade_performance)s, %(improvements_needed)s, %(strategy_version)s,
|
||||
%(plan_author)s, %(trade_review_notes)s, %(future_testing_ideas)s,
|
||||
%(sector_focus)s, %(fundamental_criteria)s, %(options_strategy_details)s
|
||||
)
|
||||
"""
|
||||
|
||||
params = {
|
||||
'id': plan.id,
|
||||
'plan_name': plan.plan_name,
|
||||
'status': plan.status.value,
|
||||
'created_at': plan.created_at,
|
||||
'updated_at': plan.updated_at,
|
||||
'timeframe': plan.timeframe.value,
|
||||
'market_focus': plan.market_focus.value,
|
||||
'entry_criteria': plan.entry_criteria,
|
||||
'exit_criteria': plan.exit_criteria,
|
||||
'stop_loss': plan.stop_loss,
|
||||
'profit_target': plan.profit_target,
|
||||
'risk_reward_ratio': plan.risk_reward_ratio,
|
||||
'trade_frequency': plan.trade_frequency.value,
|
||||
'market_conditions': plan.market_conditions,
|
||||
'indicators_used': plan.indicators_used,
|
||||
'entry_confirmation': plan.entry_confirmation,
|
||||
'position_sizing': plan.position_sizing,
|
||||
'maximum_drawdown': plan.maximum_drawdown,
|
||||
'max_trades_per_day': plan.max_trades_per_day,
|
||||
'max_trades_per_week': plan.max_trades_per_week,
|
||||
'total_risk_per_trade': plan.total_risk_per_trade,
|
||||
'max_portfolio_risk': plan.max_portfolio_risk,
|
||||
'adjustments_for_drawdown': plan.adjustments_for_drawdown,
|
||||
'risk_controls': plan.risk_controls,
|
||||
'win_rate': plan.win_rate,
|
||||
'average_return_per_trade': plan.average_return_per_trade,
|
||||
'profit_factor': plan.profit_factor,
|
||||
'historical_backtest_results': plan.historical_backtest_results,
|
||||
'real_trade_performance': plan.real_trade_performance,
|
||||
'improvements_needed': plan.improvements_needed,
|
||||
'strategy_version': plan.strategy_version,
|
||||
'plan_author': plan.plan_author,
|
||||
'trade_review_notes': plan.trade_review_notes,
|
||||
'future_testing_ideas': plan.future_testing_ideas,
|
||||
'sector_focus': plan.sector_focus,
|
||||
'fundamental_criteria': plan.fundamental_criteria,
|
||||
'options_strategy_details': plan.options_strategy_details
|
||||
}
|
||||
|
||||
client.command(query, params)
|
||||
return plan.id
|
||||
|
||||
def get_trading_plan(plan_id: int) -> Optional[TradingPlan]:
|
||||
"""Get a trading plan by ID"""
|
||||
with create_client() as client:
|
||||
result = client.query("""
|
||||
SELECT * FROM trading_plans WHERE id = %(id)s
|
||||
""", {'id': plan_id})
|
||||
|
||||
rows = result.result_rows
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
plan = rows[0]
|
||||
return TradingPlan(
|
||||
id=plan[0],
|
||||
plan_name=plan[1],
|
||||
status=PlanStatus(plan[2]),
|
||||
created_at=plan[3],
|
||||
updated_at=plan[4],
|
||||
timeframe=Timeframe(plan[5]),
|
||||
market_focus=MarketFocus(plan[6]),
|
||||
entry_criteria=plan[7],
|
||||
exit_criteria=plan[8],
|
||||
stop_loss=plan[9],
|
||||
profit_target=plan[10],
|
||||
risk_reward_ratio=plan[11],
|
||||
trade_frequency=TradeFrequency(plan[12]),
|
||||
market_conditions=plan[13],
|
||||
indicators_used=plan[14],
|
||||
entry_confirmation=plan[15],
|
||||
position_sizing=plan[16],
|
||||
maximum_drawdown=plan[17],
|
||||
max_trades_per_day=plan[18],
|
||||
max_trades_per_week=plan[19],
|
||||
total_risk_per_trade=plan[20],
|
||||
max_portfolio_risk=plan[21],
|
||||
adjustments_for_drawdown=plan[22],
|
||||
risk_controls=plan[23],
|
||||
win_rate=plan[24],
|
||||
average_return_per_trade=plan[25],
|
||||
profit_factor=plan[26],
|
||||
historical_backtest_results=plan[27],
|
||||
real_trade_performance=plan[28],
|
||||
improvements_needed=plan[29],
|
||||
strategy_version=plan[30],
|
||||
plan_author=plan[31],
|
||||
trade_review_notes=plan[32],
|
||||
future_testing_ideas=plan[33],
|
||||
sector_focus=plan[34],
|
||||
fundamental_criteria=plan[35],
|
||||
options_strategy_details=plan[36]
|
||||
)
|
||||
|
||||
def get_all_trading_plans(status: Optional[PlanStatus] = None) -> List[TradingPlan]:
|
||||
"""Get all trading plans, optionally filtered by status"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
query = "SELECT * FROM trading_plans"
|
||||
params = {}
|
||||
|
||||
if status:
|
||||
query += " WHERE status = %(status)s"
|
||||
params['status'] = status.value
|
||||
|
||||
query += " ORDER BY updated_at DESC"
|
||||
|
||||
results = client.query(query, params)
|
||||
rows = results.result_rows
|
||||
return [TradingPlan(
|
||||
id=row[0],
|
||||
plan_name=row[1],
|
||||
status=PlanStatus(row[2]),
|
||||
created_at=row[3],
|
||||
updated_at=row[4],
|
||||
timeframe=Timeframe(row[5]),
|
||||
market_focus=MarketFocus(row[6]),
|
||||
entry_criteria=row[7],
|
||||
exit_criteria=row[8],
|
||||
stop_loss=row[9],
|
||||
profit_target=row[10],
|
||||
risk_reward_ratio=row[11],
|
||||
trade_frequency=TradeFrequency(row[12]),
|
||||
market_conditions=row[13],
|
||||
indicators_used=row[14],
|
||||
entry_confirmation=row[15],
|
||||
position_sizing=row[16],
|
||||
maximum_drawdown=row[17],
|
||||
max_trades_per_day=row[18],
|
||||
max_trades_per_week=row[19],
|
||||
total_risk_per_trade=row[20],
|
||||
max_portfolio_risk=row[21],
|
||||
adjustments_for_drawdown=row[22],
|
||||
risk_controls=row[23],
|
||||
win_rate=row[24],
|
||||
average_return_per_trade=row[25],
|
||||
profit_factor=row[26],
|
||||
historical_backtest_results=row[27],
|
||||
real_trade_performance=row[28],
|
||||
improvements_needed=row[29],
|
||||
strategy_version=row[30],
|
||||
plan_author=row[31],
|
||||
trade_review_notes=row[32],
|
||||
future_testing_ideas=row[33],
|
||||
sector_focus=row[34],
|
||||
fundamental_criteria=row[35],
|
||||
options_strategy_details=row[36]
|
||||
) for row in rows]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error retrieving trading plans: {e}")
|
||||
return []
|
||||
|
||||
def unlink_trades_from_plan(plan_id: int) -> bool:
|
||||
"""Unlink all trades from a trading plan"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
# First update the plan's metrics to NULL
|
||||
plan_update_query = """
|
||||
ALTER TABLE trading_plans
|
||||
UPDATE
|
||||
win_rate = NULL,
|
||||
average_return_per_trade = NULL,
|
||||
profit_factor = NULL
|
||||
WHERE id = %(plan_id)s
|
||||
"""
|
||||
client.command(plan_update_query, {'plan_id': plan_id})
|
||||
|
||||
# Then unlink the trades
|
||||
trades_update_query = """
|
||||
ALTER TABLE stock_db.trades
|
||||
UPDATE plan_id = NULL
|
||||
WHERE plan_id = %(plan_id)s
|
||||
"""
|
||||
client.command(trades_update_query, {'plan_id': plan_id})
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error unlinking trades from plan: {e}")
|
||||
return False
|
||||
|
||||
def delete_trading_plan(plan_id: int) -> bool:
|
||||
"""Delete a trading plan by ID"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
# First unlink all trades
|
||||
unlink_trades_from_plan(plan_id)
|
||||
|
||||
# Then delete the plan
|
||||
query = "ALTER TABLE trading_plans DELETE WHERE id = %(id)s"
|
||||
client.command(query, {'id': plan_id})
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting trading plan: {e}")
|
||||
return False
|
||||
|
||||
def link_trades_to_plan(plan_id: int, trade_ids: List[int]) -> bool:
|
||||
"""Link existing trades to a trading plan"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
# Format the trade_ids properly for the IN clause
|
||||
trade_ids_str = ", ".join(str(id) for id in trade_ids)
|
||||
query = f"""
|
||||
ALTER TABLE stock_db.trades
|
||||
UPDATE plan_id = {plan_id}
|
||||
WHERE id IN ({trade_ids_str})
|
||||
"""
|
||||
client.command(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error linking trades to plan: {e}")
|
||||
return False
|
||||
|
||||
def get_plan_trades(plan_id: int) -> List[dict]:
|
||||
"""Get all trades associated with a trading plan"""
|
||||
with create_client() as client:
|
||||
try:
|
||||
# First check if plan_id column exists
|
||||
check_query = """
|
||||
SELECT name
|
||||
FROM system.columns
|
||||
WHERE database = 'stock_db'
|
||||
AND table = 'trades'
|
||||
AND name = 'plan_id'
|
||||
"""
|
||||
result = client.query(check_query)
|
||||
|
||||
if not result.result_rows:
|
||||
# Add plan_id column if it doesn't exist
|
||||
alter_query = """
|
||||
ALTER TABLE stock_db.trades
|
||||
ADD COLUMN IF NOT EXISTS plan_id Nullable(UInt32)
|
||||
"""
|
||||
client.command(alter_query)
|
||||
print("Added plan_id column to trades table")
|
||||
|
||||
# Now query the trades
|
||||
query = """
|
||||
SELECT *
|
||||
FROM stock_db.trades
|
||||
WHERE plan_id = %(plan_id)s
|
||||
ORDER BY entry_date DESC
|
||||
"""
|
||||
result = client.query(query, {'plan_id': plan_id})
|
||||
return [dict(zip(
|
||||
['id', 'position_id', 'ticker', 'entry_date', 'shares', 'entry_price',
|
||||
'target_price', 'stop_loss', 'strategy', 'order_type', 'direction',
|
||||
'followed_rules', 'entry_reason', 'exit_price', 'exit_date',
|
||||
'exit_reason', 'notes', 'created_at', 'plan_id'],
|
||||
row
|
||||
)) for row in result.result_rows]
|
||||
except Exception as e:
|
||||
print(f"Error in get_plan_trades: {e}")
|
||||
return []
|
||||
|
||||
def calculate_plan_metrics(plan_id: int) -> dict:
|
||||
"""Calculate performance metrics for a trading plan"""
|
||||
trades = get_plan_trades(plan_id)
|
||||
if not trades:
|
||||
return {}
|
||||
|
||||
total_trades = len(trades)
|
||||
winning_trades = sum(1 for t in trades if t['exit_price'] and
|
||||
(t['exit_price'] - t['entry_price']) * t['shares'] > 0)
|
||||
total_profit = sum((t['exit_price'] - t['entry_price']) * t['shares']
|
||||
for t in trades if t['exit_price'])
|
||||
|
||||
gross_profits = sum((t['exit_price'] - t['entry_price']) * t['shares']
|
||||
for t in trades if t['exit_price'] and
|
||||
(t['exit_price'] - t['entry_price']) * t['shares'] > 0)
|
||||
|
||||
gross_losses = abs(sum((t['exit_price'] - t['entry_price']) * t['shares']
|
||||
for t in trades if t['exit_price'] and
|
||||
(t['exit_price'] - t['entry_price']) * t['shares'] < 0))
|
||||
|
||||
return {
|
||||
'total_trades': total_trades,
|
||||
'winning_trades': winning_trades,
|
||||
'win_rate': (winning_trades / total_trades * 100) if total_trades > 0 else 0,
|
||||
'total_profit': total_profit,
|
||||
'average_return': total_profit / total_trades if total_trades > 0 else 0,
|
||||
'profit_factor': gross_profits / gross_losses if gross_losses > 0 else float('inf')
|
||||
}
|
||||
|
||||
def update_trading_plan(plan: TradingPlan) -> bool:
|
||||
"""Update an existing trading plan"""
|
||||
if not plan.id:
|
||||
raise ValueError("Cannot update plan without ID")
|
||||
|
||||
with create_client() as client:
|
||||
plan.updated_at = datetime.now()
|
||||
|
||||
query = """
|
||||
ALTER TABLE trading_plans
|
||||
UPDATE
|
||||
plan_name = %(plan_name)s,
|
||||
status = %(status)s,
|
||||
updated_at = %(updated_at)s,
|
||||
timeframe = %(timeframe)s,
|
||||
market_focus = %(market_focus)s,
|
||||
entry_criteria = %(entry_criteria)s,
|
||||
exit_criteria = %(exit_criteria)s,
|
||||
stop_loss = %(stop_loss)s,
|
||||
profit_target = %(profit_target)s,
|
||||
risk_reward_ratio = %(risk_reward_ratio)s,
|
||||
trade_frequency = %(trade_frequency)s,
|
||||
market_conditions = %(market_conditions)s,
|
||||
indicators_used = %(indicators_used)s,
|
||||
entry_confirmation = %(entry_confirmation)s,
|
||||
position_sizing = %(position_sizing)s,
|
||||
maximum_drawdown = %(maximum_drawdown)s,
|
||||
max_trades_per_day = %(max_trades_per_day)s,
|
||||
max_trades_per_week = %(max_trades_per_week)s,
|
||||
total_risk_per_trade = %(total_risk_per_trade)s,
|
||||
max_portfolio_risk = %(max_portfolio_risk)s,
|
||||
adjustments_for_drawdown = %(adjustments_for_drawdown)s,
|
||||
risk_controls = %(risk_controls)s,
|
||||
win_rate = %(win_rate)s,
|
||||
average_return_per_trade = %(average_return_per_trade)s,
|
||||
profit_factor = %(profit_factor)s,
|
||||
historical_backtest_results = %(historical_backtest_results)s,
|
||||
real_trade_performance = %(real_trade_performance)s,
|
||||
improvements_needed = %(improvements_needed)s,
|
||||
strategy_version = %(strategy_version)s,
|
||||
plan_author = %(plan_author)s,
|
||||
trade_review_notes = %(trade_review_notes)s,
|
||||
future_testing_ideas = %(future_testing_ideas)s,
|
||||
sector_focus = %(sector_focus)s,
|
||||
fundamental_criteria = %(fundamental_criteria)s,
|
||||
options_strategy_details = %(options_strategy_details)s
|
||||
WHERE id = %(id)s
|
||||
"""
|
||||
|
||||
params = {
|
||||
'id': plan.id,
|
||||
'plan_name': plan.plan_name,
|
||||
'status': plan.status.value,
|
||||
'updated_at': plan.updated_at,
|
||||
'timeframe': plan.timeframe.value,
|
||||
'market_focus': plan.market_focus.value,
|
||||
'entry_criteria': plan.entry_criteria,
|
||||
'exit_criteria': plan.exit_criteria,
|
||||
'stop_loss': plan.stop_loss,
|
||||
'profit_target': plan.profit_target,
|
||||
'risk_reward_ratio': plan.risk_reward_ratio,
|
||||
'trade_frequency': plan.trade_frequency.value,
|
||||
'market_conditions': plan.market_conditions,
|
||||
'indicators_used': plan.indicators_used,
|
||||
'entry_confirmation': plan.entry_confirmation,
|
||||
'position_sizing': plan.position_sizing,
|
||||
'maximum_drawdown': plan.maximum_drawdown,
|
||||
'max_trades_per_day': plan.max_trades_per_day,
|
||||
'max_trades_per_week': plan.max_trades_per_week,
|
||||
'total_risk_per_trade': plan.total_risk_per_trade,
|
||||
'max_portfolio_risk': plan.max_portfolio_risk,
|
||||
'adjustments_for_drawdown': plan.adjustments_for_drawdown,
|
||||
'risk_controls': plan.risk_controls,
|
||||
'win_rate': plan.win_rate,
|
||||
'average_return_per_trade': plan.average_return_per_trade,
|
||||
'profit_factor': plan.profit_factor,
|
||||
'historical_backtest_results': plan.historical_backtest_results,
|
||||
'real_trade_performance': plan.real_trade_performance,
|
||||
'improvements_needed': plan.improvements_needed,
|
||||
'strategy_version': plan.strategy_version,
|
||||
'plan_author': plan.plan_author,
|
||||
'trade_review_notes': plan.trade_review_notes,
|
||||
'future_testing_ideas': plan.future_testing_ideas,
|
||||
'sector_focus': plan.sector_focus,
|
||||
'fundamental_criteria': plan.fundamental_criteria,
|
||||
'options_strategy_details': plan.options_strategy_details
|
||||
}
|
||||
|
||||
client.command(query, params)
|
||||
return True
|
||||
@ -1,166 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from db.db_connection import create_client
|
||||
import logging
|
||||
import streamlit as st
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def ensure_tables_exist():
|
||||
with create_client() as client:
|
||||
# First ensure database exists
|
||||
client.command("CREATE DATABASE IF NOT EXISTS stock_db")
|
||||
|
||||
# Create watchlists table if not exists
|
||||
client.command("""
|
||||
CREATE TABLE IF NOT EXISTS stock_db.watchlists (
|
||||
id UInt32,
|
||||
name String,
|
||||
strategy String,
|
||||
created_at DateTime
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY (id)
|
||||
""")
|
||||
|
||||
# Create watchlist_items table if not exists
|
||||
client.command("""
|
||||
CREATE TABLE IF NOT EXISTS stock_db.watchlist_items (
|
||||
id UInt32,
|
||||
watchlist_id UInt32,
|
||||
ticker String,
|
||||
entry_price Float64,
|
||||
target_price Float64,
|
||||
stop_loss Float64,
|
||||
shares Int32,
|
||||
notes String,
|
||||
created_at DateTime
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY (id)
|
||||
""")
|
||||
|
||||
@dataclass
|
||||
class WatchlistItem:
|
||||
ticker: str
|
||||
entry_price: float
|
||||
target_price: float
|
||||
stop_loss: float
|
||||
shares: Optional[int] = None
|
||||
notes: Optional[str] = None
|
||||
watchlist_id: Optional[int] = None
|
||||
id: Optional[int] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
def create_watchlist(name: str, strategy: Optional[str] = None) -> int:
|
||||
with create_client() as client:
|
||||
# Ensure tables exist before creating new watchlist
|
||||
ensure_tables_exist()
|
||||
|
||||
# Get the next available ID
|
||||
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlists")
|
||||
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
||||
|
||||
client.insert('stock_db.watchlists',
|
||||
[(next_id, name, strategy or '', datetime.now())],
|
||||
column_names=['id', 'name', 'strategy', 'created_at'])
|
||||
return next_id
|
||||
|
||||
def get_watchlists() -> List[dict]:
|
||||
with create_client() as client:
|
||||
result = client.query("SELECT * FROM stock_db.watchlists ORDER BY name")
|
||||
return [dict(zip(result.column_names, row)) for row in result.result_rows]
|
||||
|
||||
def add_to_watchlist(watchlist_id: int, item: WatchlistItem) -> bool:
|
||||
with create_client() as client:
|
||||
try:
|
||||
# Get next ID
|
||||
result = client.query("SELECT max(id) + 1 as next_id FROM stock_db.watchlist_items")
|
||||
next_id = result.first_row[0] if result.first_row[0] is not None else 1
|
||||
|
||||
# Insert the item
|
||||
data = [(
|
||||
next_id,
|
||||
watchlist_id,
|
||||
item.ticker,
|
||||
item.entry_price,
|
||||
item.target_price,
|
||||
item.stop_loss,
|
||||
item.shares or 0,
|
||||
item.notes or '',
|
||||
datetime.now()
|
||||
)]
|
||||
|
||||
client.insert(
|
||||
'stock_db.watchlist_items',
|
||||
data,
|
||||
column_names=[
|
||||
'id', 'watchlist_id', 'ticker', 'entry_price',
|
||||
'target_price', 'stop_loss', 'shares', 'notes', 'created_at'
|
||||
]
|
||||
)
|
||||
|
||||
# Verify the insert
|
||||
verify_result = client.query(f"""
|
||||
SELECT * FROM stock_db.watchlist_items
|
||||
WHERE id = {next_id}
|
||||
""")
|
||||
|
||||
return len(verify_result.result_rows) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding to watchlist: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove_from_watchlist(item_id: int) -> bool:
|
||||
with create_client() as client:
|
||||
try:
|
||||
client.command(f"""
|
||||
ALTER TABLE stock_db.watchlist_items
|
||||
DELETE WHERE id = {item_id}
|
||||
""")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error removing from watchlist: {e}")
|
||||
return False
|
||||
|
||||
def get_watchlist_items(watchlist_id: Optional[int] = None) -> List[WatchlistItem]:
|
||||
with create_client() as client:
|
||||
try:
|
||||
# Ensure tables exist before querying
|
||||
ensure_tables_exist()
|
||||
|
||||
query = """
|
||||
SELECT i.*, w.name as watchlist_name, w.strategy
|
||||
FROM stock_db.watchlist_items i
|
||||
JOIN stock_db.watchlists w ON w.id = i.watchlist_id
|
||||
"""
|
||||
if watchlist_id:
|
||||
query += f" WHERE i.watchlist_id = {watchlist_id}"
|
||||
query += " ORDER BY i.created_at DESC"
|
||||
|
||||
logger.info(f"Executing query: {query}")
|
||||
|
||||
result = client.query(query)
|
||||
items = []
|
||||
for row in result.result_rows:
|
||||
data = dict(zip(result.column_names, row))
|
||||
logger.info(f"Processing row: {data}")
|
||||
items.append(WatchlistItem(
|
||||
id=data['id'],
|
||||
watchlist_id=data['watchlist_id'],
|
||||
ticker=data['ticker'],
|
||||
entry_price=data['entry_price'],
|
||||
target_price=data['target_price'],
|
||||
stop_loss=data['stop_loss'],
|
||||
shares=data.get('shares', 0), # Add default value
|
||||
notes=data['notes'],
|
||||
created_at=data['created_at']
|
||||
))
|
||||
return items
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting watchlist items: {e}", exc_info=True)
|
||||
return []
|
||||
@ -1 +1,3 @@
|
||||
# Empty file
|
||||
from .data_utils import get_stock_data
|
||||
|
||||
__all__ = ['get_stock_data']
|
||||
|
||||
@ -1,172 +0,0 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from db.db_connection import create_client
|
||||
|
||||
def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]:
|
||||
"""
|
||||
Get user input with escape option
|
||||
"""
|
||||
while True:
|
||||
value = input(f"{prompt} (q to quit): ").strip()
|
||||
|
||||
if value.lower() in ['q', 'quit', 'exit']:
|
||||
return None
|
||||
|
||||
if not value and allow_empty:
|
||||
return None
|
||||
|
||||
try:
|
||||
if input_type == bool:
|
||||
return value.lower() in ['y', 'yes', 'true', '1']
|
||||
return input_type(value)
|
||||
except ValueError:
|
||||
print(f"Please enter a valid {input_type.__name__}")
|
||||
|
||||
def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interval: str) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch and resample stock data based on the chosen interval
|
||||
"""
|
||||
try:
|
||||
with create_client() as client:
|
||||
# Expand window to get enough data for calculations
|
||||
start_date = start_date - timedelta(days=90)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
toDateTime(window_start/1000000000) as date,
|
||||
open,
|
||||
high,
|
||||
low,
|
||||
close,
|
||||
volume
|
||||
FROM stock_db.stock_prices
|
||||
WHERE ticker = '{ticker}'
|
||||
AND window_start BETWEEN
|
||||
{int(start_date.timestamp() * 1e9)} AND
|
||||
{int(end_date.timestamp() * 1e9)}
|
||||
AND toDateTime(window_start/1000000000) <= now()
|
||||
ORDER BY date ASC
|
||||
"""
|
||||
|
||||
result = client.query(query)
|
||||
|
||||
if not result.result_rows:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(
|
||||
result.result_rows,
|
||||
columns=['date', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
|
||||
numeric_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in numeric_columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df.set_index('date', inplace=True)
|
||||
|
||||
if interval == 'daily':
|
||||
rule = '1D'
|
||||
elif interval == '5min':
|
||||
rule = '5T'
|
||||
elif interval == '15min':
|
||||
rule = '15T'
|
||||
elif interval == '30min':
|
||||
rule = '30T'
|
||||
elif interval == '1hour':
|
||||
rule = '1H'
|
||||
else:
|
||||
rule = '1D'
|
||||
|
||||
resampled = df.resample(rule).agg({
|
||||
'open': 'first',
|
||||
'high': 'max',
|
||||
'low': 'min',
|
||||
'close': 'last',
|
||||
'volume': 'sum'
|
||||
}).dropna()
|
||||
|
||||
resampled.reset_index(inplace=True)
|
||||
|
||||
mask = (resampled['date'] >= start_date + timedelta(days=89)) & (resampled['date'] <= end_date)
|
||||
resampled = resampled.loc[mask]
|
||||
|
||||
if resampled['close'].isnull().any():
|
||||
print(f"Warning: Found null values in close prices")
|
||||
resampled = resampled.dropna(subset=['close'])
|
||||
|
||||
if resampled.empty or 'close' not in resampled.columns:
|
||||
return pd.DataFrame()
|
||||
|
||||
return resampled
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching {ticker} data: {str(e)}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_qualified_stocks(start_date: datetime, end_date: datetime, min_price: float, max_price: float, min_volume: int) -> list:
|
||||
"""
|
||||
Get qualified stocks based on price and volume criteria within date range
|
||||
"""
|
||||
try:
|
||||
start_ts = int(start_date.timestamp() * 1000000000)
|
||||
end_ts = int(end_date.timestamp() * 1000000000)
|
||||
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
WITH filtered_data AS (
|
||||
SELECT
|
||||
sp.ticker,
|
||||
sp.window_start,
|
||||
sp.close,
|
||||
sp.volume,
|
||||
t.type as stock_type,
|
||||
toDateTime(toDateTime(sp.window_start/1000000000)) as trade_date
|
||||
FROM stock_db.stock_prices sp
|
||||
JOIN stock_db.stock_tickers t ON sp.ticker = t.ticker
|
||||
WHERE window_start BETWEEN {start_ts} AND {end_ts}
|
||||
AND toDateTime(window_start/1000000000) <= now()
|
||||
AND close BETWEEN {min_price} AND {max_price}
|
||||
AND volume >= {min_volume}
|
||||
),
|
||||
daily_data AS (
|
||||
SELECT
|
||||
ticker,
|
||||
stock_type,
|
||||
toDate(trade_date) as date,
|
||||
argMax(close, window_start) as daily_close,
|
||||
sum(volume) as daily_volume
|
||||
FROM filtered_data
|
||||
GROUP BY ticker, stock_type, toDate(trade_date)
|
||||
),
|
||||
latest_data AS (
|
||||
SELECT
|
||||
ticker,
|
||||
any(stock_type) as stock_type,
|
||||
argMax(daily_close, date) as last_close,
|
||||
sum(daily_volume) as total_volume,
|
||||
max(toUnixTimestamp(date)) as last_update
|
||||
FROM daily_data
|
||||
GROUP BY ticker
|
||||
HAVING last_close BETWEEN {min_price} AND {max_price}
|
||||
)
|
||||
SELECT
|
||||
ticker,
|
||||
last_close,
|
||||
total_volume,
|
||||
last_update,
|
||||
stock_type
|
||||
FROM latest_data
|
||||
ORDER BY ticker
|
||||
"""
|
||||
|
||||
result = client.query(query)
|
||||
qualified_stocks = [(row[0], row[1], row[2], row[3], row[4]) for row in result.result_rows]
|
||||
|
||||
return qualified_stocks
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting qualified stocks: {str(e)}")
|
||||
return []
|
||||
@ -1,53 +1,16 @@
|
||||
import os
|
||||
import os
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from utils.common_utils import get_user_input, get_stock_data, get_qualified_stocks
|
||||
from typing import Optional
|
||||
from db.db_connection import create_client
|
||||
|
||||
def get_float_input(prompt: str) -> Optional[float]:
|
||||
return get_user_input(prompt, float)
|
||||
|
||||
def get_current_prices(tickers: list) -> dict:
|
||||
"""Get current prices for multiple tickers using yfinance"""
|
||||
if not tickers:
|
||||
return {}
|
||||
|
||||
prices = {}
|
||||
try:
|
||||
with create_client() as client:
|
||||
for ticker in tickers:
|
||||
query = f"""
|
||||
SELECT close
|
||||
FROM stock_db.stock_prices_daily
|
||||
WHERE ticker = '{ticker}'
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
result = client.query(query)
|
||||
|
||||
try:
|
||||
prices[ticker] = result.result_rows[0][0]
|
||||
except KeyError:
|
||||
prices[ticker] = 0.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching prices: {e}")
|
||||
# If there's an error, set price to 0.0
|
||||
for ticker in tickers:
|
||||
if ticker not in prices:
|
||||
prices[ticker] = 0.0
|
||||
|
||||
return prices
|
||||
|
||||
def validate_signal_date(signal_date: datetime) -> datetime:
|
||||
"""
|
||||
Validate and adjust signal date if needed
|
||||
|
||||
|
||||
Args:
|
||||
signal_date (datetime): Signal date to validate
|
||||
|
||||
|
||||
Returns:
|
||||
datetime: Valid signal date (not in future)
|
||||
"""
|
||||
@ -56,30 +19,99 @@ def validate_signal_date(signal_date: datetime) -> datetime:
|
||||
return current_date
|
||||
return signal_date
|
||||
|
||||
def print_signal(signal_dict, signal_type: str = "🔍") -> None:
|
||||
def print_signal(signal_data: dict, signal_type: str = "🔍") -> None:
|
||||
"""
|
||||
Print standardized signal output
|
||||
|
||||
|
||||
Args:
|
||||
signal_data (dict): Dictionary containing signal information
|
||||
signal_type (str): Emoji indicator for signal type (default: 🔍)
|
||||
"""
|
||||
try:
|
||||
print(f"\n{signal_type} {signal_dict['ticker']} ({signal_dict['stock_type']}) @ ${signal_dict['entry_price']:.2f} on {signal_dict['signal_date'].strftime('%Y-%m-%d %H:%M')}")
|
||||
print(f" Size: {signal_dict['shares']} shares (${signal_dict['position_size']:.2f})")
|
||||
print(f" Stop: ${signal_dict['stop_loss']:.2f} (7%) | Target: ${signal_dict['target_price']:.2f}")
|
||||
print(f" Risk/Reward: 1:{signal_dict['risk_reward_ratio']:.1f} | Risk: ${abs(signal_dict['risk_amount']):.2f}")
|
||||
print(f" Potential Profit: ${signal_dict['profit_amount']:.2f}")
|
||||
print(f"\n{signal_type} {signal_data['ticker']} @ ${signal_data['entry_price']:.2f} on {signal_data['signal_date'].strftime('%Y-%m-%d %H:%M')}")
|
||||
print(f" Size: {signal_data['shares']} shares (${signal_data['position_size']:.2f})")
|
||||
print(f" Stop: ${signal_data['stop_loss']:.2f} (7%) | Target: ${signal_data['target_price']:.2f}")
|
||||
print(f" Risk/Reward: 1:{signal_data['risk_reward_ratio']:.1f} | Risk: ${abs(signal_data['risk_amount']):.2f}")
|
||||
print(f" Potential Profit: ${signal_data['profit_amount']:.2f}")
|
||||
except KeyError as e:
|
||||
print(f"Error printing signal for {signal_dict.get('ticker', 'Unknown')}: Missing key {e}")
|
||||
print(f"Error printing signal for {signal_data.get('ticker', 'Unknown')}: Missing key {e}")
|
||||
# Print available keys for debugging
|
||||
print(f"Available keys: {list(signal_dict.keys())}")
|
||||
print(f"Available keys: {list(signal_data.keys())}")
|
||||
|
||||
def get_qualified_stocks(start_date: datetime, end_date: datetime, min_price: float, max_price: float, min_volume: int) -> list:
|
||||
"""
|
||||
Get qualified stocks based on price and volume criteria within date range
|
||||
|
||||
Args:
|
||||
start_date (datetime): Start date for data fetch
|
||||
end_date (datetime): End date for data fetch
|
||||
min_price (float): Minimum stock price
|
||||
max_price (float): Maximum stock price
|
||||
min_volume (int): Minimum trading volume
|
||||
|
||||
Returns:
|
||||
list: List of tuples (ticker, price, volume, last_update)
|
||||
"""
|
||||
try:
|
||||
start_ts = int(start_date.timestamp() * 1000000000)
|
||||
end_ts = int(end_date.timestamp() * 1000000000)
|
||||
|
||||
with create_client() as client:
|
||||
query = f"""
|
||||
WITH filtered_data AS (
|
||||
SELECT
|
||||
ticker,
|
||||
window_start,
|
||||
close,
|
||||
volume,
|
||||
toDateTime(toDateTime(window_start/1000000000)) as trade_date
|
||||
FROM stock_db.stock_prices
|
||||
WHERE window_start BETWEEN {start_ts} AND {end_ts}
|
||||
AND toDateTime(window_start/1000000000) <= now()
|
||||
AND close BETWEEN {min_price} AND {max_price}
|
||||
AND volume >= {min_volume}
|
||||
),
|
||||
daily_data AS (
|
||||
SELECT
|
||||
ticker,
|
||||
toDate(trade_date) as date,
|
||||
argMax(close, window_start) as daily_close,
|
||||
sum(volume) as daily_volume
|
||||
FROM filtered_data
|
||||
GROUP BY ticker, toDate(trade_date)
|
||||
),
|
||||
latest_data AS (
|
||||
SELECT
|
||||
ticker,
|
||||
argMax(daily_close, date) as last_close,
|
||||
sum(daily_volume) as total_volume,
|
||||
max(toUnixTimestamp(date)) as last_update
|
||||
FROM daily_data
|
||||
GROUP BY ticker
|
||||
HAVING last_close BETWEEN {min_price} AND {max_price}
|
||||
)
|
||||
SELECT
|
||||
ticker,
|
||||
last_close,
|
||||
total_volume,
|
||||
last_update
|
||||
FROM latest_data
|
||||
ORDER BY ticker
|
||||
"""
|
||||
|
||||
result = client.query(query)
|
||||
qualified_stocks = [(row[0], row[1], row[2], row[3]) for row in result.result_rows]
|
||||
|
||||
return qualified_stocks
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting qualified stocks: {str(e)}")
|
||||
return []
|
||||
|
||||
def save_signals_to_csv(signals: list, scanner_name: str) -> None:
|
||||
"""
|
||||
Save signals to CSV file with standardized format and naming
|
||||
|
||||
|
||||
Args:
|
||||
signals (list): List of signal dictionaries
|
||||
scanner_name (str): Name of the scanner for file naming
|
||||
@ -87,12 +119,113 @@ def save_signals_to_csv(signals: list, scanner_name: str) -> None:
|
||||
if not signals:
|
||||
print("\nNo signals found")
|
||||
return
|
||||
|
||||
|
||||
output_dir = 'reports'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_date = datetime.now().strftime("%Y%m%d_%H%M")
|
||||
output_file = f'{output_dir}/{scanner_name}_{output_date}.csv'
|
||||
|
||||
|
||||
df_signals = pd.DataFrame(signals)
|
||||
df_signals.to_csv(output_file, index=False)
|
||||
print(f"\nSaved {len(signals)} signals to {output_file}")
|
||||
|
||||
def get_stock_data(ticker: str, start_date: datetime, end_date: datetime, interval: str) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch and resample stock data based on the chosen interval
|
||||
|
||||
Args:
|
||||
ticker (str): Stock ticker symbol
|
||||
start_date (datetime): Start date for data fetch
|
||||
end_date (datetime): End date for data fetch
|
||||
interval (str): Time interval for data ('daily', '5min', '15min', '30min', '1hour')
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Resampled DataFrame with OHLCV data
|
||||
"""
|
||||
try:
|
||||
with create_client() as client:
|
||||
# Expand window to get enough data for calculations
|
||||
start_date = start_date - timedelta(days=90)
|
||||
|
||||
# Base query to get raw data at finest granularity
|
||||
query = f"""
|
||||
SELECT
|
||||
toDateTime(window_start/1000000000) as date,
|
||||
open,
|
||||
high,
|
||||
low,
|
||||
close,
|
||||
volume
|
||||
FROM stock_db.stock_prices
|
||||
WHERE ticker = '{ticker}'
|
||||
AND window_start BETWEEN
|
||||
{int(start_date.timestamp() * 1e9)} AND
|
||||
{int(end_date.timestamp() * 1e9)}
|
||||
AND toDateTime(window_start/1000000000) <= now()
|
||||
ORDER BY date ASC
|
||||
"""
|
||||
|
||||
result = client.query(query)
|
||||
|
||||
if not result.result_rows:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Create base DataFrame
|
||||
df = pd.DataFrame(
|
||||
result.result_rows,
|
||||
columns=['date', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
|
||||
# Convert numeric columns
|
||||
numeric_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in numeric_columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
# Convert date column
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
|
||||
# Set date as index for resampling
|
||||
df.set_index('date', inplace=True)
|
||||
|
||||
# Resample based on interval
|
||||
if interval == 'daily':
|
||||
rule = '1D'
|
||||
elif interval == '5min':
|
||||
rule = '5T'
|
||||
elif interval == '15min':
|
||||
rule = '15T'
|
||||
elif interval == '30min':
|
||||
rule = '30T'
|
||||
elif interval == '1hour':
|
||||
rule = '1H'
|
||||
else:
|
||||
rule = '1D' # Default to daily
|
||||
|
||||
resampled = df.resample(rule).agg({
|
||||
'open': 'first',
|
||||
'high': 'max',
|
||||
'low': 'min',
|
||||
'close': 'last',
|
||||
'volume': 'sum'
|
||||
}).dropna()
|
||||
|
||||
# Reset index to get date as column
|
||||
resampled.reset_index(inplace=True)
|
||||
|
||||
# Filter to requested date range
|
||||
mask = (resampled['date'] >= start_date + timedelta(days=89)) & (resampled['date'] <= end_date)
|
||||
resampled = resampled.loc[mask]
|
||||
|
||||
# Handle null values
|
||||
if resampled['close'].isnull().any():
|
||||
print(f"Warning: Found null values in close prices")
|
||||
resampled = resampled.dropna(subset=['close'])
|
||||
|
||||
if resampled.empty or 'close' not in resampled.columns:
|
||||
return pd.DataFrame()
|
||||
|
||||
return resampled
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching {ticker} data: {str(e)}")
|
||||
return pd.DataFrame()
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
def load_scanner_reports(scanner_type: str = None):
|
||||
"""
|
||||
Load and return available scanner reports
|
||||
|
||||
Args:
|
||||
scanner_type (str, optional): Filter reports by scanner type (e.g., 'technical', 'canslim')
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing report information
|
||||
"""
|
||||
reports = []
|
||||
|
||||
# Get absolute path to reports directory
|
||||
reports_dir = Path.cwd() / "reports"
|
||||
|
||||
# Create reports directory if it doesn't exist
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if reports_dir.exists():
|
||||
# Print debug info
|
||||
print(f"Scanning directory: {reports_dir}")
|
||||
|
||||
for file in reports_dir.glob("*.csv"):
|
||||
print(f"Found file: {file.name}") # Debug print
|
||||
|
||||
# Get file creation time
|
||||
created = datetime.fromtimestamp(file.stat().st_ctime)
|
||||
|
||||
# If scanner_type is specified, apply appropriate filtering
|
||||
if scanner_type == "canslim":
|
||||
if not file.name.startswith("canslim_"):
|
||||
print(f"Skipping {file.name} - not a CANSLIM report") # Debug print
|
||||
continue
|
||||
elif scanner_type == "non_canslim":
|
||||
if file.name.startswith("canslim_"):
|
||||
print(f"Skipping {file.name} - is a CANSLIM report") # Debug print
|
||||
continue
|
||||
|
||||
reports.append({
|
||||
'name': file.name,
|
||||
'path': str(file),
|
||||
'created': created
|
||||
})
|
||||
print(f"Added report: {file.name}") # Debug print
|
||||
|
||||
# Print final count
|
||||
print(f"Found {len(reports)} matching reports")
|
||||
|
||||
# Sort by creation time, newest first
|
||||
return sorted(reports, key=lambda x: x['created'], reverse=True)
|
||||
@ -1,85 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from utils.common_utils import get_user_input, get_stock_data, get_qualified_stocks
|
||||
from screener.user_input import get_interval_choice, get_date_range
|
||||
from trading.position_calculator import PositionCalculator
|
||||
from typing import Optional
|
||||
|
||||
def initialize_scanner(min_price: float, max_price: float, min_volume: int,
|
||||
portfolio_size: float = None, interval: str = "1d",
|
||||
start_date: datetime = None, end_date: datetime = None) -> tuple:
|
||||
"""
|
||||
Initialize common scanner components
|
||||
|
||||
Args:
|
||||
min_price (float): Minimum stock price
|
||||
max_price (float): Maximum stock price
|
||||
min_volume (int): Minimum volume threshold
|
||||
portfolio_size (float, optional): Portfolio size for position calculations
|
||||
interval (str, optional): Time interval for data (default: "1d")
|
||||
start_date (datetime, optional): Start date for scanning
|
||||
end_date (datetime, optional): End date for scanning
|
||||
"""
|
||||
print(f"\nScanning for stocks ${min_price:.2f}-${max_price:.2f} with min volume {min_volume:,}")
|
||||
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("start_date and end_date must be provided")
|
||||
|
||||
qualified_stocks = get_qualified_stocks(start_date, end_date, min_price, max_price, min_volume)
|
||||
|
||||
if not qualified_stocks:
|
||||
print("No stocks found matching criteria.")
|
||||
return None, None, None, None, None
|
||||
|
||||
print(f"\nFound {len(qualified_stocks)} stocks matching criteria")
|
||||
|
||||
# Initialize position calculator if portfolio size provided
|
||||
calculator = None
|
||||
if portfolio_size and portfolio_size > 0:
|
||||
calculator = PositionCalculator(
|
||||
account_size=portfolio_size,
|
||||
risk_percentage=1.0,
|
||||
stop_loss_percentage=7.0
|
||||
)
|
||||
|
||||
return interval, start_date, end_date, qualified_stocks, calculator
|
||||
|
||||
def process_signal_data(ticker: str, signal_data: dict, current_volume: int,
|
||||
last_update: int, stock_type: str, calculator: PositionCalculator = None) -> dict:
|
||||
"""
|
||||
Process and format signal data consistently
|
||||
"""
|
||||
entry_price = signal_data['price']
|
||||
|
||||
# Determine target price based on signal type
|
||||
if 'ha_close' in signal_data: # Heikin Ashi signal
|
||||
# Use a 2:1 reward-to-risk ratio for Heikin Ashi
|
||||
stop_loss_pct = 0.07 # 7% stop loss
|
||||
stop_distance = entry_price * stop_loss_pct
|
||||
target_price = entry_price + (stop_distance * 2) # 2x the stop distance
|
||||
else:
|
||||
# Handle other signal types (ATR-EMA or Sunny Bands)
|
||||
target_price = signal_data.get('ema', signal_data.get('upper_band'))
|
||||
|
||||
entry_data = {
|
||||
'ticker': ticker,
|
||||
'entry_price': entry_price,
|
||||
'target_price': target_price,
|
||||
'volume': current_volume,
|
||||
'signal_date': signal_data.get('date', datetime.now()),
|
||||
'stock_type': stock_type,
|
||||
'last_update': datetime.fromtimestamp(last_update/1000000000)
|
||||
}
|
||||
|
||||
if calculator:
|
||||
position = calculator.calculate_position_size(entry_price)
|
||||
potential_profit = (target_price - entry_price) * position['shares']
|
||||
entry_data.update({
|
||||
'shares': position['shares'],
|
||||
'position_size': position['position_value'],
|
||||
'stop_loss': position['stop_loss'],
|
||||
'risk_amount': position['potential_loss'],
|
||||
'profit_amount': potential_profit,
|
||||
'risk_reward_ratio': abs(potential_profit / position['potential_loss']) if position['potential_loss'] != 0 else 0
|
||||
})
|
||||
|
||||
return entry_data
|
||||
Loading…
Reference in New Issue
Block a user