feat: Add timezone-aware datetime input with market hours validation
This commit is contained in:
parent
25a664e5bb
commit
0ce4bb4486
@ -1,6 +1,8 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import pytz
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
from db.db_connection import create_client
|
from db.db_connection import create_client
|
||||||
from trading.position_calculator import PositionCalculator
|
from trading.position_calculator import PositionCalculator
|
||||||
|
|
||||||
@ -36,6 +38,74 @@ class TradeEntry:
|
|||||||
return (self.exit_price - self.entry_price) * self.shares
|
return (self.exit_price - self.entry_price) * self.shares
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_market_hours(date: datetime) -> tuple:
|
||||||
|
"""Get market open/close times in Eastern for given date"""
|
||||||
|
eastern = pytz.timezone('US/Eastern')
|
||||||
|
date_eastern = date.astimezone(eastern)
|
||||||
|
|
||||||
|
market_open = eastern.localize(
|
||||||
|
datetime.combine(date_eastern.date(), datetime.strptime("09:30", "%H:%M").time())
|
||||||
|
)
|
||||||
|
market_close = eastern.localize(
|
||||||
|
datetime.combine(date_eastern.date(), datetime.strptime("16:00", "%H:%M").time())
|
||||||
|
)
|
||||||
|
return market_open, market_close
|
||||||
|
|
||||||
|
def validate_market_time(dt: datetime) -> tuple[datetime, bool]:
|
||||||
|
"""
|
||||||
|
Validate if time is during market hours, adjust if needed
|
||||||
|
Returns: (adjusted_datetime, was_adjusted)
|
||||||
|
"""
|
||||||
|
pacific = pytz.timezone('US/Pacific')
|
||||||
|
eastern = pytz.timezone('US/Eastern')
|
||||||
|
|
||||||
|
# Ensure datetime is timezone-aware
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = pacific.localize(dt)
|
||||||
|
|
||||||
|
dt_eastern = dt.astimezone(eastern)
|
||||||
|
market_open, market_close = get_market_hours(dt_eastern)
|
||||||
|
|
||||||
|
if dt_eastern < market_open:
|
||||||
|
return market_open.astimezone(pacific), True
|
||||||
|
elif dt_eastern > market_close:
|
||||||
|
return market_close.astimezone(pacific), True
|
||||||
|
|
||||||
|
return dt, False
|
||||||
|
|
||||||
|
def get_datetime_input(prompt: str, default: datetime = None) -> datetime:
|
||||||
|
"""Get date and time input in Pacific time"""
|
||||||
|
pacific = pytz.timezone('US/Pacific')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if default:
|
||||||
|
print(f"Press Enter for current time ({default.strftime('%Y-%m-%d %H:%M')})")
|
||||||
|
date_str = input(f"{prompt} (YYYY-MM-DD HH:MM): ").strip()
|
||||||
|
|
||||||
|
if not date_str and default:
|
||||||
|
dt = default
|
||||||
|
else:
|
||||||
|
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
|
# Make datetime timezone-aware (Pacific)
|
||||||
|
dt = pacific.localize(dt)
|
||||||
|
|
||||||
|
# Validate market hours
|
||||||
|
adjusted_dt, was_adjusted = validate_market_time(dt)
|
||||||
|
if was_adjusted:
|
||||||
|
print(f"\nWarning: Time adjusted to market hours (Eastern)")
|
||||||
|
print(f"Original (Pacific): {dt.strftime('%Y-%m-%d %H:%M %Z')}")
|
||||||
|
print(f"Adjusted (Pacific): {adjusted_dt.strftime('%Y-%m-%d %H:%M %Z')}")
|
||||||
|
print(f"Adjusted (Eastern): {adjusted_dt.astimezone(pytz.timezone('US/Eastern')).strftime('%Y-%m-%d %H:%M %Z')}")
|
||||||
|
if input("Accept adjusted time? (y/n): ").lower() != 'y':
|
||||||
|
continue
|
||||||
|
|
||||||
|
return adjusted_dt
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print("Invalid format. Please use YYYY-MM-D HH:MM")
|
||||||
|
|
||||||
def create_trades_table():
|
def create_trades_table():
|
||||||
with create_client() as client:
|
with create_client() as client:
|
||||||
query = """
|
query = """
|
||||||
@ -208,6 +278,9 @@ def journal_menu():
|
|||||||
else:
|
else:
|
||||||
position_id = generate_position_id(ticker)
|
position_id = generate_position_id(ticker)
|
||||||
|
|
||||||
|
# Get entry date/time with market hours validation
|
||||||
|
entry_date = get_datetime_input("Enter entry date and time", default=datetime.now())
|
||||||
|
|
||||||
shares = int(input("Enter number of shares: "))
|
shares = int(input("Enter number of shares: "))
|
||||||
entry_price = float(input("Enter entry price: "))
|
entry_price = float(input("Enter entry price: "))
|
||||||
order_type = get_order_type()
|
order_type = get_order_type()
|
||||||
@ -270,6 +343,10 @@ def journal_menu():
|
|||||||
|
|
||||||
trade_id = int(input("\nEnter trade ID to update: "))
|
trade_id = int(input("\nEnter trade ID to update: "))
|
||||||
exit_price = float(input("Enter exit price: "))
|
exit_price = float(input("Enter exit price: "))
|
||||||
|
|
||||||
|
# Get exit date/time with market hours validation
|
||||||
|
exit_date = get_datetime_input("Enter exit date and time", default=datetime.now())
|
||||||
|
|
||||||
followed_rules = input("Did you follow your rules? (y/n): ").lower() == 'y'
|
followed_rules = input("Did you follow your rules? (y/n): ").lower() == 'y'
|
||||||
exit_reason = input("Enter exit reason: ")
|
exit_reason = input("Enter exit reason: ")
|
||||||
notes = input("Additional notes (optional): ") or None
|
notes = input("Additional notes (optional): ") or None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user