feat: Add user input escape mechanism with consistent quit options

This commit is contained in:
Bobby (aider) 2025-02-10 09:58:25 -08:00
parent 9ccccfce97
commit 3dbaed0399
2 changed files with 65 additions and 13 deletions

View File

@ -5,6 +5,7 @@ import pytz
from zoneinfo import ZoneInfo
from db.db_connection import create_client
from trading.position_calculator import PositionCalculator
from utils.data_utils import get_user_input
@dataclass
class TradeEntry:
@ -73,7 +74,7 @@ def validate_market_time(dt: datetime) -> tuple[datetime, bool]:
return dt, False
def get_datetime_input(prompt: str, default: datetime = None) -> datetime:
def get_datetime_input(prompt: str, default: datetime = None) -> Optional[datetime]:
"""Get date and time input in Pacific time"""
pacific = pytz.timezone('US/Pacific')
@ -81,7 +82,10 @@ def get_datetime_input(prompt: str, default: datetime = None) -> datetime:
try:
if default:
print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})")
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM): ").strip()
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM, q to quit): ").strip()
if date_str.lower() in ['q', 'quit', 'exit']:
return None
if not date_str and default:
dt = default
@ -167,14 +171,18 @@ def get_position_summary(ticker: str) -> dict:
'target_price', 'stop_loss', 'strategy']
return [dict(zip(columns, row)) for row in result]
def get_order_type() -> str:
def get_order_type() -> Optional[str]:
"""Get order type from user"""
while True:
print("\nOrder Type:")
print("1. Market")
print("2. Limit")
choice = input("Select order type (1-2): ")
if choice == "1":
print("q. Quit")
choice = input("Select order type (1-2, q to quit): ")
if choice.lower() in ['q', 'quit', 'exit']:
return None
elif choice == "1":
return "Market"
elif choice == "2":
return "Limit"
@ -283,7 +291,10 @@ def journal_menu():
choice = input("\nSelect an option (1-5): ")
if choice == "1":
ticker = input("Enter ticker symbol: ").upper()
ticker = get_user_input("Enter ticker symbol:", str)
if ticker is None:
continue
ticker = ticker.upper()
# Show existing positions for this ticker
existing_positions = get_position_summary(ticker)
@ -296,9 +307,14 @@ def journal_menu():
print(f"First Entry: {pos['first_entry']}")
print(f"Number of Orders: {pos['num_orders']}")
add_to_existing = input("\nAdd to existing position? (y/n): ").lower() == 'y'
add_to_existing = get_user_input("Add to existing position? (y/n):", bool)
if add_to_existing is None:
continue
if add_to_existing:
position_id = input("Enter Position ID: ")
position_id = get_user_input("Enter Position ID:", str)
if position_id is None:
continue
else:
position_id = generate_position_id(ticker)
else:
@ -306,10 +322,20 @@ def journal_menu():
# Get entry date/time with market hours validation
entry_date = get_datetime_input("Enter entry date and time", default=datetime.now())
if entry_date is None:
continue
shares = int(input("Enter number of shares: "))
entry_price = float(input("Enter entry price: "))
shares = get_user_input("Enter number of shares:", int)
if shares is None:
continue
entry_price = get_user_input("Enter entry price:", float)
if entry_price is None:
continue
order_type = get_order_type()
if order_type is None:
continue
# If adding to existing position, get target/stop from existing
if existing_positions and add_to_existing:

View File

@ -5,12 +5,38 @@ from db.db_connection import create_client
from screener.user_input import get_interval_choice, get_date_range
from trading.position_calculator import PositionCalculator
def get_float_input(prompt: str) -> float:
from typing import Optional
def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]:
"""
Get user input with escape option
Args:
prompt (str): Input prompt to display
input_type (type): Expected input type (str, float, int)
allow_empty (bool): Whether to allow empty input
Returns:
Optional[any]: Converted input value or None if user wants to exit
"""
while True:
value = input(f"{prompt} (q to quit): ").strip()
if value.lower() in ['q', 'quit', 'exit']:
return None
if not value and allow_empty:
return None
try:
return float(input(prompt))
if input_type == bool:
return value.lower() in ['y', 'yes', 'true', '1']
return input_type(value)
except ValueError:
print("Please enter a valid number")
print(f"Please enter a valid {input_type.__name__}")
def get_float_input(prompt: str) -> Optional[float]:
return get_user_input(prompt, float)
def validate_signal_date(signal_date: datetime) -> datetime:
"""