diff --git a/src/trading/journal.py b/src/trading/journal.py index dbfd728..f1d6e1d 100644 --- a/src/trading/journal.py +++ b/src/trading/journal.py @@ -5,6 +5,7 @@ import pytz from zoneinfo import ZoneInfo from db.db_connection import create_client from trading.position_calculator import PositionCalculator +from utils.data_utils import get_user_input @dataclass class TradeEntry: @@ -73,7 +74,7 @@ def validate_market_time(dt: datetime) -> tuple[datetime, bool]: return dt, False -def get_datetime_input(prompt: str, default: datetime = None) -> datetime: +def get_datetime_input(prompt: str, default: datetime = None) -> Optional[datetime]: """Get date and time input in Pacific time""" pacific = pytz.timezone('US/Pacific') @@ -81,7 +82,10 @@ def get_datetime_input(prompt: str, default: datetime = None) -> datetime: try: if default: print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})") - date_str = input(f"{prompt} (YYYY-MM-DD HH:MM): ").strip() + date_str = input(f"{prompt} (YYYY-MM-DD HH:MM, q to quit): ").strip() + + if date_str.lower() in ['q', 'quit', 'exit']: + return None if not date_str and default: dt = default @@ -167,14 +171,18 @@ def get_position_summary(ticker: str) -> dict: 'target_price', 'stop_loss', 'strategy'] return [dict(zip(columns, row)) for row in result] -def get_order_type() -> str: +def get_order_type() -> Optional[str]: """Get order type from user""" while True: print("\nOrder Type:") print("1. Market") print("2. Limit") - choice = input("Select order type (1-2): ") - if choice == "1": + print("q. Quit") + choice = input("Select order type (1-2, q to quit): ") + + if choice.lower() in ['q', 'quit', 'exit']: + return None + elif choice == "1": return "Market" elif choice == "2": return "Limit" @@ -283,7 +291,10 @@ def journal_menu(): choice = input("\nSelect an option (1-5): ") if choice == "1": - ticker = input("Enter ticker symbol: ").upper() + ticker = get_user_input("Enter ticker symbol:", str) + if ticker is None: + continue + ticker = ticker.upper() # Show existing positions for this ticker existing_positions = get_position_summary(ticker) @@ -296,9 +307,14 @@ def journal_menu(): print(f"First Entry: {pos['first_entry']}") print(f"Number of Orders: {pos['num_orders']}") - add_to_existing = input("\nAdd to existing position? (y/n): ").lower() == 'y' + add_to_existing = get_user_input("Add to existing position? (y/n):", bool) + if add_to_existing is None: + continue + if add_to_existing: - position_id = input("Enter Position ID: ") + position_id = get_user_input("Enter Position ID:", str) + if position_id is None: + continue else: position_id = generate_position_id(ticker) else: @@ -306,10 +322,20 @@ def journal_menu(): # Get entry date/time with market hours validation entry_date = get_datetime_input("Enter entry date and time", default=datetime.now()) + if entry_date is None: + continue - shares = int(input("Enter number of shares: ")) - entry_price = float(input("Enter entry price: ")) + shares = get_user_input("Enter number of shares:", int) + if shares is None: + continue + + entry_price = get_user_input("Enter entry price:", float) + if entry_price is None: + continue + order_type = get_order_type() + if order_type is None: + continue # If adding to existing position, get target/stop from existing if existing_positions and add_to_existing: diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 325cc4b..1177de7 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -5,12 +5,38 @@ from db.db_connection import create_client from screener.user_input import get_interval_choice, get_date_range from trading.position_calculator import PositionCalculator -def get_float_input(prompt: str) -> float: +from typing import Optional + +def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]: + """ + Get user input with escape option + + Args: + prompt (str): Input prompt to display + input_type (type): Expected input type (str, float, int) + allow_empty (bool): Whether to allow empty input + + Returns: + Optional[any]: Converted input value or None if user wants to exit + """ while True: + value = input(f"{prompt} (q to quit): ").strip() + + if value.lower() in ['q', 'quit', 'exit']: + return None + + if not value and allow_empty: + return None + try: - return float(input(prompt)) + if input_type == bool: + return value.lower() in ['y', 'yes', 'true', '1'] + return input_type(value) except ValueError: - print("Please enter a valid number") + print(f"Please enter a valid {input_type.__name__}") + +def get_float_input(prompt: str) -> Optional[float]: + return get_user_input(prompt, float) def validate_signal_date(signal_date: datetime) -> datetime: """