feat: Add user input escape mechanism with consistent quit options
This commit is contained in:
parent
9ccccfce97
commit
3dbaed0399
@ -5,6 +5,7 @@ import pytz
|
|||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
from db.db_connection import create_client
|
from db.db_connection import create_client
|
||||||
from trading.position_calculator import PositionCalculator
|
from trading.position_calculator import PositionCalculator
|
||||||
|
from utils.data_utils import get_user_input
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TradeEntry:
|
class TradeEntry:
|
||||||
@ -73,7 +74,7 @@ def validate_market_time(dt: datetime) -> tuple[datetime, bool]:
|
|||||||
|
|
||||||
return dt, False
|
return dt, False
|
||||||
|
|
||||||
def get_datetime_input(prompt: str, default: datetime = None) -> datetime:
|
def get_datetime_input(prompt: str, default: datetime = None) -> Optional[datetime]:
|
||||||
"""Get date and time input in Pacific time"""
|
"""Get date and time input in Pacific time"""
|
||||||
pacific = pytz.timezone('US/Pacific')
|
pacific = pytz.timezone('US/Pacific')
|
||||||
|
|
||||||
@ -81,7 +82,10 @@ def get_datetime_input(prompt: str, default: datetime = None) -> datetime:
|
|||||||
try:
|
try:
|
||||||
if default:
|
if default:
|
||||||
print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})")
|
print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})")
|
||||||
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM): ").strip()
|
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM, q to quit): ").strip()
|
||||||
|
|
||||||
|
if date_str.lower() in ['q', 'quit', 'exit']:
|
||||||
|
return None
|
||||||
|
|
||||||
if not date_str and default:
|
if not date_str and default:
|
||||||
dt = default
|
dt = default
|
||||||
@ -167,14 +171,18 @@ def get_position_summary(ticker: str) -> dict:
|
|||||||
'target_price', 'stop_loss', 'strategy']
|
'target_price', 'stop_loss', 'strategy']
|
||||||
return [dict(zip(columns, row)) for row in result]
|
return [dict(zip(columns, row)) for row in result]
|
||||||
|
|
||||||
def get_order_type() -> str:
|
def get_order_type() -> Optional[str]:
|
||||||
"""Get order type from user"""
|
"""Get order type from user"""
|
||||||
while True:
|
while True:
|
||||||
print("\nOrder Type:")
|
print("\nOrder Type:")
|
||||||
print("1. Market")
|
print("1. Market")
|
||||||
print("2. Limit")
|
print("2. Limit")
|
||||||
choice = input("Select order type (1-2): ")
|
print("q. Quit")
|
||||||
if choice == "1":
|
choice = input("Select order type (1-2, q to quit): ")
|
||||||
|
|
||||||
|
if choice.lower() in ['q', 'quit', 'exit']:
|
||||||
|
return None
|
||||||
|
elif choice == "1":
|
||||||
return "Market"
|
return "Market"
|
||||||
elif choice == "2":
|
elif choice == "2":
|
||||||
return "Limit"
|
return "Limit"
|
||||||
@ -283,7 +291,10 @@ def journal_menu():
|
|||||||
choice = input("\nSelect an option (1-5): ")
|
choice = input("\nSelect an option (1-5): ")
|
||||||
|
|
||||||
if choice == "1":
|
if choice == "1":
|
||||||
ticker = input("Enter ticker symbol: ").upper()
|
ticker = get_user_input("Enter ticker symbol:", str)
|
||||||
|
if ticker is None:
|
||||||
|
continue
|
||||||
|
ticker = ticker.upper()
|
||||||
|
|
||||||
# Show existing positions for this ticker
|
# Show existing positions for this ticker
|
||||||
existing_positions = get_position_summary(ticker)
|
existing_positions = get_position_summary(ticker)
|
||||||
@ -296,9 +307,14 @@ def journal_menu():
|
|||||||
print(f"First Entry: {pos['first_entry']}")
|
print(f"First Entry: {pos['first_entry']}")
|
||||||
print(f"Number of Orders: {pos['num_orders']}")
|
print(f"Number of Orders: {pos['num_orders']}")
|
||||||
|
|
||||||
add_to_existing = input("\nAdd to existing position? (y/n): ").lower() == 'y'
|
add_to_existing = get_user_input("Add to existing position? (y/n):", bool)
|
||||||
|
if add_to_existing is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if add_to_existing:
|
if add_to_existing:
|
||||||
position_id = input("Enter Position ID: ")
|
position_id = get_user_input("Enter Position ID:", str)
|
||||||
|
if position_id is None:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
position_id = generate_position_id(ticker)
|
position_id = generate_position_id(ticker)
|
||||||
else:
|
else:
|
||||||
@ -306,10 +322,20 @@ def journal_menu():
|
|||||||
|
|
||||||
# Get entry date/time with market hours validation
|
# Get entry date/time with market hours validation
|
||||||
entry_date = get_datetime_input("Enter entry date and time", default=datetime.now())
|
entry_date = get_datetime_input("Enter entry date and time", default=datetime.now())
|
||||||
|
if entry_date is None:
|
||||||
|
continue
|
||||||
|
|
||||||
shares = int(input("Enter number of shares: "))
|
shares = get_user_input("Enter number of shares:", int)
|
||||||
entry_price = float(input("Enter entry price: "))
|
if shares is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry_price = get_user_input("Enter entry price:", float)
|
||||||
|
if entry_price is None:
|
||||||
|
continue
|
||||||
|
|
||||||
order_type = get_order_type()
|
order_type = get_order_type()
|
||||||
|
if order_type is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# If adding to existing position, get target/stop from existing
|
# If adding to existing position, get target/stop from existing
|
||||||
if existing_positions and add_to_existing:
|
if existing_positions and add_to_existing:
|
||||||
|
|||||||
@ -5,12 +5,38 @@ from db.db_connection import create_client
|
|||||||
from screener.user_input import get_interval_choice, get_date_range
|
from screener.user_input import get_interval_choice, get_date_range
|
||||||
from trading.position_calculator import PositionCalculator
|
from trading.position_calculator import PositionCalculator
|
||||||
|
|
||||||
def get_float_input(prompt: str) -> float:
|
from typing import Optional
|
||||||
|
|
||||||
|
def get_user_input(prompt: str, input_type: type = str, allow_empty: bool = False) -> Optional[any]:
|
||||||
|
"""
|
||||||
|
Get user input with escape option
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): Input prompt to display
|
||||||
|
input_type (type): Expected input type (str, float, int)
|
||||||
|
allow_empty (bool): Whether to allow empty input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[any]: Converted input value or None if user wants to exit
|
||||||
|
"""
|
||||||
while True:
|
while True:
|
||||||
|
value = input(f"{prompt} (q to quit): ").strip()
|
||||||
|
|
||||||
|
if value.lower() in ['q', 'quit', 'exit']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not value and allow_empty:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return float(input(prompt))
|
if input_type == bool:
|
||||||
|
return value.lower() in ['y', 'yes', 'true', '1']
|
||||||
|
return input_type(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("Please enter a valid number")
|
print(f"Please enter a valid {input_type.__name__}")
|
||||||
|
|
||||||
|
def get_float_input(prompt: str) -> Optional[float]:
|
||||||
|
return get_user_input(prompt, float)
|
||||||
|
|
||||||
def validate_signal_date(signal_date: datetime) -> datetime:
|
def validate_signal_date(signal_date: datetime) -> datetime:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user