diff --git a/models/trading.py b/models/trading.py index 23a6524a..a7ee0d7d 100644 --- a/models/trading.py +++ b/models/trading.py @@ -190,6 +190,7 @@ class PortfolioStateFilterRequest(BaseModel): account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") skip_gateway: bool = Field(default=False, description="Skip Gateway wallet balance updates for faster CEX-only queries") + refresh: bool = Field(default=False, description="If True, refresh balances before returning. If False, return cached state") class PortfolioHistoryFilterRequest(TimeRangePaginationParams): diff --git a/routers/portfolio.py b/routers/portfolio.py index 82b9fad2..65bec864 100644 --- a/routers/portfolio.py +++ b/routers/portfolio.py @@ -29,13 +29,21 @@ async def get_portfolio_state( - account_names: Optional list of account names to filter by - connector_names: Optional list of connector names to filter by - skip_gateway: If True, skip Gateway wallet balance updates for faster CEX-only queries + - refresh: If True, refresh balances before returning. If False (default), return cached state Returns: Dict containing account states with connector balances and token information """ - await accounts_service.update_account_state(skip_gateway=filter_request.skip_gateway) + # Only refresh balances if explicitly requested + if filter_request.refresh: + await accounts_service.update_account_state( + skip_gateway=filter_request.skip_gateway, + account_names=filter_request.account_names, + connector_names=filter_request.connector_names + ) + all_states = accounts_service.get_accounts_state() - + # Apply account name filter first if filter_request.account_names: filtered_states = {} @@ -43,7 +51,7 @@ async def get_portfolio_state( if account_name in all_states: filtered_states[account_name] = all_states[account_name] all_states = filtered_states - + # Apply connector filter if specified if filter_request.connector_names: for account_name, account_data in all_states.items(): @@ -54,7 +62,7 @@ async def get_portfolio_state( filtered_connectors[connector_name] = account_data[connector_name] # Replace account_data with only filtered connectors all_states[account_name] = filtered_connectors - + return all_states diff --git a/services/accounts_service.py b/services/accounts_service.py index 2cb5773b..e12a98ee 100644 --- a/services/accounts_service.py +++ b/services/accounts_service.py @@ -1,6 +1,5 @@ import asyncio import logging -import time from datetime import datetime, timezone from decimal import Decimal from typing import Dict, List, Optional @@ -40,9 +39,8 @@ class AccountsService: } potential_wrapped_tokens = ["ETH", "SOL", "BNB", "POL", "AVAX", "FTM", "ONE", "GLMR", "MOVR"] - # Cache for storing last successful prices by trading pair with timestamps + # Cache for storing last successful prices by trading pair _last_known_prices = {} - _price_update_interval = 60 # Update prices every 60 seconds def __init__(self, account_update_interval: int = 5, @@ -299,11 +297,19 @@ async def _initialize_price_tracking(self, account_name: str, connector_name: st except Exception as e: logger.error(f"Error initializing price tracking for {connector_name} in account {account_name}: {e}") - async def update_account_state(self, skip_gateway: bool = False): - """Update account state for all connectors and optionally Gateway wallets. + async def update_account_state( + self, + skip_gateway: bool = False, + account_names: Optional[List[str]] = None, + connector_names: Optional[List[str]] = None + ): + """Update account state for filtered connectors and optionally Gateway wallets. Args: skip_gateway: If True, skip Gateway wallet balance updates for faster CEX-only queries. + account_names: If provided, only update these accounts. If None, update all accounts. + connector_names: If provided, only update these connectors. If None, update all connectors. + For Gateway, this filters by chain-network (e.g., 'solana-mainnet-beta'). """ all_connectors = self.connector_manager.get_all_connectors() @@ -312,9 +318,17 @@ async def update_account_state(self, skip_gateway: bool = False): task_meta = [] # (account_name, connector_name) for account_name, connectors in all_connectors.items(): + # Filter by account_names if specified + if account_names and account_name not in account_names: + continue + if account_name not in self.accounts_state: self.accounts_state[account_name] = {} for connector_name, connector in connectors.items(): + # Filter by connector_names if specified + if connector_names and connector_name not in connector_names: + continue + tasks.append(self._get_connector_tokens_info(connector, connector_name)) task_meta.append((account_name, connector_name)) @@ -322,7 +336,12 @@ async def update_account_state(self, skip_gateway: bool = False): if skip_gateway: results = await asyncio.gather(*tasks, return_exceptions=True) else: - results = await asyncio.gather(*tasks, self._update_gateway_balances(), return_exceptions=True) + # Pass connector_names filter to gateway for chain-network filtering + results = await asyncio.gather( + *tasks, + self._update_gateway_balances(chain_networks=connector_names), + return_exceptions=True + ) # Remove gateway result from processing (it handles its own state internally) results = results[:-1] @@ -1005,7 +1024,25 @@ async def get_connector_instance(self, account_name: str, connector_name: str): raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found for account '{account_name}'") return await self.connector_manager.get_connector(account_name, connector_name) - + + async def _get_perpetual_connector(self, account_name: str, connector_name: str): + """ + Get a perpetual connector instance with validation. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Perpetual connector instance + + Raises: + HTTPException: If connector is not perpetual or not found + """ + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + return await self.get_connector_instance(account_name, connector_name) + async def get_active_orders(self, account_name: str, connector_name: str) -> Dict[str, any]: """ Get active orders for a specific connector. @@ -1049,30 +1086,25 @@ async def cancel_order(self, account_name: str, connector_name: str, client_orde logger.error(f"Failed to initiate cancellation for order {client_order_id}: {e}") raise HTTPException(status_code=500, detail=f"Failed to initiate order cancellation: {str(e)}") - async def set_leverage(self, account_name: str, connector_name: str, + async def set_leverage(self, account_name: str, connector_name: str, trading_pair: str, leverage: int) -> Dict[str, str]: """ Set leverage for a specific trading pair on a perpetual connector. - + Args: account_name: Name of the account connector_name: Name of the connector (must be perpetual) trading_pair: Trading pair to set leverage for leverage: Leverage value (typically 1-125) - + Returns: Dictionary with success status and message - + Raises: HTTPException: If account/connector not found, not perpetual, or operation fails """ - # Validate this is a perpetual connector - if "_perpetual" not in connector_name: - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") - - connector = await self.get_connector_instance(account_name, connector_name) - - # Check if connector has leverage functionality + connector = await self._get_perpetual_connector(account_name, connector_name) + if not hasattr(connector, '_execute_set_leverage'): raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support leverage setting") @@ -1086,28 +1118,24 @@ async def set_leverage(self, account_name: str, connector_name: str, logger.error(f"Failed to set leverage for {trading_pair} to {leverage}: {e}") raise HTTPException(status_code=500, detail=f"Failed to set leverage: {str(e)}") - async def set_position_mode(self, account_name: str, connector_name: str, + async def set_position_mode(self, account_name: str, connector_name: str, position_mode: PositionMode) -> Dict[str, str]: """ Set position mode for a perpetual connector. - + Args: account_name: Name of the account connector_name: Name of the connector (must be perpetual) position_mode: PositionMode.HEDGE or PositionMode.ONEWAY - + Returns: Dictionary with success status and message - + Raises: HTTPException: If account/connector not found, not perpetual, or operation fails """ - # Validate this is a perpetual connector - if "_perpetual" not in connector_name: - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") - - connector = await self.get_connector_instance(account_name, connector_name) - + connector = await self._get_perpetual_connector(account_name, connector_name) + # Check if the requested position mode is supported supported_modes = connector.supported_position_modes() if position_mode not in supported_modes: @@ -1135,24 +1163,19 @@ async def set_position_mode(self, account_name: str, connector_name: str, async def get_position_mode(self, account_name: str, connector_name: str) -> Dict[str, str]: """ Get current position mode for a perpetual connector. - + Args: account_name: Name of the account connector_name: Name of the connector (must be perpetual) - + Returns: Dictionary with current position mode - + Raises: HTTPException: If account/connector not found, not perpetual, or operation fails """ - # Validate this is a perpetual connector - if "_perpetual" not in connector_name: - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") - - connector = await self.get_connector_instance(account_name, connector_name) - - # Check if connector has position mode functionality + connector = await self._get_perpetual_connector(account_name, connector_name) + if not hasattr(connector, 'position_mode'): raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position mode") @@ -1263,24 +1286,19 @@ async def get_trades(self, account_name: Optional[str] = None, connector_name: O async def get_account_positions(self, account_name: str, connector_name: str) -> List[Dict]: """ Get current positions for a specific perpetual connector. - + Args: account_name: Name of the account connector_name: Name of the connector (must be perpetual) - + Returns: List of position dictionaries - + Raises: HTTPException: If account/connector not found or not perpetual """ - # Validate this is a perpetual connector - if "_perpetual" not in connector_name: - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") - - connector = await self.get_connector_instance(account_name, connector_name) - - # Check if connector has account_positions property + connector = await self._get_perpetual_connector(account_name, connector_name) + if not hasattr(connector, 'account_positions'): raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position tracking") @@ -1382,8 +1400,14 @@ async def get_total_funding_fees(self, account_name: str, connector_name: str, # Gateway Wallet Management Methods # ============================================ - async def _update_gateway_balances(self): - """Update Gateway wallet balances in master_account state.""" + async def _update_gateway_balances(self, chain_networks: Optional[List[str]] = None): + """Update Gateway wallet balances in master_account state. + + Args: + chain_networks: If provided, only update these chain-network combinations + (e.g., ['solana-mainnet-beta', 'ethereum-mainnet']). + If None, update all available chain-networks. + """ try: # Check if Gateway is available if not await self.gateway_client.ping(): @@ -1443,20 +1467,25 @@ async def _update_gateway_balances(self): # Create tasks for all networks for this wallet for network in networks: + chain_network_key = f"{chain}-{network}" + + # Filter by chain_networks if specified + if chain_networks: + if chain_network_key not in chain_networks: + continue + balance_tasks.append(self.get_gateway_balances(chain, address, network=network)) task_metadata.append((chain, network, address)) # Execute all balance queries in parallel if balance_tasks: - t_zero = time.time() results = await asyncio.gather(*balance_tasks, return_exceptions=True) - duration = time.time() - t_zero # Build set of active chain-network keys from current wallets active_chain_networks = {f"{chain}-{network}" for chain, network, _ in task_metadata} # Process results - for idx, (result, (chain, network, address)) in enumerate(zip(results, task_metadata)): + for result, (chain, network, address) in zip(results, task_metadata): chain_network = f"{chain}-{network}" if isinstance(result, Exception): @@ -1470,20 +1499,23 @@ async def _update_gateway_balances(self): # Store empty list to indicate we checked this network self.accounts_state["master_account"][chain_network] = [] - # Remove stale gateway chain-network keys (wallets that were deleted) - # Gateway keys follow pattern: chain-network (e.g., "solana-mainnet-beta", "ethereum-mainnet") - stale_keys = [] - for key in self.accounts_state["master_account"]: - # Check if key looks like a gateway chain-network (contains hyphen and matches chain pattern) - if "-" in key and key not in active_chain_networks: - # Verify it's a gateway key by checking if chain part matches known chains - chain_part = key.split("-")[0] - if chain_part in chain_networks_map: - stale_keys.append(key) - - for key in stale_keys: - logger.info(f"Removing stale Gateway balance data for {key} (wallet no longer exists)") - del self.accounts_state["master_account"][key] + # Only remove stale keys if we're doing a full update (no filter) + # When filtering, we don't want to remove keys that weren't in the filter + if not chain_networks: + # Remove stale gateway chain-network keys (wallets that were deleted) + # Gateway keys follow pattern: chain-network (e.g., "solana-mainnet-beta", "ethereum-mainnet") + stale_keys = [] + for key in self.accounts_state["master_account"]: + # Check if key looks like a gateway chain-network (contains hyphen and matches chain pattern) + if "-" in key and key not in active_chain_networks: + # Verify it's a gateway key by checking if chain part matches known chains + chain_part = key.split("-")[0] + if chain_part in chain_networks_map: + stale_keys.append(key) + + for key in stale_keys: + logger.info(f"Removing stale Gateway balance data for {key} (wallet no longer exists)") + del self.accounts_state["master_account"][key] except Exception as e: logger.error(f"Error updating Gateway balances: {e}") @@ -1605,13 +1637,11 @@ async def get_gateway_balances(self, chain: str, address: str, network: Optional # Get prices using rate sources (similar to _get_connector_tokens_info) unique_tokens = [b["token"] for b in balances_list] - connector_name = f"gateway_{chain}-{network}" # Try to get cached prices first # Try USDT first (more common in CEX like Binance), then USDC (common in DEX) prices_from_cache = {} tokens_need_update = [] - token_to_trading_pair = {} # Maps token -> trading_pair that has the price if self.market_data_feed_manager: for token in unique_tokens: @@ -1625,7 +1655,6 @@ async def get_gateway_balances(self, chain: str, address: str, network: Optional cached_price = self.market_data_feed_manager.market_data_provider.get_rate(trading_pair) if cached_price > 0: prices_from_cache[token] = cached_price - token_to_trading_pair[token] = trading_pair found_price = True break except Exception: @@ -1660,7 +1689,6 @@ async def get_gateway_balances(self, chain: str, address: str, network: Optional for token, price in fetched_prices.items(): if price > 0: prices_from_cache[token] = price - token_to_trading_pair[token] = f"{token}-USDC" except Exception as e: logger.warning(f"Error fetching immediate gateway prices: {e}") @@ -1738,7 +1766,7 @@ async def _fetch_gateway_prices_immediate(self, chain: str, network: str, base_asset=token, quote_asset="USDC", amount=Decimal("1"), - side=TradeType.BUY + side=TradeType.SELL ) tasks.append(task) task_tokens.append(token) @@ -1765,8 +1793,7 @@ async def _fetch_gateway_prices_immediate(self, chain: str, network: str, return prices def get_unwrapped_token(self, token: str) -> str: - """Get the unwrapped version of a wrapped token symbol.""" - for pw in self.potential_wrapped_tokens: - if token in pw: - return pw + """Get the unwrapped version of a wrapped token symbol (e.g., WSOL -> SOL).""" + if token.startswith("W") and token[1:] in self.potential_wrapped_tokens: + return token[1:] return token \ No newline at end of file