diff --git a/.gitignore b/.gitignore index 7c67755..38efd11 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # Telegram Bot Persistence condor_bot_data.pickle condor_bot_data.pickle.bak +.DS_Store diff --git a/Makefile b/Makefile index 33833a4..525b74c 100644 --- a/Makefile +++ b/Makefile @@ -1,15 +1,21 @@ -.PHONY: uninstall install run deploy setup stop +.PHONY: help setup install uninstall run deploy stop test lint setup-chrome -# Check if conda is available -ifeq (, $(shell which conda)) - $(error "Conda is not found in PATH. Please install Conda or add it to your PATH.") -endif +help: + @echo "Condor Bot - Available Commands" + @echo "" + @echo " make setup - Interactive setup (creates .env file)" + @echo " make install - Setup + create conda environment" + @echo " make run - Run the bot locally" + @echo " make deploy - Run with Docker Compose" + @echo " make stop - Stop Docker containers" + @echo " make test - Run tests" + @echo " make lint - Run black + isort" + @echo " make uninstall - Remove conda environment" -uninstall: - conda env remove -n condor -y - -stop: - docker compose down +# Interactive setup (creates .env file) +setup: + chmod +x setup-environment.sh + ./setup-environment.sh # Install conda environment install: @@ -19,16 +25,35 @@ install: else \ conda env create -f environment.yml; \ fi + $(MAKE) setup-chrome -# Docker setup -setup: - chmod +x setup-environment.sh - ./setup-environment.sh +# Install Chrome for Kaleido (must run after conda env is created) +setup-chrome: + @echo "Installing Chrome for Plotly image generation..." + @conda run -n condor python -c "import kaleido; kaleido.get_chrome_sync()" 2>/dev/null || \ + echo "Chrome installation skipped (not required for basic usage)" # Run locally (dev mode) -run: +run: conda run --no-capture-output -n condor python main.py # Deploy with Docker deploy: docker compose up -d + +# Stop Docker containers +stop: + docker compose down + +# Run tests +test: + conda run -n condor pytest + +# Lint and format code +lint: + conda run -n condor black . + conda run -n condor isort . + +# Remove conda environment +uninstall: + conda env remove -n condor -y diff --git a/README.md b/README.md index 7fa6343..a664dd0 100644 --- a/README.md +++ b/README.md @@ -13,25 +13,19 @@ A Telegram bot for monitoring and trading with Hummingbot via the Backend API. ## Quick Start -**Prerequisites:** Python 3.11+, Conda, Hummingbot Backend API running, Telegram Bot Token +**Prerequisites:** Python 3.12+, Conda, Hummingbot Backend API running, Telegram Bot Token ```bash -# clone repo git clone https://github.com/hummingbot/condor.git cd condor -# environment setup -conda env create -f environment.yml -conda activate condor - -# Configure -cp .env.example .env -# Edit .env with your credentials: -# - TELEGRAM_TOKEN -# - TELEGRAM_ALLOWED_IDS -# - OPENAI_API_KEY (optional, for AI features) - -# Run -python main.py + +# Option 1: Local Python +make install # Interactive setup + conda environment +make run # Start the bot + +# Option 2: Docker +make setup # Interactive configuration +make deploy # Start with Docker Compose ``` ## Commands @@ -40,73 +34,55 @@ python main.py |---------|-------------| | `/portfolio` | Portfolio dashboard with PNL indicators, holdings, and graphs | | `/bots` | All active bots with status and metrics | -| `/bots ` | Specific bot details | -| `/clob_trading` | CLOB trading menu (place orders, manage positions) | -| `/dex_trading` | DEX trading menu (swaps, pools, liquidity positions) | -| `/config` | Configuration menu (servers, API keys) | -| `/trade ` | AI assistant (disabled, coming soon) | +| `/trade` | CEX trading menu (spot & perpetual orders, positions) | +| `/swap` | DEX swap trading (quotes, execution, history) | +| `/lp` | DEX liquidity pool management (positions, pools) | +| `/routines` | Auto-discoverable Python scripts with scheduling | +| `/config` | Configuration menu (servers, API keys, Gateway, admin) | ## Architecture ``` Telegram → Condor Bot → Hummingbot Backend API → Trading Bots ↘ Gateway → DEX Protocols - ↘ GPT-4o → MCP (Docker) → Hummingbot API (future) ``` -### Direct API Commands -- `/portfolio`, `/bots`, `/clob_trading`, `/dex_trading` use direct API calls via `hummingbot_api_client` -- Fast, reliable, interactive button menus - -### AI Assistant (Future) -- `/trade` uses GPT-4o + MCP server with access to all Hummingbot tools -- Natural language interface for market data, portfolio queries, and more +All commands use direct API calls via `hummingbot_api_client` with interactive button menus. ## Project Structure ``` condor/ ├── handlers/ # Telegram command handlers -│ ├── bots.py # /bots command │ ├── portfolio.py # /portfolio command with dashboard -│ ├── trade_ai.py # /trade AI assistant (disabled) -│ ├── clob/ # CLOB trading module -│ │ ├── __init__.py # Main command, callback router -│ │ ├── menu.py # Trading menu with overview -│ │ ├── place_order.py # Order placement flow -│ │ ├── leverage.py # Leverage/position mode config -│ │ ├── orders.py # Order search/cancel -│ │ ├── positions.py # Position management -│ │ └── account.py # Account switching -│ ├── dex/ # DEX trading module +│ ├── bots/ # Bot monitoring module +│ │ ├── __init__.py # /bots command +│ │ ├── menu.py # Bot status display +│ │ └── controllers/ # Bot controller configs +│ ├── cex/ # CEX trading module (/trade) │ │ ├── __init__.py # Main command, callback router -│ │ ├── menu.py # Trading menu with balances -│ │ ├── swap_quote.py # Get swap quotes -│ │ ├── swap_execute.py # Execute swaps -│ │ ├── swap_history.py # Swap history/status -│ │ └── pools.py # Pool discovery & LP management -│ └── config/ # Configuration module -│ ├── __init__.py # /config command -│ ├── servers.py # API server management -│ ├── api_keys.py # Exchange credentials -│ └── user_preferences.py # User preference storage +│ │ ├── trade.py # Order placement +│ │ ├── orders.py # Order management +│ │ └── positions.py # Position tracking +│ ├── dex/ # DEX trading module (/swap, /lp) +│ │ ├── __init__.py # Main commands, callback router +│ │ ├── swap.py # Quote, execute, history +│ │ ├── liquidity.py # LP positions management +│ │ └── pools.py # Pool info and discovery +│ ├── config/ # Configuration module (/config) +│ │ ├── __init__.py # Main command +│ │ ├── servers.py # API server management +│ │ ├── api_keys.py # Exchange credentials +│ │ └── gateway/ # Gateway configuration +│ ├── routines/ # Routines module (/routines) +│ │ └── __init__.py # Script discovery and execution +│ └── admin/ # Admin panel (via /config) +├── routines/ # User-defined automation scripts ├── utils/ # Utilities -│ ├── auth.py # @restricted decorator -│ ├── telegram_formatters.py # Message formatting -│ ├── portfolio_graphs.py # Dashboard chart generation -│ └── trading_data.py # Data aggregation helpers -├── servers/ # Server management -│ ├── server_manager.py # Server CRUD & client pool -│ └── servers.yml # Server configuration +│ ├── auth.py # @restricted, @admin_required decorators +│ └── telegram_formatters.py # Message formatting +├── config_manager.py # Unified config (servers, users, permissions) ├── hummingbot_api_client/ # API client library -├── hummingbot_mcp/ # MCP AI tools (Docker) -├── flows/ # Documentation -│ ├── bots_flow.txt # /bots command flow -│ ├── portfolio_flow.txt # /portfolio command flow -│ ├── clob_trading_flow.txt # CLOB trading flow -│ ├── dex_trading_flow.txt # DEX trading flow -│ ├── config_flow.txt # Configuration flow -│ └── common_patterns.txt # Shared patterns └── main.py # Entry point ``` @@ -121,7 +97,7 @@ condor/ - **Dashboard** - Combined chart with value history, token distribution, account breakdown - **Settings** - Configure time period (1d, 3d, 7d, 14d, 30d) -### CLOB Trading (`/clob_trading`) +### CEX Trading (`/trade`) - **Overview** - Account balances, positions, orders at a glance - **Place Orders** - Interactive menu with dual input (buttons + direct text) - Toggle: side, order type, position mode @@ -131,15 +107,24 @@ condor/ - **Search Orders** - View/filter/cancel orders - **Manage Positions** - View, trade, close positions with confirmation -### DEX Trading (`/dex_trading`) +### DEX Swaps (`/swap`) - **Gateway Balances** - Token balances across DEX wallets - **Swap Quote** - Get quotes before executing - **Execute Swap** - Perform swaps with slippage control - **Quick Swap** - Repeat last swap with minimal input + +### Liquidity Pools (`/lp`) - **Pool Discovery** - Search pools by connector and token - **Pool Info** - Detailed pool stats with liquidity charts - **LP Positions** - Manage CLMM positions (add, close, collect fees) +### Routines (`/routines`) +- **Auto-Discovery** - Python scripts auto-discovered from `routines/` folder +- **Pydantic Config** - Type-safe configuration with descriptions +- **One-shot Scripts** - Run once, optionally schedule (interval or daily) +- **Continuous Scripts** - Long-running tasks with start/stop control +- **Multi-instance** - Run multiple instances with different configs + ### Configuration (`/config`) - **API Servers** - Add, modify, delete Hummingbot Backend API servers - Real-time status checking (online/offline/auth error) @@ -155,15 +140,15 @@ condor/ Preferences are automatically saved and persist across sessions: - **Portfolio** - Graph time period (days, interval) -- **CLOB** - Active account, last order parameters +- **CEX** - Active account, last order parameters - **DEX** - Default network/connector, last swap parameters - **General** - Active server ## Security -- **User ID Whitelist** - Only `TELEGRAM_ALLOWED_IDS` can access +- **Admin Whitelist** - Only `ADMIN_USER_ID` has initial access +- **Role-Based Access** - Admin, User, Pending, Blocked roles - **@restricted Decorator** - Applied to all command handlers -- **Environment Credentials** - API keys stored in `.env` - **Secret Masking** - Passwords hidden in UI ## Configuration Files @@ -171,32 +156,50 @@ Preferences are automatically saved and persist across sessions: ### `.env` ```bash TELEGRAM_TOKEN=your_bot_token -TELEGRAM_ALLOWED_IDS=123456789,987654321 +ADMIN_USER_ID=123456789 OPENAI_API_KEY=sk-... # Optional, for AI features ``` -### `servers.yml` +### `config.yml` (auto-created on first run) ```yaml -default_server: main servers: main: host: localhost port: 8000 username: admin password: admin - enabled: true +default_server: main +admin_id: 123456789 +users: {} +server_access: {} +chat_defaults: {} +audit_log: [] ``` ## Troubleshooting | Issue | Solution | |-------|----------| -| Bot not responding | Check `TELEGRAM_TOKEN` and `TELEGRAM_ALLOWED_IDS` | +| Bot not responding | Check `TELEGRAM_TOKEN` and `ADMIN_USER_ID` in `.env` | +| Access pending | Admin must approve user via /config > Admin Panel | | Commands failing | Verify Hummingbot API is running | | Connection refused | Check server host:port in `/config` | | Auth error | Verify server credentials | | DEX features unavailable | Ensure Gateway is configured and running | +## Docker Deployment + +```bash +# Setup and run with Docker +make setup # Interactive configuration +docker compose up -d +``` + +Volumes mounted: +- `condor_bot_data.pickle` - User preferences and state +- `config.yml` - Server and permission configuration +- `routines/` - Custom automation scripts + ## Development ### Flow Documentation diff --git a/config_manager.py b/config_manager.py new file mode 100644 index 0000000..ff16184 --- /dev/null +++ b/config_manager.py @@ -0,0 +1,805 @@ +""" +Unified Configuration Manager for Condor Bot. +Manages servers, users, permissions, and settings in a single config.yml file. +""" + +import logging +import time +from enum import Enum +from pathlib import Path +from typing import Optional, Dict, Tuple, Any + +import yaml +from aiohttp import ClientTimeout + +logger = logging.getLogger(__name__) + + +class UserRole(str, Enum): + """User roles in the system""" + ADMIN = "admin" + USER = "user" + PENDING = "pending" + BLOCKED = "blocked" + + +class ServerPermission(str, Enum): + """Permission levels for server access""" + OWNER = "owner" + TRADER = "trader" + VIEWER = "viewer" + + +PERMISSION_HIERARCHY = { + ServerPermission.VIEWER: 0, + ServerPermission.TRADER: 1, + ServerPermission.OWNER: 2, +} + + +class ConfigManager: + """ + Unified configuration manager for Condor Bot. + Handles servers, users, permissions, and chat defaults in a single YAML file. + Uses singleton pattern - access via ConfigManager.instance() + """ + + VERSION = 1 + MAX_AUDIT_LOG_ENTRIES = 500 + + _instance: Optional['ConfigManager'] = None + + def __init__(self, config_path: str = "config.yml"): + self.config_path = Path(config_path) + self.audit_log_path = Path("audit_log.yml") + self._data: dict = {} + self._audit_log: list = [] + self._clients: Dict[str, Tuple[Any, float]] = {} # server_name -> (client, connect_time) + self._client_ttl = 300 # 5 minutes + self._load_config() + self._load_audit_log() + + @classmethod + def instance(cls, config_path: str = "config.yml") -> 'ConfigManager': + """Get the singleton instance.""" + if cls._instance is None: + cls._instance = cls(config_path) + return cls._instance + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton (for testing).""" + cls._instance = None + + def _get_admin_from_env(self) -> Optional[int]: + """Get admin user ID from environment.""" + from utils.config import ADMIN_USER_ID + return ADMIN_USER_ID + + def _load_config(self): + """Load configuration from YAML file.""" + if not self.config_path.exists(): + self._init_default_config() + return + + try: + with open(self.config_path, 'r') as f: + self._data = yaml.safe_load(f) or {} + + # Ensure all sections exist + self._data.setdefault('servers', {}) + self._data.setdefault('default_server', None) + self._data.setdefault('users', {}) + self._data.setdefault('server_access', {}) + self._data.setdefault('chat_defaults', {}) + # Migrate audit_log from config.yml to separate file (one-time) + if 'audit_log' in self._data: + self._audit_log = self._data.pop('audit_log') + self._save_audit_log() + self._save_config() # Save config without audit_log + + # Always trust admin_id from env + admin_id = self._get_admin_from_env() + if admin_id: + self._data['admin_id'] = admin_id + self._ensure_admin_user(admin_id) + + logger.info(f"Loaded config from {self.config_path}") + except Exception as e: + logger.error(f"Failed to load config: {e}") + self._init_default_config() + + def _init_default_config(self): + """Initialize with default configuration.""" + admin_id = self._get_admin_from_env() + self._data = { + 'servers': {}, + 'default_server': None, + 'admin_id': admin_id, + 'users': {}, + 'server_access': {}, + 'chat_defaults': {}, + 'version': self.VERSION + } + self._audit_log = [] + if admin_id: + self._ensure_admin_user(admin_id) + self._save_config() + logger.info(f"Created new config at {self.config_path}") + + def _ensure_admin_user(self, admin_id: int): + """Ensure admin user exists in users dict.""" + if admin_id not in self._data['users']: + self._data['users'][admin_id] = { + 'user_id': admin_id, + 'role': UserRole.ADMIN.value, + 'created_at': time.time(), + 'notes': 'Primary admin from ADMIN_USER_ID' + } + self._save_config() + + def _save_config(self): + """Save configuration to YAML file.""" + try: + data = { + 'servers': self._data.get('servers', {}), + 'default_server': self._data.get('default_server'), + 'admin_id': self._data.get('admin_id'), + 'users': self._data.get('users', {}), + 'server_access': self._data.get('server_access', {}), + 'chat_defaults': self._data.get('chat_defaults', {}), + 'version': self._data.get('version', self.VERSION) + } + with open(self.config_path, 'w') as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + logger.debug(f"Saved config to {self.config_path}") + except Exception as e: + logger.error(f"Failed to save config: {e}") + raise + + def _load_audit_log(self): + """Load audit log from separate file.""" + if not self.audit_log_path.exists(): + self._audit_log = [] + return + + try: + with open(self.audit_log_path, 'r') as f: + data = yaml.safe_load(f) or {} + self._audit_log = data.get('entries', []) + logger.debug(f"Loaded {len(self._audit_log)} audit log entries") + except Exception as e: + logger.error(f"Failed to load audit log: {e}") + self._audit_log = [] + + def _save_audit_log(self): + """Save audit log to separate file.""" + try: + # Trim to max entries + if len(self._audit_log) > self.MAX_AUDIT_LOG_ENTRIES: + self._audit_log = self._audit_log[-self.MAX_AUDIT_LOG_ENTRIES:] + + data = {'entries': self._audit_log} + with open(self.audit_log_path, 'w') as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + logger.debug(f"Saved {len(self._audit_log)} audit log entries") + except Exception as e: + logger.error(f"Failed to save audit log: {e}") + + def reload(self): + """Reload configuration from file.""" + self._load_config() + self._load_audit_log() + + @property + def admin_id(self) -> Optional[int]: + return self._data.get('admin_id') + + # ========================================================================= + # SERVER MANAGEMENT + # ========================================================================= + + def list_servers(self) -> Dict[str, dict]: + """List all configured servers.""" + return self._data.get('servers', {}).copy() + + def get_server(self, name: str) -> Optional[dict]: + """Get a specific server configuration.""" + return self._data.get('servers', {}).get(name) + + def add_server(self, name: str, host: str, port: int, username: str, + password: str, owner_id: int = None) -> bool: + """Add a new server.""" + servers = self._data['servers'] + if name in servers: + logger.error(f"Server '{name}' already exists") + return False + + servers[name] = { + 'host': host, + 'port': port, + 'username': username, + 'password': password + } + + # Register ownership + if owner_id: + self.register_server_owner(name, owner_id) + + self._save_config() + logger.info(f"Added server '{name}'") + return True + + def modify_server(self, name: str, host: str = None, port: int = None, + username: str = None, password: str = None) -> bool: + """Modify an existing server.""" + servers = self._data['servers'] + if name not in servers: + logger.error(f"Server '{name}' not found") + return False + + # Clear cached client + if name in self._clients: + del self._clients[name] + + if host is not None: + servers[name]['host'] = host + if port is not None: + servers[name]['port'] = port + if username is not None: + servers[name]['username'] = username + if password is not None: + servers[name]['password'] = password + + self._save_config() + logger.info(f"Modified server '{name}'") + return True + + def delete_server(self, name: str, actor_id: int = None) -> bool: + """Delete a server.""" + servers = self._data['servers'] + if name not in servers: + logger.error(f"Server '{name}' not found") + return False + + # Clear cached client + if name in self._clients: + del self._clients[name] + + del servers[name] + + # Unregister from access control + if name in self._data['server_access']: + del self._data['server_access'][name] + + self._save_config() + logger.info(f"Deleted server '{name}'") + return True + + def get_default_server(self) -> Optional[str]: + """Get the default server name.""" + return self._data.get('default_server') + + def set_default_server(self, name: str) -> bool: + """Set the default server.""" + if name not in self._data['servers']: + logger.error(f"Server '{name}' not found") + return False + + self._data['default_server'] = name + self._save_config() + logger.info(f"Set default server to '{name}'") + return True + + async def get_client(self, name: str = None): + """Get or create API client for a server.""" + from hummingbot_api_client import HummingbotAPIClient + + if name is None: + name = self.get_default_server() + if not name: + if self._data['servers']: + name = list(self._data['servers'].keys())[0] + else: + raise ValueError("No servers configured") + + if name not in self._data['servers']: + raise ValueError(f"Server '{name}' not found") + + # Return cached client if fresh + if name in self._clients: + client, connect_time = self._clients[name] + if time.time() - connect_time < self._client_ttl: + self._clients[name] = (client, time.time()) # Refresh + return client + else: + try: + await client.close() + except: + pass + del self._clients[name] + + # Create new client + server = self._data['servers'][name] + base_url = f"http://{server['host']}:{server['port']}" + client = HummingbotAPIClient( + base_url=base_url, + username=server['username'], + password=server['password'], + timeout=ClientTimeout(total=60, connect=10) + ) + + try: + await client.init() + await client.accounts.list_accounts() + self._clients[name] = (client, time.time()) + logger.info(f"Connected to server '{name}' at {base_url}") + return client + except Exception as e: + await client.close() + logger.error(f"Failed to connect to '{name}': {e}") + raise + + async def get_client_for_chat(self, chat_id: int, user_id: int = None, preferred_server: str = None): + """Get the API client for a user's preferred or first accessible server. + + Priority: + 1. preferred_server (from user preferences/context) if accessible + 2. chat_defaults[chat_id] if accessible + 3. First accessible server for the user + 4. If no user_id, use chat default or any available server + """ + if user_id: + accessible = self.get_accessible_servers(user_id) + if not accessible: + raise ValueError("No servers available. Ask the admin to share a server with you.") + + # 1. User's preferred server if accessible + if preferred_server and preferred_server in accessible: + return await self.get_client(preferred_server) + + # 2. Chat's default server if accessible + chat_default = self._data.get('chat_defaults', {}).get(chat_id) + if chat_default and chat_default in accessible: + return await self.get_client(chat_default) + + # 3. First accessible server + return await self.get_client(accessible[0]) + + # No user_id - use chat default with proper fallbacks + server_name = self.get_chat_default_server(chat_id) + if not server_name: + raise ValueError("No servers configured") + return await self.get_client(server_name) + + async def check_server_status(self, name: str) -> dict: + """Check if a server is online.""" + from hummingbot_api_client import HummingbotAPIClient + + if name not in self._data['servers']: + return {"status": "error", "message": "Server not found"} + + server = self._data['servers'][name] + base_url = f"http://{server['host']}:{server['port']}" + + client = HummingbotAPIClient( + base_url=base_url, + username=server['username'], + password=server['password'], + timeout=ClientTimeout(total=3, connect=2) + ) + + try: + await client.init() + await client.accounts.list_accounts() + return {"status": "online", "message": "Connected and authenticated"} + except Exception as e: + error_msg = str(e) + if "401" in error_msg: + return {"status": "auth_error", "message": "Invalid credentials"} + elif "timeout" in error_msg.lower(): + return {"status": "offline", "message": "Connection timeout"} + elif "connect" in error_msg.lower(): + return {"status": "offline", "message": "Cannot reach server"} + else: + return {"status": "error", "message": f"Error: {error_msg[:80]}"} + finally: + try: + await client.close() + except: + pass + + async def close_all_clients(self): + """Close all cached client connections.""" + for name, (client, _) in list(self._clients.items()): + try: + await client.close() + logger.info(f"Closed connection to '{name}'") + except Exception as e: + logger.error(f"Error closing client '{name}': {e}") + self._clients.clear() + + # ========================================================================= + # USER MANAGEMENT + # ========================================================================= + + def get_user(self, user_id: int) -> Optional[dict]: + """Get user record.""" + return self._data.get('users', {}).get(user_id) + + def get_user_role(self, user_id: int) -> Optional[UserRole]: + """Get user's role.""" + user = self.get_user(user_id) + if user: + try: + return UserRole(user['role']) + except ValueError: + return None + return None + + def is_admin(self, user_id: int) -> bool: + return self.get_user_role(user_id) == UserRole.ADMIN + + def is_approved(self, user_id: int) -> bool: + role = self.get_user_role(user_id) + return role in (UserRole.ADMIN, UserRole.USER) + + def register_pending(self, user_id: int, username: str = None) -> bool: + """Register a new pending user.""" + users = self._data['users'] + if user_id in users: + return False + + users[user_id] = { + 'user_id': user_id, + 'username': username, + 'role': UserRole.PENDING.value, + 'created_at': time.time() + } + self._audit('user_registered', 'user', str(user_id), user_id) + self._save_config() + logger.info(f"Registered pending user {user_id}") + return True + + def approve_user(self, user_id: int, admin_id: int) -> bool: + """Approve a pending user.""" + users = self._data['users'] + if user_id not in users: + return False + if users[user_id]['role'] == UserRole.BLOCKED.value: + return False + + users[user_id]['role'] = UserRole.USER.value + users[user_id]['approved_by'] = admin_id + users[user_id]['approved_at'] = time.time() + + self._audit('user_approved', 'user', str(user_id), admin_id) + self._save_config() + logger.info(f"User {user_id} approved by {admin_id}") + return True + + def reject_user(self, user_id: int, admin_id: int) -> bool: + """Reject a pending user.""" + users = self._data['users'] + if user_id not in users or users[user_id]['role'] != UserRole.PENDING.value: + return False + + del users[user_id] + self._audit('user_rejected', 'user', str(user_id), admin_id) + self._save_config() + return True + + def block_user(self, user_id: int, admin_id: int) -> bool: + """Block a user.""" + users = self._data['users'] + if user_id not in users or user_id == admin_id: + return False + if users[user_id]['role'] == UserRole.ADMIN.value: + return False + + users[user_id]['role'] = UserRole.BLOCKED.value + self._audit('user_blocked', 'user', str(user_id), admin_id) + self._save_config() + return True + + def unblock_user(self, user_id: int, admin_id: int) -> bool: + """Unblock a user (sets to pending).""" + users = self._data['users'] + if user_id not in users or users[user_id]['role'] != UserRole.BLOCKED.value: + return False + + users[user_id]['role'] = UserRole.PENDING.value + self._audit('user_unblocked', 'user', str(user_id), admin_id) + self._save_config() + return True + + def get_pending_users(self) -> list: + return [u for u in self._data.get('users', {}).values() + if u.get('role') == UserRole.PENDING.value] + + def get_all_users(self) -> list: + return list(self._data.get('users', {}).values()) + + # ========================================================================= + # SERVER ACCESS CONTROL + # ========================================================================= + + def register_server_owner(self, server_name: str, owner_id: int) -> bool: + """Register server ownership.""" + access = self._data['server_access'] + if server_name in access: + return False + + access[server_name] = { + 'owner_id': owner_id, + 'created_at': time.time(), + 'shared_with': {} + } + self._audit('server_registered', 'server', server_name, owner_id) + self._save_config() + return True + + def ensure_server_registered(self, server_name: str, default_owner_id: int = None) -> bool: + """Ensure server is registered in access control.""" + if server_name in self._data['server_access']: + return True + + owner_id = default_owner_id or self.admin_id + if owner_id: + self._data['server_access'][server_name] = { + 'owner_id': owner_id, + 'created_at': time.time(), + 'shared_with': {} + } + self._save_config() + return True + return False + + def get_server_owner(self, server_name: str) -> Optional[int]: + access = self._data.get('server_access', {}).get(server_name) + return access.get('owner_id') if access else None + + def get_server_permission(self, user_id: int, server_name: str) -> Optional[ServerPermission]: + """Get user's permission level for a server.""" + if self.is_admin(user_id): + return ServerPermission.OWNER + + access = self._data.get('server_access', {}).get(server_name) + if not access: + return None + + if access.get('owner_id') == user_id: + return ServerPermission.OWNER + + perm_str = access.get('shared_with', {}).get(user_id) + if perm_str: + try: + return ServerPermission(perm_str) + except ValueError: + return None + return None + + def has_server_access(self, user_id: int, server_name: str, + min_permission: ServerPermission = ServerPermission.VIEWER) -> bool: + perm = self.get_server_permission(user_id, server_name) + if perm is None: + return False + return PERMISSION_HIERARCHY.get(perm, 0) >= PERMISSION_HIERARCHY.get(min_permission, 0) + + def share_server(self, server_name: str, owner_id: int, + target_user_id: int, permission: ServerPermission) -> bool: + """Share a server with another user.""" + access = self._data.get('server_access', {}).get(server_name) + if not access: + return False + if access.get('owner_id') != owner_id and not self.is_admin(owner_id): + return False + if target_user_id == access.get('owner_id'): + return False + if not self.is_approved(target_user_id): + return False + + access.setdefault('shared_with', {})[target_user_id] = permission.value + self._audit('server_shared', 'server', server_name, owner_id, + {'target_user': target_user_id, 'permission': permission.value}) + self._save_config() + return True + + def revoke_server_access(self, server_name: str, owner_id: int, target_user_id: int) -> bool: + """Revoke a user's access to a server.""" + access = self._data.get('server_access', {}).get(server_name) + if not access: + return False + if access.get('owner_id') != owner_id and not self.is_admin(owner_id): + return False + + shared = access.get('shared_with', {}) + if target_user_id not in shared: + return False + + del shared[target_user_id] + self._audit('server_access_revoked', 'server', server_name, owner_id, + {'target_user': target_user_id}) + self._save_config() + return True + + def get_server_shared_users(self, server_name: str) -> list: + """Get list of users a server is shared with.""" + access = self._data.get('server_access', {}).get(server_name) + if not access: + return [] + + result = [] + for user_id, perm_str in access.get('shared_with', {}).items(): + try: + result.append((user_id, ServerPermission(perm_str))) + except ValueError: + pass + return result + + def get_accessible_servers(self, user_id: int) -> list: + """Get all servers a user can access.""" + if self.is_admin(user_id): + return list(self._data.get('server_access', {}).keys()) + + accessible = [] + for server_name, access in self._data.get('server_access', {}).items(): + if access.get('owner_id') == user_id: + accessible.append(server_name) + elif user_id in access.get('shared_with', {}): + accessible.append(server_name) + return accessible + + def get_owned_servers(self, user_id: int) -> list: + return [s for s, a in self._data.get('server_access', {}).items() + if a.get('owner_id') == user_id] + + def get_shared_servers(self, user_id: int) -> list: + """Get servers shared with user (not owned).""" + result = [] + for server_name, access in self._data.get('server_access', {}).items(): + if access.get('owner_id') == user_id: + continue + perm_str = access.get('shared_with', {}).get(user_id) + if perm_str: + try: + result.append((server_name, ServerPermission(perm_str))) + except ValueError: + pass + return result + + def list_accessible_servers(self, user_id: int) -> Dict[str, dict]: + """List servers accessible by a user with their configs.""" + if self.is_admin(user_id): + # Auto-register unregistered servers for admin + for name in self._data['servers']: + self.ensure_server_registered(name, self.admin_id) + return self._data['servers'].copy() + + accessible = {} + for name in self.get_accessible_servers(user_id): + if name in self._data['servers']: + accessible[name] = self._data['servers'][name] + return accessible + + # ========================================================================= + # CHAT DEFAULTS + # ========================================================================= + + def get_chat_default_server(self, chat_id: int) -> Optional[str]: + """Get the default server for a chat.""" + server = self._data.get('chat_defaults', {}).get(chat_id) + if server and server in self._data['servers']: + return server + # Fallback to global default + default = self.get_default_server() + if default and default in self._data['servers']: + return default + # Last resort: first server + if self._data['servers']: + return list(self._data['servers'].keys())[0] + return None + + def set_chat_default_server(self, chat_id: int, server_name: str) -> bool: + """Set the default server for a chat.""" + if server_name not in self._data['servers']: + return False + self._data.setdefault('chat_defaults', {})[chat_id] = server_name + self._save_config() + return True + + def clear_chat_default_server(self, chat_id: int) -> bool: + """Clear the default server for a chat.""" + defaults = self._data.get('chat_defaults', {}) + if chat_id in defaults: + del defaults[chat_id] + self._save_config() + return True + return False + + def get_chat_server_info(self, chat_id: int) -> dict: + """Get server info for a chat.""" + per_chat = self._data.get('chat_defaults', {}).get(chat_id) + if per_chat and per_chat in self._data['servers']: + return { + "server": per_chat, + "is_per_chat": True, + "global_default": self.get_default_server() + } + return { + "server": self.get_default_server(), + "is_per_chat": False, + "global_default": self.get_default_server() + } + + # ========================================================================= + # AUDIT LOG + # ========================================================================= + + def _audit(self, action: str, target_type: str, target_id: str, + actor_id: int, details: dict = None): + self._audit_log.append({ + 'timestamp': time.time(), + 'actor_id': actor_id, + 'action': action, + 'target_type': target_type, + 'target_id': target_id, + 'details': details + }) + self._save_audit_log() + + def get_audit_log(self, limit: int = 50) -> list: + return list(reversed(self._audit_log))[:limit] + + +# Convenience functions +def get_config_manager() -> ConfigManager: + """Get the ConfigManager singleton instance.""" + return ConfigManager.instance() + + +def get_effective_server(chat_id: int, user_data: dict = None) -> str | None: + """Get the effective default server for a chat, checking both user_data and config.yml. + + Priority: + 1. user_data active_server (from pickle, fast in-memory) + 2. chat_defaults from config.yml (persistent across hard kills) + 3. None if nothing configured + + Args: + chat_id: The chat ID + user_data: Optional user_data dict from context + + Returns: + Server name or None + """ + from handlers.config.user_preferences import get_active_server + + # First check user_data (pickle - might be lost on hard kill) + if user_data: + active = get_active_server(user_data) + if active: + return active + + # Fall back to chat_defaults in config.yml (always persisted) + cm = get_config_manager() + chat_default = cm._data.get('chat_defaults', {}).get(chat_id) + if chat_default and chat_default in cm._data.get('servers', {}): + # Also sync back to user_data for fast future access + if user_data is not None: + from handlers.config.user_preferences import set_active_server + set_active_server(user_data, chat_default) + return chat_default + + return None + + +async def get_client(chat_id: int, user_id: int = None, context=None): + """Get the API client for the user's preferred server.""" + preferred_server = None + if context is not None: + if user_id is None: + user_id = context.user_data.get('_user_id') + preferred_server = get_effective_server(chat_id, context.user_data) + + return await get_config_manager().get_client_for_chat(chat_id, user_id, preferred_server) diff --git a/docker-compose.yml b/docker-compose.yml index ac1ecac..9c9e2bd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,37 +4,17 @@ services: container_name: condor-bot restart: unless-stopped environment: - TELEGRAM_TOKEN: your_bot_token - AUTHORIZED_USERS: 123,1234 + TELEGRAM_TOKEN: ${TELEGRAM_TOKEN} + ADMIN_USER_ID: ${ADMIN_USER_ID} env_file: - .env volumes: # Persist bot data (user preferences, trading context, etc.) - ./condor_bot_data.pickle:/app/condor_bot_data.pickle - # Mount servers config - - ./servers.yml:/app/servers.yml + # Mount config (servers, users, permissions) + - ./config.yml:/app/config.yml # Mount routines for auto-discovery - ./routines:/app/routines network_mode: host labels: - "com.centurylinklabs.watchtower.enable=true" - - watchtower: - image: containrrr/watchtower - container_name: watchtower - restart: unless-stopped - volumes: - - /var/run/docker.sock:/var/run/docker.sock - environment: - # Only update containers with the watchtower label - WATCHTOWER_LABEL_ENABLE: "true" - # Check for updates every 5 minutes (300 seconds) - WATCHTOWER_POLL_INTERVAL: 300 - # Remove old images after updating - WATCHTOWER_CLEANUP: "true" - # Telegram notifications - WATCHTOWER_NOTIFICATIONS: shoutrrr - WATCHTOWER_NOTIFICATION_URL: telegram://${TELEGRAM_TOKEN}@telegram?chats=${TELEGRAM_ALLOWED_IDS} - env_file: - - .env - command: --interval 300 diff --git a/handlers/__init__.py b/handlers/__init__.py index ea4bcfd..56f5f59 100644 --- a/handlers/__init__.py +++ b/handlers/__init__.py @@ -5,6 +5,31 @@ from telegram.ext import ContextTypes +def is_gateway_network(connector_name: str) -> bool: + """ + Check if a connector name is a Gateway network (DEX) vs a CEX connector. + + Gateway networks: solana-mainnet-beta, ethereum-mainnet, base, arbitrum, etc. + CEX connectors: binance, binance_perpetual, hyperliquid, kucoin, etc. + """ + if not connector_name: + return False + + connector_lower = connector_name.lower() + + # Known Gateway network patterns + gateway_patterns = [ + 'solana', 'ethereum', 'base', 'arbitrum', 'polygon', + 'optimism', 'avalanche', 'mainnet', 'devnet', 'testnet' + ] + + for pattern in gateway_patterns: + if pattern in connector_lower: + return True + + return False + + def clear_all_input_states(context: ContextTypes.DEFAULT_TYPE) -> None: """ Clear ALL input-related states from user context. @@ -90,3 +115,14 @@ def clear_all_input_states(context: ContextTypes.DEFAULT_TYPE) -> None: # Routines states context.user_data.pop("routines_state", None) context.user_data.pop("routines_editing", None) + + # Signals states + context.user_data.pop("signals_state", None) + context.user_data.pop("signals_editing", None) + + # Access share states + context.user_data.pop("sharing_server", None) + context.user_data.pop("awaiting_share_user_id", None) + context.user_data.pop("share_target_user_id", None) + context.user_data.pop("share_message_id", None) + context.user_data.pop("share_chat_id", None) diff --git a/handlers/admin/__init__.py b/handlers/admin/__init__.py new file mode 100644 index 0000000..0c6ed68 --- /dev/null +++ b/handlers/admin/__init__.py @@ -0,0 +1,484 @@ +""" +Admin panel for user and access management. +Only accessible by admin users. +""" + +import logging +from datetime import datetime + +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import ContextTypes + +from utils.auth import admin_required +from config_manager import ( + get_config_manager, + UserRole, +) +from utils.telegram_formatters import escape_markdown_v2 + +logger = logging.getLogger(__name__) + + +def _get_admin_menu_keyboard() -> InlineKeyboardMarkup: + """Build the admin menu keyboard.""" + keyboard = [ + [ + InlineKeyboardButton("👥 Pending Requests", callback_data="admin:pending"), + InlineKeyboardButton("📋 All Users", callback_data="admin:users"), + ], + [ + InlineKeyboardButton("📜 Audit Log", callback_data="admin:audit"), + InlineKeyboardButton("📊 Stats", callback_data="admin:stats"), + ], + [ + InlineKeyboardButton("« Close", callback_data="config_close"), + ], + ] + return InlineKeyboardMarkup(keyboard) + + +def _format_user_role_badge(role: str) -> str: + """Get role badge emoji.""" + badges = { + UserRole.ADMIN.value: "👑", + UserRole.USER.value: "✓", + UserRole.PENDING.value: "⏳", + UserRole.BLOCKED.value: "🚫", + } + return badges.get(role, "?") + + +def _format_timestamp(ts: float) -> str: + """Format timestamp for display.""" + if not ts: + return "N/A" + dt = datetime.fromtimestamp(ts) + return dt.strftime("%Y-%m-%d %H:%M") + + +@admin_required +async def admin_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /admin command - show admin panel.""" + from handlers import clear_all_input_states + clear_all_input_states(context) + + cm = get_config_manager() + pending_count = len(cm.get_pending_users()) + total_users = len(cm.get_all_users()) + + message = ( + "🔐 *Admin Panel*\n\n" + f"👥 Total Users: {total_users}\n" + f"⏳ Pending Requests: {pending_count}\n\n" + "Select an option below:" + ) + + await update.message.reply_text( + message, + parse_mode="MarkdownV2", + reply_markup=_get_admin_menu_keyboard() + ) + + +@admin_required +async def admin_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle admin panel callbacks.""" + query = update.callback_query + + data = query.data + + # Handle config_admin entry point (from /config menu) + if data == "config_admin": + await query.answer() + await _show_admin_menu(query, context) + return + + await query.answer() + action = data.split(":", 1)[1] if ":" in data else data + + if action == "menu" or action == "back": + await _show_admin_menu(query, context) + elif action == "pending": + await _show_pending_users(query, context) + elif action == "users": + await _show_all_users(query, context) + elif action == "audit": + await _show_audit_log(query, context) + elif action == "stats": + await _show_stats(query, context) + elif action.startswith("approve_"): + user_id = int(action.replace("approve_", "")) + await _approve_user(query, context, user_id) + elif action.startswith("reject_"): + user_id = int(action.replace("reject_", "")) + await _reject_user(query, context, user_id) + elif action.startswith("block_"): + user_id = int(action.replace("block_", "")) + await _block_user(query, context, user_id) + elif action.startswith("unblock_"): + user_id = int(action.replace("unblock_", "")) + await _unblock_user(query, context, user_id) + elif action.startswith("user_"): + user_id = int(action.replace("user_", "")) + await _show_user_details(query, context, user_id) + + +async def _show_admin_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show main admin menu.""" + cm = get_config_manager() + pending_count = len(cm.get_pending_users()) + total_users = len(cm.get_all_users()) + + message = ( + "🔐 *Admin Panel*\n\n" + f"👥 Total Users: {total_users}\n" + f"⏳ Pending Requests: {pending_count}\n\n" + "Select an option below:" + ) + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=_get_admin_menu_keyboard() + ) + + +async def _show_pending_users(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show pending user approval list.""" + cm = get_config_manager() + pending = cm.get_pending_users() + + if not pending: + message = ( + "👥 *Pending Requests*\n\n" + "No pending access requests\\." + ) + keyboard = [[InlineKeyboardButton("🔙 Back", callback_data="admin:back")]] + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + return + + message = f"👥 *Pending Requests* \\({len(pending)}\\)\n\n" + + keyboard = [] + for user in pending: + user_id = user['user_id'] + username = user.get('username') or 'N/A' + created = _format_timestamp(user.get('created_at', 0)) + + message += f"• `{user_id}` \\(@{escape_markdown_v2(username)}\\)\n" + message += f" Requested: {escape_markdown_v2(created)}\n\n" + + keyboard.append([ + InlineKeyboardButton(f"✓ Approve {user_id}", callback_data=f"admin:approve_{user_id}"), + InlineKeyboardButton(f"✕ Reject", callback_data=f"admin:reject_{user_id}"), + ]) + + keyboard.append([InlineKeyboardButton("🔙 Back", callback_data="admin:back")]) + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _show_all_users(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show all users list.""" + cm = get_config_manager() + users = cm.get_all_users() + + if not users: + message = "📋 *All Users*\n\nNo users registered\\." + keyboard = [[InlineKeyboardButton("🔙 Back", callback_data="admin:back")]] + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + return + + # Group by role + by_role = {} + for user in users: + role = user.get('role', 'unknown') + by_role.setdefault(role, []).append(user) + + message = f"📋 *All Users* \\({len(users)}\\)\n\n" + + # Show in order: admin, user, pending, blocked + role_order = [UserRole.ADMIN.value, UserRole.USER.value, UserRole.PENDING.value, UserRole.BLOCKED.value] + + keyboard = [] + for role in role_order: + role_users = by_role.get(role, []) + if not role_users: + continue + + badge = _format_user_role_badge(role) + message += f"*{badge} {role.title()}* \\({len(role_users)}\\)\n" + + for user in role_users[:5]: # Limit to 5 per role in message + user_id = user['user_id'] + username = user.get('username') or 'N/A' + message += f" • `{user_id}` @{escape_markdown_v2(username)}\n" + + keyboard.append([ + InlineKeyboardButton( + f"{badge} {user_id} (@{username[:10]})", + callback_data=f"admin:user_{user_id}" + ) + ]) + + if len(role_users) > 5: + message += f" _\\.\\.\\. and {len(role_users) - 5} more_\n" + + message += "\n" + + keyboard.append([InlineKeyboardButton("🔙 Back", callback_data="admin:back")]) + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _show_user_details(query, context: ContextTypes.DEFAULT_TYPE, user_id: int) -> None: + """Show details for a specific user.""" + cm = get_config_manager() + user = cm.get_user(user_id) + + if not user: + await query.answer("User not found", show_alert=True) + return + + role = user.get('role', 'unknown') + username = user.get('username') or 'N/A' + created = _format_timestamp(user.get('created_at', 0)) + approved_at = _format_timestamp(user.get('approved_at')) + approved_by = user.get('approved_by') + notes = user.get('notes') or 'None' + + badge = _format_user_role_badge(role) + + message = ( + f"👤 *User Details*\n\n" + f"*ID:* `{user_id}`\n" + f"*Username:* @{escape_markdown_v2(username)}\n" + f"*Role:* {badge} {escape_markdown_v2(role.title())}\n" + f"*Created:* {escape_markdown_v2(created)}\n" + ) + + if approved_at != "N/A": + message += f"*Approved:* {escape_markdown_v2(approved_at)}\n" + if approved_by: + message += f"*Approved By:* `{approved_by}`\n" + if notes != 'None': + message += f"*Notes:* {escape_markdown_v2(notes)}\n" + + # Show servers owned by user + owned_servers = cm.get_owned_servers(user_id) + if owned_servers: + message += f"\n*Owned Servers:* {len(owned_servers)}\n" + for s in owned_servers[:3]: + message += f" • {escape_markdown_v2(s)}\n" + if len(owned_servers) > 3: + message += f" _\\.\\.\\. and {len(owned_servers) - 3} more_\n" + + # Show shared servers + shared_servers = cm.get_shared_servers(user_id) + if shared_servers: + message += f"\n*Shared Access:* {len(shared_servers)}\n" + for s, perm in shared_servers[:3]: + message += f" • {escape_markdown_v2(s)} \\({perm.value}\\)\n" + + # Build action buttons based on role + keyboard = [] + admin_id = cm.admin_id + + if role == UserRole.PENDING.value: + keyboard.append([ + InlineKeyboardButton("✓ Approve", callback_data=f"admin:approve_{user_id}"), + InlineKeyboardButton("✕ Reject", callback_data=f"admin:reject_{user_id}"), + ]) + elif role == UserRole.BLOCKED.value: + keyboard.append([ + InlineKeyboardButton("🔓 Unblock", callback_data=f"admin:unblock_{user_id}"), + ]) + elif role == UserRole.USER.value and user_id != admin_id: + keyboard.append([ + InlineKeyboardButton("🚫 Block", callback_data=f"admin:block_{user_id}"), + ]) + + keyboard.append([InlineKeyboardButton("🔙 Back to Users", callback_data="admin:users")]) + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _approve_user(query, context: ContextTypes.DEFAULT_TYPE, user_id: int) -> None: + """Approve a pending user.""" + cm = get_config_manager() + admin_id = query.from_user.id + + if cm.approve_user(user_id, admin_id): + # Notify the user + try: + await context.bot.send_message( + chat_id=user_id, + text=( + "✅ *Access Approved\\!*\n\n" + "Your access request has been approved\\.\n" + "Use /start to begin\\." + ), + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Failed to notify user {user_id} of approval: {e}") + + await query.answer("User approved!", show_alert=True) + else: + await query.answer("Failed to approve user", show_alert=True) + + # Refresh pending list + await _show_pending_users(query, context) + + +async def _reject_user(query, context: ContextTypes.DEFAULT_TYPE, user_id: int) -> None: + """Reject a pending user.""" + cm = get_config_manager() + admin_id = query.from_user.id + + if cm.reject_user(user_id, admin_id): + await query.answer("User rejected", show_alert=True) + else: + await query.answer("Failed to reject user", show_alert=True) + + # Refresh pending list + await _show_pending_users(query, context) + + +async def _block_user(query, context: ContextTypes.DEFAULT_TYPE, user_id: int) -> None: + """Block a user.""" + cm = get_config_manager() + admin_id = query.from_user.id + + if cm.block_user(user_id, admin_id): + await query.answer("User blocked", show_alert=True) + else: + await query.answer("Failed to block user", show_alert=True) + + # Show user details + await _show_user_details(query, context, user_id) + + +async def _unblock_user(query, context: ContextTypes.DEFAULT_TYPE, user_id: int) -> None: + """Unblock a user.""" + cm = get_config_manager() + admin_id = query.from_user.id + + if cm.unblock_user(user_id, admin_id): + await query.answer("User unblocked (now pending)", show_alert=True) + else: + await query.answer("Failed to unblock user", show_alert=True) + + # Show user details + await _show_user_details(query, context, user_id) + + +async def _show_audit_log(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show recent audit log entries.""" + cm = get_config_manager() + entries = cm.get_audit_log(limit=10) + + if not entries: + message = "📜 *Audit Log*\n\nNo entries yet\\." + keyboard = [[InlineKeyboardButton("🔙 Back", callback_data="admin:back")]] + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + return + + message = "📜 *Audit Log* \\(Recent 10\\)\n\n" + + for entry in entries: + ts = _format_timestamp(entry.get('timestamp', 0)) + action = entry.get('action', 'unknown') + actor = entry.get('actor_id', 0) + target_type = entry.get('target_type', '') + target_id = entry.get('target_id', '') + + # Format action nicely + action_display = action.replace('_', ' ').title() + + message += f"• *{escape_markdown_v2(ts)}*\n" + message += f" {escape_markdown_v2(action_display)}\n" + message += f" By: `{actor}` \\| {escape_markdown_v2(target_type)}: `{escape_markdown_v2(str(target_id))}`\n\n" + + keyboard = [[InlineKeyboardButton("🔙 Back", callback_data="admin:back")]] + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _show_stats(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show system statistics.""" + cm = get_config_manager() + from config_manager import get_config_manager + + users = cm.get_all_users() + servers = list(get_config_manager().list_servers().keys()) + + # Count by role + role_counts = {} + for user in users: + role = user.get('role', 'unknown') + role_counts[role] = role_counts.get(role, 0) + 1 + + # Count servers by owner + server_owners = {} + for server_name in servers: + owner = cm.get_server_owner(server_name) + if owner: + server_owners[owner] = server_owners.get(owner, 0) + 1 + + message = ( + "📊 *System Statistics*\n\n" + f"*Users*\n" + f" 👑 Admins: {role_counts.get(UserRole.ADMIN.value, 0)}\n" + f" ✓ Approved: {role_counts.get(UserRole.USER.value, 0)}\n" + f" ⏳ Pending: {role_counts.get(UserRole.PENDING.value, 0)}\n" + f" 🚫 Blocked: {role_counts.get(UserRole.BLOCKED.value, 0)}\n\n" + f"*Servers*\n" + f" Total: {len(servers)}\n" + f" With owners: {len(server_owners)}\n\n" + f"*Audit Log*\n" + f" Entries: {len(cm.cm.get('audit_log', []))}\n" + ) + + keyboard = [[InlineKeyboardButton("🔙 Back", callback_data="admin:back")]] + + await query.edit_message_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +__all__ = [ + 'admin_command', + 'admin_callback_handler', + '_show_admin_menu', +] diff --git a/handlers/bots/__init__.py b/handlers/bots/__init__.py index fca3044..ac92790 100644 --- a/handlers/bots/__init__.py +++ b/handlers/bots/__init__.py @@ -16,7 +16,7 @@ from telegram import Update from telegram.ext import ContextTypes, CallbackQueryHandler, MessageHandler, filters -from utils.auth import restricted +from utils.auth import restricted, hummingbot_api_required from handlers import clear_all_input_states # Import submodule handlers @@ -28,6 +28,9 @@ show_controller_detail, handle_stop_controller, handle_confirm_stop_controller, + handle_start_controller, + handle_confirm_start_controller, + handle_clone_controller, handle_quick_stop_controller, handle_quick_start_controller, handle_stop_bot, @@ -66,11 +69,13 @@ handle_cfg_edit_save, handle_cfg_edit_save_all, handle_cfg_edit_cancel, + handle_cfg_branch, show_new_grid_strike_form, show_new_pmm_mister_form, show_config_form, handle_set_field, handle_toggle_side, + handle_toggle_position_mode, handle_cycle_order_type, handle_select_connector, process_field_input, @@ -86,7 +91,6 @@ process_deploy_field_input, handle_execute_deploy, # Progressive deploy flow - show_deploy_progressive_form, handle_deploy_progressive_input, handle_deploy_use_default, handle_deploy_skip_field, @@ -128,13 +132,17 @@ handle_gs_review_back, handle_gs_edit_price, process_gs_wizard_input, + handle_gs_pair_select, # PMM Mister wizard handle_pmm_wizard_connector, handle_pmm_wizard_pair, + handle_pmm_pair_select, handle_pmm_wizard_leverage, handle_pmm_wizard_allocation, + handle_pmm_wizard_amount, handle_pmm_wizard_spreads, handle_pmm_wizard_tp, + handle_pmm_back, handle_pmm_save, handle_pmm_review_back, handle_pmm_edit_id, @@ -143,9 +151,11 @@ handle_pmm_edit_advanced, handle_pmm_adv_setting, process_pmm_wizard_input, + # Custom config upload + show_upload_config_prompt, + handle_upload_cancel, + handle_config_file_upload, ) -from ._shared import clear_bots_state, SIDE_LONG, SIDE_SHORT - # Archived bots handlers from .archived import ( show_archived_menu, @@ -154,7 +164,6 @@ show_bot_chart, handle_generate_report, handle_archived_refresh, - clear_archived_state, ) logger = logging.getLogger(__name__) @@ -165,6 +174,7 @@ # ============================================ @restricted +@hummingbot_api_required async def bots_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """ Handle /bots command - Display bots dashboard @@ -193,7 +203,7 @@ async def bots_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No from ._shared import get_bots_client try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) bot_status = await client.bot_orchestration.get_bot_status(bot_name) response_message = format_bot_status(bot_status) await msg.reply_text(response_message, parse_mode="MarkdownV2") @@ -305,6 +315,16 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY elif main_action == "cfg_edit_cancel": await handle_cfg_edit_cancel(update, context) + elif main_action == "cfg_branch": + await handle_cfg_branch(update, context) + + # Custom config upload + elif main_action == "upload_config": + await show_upload_config_prompt(update, context) + + elif main_action == "upload_cancel": + await handle_upload_cancel(update, context) + elif main_action == "noop": pass # Do nothing - used for pagination display button @@ -330,6 +350,9 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY elif main_action == "toggle_side": await handle_toggle_side(update, context) + elif main_action == "toggle_position_mode": + await handle_toggle_position_mode(update, context) + elif main_action == "cycle_order_type": if len(action_parts) > 1: order_type_key = action_parts[1] # 'open' or 'tp' @@ -400,7 +423,8 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY elif main_action == "select_image": if len(action_parts) > 1: - image = action_parts[1] + # Rejoin parts to preserve colons in image tag (e.g., "hummingbot:development") + image = ":".join(action_parts[1:]) await handle_select_image(update, context, image) elif main_action == "select_name": @@ -425,6 +449,11 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY pair = action_parts[1] await handle_gs_wizard_pair(update, context, pair) + elif main_action == "gs_pair_select": + if len(action_parts) > 1: + pair = action_parts[1] + await handle_gs_pair_select(update, context, pair) + elif main_action == "gs_side": if len(action_parts) > 1: side_str = action_parts[1] @@ -517,6 +546,11 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY pair = action_parts[1] await handle_pmm_wizard_pair(update, context, pair) + elif main_action == "pmm_pair_select": + if len(action_parts) > 1: + pair = action_parts[1] + await handle_pmm_pair_select(update, context, pair) + elif main_action == "pmm_leverage": if len(action_parts) > 1: leverage = int(action_parts[1]) @@ -527,6 +561,11 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY allocation = float(action_parts[1]) await handle_pmm_wizard_allocation(update, context, allocation) + elif main_action == "pmm_amount": + if len(action_parts) > 1: + amount = float(action_parts[1]) + await handle_pmm_wizard_amount(update, context, amount) + elif main_action == "pmm_spreads": if len(action_parts) > 1: spreads = action_parts[1] @@ -537,6 +576,11 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY tp = float(action_parts[1]) await handle_pmm_wizard_tp(update, context, tp) + elif main_action == "pmm_back": + if len(action_parts) > 1: + target = action_parts[1] + await handle_pmm_back(update, context, target) + elif main_action == "pmm_save": await handle_pmm_save(update, context) @@ -602,6 +646,17 @@ async def bots_callback_handler(update: Update, context: ContextTypes.DEFAULT_TY elif main_action == "confirm_stop_ctrl": await handle_confirm_stop_controller(update, context) + # Start controller (uses context) + elif main_action == "start_ctrl": + await handle_start_controller(update, context) + + elif main_action == "confirm_start_ctrl": + await handle_confirm_start_controller(update, context) + + # Clone controller (PMM Mister only) + elif main_action == "clone_ctrl": + await handle_clone_controller(update, context) + # Quick stop/start controller (from bot detail view) elif main_action == "stop_ctrl_quick": if len(action_parts) > 1: @@ -764,10 +819,28 @@ def get_bots_message_handler(): ) +@restricted +async def bots_document_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle document uploads for bots module (e.g., config file uploads)""" + # Only process if we're expecting a config upload + if context.user_data.get("bots_state") == "awaiting_config_upload": + await handle_config_file_upload(update, context) + + +def get_bots_document_handler(): + """Get the document handler for bots module""" + return MessageHandler( + filters.Document.ALL, + bots_document_handler + ) + + __all__ = [ 'bots_command', 'bots_callback_handler', 'bots_message_handler', + 'bots_document_handler', 'get_bots_callback_handler', 'get_bots_message_handler', + 'get_bots_document_handler', ] diff --git a/handlers/bots/_shared.py b/handlers/bots/_shared.py index d6d9829..435b993 100644 --- a/handlers/bots/_shared.py +++ b/handlers/bots/_shared.py @@ -28,21 +28,22 @@ from .controllers import SUPPORTED_CONTROLLERS, get_controller from .controllers.grid_strike import ( - DEFAULTS as GRID_STRIKE_DEFAULTS, SIDE_LONG, SIDE_SHORT, ORDER_TYPE_MARKET, ORDER_TYPE_LIMIT, ORDER_TYPE_LIMIT_MAKER, ORDER_TYPE_LABELS, - WIZARD_STEPS as GS_WIZARD_STEPS, - calculate_auto_prices, generate_chart as _gs_generate_chart, generate_id as _gs_generate_id, + calculate_auto_prices, + DEFAULTS as GRID_STRIKE_DEFAULTS, + FIELD_ORDER as GRID_STRIKE_FIELD_ORDER, + EDITABLE_FIELDS as GS_EDITABLE_FIELDS, ) # Convert ControllerField objects to dicts for backwards compatibility -from .controllers.grid_strike import FIELDS as _GS_FIELDS, FIELD_ORDER as GRID_STRIKE_FIELD_ORDER +from .controllers.grid_strike import FIELDS as _GS_FIELDS GRID_STRIKE_FIELDS = { name: { "label": field.label, @@ -58,41 +59,59 @@ # SERVER CLIENT HELPER # ============================================ -async def get_bots_client(chat_id: Optional[int] = None): +async def get_bots_client(chat_id: Optional[int] = None, user_data: Optional[Dict] = None) -> Tuple[Any, str]: """Get the API client for bot operations Args: - chat_id: Optional chat ID to get per-chat server. If None, uses global default. + chat_id: Optional chat ID (legacy, not used for server selection) + user_data: Optional user_data dict to get user's preferred server and user_id Returns: - Client instance with bot_orchestration and controller endpoints + Tuple of (client, server_name) - client has bot_orchestration and controller endpoints Raises: - ValueError: If no enabled servers available + ValueError: If no accessible servers are available for the user """ - from servers import server_manager + from config_manager import get_config_manager - servers = server_manager.list_servers() - enabled_servers = [name for name, cfg in servers.items() if cfg.get("enabled", True)] + cm = get_config_manager() - if not enabled_servers: - raise ValueError("No enabled API servers available") + # Get user_id from user_data for access control + user_id = user_data.get('_user_id') if user_data else None - # Use per-chat server if chat_id provided, otherwise global default - if chat_id is not None: - default_server = server_manager.get_default_server_for_chat(chat_id) + # Get servers the user has access to (not all servers) + if user_id: + accessible_servers = cm.get_accessible_servers(user_id) + # Filter to only enabled servers + all_servers = cm.list_servers() + enabled_accessible = [s for s in accessible_servers if all_servers.get(s, {}).get("enabled", True)] else: - default_server = server_manager.get_default_server() - - if default_server and default_server in enabled_servers: - server_name = default_server + # Fallback for legacy calls without user_data - use all enabled servers + # This should not happen in normal operation + logger.warning("get_bots_client called without user_data - cannot verify server access") + all_servers = cm.list_servers() + enabled_accessible = [name for name, cfg in all_servers.items() if cfg.get("enabled", True)] + + if not enabled_accessible: + raise ValueError("No accessible API servers available. Please configure server access.") + + # Use user's preferred server if valid + preferred = None + if user_data: + from handlers.config.user_preferences import get_active_server + preferred = get_active_server(user_data) + + # Only use preferred server if user has access to it + if preferred and preferred in enabled_accessible: + server_name = preferred + elif enabled_accessible: + server_name = enabled_accessible[0] else: - server_name = enabled_servers[0] - - logger.info(f"Bots using server: {server_name}" + (f" (chat_id={chat_id})" if chat_id else "")) - client = await server_manager.get_client(server_name) + raise ValueError("No accessible API servers available") - return client + logger.info(f"Bots using server: {server_name} (user_id: {user_id})") + client = await cm.get_client(server_name) + return client, server_name # ============================================ @@ -308,10 +327,22 @@ async def get_available_cex_connectors( user_data: dict, client, account_name: str = "master_account", - ttl: int = 300 + ttl: int = 300, + server_name: str = "default" ) -> List[str]: - """Get available CEX connectors with caching.""" - cache_key = f"available_cex_connectors_{account_name}" + """Get available CEX connectors with caching. + + Args: + user_data: context.user_data dict + client: API client instance + account_name: Account name to check credentials for + ttl: Cache time-to-live in seconds + server_name: Server name to include in cache key (prevents cross-server cache pollution) + + Returns: + List of available CEX connector names + """ + cache_key = f"available_cex_connectors_{server_name}_{account_name}" return await cached_call( user_data, cache_key, diff --git a/handlers/bots/archived.py b/handlers/bots/archived.py index db480c6..018b313 100644 --- a/handlers/bots/archived.py +++ b/handlers/bots/archived.py @@ -288,7 +288,7 @@ async def show_archived_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Fetch databases (with caching) - only healthy databases cache_key = "archived_databases" @@ -449,7 +449,7 @@ async def show_archived_detail(update: Update, context: ContextTypes.DEFAULT_TYP context.user_data["archived_current_db"] = db_path context.user_data["archived_current_idx"] = db_index - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Fetch summary summary = await fetch_database_summary(client, db_path) @@ -575,7 +575,7 @@ async def show_timeline_chart(update: Update, context: ContextTypes.DEFAULT_TYPE summaries = context.user_data.get("archived_summaries", {}) if not databases: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) all_databases = await fetch_archived_databases(client) # Filter to only healthy databases databases = await get_healthy_databases(client, all_databases) @@ -586,7 +586,7 @@ async def show_timeline_chart(update: Update, context: ContextTypes.DEFAULT_TYPE ) return - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) bots_data = [] # Import chart functions @@ -671,7 +671,7 @@ async def show_bot_chart(update: Update, context: ContextTypes.DEFAULT_TYPE, db_ parse_mode="MarkdownV2" ) - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Fetch summary and ALL trades summary = await fetch_database_summary(client, db_path) @@ -754,7 +754,7 @@ async def handle_generate_report(update: Update, context: ContextTypes.DEFAULT_T # Import report generation from .archived_report import save_full_report - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) json_path, png_path = await save_full_report(client, db_path) # Update message with success diff --git a/handlers/bots/archived_chart.py b/handlers/bots/archived_chart.py index aec5608..ec0a8c8 100644 --- a/handlers/bots/archived_chart.py +++ b/handlers/bots/archived_chart.py @@ -94,9 +94,9 @@ def calculate_pnl_from_trades(trades: List[Dict[str, Any]]) -> Dict[str, Any]: """ Calculate realized PnL from a list of trades using position tracking. - For perpetual futures: - - OPEN trades establish positions (long or short) - - CLOSE trades realize PnL + Supports two modes: + 1. Perpetual futures: Uses OPEN/CLOSE position tracking + 2. Spot/Market Making (NIL positions): Uses average cost basis inventory tracking Args: trades: List of trade dicts with timestamp, trading_pair, trade_type, @@ -119,6 +119,130 @@ def calculate_pnl_from_trades(trades: List[Dict[str, Any]]) -> Dict[str, Any]: "total_volume": 0, } + # Detect if this is OPEN/CLOSE mode or NIL mode (market making) + position_types = set(t.get("position", "").upper() for t in trades) + has_open_close = "OPEN" in position_types or "CLOSE" in position_types + is_nil_mode = "NIL" in position_types and not has_open_close + + if is_nil_mode: + return _calculate_pnl_average_cost(trades) + else: + return _calculate_pnl_open_close(trades) + + +def _calculate_pnl_average_cost(trades: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Calculate PnL using average cost basis for spot/market making trades. + + This handles NIL position trades where: + - BUY adds to inventory at that price + - SELL realizes PnL based on weighted average cost of inventory + """ + # Track inventory per trading pair using average cost + # inventory = {amount: float, total_cost: float} + inventory: Dict[str, Dict[str, float]] = {} + + pnl_by_pair: Dict[str, float] = defaultdict(float) + cumulative_pnl: List[Dict[str, Any]] = [] + running_pnl = 0.0 + total_fees = 0.0 + total_volume = 0.0 + + # Debug counters + buy_count = 0 + sell_count = 0 + realized_trades = 0 + + # Sort trades by timestamp + sorted_trades = sorted(trades, key=lambda t: t.get("timestamp", 0)) + + for trade in sorted_trades: + pair = trade.get("trading_pair", "Unknown") + amount = float(trade.get("amount", 0)) + price = float(trade.get("price", 0)) + trade_type = trade.get("trade_type", "").upper() + fee = float(trade.get("trade_fee_in_quote", 0)) + timestamp = trade.get("timestamp", 0) + + total_fees += fee + total_volume += amount * price + + # Parse timestamp for cumulative chart + ts = _parse_timestamp(timestamp) + + # Initialize inventory for this pair if needed + if pair not in inventory: + inventory[pair] = {"amount": 0.0, "total_cost": 0.0} + + inv = inventory[pair] + + if trade_type == "BUY": + buy_count += 1 + # Add to inventory at this price + inv["amount"] += amount + inv["total_cost"] += amount * price + + elif trade_type == "SELL": + sell_count += 1 + # Realize PnL if we have inventory + if inv["amount"] > 0: + realized_trades += 1 + # Calculate average cost of inventory + avg_cost = inv["total_cost"] / inv["amount"] if inv["amount"] > 0 else 0 + + # Determine how much we can actually sell from inventory + sell_amount = min(amount, inv["amount"]) + + # PnL = (sell_price - avg_cost) * amount - fee + pnl = (price - avg_cost) * sell_amount - fee + + pnl_by_pair[pair] += pnl + running_pnl += pnl + + # Reduce inventory + if sell_amount >= inv["amount"]: + # Fully depleted + inv["amount"] = 0.0 + inv["total_cost"] = 0.0 + else: + # Partially depleted - reduce proportionally + ratio = sell_amount / inv["amount"] + inv["amount"] -= sell_amount + inv["total_cost"] -= inv["total_cost"] * ratio + else: + # Short selling (no inventory) - track as negative PnL for now + # This means we're selling something we don't have (going short) + # For simplicity, just count fees + running_pnl -= fee + + # Record cumulative PnL point for charting + if ts: + cumulative_pnl.append({ + "timestamp": ts, + "pnl": running_pnl, + "pair": pair, + }) + + logger.info(f"PnL calculation (avg cost): {len(trades)} trades, {buy_count} BUY, {sell_count} SELL, " + f"{realized_trades} realized, total_pnl=${running_pnl:.4f}") + + return { + "total_pnl": running_pnl, + "total_fees": total_fees, + "pnl_by_pair": dict(pnl_by_pair), + "cumulative_pnl": cumulative_pnl, + "total_volume": total_volume, + } + + +def _calculate_pnl_open_close(trades: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Calculate PnL using OPEN/CLOSE position tracking for perpetual futures. + + For perpetual futures: + - OPEN trades establish positions (long or short) + - CLOSE trades realize PnL + """ # Track positions per trading pair # position = {amount: float, total_cost: float, direction: int (1=long, -1=short)} positions: Dict[str, Dict[str, Any]] = {} @@ -212,7 +336,7 @@ def calculate_pnl_from_trades(trades: List[Dict[str, Any]]) -> Dict[str, Any]: "pair": pair, }) - logger.info(f"PnL calculation: {len(trades)} trades, {open_count} OPEN, {close_count} CLOSE, " + logger.info(f"PnL calculation (open/close): {len(trades)} trades, {open_count} OPEN, {close_count} CLOSE, " f"{close_with_position} CLOSE with matching position, total_pnl=${running_pnl:.4f}") return { diff --git a/handlers/bots/controller_handlers.py b/handlers/bots/controller_handlers.py index 7145e5c..cc1e176 100644 --- a/handlers/bots/controller_handlers.py +++ b/handlers/bots/controller_handlers.py @@ -13,8 +13,9 @@ """ import asyncio +import copy import logging -from typing import Dict, Any, List, Optional +from typing import List from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.error import BadRequest @@ -27,7 +28,6 @@ get_controller_config, set_controller_config, init_new_controller_config, - format_controller_config_summary, format_config_field_value, get_available_cex_connectors, fetch_current_price, @@ -35,11 +35,10 @@ calculate_auto_prices, generate_config_id, generate_candles_chart, - SUPPORTED_CONTROLLERS, GRID_STRIKE_DEFAULTS, GRID_STRIKE_FIELDS, GRID_STRIKE_FIELD_ORDER, - GS_WIZARD_STEPS, + GS_EDITABLE_FIELDS, SIDE_LONG, SIDE_SHORT, ORDER_TYPE_MARKET, @@ -47,18 +46,20 @@ ORDER_TYPE_LIMIT_MAKER, ORDER_TYPE_LABELS, ) +from .controllers.pmm_mister import ( + FIELDS as PMM_FIELDS, + FIELD_ORDER as PMM_FIELD_ORDER, +) from .controllers.grid_strike.grid_analysis import ( calculate_natr, - calculate_price_stats, suggest_grid_params, generate_theoretical_grid, - format_grid_summary, ) from handlers.cex._shared import ( - fetch_cex_balances, get_cex_balances, - fetch_trading_rules, get_trading_rules, + validate_trading_pair, + get_correct_pair_format, ) logger = logging.getLogger(__name__) @@ -152,7 +153,7 @@ async def show_controller_configs_menu(update: Update, context: ContextTypes.DEF chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) configs = await client.controllers.list_controller_configs() # Store all configs @@ -269,6 +270,7 @@ async def show_controller_configs_menu(update: Update, context: ContextTypes.DEF ]) keyboard.append([ + InlineKeyboardButton("📤 Upload", callback_data="bots:upload_config"), InlineKeyboardButton("⬅️ Back", callback_data="bots:main_menu"), ]) @@ -446,7 +448,7 @@ async def handle_cfg_delete_execute(update: Update, context: ContextTypes.DEFAUL ) # Delete each config - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) deleted = [] failed = [] @@ -548,23 +550,35 @@ async def handle_cfg_edit_loop(update: Update, context: ContextTypes.DEFAULT_TYP def _get_editable_config_fields(config: dict) -> dict: - """Extract editable fields from a controller config""" + """Extract editable fields from a controller config using centralized field definitions""" controller_type = config.get("controller_name", "grid_strike") tp_cfg = config.get("triple_barrier_config", {}) take_profit = tp_cfg.get("take_profit", 0.0001) if isinstance(tp_cfg, dict) else 0.0001 if "grid_strike" in controller_type: - return { - "start_price": config.get("start_price", 0), - "end_price": config.get("end_price", 0), - "limit_price": config.get("limit_price", 0), - "total_amount_quote": config.get("total_amount_quote", 0), - "max_open_orders": config.get("max_open_orders", 3), - "max_orders_per_batch": config.get("max_orders_per_batch", 1), - "min_spread_between_orders": config.get("min_spread_between_orders", 0.0001), - "activation_bounds": config.get("activation_bounds", 0.01), - "take_profit": take_profit, - } + # Use centralized GS_EDITABLE_FIELDS for consistency between wizard and edit views + result = {} + for field_name in GS_EDITABLE_FIELDS: + if field_name == "take_profit": + result[field_name] = take_profit + else: + default_val = GRID_STRIKE_DEFAULTS.get(field_name, "") + result[field_name] = config.get(field_name, default_val) + return result + elif "pmm" in controller_type: + # Use centralized PMM_FIELDS and PMM_FIELD_ORDER for consistency + # between config creation and editing + from .controllers.pmm_mister import DEFAULTS as PMM_DEFAULTS + result = {} + for field_name in PMM_FIELD_ORDER: + # Skip 'id' - it's shown in the header already + if field_name == "id": + continue + if field_name in PMM_FIELDS: + # Get value from config, fallback to PMM_DEFAULTS + default_val = PMM_DEFAULTS.get(field_name, "") + result[field_name] = config.get(field_name, default_val) + return result # Default fields for other controller types return { "total_amount_quote": config.get("total_amount_quote", 0), @@ -572,7 +586,7 @@ def _get_editable_config_fields(config: dict) -> dict: } -async def show_cfg_edit_form(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: +async def show_cfg_edit_form(update: Update, context: ContextTypes.DEFAULT_TYPE, status_msg: str = None) -> None: """Show edit form for current config in bulk edit format (key=value)""" query = update.callback_query @@ -605,10 +619,23 @@ async def show_cfg_edit_form(update: Update, context: ContextTypes.DEFAULT_TYPE) context.user_data["cfg_edit_chat_id"] = query.message.chat_id # Build message with key=value format - lines = [f"*Edit Config* \\({current_idx + 1}/{total}\\)", ""] + header = f"*Edit Config* \\({current_idx + 1}/{total}\\)" + if status_msg: + header += f" — {escape_markdown_v2(status_msg)}" + lines = [header, ""] lines.append(f"`{escape_markdown_v2(config_id)}`") lines.append("") + # Add context info for Grid Strike (connector, trading pair, side) + controller_type = config.get("controller_name", "") + if "grid_strike" in controller_type: + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + side = config.get("side", SIDE_LONG) + side_str = "LONG" if side == SIDE_LONG else "SHORT" + lines.append(f"*{escape_markdown_v2(pair)}* {side_str} on {escape_markdown_v2(connector)}") + lines.append("") + # Build config text for display (each line copyable) for key, value in editable_fields.items(): lines.append(f"`{key}={value}`") @@ -627,6 +654,11 @@ async def show_cfg_edit_form(update: Update, context: ContextTypes.DEFAULT_TYPE) nav_row.append(InlineKeyboardButton("Next ▶️", callback_data="bots:cfg_edit_next")) keyboard.append(nav_row) + # Branch button row + keyboard.append([ + InlineKeyboardButton("🔀 Branch", callback_data="bots:cfg_branch"), + ]) + # Final row keyboard.append([ InlineKeyboardButton("💾 Save All & Exit", callback_data="bots:cfg_edit_save_all"), @@ -772,6 +804,7 @@ async def process_cfg_edit_input(update: Update, context: ContextTypes.DEFAULT_T return # Apply updates to config + old_config_id = config.get("id", "") for key, value in updates.items(): if key == "take_profit": if "triple_barrier_config" not in config: @@ -780,10 +813,36 @@ async def process_cfg_edit_input(update: Update, context: ContextTypes.DEFAULT_T else: config[key] = value - # Store modified config - config_id = config.get("id") + # Auto-update ID if connector_name or trading_pair changed + if "connector_name" in updates or "trading_pair" in updates: + # Extract sequence number from old ID + parts = old_config_id.split("_", 1) + seq_num = parts[0] if parts and parts[0].isdigit() else "001" + + # Determine controller type abbreviation + controller_name = config.get("controller_name", "") + if controller_name == "grid_strike": + type_abbrev = "gs" + elif controller_name == "pmm_mister": + type_abbrev = "pmm" + else: + type_abbrev = parts[1].split("_")[0] if len(parts) > 1 and "_" in parts[1] else "cfg" + + # Build new ID with current values + connector = config.get("connector_name", "unknown") + conn_clean = connector.replace("_perpetual", "").replace("_spot", "") + pair = config.get("trading_pair", "UNKNOWN").upper() + new_config_id = f"{seq_num}_{type_abbrev}_{conn_clean}_{pair}" + + config["id"] = new_config_id + else: + new_config_id = old_config_id + + # Store modified config (remove old key if ID changed) modified = context.user_data.get("cfg_edit_modified", {}) - modified[config_id] = config + if old_config_id != new_config_id and old_config_id in modified: + del modified[old_config_id] + modified[new_config_id] = config context.user_data["cfg_edit_modified"] = modified # Update in edit loop @@ -793,19 +852,73 @@ async def process_cfg_edit_input(update: Update, context: ContextTypes.DEFAULT_T configs_to_edit[current_idx] = config # Update editable fields for display - context.user_data["cfg_editable_fields"] = _get_editable_config_fields(config) + editable_fields = _get_editable_config_fields(config) + context.user_data["cfg_editable_fields"] = editable_fields - # Format updated fields - updated_lines = [f"`{escape_markdown_v2(k)}` \\= `{escape_markdown_v2(str(v))}`" for k, v in updates.items()] + # Try to delete the user's input message + try: + await update.message.delete() + except Exception: + pass - keyboard = [[InlineKeyboardButton("✅ Continue", callback_data="bots:cfg_edit_form")]] + # Rebuild the edit form with updated values + total = len(configs_to_edit) + config_id = config.get("id", "unknown") - await update.get_bot().send_message( - chat_id=chat_id, - text=f"✅ *Updated*\n\n" + "\n".join(updated_lines) + "\n\n_Tap to continue editing_", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard) - ) + lines = [f"*Edit Config* \\({current_idx + 1}/{total}\\)", ""] + lines.append(f"`{escape_markdown_v2(config_id)}`") + lines.append("") + + # Add context info for Grid Strike (connector, trading pair, side) + controller_type = config.get("controller_name", "") + if "grid_strike" in controller_type: + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + side = config.get("side", SIDE_LONG) + side_str = "LONG" if side == SIDE_LONG else "SHORT" + lines.append(f"*{escape_markdown_v2(pair)}* {side_str} on {escape_markdown_v2(connector)}") + lines.append("") + + for key, value in editable_fields.items(): + lines.append(f"`{key}={value}`") + lines.append("") + lines.append("✏️ _Send `key=value` to update_") + + # Build keyboard + keyboard = [] + nav_row = [] + if current_idx > 0: + nav_row.append(InlineKeyboardButton("◀️ Prev", callback_data="bots:cfg_edit_prev")) + nav_row.append(InlineKeyboardButton(f"💾 Save", callback_data="bots:cfg_edit_save")) + if current_idx < total - 1: + nav_row.append(InlineKeyboardButton("Next ▶️", callback_data="bots:cfg_edit_next")) + keyboard.append(nav_row) + keyboard.append([InlineKeyboardButton("🔀 Branch", callback_data="bots:cfg_branch")]) + keyboard.append([ + InlineKeyboardButton("💾 Save All & Exit", callback_data="bots:cfg_edit_save_all"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:cfg_edit_cancel"), + ]) + + # Edit the original message + message_id = context.user_data.get("cfg_edit_message_id") + if message_id: + try: + await update.get_bot().edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="\n".join(lines), + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception: + # If edit fails, send a new message + msg = await update.get_bot().send_message( + chat_id=chat_id, + text="\n".join(lines), + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + context.user_data["cfg_edit_message_id"] = msg.message_id async def handle_cfg_edit_prev(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -838,15 +951,18 @@ async def handle_cfg_edit_save(update: Update, context: ContextTypes.DEFAULT_TYP config_id = config.get("id") try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) await client.controllers.create_or_update_controller_config(config_id, config) - await query.answer(f"✅ Saved {config_id[:20]}") + await query.answer() # Remove from modified since it's now saved modified = context.user_data.get("cfg_edit_modified", {}) modified.pop(config_id, None) context.user_data["cfg_edit_modified"] = modified + # Refresh form with saved status + await show_cfg_edit_form(update, context, status_msg="✅ Saved!") + except Exception as e: logger.error(f"Failed to save config {config_id}: {e}") await query.answer(f"❌ Save failed: {str(e)[:30]}", show_alert=True) @@ -874,7 +990,7 @@ async def handle_cfg_edit_save_all(update: Update, context: ContextTypes.DEFAULT parse_mode="MarkdownV2" ) - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) saved = [] failed = [] @@ -926,6 +1042,101 @@ async def handle_cfg_edit_cancel(update: Update, context: ContextTypes.DEFAULT_T await show_controller_configs_menu(update, context) +async def handle_cfg_branch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Branch (duplicate) the current config with a new ID""" + query = update.callback_query + chat_id = update.effective_chat.id + + configs_to_edit = context.user_data.get("cfg_edit_loop", []) + current_idx = context.user_data.get("cfg_edit_index", 0) + modified = context.user_data.get("cfg_edit_modified", {}) + + if not configs_to_edit or current_idx >= len(configs_to_edit): + await query.answer("No config to branch") + return + + # Get current config (use modified version if exists) + config = configs_to_edit[current_idx] + config_id = config.get("id", "unknown") + if config_id in modified: + config = modified[config_id] + + # Generate new ID by incrementing the sequence number + # Format: NNN_type_connector_pair -> increment NNN + old_id = config.get("id", "") + parts = old_id.split("_", 1) + + # Find highest sequence number across all configs from multiple sources + client, _ = await get_bots_client(chat_id, context.user_data) + + # Source 1: Fresh list from API + try: + api_configs = await client.controllers.list_controller_configs() + except Exception: + api_configs = [] + + # Source 2: Cached list in user_data (may have configs not yet saved) + cached_configs = context.user_data.get("controller_configs_list", []) + + max_num = 0 + + # Check all sources for highest sequence number + all_config_sources = [api_configs, cached_configs, configs_to_edit] + for config_list in all_config_sources: + for cfg in config_list: + cfg_id = cfg.get("id", "") if isinstance(cfg, dict) else "" + cfg_parts = cfg_id.split("_", 1) + if cfg_parts and cfg_parts[0].isdigit(): + max_num = max(max_num, int(cfg_parts[0])) + + # Also check modified config IDs (keys) + for cfg_id in modified.keys(): + cfg_parts = cfg_id.split("_", 1) + if cfg_parts and cfg_parts[0].isdigit(): + max_num = max(max_num, int(cfg_parts[0])) + + # Create new ID based on current config values + new_num = str(max_num + 1).zfill(3) + + # Determine controller type abbreviation + controller_name = config.get("controller_name", "") + if controller_name == "grid_strike": + type_abbrev = "gs" + elif controller_name == "pmm_mister": + type_abbrev = "pmm" + else: + # Fallback: try to extract from old ID + if len(parts) > 1: + type_abbrev = parts[1].split("_")[0] if "_" in parts[1] else parts[1] + else: + type_abbrev = "cfg" + + # Get connector and trading pair from current config values + connector = config.get("connector_name", "unknown") + conn_clean = connector.replace("_perpetual", "").replace("_spot", "") + pair = config.get("trading_pair", "UNKNOWN").upper() + + new_id = f"{new_num}_{type_abbrev}_{conn_clean}_{pair}" + + # Deep copy the config with new ID + new_config = copy.deepcopy(config) + new_config["id"] = new_id + + # Add to edit loop right after current config + configs_to_edit.insert(current_idx + 1, new_config) + context.user_data["cfg_edit_loop"] = configs_to_edit + + # Mark as modified so it gets saved + modified[new_id] = new_config + context.user_data["cfg_edit_modified"] = modified + + # Navigate to the new config + context.user_data["cfg_edit_index"] = current_idx + 1 + + await query.answer(f"Branched to {new_id}") + await show_cfg_edit_form(update, context) + + async def handle_configs_page(update: Update, context: ContextTypes.DEFAULT_TYPE, page: int) -> None: """Handle pagination for controller configs menu (legacy, redirects to cfg_page)""" controller_type = context.user_data.get("configs_controller_type") @@ -953,9 +1164,19 @@ async def show_new_grid_strike_form(update: Update, context: ContextTypes.DEFAUL query = update.callback_query chat_id = update.effective_chat.id + # Clear any cached market data from previous wizard runs + # This prevents showing stale data when starting a new grid for a different pair + gs_keys_to_clear = [ + "gs_current_price", "gs_candles", "gs_candles_interval", + "gs_chart_interval", "gs_natr", "gs_trading_rules", + "gs_theoretical_grid", "gs_market_data_ready", "gs_market_data_error" + ] + for key in gs_keys_to_clear: + context.user_data.pop(key, None) + # Fetch existing configs for sequence numbering try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) configs = await client.controllers.list_controller_configs() context.user_data["controller_configs_list"] = configs except Exception as e: @@ -979,15 +1200,19 @@ async def _show_wizard_connector_step(update: Update, context: ContextTypes.DEFA config = get_controller_config(context) try: - client = await get_bots_client(chat_id) - cex_connectors = await get_available_cex_connectors(context.user_data, client) + client, server_name = await get_bots_client(chat_id, context.user_data) + cex_connectors = await get_available_cex_connectors(context.user_data, client, server_name=server_name) if not cex_connectors: - keyboard = [[InlineKeyboardButton("Back", callback_data="bots:main_menu")]] + keyboard = [ + [InlineKeyboardButton("🔑 Configure API Keys", callback_data="config_api_keys")], + [InlineKeyboardButton("« Back", callback_data="bots:main_menu")] + ] await query.message.edit_text( r"*Grid Strike \- New Config*" + "\n\n" - r"No CEX connectors configured\." + "\n" - r"Please configure exchange credentials first\.", + r"⚠️ No CEX connectors available\." + "\n\n" + r"You need to connect API keys for an exchange to deploy strategies\." + "\n" + r"Click below to configure your API keys\.", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -1009,9 +1234,12 @@ async def _show_wizard_connector_step(update: Update, context: ContextTypes.DEFA await query.message.edit_text( r"*📈 Grid Strike \- Step 1*" + "\n\n" r"🏦 *Select Connector*" + "\n\n" + r"Grid Strike automatically places a grid of buy or sell orders within a set price range\." + "\n" + r"[📖 Strategy Guide](https://hummingbot.org/blog/strategy-guide-grid-strike/)" + "\n\n" r"Choose the exchange for this grid \(spot or perpetual\):", parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard) + reply_markup=InlineKeyboardMarkup(keyboard), + disable_web_page_preview=True ) except Exception as e: @@ -1043,6 +1271,14 @@ async def handle_gs_wizard_pair(update: Update, context: ContextTypes.DEFAULT_TY chat_id = update.effective_chat.id config = get_controller_config(context) + # Clear old market data if pair changed (prevents stale data) + old_pair = config.get("trading_pair", "") + if old_pair and old_pair.upper() != pair.upper(): + for key in ["gs_current_price", "gs_candles", "gs_candles_interval", + "gs_natr", "gs_trading_rules", "gs_theoretical_grid", + "gs_market_data_ready", "gs_market_data_error"]: + context.user_data.pop(key, None) + config["trading_pair"] = pair.upper() set_controller_config(context, config) @@ -1172,6 +1408,10 @@ async def _show_wizard_leverage_step(update: Update, context: ContextTypes.DEFAU pair = config.get("trading_pair", "") side = "📈 LONG" if config.get("side") == SIDE_LONG else "📉 SHORT" + # Enable text input for leverage + context.user_data["bots_state"] = "gs_wizard_input" + context.user_data["gs_wizard_step"] = "leverage" + keyboard = [ [ InlineKeyboardButton("1x", callback_data="bots:gs_leverage:1"), @@ -1193,7 +1433,8 @@ async def _show_wizard_leverage_step(update: Update, context: ContextTypes.DEFAU await query.message.edit_text( r"*📈 Grid Strike \- Step 4/6*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}` \\| {side}" + "\n\n" - r"⚡ *Select Leverage*", + r"⚡ *Select Leverage*" + "\n" + r"_Or type a value \(e\.g\. 2, 3x\)_", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -1234,7 +1475,7 @@ async def _show_wizard_amount_step(update: Update, context: ContextTypes.DEFAULT # Fetch balances for the connector balance_text = "" try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) balances = await get_cex_balances( context.user_data, client, "master_account", ttl=30 ) @@ -1378,6 +1619,29 @@ async def handle_gs_wizard_amount(update: Update, context: ContextTypes.DEFAULT_ await _show_wizard_prices_step(update, context) +def _calculate_min_order_amount(current_price: float, trading_rules: dict, default: float = 6.0) -> float: + """ + Calculate minimum order amount based on trading rules. + + The minimum is the greater of: + - min_notional_size from trading rules + - current_price * min_order_size (min base amount) + - the provided default + + Returns the calculated minimum order amount in quote currency. + """ + min_notional = trading_rules.get("min_notional_size", 0) or 0 + min_order_size = trading_rules.get("min_order_size", 0) or 0 + + # Calculate min from base amount requirement + min_from_base = current_price * min_order_size if min_order_size > 0 else 0 + + # Take the maximum of all constraints + calculated_min = max(default, min_notional, min_from_base) + + return calculated_min + + async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT_TYPE, interval: str = None) -> None: """Wizard Step 6: Grid Configuration with prices, TP, spread, and grid analysis""" query = update.callback_query @@ -1389,9 +1653,9 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT side = config.get("side", SIDE_LONG) total_amount = config.get("total_amount_quote", 1000) - # Get current interval (default 1m for better NATR calculation) + # Get current interval (default 5m for better NATR calculation) if interval is None: - interval = context.user_data.get("gs_chart_interval", "1m") + interval = context.user_data.get("gs_chart_interval", "5m") context.user_data["gs_chart_interval"] = interval # Check if we have pre-cached data from background fetch @@ -1400,7 +1664,7 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT try: # If no cached data or interval changed, fetch now - cached_interval = context.user_data.get("gs_candles_interval", "1m") + cached_interval = context.user_data.get("gs_candles_interval", "5m") need_refetch = interval != cached_interval if not current_price or need_refetch: @@ -1427,7 +1691,7 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT ) context.user_data["gs_wizard_message_id"] = loading_msg.message_id - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) current_price = await fetch_current_price(client, connector, pair) if current_price: @@ -1451,7 +1715,7 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT await query.message.edit_text( r"*❌ Error*" + "\n\n" f"Could not fetch price for `{escape_markdown_v2(pair)}`\\.\n" - r"Please check the trading pair and try again\\.", + r"Please check the trading pair and try again\.", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -1461,7 +1725,7 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT text=( r"*❌ Error*" + "\n\n" f"Could not fetch price for `{escape_markdown_v2(pair)}`\\.\n" - r"Please check the trading pair and try again\\." + r"Please check the trading pair and try again\." ), parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) @@ -1475,6 +1739,13 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT if candles_list: natr = calculate_natr(candles_list, period=14) context.user_data["gs_natr"] = natr + # Use the last candle's close price for better chart alignment + last_candle = candles_list[-1] if candles_list else None + if last_candle: + last_close = last_candle.get("close") or last_candle.get("c") + if last_close: + current_price = float(last_close) + context.user_data["gs_current_price"] = current_price # Get trading rules trading_rules = context.user_data.get("gs_trading_rules", {}) @@ -1505,12 +1776,15 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT limit = config.get("limit_price") min_spread = config.get("min_spread_between_orders", 0.0001) take_profit = config.get("triple_barrier_config", {}).get("take_profit", 0.0001) - min_order_amount = config.get("min_order_amount_quote", max(6, min_notional)) + + # Calculate minimum order amount from trading rules + required_min_order = _calculate_min_order_amount(current_price, trading_rules, default=6.0) + min_order_amount = config.get("min_order_amount_quote", required_min_order) # Ensure min_order_amount respects exchange rules - if min_notional > min_order_amount: - config["min_order_amount_quote"] = min_notional - min_order_amount = min_notional + if min_order_amount < required_min_order: + config["min_order_amount_quote"] = required_min_order + min_order_amount = required_min_order # Generate config ID with sequence number (if not already set) if not config.get("id"): @@ -1560,6 +1834,9 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT max_open_orders = config.get("max_open_orders", 3) order_frequency = config.get("order_frequency", 3) leverage = config.get("leverage", 1) + position_mode = config.get("position_mode", "ONEWAY") + coerce_tp_to_step = config.get("coerce_tp_to_step", False) + activation_bounds = config.get("activation_bounds", 0.01) side_value = config.get("side", SIDE_LONG) side_str_label = "LONG" if side_value == SIDE_LONG else "SHORT" @@ -1577,18 +1854,23 @@ async def _show_wizard_prices_step(update: Update, context: ContextTypes.DEFAULT rf"*📈 Grid Strike \- Step {final_step}/{final_step} \(Final\)*" + "\n\n" f"*{escape_markdown_v2(pair)}* {side_str_label}\n" f"Price: `{current_price:,.6g}` \\| Range: `{range_pct}` \\| NATR: `{natr_pct}`\n\n" + f"`connector_name={connector}`\n" + f"`trading_pair={pair}`\n" f"`total_amount_quote={total_amount:.0f}`\n" f"`start_price={start:.6g}`\n" f"`end_price={end:.6g}`\n" f"`limit_price={limit:.6g}`\n" f"`leverage={leverage}`\n" + f"`position_mode={position_mode}`\n" f"`take_profit={take_profit}`\n" + f"`coerce_tp_to_step={str(coerce_tp_to_step).lower()}`\n" f"`min_spread_between_orders={min_spread}`\n" f"`min_order_amount_quote={min_order_amount:.0f}`\n" - f"`max_open_orders={max_open_orders}`\n\n" + f"`max_open_orders={max_open_orders}`\n" + f"`activation_bounds={activation_bounds}`\n\n" f"{grid_valid} Grid: `{grid['num_levels']}` levels " f"\\(↓{grid.get('levels_below_current', 0)} ↑{grid.get('levels_above_current', 0)}\\) " - f"@ `${grid['amount_per_level']:.2f}`/lvl" + f"@ `${grid['amount_per_level']:.2f}`/lvl \\| step: `{grid.get('spread_pct', 0):.3f}%`" ) # Add warnings if any @@ -1815,18 +2097,23 @@ async def _show_wizard_review_step(update: Update, context: ContextTypes.DEFAULT pair = config.get("trading_pair", "") side = "LONG" if config.get("side") == SIDE_LONG else "SHORT" leverage = config.get("leverage", 1) + position_mode = config.get("position_mode", "ONEWAY") amount = config.get("total_amount_quote", 0) start_price = config.get("start_price", 0) end_price = config.get("end_price", 0) limit_price = config.get("limit_price", 0) tp = config.get("triple_barrier_config", {}).get("take_profit", 0.0001) + open_order_type = config.get("triple_barrier_config", {}).get("open_order_type", ORDER_TYPE_LIMIT_MAKER) + tp_order_type = config.get("triple_barrier_config", {}).get("take_profit_order_type", ORDER_TYPE_LIMIT_MAKER) keep_position = config.get("keep_position", True) activation_bounds = config.get("activation_bounds", 0.01) config_id = config.get("id", "") max_open_orders = config.get("max_open_orders", 3) max_orders_per_batch = config.get("max_orders_per_batch", 1) + order_frequency = config.get("order_frequency", 3) min_order_amount = config.get("min_order_amount_quote", 6) min_spread = config.get("min_spread_between_orders", 0.0001) + coerce_tp_to_step = config.get("coerce_tp_to_step", False) # Delete previous chart if exists chart_msg_id = context.user_data.pop("gs_chart_message_id", None) @@ -1850,15 +2137,20 @@ async def _show_wizard_review_step(update: Update, context: ContextTypes.DEFAULT f"trading_pair: {pair}\n" f"side: {side_value}\n" f"leverage: {leverage}\n" + f"position_mode: {position_mode}\n" f"total_amount_quote: {amount:.0f}\n" f"start_price: {start_price:.6g}\n" f"end_price: {end_price:.6g}\n" f"limit_price: {limit_price:.6g}\n" f"take_profit: {tp}\n" + f"open_order_type: {open_order_type}\n" + f"take_profit_order_type: {tp_order_type}\n" + f"coerce_tp_to_step: {str(coerce_tp_to_step).lower()}\n" f"keep_position: {str(keep_position).lower()}\n" f"activation_bounds: {activation_bounds}\n" f"max_open_orders: {max_open_orders}\n" f"max_orders_per_batch: {max_orders_per_batch}\n" + f"order_frequency: {order_frequency}\n" f"min_order_amount_quote: {min_order_amount}\n" f"min_spread_between_orders: {min_spread}" ) @@ -1926,18 +2218,23 @@ async def _update_wizard_message_for_review(update: Update, context: ContextType pair = config.get("trading_pair", "") side = "LONG" if config.get("side") == SIDE_LONG else "SHORT" leverage = config.get("leverage", 1) + position_mode = config.get("position_mode", "ONEWAY") amount = config.get("total_amount_quote", 0) start_price = config.get("start_price", 0) end_price = config.get("end_price", 0) limit_price = config.get("limit_price", 0) tp = config.get("triple_barrier_config", {}).get("take_profit", 0.0001) + open_order_type = config.get("triple_barrier_config", {}).get("open_order_type", ORDER_TYPE_LIMIT_MAKER) + tp_order_type = config.get("triple_barrier_config", {}).get("take_profit_order_type", ORDER_TYPE_LIMIT_MAKER) keep_position = config.get("keep_position", True) activation_bounds = config.get("activation_bounds", 0.01) config_id = config.get("id", "") max_open_orders = config.get("max_open_orders", 3) max_orders_per_batch = config.get("max_orders_per_batch", 1) + order_frequency = config.get("order_frequency", 3) min_order_amount = config.get("min_order_amount_quote", 6) min_spread = config.get("min_spread_between_orders", 0.0001) + coerce_tp_to_step = config.get("coerce_tp_to_step", False) # Build copyable config block with real YAML field names side_value = config.get("side", SIDE_LONG) @@ -1947,15 +2244,20 @@ async def _update_wizard_message_for_review(update: Update, context: ContextType f"trading_pair: {pair}\n" f"side: {side_value}\n" f"leverage: {leverage}\n" + f"position_mode: {position_mode}\n" f"total_amount_quote: {amount:.0f}\n" f"start_price: {start_price:.6g}\n" f"end_price: {end_price:.6g}\n" f"limit_price: {limit_price:.6g}\n" f"take_profit: {tp}\n" + f"open_order_type: {open_order_type}\n" + f"take_profit_order_type: {tp_order_type}\n" + f"coerce_tp_to_step: {str(coerce_tp_to_step).lower()}\n" f"keep_position: {str(keep_position).lower()}\n" f"activation_bounds: {activation_bounds}\n" f"max_open_orders: {max_open_orders}\n" f"max_orders_per_batch: {max_orders_per_batch}\n" + f"order_frequency: {order_frequency}\n" f"min_order_amount_quote: {min_order_amount}\n" f"min_spread_between_orders: {min_spread}" ) @@ -1986,7 +2288,8 @@ async def _update_wizard_message_for_review(update: Update, context: ContextType reply_markup=InlineKeyboardMarkup(keyboard) ) except Exception as e: - logger.error(f"Error updating review message: {e}") + logger.error(f"Error updating review message: {e}", exc_info=True) + logger.debug(f"Message text was: {message_text[:500]}") async def handle_gs_edit_id(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -2171,6 +2474,11 @@ async def handle_gs_edit_min_amt(update: Update, context: ContextTypes.DEFAULT_T current = config.get("min_order_amount_quote", 6) + # Calculate minimum required from trading rules + current_price = context.user_data.get("gs_current_price", 0) + trading_rules = context.user_data.get("gs_trading_rules", {}) + required_min = _calculate_min_order_amount(current_price, trading_rules, default=6.0) + keyboard = [ [InlineKeyboardButton("Cancel", callback_data="bots:gs_review_back")], ] @@ -2183,8 +2491,9 @@ async def handle_gs_edit_min_amt(update: Update, context: ContextTypes.DEFAULT_T msg = await context.bot.send_message( chat_id=chat_id, text=r"*Edit Min Order Amount*" + "\n\n" - f"Current: `{current}`" + "\n\n" - r"Enter new value \(e\.g\. 6\):", + f"Current: `{current}`\n" + f"Minimum: `{required_min:.2f}` \\(from trading rules\\)" + "\n\n" + r"Enter new value:", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -2286,7 +2595,7 @@ async def handle_gs_save(update: Update, context: ContextTypes.DEFAULT_TYPE) -> ) try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) result = await client.controllers.create_or_update_controller_config(config_id, config) # Clean up wizard state @@ -2345,7 +2654,7 @@ async def _background_fetch_market_data(context, config: dict, chat_id: int = No return try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Fetch current price current_price = await fetch_current_price(client, connector, pair) @@ -2353,10 +2662,10 @@ async def _background_fetch_market_data(context, config: dict, chat_id: int = No if current_price: context.user_data["gs_current_price"] = current_price - # Fetch candles (1m, 420 records) - consistent with default interval - candles = await fetch_candles(client, connector, pair, interval="1m", max_records=420) + # Fetch candles (5m, 420 records) - consistent with default interval + candles = await fetch_candles(client, connector, pair, interval="5m", max_records=420) context.user_data["gs_candles"] = candles - context.user_data["gs_candles_interval"] = "1m" + context.user_data["gs_candles_interval"] = "5m" context.user_data["gs_market_data_ready"] = True logger.info(f"Background fetch complete for {pair}: price={current_price}") @@ -2374,7 +2683,10 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ chat_id = update.effective_chat.id config = get_controller_config(context) + logger.debug(f"GS wizard input: step={step}, input={user_input[:50]}") + if not step: + logger.warning("GS wizard input called but no step set") return try: @@ -2390,6 +2702,32 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ if "-" not in pair: pair = pair.replace("/", "-").replace("_", "-") + connector = config.get("connector_name", "") + + # Validate trading pair exists on the connector + client, _ = await get_bots_client(chat_id, context.user_data) + is_valid, error_msg, suggestions = await validate_trading_pair( + context.user_data, client, connector, pair + ) + + if not is_valid: + # Show error with suggestions + await _show_gs_pair_suggestions(update, context, pair, error_msg, suggestions, connector) + return + + # Get correctly formatted pair from trading rules + trading_rules = await get_trading_rules(context.user_data, client, connector) + correct_pair = get_correct_pair_format(trading_rules, pair) + pair = correct_pair if correct_pair else pair + + # Clear old market data if pair changed (prevents stale data) + old_pair = config.get("trading_pair", "") + if old_pair and old_pair.upper() != pair.upper(): + for key in ["gs_current_price", "gs_candles", "gs_candles_interval", + "gs_natr", "gs_trading_rules", "gs_theoretical_grid", + "gs_market_data_ready", "gs_market_data_error"]: + context.user_data.pop(key, None) + config["trading_pair"] = pair set_controller_config(context, config) @@ -2416,6 +2754,12 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ if "=" in input_stripped: # Parse field=value format changes_made = False + chart_affecting_change = False # Track if chart needs regeneration + warning_msg = None + # Fields that affect the chart visualization + chart_fields = {"start_price", "start", "end_price", "end", "limit_price", "limit", + "connector_name", "trading_pair"} + for line in input_stripped.split("\n"): line = line.strip() if not line or "=" not in line: @@ -2425,6 +2769,10 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ field = field.strip().lower() value = value.strip() + # Check if this field affects the chart + if field in chart_fields: + chart_affecting_change = True + # Map field names and set values if field in ("start_price", "start"): config["start_price"] = float(value) @@ -2450,7 +2798,16 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ config["min_spread_between_orders"] = val changes_made = True elif field in ("min_order_amount_quote", "min_order_amount", "min_order", "min"): - config["min_order_amount_quote"] = float(value.replace("$", "")) + new_min_amt = float(value.replace("$", "")) + # Validate against trading rules + current_price = context.user_data.get("gs_current_price", 0) + trading_rules = context.user_data.get("gs_trading_rules", {}) + required_min = _calculate_min_order_amount(current_price, trading_rules, default=6.0) + if new_min_amt < required_min: + config["min_order_amount_quote"] = required_min + warning_msg = f"Min order must be >= ${required_min:.2f}" + else: + config["min_order_amount_quote"] = new_min_amt changes_made = True elif field in ("total_amount_quote", "total_amount", "amount"): config["total_amount_quote"] = float(value) @@ -2476,10 +2833,22 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ val = val / 100 config["activation_bounds"] = val changes_made = True + elif field in ("coerce_tp_to_step", "coerce_tp", "coerce"): + # Boolean field - accept true/false/1/0/yes/no + val_lower = value.lower() + config["coerce_tp_to_step"] = val_lower in ("true", "1", "yes", "on") + changes_made = True + elif field == "position_mode": + config["position_mode"] = value.upper() + changes_made = True if changes_made: set_controller_config(context, config) - await _update_wizard_message_for_prices_after_edit(update, context) + # Only regenerate chart if price/pair fields changed + if chart_affecting_change: + await _update_wizard_message_for_prices_after_edit(update, context) + else: + await _update_wizard_caption_only(update, context, warning_msg=warning_msg) else: raise ValueError(f"Unknown field: {field}") @@ -2491,7 +2860,7 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ config["triple_barrier_config"] = GRID_STRIKE_DEFAULTS["triple_barrier_config"].copy() config["triple_barrier_config"]["take_profit"] = tp_decimal set_controller_config(context, config) - await _update_wizard_message_for_prices_after_edit(update, context) + await _update_wizard_caption_only(update, context) elif input_lower.startswith("spread:"): # Min spread in percentage (e.g., spread:0.05 = 0.05% = 0.0005) @@ -2499,14 +2868,23 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ spread_decimal = spread_pct / 100 config["min_spread_between_orders"] = spread_decimal set_controller_config(context, config) - await _update_wizard_message_for_prices_after_edit(update, context) + await _update_wizard_caption_only(update, context) elif input_lower.startswith("min:"): # Min order amount in quote (e.g., min:10 = $10) min_amt = float(input_lower.replace("min:", "").replace("$", "").strip()) - config["min_order_amount_quote"] = min_amt + # Validate against trading rules + current_price = context.user_data.get("gs_current_price", 0) + trading_rules = context.user_data.get("gs_trading_rules", {}) + required_min = _calculate_min_order_amount(current_price, trading_rules, default=6.0) + warning_msg = None + if min_amt < required_min: + config["min_order_amount_quote"] = required_min + warning_msg = f"Min order must be >= ${required_min:.2f}" + else: + config["min_order_amount_quote"] = min_amt set_controller_config(context, config) - await _update_wizard_message_for_prices_after_edit(update, context) + await _update_wizard_caption_only(update, context, warning_msg=warning_msg) else: # Parse comma-separated prices: start,end,limit @@ -2549,6 +2927,21 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ context.user_data["gs_wizard_step"] = "prices" await _update_wizard_message_for_prices(update, context) + elif step == "leverage": + # Parse leverage - handle formats like "2", "2x", "2X", "10x" + lev_input = user_input.strip().lower().replace("x", "") + leverage = int(float(lev_input)) + + if leverage < 1: + raise ValueError("Leverage must be at least 1") + + config["leverage"] = leverage + set_controller_config(context, config) + + # Move to amount step + context.user_data["gs_wizard_step"] = "total_amount_quote" + await _update_wizard_message_for_amount(update, context) + elif step == "edit_id": new_id = user_input.strip() config["id"] = new_id @@ -2601,7 +2994,24 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ await _update_wizard_message_for_review(update, context) elif step == "edit_min_amt": - config["min_order_amount_quote"] = float(user_input) + new_min_amt = float(user_input) + # Validate against trading rules + current_price = context.user_data.get("gs_current_price", 0) + trading_rules = context.user_data.get("gs_trading_rules", {}) + required_min = _calculate_min_order_amount(current_price, trading_rules, default=6.0) + if new_min_amt < required_min: + config["min_order_amount_quote"] = required_min + # Send warning message + warn_msg = await update.message.reply_text( + f"Min order must be >= ${required_min:.2f}. Set to ${required_min:.2f}." + ) + await asyncio.sleep(3) + try: + await warn_msg.delete() + except: + pass + else: + config["min_order_amount_quote"] = new_min_amt set_controller_config(context, config) context.user_data["gs_wizard_step"] = "review" await _update_wizard_message_for_review(update, context) @@ -2621,15 +3031,20 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ "trading_pair": "trading_pair", "side": "side", "leverage": "leverage", + "position_mode": "position_mode", "total_amount_quote": "total_amount_quote", "start_price": "start_price", "end_price": "end_price", "limit_price": "limit_price", "take_profit": "triple_barrier_config.take_profit", + "open_order_type": "triple_barrier_config.open_order_type", + "take_profit_order_type": "triple_barrier_config.take_profit_order_type", + "coerce_tp_to_step": "coerce_tp_to_step", "keep_position": "keep_position", "activation_bounds": "activation_bounds", "max_open_orders": "max_open_orders", "max_orders_per_batch": "max_orders_per_batch", + "order_frequency": "order_frequency", "min_order_amount_quote": "min_order_amount_quote", "min_spread_between_orders": "min_spread_between_orders", } @@ -2660,13 +3075,26 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ config["side"] = SIDE_LONG else: config["side"] = SIDE_SHORT + elif key == "position_mode": + # Accept HEDGE or ONEWAY (case insensitive) + config["position_mode"] = "ONEWAY" if value.upper() == "ONEWAY" else "HEDGE" elif key == "keep_position": config["keep_position"] = value.lower() in ("true", "yes", "y", "1") + elif key == "coerce_tp_to_step": + config["coerce_tp_to_step"] = value.lower() in ("true", "yes", "y", "1") elif key == "take_profit": if "triple_barrier_config" not in config: config["triple_barrier_config"] = GRID_STRIKE_DEFAULTS["triple_barrier_config"].copy() config["triple_barrier_config"]["take_profit"] = float(value) - elif field in ["leverage", "max_open_orders", "max_orders_per_batch"]: + elif key == "open_order_type": + if "triple_barrier_config" not in config: + config["triple_barrier_config"] = GRID_STRIKE_DEFAULTS["triple_barrier_config"].copy() + config["triple_barrier_config"]["open_order_type"] = int(value) + elif key == "take_profit_order_type": + if "triple_barrier_config" not in config: + config["triple_barrier_config"] = GRID_STRIKE_DEFAULTS["triple_barrier_config"].copy() + config["triple_barrier_config"]["take_profit_order_type"] = int(value) + elif field in ["leverage", "max_open_orders", "max_orders_per_batch", "order_frequency"]: config[field] = int(value) elif field in ["total_amount_quote", "start_price", "end_price", "limit_price", "activation_bounds", "min_order_amount_quote", "min_spread_between_orders"]: @@ -2682,8 +3110,9 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ else: raise ValueError("No valid fields found") - except ValueError: + except ValueError as e: # Send error and let user try again + logger.warning(f"GS wizard input ValueError: {e}") error_msg = await update.message.reply_text( f"Invalid input. Please enter a valid value." ) @@ -2693,6 +3122,95 @@ async def process_gs_wizard_input(update: Update, context: ContextTypes.DEFAULT_ await error_msg.delete() except: pass + except Exception as e: + # Catch any other exceptions and log them + logger.error(f"GS wizard input error: {e}", exc_info=True) + try: + error_msg = await update.message.reply_text( + f"Error processing input: {str(e)[:100]}" + ) + await asyncio.sleep(3) + await error_msg.delete() + except: + pass + + +async def _show_gs_pair_suggestions( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + input_pair: str, + error_msg: str, + suggestions: list, + connector: str +) -> None: + """Show trading pair suggestions when validation fails in grid strike wizard""" + config = get_controller_config(context) + message_id = context.user_data.get("gs_wizard_message_id") + chat_id = context.user_data.get("gs_wizard_chat_id") + + # Build suggestion message + help_text = f"❌ *{escape_markdown_v2(error_msg)}*\n\n" + + if suggestions: + help_text += "💡 *Did you mean:*\n" + else: + help_text += "_No similar pairs found\\._\n" + + # Build keyboard with suggestions + keyboard = [] + for pair in suggestions: + keyboard.append([InlineKeyboardButton( + f"📈 {pair}", + callback_data=f"bots:gs_pair_select:{pair}" + )]) + + keyboard.append([ + InlineKeyboardButton("⬅️ Back", callback_data="bots:gs_back:connector"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ]) + reply_markup = InlineKeyboardMarkup(keyboard) + + if message_id and chat_id: + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + except Exception as e: + logger.debug(f"Could not update wizard message: {e}") + else: + await update.effective_chat.send_message( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + +async def handle_gs_pair_select(update: Update, context: ContextTypes.DEFAULT_TYPE, trading_pair: str) -> None: + """Handle selection of a suggested trading pair in grid strike wizard""" + config = get_controller_config(context) + chat_id = update.effective_chat.id + + # Clear old market data + for key in ["gs_current_price", "gs_candles", "gs_candles_interval", + "gs_natr", "gs_trading_rules", "gs_theoretical_grid", + "gs_market_data_ready", "gs_market_data_error"]: + context.user_data.pop(key, None) + + config["trading_pair"] = trading_pair + set_controller_config(context, config) + + # Start background fetch of market data + asyncio.create_task(_background_fetch_market_data(context, config, chat_id)) + + # Move to side step + context.user_data["gs_wizard_step"] = "side" + + # Update the wizard message + await _update_wizard_message_for_side(update, context) async def _update_wizard_message_for_side(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -2776,6 +3294,47 @@ async def delete(self): await _show_wizard_prices_step(fake_update, context) +async def _update_wizard_message_for_amount(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Trigger amount step after leverage input""" + message_id = context.user_data.get("gs_wizard_message_id") + chat_id = context.user_data.get("gs_wizard_chat_id") + + if not message_id or not chat_id: + return + + # Create a fake query object to reuse _show_wizard_amount_step + class FakeChat: + def __init__(self, chat_id): + self.id = chat_id + + class FakeQuery: + def __init__(self, bot, chat_id, message_id): + self.message = FakeMessage(bot, chat_id, message_id) + + class FakeMessage: + def __init__(self, bot, chat_id, message_id): + self.chat_id = chat_id + self.message_id = message_id + self._bot = bot + + async def edit_text(self, text, **kwargs): + await self._bot.edit_message_text( + chat_id=self.chat_id, + message_id=self.message_id, + text=text, + **kwargs + ) + + async def delete(self): + await self._bot.delete_message(chat_id=self.chat_id, message_id=self.message_id) + + fake_update = type('FakeUpdate', (), { + 'callback_query': FakeQuery(context.bot, chat_id, message_id), + 'effective_chat': FakeChat(chat_id) + })() + await _show_wizard_amount_step(fake_update, context) + + async def _update_wizard_message_for_prices_after_edit(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Update prices display after editing prices - regenerate chart with new prices and grid analysis""" config = get_controller_config(context) @@ -2794,7 +3353,7 @@ async def _update_wizard_message_for_prices_after_edit(update: Update, context: limit = config.get("limit_price", 0) current_price = context.user_data.get("gs_current_price", 0) candles = context.user_data.get("gs_candles") - interval = context.user_data.get("gs_chart_interval", "1m") + interval = context.user_data.get("gs_chart_interval", "5m") total_amount = config.get("total_amount_quote", 1000) min_spread = config.get("min_spread_between_orders", 0.0001) take_profit = config.get("triple_barrier_config", {}).get("take_profit", 0.0001) @@ -2834,6 +3393,9 @@ async def _update_wizard_message_for_prices_after_edit(update: Update, context: max_open_orders = config.get("max_open_orders", 3) order_frequency = config.get("order_frequency", 3) leverage = config.get("leverage", 1) + position_mode = config.get("position_mode", "ONEWAY") + coerce_tp_to_step = config.get("coerce_tp_to_step", False) + activation_bounds = config.get("activation_bounds", 0.01) side_value = config.get("side", SIDE_LONG) side_str = "LONG" if side_value == SIDE_LONG else "SHORT" @@ -2846,18 +3408,23 @@ async def _update_wizard_message_for_prices_after_edit(update: Update, context: config_text = ( f"*{escape_markdown_v2(pair)}* {side_str}\n" f"Price: `{current_price:,.6g}` \\| Range: `{range_pct}` \\| NATR: `{natr_pct}`\n\n" + f"`connector_name={connector}`\n" + f"`trading_pair={pair}`\n" f"`total_amount_quote={total_amount:.0f}`\n" f"`start_price={start:.6g}`\n" f"`end_price={end:.6g}`\n" f"`limit_price={limit:.6g}`\n" f"`leverage={leverage}`\n" + f"`position_mode={position_mode}`\n" f"`take_profit={take_profit}`\n" + f"`coerce_tp_to_step={str(coerce_tp_to_step).lower()}`\n" f"`min_spread_between_orders={min_spread}`\n" f"`min_order_amount_quote={min_order_amount:.0f}`\n" - f"`max_open_orders={max_open_orders}`\n\n" + f"`max_open_orders={max_open_orders}`\n" + f"`activation_bounds={activation_bounds}`\n\n" f"{grid_valid} Grid: `{grid['num_levels']}` levels " f"\\(↓{grid.get('levels_below_current', 0)} ↑{grid.get('levels_above_current', 0)}\\) " - f"@ `${grid['amount_per_level']:.2f}`/lvl" + f"@ `${grid['amount_per_level']:.2f}`/lvl \\| step: `{grid.get('spread_pct', 0):.3f}%`" ) # Add warnings if any @@ -2913,10 +3480,130 @@ async def _update_wizard_message_for_prices_after_edit(update: Update, context: logger.error(f"Error updating prices message: {e}", exc_info=True) -async def handle_gs_edit_price(update: Update, context: ContextTypes.DEFAULT_TYPE, price_type: str) -> None: - """Handle price editing request""" - query = update.callback_query - config = get_controller_config(context) +async def _update_wizard_caption_only(update: Update, context: ContextTypes.DEFAULT_TYPE, warning_msg: str = None) -> None: + """ + Update only the caption of the chart message without regenerating the chart. + + Use this when changing fields that don't affect the visual representation + (e.g., min_order_amount, take_profit, activation_bounds, etc.) + """ + config = get_controller_config(context) + message_id = context.user_data.get("gs_wizard_message_id") + chat_id = context.user_data.get("gs_wizard_chat_id") + + if not message_id or not chat_id: + return + + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + side = config.get("side", SIDE_LONG) + side_value = config.get("side", SIDE_LONG) + side_str = "LONG" if side_value == SIDE_LONG else "SHORT" + start = config.get("start_price", 0) + end = config.get("end_price", 0) + limit = config.get("limit_price", 0) + current_price = context.user_data.get("gs_current_price", 0) + interval = context.user_data.get("gs_chart_interval", "5m") + total_amount = config.get("total_amount_quote", 1000) + min_spread = config.get("min_spread_between_orders", 0.0001) + take_profit = config.get("triple_barrier_config", {}).get("take_profit", 0.0001) + min_order_amount = config.get("min_order_amount_quote", 6) + natr = context.user_data.get("gs_natr") + trading_rules = context.user_data.get("gs_trading_rules", {}) + + # Get config values + max_open_orders = config.get("max_open_orders", 3) + leverage = config.get("leverage", 1) + position_mode = config.get("position_mode", "ONEWAY") + coerce_tp_to_step = config.get("coerce_tp_to_step", False) + activation_bounds = config.get("activation_bounds", 0.01) + + # Regenerate theoretical grid with updated parameters + grid = generate_theoretical_grid( + start_price=start, + end_price=end, + min_spread=min_spread, + total_amount=total_amount, + min_order_amount=min_order_amount, + current_price=current_price, + side=side, + trading_rules=trading_rules, + ) + context.user_data["gs_theoretical_grid"] = grid + + # Build interval buttons with current one highlighted + interval_options = ["1m", "5m", "15m", "1h", "4h"] + interval_row = [] + for opt in interval_options: + label = f"✓ {opt}" if opt == interval else opt + interval_row.append(InlineKeyboardButton(label, callback_data=f"bots:gs_interval:{opt}")) + + keyboard = [ + interval_row, + [ + InlineKeyboardButton("💾 Save Config", callback_data="bots:gs_save"), + ], + [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + ] + + # Grid analysis info + grid_valid = "✓" if grid.get("valid") else "⚠️" + natr_pct = f"{natr*100:.2f}%" if natr else "N/A" + range_pct = f"{grid.get('grid_range_pct', 0):.2f}%" + + # Build config text with individually copyable key=value params + config_text = ( + f"*{escape_markdown_v2(pair)}* {side_str}\n" + f"Price: `{current_price:,.6g}` \\| Range: `{range_pct}` \\| NATR: `{natr_pct}`\n\n" + f"`connector_name={connector}`\n" + f"`trading_pair={pair}`\n" + f"`total_amount_quote={total_amount:.0f}`\n" + f"`start_price={start:.6g}`\n" + f"`end_price={end:.6g}`\n" + f"`limit_price={limit:.6g}`\n" + f"`leverage={leverage}`\n" + f"`position_mode={position_mode}`\n" + f"`take_profit={take_profit}`\n" + f"`coerce_tp_to_step={str(coerce_tp_to_step).lower()}`\n" + f"`min_spread_between_orders={min_spread}`\n" + f"`min_order_amount_quote={min_order_amount:.0f}`\n" + f"`max_open_orders={max_open_orders}`\n" + f"`activation_bounds={activation_bounds}`\n\n" + f"{grid_valid} Grid: `{grid['num_levels']}` levels " + f"\\(↓{grid.get('levels_below_current', 0)} ↑{grid.get('levels_above_current', 0)}\\) " + f"@ `${grid['amount_per_level']:.2f}`/lvl \\| step: `{grid.get('spread_pct', 0):.3f}%`" + ) + + # Add warnings if any + if grid.get("warnings"): + warnings_text = "\n".join(f"⚠️ {escape_markdown_v2(w)}" for w in grid["warnings"]) + config_text += f"\n{warnings_text}" + + # Add user warning message if provided + if warning_msg: + config_text += f"\n\n⚠️ {escape_markdown_v2(warning_msg)}" + + config_text += "\n\n_Edit: `field=value`_" + + try: + # Try to edit the caption of the existing photo message + await context.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption=config_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception as e: + # If editing caption fails (e.g., message is text not photo), fall back to full update + logger.warning(f"Caption edit failed, falling back to full update: {e}") + await _update_wizard_message_for_prices_after_edit(update, context) + + +async def handle_gs_edit_price(update: Update, context: ContextTypes.DEFAULT_TYPE, price_type: str) -> None: + """Handle price editing request""" + query = update.callback_query + config = get_controller_config(context) price_map = { "start": ("start_price", "Start Price"), @@ -3019,31 +3706,32 @@ async def show_config_form(update: Update, context: ContextTypes.DEFAULT_TYPE) - InlineKeyboardButton("Pair", callback_data="bots:set_field:trading_pair"), ]) - # Row 2: Side and Leverage + # Row 2: Side, Leverage, Position Mode keyboard.append([ InlineKeyboardButton("Side", callback_data="bots:toggle_side"), InlineKeyboardButton("Leverage", callback_data="bots:set_field:leverage"), - InlineKeyboardButton("Amount", callback_data="bots:set_field:total_amount_quote"), + InlineKeyboardButton("Pos Mode", callback_data="bots:toggle_position_mode"), ]) - # Row 3: Prices + # Row 3: Amount and Prices keyboard.append([ - InlineKeyboardButton("Start Price", callback_data="bots:set_field:start_price"), - InlineKeyboardButton("End Price", callback_data="bots:set_field:end_price"), - InlineKeyboardButton("Limit Price", callback_data="bots:set_field:limit_price"), + InlineKeyboardButton("Amount", callback_data="bots:set_field:total_amount_quote"), + InlineKeyboardButton("Start", callback_data="bots:set_field:start_price"), + InlineKeyboardButton("End", callback_data="bots:set_field:end_price"), ]) - # Row 4: Advanced + # Row 4: Limit Price and Order Settings keyboard.append([ + InlineKeyboardButton("Limit", callback_data="bots:set_field:limit_price"), InlineKeyboardButton("Max Orders", callback_data="bots:set_field:max_open_orders"), InlineKeyboardButton("Min Spread", callback_data="bots:set_field:min_spread_between_orders"), - InlineKeyboardButton("Take Profit", callback_data="bots:set_field:take_profit"), ]) - # Row 5: Order Types + # Row 5: Take Profit and Order Types keyboard.append([ - InlineKeyboardButton("Open Order Type", callback_data="bots:cycle_order_type:open"), - InlineKeyboardButton("TP Order Type", callback_data="bots:cycle_order_type:tp"), + InlineKeyboardButton("Take Profit", callback_data="bots:set_field:take_profit"), + InlineKeyboardButton("Open Type", callback_data="bots:cycle_order_type:open"), + InlineKeyboardButton("TP Type", callback_data="bots:cycle_order_type:tp"), ]) # Row 6: Actions @@ -3128,10 +3816,10 @@ async def show_connector_selector(update: Update, context: ContextTypes.DEFAULT_ chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) + client, server_name = await get_bots_client(chat_id, context.user_data) # Get available CEX connectors (with cache) - cex_connectors = await get_available_cex_connectors(context.user_data, client) + cex_connectors = await get_available_cex_connectors(context.user_data, client, server_name=server_name) if not cex_connectors: await query.answer("No CEX connectors configured", show_alert=True) @@ -3208,7 +3896,7 @@ async def fetch_and_apply_market_data(update: Update, context: ContextTypes.DEFA return try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Show loading message await query.message.edit_text( @@ -3326,6 +4014,21 @@ async def handle_toggle_side(update: Update, context: ContextTypes.DEFAULT_TYPE) await show_config_form(update, context) +async def handle_toggle_position_mode(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Toggle the position mode between HEDGE and ONEWAY""" + query = update.callback_query + config = get_controller_config(context) + + current_mode = config.get("position_mode", "ONEWAY") + new_mode = "ONEWAY" if current_mode == "HEDGE" else "HEDGE" + config["position_mode"] = new_mode + + set_controller_config(context, config) + + # Refresh the form + await show_config_form(update, context) + + async def handle_cycle_order_type(update: Update, context: ContextTypes.DEFAULT_TYPE, order_type_key: str) -> None: """Cycle the order type between Market, Limit, and Limit Maker @@ -3424,7 +4127,7 @@ async def process_field_input(update: Update, context: ContextTypes.DEFAULT_TYPE ) try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) connector = config.get("connector_name") pair = config.get("trading_pair") side = config.get("side", SIDE_LONG) @@ -3523,7 +4226,7 @@ async def handle_save_config(update: Update, context: ContextTypes.DEFAULT_TYPE) return try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Save to backend using config id as the config_name config_name = config.get("id", "") @@ -3646,7 +4349,7 @@ async def show_deploy_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) - chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) configs = await client.controllers.list_controller_configs() if not configs: @@ -4205,7 +4908,7 @@ async def handle_execute_deploy(update: Update, context: ContextTypes.DEFAULT_TY ) try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Deploy using deploy_v2_controllers (this can take time) result = await client.bot_orchestration.deploy_v2_controllers( @@ -4339,9 +5042,11 @@ async def show_deploy_config_step(update: Update, context: ContextTypes.DEFAULT_ controllers_block, "```", "", - f"*Name:* `{escape_markdown_v2(instance_name)}`", - f"*Account:* `{escape_markdown_v2(creds)}`", - f"*Image:* `{escape_markdown_v2(image_short)}`", + r"*Configuration*", + "", + f" 📝 *Name:* `{escape_markdown_v2(instance_name)}`", + f" 👤 *Account:* `{escape_markdown_v2(creds)}`", + f" 🐳 *Image:* `{escape_markdown_v2(image_short)}`", "", r"_Tap buttons below to change settings_", ] @@ -4377,7 +5082,7 @@ async def handle_select_credentials(update: Update, context: ContextTypes.DEFAUL if creds == "_show": # Show available credentials profiles try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) available_creds = await _get_available_credentials(client) except Exception: available_creds = ["master_account"] @@ -4687,7 +5392,7 @@ async def process_deploy_custom_name_input(update: Update, context: ContextTypes logger.error(f"Error updating deploy message: {e}") try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) result = await client.bot_orchestration.deploy_v2_controllers( instance_name=custom_name, @@ -4766,12 +5471,8 @@ async def process_deploy_custom_name_input(update: Update, context: ContextTypes # ============================================ from .controllers.pmm_mister import ( - DEFAULTS as PMM_DEFAULTS, - WIZARD_STEPS as PMM_WIZARD_STEPS, validate_config as pmm_validate_config, generate_id as pmm_generate_id, - parse_spreads, - format_spreads, ) @@ -4781,7 +5482,7 @@ async def show_new_pmm_mister_form(update: Update, context: ContextTypes.DEFAULT chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) configs = await client.controllers.list_controller_configs() context.user_data["controller_configs_list"] = configs except Exception as e: @@ -4802,14 +5503,19 @@ async def _show_pmm_wizard_connector_step(update: Update, context: ContextTypes. chat_id = update.effective_chat.id try: - client = await get_bots_client(chat_id) - cex_connectors = await get_available_cex_connectors(context.user_data, client) + client, server_name = await get_bots_client(chat_id, context.user_data) + cex_connectors = await get_available_cex_connectors(context.user_data, client, server_name=server_name) if not cex_connectors: - keyboard = [[InlineKeyboardButton("Back", callback_data="bots:main_menu")]] + keyboard = [ + [InlineKeyboardButton("🔑 Configure API Keys", callback_data="config_api_keys")], + [InlineKeyboardButton("« Back", callback_data="bots:main_menu")] + ] await query.message.edit_text( r"*PMM Mister \- New Config*" + "\n\n" - r"No CEX connectors configured\.", + r"⚠️ No CEX connectors available\." + "\n\n" + r"You need to connect API keys for an exchange to deploy strategies\." + "\n" + r"Click below to configure your API keys\.", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -4828,7 +5534,7 @@ async def _show_pmm_wizard_connector_step(update: Update, context: ContextTypes. await query.message.edit_text( r"*📈 PMM Mister \- New Config*" + "\n\n" - r"*Step 1/7:* 🏦 Select Connector", + r"*Step 1/8:* 🏦 Select Connector", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -4881,12 +5587,15 @@ async def _show_pmm_wizard_pair_step(update: Update, context: ContextTypes.DEFAU row = [] if row: keyboard.append(row) - keyboard.append([InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")]) + keyboard.append([ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:connector"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ]) await query.message.edit_text( r"*📈 PMM Mister \- New Config*" + "\n\n" f"*Connector:* `{escape_markdown_v2(connector)}`" + "\n\n" - r"*Step 2/7:* 🔗 Trading Pair" + "\n\n" + r"*Step 2/8:* 🔗 Trading Pair" + "\n\n" r"Select or type a pair:", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) @@ -4931,13 +5640,16 @@ async def _show_pmm_wizard_leverage_step(update: Update, context: ContextTypes.D InlineKeyboardButton("50x", callback_data="bots:pmm_leverage:50"), InlineKeyboardButton("75x", callback_data="bots:pmm_leverage:75"), ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:pair"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await query.message.edit_text( r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n\n" - r"*Step 3/7:* ⚡ Leverage", + r"*Step 3/8:* ⚡ Leverage", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -4960,25 +5672,35 @@ async def _show_pmm_wizard_allocation_step(update: Update, context: ContextTypes pair = config.get("trading_pair", "") leverage = config.get("leverage", 20) + context.user_data["bots_state"] = "pmm_wizard_input" + context.user_data["pmm_wizard_step"] = "portfolio_allocation" + + # Back goes to leverage for perpetual, or pair for spot + back_target = "leverage" if connector.endswith("_perpetual") else "pair" + keyboard = [ [ InlineKeyboardButton("1%", callback_data="bots:pmm_alloc:0.01"), InlineKeyboardButton("2%", callback_data="bots:pmm_alloc:0.02"), - InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), + InlineKeyboardButton("3%", callback_data="bots:pmm_alloc:0.03"), ], [ + InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), InlineKeyboardButton("10%", callback_data="bots:pmm_alloc:0.1"), InlineKeyboardButton("20%", callback_data="bots:pmm_alloc:0.2"), - InlineKeyboardButton("50%", callback_data="bots:pmm_alloc:0.5"), ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data=f"bots:pmm_back:{back_target}"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await query.message.edit_text( r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" f"⚡ `{leverage}x`" + "\n\n" - r"*Step 4/7:* 💰 Portfolio Allocation", + r"*Step 4/8:* 💰 Portfolio Allocation" + "\n\n" + r"_Or type a custom value \(e\.g\. 3% or 0\.03\)_", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -4989,80 +5711,281 @@ async def handle_pmm_wizard_allocation(update: Update, context: ContextTypes.DEF config = get_controller_config(context) config["portfolio_allocation"] = allocation set_controller_config(context, config) - context.user_data["pmm_wizard_step"] = "spreads" - await _show_pmm_wizard_spreads_step(update, context) + context.user_data["pmm_wizard_step"] = "total_amount_quote" + await _show_pmm_wizard_amount_step(update, context) -async def _show_pmm_wizard_spreads_step(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """PMM Wizard Step 5: Spreads""" +async def _show_pmm_wizard_amount_step(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """PMM Wizard Step 5: Total Amount Quote""" query = update.callback_query + chat_id = update.effective_chat.id config = get_controller_config(context) + connector = config.get("connector_name", "") pair = config.get("trading_pair", "") leverage = config.get("leverage", 20) allocation = config.get("portfolio_allocation", 0.05) context.user_data["bots_state"] = "pmm_wizard_input" - context.user_data["pmm_wizard_step"] = "spreads" + context.user_data["pmm_wizard_step"] = "total_amount_quote" + + # Extract base and quote tokens from pair + base_token, quote_token = "", "" + if "-" in pair: + base_token, quote_token = pair.split("-", 1) + + # Fetch balances for the connector + balance_text = "" + try: + client, _ = await get_bots_client(chat_id, context.user_data) + balances = await get_cex_balances( + context.user_data, client, "master_account", ttl=30 + ) + + # Try to find connector balances with flexible matching + connector_balances = [] + connector_lower = connector.lower() + connector_base = connector_lower.replace("_perpetual", "").replace("_spot", "") + + for bal_connector, bal_list in balances.items(): + bal_lower = bal_connector.lower() + bal_base = bal_lower.replace("_perpetual", "").replace("_spot", "") + if bal_lower == connector_lower or bal_base == connector_base: + connector_balances = bal_list + break + + if connector_balances: + relevant_balances = [] + for bal in connector_balances: + token = bal.get("token", bal.get("asset", "")) + available = bal.get("units", bal.get("available_balance", bal.get("free", 0))) + value_usd = bal.get("value", 0) + if token and available: + try: + available_float = float(available) + if available_float > 0: + if token.upper() in [quote_token.upper(), base_token.upper()]: + relevant_balances.append((token, available_float, float(value_usd) if value_usd else None)) + except (ValueError, TypeError): + continue + + if relevant_balances: + bal_lines = [] + for token, available, value_usd in relevant_balances: + if available >= 1000: + amt_str = f"{available:,.0f}" + elif available >= 1: + amt_str = f"{available:,.2f}" + else: + amt_str = f"{available:,.6f}" + + if value_usd and value_usd >= 1: + bal_lines.append(f"{token}: {amt_str} (${value_usd:,.0f})") + else: + bal_lines.append(f"{token}: {amt_str}") + balance_text = "💼 *Available:* " + " \\| ".join( + escape_markdown_v2(b) for b in bal_lines + ) + "\n\n" + else: + balance_text = f"_No {escape_markdown_v2(quote_token)} balance on {escape_markdown_v2(connector)}_\n\n" + elif balances: + balance_text = f"_No {escape_markdown_v2(quote_token)} balance found_\n\n" + except Exception as e: + logger.warning(f"Could not fetch balances for PMM amount step: {e}") keyboard = [ - [InlineKeyboardButton("Tight: 0.02%, 0.1%", callback_data="bots:pmm_spreads:0.0002,0.001")], - [InlineKeyboardButton("Normal: 0.5%, 1%", callback_data="bots:pmm_spreads:0.005,0.01")], - [InlineKeyboardButton("Wide: 1%, 2%", callback_data="bots:pmm_spreads:0.01,0.02")], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("💵 100", callback_data="bots:pmm_amount:100"), + InlineKeyboardButton("💵 500", callback_data="bots:pmm_amount:500"), + InlineKeyboardButton("💵 1000", callback_data="bots:pmm_amount:1000"), + ], + [ + InlineKeyboardButton("💰 2000", callback_data="bots:pmm_amount:2000"), + InlineKeyboardButton("💰 5000", callback_data="bots:pmm_amount:5000"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:allocation"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu"), + ], ] - await query.message.edit_text( + message_text = ( r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" - f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%`" + "\n\n" - r"*Step 5/7:* 📊 Spreads" + "\n\n" - r"_Or type custom: `0\.01,0\.02`_", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard) + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.1f}%`" + "\n\n" + + balance_text + + r"*Step 5/8:* 💵 Total Amount \(Quote\)" + "\n\n" + r"Select or type amount:" ) + try: + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception: + pass + -async def handle_pmm_wizard_spreads(update: Update, context: ContextTypes.DEFAULT_TYPE, spreads: str) -> None: - """Handle spreads selection""" +async def handle_pmm_wizard_amount(update: Update, context: ContextTypes.DEFAULT_TYPE, amount: float) -> None: + """Handle amount selection in PMM wizard""" config = get_controller_config(context) - config["buy_spreads"] = spreads - config["sell_spreads"] = spreads + config["total_amount_quote"] = amount set_controller_config(context, config) - context.user_data["pmm_wizard_step"] = "take_profit" - await _show_pmm_wizard_tp_step(update, context) + context.user_data["pmm_wizard_step"] = "spreads" + await _show_pmm_wizard_spreads_step(update, context) -async def _show_pmm_wizard_tp_step(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """PMM Wizard Step 6: Take Profit""" +async def _show_pmm_wizard_spreads_step(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """PMM Wizard Step 6: Spreads""" query = update.callback_query config = get_controller_config(context) connector = config.get("connector_name", "") pair = config.get("trading_pair", "") leverage = config.get("leverage", 20) allocation = config.get("portfolio_allocation", 0.05) - spreads = config.get("buy_spreads", "0.0002,0.001") + + context.user_data["bots_state"] = "pmm_wizard_input" + context.user_data["pmm_wizard_step"] = "spreads" + + amount = config.get("total_amount_quote", 100) keyboard = [ + [InlineKeyboardButton("Tight: 0.02%, 0.1%", callback_data="bots:pmm_spreads:0.0002,0.001")], + [InlineKeyboardButton("Normal: 0.5%, 1%", callback_data="bots:pmm_spreads:0.005,0.01")], + [InlineKeyboardButton("Wide: 1%, 2%", callback_data="bots:pmm_spreads:0.01,0.02")], [ - InlineKeyboardButton("0.01%", callback_data="bots:pmm_tp:0.0001"), - InlineKeyboardButton("0.02%", callback_data="bots:pmm_tp:0.0002"), - InlineKeyboardButton("0.05%", callback_data="bots:pmm_tp:0.0005"), - ], - [ - InlineKeyboardButton("0.1%", callback_data="bots:pmm_tp:0.001"), - InlineKeyboardButton("0.2%", callback_data="bots:pmm_tp:0.002"), - InlineKeyboardButton("0.5%", callback_data="bots:pmm_tp:0.005"), + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:amount"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], ] await query.message.edit_text( r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" - f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%`" + "\n" + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%` \\| 💵 `{amount:,.0f}`" + "\n\n" + r"*Step 6/8:* 📊 Spreads" + "\n\n" + r"_Or type custom: `0\.01,0\.02`_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _show_pmm_wizard_spreads_step_msg(context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, config: dict) -> None: + """Show spreads step via direct message edit (for text input flow)""" + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + leverage = config.get("leverage", 20) + allocation = config.get("portfolio_allocation", 0.05) + amount = config.get("total_amount_quote", 100) + + context.user_data["bots_state"] = "pmm_wizard_input" + context.user_data["pmm_wizard_step"] = "spreads" + + keyboard = [ + [InlineKeyboardButton("Tight: 0.02%, 0.1%", callback_data="bots:pmm_spreads:0.0002,0.001")], + [InlineKeyboardButton("Normal: 0.5%, 1%", callback_data="bots:pmm_spreads:0.005,0.01")], + [InlineKeyboardButton("Wide: 1%, 2%", callback_data="bots:pmm_spreads:0.01,0.02")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:amount"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], + ] + + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%` \\| 💵 `{amount:,.0f}`" + "\n\n" + r"*Step 6/8:* 📊 Spreads" + "\n\n" + r"_Or type custom: `0\.01,0\.02`_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def _show_pmm_wizard_amount_step_msg(context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, config: dict) -> None: + """Show amount step via direct message edit (for text input flow)""" + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + leverage = config.get("leverage", 20) + allocation = config.get("portfolio_allocation", 0.05) + + context.user_data["bots_state"] = "pmm_wizard_input" + context.user_data["pmm_wizard_step"] = "total_amount_quote" + + keyboard = [ + [ + InlineKeyboardButton("💵 100", callback_data="bots:pmm_amount:100"), + InlineKeyboardButton("💵 500", callback_data="bots:pmm_amount:500"), + InlineKeyboardButton("💵 1000", callback_data="bots:pmm_amount:1000"), + ], + [ + InlineKeyboardButton("💰 2000", callback_data="bots:pmm_amount:2000"), + InlineKeyboardButton("💰 5000", callback_data="bots:pmm_amount:5000"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:allocation"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu"), + ], + ] + + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.1f}%`" + "\n\n" + r"*Step 5/8:* 💵 Total Amount \(Quote\)" + "\n\n" + r"Select or type amount:", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_pmm_wizard_spreads(update: Update, context: ContextTypes.DEFAULT_TYPE, spreads: str) -> None: + """Handle spreads selection""" + config = get_controller_config(context) + config["buy_spreads"] = spreads + config["sell_spreads"] = spreads + set_controller_config(context, config) + context.user_data["pmm_wizard_step"] = "take_profit" + await _show_pmm_wizard_tp_step(update, context) + + +async def _show_pmm_wizard_tp_step(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """PMM Wizard Step 7: Take Profit""" + query = update.callback_query + config = get_controller_config(context) + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + leverage = config.get("leverage", 20) + allocation = config.get("portfolio_allocation", 0.05) + spreads = config.get("buy_spreads", "0.0002,0.001") + + keyboard = [ + [ + InlineKeyboardButton("0.01%", callback_data="bots:pmm_tp:0.0001"), + InlineKeyboardButton("0.02%", callback_data="bots:pmm_tp:0.0002"), + InlineKeyboardButton("0.05%", callback_data="bots:pmm_tp:0.0005"), + ], + [ + InlineKeyboardButton("0.1%", callback_data="bots:pmm_tp:0.001"), + InlineKeyboardButton("0.2%", callback_data="bots:pmm_tp:0.002"), + InlineKeyboardButton("0.5%", callback_data="bots:pmm_tp:0.005"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:spreads"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], + ] + + await query.message.edit_text( + r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%`" + "\n" f"📊 Spreads: `{escape_markdown_v2(spreads)}`" + "\n\n" - r"*Step 6/7:* 🎯 Take Profit", + r"*Step 7/8:* 🎯 Take Profit", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -5091,22 +6014,58 @@ async def _show_pmm_wizard_review_step(update: Update, context: ContextTypes.DEF context.user_data["bots_state"] = "pmm_wizard_input" context.user_data["pmm_wizard_step"] = "review" + # Format order types as string + tp_order_type = config.get('take_profit_order_type', "LIMIT_MAKER") + if isinstance(tp_order_type, int): + tp_order_type_str = ORDER_TYPE_LABELS.get(tp_order_type, "LIMIT_MAKER") + else: + tp_order_type_str = str(tp_order_type) + + open_order_type = config.get('open_order_type', "LIMIT") + if isinstance(open_order_type, int): + open_order_type_str = ORDER_TYPE_LABELS.get(open_order_type, "LIMIT") + else: + open_order_type_str = str(open_order_type) + + # Calculate amounts_pct based on spreads if not set + buy_spreads = config.get('buy_spreads', '0.0002,0.001') + sell_spreads = config.get('sell_spreads', '0.0002,0.001') + buy_amounts = config.get('buy_amounts_pct') + sell_amounts = config.get('sell_amounts_pct') + + if not buy_amounts: + num_buy_spreads = len(buy_spreads.split(',')) if buy_spreads else 1 + buy_amounts = ','.join(['1'] * num_buy_spreads) + if not sell_amounts: + num_sell_spreads = len(sell_spreads.split(',')) if sell_spreads else 1 + sell_amounts = ','.join(['1'] * num_sell_spreads) + # Build copyable config block config_block = ( f"id: {config.get('id', '')}\n" f"connector_name: {config.get('connector_name', '')}\n" f"trading_pair: {config.get('trading_pair', '')}\n" f"leverage: {config.get('leverage', 1)}\n" + f"position_mode: {config.get('position_mode', 'HEDGE')}\n" + f"total_amount_quote: {config.get('total_amount_quote', 100)}\n" f"portfolio_allocation: {config.get('portfolio_allocation', 0.05)}\n" - f"buy_spreads: {config.get('buy_spreads', '0.0002,0.001')}\n" - f"sell_spreads: {config.get('sell_spreads', '0.0002,0.001')}\n" + f"target_base_pct: {config.get('target_base_pct', 0.5)}\n" + f"min_base_pct: {config.get('min_base_pct', 0.4)}\n" + f"max_base_pct: {config.get('max_base_pct', 0.6)}\n" + f"buy_spreads: {buy_spreads}\n" + f"sell_spreads: {sell_spreads}\n" + f"buy_amounts_pct: {buy_amounts}\n" + f"sell_amounts_pct: {sell_amounts}\n" f"take_profit: {config.get('take_profit', 0.0001)}\n" - f"target_base_pct: {config.get('target_base_pct', 0.2)}\n" - f"min_base_pct: {config.get('min_base_pct', 0.1)}\n" - f"max_base_pct: {config.get('max_base_pct', 0.4)}\n" + f"take_profit_order_type: {tp_order_type_str}\n" + f"open_order_type: {open_order_type_str}\n" f"executor_refresh_time: {config.get('executor_refresh_time', 30)}\n" f"buy_cooldown_time: {config.get('buy_cooldown_time', 15)}\n" f"sell_cooldown_time: {config.get('sell_cooldown_time', 15)}\n" + f"buy_position_effectivization_time: {config.get('buy_position_effectivization_time', 3600)}\n" + f"sell_position_effectivization_time: {config.get('sell_position_effectivization_time', 3600)}\n" + f"min_buy_price_distance_pct: {config.get('min_buy_price_distance_pct', 0.003)}\n" + f"min_sell_price_distance_pct: {config.get('min_sell_price_distance_pct', 0.003)}\n" f"max_active_executors_by_level: {config.get('max_active_executors_by_level', 4)}" ) @@ -5114,14 +6073,15 @@ async def _show_pmm_wizard_review_step(update: Update, context: ContextTypes.DEF message_text = ( f"*{escape_markdown_v2(pair)}* \\- Review Config\n\n" f"```\n{config_block}\n```\n\n" - f"_To edit, send `field: value` lines:_\n" - f"`leverage: 20`\n" - f"`take_profit: 0.001`" + f"_To edit, send `field: value` lines_" ) keyboard = [ [InlineKeyboardButton("✅ Save Config", callback_data="bots:pmm_save")], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:tp"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await query.message.edit_text( @@ -5131,6 +6091,26 @@ async def _show_pmm_wizard_review_step(update: Update, context: ContextTypes.DEF ) +async def handle_pmm_back(update: Update, context: ContextTypes.DEFAULT_TYPE, target: str) -> None: + """Handle back navigation in PMM wizard""" + query = update.callback_query + + if target == "connector": + await _show_pmm_wizard_connector_step(update, context) + elif target == "pair": + await _show_pmm_wizard_pair_step(update, context) + elif target == "leverage": + await _show_pmm_wizard_leverage_step(update, context) + elif target == "allocation": + await _show_pmm_wizard_allocation_step(update, context) + elif target == "amount": + await _show_pmm_wizard_amount_step(update, context) + elif target == "spreads": + await _show_pmm_wizard_spreads_step(update, context) + elif target == "tp": + await _show_pmm_wizard_tp_step(update, context) + + async def handle_pmm_save(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Save PMM config""" query = update.callback_query @@ -5148,7 +6128,7 @@ async def handle_pmm_save(update: Update, context: ContextTypes.DEFAULT_TYPE) -> return try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) config_id = config.get("id", "") result = await client.controllers.create_or_update_controller_config(config_id, config) @@ -5242,19 +6222,19 @@ async def handle_pmm_edit_field(update: Update, context: ContextTypes.DEFAULT_TY [ InlineKeyboardButton("1%", callback_data="bots:pmm_set:allocation:0.01"), InlineKeyboardButton("2%", callback_data="bots:pmm_set:allocation:0.02"), - InlineKeyboardButton("5%", callback_data="bots:pmm_set:allocation:0.05"), + InlineKeyboardButton("3%", callback_data="bots:pmm_set:allocation:0.03"), ], [ + InlineKeyboardButton("5%", callback_data="bots:pmm_set:allocation:0.05"), InlineKeyboardButton("10%", callback_data="bots:pmm_set:allocation:0.1"), InlineKeyboardButton("20%", callback_data="bots:pmm_set:allocation:0.2"), - InlineKeyboardButton("50%", callback_data="bots:pmm_set:allocation:0.5"), ], [InlineKeyboardButton("❌ Cancel", callback_data="bots:pmm_review_back")], ] await query.message.edit_text( r"*Edit Portfolio Allocation*" + "\n\n" f"Current: `{config.get('portfolio_allocation', 0.05)*100:.0f}%`" + "\n\n" - r"_Or type a custom value \(e\.g\. 0\.15 for 15%\)_", + r"_Or type a custom value \(e\.g\. 3% or 0\.03\)_", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -5385,6 +6365,129 @@ async def handle_pmm_adv_setting(update: Update, context: ContextTypes.DEFAULT_T ) +async def _show_pmm_pair_suggestions( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + input_pair: str, + error_msg: str, + suggestions: list, + connector: str +) -> None: + """Show trading pair suggestions when validation fails in PMM wizard""" + message_id = context.user_data.get("pmm_wizard_message_id") + chat_id = context.user_data.get("pmm_wizard_chat_id") + + # Build suggestion message + help_text = f"❌ *{escape_markdown_v2(error_msg)}*\n\n" + + if suggestions: + help_text += "💡 *Did you mean:*\n" + else: + help_text += "_No similar pairs found\\._\n" + + # Build keyboard with suggestions + keyboard = [] + for pair in suggestions: + keyboard.append([InlineKeyboardButton( + f"📈 {pair}", + callback_data=f"bots:pmm_pair_select:{pair}" + )]) + + keyboard.append([ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:connector"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ]) + reply_markup = InlineKeyboardMarkup(keyboard) + + if message_id and chat_id: + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + except Exception as e: + logger.debug(f"Could not update PMM wizard message: {e}") + else: + await update.effective_chat.send_message( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + +async def handle_pmm_pair_select(update: Update, context: ContextTypes.DEFAULT_TYPE, trading_pair: str) -> None: + """Handle selection of a suggested trading pair in PMM wizard""" + config = get_controller_config(context) + message_id = context.user_data.get("pmm_wizard_message_id") + chat_id = context.user_data.get("pmm_wizard_chat_id") + + config["trading_pair"] = trading_pair + set_controller_config(context, config) + connector = config.get("connector_name", "") + + # Only ask for leverage on perpetual exchanges + if connector.endswith("_perpetual"): + context.user_data["pmm_wizard_step"] = "leverage" + keyboard = [ + [ + InlineKeyboardButton("1x", callback_data="bots:pmm_leverage:1"), + InlineKeyboardButton("5x", callback_data="bots:pmm_leverage:5"), + InlineKeyboardButton("10x", callback_data="bots:pmm_leverage:10"), + ], + [ + InlineKeyboardButton("20x", callback_data="bots:pmm_leverage:20"), + InlineKeyboardButton("50x", callback_data="bots:pmm_leverage:50"), + InlineKeyboardButton("75x", callback_data="bots:pmm_leverage:75"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:pair"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], + ] + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(trading_pair)}`" + "\n\n" + r"*Step 3/8:* ⚡ Leverage", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + else: + # Spot exchange - set leverage to 1 and skip to allocation + config["leverage"] = 1 + set_controller_config(context, config) + context.user_data["bots_state"] = "pmm_wizard_input" + context.user_data["pmm_wizard_step"] = "portfolio_allocation" + keyboard = [ + [ + InlineKeyboardButton("1%", callback_data="bots:pmm_alloc:0.01"), + InlineKeyboardButton("2%", callback_data="bots:pmm_alloc:0.02"), + InlineKeyboardButton("3%", callback_data="bots:pmm_alloc:0.03"), + ], + [ + InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), + InlineKeyboardButton("10%", callback_data="bots:pmm_alloc:0.1"), + InlineKeyboardButton("20%", callback_data="bots:pmm_alloc:0.2"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:pair"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], + ] + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(trading_pair)}`" + "\n\n" + r"*Step 4/8:* 💰 Portfolio Allocation" + "\n\n" + r"_Or type a custom value \(e\.g\. 3% or 0\.03\)_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT_TYPE, user_input: str) -> None: """Process text input during PMM wizard""" step = context.user_data.get("pmm_wizard_step", "") @@ -5398,10 +6501,31 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT pass if step == "trading_pair": - config["trading_pair"] = user_input.upper() - set_controller_config(context, config) + pair = user_input.upper().strip() + if "-" not in pair: + pair = pair.replace("/", "-").replace("_", "-") + connector = config.get("connector_name", "") + # Validate trading pair exists on the connector + client, _ = await get_bots_client(chat_id, context.user_data) + is_valid, error_msg, suggestions = await validate_trading_pair( + context.user_data, client, connector, pair + ) + + if not is_valid: + # Show error with suggestions + await _show_pmm_pair_suggestions(update, context, pair, error_msg, suggestions, connector) + return + + # Get correctly formatted pair from trading rules + trading_rules = await get_trading_rules(context.user_data, client, connector) + correct_pair = get_correct_pair_format(trading_rules, pair) + pair = correct_pair if correct_pair else pair + + config["trading_pair"] = pair + set_controller_config(context, config) + # Only ask for leverage on perpetual exchanges if connector.endswith("_perpetual"): context.user_data["pmm_wizard_step"] = "leverage" @@ -5416,13 +6540,16 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT InlineKeyboardButton("50x", callback_data="bots:pmm_leverage:50"), InlineKeyboardButton("75x", callback_data="bots:pmm_leverage:75"), ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:pair"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await context.bot.edit_message_text( chat_id=chat_id, message_id=message_id, text=r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(config['trading_pair'])}`" + "\n\n" - r"*Step 3/7:* ⚡ Leverage", + r"*Step 3/8:* ⚡ Leverage", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -5430,25 +6557,114 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT # Spot exchange - set leverage to 1 and skip to allocation config["leverage"] = 1 set_controller_config(context, config) + context.user_data["bots_state"] = "pmm_wizard_input" context.user_data["pmm_wizard_step"] = "portfolio_allocation" keyboard = [ [ InlineKeyboardButton("1%", callback_data="bots:pmm_alloc:0.01"), InlineKeyboardButton("2%", callback_data="bots:pmm_alloc:0.02"), - InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), + InlineKeyboardButton("3%", callback_data="bots:pmm_alloc:0.03"), ], [ + InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), InlineKeyboardButton("10%", callback_data="bots:pmm_alloc:0.1"), InlineKeyboardButton("20%", callback_data="bots:pmm_alloc:0.2"), - InlineKeyboardButton("50%", callback_data="bots:pmm_alloc:0.5"), ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:pair"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await context.bot.edit_message_text( chat_id=chat_id, message_id=message_id, text=r"*📈 PMM Mister \- New Config*" + "\n\n" f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(config['trading_pair'])}`" + "\n\n" - r"*Step 4/7:* 💰 Portfolio Allocation", + r"*Step 4/8:* 💰 Portfolio Allocation" + "\n\n" + r"_Or type a custom value \(e\.g\. 3% or 0\.03\)_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + elif step == "portfolio_allocation": + # Parse allocation value (handle "3%" or "0.03" formats) + try: + val_str = user_input.strip().replace("%", "") + val = float(val_str) + if val > 1: # User entered percentage like "3" or "3%" + val = val / 100 + config["portfolio_allocation"] = val + set_controller_config(context, config) + context.user_data["pmm_wizard_step"] = "total_amount_quote" + await _show_pmm_wizard_amount_step_msg(context, chat_id, message_id, config) + except ValueError: + # Invalid input - show error and keep at same step + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + leverage = config.get("leverage", 20) + back_target = "leverage" if connector.endswith("_perpetual") else "pair" + keyboard = [ + [ + InlineKeyboardButton("1%", callback_data="bots:pmm_alloc:0.01"), + InlineKeyboardButton("2%", callback_data="bots:pmm_alloc:0.02"), + InlineKeyboardButton("3%", callback_data="bots:pmm_alloc:0.03"), + ], + [ + InlineKeyboardButton("5%", callback_data="bots:pmm_alloc:0.05"), + InlineKeyboardButton("10%", callback_data="bots:pmm_alloc:0.1"), + InlineKeyboardButton("20%", callback_data="bots:pmm_alloc:0.2"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data=f"bots:pmm_back:{back_target}"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], + ] + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" + f"⚡ `{leverage}x`" + "\n\n" + r"*Step 4/8:* 💰 Portfolio Allocation" + "\n\n" + r"⚠️ _Invalid value\. Enter a percentage \(e\.g\. 3% or 0\.03\)_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + elif step == "total_amount_quote": + # Parse amount value + try: + amount = float(user_input.strip().replace(",", "")) + config["total_amount_quote"] = amount + set_controller_config(context, config) + context.user_data["pmm_wizard_step"] = "spreads" + await _show_pmm_wizard_spreads_step_msg(context, chat_id, message_id, config) + except ValueError: + # Invalid input - show error and keep at same step + connector = config.get("connector_name", "") + pair = config.get("trading_pair", "") + leverage = config.get("leverage", 20) + allocation = config.get("portfolio_allocation", 0.05) + keyboard = [ + [ + InlineKeyboardButton("💵 100", callback_data="bots:pmm_amount:100"), + InlineKeyboardButton("💵 500", callback_data="bots:pmm_amount:500"), + InlineKeyboardButton("💵 1000", callback_data="bots:pmm_amount:1000"), + ], + [ + InlineKeyboardButton("💰 2000", callback_data="bots:pmm_amount:2000"), + InlineKeyboardButton("💰 5000", callback_data="bots:pmm_amount:5000"), + ], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:allocation"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu"), + ], + ] + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=r"*📈 PMM Mister \- New Config*" + "\n\n" + f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" + f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.1f}%`" + "\n\n" + r"*Step 5/8:* 💵 Total Amount \(Quote\)" + "\n\n" + r"⚠️ _Invalid value\. Enter a number \(e\.g\. 500\)_", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -5473,7 +6689,10 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT InlineKeyboardButton("0.2%", callback_data="bots:pmm_tp:0.002"), InlineKeyboardButton("0.5%", callback_data="bots:pmm_tp:0.005"), ], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:spreads"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] await context.bot.edit_message_text( chat_id=chat_id, message_id=message_id, @@ -5481,7 +6700,7 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT f"🏦 `{escape_markdown_v2(connector)}` \\| 🔗 `{escape_markdown_v2(pair)}`" + "\n" f"⚡ `{leverage}x` \\| 💰 `{allocation*100:.0f}%`" + "\n" f"📊 Spreads: `{escape_markdown_v2(user_input.strip())}`" + "\n\n" - r"*Step 6/7:* 🎯 Take Profit", + r"*Step 7/8:* 🎯 Take Profit", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) @@ -5570,16 +6789,26 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT "connector_name": ("connector_name", str), "trading_pair": ("trading_pair", str), "leverage": ("leverage", int), + "position_mode": ("position_mode", str), + "total_amount_quote": ("total_amount_quote", float), "portfolio_allocation": ("portfolio_allocation", float), - "buy_spreads": ("buy_spreads", str), - "sell_spreads": ("sell_spreads", str), - "take_profit": ("take_profit", float), "target_base_pct": ("target_base_pct", float), "min_base_pct": ("min_base_pct", float), "max_base_pct": ("max_base_pct", float), + "buy_spreads": ("buy_spreads", str), + "sell_spreads": ("sell_spreads", str), + "buy_amounts_pct": ("buy_amounts_pct", str), + "sell_amounts_pct": ("sell_amounts_pct", str), + "take_profit": ("take_profit", float), + "take_profit_order_type": ("take_profit_order_type", str), + "open_order_type": ("open_order_type", str), "executor_refresh_time": ("executor_refresh_time", int), "buy_cooldown_time": ("buy_cooldown_time", int), "sell_cooldown_time": ("sell_cooldown_time", int), + "buy_position_effectivization_time": ("buy_position_effectivization_time", int), + "sell_position_effectivization_time": ("sell_position_effectivization_time", int), + "min_buy_price_distance_pct": ("min_buy_price_distance_pct", float), + "min_sell_price_distance_pct": ("min_sell_price_distance_pct", float), "max_active_executors_by_level": ("max_active_executors_by_level", int), } @@ -5608,11 +6837,19 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT if field_name in field_map: config_key, type_fn = field_map[field_name] try: - if type_fn == str: + # Special handling for order type fields + if config_key in ("take_profit_order_type", "open_order_type"): + # Normalize input to uppercase with underscores + normalized = value_str.upper().replace(" ", "_") + if normalized in ("LIMIT_MAKER", "LIMIT", "MARKET"): + config[config_key] = normalized + updated_fields.append(field_name) + elif type_fn == str: config[config_key] = value_str + updated_fields.append(field_name) else: config[config_key] = type_fn(value_str) - updated_fields.append(field_name) + updated_fields.append(field_name) except (ValueError, TypeError): pass @@ -5623,22 +6860,58 @@ async def process_pmm_wizard_input(update: Update, context: ContextTypes.DEFAULT async def _pmm_show_review(context, chat_id, message_id, config): """Helper to show review step with copyable config format""" + # Format order types as string + tp_order_type = config.get('take_profit_order_type', "LIMIT_MAKER") + if isinstance(tp_order_type, int): + tp_order_type_str = ORDER_TYPE_LABELS.get(tp_order_type, "LIMIT_MAKER") + else: + tp_order_type_str = str(tp_order_type) + + open_order_type = config.get('open_order_type', "LIMIT") + if isinstance(open_order_type, int): + open_order_type_str = ORDER_TYPE_LABELS.get(open_order_type, "LIMIT") + else: + open_order_type_str = str(open_order_type) + + # Calculate amounts_pct based on spreads if not set + buy_spreads = config.get('buy_spreads', '0.0002,0.001') + sell_spreads = config.get('sell_spreads', '0.0002,0.001') + buy_amounts = config.get('buy_amounts_pct') + sell_amounts = config.get('sell_amounts_pct') + + if not buy_amounts: + num_buy_spreads = len(buy_spreads.split(',')) if buy_spreads else 1 + buy_amounts = ','.join(['1'] * num_buy_spreads) + if not sell_amounts: + num_sell_spreads = len(sell_spreads.split(',')) if sell_spreads else 1 + sell_amounts = ','.join(['1'] * num_sell_spreads) + # Build copyable config block config_block = ( f"id: {config.get('id', '')}\n" f"connector_name: {config.get('connector_name', '')}\n" f"trading_pair: {config.get('trading_pair', '')}\n" f"leverage: {config.get('leverage', 1)}\n" + f"position_mode: {config.get('position_mode', 'HEDGE')}\n" + f"total_amount_quote: {config.get('total_amount_quote', 100)}\n" f"portfolio_allocation: {config.get('portfolio_allocation', 0.05)}\n" - f"buy_spreads: {config.get('buy_spreads', '0.0002,0.001')}\n" - f"sell_spreads: {config.get('sell_spreads', '0.0002,0.001')}\n" + f"target_base_pct: {config.get('target_base_pct', 0.5)}\n" + f"min_base_pct: {config.get('min_base_pct', 0.4)}\n" + f"max_base_pct: {config.get('max_base_pct', 0.6)}\n" + f"buy_spreads: {buy_spreads}\n" + f"sell_spreads: {sell_spreads}\n" + f"buy_amounts_pct: {buy_amounts}\n" + f"sell_amounts_pct: {sell_amounts}\n" f"take_profit: {config.get('take_profit', 0.0001)}\n" - f"target_base_pct: {config.get('target_base_pct', 0.2)}\n" - f"min_base_pct: {config.get('min_base_pct', 0.1)}\n" - f"max_base_pct: {config.get('max_base_pct', 0.4)}\n" + f"take_profit_order_type: {tp_order_type_str}\n" + f"open_order_type: {open_order_type_str}\n" f"executor_refresh_time: {config.get('executor_refresh_time', 30)}\n" f"buy_cooldown_time: {config.get('buy_cooldown_time', 15)}\n" f"sell_cooldown_time: {config.get('sell_cooldown_time', 15)}\n" + f"buy_position_effectivization_time: {config.get('buy_position_effectivization_time', 3600)}\n" + f"sell_position_effectivization_time: {config.get('sell_position_effectivization_time', 3600)}\n" + f"min_buy_price_distance_pct: {config.get('min_buy_price_distance_pct', 0.003)}\n" + f"min_sell_price_distance_pct: {config.get('min_sell_price_distance_pct', 0.003)}\n" f"max_active_executors_by_level: {config.get('max_active_executors_by_level', 4)}" ) @@ -5646,22 +6919,28 @@ async def _pmm_show_review(context, chat_id, message_id, config): message_text = ( f"*{escape_markdown_v2(pair)}* \\- Review Config\n\n" f"```\n{config_block}\n```\n\n" - f"_To edit, send `field: value` lines:_\n" - f"`leverage: 20`\n" - f"`take_profit: 0.001`" + f"_To edit, send `field: value` lines_" ) keyboard = [ [InlineKeyboardButton("✅ Save Config", callback_data="bots:pmm_save")], - [InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu")], + [ + InlineKeyboardButton("⬅️ Back", callback_data="bots:pmm_back:tp"), + InlineKeyboardButton("❌ Cancel", callback_data="bots:main_menu") + ], ] - await context.bot.edit_message_text( - chat_id=chat_id, message_id=message_id, - text=message_text, - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard) - ) + try: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=message_id, + text=message_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception as e: + # Ignore "Message is not modified" error + if "Message is not modified" not in str(e): + raise async def _pmm_show_advanced(context, chat_id, message_id, config): @@ -5690,3 +6969,129 @@ async def _pmm_show_advanced(context, chat_id, message_id, config): parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) ) + + +# ============================================ +# CUSTOM CONFIG UPLOAD +# ============================================ + +async def show_upload_config_prompt(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show message prompting user to upload a YAML config file""" + query = update.callback_query + + # Set state to expect file upload + context.user_data["bots_state"] = "awaiting_config_upload" + + message_text = ( + r"*Upload Custom Config*" + "\n\n" + r"Upload a YAML file \(`.yml` or `.yaml`\) with your controller configuration\." + "\n\n" + r"The file should contain a valid controller config with at least an `id` field\." + ) + + keyboard = [ + [InlineKeyboardButton("❌ Cancel", callback_data="bots:upload_cancel")], + ] + + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_upload_cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Cancel the upload and return to configs menu""" + clear_bots_state(context) + await show_controller_configs_menu(update, context) + + +async def handle_config_file_upload(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle uploaded YAML config file""" + import yaml + + # Only process if we're expecting a config upload + if context.user_data.get("bots_state") != "awaiting_config_upload": + return + + chat_id = update.effective_chat.id + document = update.message.document + + # Check file extension + file_name = document.file_name or "" + if not file_name.lower().endswith(('.yml', '.yaml')): + await update.message.reply_text( + format_error_message("Please upload a YAML file (.yml or .yaml)"), + parse_mode="MarkdownV2" + ) + return + + try: + # Download the file + file = await context.bot.get_file(document.file_id) + file_bytes = await file.download_as_bytearray() + content = file_bytes.decode('utf-8') + + # Parse YAML + try: + config = yaml.safe_load(content) + except yaml.YAMLError as e: + await update.message.reply_text( + format_error_message(f"Invalid YAML file: {str(e)}"), + parse_mode="MarkdownV2" + ) + return + + if not isinstance(config, dict): + await update.message.reply_text( + format_error_message("YAML file must contain a dictionary/object"), + parse_mode="MarkdownV2" + ) + return + + # Validate minimum required field + config_id = config.get("id") + if not config_id: + await update.message.reply_text( + format_error_message("Config must have an 'id' field"), + parse_mode="MarkdownV2" + ) + return + + # Save to backend + client, _ = await get_bots_client(chat_id, context.user_data) + result = await client.controllers.create_or_update_controller_config(config_id, config) + + # Clear state + clear_bots_state(context) + + # Check result + if result.get("status") == "success" or "success" in str(result).lower(): + controller_name = config.get("controller_name", "unknown") + success_msg = ( + f"✅ *Config uploaded successfully\\!*\n\n" + f"ID: `{escape_markdown_v2(config_id)}`\n" + f"Type: `{escape_markdown_v2(controller_name)}`" + ) + keyboard = [ + [InlineKeyboardButton("📁 View Configs", callback_data="bots:controller_configs")], + [InlineKeyboardButton("⬅️ Back to Menu", callback_data="bots:main_menu")], + ] + await update.message.reply_text( + success_msg, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + else: + error_detail = result.get("message", result.get("error", str(result))) + await update.message.reply_text( + format_error_message(f"Failed to save config: {error_detail}"), + parse_mode="MarkdownV2" + ) + + except Exception as e: + logger.error(f"Error uploading config file: {e}", exc_info=True) + clear_bots_state(context) + await update.message.reply_text( + format_error_message(f"Failed to upload config: {str(e)}"), + parse_mode="MarkdownV2" + ) diff --git a/handlers/bots/controllers/__init__.py b/handlers/bots/controllers/__init__.py index 46d7d43..e716a12 100644 --- a/handlers/bots/controllers/__init__.py +++ b/handlers/bots/controllers/__init__.py @@ -9,7 +9,7 @@ - ID generation with chronological numbering """ -from typing import Any, Dict, List, Optional, Type +from typing import Dict, List, Optional, Type from ._base import BaseController, ControllerField from .grid_strike import GridStrikeController diff --git a/handlers/bots/controllers/grid_strike/__init__.py b/handlers/bots/controllers/grid_strike/__init__.py index fe93a89..70654a9 100644 --- a/handlers/bots/controllers/grid_strike/__init__.py +++ b/handlers/bots/controllers/grid_strike/__init__.py @@ -18,6 +18,7 @@ FIELDS, FIELD_ORDER, WIZARD_STEPS, + EDITABLE_FIELDS, SIDE_LONG, SIDE_SHORT, ORDER_TYPE_MARKET, @@ -98,6 +99,7 @@ def generate_id( "FIELDS", "FIELD_ORDER", "WIZARD_STEPS", + "EDITABLE_FIELDS", "SIDE_LONG", "SIDE_SHORT", "ORDER_TYPE_MARKET", diff --git a/handlers/bots/controllers/grid_strike/config.py b/handlers/bots/controllers/grid_strike/config.py index d3a30f3..443b0e6 100644 --- a/handlers/bots/controllers/grid_strike/config.py +++ b/handlers/bots/controllers/grid_strike/config.py @@ -34,7 +34,7 @@ "trading_pair": "", "side": SIDE_LONG, "leverage": 1, - "position_mode": "HEDGE", + "position_mode": "ONEWAY", "total_amount_quote": 1000, "min_order_amount_quote": 6, "start_price": 0.0, @@ -92,6 +92,14 @@ required=True, hint="e.g. 1, 5, 10" ), + "position_mode": ControllerField( + name="position_mode", + label="Position Mode", + type="str", + required=False, + hint="HEDGE or ONEWAY", + default="ONEWAY" + ), "total_amount_quote": ControllerField( name="total_amount_quote", label="Total Amount (Quote)", @@ -104,14 +112,14 @@ label="Start Price", type="float", required=True, - hint="Auto: -2% from current" + hint="Auto: -2% LONG, -6% SHORT" ), "end_price": ControllerField( name="end_price", label="End Price", type="float", required=True, - hint="Auto: +2% from current" + hint="Auto: +6% LONG, +2% SHORT" ), "limit_price": ControllerField( name="limit_price", @@ -213,7 +221,7 @@ # Field display order FIELD_ORDER: List[str] = [ - "id", "connector_name", "trading_pair", "side", "leverage", + "id", "connector_name", "trading_pair", "side", "leverage", "position_mode", "total_amount_quote", "start_price", "end_price", "limit_price", "max_open_orders", "max_orders_per_batch", "order_frequency", "min_order_amount_quote", "min_spread_between_orders", "take_profit", @@ -235,6 +243,26 @@ ] +# Editable fields for config editing +# This is the standard list shown in both wizard final step and edit views +EDITABLE_FIELDS: List[str] = [ + "connector_name", + "trading_pair", + "total_amount_quote", + "start_price", + "end_price", + "limit_price", + "leverage", + "position_mode", + "take_profit", + "coerce_tp_to_step", + "min_spread_between_orders", + "min_order_amount_quote", + "max_open_orders", + "activation_bounds", +] + + def validate_config(config: Dict[str, Any]) -> Tuple[bool, Optional[str]]: """ Validate a grid strike configuration. @@ -279,33 +307,36 @@ def validate_config(config: Dict[str, Any]) -> Tuple[bool, Optional[str]]: def calculate_auto_prices( current_price: float, side: int, - start_pct: float = 0.02, - end_pct: float = 0.02, + base_pct: float = 0.02, limit_pct: float = 0.03 ) -> Tuple[float, float, float]: """ Calculate start, end, and limit prices based on current price and side. - For LONG: - - start_price: current_price - 2% - - end_price: current_price + 2% - - limit_price: current_price - 3% + Uses a 3:1 ratio for the grid range: + + For LONG (buying grid below, selling above): + - start_price: current_price - 1x base_pct (buy zone starts here) + - end_price: current_price + 3x base_pct (grid extends up) + - limit_price: current_price - limit_pct (stop loss below start) - For SHORT: - - start_price: current_price - 2% - - end_price: current_price + 2% - - limit_price: current_price + 3% + For SHORT (selling grid above, buying below): + - start_price: current_price - 3x base_pct (grid extends down) + - end_price: current_price + 1x base_pct (sell zone ends here) + - limit_price: current_price + limit_pct (stop loss above end) Returns: Tuple of (start_price, end_price, limit_price) """ if side == SIDE_LONG: - start_price = current_price * (1 - start_pct) - end_price = current_price * (1 + end_pct) + # LONG: small range below (-1x), larger range above (+3x) + start_price = current_price * (1 - base_pct) + end_price = current_price * (1 + base_pct * 3) limit_price = current_price * (1 - limit_pct) else: # SHORT - start_price = current_price * (1 - start_pct) - end_price = current_price * (1 + end_pct) + # SHORT: larger range below (-3x), small range above (+1x) + start_price = current_price * (1 - base_pct * 3) + end_price = current_price * (1 + base_pct) limit_price = current_price * (1 + limit_pct) return ( diff --git a/handlers/bots/controllers/grid_strike/grid_analysis.py b/handlers/bots/controllers/grid_strike/grid_analysis.py index 98ad8b1..3863b66 100644 --- a/handlers/bots/controllers/grid_strike/grid_analysis.py +++ b/handlers/bots/controllers/grid_strike/grid_analysis.py @@ -8,7 +8,7 @@ - Grid metrics calculation """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import logging logger = logging.getLogger(__name__) @@ -164,15 +164,20 @@ def suggest_grid_params( suggested_spread = max(suggested_spread, 0.0002) # At least 0.02% suggested_tp = max(suggested_tp, 0.0001) # At least 0.01% - # Calculate prices based on side + # Calculate prices based on side using 3:1 ratio + # Total range = 4 units (1 unit on one side, 3 units on the other) + unit = grid_range / 4 + if side == 1: # LONG - start_price = current_price * (1 - grid_range / 2) - end_price = current_price * (1 + grid_range / 2) - limit_price = start_price * (1 - grid_range / 3) # Stop below start + # LONG: small range below (-1 unit), larger range above (+3 units) + start_price = current_price * (1 - unit) + end_price = current_price * (1 + unit * 3) + limit_price = start_price * (1 - unit) # Stop below start else: # SHORT - start_price = current_price * (1 - grid_range / 2) - end_price = current_price * (1 + grid_range / 2) - limit_price = end_price * (1 + grid_range / 3) # Stop above end + # SHORT: larger range below (-3 units), small range above (+1 unit) + start_price = current_price * (1 - unit * 3) + end_price = current_price * (1 + unit) + limit_price = end_price * (1 + unit) # Stop above end # Estimate number of levels price_range = abs(end_price - start_price) diff --git a/handlers/bots/controllers/pmm_mister/config.py b/handlers/bots/controllers/pmm_mister/config.py index 5966e44..73a9b86 100644 --- a/handlers/bots/controllers/pmm_mister/config.py +++ b/handlers/bots/controllers/pmm_mister/config.py @@ -16,9 +16,9 @@ ORDER_TYPE_LIMIT_MAKER = 3 ORDER_TYPE_LABELS = { - ORDER_TYPE_MARKET: "Market", - ORDER_TYPE_LIMIT: "Limit", - ORDER_TYPE_LIMIT_MAKER: "Limit Maker", + ORDER_TYPE_MARKET: "MARKET", + ORDER_TYPE_LIMIT: "LIMIT", + ORDER_TYPE_LIMIT_MAKER: "LIMIT_MAKER", } @@ -31,23 +31,25 @@ "trading_pair": "", "leverage": 20, "position_mode": "HEDGE", + "total_amount_quote": 100, "portfolio_allocation": 0.05, - "target_base_pct": 0.2, - "min_base_pct": 0.1, - "max_base_pct": 0.4, + "target_base_pct": 0.5, + "min_base_pct": 0.4, + "max_base_pct": 0.6, "buy_spreads": "0.0002,0.001", "sell_spreads": "0.0002,0.001", - "buy_amounts_pct": "1,2", - "sell_amounts_pct": "1,2", + "buy_amounts_pct": None, # Auto-calculated: 1 per spread level + "sell_amounts_pct": None, # Auto-calculated: 1 per spread level "executor_refresh_time": 30, "buy_cooldown_time": 15, "sell_cooldown_time": 15, - "buy_position_effectivization_time": 60, - "sell_position_effectivization_time": 60, + "buy_position_effectivization_time": 3600, + "sell_position_effectivization_time": 3600, "min_buy_price_distance_pct": 0.003, "min_sell_price_distance_pct": 0.003, "take_profit": 0.0001, - "take_profit_order_type": ORDER_TYPE_LIMIT_MAKER, + "take_profit_order_type": "LIMIT_MAKER", # String format for API + "open_order_type": "LIMIT", # String format for API "max_active_executors_by_level": 4, "tick_mode": False, "candles_config": [], @@ -237,15 +239,40 @@ hint="Enable tick-based updates", default=False ), + "total_amount_quote": ControllerField( + name="total_amount_quote", + label="Total Amount (Quote)", + type="float", + required=False, + hint="Total amount in quote currency (e.g. 500 USDT)", + default=100 + ), + "open_order_type": ControllerField( + name="open_order_type", + label="Open Order Type", + type="str", + required=False, + hint="Order type for opening (LIMIT, LIMIT_MAKER, MARKET)", + default="LIMIT" + ), + "position_mode": ControllerField( + name="position_mode", + label="Position Mode", + type="str", + required=False, + hint="Position mode (HEDGE, ONEWAY)", + default="HEDGE" + ), } # Field display order FIELD_ORDER: List[str] = [ "id", "connector_name", "trading_pair", "leverage", - "portfolio_allocation", "target_base_pct", "min_base_pct", "max_base_pct", + "total_amount_quote", "portfolio_allocation", "position_mode", + "target_base_pct", "min_base_pct", "max_base_pct", "buy_spreads", "sell_spreads", "buy_amounts_pct", "sell_amounts_pct", - "take_profit", "take_profit_order_type", + "take_profit", "take_profit_order_type", "open_order_type", "executor_refresh_time", "buy_cooldown_time", "sell_cooldown_time", "buy_position_effectivization_time", "sell_position_effectivization_time", "min_buy_price_distance_pct", "min_sell_price_distance_pct", diff --git a/handlers/bots/menu.py b/handlers/bots/menu.py index 6efcd0a..7c7a26c 100644 --- a/handlers/bots/menu.py +++ b/handlers/bots/menu.py @@ -10,14 +10,14 @@ """ import logging -from typing import Dict, Any, Optional, List +from typing import Dict, Any from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.error import BadRequest from telegram.ext import ContextTypes -from utils.telegram_formatters import format_active_bots, format_error_message, escape_markdown_v2, format_number -from ._shared import get_bots_client, clear_bots_state +from utils.telegram_formatters import format_active_bots, format_error_message, escape_markdown_v2, format_uptime +from ._shared import get_bots_client, clear_bots_state, set_controller_config logger = logging.getLogger(__name__) @@ -89,9 +89,33 @@ async def show_bots_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> return try: - client = await get_bots_client(chat_id) + from config_manager import get_config_manager + + client, server_name = await get_bots_client(chat_id, context.user_data) + + # Check server status for indicator + try: + server_status_info = await get_config_manager().check_server_status(server_name) + server_status = server_status_info.get("status", "online") + except Exception: + server_status = "online" # Default to online if check fails + + status_emoji = {"online": "🟢", "offline": "🔴", "auth_error": "🟠", "error": "⚠️"}.get(server_status, "🟢") + bots_data = await client.bot_orchestration.get_active_bots_status() + # Fetch bot runs to get deployment times + bot_runs_map = {} + try: + bot_runs_data = await client.bot_orchestration.get_bot_runs() + if isinstance(bot_runs_data, dict) and "data" in bot_runs_data: + for run in bot_runs_data.get("data", []): + # Only include DEPLOYED bots (not ARCHIVED) + if run.get("deployment_status") == "DEPLOYED" and run.get("deployed_at"): + bot_runs_map[run.get("bot_name")] = run.get("deployed_at") + except Exception as e: + logger.debug(f"Could not fetch bot runs for uptime: {e}") + # Extract bots dictionary for building keyboard if isinstance(bots_data, dict) and "data" in bots_data: bots_dict = bots_data.get("data", {}) @@ -102,15 +126,17 @@ async def show_bots_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> # Store bots data for later use context.user_data["active_bots_data"] = bots_data + context.user_data["bot_runs_map"] = bot_runs_map + context.user_data["current_server_name"] = server_name # Format the bot status message - status_message = format_active_bots(bots_data) + status_message = format_active_bots(bots_data, bot_runs=bot_runs_map) # Build the menu with bot buttons reply_markup = _build_main_menu_keyboard(bots_dict) - # Add header - header = r"*Bots Dashboard*" + "\n\n" + # Add header with server indicator + header = f"*Bots Dashboard* \\| _Server: {escape_markdown_v2(server_name)} {status_emoji}_\n\n" full_message = header + status_message if query: @@ -197,7 +223,7 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo # If not in cache, fetch fresh data if not bot_info: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) fresh_data = await client.bot_orchestration.get_active_bots_status() if isinstance(fresh_data, dict) and "data" in fresh_data: bot_info = fresh_data.get("data", {}).get(bot_name) @@ -219,13 +245,21 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo # Truncate bot name for display display_name = bot_name[:45] + "..." if len(bot_name) > 45 else bot_name + # Get uptime if available + bot_runs_map = context.user_data.get("bot_runs_map", {}) + uptime_str = "" + if bot_name in bot_runs_map: + uptime = format_uptime(bot_runs_map[bot_name]) + if uptime: + uptime_str = f" ⏱️ {uptime}" + lines = [ f"*Bot Details*", "", - f"{status_emoji} `{escape_markdown_v2(display_name)}`", + f"{status_emoji} `{escape_markdown_v2(display_name)}`{uptime_str}", ] - # Controllers and performance - rich format + # Controllers and performance - table format performance = bot_info.get("performance", {}) controller_names = list(performance.keys()) @@ -241,6 +275,11 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo total_realized = 0 total_unrealized = 0 + # Collect controller data for table + ctrl_rows = [] + all_positions = [] + all_closed = [] + for idx, (ctrl_name, ctrl_info) in enumerate(performance.items()): if not isinstance(ctrl_info, dict): continue @@ -258,26 +297,85 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo total_realized += realized total_unrealized += unrealized - # Controller section - compact format - lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - - # Controller name and status - ctrl_status_emoji = "▶️" if ctrl_status == "running" else "⏸️" - lines.append(f"{ctrl_status_emoji} *{escape_markdown_v2(ctrl_name)}*") - - # P&L + Volume in one line (compact) - pnl_emoji = "🟢" if pnl >= 0 else "🔴" - vol_str = f"{volume/1000:.1f}k" if volume >= 1000 else f"{volume:.0f}" - lines.append(f"{pnl_emoji} pnl: `{escape_markdown_v2(f'{pnl:+.2f}')}` \\(R: `{escape_markdown_v2(f'{realized:+.2f}')}` / U: `{escape_markdown_v2(f'{unrealized:+.2f}')}`\\) 📦 vol: `{escape_markdown_v2(vol_str)}`") - - # Open Positions section + ctrl_rows.append({ + "idx": idx, + "name": ctrl_name, + "status": ctrl_status, + "pnl": pnl, + "realized": realized, + "unrealized": unrealized, + "volume": volume, + }) + + # Collect positions with controller info positions = ctrl_perf.get("positions_summary", []) if positions: - lines.append("") - lines.append(f"*Open Positions* \\({len(positions)}\\)") - # Extract trading pair from controller name for display trading_pair = _extract_pair_from_name(ctrl_name) for pos in positions: + all_positions.append({"ctrl": ctrl_name, "pair": trading_pair, "pos": pos}) + + # Collect closed counts + close_counts = ctrl_perf.get("close_type_counts", {}) + if close_counts: + all_closed.append({"name": ctrl_name, "counts": close_counts}) + + # Build table header + lines.append("```") + lines.append("Controller PnL Vol") + lines.append("──────────────────────────── ──────── ───────") + + # Build table rows + for row in ctrl_rows: + status_char = "▶" if row["status"] == "running" else "⏸" + # Truncate controller name to fit table (max 27 chars) + name_display = row["name"][:27] if len(row["name"]) > 27 else row["name"] + name_padded = f"{status_char}{name_display}".ljust(28) + + pnl_str = f"{row['pnl']:+.2f}".rjust(8) + vol_str = f"{row['volume']/1000:.1f}k" if row["volume"] >= 1000 else f"{row['volume']:.0f}" + vol_str = vol_str.rjust(7) + + lines.append(f"{name_padded} {pnl_str} {vol_str}") + + # Total row + if len(ctrl_rows) > 1: + lines.append("──────────────────────────── ──────── ───────") + total_name = "TOTAL".ljust(28) + pnl_str = f"{total_pnl:+.2f}".rjust(8) + vol_str = f"{total_volume/1000:.1f}k" if total_volume >= 1000 else f"{total_volume:.0f}" + vol_str = vol_str.rjust(7) + lines.append(f"{total_name} {pnl_str} {vol_str}") + + lines.append("```") + + # Open Positions section (grouped by controller) - limit to avoid message too long + MAX_POSITIONS_DISPLAY = 8 + if all_positions: + lines.append("") + lines.append(f"*Open Positions* \\({len(all_positions)}\\)") + + # Group positions by controller + positions_by_ctrl = {} + for item in all_positions: + ctrl = item["ctrl"] + if ctrl not in positions_by_ctrl: + positions_by_ctrl[ctrl] = [] + positions_by_ctrl[ctrl].append(item) + + positions_shown = 0 + for ctrl_name, ctrl_positions in positions_by_ctrl.items(): + if positions_shown >= MAX_POSITIONS_DISPLAY: + remaining = len(all_positions) - positions_shown + lines.append(f"_\\.\\.\\.and {remaining} more_") + break + # Shorten controller name for display + short_ctrl = ctrl_name[:25] if len(ctrl_name) > 25 else ctrl_name + lines.append(f"_{escape_markdown_v2(short_ctrl)}_") + for item in ctrl_positions: + if positions_shown >= MAX_POSITIONS_DISPLAY: + break + pos = item["pos"] + trading_pair = item["pair"] side_raw = pos.get("side", "") is_long = "BUY" in str(side_raw).upper() side_emoji = "🟢" if is_long else "🔴" @@ -286,74 +384,69 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo breakeven = pos.get("breakeven_price", 0) or 0 pos_value = amount * breakeven pos_unrealized = pos.get("unrealized_pnl_quote", 0) or 0 - - lines.append(f"📍 {escape_markdown_v2(trading_pair)} {side_emoji}{side_str} `${escape_markdown_v2(f'{pos_value:.2f}')}` @ `{escape_markdown_v2(f'{breakeven:.4f}')}` \\| U: `{escape_markdown_v2(f'{pos_unrealized:+.2f}')}`") - - # Closed Positions section - close_counts = ctrl_perf.get("close_type_counts", {}) - if close_counts: - total_closed = sum(close_counts.values()) + lines.append(f" 📍 {side_emoji}{side_str} `${escape_markdown_v2(f'{pos_value:.2f}')}` @ `{escape_markdown_v2(f'{breakeven:.4f}')}` \\| U: `{escape_markdown_v2(f'{pos_unrealized:+.2f}')}`") + positions_shown += 1 + + # Closed Positions section (combined) + if all_closed: + total_tp = total_sl = total_hold = total_early = total_insuf = 0 + for item in all_closed: + counts = item["counts"] + total_tp += _get_close_count(counts, "TAKE_PROFIT") + total_sl += _get_close_count(counts, "STOP_LOSS") + total_hold += _get_close_count(counts, "POSITION_HOLD") + total_early += _get_close_count(counts, "EARLY_STOP") + total_insuf += _get_close_count(counts, "INSUFFICIENT_BALANCE") + + total_closed_count = total_tp + total_sl + total_hold + total_early + total_insuf + if total_closed_count > 0: lines.append("") - lines.append(f"*Closed Positions* \\({total_closed}\\)") - - # Extract counts for each type - tp = _get_close_count(close_counts, "TAKE_PROFIT") - sl = _get_close_count(close_counts, "STOP_LOSS") - hold = _get_close_count(close_counts, "POSITION_HOLD") - early = _get_close_count(close_counts, "EARLY_STOP") - insuf = _get_close_count(close_counts, "INSUFFICIENT_BALANCE") - - # Row 1: TP | SL (if any) - row1_parts = [] - if tp > 0: - row1_parts.append(f"🎯 TP: `{tp}`") - if sl > 0: - row1_parts.append(f"🛑 SL: `{sl}`") - if row1_parts: - lines.append(" \\| ".join(row1_parts)) - - # Row 2: Hold | Early (if any) - row2_parts = [] - if hold > 0: - row2_parts.append(f"✋ Hold: `{hold}`") - if early > 0: - row2_parts.append(f"⚡ Early: `{early}`") - if row2_parts: - lines.append(" \\| ".join(row2_parts)) - - # Row 3: Insufficient balance (if any) - if insuf > 0: - lines.append(f"⚠️ Insuf\\. Balance: `{insuf}`") - - # Add controller button row: [✏️ controller_name] [▶️/⏸️] + lines.append(f"*Closed Positions* \\({total_closed_count}\\)") + + row_parts = [] + if total_tp > 0: + row_parts.append(f"🎯 TP: `{total_tp}`") + if total_sl > 0: + row_parts.append(f"🛑 SL: `{total_sl}`") + if total_hold > 0: + row_parts.append(f"✋ Hold: `{total_hold}`") + if total_early > 0: + row_parts.append(f"⚡ Early: `{total_early}`") + if total_insuf > 0: + row_parts.append(f"⚠️ Insuf: `{total_insuf}`") + + if row_parts: + lines.append(" \\| ".join(row_parts)) + + # Add controller buttons + for row in ctrl_rows: + idx = row["idx"] + ctrl_status = row["status"] + ctrl_name = row["name"] + toggle_emoji = "⏸" if ctrl_status == "running" else "▶️" toggle_action = "stop_ctrl_quick" if ctrl_status == "running" else "start_ctrl_quick" if idx < 8: # Max 8 controllers with buttons - # Use shortened name for button but keep it readable - btn_name = _shorten_controller_name(ctrl_name, 22) + # Use controller name directly, truncate if needed + btn_name = ctrl_name[:26] if len(ctrl_name) > 26 else ctrl_name keyboard.append([ InlineKeyboardButton(f"✏️ {btn_name}", callback_data=f"bots:ctrl_idx:{idx}"), InlineKeyboardButton(toggle_emoji, callback_data=f"bots:{toggle_action}:{idx}"), ]) - # Total summary (only if multiple controllers) - if len(performance) > 1: - lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - pnl_emoji = "🟢" if total_pnl >= 0 else "🔴" - vol_total = f"{total_volume/1000:.1f}k" if total_volume >= 1000 else f"{total_volume:.0f}" - lines.append(f"*TOTAL* {pnl_emoji} pnl: `{escape_markdown_v2(f'{total_pnl:+.2f}')}` \\(R: `{escape_markdown_v2(f'{total_realized:+.2f}')}` / U: `{escape_markdown_v2(f'{total_unrealized:+.2f}')}`\\) 📦 vol: `{escape_markdown_v2(vol_total)}`") - # Error summary at the bottom error_logs = bot_info.get("error_logs", []) if error_logs: lines.append("") lines.append(f"⚠️ *{len(error_logs)} error\\(s\\):*") - # Show last 2 errors briefly - for err in error_logs[-2:]: + # Show last 3 errors with truncated message + for err in error_logs[-3:]: err_msg = err.get("msg", str(err)) if isinstance(err, dict) else str(err) - err_short = err_msg[:60] + "..." if len(err_msg) > 60 else err_msg - lines.append(f" `{escape_markdown_v2(err_short)}`") + # Truncate long error messages + if len(err_msg) > 80: + err_msg = err_msg[:77] + "..." + lines.append(f" `{escape_markdown_v2(err_msg)}`") # Bot-level actions keyboard.append([ @@ -368,6 +461,13 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo reply_markup = InlineKeyboardMarkup(keyboard) + # Build message and ensure it doesn't exceed Telegram's limit + message_text = "\n".join(lines) + MAX_MESSAGE_LENGTH = 4000 # Leave some buffer below 4096 + if len(message_text) > MAX_MESSAGE_LENGTH: + # Truncate and add indicator + message_text = message_text[:MAX_MESSAGE_LENGTH - 50] + "\n\n_\\.\\.\\. truncated_" + try: # Check if current message is a photo (from controller detail view) if getattr(query.message, 'photo', None): @@ -377,13 +477,13 @@ async def show_bot_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, bo except Exception: pass await query.message.chat.send_message( - "\n".join(lines), + message_text, parse_mode="MarkdownV2", reply_markup=reply_markup ) else: await query.message.edit_text( - "\n".join(lines), + message_text, parse_mode="MarkdownV2", reply_markup=reply_markup ) @@ -543,9 +643,10 @@ async def show_controller_detail(update: Update, context: ContextTypes.DEFAULT_T # Try to fetch controller config ctrl_config = None is_grid_strike = False + is_pmm_mister = False try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) configs = await client.controllers.get_bot_controller_configs(bot_name) # Find the matching config @@ -558,6 +659,7 @@ async def show_controller_detail(update: Update, context: ContextTypes.DEFAULT_T context.user_data["current_controller_config"] = ctrl_config controller_type = ctrl_config.get("controller_name", "") is_grid_strike = "grid_strike" in controller_type.lower() + is_pmm_mister = "pmm_mister" in controller_type.lower() except Exception as e: logger.warning(f"Could not fetch controller config: {e}") @@ -570,13 +672,12 @@ async def show_controller_detail(update: Update, context: ContextTypes.DEFAULT_T lines = [ f"{status_emoji} *{escape_markdown_v2(controller_name)}*", "", - f"{pnl_emoji} `{escape_markdown_v2(f'{pnl:+.2f}')}` \\| 💰 R: `{escape_markdown_v2(f'{realized:+.2f}')}` \\| 📊 U: `{escape_markdown_v2(f'{unrealized:+.2f}')}`", - f"📦 Vol: `{escape_markdown_v2(vol_str)}`", + f"{pnl_emoji} `{escape_markdown_v2(f'{pnl:+.2f}')}` \\| 💰 R: `{escape_markdown_v2(f'{realized:+.2f}')}` \\| 📊 U: `{escape_markdown_v2(f'{unrealized:+.2f}')}` \\| 📦 `{escape_markdown_v2(vol_str)}`", ] # Add editable config section if available - if ctrl_config and is_grid_strike: - editable_fields = _get_editable_controller_fields(ctrl_config) + if ctrl_config and (is_grid_strike or is_pmm_mister): + editable_fields = _get_editable_controller_fields(ctrl_config, is_pmm_mister) # Store for input processing context.user_data["ctrl_editable_fields"] = editable_fields @@ -596,18 +697,43 @@ async def show_controller_detail(update: Update, context: ContextTypes.DEFAULT_T lines.append("") lines.append("✏️ _Send `key=value` to update_") - # Build keyboard + # Build keyboard - show Start or Stop based on controller status keyboard = [] + is_running = ctrl_status == "running" if is_grid_strike and ctrl_config: - keyboard.append([ - InlineKeyboardButton("📊 Chart", callback_data="bots:ctrl_chart"), - InlineKeyboardButton("🛑 Stop", callback_data="bots:stop_ctrl"), - ]) + # Grid Strike: show Chart + Stop/Start + if is_running: + keyboard.append([ + InlineKeyboardButton("📊 Chart", callback_data="bots:ctrl_chart"), + InlineKeyboardButton("🛑 Stop", callback_data="bots:stop_ctrl"), + ]) + else: + keyboard.append([ + InlineKeyboardButton("📊 Chart", callback_data="bots:ctrl_chart"), + InlineKeyboardButton("▶️ Start", callback_data="bots:start_ctrl"), + ]) + elif is_pmm_mister and ctrl_config: + # PMM Mister: Stop/Start + Clone + if is_running: + keyboard.append([ + InlineKeyboardButton("🛑 Stop", callback_data="bots:stop_ctrl"), + InlineKeyboardButton("📋 Clone", callback_data="bots:clone_ctrl"), + ]) + else: + keyboard.append([ + InlineKeyboardButton("▶️ Start", callback_data="bots:start_ctrl"), + InlineKeyboardButton("📋 Clone", callback_data="bots:clone_ctrl"), + ]) else: - keyboard.append([ - InlineKeyboardButton("🛑 Stop Controller", callback_data="bots:stop_ctrl"), - ]) + if is_running: + keyboard.append([ + InlineKeyboardButton("🛑 Stop Controller", callback_data="bots:stop_ctrl"), + ]) + else: + keyboard.append([ + InlineKeyboardButton("▶️ Start Controller", callback_data="bots:start_ctrl"), + ]) keyboard.append([ InlineKeyboardButton("⬅️ Back", callback_data="bots:back_to_bot"), @@ -707,7 +833,7 @@ async def handle_confirm_stop_controller(update: Update, context: ContextTypes.D ) try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Stop controller by setting manual_kill_switch=True result = await client.controllers.update_bot_controller_config( @@ -716,7 +842,10 @@ async def handle_confirm_stop_controller(update: Update, context: ContextTypes.D config={"manual_kill_switch": True} ) - keyboard = [[InlineKeyboardButton("⬅️ Back to Bot", callback_data="bots:back_to_bot")]] + keyboard = [[ + InlineKeyboardButton("▶️ Restart", callback_data="bots:start_ctrl"), + InlineKeyboardButton("⬅️ Back to Bot", callback_data="bots:back_to_bot"), + ]] await query.message.edit_text( f"*Controller Stopped*\n\n`{escape_markdown_v2(short_name)}`", @@ -734,6 +863,152 @@ async def handle_confirm_stop_controller(update: Update, context: ContextTypes.D ) +async def handle_start_controller(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Start/restart current controller - show confirmation""" + query = update.callback_query + + bot_name = context.user_data.get("current_bot_name") + controllers = context.user_data.get("current_controllers", []) + controller_idx = context.user_data.get("current_controller_idx") + + if not bot_name or controller_idx is None or controller_idx >= len(controllers): + await query.answer("Context lost", show_alert=True) + return + + controller_name = controllers[controller_idx] + short_name = _shorten_controller_name(controller_name, 30) + + keyboard = [ + [ + InlineKeyboardButton("✅ Yes, Start", callback_data="bots:confirm_start_ctrl"), + InlineKeyboardButton("❌ Cancel", callback_data=f"bots:ctrl_idx:{controller_idx}"), + ], + ] + + message_text = ( + f"*Start Controller?*\n\n" + f"`{escape_markdown_v2(short_name)}`\n\n" + f"This will resume the controller\\." + ) + + # Handle photo messages (from controller detail view with chart) + if getattr(query.message, 'photo', None): + try: + await query.message.delete() + except Exception: + pass + await query.message.chat.send_message( + message_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + else: + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_confirm_start_controller(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Actually start the controller by setting manual_kill_switch=False""" + query = update.callback_query + chat_id = update.effective_chat.id + + bot_name = context.user_data.get("current_bot_name") + controllers = context.user_data.get("current_controllers", []) + controller_idx = context.user_data.get("current_controller_idx") + + if not bot_name or controller_idx is None or controller_idx >= len(controllers): + await query.answer("Context lost", show_alert=True) + return + + controller_name = controllers[controller_idx] + short_name = _shorten_controller_name(controller_name, 30) + + await query.message.edit_text( + f"Starting `{escape_markdown_v2(short_name)}`\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + try: + client, _ = await get_bots_client(chat_id, context.user_data) + + # Start controller by setting manual_kill_switch=False + result = await client.controllers.update_bot_controller_config( + bot_name=bot_name, + controller_name=controller_name, + config={"manual_kill_switch": False} + ) + + keyboard = [[InlineKeyboardButton("⬅️ Back to Bot", callback_data="bots:back_to_bot")]] + + await query.message.edit_text( + f"*Controller Started*\n\n`{escape_markdown_v2(short_name)}`", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + except Exception as e: + logger.error(f"Error starting controller: {e}", exc_info=True) + keyboard = [[InlineKeyboardButton("⬅️ Back", callback_data=f"bots:ctrl_idx:{controller_idx}")]] + await query.message.edit_text( + f"*Failed*\n\nError: {escape_markdown_v2(str(e)[:100])}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_clone_controller(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Clone current controller config - opens PMM wizard in review mode""" + query = update.callback_query + chat_id = update.effective_chat.id + + ctrl_config = context.user_data.get("current_controller_config") + if not ctrl_config: + await query.answer("No config to clone", show_alert=True) + return + + controller_type = ctrl_config.get("controller_name", "") + if "pmm_mister" not in controller_type.lower(): + await query.answer("Clone only supported for PMM Mister", show_alert=True) + return + + await query.answer("Cloning config...") + + try: + # Fetch existing configs to generate new ID (use get_all to find max number) + client, _ = await get_bots_client(chat_id, context.user_data) + configs = await client.controllers.get_all_controller_configs() + context.user_data["controller_configs_list"] = configs + + # Import generate_id from pmm_mister + from .controllers.pmm_mister import generate_id as pmm_generate_id + + # Create a copy of the config + new_config = dict(ctrl_config) + + # Generate new ID + new_config["id"] = pmm_generate_id(new_config, configs) + + # Set the config for the wizard + set_controller_config(context, new_config) + + # Set up wizard state for review mode + context.user_data["bots_state"] = "pmm_wizard" + context.user_data["pmm_wizard_step"] = "review" + context.user_data["pmm_wizard_message_id"] = query.message.message_id + context.user_data["pmm_wizard_chat_id"] = query.message.chat_id + + # Import and show the review step + from .controller_handlers import _pmm_show_review + await _pmm_show_review(context, chat_id, query.message.message_id, new_config) + + except Exception as e: + logger.error(f"Error cloning controller: {e}", exc_info=True) + await query.answer(f"Error: {str(e)[:50]}", show_alert=True) + + async def handle_quick_stop_controller(update: Update, context: ContextTypes.DEFAULT_TYPE, controller_idx: int) -> None: """Quick stop controller from bot detail view (no confirmation)""" query = update.callback_query @@ -752,7 +1027,7 @@ async def handle_quick_stop_controller(update: Update, context: ContextTypes.DEF await query.answer(f"Stopping {short_name}...") try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Stop controller by setting manual_kill_switch=True await client.controllers.update_bot_controller_config( @@ -761,8 +1036,9 @@ async def handle_quick_stop_controller(update: Update, context: ContextTypes.DEF config={"manual_kill_switch": True} ) - # Refresh bot detail view + # Clear caches to force fresh data fetch context.user_data.pop("current_bot_info", None) + context.user_data.pop("active_bots_data", None) await show_bot_detail(update, context, bot_name) except Exception as e: @@ -788,7 +1064,7 @@ async def handle_quick_start_controller(update: Update, context: ContextTypes.DE await query.answer(f"Starting {short_name}...") try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Start controller by setting manual_kill_switch=False await client.controllers.update_bot_controller_config( @@ -797,8 +1073,9 @@ async def handle_quick_start_controller(update: Update, context: ContextTypes.DE config={"manual_kill_switch": False} ) - # Refresh bot detail view + # Clear caches to force fresh data fetch context.user_data.pop("current_bot_info", None) + context.user_data.pop("active_bots_data", None) await show_bot_detail(update, context, bot_name) except Exception as e: @@ -811,7 +1088,7 @@ async def handle_quick_start_controller(update: Update, context: ContextTypes.DE # ============================================ async def show_controller_chart(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Generate and show OHLC chart for grid strike controller""" + """Generate and show OHLC chart for controller""" query = update.callback_query chat_id = update.effective_chat.id controller_idx = context.user_data.get("current_controller_idx", 0) @@ -821,6 +1098,10 @@ async def show_controller_chart(update: Update, context: ContextTypes.DEFAULT_TY await query.answer("Config not found", show_alert=True) return + # Detect controller type + controller_type = ctrl_config.get("controller_name", "") + is_pmm_mister = "pmm_mister" in controller_type.lower() + # Show loading message short_name = _shorten_controller_name(ctrl_config.get("id", ""), 30) loading_text = f"⏳ *Generating chart\\.\\.\\.*" @@ -831,7 +1112,7 @@ async def show_controller_chart(update: Update, context: ContextTypes.DEFAULT_TY pass try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) connector = ctrl_config.get("connector_name", "") pair = ctrl_config.get("trading_pair", "") @@ -848,24 +1129,37 @@ async def show_controller_chart(update: Update, context: ContextTypes.DEFAULT_TY ) current_price = prices.get("prices", {}).get(pair) - # Generate chart - from .controllers.grid_strike import generate_chart + # Generate chart based on controller type + if is_pmm_mister: + from .controllers.pmm_mister import generate_chart + else: + from .controllers.grid_strike import generate_chart chart_bytes = generate_chart(ctrl_config, candles, current_price) if chart_bytes: - # Build caption - side_val = ctrl_config.get("side", 1) - side_str = "LONG" if side_val == 1 else "SHORT" + # Build caption based on controller type leverage = ctrl_config.get("leverage", 1) - start_p = ctrl_config.get("start_price", 0) - end_p = ctrl_config.get("end_price", 0) - limit_p = ctrl_config.get("limit_price", 0) - - caption = ( - f"📊 *{escape_markdown_v2(pair)}* \\| {escape_markdown_v2(side_str)} {leverage}x\n" - f"Grid: `{escape_markdown_v2(f'{start_p:.6g}')}` → `{escape_markdown_v2(f'{end_p:.6g}')}`\n" - f"Limit: `{escape_markdown_v2(f'{limit_p:.6g}')}`" - ) + + if is_pmm_mister: + buy_spreads = ctrl_config.get("buy_spreads", "0.0002,0.001") + sell_spreads = ctrl_config.get("sell_spreads", "0.0002,0.001") + take_profit = ctrl_config.get("take_profit", 0.0001) + caption = ( + f"📊 *{escape_markdown_v2(pair)}* \\| PMM {leverage}x\n" + f"Buy: `{escape_markdown_v2(buy_spreads)}` \\| Sell: `{escape_markdown_v2(sell_spreads)}`\n" + f"TP: `{escape_markdown_v2(f'{take_profit:.4%}')}`" + ) + else: + side_val = ctrl_config.get("side", 1) + side_str = "LONG" if side_val == 1 else "SHORT" + start_p = ctrl_config.get("start_price", 0) + end_p = ctrl_config.get("end_price", 0) + limit_p = ctrl_config.get("limit_price", 0) + caption = ( + f"📊 *{escape_markdown_v2(pair)}* \\| {escape_markdown_v2(side_str)} {leverage}x\n" + f"Grid: `{escape_markdown_v2(f'{start_p:.6g}')}` → `{escape_markdown_v2(f'{end_p:.6g}')}`\n" + f"Limit: `{escape_markdown_v2(f'{limit_p:.6g}')}`" + ) keyboard = [[ InlineKeyboardButton("⬅️ Back", callback_data=f"bots:ctrl_idx:{controller_idx}"), @@ -916,9 +1210,11 @@ async def show_controller_edit(update: Update, context: ContextTypes.DEFAULT_TYP return controller_name = ctrl_config.get("id", "") + controller_type = ctrl_config.get("controller_name", "") + is_pmm_mister = "pmm_mister" in controller_type.lower() # Define editable fields with their current values - editable_fields = _get_editable_controller_fields(ctrl_config) + editable_fields = _get_editable_controller_fields(ctrl_config, is_pmm_mister) # Store editable fields in context for input processing context.user_data["ctrl_editable_fields"] = editable_fields @@ -973,21 +1269,61 @@ async def show_controller_edit(update: Update, context: ContextTypes.DEFAULT_TYP context.user_data["ctrl_edit_message_id"] = query.message.message_id -def _get_editable_controller_fields(ctrl_config: Dict[str, Any]) -> Dict[str, Any]: +def _get_editable_controller_fields(ctrl_config: Dict[str, Any], is_pmm_mister: bool = False) -> Dict[str, Any]: """Extract editable fields from controller config""" - tp_cfg = ctrl_config.get("triple_barrier_config", {}) - take_profit = tp_cfg.get("take_profit", 0.0001) if isinstance(tp_cfg, dict) else 0.0001 - - return { - "start_price": ctrl_config.get("start_price", 0), - "end_price": ctrl_config.get("end_price", 0), - "limit_price": ctrl_config.get("limit_price", 0), - "total_amount_quote": ctrl_config.get("total_amount_quote", 0), - "max_open_orders": ctrl_config.get("max_open_orders", 3), - "max_orders_per_batch": ctrl_config.get("max_orders_per_batch", 1), - "min_spread_between_orders": ctrl_config.get("min_spread_between_orders", 0.0001), - "take_profit": take_profit, - } + if is_pmm_mister: + # PMM Mister editable fields - match review step order + return { + # Identification fields + "id": ctrl_config.get("id", ""), + "connector_name": ctrl_config.get("connector_name", ""), + "trading_pair": ctrl_config.get("trading_pair", ""), + "leverage": ctrl_config.get("leverage", 20), + "position_mode": ctrl_config.get("position_mode", "HEDGE"), + # Amount settings + "total_amount_quote": ctrl_config.get("total_amount_quote", 100), + "portfolio_allocation": ctrl_config.get("portfolio_allocation", 0.05), + # Base percentages + "target_base_pct": ctrl_config.get("target_base_pct", 0.5), + "min_base_pct": ctrl_config.get("min_base_pct", 0.4), + "max_base_pct": ctrl_config.get("max_base_pct", 0.6), + # Spreads and amounts + "buy_spreads": ctrl_config.get("buy_spreads", "0.0002,0.001"), + "sell_spreads": ctrl_config.get("sell_spreads", "0.0002,0.001"), + "buy_amounts_pct": ctrl_config.get("buy_amounts_pct", "1,2"), + "sell_amounts_pct": ctrl_config.get("sell_amounts_pct", "1,2"), + # Take profit settings + "take_profit": ctrl_config.get("take_profit", 0.0001), + "take_profit_order_type": ctrl_config.get("take_profit_order_type", "LIMIT_MAKER"), + "open_order_type": ctrl_config.get("open_order_type", "LIMIT"), + # Timing settings + "executor_refresh_time": ctrl_config.get("executor_refresh_time", 30), + "buy_cooldown_time": ctrl_config.get("buy_cooldown_time", 15), + "sell_cooldown_time": ctrl_config.get("sell_cooldown_time", 15), + "buy_position_effectivization_time": ctrl_config.get("buy_position_effectivization_time", 3600), + "sell_position_effectivization_time": ctrl_config.get("sell_position_effectivization_time", 3600), + # Distance settings + "min_buy_price_distance_pct": ctrl_config.get("min_buy_price_distance_pct", 0.003), + "min_sell_price_distance_pct": ctrl_config.get("min_sell_price_distance_pct", 0.003), + # Executor settings + "max_active_executors_by_level": ctrl_config.get("max_active_executors_by_level", 4), + "tick_mode": ctrl_config.get("tick_mode", False), + } + else: + # Grid Strike editable fields + tp_cfg = ctrl_config.get("triple_barrier_config", {}) + take_profit = tp_cfg.get("take_profit", 0.0001) if isinstance(tp_cfg, dict) else 0.0001 + + return { + "start_price": ctrl_config.get("start_price", 0), + "end_price": ctrl_config.get("end_price", 0), + "limit_price": ctrl_config.get("limit_price", 0), + "total_amount_quote": ctrl_config.get("total_amount_quote", 0), + "max_open_orders": ctrl_config.get("max_open_orders", 3), + "max_orders_per_batch": ctrl_config.get("max_orders_per_batch", 1), + "min_spread_between_orders": ctrl_config.get("min_spread_between_orders", 0.0001), + "take_profit": take_profit, + } async def handle_controller_set_field(update: Update, context: ContextTypes.DEFAULT_TYPE, field_name: str) -> None: @@ -1088,7 +1424,7 @@ async def handle_controller_confirm_set(update: Update, context: ContextTypes.DE await query.answer("Updating...") try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Build config update if field_name == "take_profit": @@ -1256,7 +1592,7 @@ async def process_controller_field_input(update: Update, context: ContextTypes.D update_config[key] = value try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) # Apply the update result = await client.controllers.update_bot_controller_config( @@ -1265,7 +1601,16 @@ async def process_controller_field_input(update: Update, context: ContextTypes.D config=update_config ) - if result.get("status") == "success": + # Check for success - API may return status="success" or message containing "successfully" + result_status = result.get("status", "") + result_message = result.get("message", "") + is_success = ( + result_status == "success" or + "successfully" in str(result_message).lower() or + "updated" in str(result_message).lower() + ) + + if is_success: # Update local config cache for key, value in updates.items(): if key == "take_profit": @@ -1302,7 +1647,7 @@ async def process_controller_field_input(update: Update, context: ContextTypes.D reply_markup=InlineKeyboardMarkup(keyboard) ) else: - error_msg = result.get("message", "Update failed") + error_msg = result_message or "Update failed" keyboard = [[InlineKeyboardButton("⬅️ Back", callback_data=f"bots:ctrl_idx:{controller_idx}")]] if message_id: @@ -1385,7 +1730,7 @@ async def handle_confirm_stop_bot(update: Update, context: ContextTypes.DEFAULT_ ) try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) result = await client.bot_orchestration.stop_and_archive_bot( bot_name=bot_name, @@ -1457,7 +1802,7 @@ async def handle_refresh_controller(update: Update, context: ContextTypes.DEFAUL if bot_name: try: - client = await get_bots_client(chat_id) + client, _ = await get_bots_client(chat_id, context.user_data) fresh_data = await client.bot_orchestration.get_active_bots_status() if isinstance(fresh_data, dict) and "data" in fresh_data: bot_info = fresh_data.get("data", {}).get(bot_name) @@ -1477,8 +1822,8 @@ async def handle_refresh_controller(update: Update, context: ContextTypes.DEFAUL # VIEW LOGS # ============================================ -def _format_log_entry(log, max_msg_len: int = 55) -> str: - """Format a log entry with timestamp""" +def _format_log_entry(log) -> str: + """Format a log entry with timestamp - full message, no truncation""" if isinstance(log, dict): timestamp = log.get("timestamp", log.get("time", log.get("ts", ""))) msg = log.get("msg", log.get("message", str(log))) @@ -1486,6 +1831,9 @@ def _format_log_entry(log, max_msg_len: int = 55) -> str: timestamp = "" msg = str(log) + # Escape backticks in log messages to prevent breaking code blocks + msg = str(msg).replace("`", "'") + # Extract time portion (HH:MM:SS) from timestamp time_str = "" if timestamp: @@ -1502,16 +1850,13 @@ def _format_log_entry(log, max_msg_len: int = 55) -> str: elif len(ts) >= 8 and ":" in ts: time_str = ts[:8] - # Truncate message - msg = msg[:max_msg_len] if len(msg) > max_msg_len else msg - if time_str: return f"[{time_str}] {msg}" return msg async def show_bot_logs(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Show recent logs for current bot with timestamps""" + """Show recent logs for current bot with timestamps - full messages""" query = update.callback_query bot_name = context.user_data.get("current_bot_name") @@ -1524,10 +1869,8 @@ async def show_bot_logs(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N general_logs = bot_info.get("general_logs", []) error_logs = bot_info.get("error_logs", []) - display_name = bot_name[:25] + "..." if len(bot_name) > 25 else bot_name - lines = [ - f"*Logs: `{escape_markdown_v2(display_name)}`*", + f"*Logs: `{escape_markdown_v2(bot_name)}`*", "", ] @@ -1535,8 +1878,8 @@ async def show_bot_logs(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N if error_logs: lines.append("*🔴 Errors:*") lines.append("```") - for log in error_logs[:5]: - entry = _format_log_entry(log, 50) + for log in error_logs[:10]: + entry = _format_log_entry(log) lines.append(entry) lines.append("```") lines.append("") @@ -1545,8 +1888,8 @@ async def show_bot_logs(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N if general_logs: lines.append("*📋 Recent Activity:*") lines.append("```") - for log in general_logs[-10:]: - entry = _format_log_entry(log, 50) + for log in general_logs[-20:]: + entry = _format_log_entry(log) lines.append(entry) lines.append("```") else: @@ -1556,7 +1899,11 @@ async def show_bot_logs(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N message = "\n".join(lines) if len(message) > 4000: - message = message[:4000] + "\n\\.\\.\\." + message = message[:4000] + # Check if we have an unclosed code block (odd number of ```) + if message.count("```") % 2 == 1: + message += "\n```" + message += "\n\\.\\.\\." await query.message.edit_text( message, diff --git a/handlers/cex/__init__.py b/handlers/cex/__init__.py index 2b18129..71cefee 100644 --- a/handlers/cex/__init__.py +++ b/handlers/cex/__init__.py @@ -13,7 +13,7 @@ from telegram import Update from telegram.ext import ContextTypes, CallbackQueryHandler, MessageHandler, filters -from utils.auth import restricted +from utils.auth import restricted, hummingbot_api_required from utils.telegram_formatters import format_error_message from handlers import clear_all_input_states @@ -34,7 +34,35 @@ def clear_cex_state(context: ContextTypes.DEFAULT_TYPE) -> None: context.user_data.pop("trade_menu_chat_id", None) +async def _handle_switch_to_dex( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + network_id: str +) -> None: + """Switch from CEX to DEX trading""" + from handlers.config.user_preferences import ( + set_last_trade_connector, + get_dex_swap_defaults, + ) + from handlers.dex.swap import handle_swap as dex_handle_swap + + # Clear CEX state + clear_cex_state(context) + + # Save preference (DEX stores network ID) + set_last_trade_connector(context.user_data, "dex", network_id) + + # Set up DEX swap params with the selected network + defaults = get_dex_swap_defaults(context.user_data) + defaults["network"] = network_id + context.user_data["swap_params"] = defaults + + # Route to DEX swap menu + await dex_handle_swap(update, context) + + @restricted +@hummingbot_api_required async def trade_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """ Handle /trade command - CEX trading interface with order books @@ -63,7 +91,7 @@ async def trade_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N @restricted async def cex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle inline button callbacks for CEX trading operations""" - from .menu import show_cex_menu, cancel_cex_loading_task + from .menu import cancel_cex_loading_task from .trade import ( handle_trade, handle_trade_refresh, @@ -79,6 +107,7 @@ async def cex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYP handle_trade_toggle_pos_mode, handle_trade_get_quote, handle_trade_execute, + handle_trade_pair_select, handle_close, ) from .orders import handle_search_orders, handle_cancel_order, handle_confirm_cancel_order @@ -122,6 +151,9 @@ async def cex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYP await handle_trade_connector_select(update, context, connector_name) elif action == "trade_set_pair": await handle_trade_set_pair(update, context) + elif action.startswith("trade_pair_select_"): + trading_pair = action.replace("trade_pair_select_", "") + await handle_trade_pair_select(update, context, trading_pair) elif action == "trade_set_amount": await handle_trade_set_amount(update, context) elif action == "trade_set_price": @@ -139,8 +171,6 @@ async def cex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYP # Orders elif action == "search_orders": - await handle_search_orders(update, context, status="OPEN") - elif action == "search_all": await handle_search_orders(update, context, status="ALL") elif action == "search_filled": await handle_search_orders(update, context, status="FILLED") @@ -170,6 +200,15 @@ async def cex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYP elif action == "close": await handle_close(update, context) + # Switch to DEX + elif action.startswith("switch_dex_"): + network_id = action.replace("switch_dex_", "") + await _handle_switch_to_dex(update, context, network_id) + + # No-op for separator buttons + elif action == "noop": + pass + else: await query.message.reply_text(f"Unknown action: {action}") diff --git a/handlers/cex/_shared.py b/handlers/cex/_shared.py index 71ecf29..571acc3 100644 --- a/handlers/cex/_shared.py +++ b/handlers/cex/_shared.py @@ -219,12 +219,13 @@ def get_cex_connectors(connectors: Dict[str, Any]) -> List[str]: # BALANCE FETCHING # ============================================ -async def fetch_cex_balances(client, account_name: str) -> Dict[str, List[Dict[str, Any]]]: +async def fetch_cex_balances(client, account_name: str, refresh: bool = False) -> Dict[str, List[Dict[str, Any]]]: """Fetch balances for all CEX connectors. Args: client: API client account_name: Account name to fetch balances for + refresh: If True, force fresh fetch from exchange (slow). Default False uses cached data. Returns: Dict of connector_name -> list of balances @@ -245,7 +246,7 @@ async def fetch_cex_balances(client, account_name: str) -> Dict[str, List[Dict[s portfolio_state = await client.portfolio.get_state( account_names=[account_name], connector_names=cex_connectors, - refresh=True, + refresh=refresh, ) # portfolio.get_state returns {account_name: {connector_name: [balances]}} @@ -272,7 +273,8 @@ async def get_cex_balances( user_data: dict, client, account_name: str, - ttl: int = DEFAULT_CACHE_TTL + ttl: int = DEFAULT_CACHE_TTL, + refresh: bool = False ) -> Dict[str, List[Dict[str, Any]]]: """Get CEX balances with caching. @@ -281,6 +283,7 @@ async def get_cex_balances( client: API client account_name: Account name ttl: Cache TTL in seconds + refresh: If True, force fresh fetch from exchange (slow). Default False uses cached data. Returns: Dict of connector_name -> list of balances @@ -292,7 +295,8 @@ async def get_cex_balances( fetch_cex_balances, ttl, client, - account_name + account_name, + refresh ) @@ -543,7 +547,8 @@ async def get_available_cex_connectors( user_data: dict, client, account_name: str = "master_account", - ttl: int = 300 # 5 min cache + ttl: int = 300, # 5 min cache + server_name: str = "default" ) -> List[str]: """Get available CEX connectors with caching. @@ -552,11 +557,12 @@ async def get_available_cex_connectors( client: API client account_name: Account name to check credentials for ttl: Cache TTL in seconds + server_name: Server name to include in cache key (prevents cross-server cache pollution) Returns: List of available CEX connector names """ - cache_key = f"available_cex_connectors_{account_name}" + cache_key = f"available_cex_connectors_{server_name}_{account_name}" return await cached_call( user_data, cache_key, @@ -582,3 +588,173 @@ def clear_cex_state(context) -> None: context.user_data.pop("place_order_params", None) context.user_data.pop("current_positions", None) context.user_data.pop("current_orders", None) + + +# ============================================ +# TRADING PAIR VALIDATION +# ============================================ + +def _calculate_similarity(s1: str, s2: str) -> float: + """Calculate similarity ratio between two strings using Levenshtein-like approach. + + Returns: + Similarity score between 0 and 1 (1 = exact match) + """ + s1, s2 = s1.upper(), s2.upper() + + if s1 == s2: + return 1.0 + + # Check for partial matches (base token match) + s1_parts = s1.replace("_", "-").split("-") + s2_parts = s2.replace("_", "-").split("-") + + # Exact base token match gets high score + if s1_parts and s2_parts and s1_parts[0] == s2_parts[0]: + # Same base token, check quote similarity + if len(s1_parts) > 1 and len(s2_parts) > 1: + # Both have quote tokens + quote_sim = _levenshtein_ratio(s1_parts[1], s2_parts[1]) + return 0.7 + (0.3 * quote_sim) + return 0.7 + + # Fall back to full string Levenshtein ratio + return _levenshtein_ratio(s1, s2) + + +def _levenshtein_ratio(s1: str, s2: str) -> float: + """Calculate Levenshtein similarity ratio between two strings.""" + if not s1 and not s2: + return 1.0 + if not s1 or not s2: + return 0.0 + + len1, len2 = len(s1), len(s2) + + # Create distance matrix + dp = [[0] * (len2 + 1) for _ in range(len1 + 1)] + + for i in range(len1 + 1): + dp[i][0] = i + for j in range(len2 + 1): + dp[0][j] = j + + for i in range(1, len1 + 1): + for j in range(1, len2 + 1): + cost = 0 if s1[i - 1] == s2[j - 1] else 1 + dp[i][j] = min( + dp[i - 1][j] + 1, # deletion + dp[i][j - 1] + 1, # insertion + dp[i - 1][j - 1] + cost # substitution + ) + + distance = dp[len1][len2] + max_len = max(len1, len2) + return 1 - (distance / max_len) + + +def find_similar_trading_pairs( + input_pair: str, + available_pairs: List[str], + limit: int = 4, + min_similarity: float = 0.3 +) -> List[str]: + """Find trading pairs similar to the input. + + Args: + input_pair: User's input trading pair + available_pairs: List of available trading pairs from trading rules + limit: Maximum number of suggestions to return + min_similarity: Minimum similarity score to include (0-1) + + Returns: + List of similar trading pairs, sorted by similarity (most similar first) + """ + input_normalized = input_pair.upper().replace("_", "-").replace("/", "-") + + # Calculate similarity for each available pair + scored_pairs = [] + for pair in available_pairs: + pair_normalized = pair.upper().replace("_", "-") + score = _calculate_similarity(input_normalized, pair_normalized) + if score >= min_similarity: + scored_pairs.append((pair, score)) + + # Sort by score (descending) and return top matches + scored_pairs.sort(key=lambda x: x[1], reverse=True) + return [pair for pair, _ in scored_pairs[:limit]] + + +async def validate_trading_pair( + user_data: dict, + client, + connector_name: str, + trading_pair: str, + ttl: int = 300 +) -> tuple[bool, Optional[str], List[str]]: + """Validate that a trading pair exists on a connector. + + Args: + user_data: context.user_data dict + client: API client + connector_name: Name of the connector + trading_pair: Trading pair to validate + ttl: Cache TTL for trading rules + + Returns: + Tuple of (is_valid, error_message, suggestions) + - is_valid: True if the pair exists + - error_message: Error message if invalid, None if valid + - suggestions: List of similar trading pairs if invalid, empty if valid + """ + # Normalize input + pair_normalized = trading_pair.upper().replace("_", "-").replace("/", "-") + + # Get trading rules for the connector + trading_rules = await get_trading_rules(user_data, client, connector_name, ttl) + + if not trading_rules: + # No rules available, can't validate - allow through + logger.warning(f"No trading rules available for {connector_name}, skipping validation") + return True, None, [] + + # Get all available pairs + available_pairs = list(trading_rules.keys()) + + # Check for exact match (case-insensitive, normalized) + available_normalized = {p.upper().replace("_", "-"): p for p in available_pairs} + if pair_normalized in available_normalized: + # Return the correctly formatted pair from the exchange + return True, None, [] + + # Pair not found - find suggestions + suggestions = find_similar_trading_pairs(pair_normalized, available_pairs, limit=4) + + error_msg = f"Trading pair '{trading_pair}' not found on {connector_name}" + + return False, error_msg, suggestions + + +def get_correct_pair_format( + trading_rules: Dict[str, Dict[str, Any]], + input_pair: str +) -> Optional[str]: + """Get the correctly formatted trading pair from trading rules. + + Args: + trading_rules: Dict of trading_pair -> rules + input_pair: User's input trading pair + + Returns: + Correctly formatted pair if found, None otherwise + """ + if not trading_rules: + return None + + pair_normalized = input_pair.upper().replace("_", "-").replace("/", "-") + + for pair in trading_rules.keys(): + if pair.upper().replace("_", "-") == pair_normalized: + return pair + + return None diff --git a/handlers/cex/orders.py b/handlers/cex/orders.py index 7f9e849..067fe55 100644 --- a/handlers/cex/orders.py +++ b/handlers/cex/orders.py @@ -12,63 +12,108 @@ logger = logging.getLogger(__name__) -async def handle_search_orders(update: Update, context: ContextTypes.DEFAULT_TYPE, status: str = "OPEN") -> None: - """Handle search orders operation""" +async def handle_search_orders(update: Update, context: ContextTypes.DEFAULT_TYPE, status: str = "ALL") -> None: + """Handle search orders operation + + Status options: + - ALL: All orders with open orders section at top (default) + - FILLED: Only filled orders + - CANCELLED: Only cancelled orders + """ try: - from servers import get_client + from config_manager import get_client + import asyncio chat_id = update.effective_chat.id - client = await get_client(chat_id) - - # Search for orders with specified status - if status == "OPEN": - # Use get_active_orders for real-time open orders from exchange - result = await client.trading.get_active_orders(limit=100) - status_label = "Open Orders" - elif status == "ALL": - result = await client.trading.search_orders(limit=100) - status_label = "All Orders" - else: - result = await client.trading.search_orders( - status=status, - limit=100 - ) - status_label = f"{status.title()} Orders" + client = await get_client(chat_id, context=context) + + keyboard = [] + + if status == "ALL": + # Fetch open orders (from active endpoint) and all orders in parallel + async def get_open(): + try: + result = await client.trading.get_active_orders(limit=50) + return result.get("data", []) + except Exception as e: + logger.warning(f"Error fetching open orders: {e}") + return [] + + async def get_all(): + try: + result = await client.trading.search_orders(limit=100) + return result.get("data", []) + except Exception as e: + logger.warning(f"Error fetching all orders: {e}") + return [] + + open_orders, all_orders = await asyncio.gather(get_open(), get_all()) + + # Store open orders for cancel operations + context.user_data["current_orders"] = open_orders + + # Build set of truly open order IDs for status correction + open_order_ids = { + o.get('client_order_id') or o.get('order_id') + for o in open_orders + } + + # Correct stale "OPEN" status in all_orders based on actual open orders + for order in all_orders: + order_id = order.get('client_order_id') or order.get('order_id') + if order.get('status') == 'OPEN' and order_id not in open_order_ids: + order['status'] = 'FILLED' # Most likely filled + + # Build message with sections + sections = [] + + # Open orders section with cancel buttons + if open_orders: + from utils.telegram_formatters import format_orders_table + open_table = format_orders_table(open_orders[:10]) + sections.append(f"*🟢 Open Orders* \\({len(open_orders)}\\)\n```\n{open_table}\n```") + + # Cancel buttons for open orders + for i, order in enumerate(open_orders[:3]): + pair = order.get('trading_pair', 'N/A') + side = order.get('trade_type', order.get('side', 'N/A')) + button_label = f"❌ Cancel {pair} {side}" + keyboard.append([InlineKeyboardButton(button_label, callback_data=f"cex:cancel_order:{i}")]) + else: + sections.append("*🟢 Open Orders*\n_No open orders_") - orders = result.get("data", []) + # All orders section + if all_orders: + from utils.telegram_formatters import format_orders_table + all_table = format_orders_table(all_orders) + sections.append(f"\n*📋 All Orders* \\({len(all_orders)}\\)\n```\n{all_table}\n```") - # Store orders for cancel operations - context.user_data["current_orders"] = orders + message = "\n".join(sections) - if not orders: - message = f"🔍 *{escape_markdown_v2(status_label)}*\n\nNo orders found\\." - keyboard = [] else: - from utils.telegram_formatters import format_orders_table - orders_table = format_orders_table(orders) - message = f"🔍 *{escape_markdown_v2(status_label)}* \\({len(orders)} found\\)\n\n```\n{orders_table}\n```" - - # Build keyboard with cancel buttons for open orders - keyboard = [] - if status == "OPEN": - for i, order in enumerate(orders[:5]): - pair = order.get('trading_pair', 'N/A') - side = order.get('trade_type', order.get('side', 'N/A')) - order_type = order.get('order_type', 'N/A') - button_label = f"❌ Cancel {pair} {side} {order_type}" - keyboard.append([InlineKeyboardButton(button_label, callback_data=f"cex:cancel_order:{i}")]) + # Specific status filter (FILLED, CANCELLED) + result = await client.trading.search_orders(status=status, limit=100) + orders = result.get("data", []) + context.user_data["current_orders"] = [] + status_label = f"{status.title()} Orders" + emoji = "✅" if status == "FILLED" else "❌" if status == "CANCELLED" else "📋" - if len(orders) > 5: - keyboard.append([InlineKeyboardButton("⋯ More Orders", callback_data="cex:orders_list")]) + if not orders: + message = f"{emoji} *{escape_markdown_v2(status_label)}*\n\n_No orders found_" + else: + from utils.telegram_formatters import format_orders_table + orders_table = format_orders_table(orders) + message = f"{emoji} *{escape_markdown_v2(status_label)}* \\({len(orders)}\\)\n\n```\n{orders_table}\n```" + + # Filter buttons - highlight current filter + def btn(label, action, current): + prefix = "• " if current else "" + return InlineKeyboardButton(f"{prefix}{label}", callback_data=f"cex:{action}") - # Filter buttons - keyboard.append([ - InlineKeyboardButton("Open Orders", callback_data="cex:search_orders"), - InlineKeyboardButton("All Orders", callback_data="cex:search_all"), - ]) keyboard.append([ - InlineKeyboardButton("Filled", callback_data="cex:search_filled"), - InlineKeyboardButton("Cancelled", callback_data="cex:search_cancelled") + btn("All", "search_orders", status == "ALL"), + btn("Filled", "search_filled", status == "FILLED"), + btn("Cancelled", "search_cancelled", status == "CANCELLED"), ]) keyboard.append([InlineKeyboardButton("« Back", callback_data="cex:trade")]) @@ -162,10 +207,10 @@ async def handle_confirm_cancel_order(update: Update, context: ContextTypes.DEFA if not client_order_id: raise ValueError("Order ID not found") - from servers import get_client + from config_manager import get_client chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Cancel the order result = await client.trading.cancel_order( diff --git a/handlers/cex/positions.py b/handlers/cex/positions.py index 4d40d1d..884ac2d 100644 --- a/handlers/cex/positions.py +++ b/handlers/cex/positions.py @@ -15,10 +15,10 @@ async def handle_positions(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle get positions operation""" try: - from servers import get_client + from config_manager import get_client chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Get all positions result = await client.trading.get_positions(limit=100) @@ -236,10 +236,10 @@ async def handle_confirm_close_position(update: Update, context: ContextTypes.DE # Determine the opposite side to close the position close_side = "SELL" if side in ["LONG", "BUY"] else "BUY" - from servers import get_client + from config_manager import get_client chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Place market order to close position result = await client.trading.place_order( diff --git a/handlers/cex/trade.py b/handlers/cex/trade.py index a6a2e6d..be5a101 100644 --- a/handlers/cex/trade.py +++ b/handlers/cex/trade.py @@ -10,7 +10,6 @@ import asyncio import logging -from decimal import Decimal from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes @@ -18,10 +17,11 @@ from handlers.config.user_preferences import ( get_clob_account, get_clob_order_defaults, - get_clob_last_order, set_clob_last_order, + set_last_trade_connector, + get_all_enabled_networks, ) -from servers import get_client +from config_manager import get_client from ._shared import ( get_cached, set_cached, @@ -30,6 +30,8 @@ get_positions, get_trading_rules, get_available_cex_connectors, + validate_trading_pair, + get_correct_pair_format, ) from handlers.dex._shared import format_relative_time @@ -288,7 +290,8 @@ def _build_trade_keyboard(params: dict, is_perpetual: bool = False, def _build_trade_menu_text(user_data: dict, params: dict, balances: dict = None, positions: list = None, orders: list = None, trading_rules: dict = None, - current_price: float = None, quote_data: dict = None) -> str: + current_price: float = None, quote_data: dict = None, + server_name: str = None, server_status: str = "online") -> str: """Build the trade menu text content (swap.py style)""" connector = params.get('connector', 'binance_perpetual') trading_pair = params.get('trading_pair', 'BTC-USDT') @@ -299,8 +302,12 @@ def _build_trade_menu_text(user_data: dict, params: dict, else: base_token, quote_token = trading_pair, 'USDT' - # Build header - help_text = r"📝 *Trade*" + "\n\n" + # Build header with server indicator + if server_name: + status_emoji = {"online": "🟢", "offline": "🔴", "auth_error": "🟠", "error": "⚠️"}.get(server_status, "🟢") + help_text = f"📝 *Trade* \\| _Server: {escape_markdown_v2(server_name)} {status_emoji}_\n\n" + else: + help_text = r"📝 *Trade*" + "\n\n" # Show balances section (with loading placeholder) help_text += r"━━━ Balance ━━━" + "\n" @@ -434,12 +441,40 @@ async def show_trade_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, send_new: bool = False, auto_fetch: bool = True, quote_data: dict = None) -> None: """Display the unified trade menu with balances and data""" + from config_manager import get_config_manager + from handlers.config.user_preferences import get_active_server + params = context.user_data.get("trade_params", {}) connector = params.get("connector", "binance_perpetual") trading_pair = params.get("trading_pair", "BTC-USDT") account = get_clob_account(context.user_data) is_perpetual = _is_perpetual_connector(connector) + # Get current server name and status - with proper access control + cm = get_config_manager() + user_id = context.user_data.get('_user_id') + + # Only use servers the user has access to + if user_id: + accessible_servers = cm.get_accessible_servers(user_id) + all_servers = cm.list_servers() + enabled_accessible = [s for s in accessible_servers if all_servers.get(s, {}).get("enabled", True)] + else: + logger.warning("show_trade_menu called without user_id - cannot verify server access") + all_servers = cm.list_servers() + enabled_accessible = [name for name, cfg in all_servers.items() if cfg.get("enabled", True)] + + preferred = get_active_server(context.user_data) + # Only use preferred if user has access to it + server_name = preferred if preferred and preferred in enabled_accessible else (enabled_accessible[0] if enabled_accessible else None) + server_status = "online" + if server_name: + try: + server_status_info = await cm.check_server_status(server_name) + server_status = server_status_info.get("status", "online") + except Exception: + pass + # Try to get cached data balances = get_cached(context.user_data, f"cex_balances_{account}", ttl=60) positions = get_cached(context.user_data, f"positions_{connector}", ttl=60) if is_perpetual else None @@ -460,7 +495,8 @@ async def show_trade_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, # Build text and keyboard help_text = _build_trade_menu_text( - context.user_data, params, balances, positions, orders, trading_rules, current_price, quote_data + context.user_data, params, balances, positions, orders, trading_rules, current_price, quote_data, + server_name=server_name, server_status=server_status ) keyboard = _build_trade_keyboard(params, is_perpetual, leverage, position_mode) reply_markup = InlineKeyboardMarkup(keyboard) @@ -558,7 +594,7 @@ async def _fetch_trade_data_background( is_perpetual = _is_perpetual_connector(connector) try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) except Exception as e: logger.warning(f"Could not get client for trade data: {e}") return @@ -566,7 +602,9 @@ async def _fetch_trade_data_background( # Define safe fetch functions async def fetch_balances_safe(): try: - return await get_cex_balances(context.user_data, client, account) + # Check if force refresh is needed (e.g., after trade execution) + force_refresh = context.user_data.pop("_force_cex_balance_refresh", False) + return await get_cex_balances(context.user_data, client, account, refresh=force_refresh) except Exception as e: logger.warning(f"Could not fetch balances: {e}") # Cache empty dict so display shows "No balance found" instead of "Loading..." @@ -753,7 +791,7 @@ async def handle_trade_get_quote(update: Update, context: ContextTypes.DEFAULT_T volume = float(str(amount).replace("$", "")) chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # If amount is in USD, we need to convert to base token volume if "$" in str(amount): @@ -882,35 +920,116 @@ async def handle_trade_toggle_position(update: Update, context: ContextTypes.DEF async def handle_trade_set_connector(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Show available CEX connectors for selection""" - help_text = r"🔌 *Select Connector*" + """Show available CEX connectors and DEX networks for selection""" + from config_manager import get_config_manager + from handlers.config.user_preferences import get_active_server + chat_id = update.effective_chat.id keyboard = [] + # Get server name for cache keying - with proper access control + cm = get_config_manager() + user_id = context.user_data.get('_user_id') + + # Only use servers the user has access to + if user_id: + accessible_servers = cm.get_accessible_servers(user_id) + all_servers = cm.list_servers() + enabled_accessible = [s for s in accessible_servers if all_servers.get(s, {}).get("enabled", True)] + else: + all_servers = cm.list_servers() + enabled_accessible = [name for name, cfg in all_servers.items() if cfg.get("enabled", True)] + + preferred = get_active_server(context.user_data) + server_name = preferred if preferred and preferred in enabled_accessible else (enabled_accessible[0] if enabled_accessible else "default") + try: - chat_id = update.effective_chat.id - client = await get_client(chat_id) - cex_connectors = await get_available_cex_connectors(context.user_data, client) - - # Build buttons (2 per row) - row = [] - for connector in cex_connectors: - row.append(InlineKeyboardButton( - connector, - callback_data=f"cex:trade_connector_{connector}" - )) - if len(row) == 2: + client = await get_client(chat_id, context=context) + + # Fetch CEX connectors and DEX networks + from handlers import is_gateway_network + + async def get_cex_connectors_list(): + try: + return await get_available_cex_connectors(context.user_data, client, server_name=server_name) + except Exception as e: + logger.warning(f"Could not fetch CEX connectors: {e}") + return [] + + async def get_dex_networks(): + try: + response = await client.gateway.list_networks() + all_networks = response.get('networks', []) + + # Filter to only show networks enabled in user's wallet configurations + enabled_networks = get_all_enabled_networks(context.user_data) + if enabled_networks is None: + # No wallets configured, show all networks + return all_networks + + # Filter networks to only those enabled for user's wallets + filtered = [] + for network_item in all_networks: + if isinstance(network_item, dict): + network_id = network_item.get('network_id') or network_item.get('id') or str(network_item) + else: + network_id = str(network_item) + + if network_id in enabled_networks: + filtered.append(network_item) + + return filtered + except Exception as e: + logger.warning(f"Could not fetch DEX networks: {e}") + return [] + + cex_connectors, dex_networks = await asyncio.gather( + get_cex_connectors_list(), + get_dex_networks() + ) + + # CEX section + if cex_connectors: + keyboard.append([InlineKeyboardButton("━━ CEX ━━", callback_data="cex:noop")]) + row = [] + for connector in cex_connectors: + row.append(InlineKeyboardButton( + connector, + callback_data=f"cex:trade_connector_{connector}" + )) + if len(row) == 2: + keyboard.append(row) + row = [] + if row: keyboard.append(row) - row = [] - if row: - keyboard.append(row) - if not cex_connectors: - help_text += "\n\n_No CEX connectors available_" + # DEX section + if dex_networks: + keyboard.append([InlineKeyboardButton("━━ DEX ━━", callback_data="cex:noop")]) + row = [] + for network_item in dex_networks: + if isinstance(network_item, dict): + network_id = network_item.get('network_id') or network_item.get('id') or str(network_item) + else: + network_id = str(network_item) + row.append(InlineKeyboardButton( + network_id, + callback_data=f"cex:switch_dex_{network_id}" + )) + if len(row) == 2: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + if not cex_connectors and not dex_networks: + help_text = r"🔄 *Select Connector*" + "\n\n" + r"_No connectors available\._" + else: + help_text = r"🔄 *Select Connector*" except Exception as e: logger.error(f"Error fetching connectors: {e}", exc_info=True) - help_text += "\n\n_Could not fetch connectors_" + help_text = r"🔄 *Select Connector*" + "\n\n" + r"_Could not fetch connectors_" keyboard.append([InlineKeyboardButton("« Back", callback_data="cex:trade")]) reply_markup = InlineKeyboardMarkup(keyboard) @@ -927,6 +1046,9 @@ async def handle_trade_connector_select(update: Update, context: ContextTypes.DE params = context.user_data.get("trade_params", {}) params["connector"] = connector_name + # Save unified preference for /trade command + set_last_trade_connector(context.user_data, "cex", connector_name) + _invalidate_trade_cache(context.user_data) invalidate_cache(context.user_data, "balances", "positions", "trading_rules") context.user_data["cex_state"] = "trade" @@ -1044,7 +1166,7 @@ async def handle_trade_toggle_pos_mode(update: Update, context: ContextTypes.DEF try: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Get current mode current_mode = context.user_data.get(_get_position_mode_cache_key(connector), "HEDGE") @@ -1094,7 +1216,7 @@ async def handle_trade_execute(update: Update, context: ContextTypes.DEFAULT_TYP raise ValueError("Price required for LIMIT orders") chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Handle USD amount is_quote_amount = "$" in str(amount) @@ -1120,8 +1242,9 @@ async def handle_trade_execute(update: Update, context: ContextTypes.DEFAULT_TYP position_action=position_action, ) - # Invalidate cache + # Invalidate cache and flag for refresh on next fetch invalidate_cache(context.user_data, "balances", "orders", "positions") + context.user_data["_force_cex_balance_refresh"] = True # Save for quick repeat set_clob_last_order(context.user_data, { @@ -1227,7 +1350,10 @@ async def process_trade( context: ContextTypes.DEFAULT_TYPE, user_input: str ) -> None: - """Process trade from text input: pair side amount [type] [price] [position]""" + """Process trade from text input: pair side amount [type] [price] [position] + + This is a QUICK TRADE - it executes immediately, not just updates params. + """ try: parts = user_input.split() @@ -1237,6 +1363,7 @@ async def process_trade( # Get current connector from params params = context.user_data.get("trade_params", {}) connector = params.get("connector", "binance_perpetual") + account = get_clob_account(context.user_data) trading_pair = parts[0].upper() side = parts[1].upper() @@ -1245,7 +1372,53 @@ async def process_trade( price = parts[4] if len(parts) > 4 else None position_action = parts[5].upper() if len(parts) > 5 else "OPEN" - # Update params + # Validate + if side not in ("BUY", "SELL"): + raise ValueError(f"Invalid side: {side}. Use BUY or SELL") + if order_type not in ("MARKET", "LIMIT", "LIMIT_MAKER"): + raise ValueError(f"Invalid type: {order_type}. Use MARKET, LIMIT, or LIMIT_MAKER") + if order_type in ("LIMIT", "LIMIT_MAKER") and not price: + raise ValueError("Price required for LIMIT orders") + + # Delete user's input message + try: + await update.message.delete() + except Exception: + pass + + chat_id = update.effective_chat.id + client = await get_client(chat_id, context=context) + + # Handle USD amount + is_quote_amount = "$" in str(amount) + if is_quote_amount: + usd_value = float(str(amount).replace("$", "")) + prices = await client.market_data.get_prices( + connector_name=connector, + trading_pairs=trading_pair + ) + current_price = prices["prices"][trading_pair] + amount_float = usd_value / current_price + else: + amount_float = float(amount) + + # Execute the trade + result = await client.trading.place_order( + account_name=account, + connector_name=connector, + trading_pair=trading_pair, + trade_type=side, + amount=amount_float, + order_type=order_type, + price=float(price) if price and order_type in ["LIMIT", "LIMIT_MAKER"] else None, + position_action=position_action, + ) + + # Invalidate cache and flag for refresh on next fetch + invalidate_cache(context.user_data, "balances", "orders", "positions") + context.user_data["_force_cex_balance_refresh"] = True + + # Update params for next trade context.user_data["trade_params"] = { "connector": connector, "trading_pair": trading_pair, @@ -1256,14 +1429,67 @@ async def process_trade( "position_mode": position_action, } - _invalidate_trade_cache(context.user_data) - context.user_data["cex_state"] = "trade" + # Save for quick repeat + set_clob_last_order(context.user_data, { + "connector": connector, + "trading_pair": trading_pair, + "side": side, + "order_type": order_type, + "position_mode": position_action, + "amount": amount, + "price": price if price else "—", + }) - await _update_trade_menu_after_input(update, context) + # Build success message + order_info = escape_markdown_v2( + f"✅ Order placed!\n\n" + f"Pair: {trading_pair}\n" + f"Side: {side}\n" + f"Amount: {amount_float:.6f}\n" + f"Type: {order_type}" + ) + + if price and order_type in ["LIMIT", "LIMIT_MAKER"]: + order_info += escape_markdown_v2(f"\nPrice: {price}") + + if "order_id" in result: + order_info += escape_markdown_v2(f"\nOrder ID: {result['order_id']}") + + keyboard = [[InlineKeyboardButton("« Back to Trade", callback_data="cex:trade")]] + reply_markup = InlineKeyboardMarkup(keyboard) + + # Update the trade menu message with success + msg_id = context.user_data.get("trade_menu_message_id") + menu_chat_id = context.user_data.get("trade_menu_chat_id") + + if msg_id and menu_chat_id: + try: + await context.bot.edit_message_text( + chat_id=menu_chat_id, + message_id=msg_id, + text=order_info, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + except Exception as e: + logger.debug(f"Could not update trade menu: {e}") + await update.effective_chat.send_message( + order_info, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + else: + await update.effective_chat.send_message( + order_info, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + context.user_data["cex_state"] = "trade" except Exception as e: - logger.error(f"Error processing trade input: {e}", exc_info=True) - error_message = format_error_message(f"Invalid: {str(e)}") + logger.error(f"Error processing quick trade: {e}", exc_info=True) + error_message = format_error_message(f"Trade failed: {str(e)}") await update.message.reply_text(error_message, parse_mode="MarkdownV2") @@ -1272,21 +1498,116 @@ async def process_trade_set_pair( context: ContextTypes.DEFAULT_TYPE, user_input: str ) -> None: - """Process trading pair input""" + """Process trading pair input with validation against available markets""" try: params = context.user_data.get("trade_params", {}) - params["trading_pair"] = user_input.strip().upper() + connector = params.get("connector", "binance_perpetual") + pair_input = user_input.strip().upper().replace("_", "-").replace("/", "-") - _invalidate_trade_cache(context.user_data) - context.user_data["cex_state"] = "trade" + chat_id = update.effective_chat.id + client = await get_client(chat_id, context=context) - await _update_trade_menu_after_input(update, context) + # Validate trading pair exists on the connector + is_valid, error_msg, suggestions = await validate_trading_pair( + context.user_data, client, connector, pair_input + ) + + if is_valid: + # Get correctly formatted pair from trading rules + trading_rules = await get_trading_rules(context.user_data, client, connector) + correct_pair = get_correct_pair_format(trading_rules, pair_input) + params["trading_pair"] = correct_pair if correct_pair else pair_input + + _invalidate_trade_cache(context.user_data) + context.user_data["cex_state"] = "trade" + + await _update_trade_menu_after_input(update, context) + else: + # Show error with suggestions + await _show_pair_suggestions(update, context, pair_input, error_msg, suggestions) except Exception as e: + logger.error(f"Error processing trading pair: {e}", exc_info=True) error_message = format_error_message(f"Failed: {str(e)}") await update.message.reply_text(error_message, parse_mode="MarkdownV2") +async def _show_pair_suggestions( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + input_pair: str, + error_msg: str, + suggestions: list +) -> None: + """Show trading pair suggestions when validation fails""" + # Delete user's input message + try: + await update.message.delete() + except Exception: + pass + + # Build suggestion message + help_text = f"❌ *{escape_markdown_v2(error_msg)}*\n\n" + + if suggestions: + help_text += "💡 *Did you mean:*\n" + else: + help_text += "_No similar pairs found\\._\n" + + # Build keyboard with suggestions + keyboard = [] + for pair in suggestions: + keyboard.append([InlineKeyboardButton( + f"📈 {pair}", + callback_data=f"cex:trade_pair_select_{pair}" + )]) + + keyboard.append([InlineKeyboardButton("« Back", callback_data="cex:trade")]) + reply_markup = InlineKeyboardMarkup(keyboard) + + # Update the stored trade menu message + msg_id = context.user_data.get("trade_menu_message_id") + menu_chat_id = context.user_data.get("trade_menu_chat_id") + + if msg_id and menu_chat_id: + try: + await context.bot.edit_message_text( + chat_id=menu_chat_id, + message_id=msg_id, + text=help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + except Exception as e: + logger.debug(f"Could not update trade menu: {e}") + await update.effective_chat.send_message( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + else: + await update.effective_chat.send_message( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + +async def handle_trade_pair_select( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + trading_pair: str +) -> None: + """Handle selection of a suggested trading pair""" + params = context.user_data.get("trade_params", {}) + params["trading_pair"] = trading_pair + + _invalidate_trade_cache(context.user_data) + context.user_data["cex_state"] = "trade" + + await show_trade_menu(update, context) + + async def process_trade_set_amount( update: Update, context: ContextTypes.DEFAULT_TYPE, @@ -1344,7 +1665,7 @@ async def process_trade_set_leverage( account = get_clob_account(context.user_data) chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Set leverage on exchange await client.trading.set_leverage( diff --git a/handlers/config/__init__.py b/handlers/config/__init__.py index 0c304e2..6cbc8ae 100644 --- a/handlers/config/__init__.py +++ b/handlers/config/__init__.py @@ -23,74 +23,120 @@ def clear_config_state(context: ContextTypes.DEFAULT_TYPE) -> None: clear_all_input_states(context) -def _get_config_menu_markup_and_text(): - """ - Build the main config menu keyboard and message text - """ +def _get_start_menu_keyboard(is_admin: bool = False) -> InlineKeyboardMarkup: + """Build the start menu inline keyboard.""" keyboard = [ [ - InlineKeyboardButton("🔌 API Servers", callback_data="config_api_servers"), - InlineKeyboardButton("🔑 API Keys", callback_data="config_api_keys"), - InlineKeyboardButton("🌐 Gateway", callback_data="config_gateway"), - ], - [ - InlineKeyboardButton("❌ Cancel", callback_data="config_close"), + InlineKeyboardButton("🔌 Servers", callback_data="start:config_servers"), + InlineKeyboardButton("🔑 Keys", callback_data="start:config_keys"), + InlineKeyboardButton("🌐 Gateway", callback_data="start:config_gateway"), ], ] - reply_markup = InlineKeyboardMarkup(keyboard) - - message_text = ( - r"⚙️ *Configuration Menu*" + "\n\n" - r"Select a configuration category:" + "\n\n" - r"🔌 *API Servers* \- Manage Hummingbot API instances" + "\n" - r"🔑 *API Keys* \- Manage exchange credentials" + "\n" - r"🌐 *Gateway* \- Manage Gateway container and DEX configuration" - ) - - return reply_markup, message_text + if is_admin: + keyboard.append([InlineKeyboardButton("👑 Admin", callback_data="start:admin")]) + keyboard.append([InlineKeyboardButton("❌ Cancel", callback_data="start:cancel")]) + return InlineKeyboardMarkup(keyboard) -async def show_config_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: +async def _show_start_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """ - Show the main config menu + Show the /start menu (replaces config menu for Back navigation). + Mirrors the logic from main.py start() but for callback query context. """ - reply_markup, message_text = _get_config_menu_markup_and_text() - - await query.message.edit_text( - message_text, - parse_mode="MarkdownV2", - reply_markup=reply_markup - ) + import asyncio + from config_manager import get_config_manager, get_effective_server + from handlers.config.server_context import get_gateway_status_info + from utils.telegram_formatters import escape_markdown_v2 + + user_id = query.from_user.id + chat_id = query.message.chat_id + cm = get_config_manager() + is_admin = cm.is_admin(user_id) + + # Get all servers and their statuses in parallel + servers = cm.list_servers() + active_server = get_effective_server(chat_id, context.user_data) or cm.get_default_server() + + server_statuses = {} + active_server_online = False + + if servers: + # Query all server statuses in parallel + status_tasks = [cm.check_server_status(name) for name in servers] + status_results = await asyncio.gather(*status_tasks, return_exceptions=True) + + for server_name, status_result in zip(servers, status_results): + if isinstance(status_result, Exception): + status = "error" + else: + status = status_result.get("status", "unknown") + server_statuses[server_name] = status + if server_name == active_server and status == "online": + active_server_online = True + + # Build servers list display + servers_display = "" + if servers: + for server_name in servers: + status = server_statuses.get(server_name, "unknown") + icon = "🟢" if status == "online" else "🔴" + is_active = " ⭐" if server_name == active_server else "" + server_escaped = escape_markdown_v2(server_name) + servers_display += f" {icon} `{server_escaped}`{is_active}\n" + else: + servers_display = " _No servers configured_\n" + + # Get gateway and accounts info only if active server is online + extra_info = "" + if active_server_online: + try: + gateway_header, _ = await get_gateway_status_info(chat_id, context.user_data) + extra_info += gateway_header + + client = await cm.get_client_for_chat(chat_id, preferred_server=active_server) + accounts = await client.accounts.list_accounts() + if accounts: + total_creds = 0 + for account in accounts: + try: + creds = await client.accounts.list_account_credentials(account_name=str(account)) + total_creds += len(creds) if creds else 0 + except Exception: + pass + accounts_escaped = escape_markdown_v2(str(len(accounts))) + creds_escaped = escape_markdown_v2(str(total_creds)) + extra_info += f"*Accounts:* {accounts_escaped} \\({creds_escaped} keys\\)\n" + except Exception as e: + logger.warning(f"Failed to get extra info: {e}") + + # Build the message + admin_badge = " 👑" if is_admin else "" + capabilities = """_Trade CEX/DEX, manage bots, monitor portfolio_""" + + # Offline help message + offline_help = "" + if not active_server_online and servers: + offline_help = """ +⚠️ *Active server is offline* +• Ensure `hummingbot\\-backend\\-api` is running +• Or select an online server below +""" + # Menu descriptions + menu_help = r""" +🔌 *Servers* \- Add/manage Hummingbot API servers +🔑 *Keys* \- Connect exchange API credentials +🌐 *Gateway* \- Deploy Gateway for DEX trading +""" -@restricted -async def config_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """ - Handle /config command - Show configuration options + reply_text = rf""" +🦅 *Condor*{admin_badge} +{capabilities} - Displays a menu with configuration categories: - - API Servers (Hummingbot instances) - - API Keys (Exchange API credentials) - - Gateway (Gateway container and DEX operations) - """ - # Clear all pending input states to prevent interference - clear_config_state(context) - - reply_markup, message_text = _get_config_menu_markup_and_text() - - # Handle both direct command and callback query invocations - if update.message: - await update.message.reply_text( - message_text, - parse_mode="MarkdownV2", - reply_markup=reply_markup - ) - elif update.callback_query: - await update.callback_query.message.reply_text( - message_text, - parse_mode="MarkdownV2", - reply_markup=reply_markup - ) +*Servers:* +{servers_display}{offline_help}{extra_info}{menu_help}""" + keyboard = _get_start_menu_keyboard(is_admin=is_admin) + await query.message.edit_text(reply_text, parse_mode="MarkdownV2", reply_markup=keyboard) async def config_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -109,7 +155,7 @@ async def config_callback_handler(update: Update, context: ContextTypes.DEFAULT_ await query.message.delete() return elif query.data == "config_back": - await show_config_menu(query, context) + await _show_start_menu(query, context) return # Route to appropriate sub-module based on callback data prefix @@ -119,6 +165,9 @@ async def config_callback_handler(update: Update, context: ContextTypes.DEFAULT_ await handle_api_keys_callback(update, context) elif query.data == "config_gateway" or query.data.startswith("gateway_"): await handle_gateway_callback(update, context) + elif query.data == "config_admin" or query.data.startswith("admin:"): + from handlers.admin import admin_callback_handler + await admin_callback_handler(update, context) # Create callback handler instance for registration @@ -126,7 +175,7 @@ def get_config_callback_handler(): """Get the callback query handler for config menu""" return CallbackQueryHandler( config_callback_handler, - pattern="^config_|^modify_field_|^add_server_|^api_server_|^api_key_|^gateway_" + pattern="^config_|^modify_field_|^add_server_|^api_server_|^api_key_|^gateway_|^admin:" ) @@ -191,6 +240,12 @@ async def handle_all_text_input(update: Update, context: ContextTypes.DEFAULT_TY await routines_message_handler(update, context) return + # 8. Check server share state + if context.user_data.get('awaiting_share_user_id'): + from handlers.config.servers import handle_share_user_id_input + await handle_share_user_id_input(update, context) + return + # No active state - ignore the message logger.debug(f"No active input state for message: {update.message.text[:50] if update.message else 'N/A'}...") diff --git a/handlers/config/api_keys.py b/handlers/config/api_keys.py index 6fc662e..3629b88 100644 --- a/handlers/config/api_keys.py +++ b/handlers/config/api_keys.py @@ -2,36 +2,72 @@ API Keys configuration management handlers """ +import asyncio import logging import base64 from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes from utils.telegram_formatters import escape_markdown_v2 +from utils.auth import restricted from .server_context import build_config_message_header, format_server_selection_needed +from .user_preferences import get_active_server logger = logging.getLogger(__name__) +# Default account name used for all API key operations +DEFAULT_ACCOUNT = "master_account" + + +@restricted +async def keys_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /keys command - show API keys configuration directly.""" + from handlers import clear_all_input_states + from utils.telegram_helpers import create_mock_query_from_message + + clear_all_input_states(context) + mock_query = await create_mock_query_from_message(update, "Loading API keys...") + await show_api_keys(mock_query, context) + + +async def get_default_account(client) -> str: + """ + Get the default account to use for API key operations. + Returns the first available account from the backend, or DEFAULT_ACCOUNT if none exist. + """ + try: + accounts = await client.accounts.list_accounts() + if accounts: + return str(accounts[0]) + except Exception: + pass + return DEFAULT_ACCOUNT + async def show_api_keys(query, context: ContextTypes.DEFAULT_TYPE) -> None: """ - Show API keys configuration with account selection + Show API keys configuration with Perpetual/Spot connector selection """ + # Clear bots state to prevent bots handler from intercepting API key input + # This is needed when navigating here from Grid Strike or PMM wizards + context.user_data.pop('bots_state', None) + try: - from servers import server_manager + from config_manager import get_config_manager - servers = server_manager.list_servers() + servers = get_config_manager().list_servers() if not servers: message_text = format_server_selection_needed() - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] else: # Build header with server context chat_id = query.message.chat_id header, server_online, _ = await build_config_message_header( "🔑 API Keys", include_gateway=False, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) if not server_online: @@ -39,65 +75,68 @@ async def show_api_keys(query, context: ContextTypes.DEFAULT_TYPE) -> None: header + "⚠️ _Server is offline\\. Cannot manage API keys\\._" ) - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] else: # Get client from per-chat server - client = await server_manager.get_client_for_chat(chat_id) - accounts = await client.accounts.list_accounts() - - if not accounts: - message_text = ( - header + - "No accounts configured\\.\n\n" - "_Create accounts in Hummingbot first\\._" - ) - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) + + # Get the default account to use + account_name = await get_default_account(client) + + # Get credentials for the account + try: + credentials = await client.accounts.list_account_credentials(account_name=account_name) + cred_list = credentials if credentials else [] + except Exception as e: + logger.warning(f"Failed to get credentials for {account_name}: {e}") + cred_list = [] + + # Separate credentials into perpetual and spot + perp_creds = [c for c in cred_list if c.endswith('_perpetual')] + spot_creds = [c for c in cred_list if not c.endswith('_perpetual')] + + # Store credentials in context for callback handling + context.user_data['api_key_current_account'] = account_name + context.user_data['api_key_credentials'] = cred_list + + # Build keyboard with credential buttons for deletion + keyboard = [] + + # Add perpetual credential buttons + if perp_creds: + for i, cred in enumerate(perp_creds): + # Store index for lookup + keyboard.append([ + InlineKeyboardButton(f"📈 {cred}", callback_data=f"api_key_manage:{i}") + ]) + + # Add spot credential buttons + if spot_creds: + for i, cred in enumerate(spot_creds): + # Offset index by perp count + idx = len(perp_creds) + i + keyboard.append([ + InlineKeyboardButton(f"💱 {cred}", callback_data=f"api_key_manage:{idx}") + ]) + + # Build message text + if cred_list: + creds_display = "_Tap a key to manage it\\._\n\n" else: - # Build account list with credentials info - account_lines = [] - for account in accounts: - account_name = str(account) - # Get credentials for this account - try: - credentials = await client.accounts.list_account_credentials(account_name=account_name) - cred_count = len(credentials) if credentials else 0 - - account_escaped = escape_markdown_v2(account_name) - if cred_count > 0: - creds_text = escape_markdown_v2(", ".join(credentials)) - account_lines.append(f"• *{account_escaped}* \\({cred_count} connected\\)\n _{creds_text}_") - else: - account_lines.append(f"• *{account_escaped}* \\(no credentials\\)") - except Exception as e: - logger.warning(f"Failed to get credentials for {account_name}: {e}") - account_escaped = escape_markdown_v2(account_name) - account_lines.append(f"• *{account_escaped}*") - - message_text = ( - header + - "\n".join(account_lines) + "\n\n" - "_Select an account to manage exchange credentials:_" - ) - - # Create account buttons in grid of 4 per row - # Use base64 encoding to avoid issues with special characters in account names - account_buttons = [] - for account in accounts: - account_name = str(account) - # Encode account name to avoid issues with underscores and special chars - encoded_name = base64.b64encode(account_name.encode()).decode() - account_buttons.append( - InlineKeyboardButton(account_name, callback_data=f"api_key_account:{encoded_name}") - ) - - # Organize into rows of max 4 columns - account_button_rows = [] - for i in range(0, len(account_buttons), 4): - account_button_rows.append(account_buttons[i:i+4]) - - keyboard = account_button_rows + [ - [InlineKeyboardButton("« Back", callback_data="config_back")] - ] + creds_display = "_No exchanges connected yet\\._\n\n" + + message_text = ( + header + + creds_display + + "_Select exchange type to add a new key:_" + ) + + # Add type selection buttons + keyboard.append([ + InlineKeyboardButton("➕ Perpetual", callback_data="api_key_type:perpetual"), + InlineKeyboardButton("➕ Spot", callback_data="api_key_type:spot") + ]) + keyboard.append([InlineKeyboardButton("« Close", callback_data="config_close")]) reply_markup = InlineKeyboardMarkup(keyboard) @@ -118,7 +157,121 @@ async def show_api_keys(query, context: ContextTypes.DEFAULT_TYPE) -> None: except Exception as e: logger.error(f"Error showing API keys: {e}", exc_info=True) error_text = f"❌ Error loading API keys: {escape_markdown_v2(str(e))}" - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] + reply_markup = InlineKeyboardMarkup(keyboard) + await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) + + +async def show_connectors_by_type(query, context: ContextTypes.DEFAULT_TYPE, connector_type: str) -> None: + """ + Show connectors filtered by type (perpetual or spot) + """ + try: + from config_manager import get_config_manager + + chat_id = query.message.chat_id + is_perpetual = connector_type == "perpetual" + type_label = "Perpetual" if is_perpetual else "Spot" + type_emoji = "📈" if is_perpetual else "💱" + + # Build header with server context + header, server_online, _ = await build_config_message_header( + f"🔑 {type_emoji} {type_label} Exchanges", + include_gateway=False, + chat_id=chat_id, + user_data=context.user_data + ) + + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) + + # Get the default account to use + account_name = await get_default_account(client) + context.user_data['api_key_current_account'] = account_name + + # Get credentials for the account + try: + credentials = await client.accounts.list_account_credentials(account_name=account_name) + cred_list = credentials if credentials else [] + except Exception as e: + logger.warning(f"Failed to get credentials for {account_name}: {e}") + cred_list = [] + + # Filter credentials by type + if is_perpetual: + type_creds = [c for c in cred_list if c.endswith('_perpetual')] + else: + type_creds = [c for c in cred_list if not c.endswith('_perpetual')] + + # Store credentials in context for delete functionality + context.user_data['api_key_credentials'] = type_creds + context.user_data['api_key_connector_type'] = connector_type + + keyboard = [] + + if type_creds: + # Build list of connected credentials with delete buttons + cred_lines = ["*Connected:*"] + for i, cred in enumerate(type_creds): + cred_escaped = escape_markdown_v2(str(cred)) + cred_lines.append(f" ✅ {cred_escaped}") + # Add delete button for each credential + keyboard.append([ + InlineKeyboardButton(f"🗑 Delete {cred}", callback_data=f"api_key_delete_cred:{i}") + ]) + creds_display = "\n".join(cred_lines) + "\n\n" + else: + creds_display = "_No exchanges connected yet\\._\n\n" + + message_text = ( + header + + creds_display + + "_Select an exchange to configure:_\n" + ) + + # Get list of available connectors + all_connectors = await client.connectors.list_connectors() + + # Filter out testnet connectors and gateway connectors (those with '/' like "uniswap/ethereum") + connectors = [c for c in all_connectors if 'testnet' not in c.lower() and '/' not in c] + + # Filter by type + if is_perpetual: + connectors = [c for c in connectors if c.endswith('_perpetual')] + else: + connectors = [c for c in connectors if not c.endswith('_perpetual')] + + # Store connector list in context + context.user_data['api_key_connectors'] = connectors + + # Create connector buttons + connector_buttons = [] + for i, connector in enumerate(connectors): + # Use index instead of full names to keep callback_data short + connector_buttons.append( + InlineKeyboardButton(connector, callback_data=f"api_key_connector:{i}") + ) + + # Organize into rows of 2 columns for better readability + connector_button_rows = [] + for i in range(0, len(connector_buttons), 2): + connector_button_rows.append(connector_buttons[i:i+2]) + + keyboard = keyboard + connector_button_rows + [ + [InlineKeyboardButton("« Back", callback_data="api_key_back_to_accounts")] + ] + + reply_markup = InlineKeyboardMarkup(keyboard) + + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + except Exception as e: + logger.error(f"Error showing connectors by type: {e}", exc_info=True) + error_text = f"❌ Error loading connectors: {escape_markdown_v2(str(e))}" + keyboard = [[InlineKeyboardButton("« Back", callback_data="api_key_back_to_accounts")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) @@ -129,24 +282,22 @@ async def handle_api_key_action(query, context: ContextTypes.DEFAULT_TYPE) -> No """ action_data = query.data.replace("api_key_", "") - if action_data.startswith("account:"): - # Decode base64 encoded account name - encoded_name = action_data.replace("account:", "") - try: - account_name = base64.b64decode(encoded_name.encode()).decode() - await show_account_credentials(query, context, account_name) - except Exception as e: - logger.error(f"Failed to decode account name: {e}") - await query.answer("❌ Invalid account name") + if action_data.startswith("type:"): + # Handle perpetual/spot selection + connector_type = action_data.replace("type:", "") + await show_connectors_by_type(query, context, connector_type) + elif action_data.startswith("account:"): + # Legacy handler - redirect to main API keys view + await show_api_keys(query, context) elif action_data.startswith("connector:"): # Format: connector:{index} - # Retrieve account and connector from context + # Retrieve connector from context (use account from context) try: connector_index = int(action_data.replace("connector:", "")) - account_name = context.user_data.get('api_key_current_account') connectors = context.user_data.get('api_key_connectors', []) + account_name = context.user_data.get('api_key_current_account', DEFAULT_ACCOUNT) - if account_name and 0 <= connector_index < len(connectors): + if 0 <= connector_index < len(connectors): connector_name = connectors[connector_index] await show_connector_config(query, context, account_name, connector_name) else: @@ -180,28 +331,36 @@ async def handle_api_key_action(query, context: ContextTypes.DEFAULT_TYPE) -> No else: await query.answer("Cannot go back") elif action_data.startswith("back_account:"): - # Back to specific account view - also clear API key config state + # Legacy handler - clear state and redirect to main API keys view context.user_data.pop('configuring_api_key', None) context.user_data.pop('awaiting_api_key_input', None) context.user_data.pop('api_key_config_data', None) - - encoded_name = action_data.replace("back_account:", "") - try: - account_name = base64.b64decode(encoded_name.encode()).decode() - await show_account_credentials(query, context, account_name) - except Exception as e: - logger.error(f"Failed to decode account name: {e}") - await query.answer("❌ Invalid account name") + await show_api_keys(query, context) elif action_data == "back_to_accounts": await show_api_keys(query, context) + elif action_data.startswith("manage:"): + # Show manage options for a credential (delete) + try: + cred_index = int(action_data.replace("manage:", "")) + credentials = context.user_data.get('api_key_credentials', []) + account_name = context.user_data.get('api_key_current_account', DEFAULT_ACCOUNT) + + if 0 <= cred_index < len(credentials): + connector_name = credentials[cred_index] + await show_credential_manage_menu(query, context, cred_index, connector_name) + else: + await query.answer("❌ Session expired, please try again") + except (ValueError, IndexError) as e: + logger.error(f"Failed to parse credential index: {e}") + await query.answer("❌ Invalid credential data") elif action_data.startswith("delete_cred:"): - # Handle credential deletion + # Handle credential deletion (use account from context) try: cred_index = int(action_data.replace("delete_cred:", "")) - account_name = context.user_data.get('api_key_current_account') credentials = context.user_data.get('api_key_credentials', []) + account_name = context.user_data.get('api_key_current_account', DEFAULT_ACCOUNT) - if account_name and 0 <= cred_index < len(credentials): + if 0 <= cred_index < len(credentials): connector_name = credentials[cred_index] # Show confirmation dialog await show_delete_credential_confirmation(query, context, account_name, connector_name) @@ -212,13 +371,13 @@ async def handle_api_key_action(query, context: ContextTypes.DEFAULT_TYPE) -> No logger.error(f"Failed to parse credential index: {e}") await query.answer("❌ Invalid credential data") elif action_data.startswith("delete_cred_confirm:"): - # Confirm credential deletion + # Confirm credential deletion (use account from context) try: cred_index = int(action_data.replace("delete_cred_confirm:", "")) - account_name = context.user_data.get('api_key_current_account') credentials = context.user_data.get('api_key_credentials', []) + account_name = context.user_data.get('api_key_current_account', DEFAULT_ACCOUNT) - if account_name and 0 <= cred_index < len(credentials): + if 0 <= cred_index < len(credentials): connector_name = credentials[cred_index] await delete_credential(query, context, account_name, connector_name) else: @@ -227,133 +386,71 @@ async def handle_api_key_action(query, context: ContextTypes.DEFAULT_TYPE) -> No logger.error(f"Failed to delete credential: {e}") await query.answer("❌ Failed to delete credential") elif action_data == "delete_cred_cancel": - # Cancel credential deletion - go back to account view - account_name = context.user_data.get('api_key_current_account') - if account_name: - await show_account_credentials(query, context, account_name) - else: - await show_api_keys(query, context) + # Cancel credential deletion - go back to keys menu + await show_api_keys(query, context) + elif action_data.startswith("select:"): + # Handle Literal type option selection + selected_value = action_data.replace("select:", "") + await _handle_field_value_selection(query, context, selected_value) + elif action_data == "skip": + # Skip optional field + await _handle_skip_optional_field(query, context) else: await query.answer("Unknown action") -async def show_account_credentials(query, context: ContextTypes.DEFAULT_TYPE, account_name: str) -> None: - """ - Show connected credentials for a specific account - """ - try: - from servers import server_manager - - chat_id = query.message.chat_id - - # Build header with server context - header, server_online, _ = await build_config_message_header( - f"🔑 API Keys", - include_gateway=False, - chat_id=chat_id - ) - - client = await server_manager.get_client_for_chat(chat_id) - - # Get list of connected credentials for this account - credentials = await client.accounts.list_account_credentials(account_name=account_name) - - account_escaped = escape_markdown_v2(account_name) - - # Store credentials in context for delete functionality - context.user_data['api_key_credentials'] = credentials if credentials else [] - - if not credentials: - message_text = ( - header + - f"*Account:* `{account_escaped}`\n\n" - "No exchange credentials connected\\.\n\n" - "Select an exchange below to add credentials:\n\n" - ) - keyboard = [] - else: - # Build list of connected credentials with delete buttons - cred_lines = [] - credential_buttons = [] - for i, cred in enumerate(credentials): - cred_escaped = escape_markdown_v2(str(cred)) - cred_lines.append(f" ✅ {cred_escaped}") - # Add delete button for each credential - credential_buttons.append([ - InlineKeyboardButton(f"🗑 Delete {cred}", callback_data=f"api_key_delete_cred:{i}") - ]) - - message_text = ( - header + - f"*Account:* `{account_escaped}`\n\n" - "*Connected Exchanges:*\n" - + "\n".join(cred_lines) + "\n\n" - "Select an exchange below to configure or delete:\n\n" - ) - keyboard = credential_buttons - - # Get list of available connectors - all_connectors = await client.connectors.list_connectors() - - # Filter out testnet connectors and gateway connectors (those with '/' like "uniswap/ethereum") - connectors = [c for c in all_connectors if 'testnet' not in c.lower() and '/' not in c] - - # Create connector buttons in grid of 3 per row (for better readability of long names) - # Store account name and connector list in context to avoid exceeding 64-byte callback_data limit - context.user_data['api_key_current_account'] = account_name - context.user_data['api_key_connectors'] = connectors - - connector_buttons = [] - for i, connector in enumerate(connectors): - # Use index instead of full names to keep callback_data short - connector_buttons.append( - InlineKeyboardButton(connector, callback_data=f"api_key_connector:{i}") - ) - - # Organize into rows of 2 columns for better readability - connector_button_rows = [] - for i in range(0, len(connector_buttons), 2): - connector_button_rows.append(connector_buttons[i:i+2]) - - keyboard = keyboard + connector_button_rows + [ - [InlineKeyboardButton("« Back to Accounts", callback_data="api_key_back_to_accounts")] - ] - - reply_markup = InlineKeyboardMarkup(keyboard) - - await query.message.edit_text( - message_text, - parse_mode="MarkdownV2", - reply_markup=reply_markup - ) - - except Exception as e: - logger.error(f"Error showing account credentials: {e}", exc_info=True) - error_text = f"❌ Error loading account credentials: {escape_markdown_v2(str(e))}" - keyboard = [[InlineKeyboardButton("« Back", callback_data="api_key_back_to_accounts")]] - reply_markup = InlineKeyboardMarkup(keyboard) - await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) - - async def show_connector_config(query, context: ContextTypes.DEFAULT_TYPE, account_name: str, connector_name: str) -> None: """ Start progressive configuration flow for a specific connector """ try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Get config map for this connector - config_fields = await client.connectors.get_config_map(connector_name) + try: + config_map = await client.connectors.get_config_map(connector_name) + except Exception as config_err: + # Fallback for older backends that don't support config-map endpoint + logger.warning(f"Failed to get config map for {connector_name}, using defaults: {config_err}") + config_map = None + + # Support both old format (list) and new format (dict with metadata) + if config_map is None: + # Fallback: use default fields for most exchange connectors + config_fields = [f"{connector_name}_api_key", f"{connector_name}_api_secret"] + field_metadata = { + f"{connector_name}_api_key": {'type': 'SecretStr', 'required': True}, + f"{connector_name}_api_secret": {'type': 'SecretStr', 'required': True} + } + elif isinstance(config_map, dict): + # New format: dict with field metadata + config_fields = list(config_map.keys()) + field_metadata = config_map + else: + # Old format: list of field names + config_fields = config_map + field_metadata = {} + + # Filter out fields that should be handled automatically + if connector_name == "xrpl": + # custom_markets expects a dict, we'll default to empty + config_fields = [f for f in config_fields if f != "custom_markets"] + field_metadata.pop("custom_markets", None) + + # Determine connector type for back navigation + connector_type = "perpetual" if connector_name.endswith('_perpetual') else "spot" # Initialize context storage for API key configuration context.user_data['configuring_api_key'] = True context.user_data['api_key_config_data'] = { 'account_name': account_name, 'connector_name': connector_name, + 'connector_type': connector_type, 'fields': config_fields, + 'field_metadata': field_metadata, 'values': {} } context.user_data['awaiting_api_key_input'] = config_fields[0] if config_fields else None @@ -362,15 +459,12 @@ async def show_connector_config(query, context: ContextTypes.DEFAULT_TYPE, accou if not config_fields: # No configuration needed - account_escaped = escape_markdown_v2(account_name) connector_escaped = escape_markdown_v2(connector_name) message_text = ( f"🔑 *Configure {connector_escaped}*\n\n" - f"Account: *{account_escaped}*\n\n" "✅ No configuration required for this connector\\." ) - encoded_account = base64.b64encode(account_name.encode()).decode() - keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_back_account:{encoded_account}")]] + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_type:{connector_type}")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(message_text, parse_mode="MarkdownV2", reply_markup=reply_markup) return @@ -392,8 +486,8 @@ async def show_connector_config(query, context: ContextTypes.DEFAULT_TYPE, accou except Exception as e: logger.error(f"Error showing connector config: {e}", exc_info=True) error_text = f"❌ Error loading connector config: {escape_markdown_v2(str(e))}" - encoded_account = base64.b64encode(account_name.encode()).decode() - keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_back_account:{encoded_account}")]] + connector_type = "perpetual" if connector_name.endswith('_perpetual') else "spot" + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_type:{connector_type}")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) @@ -422,6 +516,70 @@ async def handle_api_key_config_input(update: Update, context: ContextTypes.DEFA config_data = context.user_data.get('api_key_config_data', {}) values = config_data.get('values', {}) all_fields = config_data.get('fields', []) + field_metadata = config_data.get('field_metadata', {}) + + # Get field metadata for validation + field_meta = field_metadata.get(awaiting_field, {}) + field_type = field_meta.get('type', '') + + # Validate and convert based on field type + if field_type == 'Literal': + allowed_values = field_meta.get('allowed_values', []) + if allowed_values and new_value not in allowed_values: + # Send error message and don't advance + error_msg = await update.effective_chat.send_message( + f"❌ Invalid value. Please select one of: {', '.join(allowed_values)}" + ) + # Auto-delete error message after 3 seconds + await asyncio.sleep(3) + try: + await error_msg.delete() + except: + pass + return + elif field_type == 'bool': + # Convert string to boolean + lower_val = new_value.lower() + if lower_val in ('true', 'yes', '1'): + new_value = True + elif lower_val in ('false', 'no', '0'): + new_value = False + else: + error_msg = await update.effective_chat.send_message( + "❌ Invalid value. Please enter 'true' or 'false'" + ) + await asyncio.sleep(3) + try: + await error_msg.delete() + except: + pass + return + elif field_type == 'int': + try: + new_value = int(new_value) + except ValueError: + error_msg = await update.effective_chat.send_message( + "❌ Invalid value. Please enter an integer number" + ) + await asyncio.sleep(3) + try: + await error_msg.delete() + except: + pass + return + elif field_type == 'float': + try: + new_value = float(new_value) + except ValueError: + error_msg = await update.effective_chat.send_message( + "❌ Invalid value. Please enter a number" + ) + await asyncio.sleep(3) + try: + await error_msg.delete() + except: + pass + return # Store the value values[awaiting_field] = new_value @@ -452,11 +610,12 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id Submit the API key configuration to Hummingbot """ try: - from servers import server_manager + from config_manager import get_config_manager config_data = context.user_data.get('api_key_config_data', {}) account_name = config_data.get('account_name') connector_name = config_data.get('connector_name') + connector_type = config_data.get('connector_type', 'spot') values = config_data.get('values', {}) message_id = context.user_data.get('api_key_message_id') @@ -465,11 +624,9 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id return # Show "waiting for connection" message - account_escaped = escape_markdown_v2(account_name) connector_escaped = escape_markdown_v2(connector_name) waiting_message_text = ( f"⏳ *Connecting to {connector_escaped}*\n\n" - f"Account: *{account_escaped}*\n\n" "Please wait while we verify your credentials\\.\\.\\." ) @@ -481,7 +638,13 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id parse_mode="MarkdownV2" ) - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) + + # Handle special cases for certain connectors + if connector_name == "xrpl": + # XRPL connector expects custom_markets as a dict, default to empty + if "custom_markets" not in values or values.get("custom_markets") is None: + values["custom_markets"] = {} # Add credentials using the accounts API await client.accounts.add_credential( @@ -490,6 +653,9 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id credentials=values ) + # Store connector type before clearing context + saved_connector_type = connector_type + # Clear context data context.user_data.pop('configuring_api_key', None) context.user_data.pop('awaiting_api_key_input', None) @@ -515,9 +681,7 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id parse_mode="MarkdownV2" ) - # Create a mock query object to reuse the existing show_account_credentials function - # This automatically refreshes the account credentials view - import asyncio + # Create a mock query object to navigate back to connector type view from types import SimpleNamespace # Wait a moment to let the user see the success message @@ -537,16 +701,16 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id ) mock_query = SimpleNamespace(message=mock_message) - # Navigate back to account credentials view - await show_account_credentials(mock_query, context, account_name) + # Navigate back to connector type view + await show_connectors_by_type(mock_query, context, saved_connector_type) except Exception as e: logger.error(f"Error submitting API key config: {e}", exc_info=True) - # Get account name for back button before clearing state + # Get connector info for back button before clearing state config_data = context.user_data.get('api_key_config_data', {}) - account_name = config_data.get('account_name', '') connector_name = config_data.get('connector_name', '') + connector_type = config_data.get('connector_type', 'spot') message_id = context.user_data.get('api_key_message_id') # Clear context data so user can retry @@ -569,12 +733,8 @@ async def submit_api_key_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id else: error_text = f"❌ Error saving configuration: {escape_markdown_v2(error_str)}" - # Add back button to navigate back to account - if account_name: - encoded_account = base64.b64encode(account_name.encode()).decode() - keyboard = [[InlineKeyboardButton("« Back to Account", callback_data=f"api_key_back_account:{encoded_account}")]] - else: - keyboard = [[InlineKeyboardButton("« Back", callback_data="api_key_back_to_accounts")]] + # Add back button to navigate back to connector type view + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_type:{connector_type}")]] reply_markup = InlineKeyboardMarkup(keyboard) # Try to edit existing message, fall back to sending new message @@ -610,10 +770,13 @@ async def delete_credential(query, context: ContextTypes.DEFAULT_TYPE, account_n Delete a credential for a specific account and connector """ try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) + + # Determine connector type for back navigation + connector_type = context.user_data.get('api_key_connector_type', 'perpetual' if connector_name.endswith('_perpetual') else 'spot') # Delete the credential await client.accounts.delete_credential( @@ -622,14 +785,13 @@ async def delete_credential(query, context: ContextTypes.DEFAULT_TYPE, account_n ) # Show success message - account_escaped = escape_markdown_v2(account_name) connector_escaped = escape_markdown_v2(connector_name) message_text = ( f"✅ *Credential Deleted*\n\n" - f"The *{connector_escaped}* credentials have been removed from account *{account_escaped}*\\." + f"The *{connector_escaped}* credentials have been removed\\." ) - keyboard = [[InlineKeyboardButton("« Back to Account", callback_data=f"api_key_account:{base64.b64encode(account_name.encode()).decode()}")]] + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_type:{connector_type}")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text( @@ -643,25 +805,62 @@ async def delete_credential(query, context: ContextTypes.DEFAULT_TYPE, account_n logger.error(f"Error deleting credential: {e}", exc_info=True) error_text = f"❌ Error deleting credential: {escape_markdown_v2(str(e))}" - keyboard = [[InlineKeyboardButton("« Back to Account", callback_data=f"api_key_account:{base64.b64encode(account_name.encode()).decode()}")]] + connector_type = context.user_data.get('api_key_connector_type', 'perpetual' if connector_name.endswith('_perpetual') else 'spot') + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_key_type:{connector_type}")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) await query.answer("❌ Failed to delete credential") +async def show_credential_manage_menu(query, context: ContextTypes.DEFAULT_TYPE, cred_index: int, connector_name: str) -> None: + """ + Show management options for a credential (currently just delete) + """ + from .server_context import build_config_message_header + + chat_id = query.message.chat_id + header, _, _ = await build_config_message_header( + "🔑 Manage API Key", + include_gateway=False, + chat_id=chat_id, + user_data=context.user_data + ) + + # Determine type emoji + is_perpetual = connector_name.endswith('_perpetual') + type_emoji = "📈" if is_perpetual else "💱" + connector_escaped = escape_markdown_v2(connector_name) + + message_text = ( + header + + f"*Exchange:* {type_emoji} {connector_escaped}\n\n" + "_What would you like to do?_" + ) + + keyboard = [ + [InlineKeyboardButton("🗑 Delete Key", callback_data=f"api_key_delete_cred:{cred_index}")], + [InlineKeyboardButton("« Back", callback_data="api_key_back_to_accounts")], + ] + reply_markup = InlineKeyboardMarkup(keyboard) + + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + async def show_delete_credential_confirmation(query, context: ContextTypes.DEFAULT_TYPE, account_name: str, connector_name: str) -> None: """ Show confirmation dialog before deleting a credential """ - account_escaped = escape_markdown_v2(account_name) connector_escaped = escape_markdown_v2(connector_name) message_text = ( f"🗑 *Delete Credential*\n\n" - f"Account: *{account_escaped}*\n" f"Exchange: *{connector_escaped}*\n\n" - f"⚠️ This will remove the API credentials for *{connector_escaped}* from account *{account_escaped}*\\.\n\n" + f"⚠️ This will remove the API credentials for *{connector_escaped}*\\.\n\n" "Are you sure you want to delete this credential?" ) @@ -682,44 +881,199 @@ async def show_delete_credential_confirmation(query, context: ContextTypes.DEFAU ) +async def _handle_field_value_selection(query, context, selected_value: str) -> None: + """ + Handle selection of a value for a Literal or bool type field via inline button + """ + try: + awaiting_field = context.user_data.get('awaiting_api_key_input') + if not awaiting_field: + await query.answer("❌ No field awaiting input") + return + + config_data = context.user_data.get('api_key_config_data', {}) + values = config_data.get('values', {}) + all_fields = config_data.get('fields', []) + field_metadata = config_data.get('field_metadata', {}) + + # Convert value based on field type + field_meta = field_metadata.get(awaiting_field, {}) + if field_meta.get('type') == 'bool': + selected_value = selected_value.lower() == 'true' + + # Store the selected value + values[awaiting_field] = selected_value + config_data['values'] = values + context.user_data['api_key_config_data'] = config_data + + # Move to next field or submit + current_index = all_fields.index(awaiting_field) + + if current_index < len(all_fields) - 1: + # Move to next field + context.user_data['awaiting_api_key_input'] = all_fields[current_index + 1] + await query.answer(f"✅ {awaiting_field} = {selected_value}") + await _update_api_key_config_message(context, query.message.get_bot()) + else: + # All fields filled - submit configuration + context.user_data['awaiting_api_key_input'] = None + await query.answer("✅ Submitting configuration...") + chat_id = context.user_data.get('api_key_chat_id', query.message.chat_id) + await submit_api_key_config(context, query.message.get_bot(), chat_id) + + except Exception as e: + logger.error(f"Error handling field value selection: {e}", exc_info=True) + await query.answer("❌ Error processing selection") + + +async def _handle_skip_optional_field(query, context) -> None: + """ + Handle skipping an optional field + """ + try: + awaiting_field = context.user_data.get('awaiting_api_key_input') + if not awaiting_field: + await query.answer("❌ No field awaiting input") + return + + config_data = context.user_data.get('api_key_config_data', {}) + all_fields = config_data.get('fields', []) + field_metadata = config_data.get('field_metadata', {}) + + # Verify field is optional + field_meta = field_metadata.get(awaiting_field, {}) + if field_meta.get('required', True): + await query.answer("❌ This field is required") + return + + # Don't store a value - just move to next field + current_index = all_fields.index(awaiting_field) + + if current_index < len(all_fields) - 1: + # Move to next field + context.user_data['awaiting_api_key_input'] = all_fields[current_index + 1] + await query.answer(f"⏭ Skipped {awaiting_field}") + await _update_api_key_config_message(context, query.message.get_bot()) + else: + # All fields filled - submit configuration + context.user_data['awaiting_api_key_input'] = None + await query.answer("✅ Submitting configuration...") + chat_id = context.user_data.get('api_key_chat_id', query.message.chat_id) + await submit_api_key_config(context, query.message.get_bot(), chat_id) + + except Exception as e: + logger.error(f"Error handling skip field: {e}", exc_info=True) + await query.answer("❌ Error skipping field") + + +def _format_field_type_hint(field_meta: dict) -> str: + """ + Format a human-readable type hint from field metadata + """ + if not field_meta: + return "" + + field_type = field_meta.get('type', '') + required = field_meta.get('required', False) + allowed_values = field_meta.get('allowed_values', []) + + hints = [] + + # Type hint + if field_type == 'Literal' and allowed_values: + values_str = " | ".join(allowed_values) + hints.append(f"Options: {values_str}") + elif field_type == 'bool': + hints.append("true/false") + elif field_type == 'SecretStr': + hints.append("secret") + elif field_type == 'int': + hints.append("integer") + elif field_type == 'float': + hints.append("number") + elif field_type: + hints.append(field_type.lower()) + + # Required hint + if not required: + hints.append("optional") + + return ", ".join(hints) + + def _build_api_key_config_message(config_data: dict, current_field: str, all_fields: list) -> tuple: """ Build the progressive API key configuration message showing filled fields and current prompt Returns (message_text, reply_markup) """ - account_name = config_data.get('account_name', '') connector_name = config_data.get('connector_name', '') + connector_type = config_data.get('connector_type', 'spot') values = config_data.get('values', {}) + field_metadata = config_data.get('field_metadata', {}) - account_escaped = escape_markdown_v2(account_name) connector_escaped = escape_markdown_v2(connector_name) # Build the message showing progress lines = [f"🔑 *Configure {connector_escaped}*\n"] - lines.append(f"Account: *{account_escaped}*\n") for field in all_fields: + field_meta = field_metadata.get(field, {}) + field_escaped = escape_markdown_v2(field) + if field in values: # Field already filled - show value (mask if contains 'secret', 'key', or 'password') value = values[field] - if any(keyword in field.lower() for keyword in ['secret', 'key', 'password', 'passphrase']): + is_secret = any(keyword in field.lower() for keyword in ['secret', 'key', 'password', 'passphrase']) + if is_secret or field_meta.get('type') == 'SecretStr': value = '****' - field_escaped = escape_markdown_v2(field) value_escaped = escape_markdown_v2(str(value)) lines.append(f"*{field_escaped}:* `{value_escaped}` ✅") elif field == current_field: - # Current field being filled - field_escaped = escape_markdown_v2(field) - lines.append(f"*{field_escaped}:* _\\(awaiting input\\)_") + # Current field being filled - show with type hint + type_hint = _format_field_type_hint(field_meta) + if type_hint: + type_hint_escaped = escape_markdown_v2(type_hint) + lines.append(f"*{field_escaped}:* _\\(awaiting input\\)_") + lines.append(f" ↳ _{type_hint_escaped}_") + else: + lines.append(f"*{field_escaped}:* _\\(awaiting input\\)_") break else: - # Future field - show placeholder - field_escaped = escape_markdown_v2(field) - lines.append(f"*{field_escaped}:* \\_\\_\\_") + # Future field - show placeholder with optional indicator + is_optional = field_meta and not field_meta.get('required', True) + optional_marker = " \\(optional\\)" if is_optional else "" + lines.append(f"*{field_escaped}:*{optional_marker} \\_\\_\\_") message_text = "\n".join(lines) - # Build keyboard with back and cancel buttons + # Build keyboard + keyboard = [] + + # Add option buttons for Literal and bool types + current_field_meta = field_metadata.get(current_field, {}) + if current_field_meta.get('type') == 'Literal' and current_field_meta.get('allowed_values'): + allowed_values = current_field_meta['allowed_values'] + # Create buttons for each allowed value (max 2 per row) + option_buttons = [] + for value in allowed_values: + option_buttons.append( + InlineKeyboardButton(value, callback_data=f"api_key_select:{value}") + ) + # Arrange in rows of 2 + for i in range(0, len(option_buttons), 2): + keyboard.append(option_buttons[i:i+2]) + elif current_field_meta.get('type') == 'bool': + # Add true/false buttons for boolean fields + keyboard.append([ + InlineKeyboardButton("✓ true", callback_data="api_key_select:true"), + InlineKeyboardButton("✗ false", callback_data="api_key_select:false") + ]) + + # Add skip button for optional fields + if current_field_meta and not current_field_meta.get('required', True): + keyboard.append([InlineKeyboardButton("⏭ Skip (use default)", callback_data="api_key_skip")]) + + # Build navigation buttons buttons = [] # Add back button if not on first field @@ -727,11 +1081,10 @@ def _build_api_key_config_message(config_data: dict, current_field: str, all_fie if current_index > 0: buttons.append(InlineKeyboardButton("« Back", callback_data="api_key_config_back")) - # Always add cancel button - encoded_account = base64.b64encode(account_name.encode()).decode() - buttons.append(InlineKeyboardButton("❌ Cancel", callback_data=f"api_key_back_account:{encoded_account}")) + # Always add cancel button - navigate back to connector type view + buttons.append(InlineKeyboardButton("❌ Cancel", callback_data=f"api_key_type:{connector_type}")) - keyboard = [buttons] + keyboard.append(buttons) reply_markup = InlineKeyboardMarkup(keyboard) return message_text, reply_markup diff --git a/handlers/config/gateway/__init__.py b/handlers/config/gateway/__init__.py index 4bcb9e9..c4331ef 100644 --- a/handlers/config/gateway/__init__.py +++ b/handlers/config/gateway/__init__.py @@ -14,6 +14,24 @@ from telegram import Update from telegram.ext import ContextTypes +from utils.auth import restricted + + +@restricted +async def gateway_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /gateway command - show Gateway configuration directly.""" + from handlers import clear_all_input_states + from utils.telegram_helpers import create_mock_query_from_message + from .menu import show_gateway_menu + + clear_all_input_states(context) + # Clear trading states that might interfere + context.user_data.pop("dex_state", None) + context.user_data.pop("cex_state", None) + mock_query = await create_mock_query_from_message(update, "Loading Gateway...") + await show_gateway_menu(mock_query, context) + + # Import all submodule handlers from .deployment import ( start_deploy_gateway, @@ -22,6 +40,7 @@ stop_gateway, restart_gateway, show_gateway_logs, + handle_deployment_input, ) from .wallets import show_wallets_menu, handle_wallet_action, handle_wallet_input from .connectors import show_connectors_menu, handle_connector_action, handle_connector_config_input @@ -82,7 +101,9 @@ async def handle_gateway_callback(update: Update, context: ContextTypes.DEFAULT_ async def handle_gateway_input(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Route text input to the appropriate gateway module""" # Check which type of input we're awaiting - if context.user_data.get('awaiting_wallet_input'): + if context.user_data.get('awaiting_gateway_input'): + await handle_deployment_input(update, context) + elif context.user_data.get('awaiting_wallet_input'): await handle_wallet_input(update, context) elif context.user_data.get('awaiting_token_input'): await handle_token_input(update, context) @@ -95,6 +116,7 @@ async def handle_gateway_input(update: Update, context: ContextTypes.DEFAULT_TYP __all__ = [ + 'gateway_command', 'handle_gateway_callback', 'handle_gateway_input', ] diff --git a/handlers/config/gateway/_shared.py b/handlers/config/gateway/_shared.py index fd9387c..70c1727 100644 --- a/handlers/config/gateway/_shared.py +++ b/handlers/config/gateway/_shared.py @@ -4,8 +4,6 @@ import logging from typing import List, Dict, Any -from telegram import InlineKeyboardButton, InlineKeyboardMarkup -from telegram.ext import ContextTypes from utils.telegram_formatters import escape_markdown_v2 diff --git a/handlers/config/gateway/connectors.py b/handlers/config/gateway/connectors.py index 6977753..2973787 100644 --- a/handlers/config/gateway/connectors.py +++ b/handlers/config/gateway/connectors.py @@ -6,17 +6,18 @@ from telegram.ext import ContextTypes from ._shared import logger, escape_markdown_v2 +from ..user_preferences import get_active_server async def show_connectors_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show DEX connectors configuration menu""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading connectors...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.list_connectors() connectors = response.get('connectors', []) @@ -113,10 +114,10 @@ async def handle_connector_action(query, context: ContextTypes.DEFAULT_TYPE) -> async def show_connector_details(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str) -> None: """Show details and configuration for a specific connector""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.get_connector_config(connector_name) # Try to extract config - it might be directly in response or nested under 'config' @@ -179,10 +180,10 @@ async def show_connector_details(query, context: ContextTypes.DEFAULT_TYPE, conn async def start_connector_config_edit(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str) -> None: """Start progressive configuration editing flow for a connector""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.get_connector_config(connector_name) # Extract config @@ -365,7 +366,7 @@ async def handle_connector_config_input(update: Update, context: ContextTypes.DE async def submit_connector_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id: int) -> None: """Submit the connector configuration to Gateway""" try: - from servers import server_manager + from config_manager import get_config_manager config_data = context.user_data.get('connector_config_data', {}) connector_name = config_data.get('connector_name') @@ -396,7 +397,7 @@ async def submit_connector_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_ parse_mode="MarkdownV2" ) - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Update configuration using the gateway API await client.gateway.update_connector_config(connector_name, final_config) diff --git a/handlers/config/gateway/deployment.py b/handlers/config/gateway/deployment.py index 2b70f2f..7f41a43 100644 --- a/handlers/config/gateway/deployment.py +++ b/handlers/config/gateway/deployment.py @@ -2,10 +2,11 @@ Gateway deployment, lifecycle, and logs management """ -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes from ..server_context import build_config_message_header +from ..user_preferences import get_active_server from ._shared import logger, escape_markdown_v2 @@ -16,7 +17,8 @@ async def start_deploy_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> Non header, server_online, _ = await build_config_message_header( "🚀 Deploy Gateway", include_gateway=False, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) if not server_online: @@ -55,44 +57,144 @@ async def start_deploy_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> Non async def deploy_gateway_with_image(query, context: ContextTypes.DEFAULT_TYPE) -> None: - """Deploy Gateway container with selected Docker image""" + """Prompt for passphrase before deploying Gateway""" try: - from servers import server_manager - from .menu import show_gateway_menu - # Extract image tag from callback data image_tag = query.data.replace("gateway_deploy_image_", "") docker_image = f"hummingbot/gateway:{image_tag}" - await query.answer("🚀 Deploying Gateway...") + # Store image and prompt for passphrase + context.user_data['gateway_deploy_image'] = docker_image + await prompt_passphrase(query, context) + + except Exception as e: + logger.error(f"Error starting deploy flow: {e}", exc_info=True) + await query.answer(f"❌ Error: {str(e)[:100]}") + +async def prompt_passphrase(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Prompt user to enter passphrase for Gateway deployment""" + try: chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + header, server_online, _ = await build_config_message_header( + "🔐 Gateway Passphrase", + include_gateway=False, + chat_id=chat_id, + user_data=context.user_data + ) + + docker_image = context.user_data.get('gateway_deploy_image', 'hummingbot/gateway:latest') + image_escaped = escape_markdown_v2(docker_image) + + context.user_data['awaiting_gateway_input'] = 'passphrase' + context.user_data['gateway_message_id'] = query.message.message_id + context.user_data['gateway_chat_id'] = query.message.chat_id + + message_text = ( + header + + f"*Image:* `{image_escaped}`\n\n" + "*Enter Gateway Passphrase:*\n\n" + "This passphrase is used by Gateway to encrypt stored wallet keys\\.\n\n" + "_Please send your passphrase as a message\\._" + ) + + keyboard = [[InlineKeyboardButton("« Cancel", callback_data="gateway_deploy")]] + reply_markup = InlineKeyboardMarkup(keyboard) + + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + await query.answer() + + except Exception as e: + logger.error(f"Error prompting passphrase: {e}", exc_info=True) + await query.answer(f"❌ Error: {str(e)[:100]}") + + +async def execute_gateway_deploy(context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, docker_image: str, passphrase: str) -> None: + """Execute the actual Gateway deployment with provided config""" + from config_manager import get_config_manager + from .menu import show_gateway_menu + + try: + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Gateway configuration config = { "image": docker_image, "port": 15888, - "passphrase": "a", + "passphrase": passphrase, "dev_mode": True, } + # Show deploying message + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="🚀 *Deploying Gateway\\.\\.\\.*\n\n_Please wait, this may take a moment\\._", + parse_mode="MarkdownV2" + ) + response = await client.gateway.start(config) - if response.get('status') == 'success' or response.get('status') == 'running': - await query.answer("✅ Gateway deployed successfully") + success = response.get('status') == 'success' or response.get('status') == 'running' + + if success: + result_text = "✅ *Gateway Deployed Successfully*\n\n_Returning to menu\\.\\.\\._" else: - await query.answer("⚠️ Gateway deployment may need verification") + result_text = "⚠️ *Gateway Deployment Completed*\n\n_Verifying status\\.\\.\\._" - # Refresh the gateway menu to show new status - await show_gateway_menu(query, context) + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=result_text, + parse_mode="MarkdownV2" + ) + + # Brief pause then show menu + import asyncio + await asyncio.sleep(1) + + # Create a mock query object for show_gateway_menu + class MockQuery: + def __init__(self, message): + self.message = message + + async def answer(self, *args, **kwargs): + pass + + class MockMessage: + def __init__(self, chat_id, message_id, bot): + self.chat_id = chat_id + self.message_id = message_id + self._bot = bot + + async def edit_text(self, text, **kwargs): + await self._bot.edit_message_text( + chat_id=self.chat_id, + message_id=self.message_id, + text=text, + **kwargs + ) + + mock_message = MockMessage(chat_id, message_id, context.bot) + mock_query = MockQuery(mock_message) + await show_gateway_menu(mock_query, context) except Exception as e: logger.error(f"Error deploying gateway: {e}", exc_info=True) - await query.answer(f"❌ Deployment failed: {str(e)[:100]}") - # Still refresh menu to show current state - from .menu import show_gateway_menu - await show_gateway_menu(query, context) + error_text = f"❌ *Deployment Failed*\n\n`{escape_markdown_v2(str(e))}`" + keyboard = [[InlineKeyboardButton("« Back", callback_data="config_gateway")]] + reply_markup = InlineKeyboardMarkup(keyboard) + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=error_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) async def prompt_custom_image(query, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -102,7 +204,8 @@ async def prompt_custom_image(query, context: ContextTypes.DEFAULT_TYPE) -> None header, server_online, _ = await build_config_message_header( "✏️ Custom Gateway Image", include_gateway=False, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) context.user_data['awaiting_gateway_input'] = 'custom_image' @@ -136,13 +239,13 @@ async def prompt_custom_image(query, context: ContextTypes.DEFAULT_TYPE) -> None async def stop_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Stop Gateway container on the current server""" try: - from servers import server_manager + from config_manager import get_config_manager from .menu import show_gateway_menu await query.answer("⏹ Stopping Gateway...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.stop() if response.get('status') == 'success' or response.get('status') == 'stopped': @@ -163,7 +266,7 @@ async def stop_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: async def restart_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Restart Gateway container on the current server""" try: - from servers import server_manager + from config_manager import get_config_manager from .menu import show_gateway_menu import asyncio @@ -176,7 +279,8 @@ async def restart_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: header, _, _ = await build_config_message_header( "🌐 Gateway Configuration", include_gateway=False, # Don't check status during restart - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) restarting_text = ( @@ -194,7 +298,7 @@ async def restart_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: pass # Ignore if message can't be edited # Perform the restart - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.restart() # Wait a moment for the restart to take effect @@ -240,12 +344,12 @@ async def restart_gateway(query, context: ContextTypes.DEFAULT_TYPE) -> None: async def show_gateway_logs(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show Gateway container logs""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("📋 Loading logs...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.get_logs(tail=50) logs = response.get('logs', 'No logs available') @@ -295,3 +399,95 @@ async def show_gateway_logs(query, context: ContextTypes.DEFAULT_TYPE) -> None: keyboard = [[InlineKeyboardButton("« Back", callback_data="config_gateway")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) + + +async def handle_deployment_input(update, context) -> None: + """Handle text input during gateway deployment flow""" + + awaiting_field = context.user_data.get('awaiting_gateway_input') + if not awaiting_field: + return + + # Delete user's input message for security (passphrase shouldn't be visible) + try: + await update.message.delete() + except: + pass + + try: + message_id = context.user_data.get('gateway_message_id') + chat_id = context.user_data.get('gateway_chat_id') + + if awaiting_field == 'passphrase': + passphrase = update.message.text.strip() + docker_image = context.user_data.get('gateway_deploy_image', 'hummingbot/gateway:latest') + + # Clear context + context.user_data.pop('awaiting_gateway_input', None) + context.user_data.pop('gateway_deploy_image', None) + context.user_data.pop('gateway_message_id', None) + context.user_data.pop('gateway_chat_id', None) + + if not passphrase: + await update.get_bot().edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="❌ Passphrase cannot be empty", + parse_mode="MarkdownV2" + ) + return + + # Execute deployment with provided passphrase + await execute_gateway_deploy(context, chat_id, message_id, docker_image, passphrase) + + elif awaiting_field == 'custom_image': + custom_image = update.message.text.strip() + + # Clear context + context.user_data.pop('awaiting_gateway_input', None) + context.user_data.pop('gateway_message_id', None) + context.user_data.pop('gateway_chat_id', None) + + if not custom_image: + await update.get_bot().edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="❌ Image name cannot be empty", + parse_mode="MarkdownV2" + ) + return + + # Store custom image and prompt for passphrase + context.user_data['gateway_deploy_image'] = custom_image + context.user_data['gateway_message_id'] = message_id + context.user_data['gateway_chat_id'] = chat_id + + # Create mock query for prompt_passphrase + class MockQuery: + def __init__(self, message): + self.message = message + + async def answer(self, *args, **kwargs): + pass + + class MockMessage: + def __init__(self, chat_id, message_id, bot): + self.chat_id = chat_id + self.message_id = message_id + self._bot = bot + + async def edit_text(self, text, **kwargs): + await self._bot.edit_message_text( + chat_id=self.chat_id, + message_id=self.message_id, + text=text, + **kwargs + ) + + mock_message = MockMessage(chat_id, message_id, context.bot) + mock_query = MockQuery(mock_message) + await prompt_passphrase(mock_query, context) + + except Exception as e: + logger.error(f"Error handling deployment input: {e}", exc_info=True) + context.user_data.pop('awaiting_gateway_input', None) diff --git a/handlers/config/gateway/menu.py b/handlers/config/gateway/menu.py index 25330a9..6ef28ad 100644 --- a/handlers/config/gateway/menu.py +++ b/handlers/config/gateway/menu.py @@ -2,7 +2,7 @@ Gateway menu and server selection """ -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes from ..server_context import build_config_message_header, format_server_selection_needed @@ -14,20 +14,21 @@ async def show_gateway_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: Show gateway configuration menu with status for default server """ try: - from servers import server_manager + from config_manager import get_config_manager - servers = server_manager.list_servers() + servers = get_config_manager().list_servers() if not servers: message_text = format_server_selection_needed() - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] else: # Build unified header with server and gateway info chat_id = query.message.chat_id header, server_online, gateway_running = await build_config_message_header( "🌐 Gateway Configuration", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) message_text = header @@ -62,8 +63,8 @@ async def show_gateway_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: InlineKeyboardButton("🚀 Deploy Gateway", callback_data="gateway_deploy"), ]) - # Add back button - keyboard.append([InlineKeyboardButton("« Back", callback_data="config_back")]) + # Add close button + keyboard.append([InlineKeyboardButton("« Close", callback_data="config_close")]) reply_markup = InlineKeyboardMarkup(keyboard) @@ -90,7 +91,7 @@ async def show_gateway_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: except Exception as e: logger.error(f"Error showing gateway menu: {e}", exc_info=True) error_text = f"❌ Error loading gateway: {escape_markdown_v2(str(e))}" - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) @@ -98,10 +99,10 @@ async def show_gateway_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: async def show_server_selection(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show server selection menu for gateway configuration""" try: - from servers import server_manager + from config_manager import get_config_manager - servers = server_manager.list_servers() - default_server = server_manager.get_default_server() + servers = get_config_manager().list_servers() + default_server = get_config_manager().get_default_server() message_text = ( "🔄 *Select Server*\n\n" @@ -138,13 +139,13 @@ async def show_server_selection(query, context: ContextTypes.DEFAULT_TYPE) -> No async def handle_server_selection(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle server selection for gateway configuration""" try: - from servers import server_manager + from config_manager import get_config_manager server_name = query.data.replace("gateway_server_", "") # Set as default server temporarily for this session # Or we could store it in context for this specific flow - success = server_manager.set_default_server(server_name) + success = get_config_manager().set_default_server(server_name) if success: await query.answer(f"✅ Switched to {server_name}") diff --git a/handlers/config/gateway/networks.py b/handlers/config/gateway/networks.py index d277e06..3ebb62e 100644 --- a/handlers/config/gateway/networks.py +++ b/handlers/config/gateway/networks.py @@ -6,17 +6,18 @@ from telegram.ext import ContextTypes from ._shared import logger, escape_markdown_v2, extract_network_id +from ..user_preferences import get_active_server async def show_networks_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show networks configuration menu""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading networks...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.list_networks() networks = response.get('networks', []) @@ -112,10 +113,10 @@ async def handle_network_action(query, context: ContextTypes.DEFAULT_TYPE) -> No async def show_network_details(query, context: ContextTypes.DEFAULT_TYPE, network_id: str) -> None: """Show network config in edit mode - user can copy/paste to change values""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.get_network_config(network_id) # Try to extract config - it might be directly in response or nested under 'config' @@ -258,7 +259,7 @@ async def handle_network_config_input(update: Update, context: ContextTypes.DEFA async def submit_network_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id: int) -> None: """Submit the network configuration to Gateway""" try: - from servers import server_manager + from config_manager import get_config_manager config_data = context.user_data.get('network_config_data', {}) network_id = config_data.get('network_id') @@ -292,7 +293,7 @@ async def submit_network_config(context: ContextTypes.DEFAULT_TYPE, bot, chat_id context.user_data.pop('awaiting_network_input', None) # Submit configuration to Gateway - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) await client.gateway.update_network_config(network_id, final_config) success_text = f"✅ Configuration saved for {escape_markdown_v2(network_id)}\\!" diff --git a/handlers/config/gateway/pools.py b/handlers/config/gateway/pools.py index 842881f..a89c743 100644 --- a/handlers/config/gateway/pools.py +++ b/handlers/config/gateway/pools.py @@ -7,18 +7,19 @@ from telegram.ext import ContextTypes from ._shared import logger, escape_markdown_v2, filter_pool_connectors, extract_network_id +from ..user_preferences import get_active_server from utils.telegram_formatters import resolve_token_address async def show_pools_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show liquidity pools menu - select connector first""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading connectors...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.list_connectors() connectors = response.get('connectors', []) @@ -167,7 +168,7 @@ async def handle_pool_action(query, context: ContextTypes.DEFAULT_TYPE) -> None: async def show_pool_networks(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str) -> None: """Show network selection for viewing pools - only connector-specific networks""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading networks...") @@ -178,7 +179,7 @@ async def show_pool_networks(query, context: ContextTypes.DEFAULT_TYPE, connecto if not connector_info: # Fallback: fetch connector info again if not in context chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.list_connectors() connectors = response.get('connectors', []) connector_info = next((c for c in connectors if c.get('name') == connector_name), None) @@ -248,12 +249,12 @@ async def show_pool_networks(query, context: ContextTypes.DEFAULT_TYPE, connecto async def show_connector_pools(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str, network: str) -> None: """Show pools for a specific connector and network""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading pools...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) pools = await client.gateway.list_pools(connector_name=connector_name, network=network) connector_escaped = escape_markdown_v2(connector_name) @@ -367,7 +368,7 @@ async def prompt_add_pool(query, context: ContextTypes.DEFAULT_TYPE, connector_n async def prompt_remove_pool(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str, network: str) -> None: """Show list of pools to remove with numbered buttons""" try: - from servers import server_manager + from config_manager import get_config_manager connector_escaped = escape_markdown_v2(connector_name) network_escaped = escape_markdown_v2(network) @@ -375,7 +376,7 @@ async def prompt_remove_pool(query, context: ContextTypes.DEFAULT_TYPE, connecto chat_id = query.message.chat_id # Fetch pools to display as options - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) pools = await client.gateway.list_pools(connector_name=connector_name, network=network) if not pools: @@ -476,7 +477,7 @@ async def show_delete_pool_confirmation(query, context: ContextTypes.DEFAULT_TYP async def remove_pool(query, context: ContextTypes.DEFAULT_TYPE, connector_name: str, network: str, pool_address: str, pool_type: str) -> None: """Remove a pool from Gateway""" try: - from servers import server_manager + from config_manager import get_config_manager try: await query.answer("Removing pool...") @@ -484,7 +485,7 @@ async def remove_pool(query, context: ContextTypes.DEFAULT_TYPE, connector_name: pass # Mock query doesn't support answer chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) await client.gateway.delete_pool(connector=connector_name, network=network, pool_type=pool_type, address=pool_address) connector_escaped = escape_markdown_v2(connector_name) @@ -533,7 +534,7 @@ async def handle_pool_input(update: Update, context: ContextTypes.DEFAULT_TYPE) pass try: - from servers import server_manager + from config_manager import get_config_manager from types import SimpleNamespace connector_name = context.user_data.get('pool_connector') @@ -583,7 +584,7 @@ async def handle_pool_input(update: Update, context: ContextTypes.DEFAULT_TYPE) ) try: - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) logger.info(f"Adding pool: connector={connector_name}, network={network}, " f"pool_type={pool_type}, base={base}, quote={quote}, address={address}, " diff --git a/handlers/config/gateway/tokens.py b/handlers/config/gateway/tokens.py index 0bba80e..7cd0206 100644 --- a/handlers/config/gateway/tokens.py +++ b/handlers/config/gateway/tokens.py @@ -7,6 +7,7 @@ from geckoterminal_py import GeckoTerminalAsyncClient from ._shared import logger, escape_markdown_v2, extract_network_id +from ..user_preferences import get_active_server # Gateway network ID -> GeckoTerminal network ID mapping @@ -31,12 +32,12 @@ async def show_tokens_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show tokens menu - select network to view tokens""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading networks...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) response = await client.gateway.list_networks() networks = response.get('networks', []) @@ -178,12 +179,12 @@ async def show_network_tokens(query, context: ContextTypes.DEFAULT_TYPE, network COLUMNS = 4 try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading tokens...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Try to get tokens - the method might not exist in older versions try: @@ -334,12 +335,12 @@ async def prompt_add_token(query, context: ContextTypes.DEFAULT_TYPE, network_id async def prompt_remove_token(query, context: ContextTypes.DEFAULT_TYPE, network_id: str) -> None: """Show list of tokens to select for editing or removal""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading tokens...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Get tokens for the network try: @@ -514,12 +515,12 @@ async def prompt_edit_token(query, context: ContextTypes.DEFAULT_TYPE, token_idx async def show_delete_token_confirmation(query, context: ContextTypes.DEFAULT_TYPE, network_id: str, token_address: str) -> None: """Show confirmation dialog before deleting a token""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id # Get token details to show in confirmation - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Try to get tokens - the method might not exist in older versions try: @@ -585,12 +586,12 @@ async def show_delete_token_confirmation(query, context: ContextTypes.DEFAULT_TY async def remove_token(query, context: ContextTypes.DEFAULT_TYPE, network_id: str, token_address: str) -> None: """Remove a token from Gateway""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Removing token...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) await client.gateway.delete_token(network_id=network_id, token_address=token_address) network_escaped = escape_markdown_v2(network_id) @@ -641,7 +642,7 @@ async def handle_token_input(update: Update, context: ContextTypes.DEFAULT_TYPE) pass try: - from servers import server_manager + from config_manager import get_config_manager from types import SimpleNamespace network_id = context.user_data.get('token_network') @@ -724,7 +725,7 @@ async def handle_token_input(update: Update, context: ContextTypes.DEFAULT_TYPE) ) try: - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) await client.gateway.add_token( network_id=network_id, address=address, @@ -857,7 +858,7 @@ async def mock_answer(text=""): ) try: - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Delete old token first, then add with new values await client.gateway.delete_token(network_id=network_id, token_address=token_address) diff --git a/handlers/config/gateway/wallets.py b/handlers/config/gateway/wallets.py index ad5459f..f888ed5 100644 --- a/handlers/config/gateway/wallets.py +++ b/handlers/config/gateway/wallets.py @@ -12,6 +12,7 @@ remove_wallet_networks, get_default_networks_for_chain, get_all_networks_for_chain, + get_active_server, ) from ._shared import logger, escape_markdown_v2 @@ -19,12 +20,12 @@ async def show_wallets_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Show wallets management menu with list of connected wallets as clickable buttons""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Loading wallets...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Get list of gateway wallets try: @@ -37,7 +38,8 @@ async def show_wallets_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: header, server_online, gateway_running = await build_config_message_header( "🔑 Wallet Management", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) if not server_online: @@ -56,10 +58,13 @@ async def show_wallets_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: message_text = ( header + "_No wallets connected\\._\n\n" - "Add a wallet to get started\\." + "Add an existing wallet or create a new one\\." ) keyboard = [ - [InlineKeyboardButton("➕ Add Wallet", callback_data="gateway_wallet_add")], + [ + InlineKeyboardButton("➕ Add Wallet", callback_data="gateway_wallet_add"), + InlineKeyboardButton("🆕 Create Wallet", callback_data="gateway_wallet_create"), + ], [InlineKeyboardButton("« Back to Gateway", callback_data="config_gateway")] ] else: @@ -96,7 +101,10 @@ async def show_wallets_menu(query, context: ContextTypes.DEFAULT_TYPE) -> None: ]) keyboard = wallet_buttons + [ - [InlineKeyboardButton("➕ Add Wallet", callback_data="gateway_wallet_add")], + [ + InlineKeyboardButton("➕ Add Wallet", callback_data="gateway_wallet_add"), + InlineKeyboardButton("🆕 Create Wallet", callback_data="gateway_wallet_create"), + ], [ InlineKeyboardButton("🔄 Refresh", callback_data="gateway_wallets"), InlineKeyboardButton("« Back to Gateway", callback_data="config_gateway") @@ -131,6 +139,11 @@ async def handle_wallet_action(query, context: ContextTypes.DEFAULT_TYPE) -> Non if action_data == "add": await prompt_add_wallet_chain(query, context) + elif action_data == "create": + await prompt_create_wallet_chain(query, context) + elif action_data.startswith("create_chain_"): + chain = action_data.replace("create_chain_", "") + await create_wallet(query, context, chain) elif action_data == "remove": await prompt_remove_wallet_chain(query, context) elif action_data.startswith("view_"): @@ -241,7 +254,8 @@ async def prompt_add_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) -> header, server_online, gateway_running = await build_config_message_header( "➕ Add Wallet", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) # Base blockchain chains (wallets are at blockchain level, not network level) @@ -279,6 +293,135 @@ async def prompt_add_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) -> await query.answer(f"❌ Error: {str(e)[:100]}") +async def prompt_create_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) -> None: + """Prompt user to select chain for creating a new wallet""" + try: + chat_id = query.message.chat_id + header, server_online, gateway_running = await build_config_message_header( + "🆕 Create Wallet", + include_gateway=True, + chat_id=chat_id, + user_data=context.user_data + ) + + # Base blockchain chains (wallets are at blockchain level, not network level) + supported_chains = ["ethereum", "solana"] + + message_text = ( + header + + "*Select Chain:*\n\n" + "_Choose which blockchain to create a new wallet for\\._\n\n" + "⚠️ *Note:* A new wallet with a fresh keypair will be generated\\. " + "Make sure to back up the private key from Gateway\\." + ) + + # Create chain buttons + chain_buttons = [] + for chain in supported_chains: + chain_display = chain.replace("-", " ").title() + chain_icon = "🟣" if chain == "solana" else "🔵" + chain_buttons.append([ + InlineKeyboardButton(f"{chain_icon} {chain_display}", callback_data=f"gateway_wallet_create_chain_{chain}") + ]) + + keyboard = chain_buttons + [ + [InlineKeyboardButton("« Back", callback_data="gateway_wallets")] + ] + + reply_markup = InlineKeyboardMarkup(keyboard) + + await query.message.edit_text( + message_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + await query.answer() + + except Exception as e: + logger.error(f"Error prompting create wallet chain: {e}", exc_info=True) + await query.answer(f"❌ Error: {str(e)[:100]}") + + +async def create_wallet(query, context: ContextTypes.DEFAULT_TYPE, chain: str) -> None: + """Create a new wallet on the specified chain via Gateway""" + try: + from config_manager import get_config_manager + + await query.answer("Creating wallet...") + + chat_id = query.message.chat_id + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) + + chain_escaped = escape_markdown_v2(chain.replace("-", " ").title()) + + # Show creating message + await query.message.edit_text( + f"⏳ *Creating {chain_escaped} Wallet*\n\n_Please wait\\.\\.\\._", + parse_mode="MarkdownV2" + ) + + # Create the wallet via Gateway API + response = await client.gateway.create_wallet(chain=chain, set_default=False) + + # Extract address from response + address = response.get('address', '') if isinstance(response, dict) else '' + + if not address: + raise ValueError("No address returned from wallet creation") + + # Set default networks for the new wallet + default_networks = get_default_networks_for_chain(chain) + set_wallet_networks(context.user_data, address, default_networks) + + # Store info for network selection flow + context.user_data['new_wallet_chain'] = chain + context.user_data['new_wallet_address'] = address + context.user_data['new_wallet_networks'] = list(default_networks) + context.user_data['new_wallet_message_id'] = query.message.message_id + context.user_data['new_wallet_chat_id'] = chat_id + + # Show success message with network selection prompt + display_addr = address[:10] + "..." + address[-8:] if len(address) > 20 else address + addr_escaped = escape_markdown_v2(display_addr) + + # Build network selection message + all_networks = get_all_networks_for_chain(chain) + network_buttons = [] + for net in all_networks: + is_enabled = net in default_networks + status = "✅" if is_enabled else "⬜" + net_display = net.replace("-", " ").title() + button_text = f"{status} {net_display}" + network_buttons.append([ + InlineKeyboardButton(button_text, callback_data=f"gateway_wallet_new_toggle_{net}") + ]) + + success_text = ( + f"✅ *Wallet Created Successfully*\n\n" + f"`{addr_escaped}`\n\n" + f"*Select Networks:*\n" + f"_Choose which networks to enable for balance queries\\._" + ) + + keyboard = network_buttons + [ + [InlineKeyboardButton("✓ Done", callback_data="gateway_wallet_new_net_done")] + ] + reply_markup = InlineKeyboardMarkup(keyboard) + + await query.message.edit_text( + success_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + except Exception as e: + logger.error(f"Error creating wallet: {e}", exc_info=True) + error_text = f"❌ Error creating wallet: {escape_markdown_v2(str(e))}" + keyboard = [[InlineKeyboardButton("« Back", callback_data="gateway_wallets")]] + reply_markup = InlineKeyboardMarkup(keyboard) + await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) + + async def show_wallet_details(query, context: ContextTypes.DEFAULT_TYPE, chain: str, address: str) -> None: """Show details for a specific wallet with edit options""" try: @@ -286,7 +429,8 @@ async def show_wallet_details(query, context: ContextTypes.DEFAULT_TYPE, chain: header, server_online, gateway_running = await build_config_message_header( "🔑 Wallet Details", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) chain_escaped = escape_markdown_v2(chain.title()) @@ -361,7 +505,8 @@ async def show_wallet_network_edit(query, context: ContextTypes.DEFAULT_TYPE, ch header, server_online, gateway_running = await build_config_message_header( "🌐 Edit Networks", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) chain_escaped = escape_markdown_v2(chain.title()) @@ -462,7 +607,8 @@ async def prompt_add_wallet_private_key(query, context: ContextTypes.DEFAULT_TYP header, server_online, gateway_running = await build_config_message_header( f"➕ Add {chain.replace('-', ' ').title()} Wallet", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) context.user_data['awaiting_wallet_input'] = 'add_wallet' @@ -499,10 +645,10 @@ async def prompt_add_wallet_private_key(query, context: ContextTypes.DEFAULT_TYP async def prompt_remove_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Prompt user to select chain for removing wallet""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Get list of gateway wallets try: @@ -523,7 +669,8 @@ async def prompt_remove_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) header, server_online, gateway_running = await build_config_message_header( "➖ Remove Wallet", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) message_text = ( @@ -561,10 +708,10 @@ async def prompt_remove_wallet_chain(query, context: ContextTypes.DEFAULT_TYPE) async def prompt_remove_wallet_address(query, context: ContextTypes.DEFAULT_TYPE, chain: str) -> None: """Prompt user to select wallet address to remove""" try: - from servers import server_manager + from config_manager import get_config_manager chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Get wallets for this chain try: @@ -592,7 +739,8 @@ async def prompt_remove_wallet_address(query, context: ContextTypes.DEFAULT_TYPE header, server_online, gateway_running = await build_config_message_header( f"➖ Remove {chain.replace('-', ' ').title()} Wallet", include_gateway=True, - chat_id=chat_id + chat_id=chat_id, + user_data=context.user_data ) chain_escaped = escape_markdown_v2(chain.replace("-", " ").title()) @@ -635,12 +783,12 @@ async def prompt_remove_wallet_address(query, context: ContextTypes.DEFAULT_TYPE async def remove_wallet(query, context: ContextTypes.DEFAULT_TYPE, chain: str, address: str) -> None: """Remove a wallet from Gateway""" try: - from servers import server_manager + from config_manager import get_config_manager await query.answer("Removing wallet...") chat_id = query.message.chat_id - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Remove the wallet from Gateway await client.accounts.remove_gateway_wallet(chain=chain, address=address) @@ -705,8 +853,7 @@ async def handle_wallet_input(update: Update, context: ContextTypes.DEFAULT_TYPE return # Show adding message - from servers import server_manager - from types import SimpleNamespace + from config_manager import get_config_manager chain_escaped = escape_markdown_v2(chain.replace("-", " ").title()) if message_id and chat_id: @@ -718,7 +865,7 @@ async def handle_wallet_input(update: Update, context: ContextTypes.DEFAULT_TYPE ) try: - client = await server_manager.get_client_for_chat(chat_id) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=get_active_server(context.user_data)) # Add the wallet response = await client.accounts.add_gateway_wallet(chain=chain, private_key=private_key) diff --git a/handlers/config/server_context.py b/handlers/config/server_context.py index f39d0de..fcac248 100644 --- a/handlers/config/server_context.py +++ b/handlers/config/server_context.py @@ -5,33 +5,33 @@ """ import logging -from typing import Optional, Dict, Any, Tuple +from typing import Tuple from utils.telegram_formatters import escape_markdown_v2 logger = logging.getLogger(__name__) -async def get_server_context_header(chat_id: int = None) -> Tuple[str, bool]: +async def get_server_context_header(user_data: dict = None) -> Tuple[str, bool]: """ Get a standardized server context header showing current server and status. Args: - chat_id: Optional chat ID to get per-chat server. If None, uses global default. + user_data: Optional user_data dict to get user's preferred server Returns: Tuple of (header_text: str, is_online: bool) - header_text: Formatted markdown text with server info and status - is_online: True if server is online, False otherwise """ try: - from servers import server_manager + from config_manager import get_config_manager - # Get default server (per-chat if chat_id provided) - if chat_id is not None: - default_server = server_manager.get_default_server_for_chat(chat_id) - else: - default_server = server_manager.get_default_server() - servers = server_manager.list_servers() + # Get user's preferred server + default_server = None + if user_data: + from handlers.config.user_preferences import get_active_server + default_server = get_active_server(user_data) + if not default_server: + default_server = get_config_manager().get_default_server() + servers = get_config_manager().list_servers() if not servers: return "⚠️ _No servers configured_\n", False @@ -40,12 +40,12 @@ async def get_server_context_header(chat_id: int = None) -> Tuple[str, bool]: return "⚠️ _No default server set_\n", False # Get server config - server_config = server_manager.get_server(default_server) + server_config = get_config_manager().get_server(default_server) if not server_config: return "⚠️ _Server configuration not found_\n", False # Check server status - status_result = await server_manager.check_server_status(default_server) + status_result = await get_config_manager().check_server_status(default_server) status = status_result.get("status", "unknown") # Format status with icon @@ -79,25 +79,25 @@ async def get_server_context_header(chat_id: int = None) -> Tuple[str, bool]: return f"⚠️ _Error loading server info: {escape_markdown_v2(str(e))}_\n", False -async def get_gateway_status_info(chat_id: int = None) -> Tuple[str, bool]: +async def get_gateway_status_info(chat_id: int = None, user_data: dict = None) -> Tuple[str, bool]: """ Get gateway status information for the current server. Args: - chat_id: Optional chat ID to get per-chat server. If None, uses global default. + chat_id: Optional chat ID for getting the API client + user_data: Optional user_data dict to get user's preferred server Returns: - Tuple of (status_text: str, is_running: bool) - status_text: Formatted markdown text with gateway status - is_running: True if gateway is running, False otherwise + Tuple of (gateway_info: str, is_running: bool) """ try: - from servers import server_manager + from config_manager import get_config_manager - if chat_id is not None: - client = await server_manager.get_client_for_chat(chat_id) - else: - client = await server_manager.get_default_client() + preferred = None + if user_data: + from handlers.config.user_preferences import get_active_server + preferred = get_active_server(user_data) + client = await get_config_manager().get_client_for_chat(chat_id, preferred_server=preferred) # Check gateway status try: @@ -131,39 +131,24 @@ async def get_gateway_status_info(chat_id: int = None) -> Tuple[str, bool]: async def build_config_message_header( title: str, include_gateway: bool = False, - chat_id: int = None + chat_id: int = None, + user_data: dict = None ) -> Tuple[str, bool, bool]: - """ - Build a standardized header for configuration messages. - - Args: - title: The title/heading for this config screen (will be bolded automatically) - include_gateway: Whether to include gateway status info - chat_id: Optional chat ID to get per-chat server. If None, uses global default. - - Returns: - Tuple of (header_text: str, server_online: bool, gateway_running: bool) - """ - # Escape and bold the title + """Build a standardized header for configuration messages.""" title_escaped = escape_markdown_v2(title) header = f"*{title_escaped}*\n\n" - # Add server context - server_context, server_online = await get_server_context_header(chat_id) + server_context, server_online = await get_server_context_header(user_data) header += server_context - # Add gateway status if requested (but only if server is online to avoid long timeouts) gateway_running = False - if include_gateway: - if server_online: - gateway_info, gateway_running = await get_gateway_status_info(chat_id) - header += gateway_info - else: - # Server is offline, skip gateway check to avoid timeout - header += f"*Gateway:* ⚪️ {escape_markdown_v2('N/A')}\n" + if include_gateway and server_online: + gateway_info, gateway_running = await get_gateway_status_info(chat_id, user_data) + header += gateway_info + elif include_gateway: + header += f"*Gateway:* ⚪️ {escape_markdown_v2('N/A')}\n" header += "\n" - return header, server_online, gateway_running diff --git a/handlers/config/servers.py b/handlers/config/servers.py index 0f66e18..95417d7 100644 --- a/handlers/config/servers.py +++ b/handlers/config/servers.py @@ -2,15 +2,28 @@ API Servers configuration handlers """ +import asyncio import logging from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes, ConversationHandler from utils.telegram_formatters import escape_markdown_v2 -from .server_context import build_config_message_header +from utils.auth import restricted logger = logging.getLogger(__name__) + +@restricted +async def servers_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /servers command - show API servers configuration directly.""" + from handlers import clear_all_input_states + from utils.telegram_helpers import create_mock_query_from_message + + clear_all_input_states(context) + mock_query = await create_mock_query_from_message(update, "Loading servers...") + await show_api_servers(mock_query, context) + + # Conversation states (ADD_SERVER_NAME, ADD_SERVER_HOST, ADD_SERVER_PORT, ADD_SERVER_USERNAME, ADD_SERVER_PASSWORD, ADD_SERVER_CONFIRM, @@ -33,18 +46,23 @@ async def handle_servers_callback(update: Update, context: ContextTypes.DEFAULT_ async def show_api_servers(query, context: ContextTypes.DEFAULT_TYPE) -> None: """ - Show API servers configuration with status and actions + Show API servers configuration with status and actions. + Only shows servers the user has access to. """ try: - from servers import server_manager + from config_manager import get_config_manager, ServerPermission - # Reload configuration from servers.yml to pick up any manual changes - await server_manager.reload_config() + # Reload configuration to pick up any manual changes + get_config_manager().reload() - servers = server_manager.list_servers() - chat_id = query.message.chat_id - # Use per-chat default if set, otherwise global default - default_server = server_manager.get_default_server_for_chat(chat_id) + user_id = query.from_user.id + cm = get_config_manager() + + # Get only accessible servers + servers = cm.list_accessible_servers(user_id) + # User's preferred server (checks both user_data and config.yml) + from config_manager import get_effective_server + default_server = get_effective_server(query.message.chat_id, context.user_data) if not servers: message_text = ( @@ -54,16 +72,23 @@ async def show_api_servers(query, context: ContextTypes.DEFAULT_TYPE) -> None: ) keyboard = [ [InlineKeyboardButton("➕ Add Server", callback_data="api_server_add")], - [InlineKeyboardButton("« Back", callback_data="config_back")] + [InlineKeyboardButton("« Close", callback_data="config_close")] ] else: # Build server list with status server_lines = [] server_buttons = [] + # Check all server statuses in parallel + server_names = list(servers.keys()) + status_tasks = [ + get_config_manager().check_server_status(name) for name in server_names + ] + status_results = await asyncio.gather(*status_tasks) + server_statuses = dict(zip(server_names, status_results)) + for server_name, server_config in servers.items(): - # Check server status - status_result = await server_manager.check_server_status(server_name) + status_result = server_statuses[server_name] # Choose status icon and detail message if status_result["status"] == "online": @@ -78,24 +103,33 @@ async def show_api_servers(query, context: ContextTypes.DEFAULT_TYPE) -> None: error_msg = escape_markdown_v2(status_result.get("message", "Offline")) status_detail = f" \\[{error_msg}\\]" else: - status_icon = "🟡" + status_icon = "🔴" error_msg = escape_markdown_v2(status_result.get("message", "Error")) status_detail = f" \\[{error_msg}\\]" # Default server indicator default_indicator = " ⭐️" if server_name == default_server else "" + # Permission badge + perm = cm.get_server_permission(user_id, server_name) + perm_badges = { + ServerPermission.OWNER: "👑", + ServerPermission.TRADER: "💱", + ServerPermission.VIEWER: "👁", + } + perm_badge = perm_badges.get(perm, "") + " " if perm else "" + url = f"{server_config['host']}:{server_config['port']}" url_escaped = escape_markdown_v2(url) name_escaped = escape_markdown_v2(server_name) server_lines.append( - f"{status_icon} *{name_escaped}*{default_indicator}{status_detail}\n" + f"{status_icon} {perm_badge}*{name_escaped}*{default_indicator}{status_detail}\n" f" `{url_escaped}`" ) # Add button for each server - button_text = f"{server_name}" + button_text = f"{perm_badge}{server_name}" if server_name == default_server: button_text += " ⭐️" server_buttons.append( @@ -118,7 +152,7 @@ async def show_api_servers(query, context: ContextTypes.DEFAULT_TYPE) -> None: [ InlineKeyboardButton("➕ Add Server", callback_data="api_server_add"), InlineKeyboardButton("🔄 Refresh", callback_data="config_api_servers"), - InlineKeyboardButton("« Back", callback_data="config_back") + InlineKeyboardButton("« Close", callback_data="config_close") ] ] @@ -139,7 +173,7 @@ async def show_api_servers(query, context: ContextTypes.DEFAULT_TYPE) -> None: except Exception as e: logger.error(f"Error showing API servers: {e}", exc_info=True) error_text = f"❌ Error loading API servers: {escape_markdown_v2(str(e))}" - keyboard = [[InlineKeyboardButton("« Back", callback_data="config_back")]] + keyboard = [[InlineKeyboardButton("« Close", callback_data="config_close")]] reply_markup = InlineKeyboardMarkup(keyboard) await query.message.edit_text(error_text, parse_mode="MarkdownV2", reply_markup=reply_markup) @@ -167,33 +201,87 @@ async def handle_api_server_action(query, context: ContextTypes.DEFAULT_TYPE) -> await confirm_delete_server(query, context, server_name) elif action_data == "cancel_delete": await show_api_servers(query, context) + # Server sharing actions + elif action_data.startswith("share_user_"): + # Format: share_user_{uid}_{server_name} + parts = action_data.replace("share_user_", "").split("_", 1) + if len(parts) == 2: + target_user_id = int(parts[0]) + server_name = parts[1] + await select_share_user(query, context, server_name, target_user_id) + else: + await query.answer("Invalid share action") + elif action_data.startswith("share_manual_"): + server_name = action_data.replace("share_manual_", "") + await start_manual_share_flow(query, context, server_name) + elif action_data.startswith("share_start_"): + server_name = action_data.replace("share_start_", "") + await start_share_flow(query, context, server_name) + elif action_data.startswith("share_cancel_"): + server_name = action_data.replace("share_cancel_", "") + # Clear sharing state + context.user_data.pop('sharing_server', None) + context.user_data.pop('awaiting_share_user_id', None) + context.user_data.pop('share_target_user_id', None) + context.user_data.pop('share_message_id', None) + context.user_data.pop('share_chat_id', None) + await show_server_sharing(query, context, server_name) + elif action_data.startswith("share_"): + server_name = action_data.replace("share_", "") + await show_server_sharing(query, context, server_name) + elif action_data.startswith("perm_trader_"): + server_name = action_data.replace("perm_trader_", "") + await set_share_permission(query, context, server_name, "trader") + elif action_data.startswith("perm_viewer_"): + server_name = action_data.replace("perm_viewer_", "") + await set_share_permission(query, context, server_name, "viewer") + elif action_data.startswith("revoke_"): + # Format: revoke_{user_id}_{server_name} + parts = action_data.replace("revoke_", "").split("_", 1) + if len(parts) == 2: + target_user_id = int(parts[0]) + server_name = parts[1] + await revoke_access(query, context, server_name, target_user_id) + else: + await query.answer("Invalid revoke action") else: await query.answer("Unknown action") async def show_server_details(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: - """Show details and actions for a specific server""" + """Show details and actions for a specific server. + Actions are restricted based on user's permission level. + """ try: - from servers import server_manager + from config_manager import get_config_manager, ServerPermission # Clear any modify state when showing server details context.user_data.pop('modifying_server', None) context.user_data.pop('modifying_field', None) context.user_data.pop('awaiting_modify_input', None) - server = server_manager.get_server(server_name) + server = get_config_manager().get_server(server_name) if not server: await query.answer("❌ Server not found") return - chat_id = query.message.chat_id - chat_info = server_manager.get_chat_server_info(chat_id) - default_server = server_manager.get_default_server() - is_global_default = server_name == default_server - is_chat_default = chat_info.get("is_per_chat") and chat_info.get("server") == server_name + user_id = query.from_user.id + cm = get_config_manager() + + # Check user's permission level + perm = cm.get_server_permission(user_id, server_name) + if not perm: + await query.answer("❌ No access to this server") + return + + is_owner = perm == ServerPermission.OWNER + can_trade = perm in (ServerPermission.OWNER, ServerPermission.TRADER) + + from config_manager import get_effective_server + is_user_default = server_name == get_effective_server(query.message.chat_id, context.user_data) # Check status - status_result = await server_manager.check_server_status(server_name) + status_result = await get_config_manager().check_server_status(server_name) status = status_result["status"] message = status_result.get("message", "") @@ -209,40 +297,60 @@ async def show_server_details(query, context: ContextTypes.DEFAULT_TYPE, server_ name_escaped = escape_markdown_v2(server_name) host_escaped = escape_markdown_v2(server['host']) port_escaped = escape_markdown_v2(str(server['port'])) - username_escaped = escape_markdown_v2(server['username']) + + # Permission badge + perm_labels = { + ServerPermission.OWNER: "👑 Owner", + ServerPermission.TRADER: "💱 Trader", + ServerPermission.VIEWER: "👁 Viewer", + } + perm_label = perm_labels.get(perm, "Unknown") message_text = ( f"🔌 *Server: {name_escaped}*\n\n" f"*Status:* {status_text}\n" f"*Host:* `{host_escaped}`\n" f"*Port:* `{port_escaped}`\n" - f"*Username:* `{username_escaped}`\n" + f"*Access:* {escape_markdown_v2(perm_label)}\n" ) - # Show if this is the default for this chat - if is_chat_default: - message_text += "\n⭐️ _Default for this chat_" + # Only show username to owners + if is_owner: + username_escaped = escape_markdown_v2(server['username']) + message_text += f"*Username:* `{username_escaped}`\n" + + # Show if this is the user's default + if is_user_default: + message_text += "\n⭐️ _Your default server_" - message_text += "\n\n_You can modify or delete this server using the buttons below\\._" + # Different help text based on permission + if is_owner: + message_text += "\n\n_You can modify, share, or delete this server\\._" + elif can_trade: + message_text += "\n\n_You can use this server for trading\\._" + else: + message_text += "\n\n_You have view\\-only access to this server\\._" keyboard = [] - # Show Set as Default button only if not already default - if not is_chat_default: + # Show Set as Default button for traders and owners + if can_trade and not is_user_default: keyboard.append([InlineKeyboardButton("⭐️ Set as Default", callback_data=f"api_server_set_default_{server_name}")]) - # Add modification buttons in a row with 4 columns - keyboard.append([ - InlineKeyboardButton("🌐 Host", callback_data=f"modify_field_host_{server_name}"), - InlineKeyboardButton("🔌 Port", callback_data=f"modify_field_port_{server_name}"), - InlineKeyboardButton("👤 User", callback_data=f"modify_field_username_{server_name}"), - InlineKeyboardButton("🔑 Pass", callback_data=f"modify_field_password_{server_name}"), - ]) - - keyboard.extend([ - [InlineKeyboardButton("🗑 Delete", callback_data=f"api_server_delete_{server_name}")], - [InlineKeyboardButton("« Back to Servers", callback_data="config_api_servers")], - ]) + # Only owners can modify server settings + if is_owner: + keyboard.append([ + InlineKeyboardButton("🌐 Host", callback_data=f"modify_field_host_{server_name}"), + InlineKeyboardButton("🔌 Port", callback_data=f"modify_field_port_{server_name}"), + InlineKeyboardButton("👤 User", callback_data=f"modify_field_username_{server_name}"), + InlineKeyboardButton("🔑 Pass", callback_data=f"modify_field_password_{server_name}"), + ]) + keyboard.append([ + InlineKeyboardButton("📤 Share", callback_data=f"api_server_share_{server_name}"), + InlineKeyboardButton("🗑 Delete", callback_data=f"api_server_delete_{server_name}"), + ]) + + keyboard.append([InlineKeyboardButton("« Back to Servers", callback_data="config_api_servers")]) reply_markup = InlineKeyboardMarkup(keyboard) @@ -258,28 +366,25 @@ async def show_server_details(query, context: ContextTypes.DEFAULT_TYPE, server_ async def set_default_server(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: - """Set server as default for this chat""" + """Set server as default for this user/chat""" try: - from servers import server_manager from handlers.dex._shared import invalidate_cache + from handlers.config.user_preferences import set_active_server + from config_manager import get_config_manager - chat_id = query.message.chat_id - success = server_manager.set_default_server_for_chat(chat_id, server_name) + # Save to user_data (in-memory, pickle persistence) + set_active_server(context.user_data, server_name) - if success: - # Invalidate ALL cached data since we're switching to a different server - # This ensures /lp, /swap, etc. will fetch fresh data from the new server - invalidate_cache(context.user_data, "all") - - # Store current server in user_data as fallback for background tasks - context.user_data["_current_server"] = server_name + # Also save to config.yml for immediate persistence (survives hard kills) + chat_id = query.message.chat_id + get_config_manager().set_chat_default_server(chat_id, server_name) - logger.info(f"Cache invalidated after switching to server '{server_name}'") + # Invalidate ALL cached data since we're switching to a different server + invalidate_cache(context.user_data, "all") + context.user_data["_current_server"] = server_name - await query.answer(f"✅ Set {server_name} as default for this chat") - await show_server_details(query, context, server_name) - else: - await query.answer("❌ Failed to set default server") + await query.answer(f"✅ Set {server_name} as your default server") + await show_server_details(query, context, server_name) except Exception as e: logger.error(f"Error setting default server: {e}", exc_info=True) @@ -310,17 +415,29 @@ async def confirm_delete_server(query, context: ContextTypes.DEFAULT_TYPE, serve async def delete_server(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: - """Delete a server from configuration""" + """Delete a server from configuration. + Only owners can delete servers. + """ try: - from servers import server_manager + from config_manager import get_config_manager from handlers.dex._shared import invalidate_cache + from config_manager import get_config_manager, ServerPermission - # Check if this is the current chat's default server - chat_id = query.message.chat_id - current_default = server_manager.get_default_server_for_chat(chat_id) - was_current = (current_default == server_name) + user_id = query.from_user.id + cm = get_config_manager() + + # Check if user has owner permission + perm = cm.get_server_permission(user_id, server_name) + if perm != ServerPermission.OWNER: + await query.answer("❌ Only the owner can delete this server", show_alert=True) + return - success = server_manager.delete_server(server_name) + # Check if this is the user's current default server + from handlers.config.user_preferences import get_active_server + was_current = (get_active_server(context.user_data) == server_name) + + # Delete server and clean up permissions + success = get_config_manager().delete_server(server_name, actor_id=user_id) if success: # Invalidate cache if we deleted the server that was in use @@ -500,8 +617,8 @@ async def handle_add_server_input(update: Update, context: ContextTypes.DEFAULT_ awaiting_field not in server_data if awaiting_field == 'name': - from servers import server_manager - if new_value in server_manager.list_servers() and new_value != server_data.get('name'): + from config_manager import get_config_manager + if new_value in get_config_manager().list_servers() and new_value != server_data.get('name'): message_id = context.user_data.get('add_server_message_id') chat_id = context.user_data.get('add_server_chat_id') if message_id and chat_id: @@ -702,9 +819,10 @@ async def handle_add_server_callbacks(query, context: ContextTypes.DEFAULT_TYPE) async def confirm_add_server(query, context: ContextTypes.DEFAULT_TYPE) -> None: """Actually add the server to configuration""" try: - from servers import server_manager + from config_manager import get_config_manager server_data = context.user_data.get('adding_server', {}) + user_id = query.from_user.id required_fields = ['name', 'host', 'port', 'username', 'password'] for field in required_fields: @@ -712,12 +830,14 @@ async def confirm_add_server(query, context: ContextTypes.DEFAULT_TYPE) -> None: await query.answer(f"❌ Missing field: {field}") return - success = server_manager.add_server( + # Add server with ownership registration + success = get_config_manager().add_server( name=server_data['name'], host=server_data['host'], port=server_data['port'], username=server_data['username'], - password=server_data['password'] + password=server_data['password'], + owner_id=user_id ) if success: @@ -737,9 +857,9 @@ async def confirm_add_server(query, context: ContextTypes.DEFAULT_TYPE) -> None: async def start_modify_server(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> int: """Start the modify server conversation""" try: - from servers import server_manager + from config_manager import get_config_manager - server = server_manager.get_server(server_name) + server = get_config_manager().get_server(server_name) if not server: await query.answer("❌ Server not found") return ConversationHandler.END @@ -792,8 +912,8 @@ async def handle_modify_field_selection(query, context: ContextTypes.DEFAULT_TYP context.user_data['modify_message_id'] = query.message.message_id context.user_data['modify_chat_id'] = query.message.chat_id - from servers import server_manager - server = server_manager.get_server(server_name) + from config_manager import get_config_manager + server = get_config_manager().get_server(server_name) if not server: await query.answer("❌ Server not found") return @@ -856,7 +976,7 @@ async def handle_modify_value_input(update: Update, context: ContextTypes.DEFAUL pass try: - from servers import server_manager + from config_manager import get_config_manager if not server_name or not field: await context.bot.send_message( @@ -885,7 +1005,7 @@ async def handle_modify_value_input(update: Update, context: ContextTypes.DEFAUL return kwargs = {field: new_value} - success = server_manager.modify_server(server_name, **kwargs) + success = get_config_manager().modify_server(server_name, **kwargs) # Clear modification state context.user_data.pop('modifying_server', None) @@ -895,9 +1015,9 @@ async def handle_modify_value_input(update: Update, context: ContextTypes.DEFAUL if success: logger.info(f"Successfully modified {field} for server {server_name}") - # Invalidate cache if this is the current chat's default server - current_default = server_manager.get_default_server_for_chat(chat_id) - if current_default == server_name: + # Invalidate cache if this is the user's current default server + from handlers.config.user_preferences import get_active_server + if get_active_server(context.user_data) == server_name: from handlers.dex._shared import invalidate_cache invalidate_cache(context.user_data, "all") logger.info(f"Cache invalidated after modifying current server '{server_name}'") @@ -987,3 +1107,417 @@ async def handle_server_input(update: Update, context: ContextTypes.DEFAULT_TYPE await handle_add_server_input(update, context) elif context.user_data.get('awaiting_modify_input'): await handle_modify_value_input(update, context) + elif context.user_data.get('awaiting_share_user_id'): + await handle_share_user_id_input(update, context) + + +# ==================== Server Sharing ==================== + +async def show_server_sharing(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: + """Show sharing details and management for a server.""" + from config_manager import get_config_manager, ServerPermission + + user_id = query.from_user.id + cm = get_config_manager() + + # Check ownership + perm = cm.get_server_permission(user_id, server_name) + if perm != ServerPermission.OWNER: + await query.answer("Only the owner can manage sharing", show_alert=True) + return + + shared_users = cm.get_server_shared_users(server_name) + name_escaped = escape_markdown_v2(server_name) + + message = f"📤 *Share Server: {name_escaped}*\n\n" + + keyboard = [] + + if shared_users: + message += "*Shared with:*\n" + + perm_badges = { + ServerPermission.TRADER: "💱", + ServerPermission.VIEWER: "👁", + } + + for target_user_id, perm in shared_users: + target_user = cm.get_user(target_user_id) + username = target_user.get('username') if target_user else None + + badge = perm_badges.get(perm, "?") + if username: + message += f" {badge} `{target_user_id}` \\(@{escape_markdown_v2(username)}\\)\n" + else: + message += f" {badge} `{target_user_id}`\n" + + keyboard.append([ + InlineKeyboardButton( + f"🗑 Revoke {target_user_id}", + callback_data=f"api_server_revoke_{target_user_id}_{server_name}" + ) + ]) + + message += "\n" + else: + message += "_Not shared with anyone yet\\._\n\n" + + message += "_Enter a User ID below to share this server\\._" + + keyboard.append([ + InlineKeyboardButton("➕ Share with User ID", callback_data=f"api_server_share_start_{server_name}") + ]) + keyboard.append([InlineKeyboardButton("« Back", callback_data=f"api_server_view_{server_name}")]) + + await query.message.edit_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def start_share_flow(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: + """Start the share flow. + - Admin sees a list of approved users to pick from + - Regular users enter the user ID manually + """ + from config_manager import get_config_manager, ServerPermission, UserRole + + user_id = query.from_user.id + cm = get_config_manager() + + perm = cm.get_server_permission(user_id, server_name) + if perm != ServerPermission.OWNER: + await query.answer("Only the owner can share", show_alert=True) + return + + context.user_data['sharing_server'] = server_name + context.user_data['share_message_id'] = query.message.message_id + context.user_data['share_chat_id'] = query.message.chat_id + + name_escaped = escape_markdown_v2(server_name) + owner_id = cm.get_server_owner(server_name) + + # Get already shared users to exclude them + shared_users = cm.get_server_shared_users(server_name) + shared_user_ids = {uid for uid, _ in shared_users} + + # For admin: show list of approved users + if cm.is_admin(user_id): + approved_users = [ + u for u in cm.get_all_users() + if u.get('role') in (UserRole.USER.value, UserRole.ADMIN.value) + and u['user_id'] != owner_id + and u['user_id'] not in shared_user_ids + ] + + if not approved_users: + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + "_No approved users available to share with\\._\n\n" + "All approved users either already have access or are the owner\\." + ) + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"api_server_share_{server_name}")]] + else: + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + "Select a user to share with:" + ) + keyboard = [] + for u in approved_users[:10]: # Limit to 10 users + uid = u['user_id'] + username = u.get('username') or 'N/A' + btn_text = f"@{username}" if username != 'N/A' else str(uid) + keyboard.append([ + InlineKeyboardButton(btn_text, callback_data=f"api_server_share_user_{uid}_{server_name}") + ]) + + if len(approved_users) > 10: + message += f"\n\n_Showing first 10 of {len(approved_users)} users_" + + keyboard.append([InlineKeyboardButton("✏️ Enter ID manually", callback_data=f"api_server_share_manual_{server_name}")]) + keyboard.append([InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}")]) + + context.user_data['awaiting_share_user_id'] = False + else: + # Regular users: manual entry + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + "Enter the *User ID* of the user you want to share with:\n\n" + "_The user must be approved to receive access\\._" + ) + keyboard = [[InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}")]] + context.user_data['awaiting_share_user_id'] = True + + await query.message.edit_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def select_share_user(query, context: ContextTypes.DEFAULT_TYPE, server_name: str, target_user_id: int) -> None: + """Handle user selection from the list (admin flow).""" + from config_manager import get_config_manager, ServerPermission + + cm = get_config_manager() + user_id = query.from_user.id + + # Verify ownership + perm = cm.get_server_permission(user_id, server_name) + if perm != ServerPermission.OWNER: + await query.answer("Only the owner can share", show_alert=True) + return + + # Store target and ask for permission level + context.user_data['sharing_server'] = server_name + context.user_data['share_target_user_id'] = target_user_id + + target_user = cm.get_user(target_user_id) + username = target_user.get('username') if target_user else None + name_escaped = escape_markdown_v2(server_name) + + if username: + user_display = f"`{target_user_id}` \\(@{escape_markdown_v2(username)}\\)" + else: + user_display = f"`{target_user_id}`" + + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + f"Sharing with: {user_display}\n\n" + "Select the permission level:" + ) + + keyboard = [ + [InlineKeyboardButton("💱 Trader (can trade)", callback_data=f"api_server_perm_trader_{server_name}")], + [InlineKeyboardButton("👁 Viewer (read-only)", callback_data=f"api_server_perm_viewer_{server_name}")], + [InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}")], + ] + + await query.message.edit_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def start_manual_share_flow(query, context: ContextTypes.DEFAULT_TYPE, server_name: str) -> None: + """Start manual entry flow for sharing (used when admin clicks 'Enter ID manually').""" + from config_manager import get_config_manager, ServerPermission + + user_id = query.from_user.id + cm = get_config_manager() + + perm = cm.get_server_permission(user_id, server_name) + if perm != ServerPermission.OWNER: + await query.answer("Only the owner can share", show_alert=True) + return + + context.user_data['sharing_server'] = server_name + context.user_data['awaiting_share_user_id'] = True + context.user_data['share_message_id'] = query.message.message_id + context.user_data['share_chat_id'] = query.message.chat_id + + name_escaped = escape_markdown_v2(server_name) + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + "Enter the *User ID* of the user you want to share with:\n\n" + "_The user must be approved to receive access\\._" + ) + + keyboard = [[InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}")]] + + await query.message.edit_text( + message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_share_user_id_input(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle user ID input for sharing.""" + from config_manager import get_config_manager + + if not context.user_data.get('awaiting_share_user_id'): + return + + server_name = context.user_data.get('sharing_server') + if not server_name: + return + + try: + await update.message.delete() + except: + pass + + cm = get_config_manager() + user_id = update.effective_user.id + chat_id = update.effective_chat.id + message_id = context.user_data.get('share_message_id') + + # Parse target user ID + try: + target_user_id = int(update.message.text.strip()) + except ValueError: + if message_id: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="❌ Invalid User ID\\. Please enter a valid number\\.\n\nEnter the User ID:", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([[ + InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}") + ]]) + ) + return + + # Check if target user is approved + if not cm.is_approved(target_user_id): + if message_id: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=f"❌ User `{target_user_id}` is not an approved user\\.\n\n" + "Only approved users can receive server access\\.\n\n" + "Enter a different User ID:", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([[ + InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}") + ]]) + ) + return + + # Check if trying to share with self + owner_id = cm.get_server_owner(server_name) + if target_user_id == owner_id: + if message_id: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text="❌ You can't share with the owner\\.\n\nEnter a different User ID:", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([[ + InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}") + ]]) + ) + return + + # Store target and ask for permission level + context.user_data['share_target_user_id'] = target_user_id + context.user_data['awaiting_share_user_id'] = False + + target_user = cm.get_user(target_user_id) + username = target_user.get('username') if target_user else None + name_escaped = escape_markdown_v2(server_name) + + if username: + user_display = f"`{target_user_id}` \\(@{escape_markdown_v2(username)}\\)" + else: + user_display = f"`{target_user_id}`" + + message = ( + f"📤 *Share Server: {name_escaped}*\n\n" + f"Sharing with: {user_display}\n\n" + "Select the permission level:" + ) + + keyboard = [ + [InlineKeyboardButton("💱 Trader (can trade)", callback_data=f"api_server_perm_trader_{server_name}")], + [InlineKeyboardButton("👁 Viewer (read-only)", callback_data=f"api_server_perm_viewer_{server_name}")], + [InlineKeyboardButton("❌ Cancel", callback_data=f"api_server_share_cancel_{server_name}")], + ] + + if message_id: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=message_id, + text=message, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def set_share_permission(query, context: ContextTypes.DEFAULT_TYPE, server_name: str, permission: str) -> None: + """Set the permission level and complete sharing.""" + from config_manager import get_config_manager, ServerPermission + + user_id = query.from_user.id + target_user_id = context.user_data.get('share_target_user_id') + + if not target_user_id: + await query.answer("Session expired. Please try again.", show_alert=True) + await show_api_servers(query, context) + return + + cm = get_config_manager() + + perm_map = { + 'trader': ServerPermission.TRADER, + 'viewer': ServerPermission.VIEWER, + } + perm = perm_map.get(permission) + + if not perm: + await query.answer("Invalid permission", show_alert=True) + return + + success = cm.share_server(server_name, user_id, target_user_id, perm) + + # Clean up state + context.user_data.pop('sharing_server', None) + context.user_data.pop('share_target_user_id', None) + context.user_data.pop('share_message_id', None) + context.user_data.pop('share_chat_id', None) + + if success: + # Notify target user + try: + perm_label = "Trader" if perm == ServerPermission.TRADER else "Viewer" + await context.bot.send_message( + chat_id=target_user_id, + text=( + f"📥 *Server Shared With You*\n\n" + f"You now have *{escape_markdown_v2(perm_label)}* access to server:\n" + f"`{escape_markdown_v2(server_name)}`\n\n" + f"Use /config \\> API Servers to access it\\." + ), + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Failed to notify user {target_user_id} of share: {e}") + + await query.answer(f"Shared with {target_user_id}", show_alert=True) + await show_server_sharing(query, context, server_name) + else: + await query.answer("Failed to share server", show_alert=True) + await show_server_sharing(query, context, server_name) + + +async def revoke_access(query, context: ContextTypes.DEFAULT_TYPE, server_name: str, target_user_id: int) -> None: + """Revoke a user's access to a server.""" + from config_manager import get_config_manager + + user_id = query.from_user.id + cm = get_config_manager() + + success = cm.revoke_server_access(server_name, user_id, target_user_id) + + if success: + # Notify target user + try: + await context.bot.send_message( + chat_id=target_user_id, + text=( + f"🚫 *Access Revoked*\n\n" + f"Your access to server `{escape_markdown_v2(server_name)}` has been revoked\\." + ), + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Failed to notify user {target_user_id} of revocation: {e}") + + await query.answer(f"Revoked access for {target_user_id}", show_alert=True) + else: + await query.answer("Failed to revoke access", show_alert=True) + + await show_server_sharing(query, context, server_name) diff --git a/handlers/config/user_preferences.py b/handlers/config/user_preferences.py index caccab6..8cecec9 100644 --- a/handlers/config/user_preferences.py +++ b/handlers/config/user_preferences.py @@ -50,6 +50,10 @@ DEFAULT_DEX_SIDE = "BUY" DEFAULT_DEX_AMOUNT = "1.0" +# Unified trade defaults +DEFAULT_TRADE_CONNECTOR_TYPE = "dex" # "cex" or "dex" +DEFAULT_TRADE_CONNECTOR_NAME = "solana-mainnet-beta" # For DEX: network ID, for CEX: connector name + # ============================================ # TYPE DEFINITIONS @@ -118,12 +122,23 @@ class GatewayPrefs(TypedDict, total=False): wallet_networks: Dict[str, list] # wallet_address -> list of enabled network IDs +class UnifiedTradePrefs(TypedDict, total=False): + """Unified trade preferences for /trade command. + + Tracks which connector type (CEX/DEX) and which specific connector + was last used, so the unified /trade command can show the right UI. + """ + last_connector_type: str # "cex" or "dex" + last_connector_name: str # e.g., "jupiter", "binance_perpetual" + + class UserPreferences(TypedDict, total=False): portfolio: PortfolioPrefs clob: CLOBPrefs dex: DEXPrefs general: GeneralPrefs gateway: GatewayPrefs + unified_trade: UnifiedTradePrefs # ============================================ @@ -156,6 +171,10 @@ def _get_default_preferences() -> UserPreferences: "gateway": { "wallet_networks": {}, # wallet_address -> list of enabled network IDs }, + "unified_trade": { + "last_connector_type": DEFAULT_TRADE_CONNECTOR_TYPE, + "last_connector_name": DEFAULT_TRADE_CONNECTOR_NAME, + }, } @@ -621,6 +640,56 @@ def get_all_enabled_networks(user_data: Dict) -> set: return all_networks if all_networks else None +# ============================================ +# PUBLIC API - UNIFIED TRADE +# ============================================ + +def get_unified_trade_prefs(user_data: Dict) -> UnifiedTradePrefs: + """Get unified trade preferences + + Returns: + Unified trade preferences with last_connector_type and last_connector_name + """ + _migrate_legacy_data(user_data) + prefs = _ensure_preferences(user_data) + return deepcopy(prefs.get("unified_trade", { + "last_connector_type": DEFAULT_TRADE_CONNECTOR_TYPE, + "last_connector_name": DEFAULT_TRADE_CONNECTOR_NAME, + })) + + +def get_last_trade_connector(user_data: Dict) -> tuple: + """Get last used trade connector type and name + + Returns: + Tuple of (connector_type, connector_name) + - For DEX: ("dex", "solana-mainnet-beta") - connector_name is the NETWORK ID + - For CEX: ("cex", "binance_perpetual") - connector_name is the connector + """ + prefs = get_unified_trade_prefs(user_data) + return ( + prefs.get("last_connector_type", DEFAULT_TRADE_CONNECTOR_TYPE), + prefs.get("last_connector_name", DEFAULT_TRADE_CONNECTOR_NAME), + ) + + +def set_last_trade_connector(user_data: Dict, connector_type: str, connector_name: str) -> None: + """Set last used trade connector + + Args: + user_data: User data dict + connector_type: "cex" or "dex" + connector_name: For DEX: network ID (e.g., "solana-mainnet-beta") + For CEX: connector name (e.g., "binance_perpetual") + """ + prefs = _ensure_preferences(user_data) + if "unified_trade" not in prefs: + prefs["unified_trade"] = {} + prefs["unified_trade"]["last_connector_type"] = connector_type + prefs["unified_trade"]["last_connector_name"] = connector_name + logger.info(f"Set last trade connector: {connector_type}:{connector_name}") + + # ============================================ # UTILITY FUNCTIONS # ============================================ diff --git a/handlers/dex/__init__.py b/handlers/dex/__init__.py index bc1a3cc..20cf7bd 100644 --- a/handlers/dex/__init__.py +++ b/handlers/dex/__init__.py @@ -8,7 +8,7 @@ - Quick trading with saved parameters - GeckoTerminal pool exploration with OHLCV charts -Structure: +Module Structure: - menu.py: Main DEX menu and help - swap.py: Unified swap (quote, execute, history with filters/pagination) - liquidity.py: Unified liquidity pools (balances, positions, history with filters/pagination) @@ -16,140 +16,31 @@ - pool_data.py: Pool data fetching utilities (OHLCV, liquidity bins) - geckoterminal.py: GeckoTerminal pool explorer with charts - visualizations.py: Chart generation (liquidity distribution, OHLCV candlesticks) +- lp_monitor_handlers.py: LP monitor alert handling (navigation, rebalance, fees) +- router.py: Callback and message routing - _shared.py: Shared utilities (caching, formatters, history filters) """ import logging + from telegram import Update -from telegram.ext import ContextTypes, CallbackQueryHandler, MessageHandler, filters +from telegram.ext import ContextTypes from utils.auth import restricted from handlers import clear_all_input_states -# Import submodule handlers -from .menu import show_dex_menu, handle_close, handle_refresh, cancel_dex_loading_task -# Unified swap module -from .swap import ( - handle_swap, - handle_swap_refresh, - show_swap_menu, - handle_swap_toggle_side, - handle_swap_set_connector, - handle_swap_connector_select, - handle_swap_set_network, - handle_swap_network_select, - handle_swap_set_pair, - handle_swap_set_amount, - handle_swap_set_slippage, - handle_swap_get_quote, - handle_swap_execute_confirm, - handle_swap_history, - handle_swap_status, - handle_swap_hist_filter_pair, - handle_swap_hist_filter_connector, - handle_swap_hist_filter_status, - handle_swap_hist_set_filter, - handle_swap_hist_page, - handle_swap_hist_clear, - process_swap, - process_swap_set_pair, - process_swap_set_amount, - process_swap_set_slippage, - process_swap_status, -) -from .pools import ( - handle_pool_info, - handle_pool_list, - handle_pool_select, - handle_pool_list_back, - handle_pool_detail_refresh, - handle_add_to_gateway, - handle_plot_liquidity, - handle_pool_ohlcv, - handle_pool_combined_chart, - handle_manage_positions, - handle_pos_view, - handle_pos_view_pool, - handle_pos_collect_fees, - handle_pos_close_confirm, - handle_pos_close_execute, - handle_position_list, - handle_add_position, - show_add_position_menu, - handle_pos_set_connector, - handle_pos_set_network, - handle_pos_set_pool, - handle_pos_set_lower, - handle_pos_set_upper, - handle_pos_set_base, - handle_pos_set_quote, - handle_pos_add_confirm, - handle_pos_use_max_range, - handle_pos_help, - handle_pos_toggle_strategy, - handle_pos_refresh, - process_pool_info, - process_pool_list, - process_position_list, - process_add_position, - process_pos_set_connector, - process_pos_set_network, - process_pos_set_pool, - process_pos_set_lower, - process_pos_set_upper, - process_pos_set_base, - process_pos_set_quote, -) -from .geckoterminal import ( - show_gecko_explore_menu, - handle_gecko_toggle_view, - handle_gecko_select_network, - handle_gecko_set_network, - handle_gecko_show_pools, - handle_gecko_refresh, - handle_gecko_trending, - show_trending_pools, - handle_gecko_top, - show_top_pools, - handle_gecko_new, - show_new_pools, - handle_gecko_networks, - show_network_menu, - handle_gecko_search, - handle_gecko_search_network, - handle_gecko_search_set_network, - process_gecko_search, - show_pool_detail, - show_gecko_charts_menu, - show_ohlcv_chart, - show_recent_trades, - show_gecko_liquidity, - show_gecko_combined, - handle_copy_address, - handle_gecko_token_info, - handle_gecko_token_search, - handle_gecko_token_add, - handle_back_to_list, - handle_gecko_add_liquidity, - handle_gecko_swap, - show_gecko_info, -) -# Unified liquidity module -from .liquidity import ( - handle_liquidity, - show_liquidity_menu, - handle_lp_refresh, - handle_lp_pos_view, - handle_lp_collect_all, - handle_lp_history, - handle_lp_hist_filter_pair, - handle_lp_hist_filter_connector, - handle_lp_hist_filter_status, - handle_lp_hist_set_filter, - handle_lp_hist_page, - handle_lp_hist_clear, +# Import router components +from .router import ( + dex_callback_handler, + dex_message_handler, + get_dex_callback_handler, + get_dex_message_handler, ) +# Import command handlers +from .swap import handle_swap +from .liquidity import handle_liquidity + logger = logging.getLogger(__name__) @@ -200,476 +91,17 @@ async def lp_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None # ============================================ -# CALLBACK HANDLER -# ============================================ - -@restricted -async def dex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle inline button callbacks - Routes to appropriate sub-module""" - query = update.callback_query - await query.answer() - - try: - callback_parts = query.data.split(":", 1) - action = callback_parts[1] if len(callback_parts) > 1 else query.data - - # Cancel any pending menu loading task when navigating to a different action - # (show_dex_menu will cancel it internally anyway, so skip for main_menu) - if action != "main_menu": - cancel_dex_loading_task(context) - - # Only show typing for slow operations that need network calls - slow_actions = {"main_menu", "swap", "swap_refresh", "swap_get_quote", "swap_execute_confirm", "swap_history", - "swap_hist_clear", "swap_hist_filter_pair", "swap_hist_filter_connector", "swap_hist_filter_status", - "swap_hist_page_prev", "swap_hist_page_next", - "liquidity", "lp_refresh", "lp_history", "lp_collect_all", - "lp_hist_clear", "lp_hist_filter_pair", "lp_hist_filter_connector", "lp_hist_filter_status", - "lp_hist_page_prev", "lp_hist_page_next", - "pool_info", "pool_list", "manage_positions", "pos_add_confirm", "pos_close_exec", - "add_to_gateway", "pool_detail_refresh", - "gecko_networks", "gecko_trades", "gecko_show_pools", "gecko_refresh", "gecko_token_search", "gecko_token_add", - "gecko_explore", "gecko_swap", "gecko_info"} - # Also show typing for actions that start with these prefixes - slow_prefixes = ("gecko_trending_", "gecko_top_", "gecko_new_", "gecko_pool:", "gecko_ohlcv:", - "gecko_token:", "swap_hist_set_", "lp_hist_set_") - if action in slow_actions or action.startswith(slow_prefixes): - await query.message.reply_chat_action("typing") - - # Menu (legacy - redirect to swap) - if action == "main_menu": - await handle_swap(update, context) - - # Unified swap handlers - elif action == "swap": - await handle_swap(update, context) - elif action == "swap_refresh": - await handle_swap_refresh(update, context) - elif action == "swap_toggle_side": - await handle_swap_toggle_side(update, context) - elif action == "swap_set_connector": - await handle_swap_set_connector(update, context) - elif action.startswith("swap_connector_"): - connector_name = action.replace("swap_connector_", "") - await handle_swap_connector_select(update, context, connector_name) - elif action == "swap_set_network": - await handle_swap_set_network(update, context) - elif action.startswith("swap_network_"): - network_id = action.replace("swap_network_", "") - await handle_swap_network_select(update, context, network_id) - elif action == "swap_set_pair": - await handle_swap_set_pair(update, context) - elif action == "swap_set_amount": - await handle_swap_set_amount(update, context) - elif action == "swap_set_slippage": - await handle_swap_set_slippage(update, context) - elif action == "swap_get_quote": - await handle_swap_get_quote(update, context) - elif action == "swap_execute_confirm": - await handle_swap_execute_confirm(update, context) - elif action == "swap_history": - await handle_swap_history(update, context) - - # Swap history filter handlers - elif action == "swap_hist_filter_pair": - await handle_swap_hist_filter_pair(update, context) - elif action == "swap_hist_filter_connector": - await handle_swap_hist_filter_connector(update, context) - elif action == "swap_hist_filter_status": - await handle_swap_hist_filter_status(update, context) - elif action.startswith("swap_hist_set_pair_"): - value = action.replace("swap_hist_set_pair_", "") - await handle_swap_hist_set_filter(update, context, "pair", value) - elif action.startswith("swap_hist_set_connector_"): - value = action.replace("swap_hist_set_connector_", "") - await handle_swap_hist_set_filter(update, context, "connector", value) - elif action.startswith("swap_hist_set_status_"): - value = action.replace("swap_hist_set_status_", "") - await handle_swap_hist_set_filter(update, context, "status", value) - elif action == "swap_hist_page_prev": - await handle_swap_hist_page(update, context, "prev") - elif action == "swap_hist_page_next": - await handle_swap_hist_page(update, context, "next") - elif action == "swap_hist_clear": - await handle_swap_hist_clear(update, context) - - # Legacy swap handlers (redirect to unified) - elif action == "swap_quote": - await handle_swap(update, context) - elif action == "swap_execute": - await handle_swap(update, context) - elif action == "swap_search": - await handle_swap_history(update, context) - - # Status handler (still separate) - elif action == "swap_status": - await handle_swap_status(update, context) - - # Unified liquidity handlers - elif action == "liquidity": - await handle_liquidity(update, context) - elif action == "lp_refresh": - await handle_lp_refresh(update, context) - elif action.startswith("lp_pos_view:"): - pos_index = int(action.split(":")[1]) - await handle_lp_pos_view(update, context, pos_index) - elif action == "lp_collect_all": - await handle_lp_collect_all(update, context) - elif action == "lp_history": - await handle_lp_history(update, context) - - # LP history filter handlers - elif action == "lp_hist_filter_pair": - await handle_lp_hist_filter_pair(update, context) - elif action == "lp_hist_filter_connector": - await handle_lp_hist_filter_connector(update, context) - elif action == "lp_hist_filter_status": - await handle_lp_hist_filter_status(update, context) - elif action.startswith("lp_hist_set_pair_"): - value = action.replace("lp_hist_set_pair_", "") - await handle_lp_hist_set_filter(update, context, "pair", value) - elif action.startswith("lp_hist_set_connector_"): - value = action.replace("lp_hist_set_connector_", "") - await handle_lp_hist_set_filter(update, context, "connector", value) - elif action.startswith("lp_hist_set_status_"): - value = action.replace("lp_hist_set_status_", "") - await handle_lp_hist_set_filter(update, context, "status", value) - elif action == "lp_hist_page_prev": - await handle_lp_hist_page(update, context, "prev") - elif action == "lp_hist_page_next": - await handle_lp_hist_page(update, context, "next") - elif action == "lp_hist_clear": - await handle_lp_hist_clear(update, context) - - # No-op handler for page indicator buttons - elif action == "noop": - pass # Do nothing, just acknowledge the callback - - # Legacy - redirect to main LP menu - elif action == "explore_pools": - await handle_lp_refresh(update, context) - - # Pool handlers - elif action == "pool_info": - await handle_pool_info(update, context) - elif action == "pool_list": - await handle_pool_list(update, context) - elif action.startswith("pool_select:"): - pool_index = int(action.split(":")[1]) - await handle_pool_select(update, context, pool_index) - elif action == "pool_list_back": - await handle_pool_list_back(update, context) - elif action == "pool_detail_refresh": - await handle_pool_detail_refresh(update, context) - elif action.startswith("pool_tf:"): - # Format: pool_tf:timeframe - timeframe = action.split(":")[1] - await handle_pool_detail_refresh(update, context, timeframe=timeframe) - elif action == "add_to_gateway": - await handle_add_to_gateway(update, context) - elif action.startswith("plot_liquidity:"): - percentile = int(action.split(":")[1]) - await handle_plot_liquidity(update, context, percentile) - - # Manage positions (unified view) - elif action == "manage_positions": - await handle_manage_positions(update, context) - elif action.startswith("pos_view:"): - pos_index = action.split(":")[1] - await handle_pos_view(update, context, pos_index) - elif action.startswith("pos_view_tf:"): - # Format: pos_view_tf:pos_index:timeframe - parts = action.split(":") - pos_index = parts[1] - timeframe = parts[2] if len(parts) > 2 else "1h" - await handle_pos_view(update, context, pos_index, timeframe=timeframe) - elif action.startswith("pos_view_pool:"): - pos_index = action.split(":")[1] - await handle_pos_view_pool(update, context, pos_index) - elif action.startswith("pos_collect:"): - pos_index = action.split(":")[1] - await handle_pos_collect_fees(update, context, pos_index) - elif action.startswith("pos_close:"): - pos_index = action.split(":")[1] - await handle_pos_close_confirm(update, context, pos_index) - elif action.startswith("pos_close_exec:"): - pos_index = action.split(":")[1] - await handle_pos_close_execute(update, context, pos_index) - elif action == "position_list": - await handle_position_list(update, context) - - # Add position handlers - elif action == "add_position": - await handle_add_position(update, context) - elif action == "add_position_from_pool": - # Show loading indicator - await query.answer("Loading position form...") - # Pre-fill add position with selected pool - selected_pool = context.user_data.get("selected_pool", {}) - if selected_pool: - pool_address = selected_pool.get('pool_address', selected_pool.get('address', '')) - context.user_data["add_position_params"] = { - "connector": selected_pool.get('connector', 'meteora'), - "network": "solana-mainnet-beta", - "pool_address": pool_address, - "lower_price": "", - "upper_price": "", - "amount_base": "10%", # Default to 10% of balance - "amount_quote": "10%", # Default to 10% of balance - "strategy_type": "0", # Default strategy type (Spot) - } - await show_add_position_menu(update, context) - elif action.startswith("copy_pool:"): - # Show pool address for copying - send as message so user can easily copy - selected_pool = context.user_data.get("selected_pool", {}) - pool_address = selected_pool.get('pool_address', selected_pool.get('address', 'N/A')) - # Send as a code block message for easy copying (Telegram allows tap-to-copy on code blocks) - await query.answer("Address sent below ⬇️") - await query.message.reply_text( - f"`{pool_address}`", - parse_mode="Markdown" - ) - elif action == "pos_set_connector": - await handle_pos_set_connector(update, context) - elif action == "pos_set_network": - await handle_pos_set_network(update, context) - elif action == "pos_set_pool": - await handle_pos_set_pool(update, context) - elif action == "pos_set_lower": - await handle_pos_set_lower(update, context) - elif action == "pos_set_upper": - await handle_pos_set_upper(update, context) - elif action == "pos_set_base": - await handle_pos_set_base(update, context) - elif action == "pos_set_quote": - await handle_pos_set_quote(update, context) - elif action == "pos_add_confirm": - await handle_pos_add_confirm(update, context) - elif action == "pos_use_max_range": - await handle_pos_use_max_range(update, context) - elif action == "pos_help": - await handle_pos_help(update, context) - elif action == "pos_toggle_strategy": - await handle_pos_toggle_strategy(update, context) - elif action == "pos_refresh": - await handle_pos_refresh(update, context) - elif action.startswith("pos_tf:"): - # Format: pos_tf:timeframe - switch timeframe in add position menu - timeframe = action.split(":")[1] - await handle_pos_refresh(update, context, timeframe=timeframe) - - # GeckoTerminal explore handlers - elif action == "gecko_explore": - await show_gecko_explore_menu(update, context) - elif action == "gecko_toggle_view": - await handle_gecko_toggle_view(update, context) - elif action == "gecko_select_network": - await handle_gecko_select_network(update, context) - elif action.startswith("gecko_set_network:"): - network = action.split(":")[1] - await handle_gecko_set_network(update, context, network) - elif action == "gecko_show_pools": - await handle_gecko_show_pools(update, context) - elif action == "gecko_refresh": - await handle_gecko_refresh(update, context) - elif action == "gecko_trending": - await handle_gecko_trending(update, context) - elif action.startswith("gecko_trending_"): - network = action.replace("gecko_trending_", "") - network = None if network == "all" else network - await show_trending_pools(update, context, network) - elif action == "gecko_top": - await handle_gecko_top(update, context) - elif action.startswith("gecko_top_"): - network = action.replace("gecko_top_", "") - await show_top_pools(update, context, network) - elif action == "gecko_new": - await handle_gecko_new(update, context) - elif action.startswith("gecko_new_"): - network = action.replace("gecko_new_", "") - network = None if network == "all" else network - await show_new_pools(update, context, network) - elif action == "gecko_networks": - await handle_gecko_networks(update, context) - elif action.startswith("gecko_net_"): - network = action.replace("gecko_net_", "") - await show_network_menu(update, context, network) - elif action == "gecko_search": - await handle_gecko_search(update, context) - elif action == "gecko_search_network": - await handle_gecko_search_network(update, context) - elif action.startswith("gecko_search_set_net:"): - network = action.split(":")[1] - await handle_gecko_search_set_network(update, context, network) - elif action.startswith("gecko_pool:"): - pool_index = int(action.split(":")[1]) - await show_pool_detail(update, context, pool_index) - elif action == "gecko_charts": - await show_gecko_charts_menu(update, context) - elif action == "gecko_add_liquidity": - await handle_gecko_add_liquidity(update, context) - elif action.startswith("gecko_token:"): - token_type = action.split(":")[1] - await handle_gecko_token_info(update, context, token_type) - elif action == "gecko_token_search": - await handle_gecko_token_search(update, context) - elif action == "gecko_token_add": - await handle_gecko_token_add(update, context) - elif action == "gecko_swap": - await handle_gecko_swap(update, context) - elif action == "gecko_info": - await show_gecko_info(update, context) - elif action.startswith("gecko_ohlcv:"): - timeframe = action.split(":")[1] - await show_ohlcv_chart(update, context, timeframe) - elif action == "gecko_liquidity": - await show_gecko_liquidity(update, context) - elif action.startswith("gecko_combined:"): - timeframe = action.split(":")[1] - await show_gecko_combined(update, context, timeframe) - elif action == "gecko_trades": - await show_recent_trades(update, context) - elif action == "gecko_copy_addr": - await handle_copy_address(update, context) - elif action == "gecko_back_to_list": - await handle_back_to_list(update, context) - - # Pool OHLCV and combined chart handlers (for Meteora/CLMM pools) - elif action.startswith("pool_ohlcv:"): - parts = action.split(":") - timeframe = parts[1] - currency = parts[2] if len(parts) > 2 else "usd" - await handle_pool_ohlcv(update, context, timeframe, currency) - elif action.startswith("pool_combined:"): - parts = action.split(":") - timeframe = parts[1] - currency = parts[2] if len(parts) > 2 else "usd" - await handle_pool_combined_chart(update, context, timeframe, currency) - - # Refresh data - elif action == "refresh": - await handle_refresh(update, context) - - # Close menu - elif action == "close": - await handle_close(update, context) - - else: - await query.message.reply_text(f"Unknown action: {action}") - - except Exception as e: - # Ignore "message is not modified" errors - they're harmless - if "not modified" in str(e).lower(): - logger.debug(f"Message not modified (ignored): {e}") - return - - logger.error(f"Error in DEX callback handler: {e}", exc_info=True) - from utils.telegram_formatters import format_error_message - error_message = format_error_message(f"Operation failed: {str(e)}") - try: - await query.message.reply_text(error_message, parse_mode="MarkdownV2") - except Exception as reply_error: - logger.warning(f"Failed to send error message: {reply_error}") - - -# ============================================ -# MESSAGE HANDLER +# MODULE EXPORTS # ============================================ -@restricted -async def dex_message_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle user text input - Routes to appropriate processor""" - dex_state = context.user_data.get("dex_state") - - if not dex_state: - return - - user_input = update.message.text.strip() - logger.info(f"DEX message handler - state: {dex_state}, input: {user_input}") - - try: - # Only remove state for operations that complete (not parameter setting) - if dex_state in ["swap", "swap_status", "pool_info", "pool_list", "position_list", "add_position"]: - context.user_data.pop("dex_state", None) - - # Unified swap handlers - if dex_state == "swap": - await process_swap(update, context, user_input) - elif dex_state == "swap_set_pair": - await process_swap_set_pair(update, context, user_input) - elif dex_state == "swap_set_amount": - await process_swap_set_amount(update, context, user_input) - elif dex_state == "swap_set_slippage": - await process_swap_set_slippage(update, context, user_input) - - # Status handler - elif dex_state == "swap_status": - await process_swap_status(update, context, user_input) - - # Pool handlers - elif dex_state == "pool_info": - await process_pool_info(update, context, user_input) - elif dex_state == "pool_list": - await process_pool_list(update, context, user_input) - elif dex_state == "position_list": - await process_position_list(update, context, user_input) - - # Add position handlers - elif dex_state == "add_position": - await process_add_position(update, context, user_input) - elif dex_state == "pos_set_connector": - await process_pos_set_connector(update, context, user_input) - elif dex_state == "pos_set_network": - await process_pos_set_network(update, context, user_input) - elif dex_state == "pos_set_pool": - await process_pos_set_pool(update, context, user_input) - elif dex_state == "pos_set_lower": - await process_pos_set_lower(update, context, user_input) - elif dex_state == "pos_set_upper": - await process_pos_set_upper(update, context, user_input) - elif dex_state == "pos_set_base": - await process_pos_set_base(update, context, user_input) - elif dex_state == "pos_set_quote": - await process_pos_set_quote(update, context, user_input) - - # GeckoTerminal search handler - elif dex_state == "gecko_search": - await process_gecko_search(update, context, user_input) - - else: - await update.message.reply_text(f"Unknown state: {dex_state}") - - except Exception as e: - logger.error(f"Error processing DEX input: {e}", exc_info=True) - from utils.telegram_formatters import format_error_message - error_message = format_error_message(f"Failed to process input: {str(e)}") - await update.message.reply_text(error_message, parse_mode="MarkdownV2") - - -# ============================================ -# HANDLER FACTORIES -# ============================================ - -def get_dex_callback_handler(): - """Get the callback query handler for DEX menu""" - return CallbackQueryHandler( - dex_callback_handler, - pattern="^dex:" - ) - - -def get_dex_message_handler(): - """Returns the message handler""" - return MessageHandler( - filters.TEXT & ~filters.COMMAND, - dex_message_handler - ) - - __all__ = [ + # Commands 'swap_command', 'lp_command', + # Handlers 'dex_callback_handler', 'dex_message_handler', + # Handler factories 'get_dex_callback_handler', 'get_dex_message_handler', ] diff --git a/handlers/dex/_shared.py b/handlers/dex/_shared.py index 63cc7ec..04caf70 100644 --- a/handlers/dex/_shared.py +++ b/handlers/dex/_shared.py @@ -255,7 +255,7 @@ async def _refresh_loop(self, user_id: int, user_data: dict) -> None: try: # Use per-chat server if available chat_id = self._user_chat_ids.get(user_id) - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) except Exception as e: logger.warning(f"Background refresh: couldn't get client: {e}") return @@ -322,7 +322,7 @@ async def wrapper(update, context, *args, **kwargs): # SERVER CLIENT HELPERS # ============================================ -from servers import get_client +from config_manager import get_client # ============================================ @@ -654,7 +654,7 @@ def clear_dex_state(context) -> None: # HISTORY FILTER & PAGINATION HELPERS # ============================================ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Literal HistoryType = Literal["swap", "position"] diff --git a/handlers/dex/geckoterminal.py b/handlers/dex/geckoterminal.py index 69f7c7c..5399fdd 100644 --- a/handlers/dex/geckoterminal.py +++ b/handlers/dex/geckoterminal.py @@ -10,10 +10,8 @@ - Recent trades """ -import asyncio import io import logging -from typing import Optional, Dict, Any, List from datetime import datetime from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup @@ -21,7 +19,7 @@ from geckoterminal_py import GeckoTerminalAsyncClient from utils.telegram_formatters import escape_markdown_v2 -from ._shared import cached_call, set_cached, get_cached, clear_cache +from ._shared import cached_call from .visualizations import generate_ohlcv_chart, generate_liquidity_chart, generate_combined_chart from .pool_data import can_fetch_liquidity, get_connector_for_dex, fetch_liquidity_bins @@ -488,6 +486,20 @@ def _extract_pool_data(pool: dict) -> dict: base_token_address = _parse_token_address_from_id(base_token_id) quote_token_address = _parse_token_address_from_id(quote_token_id) + # Fallback: try direct keys for flattened DataFrames + if not base_token_address: + base_token_address = ( + _get_nested_value(attrs, "base_token_address") or + _get_nested_value(attrs, "base_token", "address") or + "" + ) + if not quote_token_address: + quote_token_address = ( + _get_nested_value(attrs, "quote_token_address") or + _get_nested_value(attrs, "quote_token", "address") or + "" + ) + # For flattened DataFrames, try direct keys first return { "id": pool.get("id", ""), @@ -737,12 +749,24 @@ async def show_trending_pools(update: Update, context: ContextTypes.DEFAULT_TYPE query = update.callback_query network_name = NETWORK_NAMES.get(network, network.title()) if network else "All Networks" + chat = query.message.chat - # Show loading - await query.message.edit_text( - f"🔥 *Trending \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", - parse_mode="MarkdownV2" - ) + # Show loading - handle photo messages + if getattr(query.message, 'photo', None): + try: + await query.message.delete() + except Exception: + pass + loading_msg = await chat.send_message( + f"🔥 *Trending \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + else: + await query.message.edit_text( + f"🔥 *Trending \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + loading_msg = query.message try: cache_key = f"gecko_trending_{network or 'all'}" @@ -778,7 +802,7 @@ async def fetch_trending(): # Build keyboard with network/view controls reply_markup = _build_pool_list_keyboard(pools, context.user_data) - await query.message.edit_text( + await loading_msg.edit_text( header + table + footer, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -786,14 +810,24 @@ async def fetch_trending(): except Exception as e: logger.error(f"Error fetching trending pools: {e}", exc_info=True) - await query.message.edit_text( - f"❌ Error fetching trending pools: {escape_markdown_v2(str(e))}", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup([ - [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], - [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] - ]) - ) + try: + await loading_msg.edit_text( + f"❌ Error fetching trending pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) + except Exception: + await chat.send_message( + f"❌ Error fetching trending pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) # ============================================ @@ -834,12 +868,24 @@ async def show_top_pools(update: Update, context: ContextTypes.DEFAULT_TYPE, net query = update.callback_query network_name = NETWORK_NAMES.get(network, network.title()) + chat = query.message.chat - # Show loading - await query.message.edit_text( - f"📈 *Top Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", - parse_mode="MarkdownV2" - ) + # Show loading - handle photo messages + if getattr(query.message, 'photo', None): + try: + await query.message.delete() + except Exception: + pass + loading_msg = await chat.send_message( + f"📈 *Top Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + else: + await query.message.edit_text( + f"📈 *Top Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + loading_msg = query.message try: cache_key = f"gecko_top_{network}" @@ -870,7 +916,7 @@ async def fetch_top(): # Build keyboard with network/view controls reply_markup = _build_pool_list_keyboard(pools, context.user_data) - await query.message.edit_text( + await loading_msg.edit_text( header + table + footer, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -878,14 +924,24 @@ async def fetch_top(): except Exception as e: logger.error(f"Error fetching top pools: {e}", exc_info=True) - await query.message.edit_text( - f"❌ Error fetching top pools: {escape_markdown_v2(str(e))}", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup([ - [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], - [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] - ]) - ) + try: + await loading_msg.edit_text( + f"❌ Error fetching top pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) + except Exception: + await chat.send_message( + f"❌ Error fetching top pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) # ============================================ @@ -928,14 +984,26 @@ async def handle_gecko_new(update: Update, context: ContextTypes.DEFAULT_TYPE) - async def show_new_pools(update: Update, context: ContextTypes.DEFAULT_TYPE, network: str = None) -> None: """Fetch and display new pools""" query = update.callback_query + chat = query.message.chat network_name = NETWORK_NAMES.get(network, network.title()) if network else "All Networks" - # Show loading - await query.message.edit_text( - f"🆕 *New Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", - parse_mode="MarkdownV2" - ) + # Show loading - handle photo messages + if query.message.photo: + try: + await query.message.delete() + except Exception: + pass + loading_msg = await chat.send_message( + f"🆕 *New Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + else: + await query.message.edit_text( + f"🆕 *New Pools \\- {escape_markdown_v2(network_name)}*\n\n_Loading\\.\\.\\._", + parse_mode="MarkdownV2" + ) + loading_msg = query.message try: cache_key = f"gecko_new_{network or 'all'}" @@ -970,7 +1038,7 @@ async def fetch_new(): # Build keyboard with network/view controls reply_markup = _build_pool_list_keyboard(pools, context.user_data) - await query.message.edit_text( + await loading_msg.edit_text( header + table + footer, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -978,14 +1046,24 @@ async def fetch_new(): except Exception as e: logger.error(f"Error fetching new pools: {e}", exc_info=True) - await query.message.edit_text( - f"❌ Error fetching new pools: {escape_markdown_v2(str(e))}", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup([ - [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], - [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] - ]) - ) + try: + await loading_msg.edit_text( + f"❌ Error fetching new pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) + except Exception: + await chat.send_message( + f"❌ Error fetching new pools: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("🔄 Retry", callback_data="dex:gecko_refresh")], + [InlineKeyboardButton("« LP Menu", callback_data="dex:liquidity")] + ]) + ) # ============================================ @@ -1304,19 +1382,18 @@ async def _show_pool_chart(update: Update, context: ContextTypes.DEFAULT_TYPE, p await query.answer("Loading chart...") - # Show loading - handle photo messages (can't edit photo to text) - if getattr(query.message, 'photo', None): + chat = query.message.chat + + # Show loading - delete current message and send loading text + try: await query.message.delete() - loading_msg = await query.message.chat.send_message( - f"📈 *{escape_markdown_v2(pool_data['name'])}*\n\n_Loading chart\\.\\.\\._", - parse_mode="MarkdownV2" - ) - else: - await query.message.edit_text( - f"📈 *{escape_markdown_v2(pool_data['name'])}*\n\n_Loading chart\\.\\.\\._", - parse_mode="MarkdownV2" - ) - loading_msg = query.message + except Exception: + pass # Message may already be deleted + + loading_msg = await chat.send_message( + f"📈 *{escape_markdown_v2(pool_data['name'])}*\n\n_Loading chart\\.\\.\\._", + parse_mode="MarkdownV2" + ) try: network = pool_data["network"] @@ -1378,16 +1455,20 @@ async def _show_pool_chart(update: Update, context: ContextTypes.DEFAULT_TYPE, p await _show_pool_text_detail(loading_msg, context, pool_data) return - # Build caption with key info - caption_lines = [ - f"📈 *{escape_markdown_v2(pool_data['name'])}*", - ] + # Build caption with detailed info + pool_index = context.user_data.get("gecko_selected_pool_index", 0) + dex_id = pool_data.get("dex_id", "") + network = pool_data.get("network", "") + network_name = NETWORK_NAMES.get(network, network) + + caption_lines = [f"📈 *{escape_markdown_v2(pool_data['name'])}*"] - # Add price and change info + # Price line + price_parts = [] if pool_data.get("base_token_price_usd"): try: price = float(pool_data["base_token_price_usd"]) - caption_lines.append(f"💰 {escape_markdown_v2(_format_price(price))}") + price_parts.append(f"💰 {escape_markdown_v2(_format_price(price))}") except (ValueError, TypeError): pass @@ -1395,60 +1476,109 @@ async def _show_pool_chart(update: Update, context: ContextTypes.DEFAULT_TYPE, p if change_24h is not None: try: change = float(change_24h) - caption_lines.append(f"{escape_markdown_v2(_format_change(change))} 24h") + emoji = "🟢" if change >= 0 else "🔴" + price_parts.append(f"{emoji} {escape_markdown_v2(_format_change(change))} 24h") except (ValueError, TypeError): pass - vol_24h = pool_data.get("volume_24h") - if vol_24h: + if price_parts: + caption_lines.append(" • ".join(price_parts)) + + # Network/DEX line + caption_lines.append(f"🌐 {escape_markdown_v2(network_name)} • 🏦 {escape_markdown_v2(dex_id)}") + + # Price changes line + changes = [] + for period, key in [("1h", "price_change_1h"), ("6h", "price_change_6h"), ("24h", "price_change_24h")]: + change = pool_data.get(key) + if change is not None: + try: + changes.append(f"{period}: {_format_change(float(change))}") + except (ValueError, TypeError): + pass + if changes: + caption_lines.append(f"📊 {escape_markdown_v2(' | '.join(changes))}") + + # Volume line + vols = [] + for period, key in [("1h", "volume_1h"), ("6h", "volume_6h"), ("24h", "volume_24h")]: + vol = pool_data.get(key) + if vol: + try: + vols.append(f"{period}: {_format_volume(float(vol))}") + except (ValueError, TypeError): + pass + if vols: + caption_lines.append(f"📈 Vol {escape_markdown_v2(' | '.join(vols))}") + + # Market metrics line + metrics = [] + if pool_data.get("reserve_usd"): + try: + metrics.append(f"Liq: {_format_volume(float(pool_data['reserve_usd']))}") + except (ValueError, TypeError): + pass + if pool_data.get("fdv_usd"): + try: + metrics.append(f"FDV: {_format_volume(float(pool_data['fdv_usd']))}") + except (ValueError, TypeError): + pass + if pool_data.get("market_cap_usd"): try: - vol = float(vol_24h) - caption_lines.append(f"Vol: {escape_markdown_v2(_format_volume(vol))}") + metrics.append(f"MC: {_format_volume(float(pool_data['market_cap_usd']))}") except (ValueError, TypeError): pass + if metrics: + caption_lines.append(f"💎 {escape_markdown_v2(' | '.join(metrics))}") - # Join caption lines - use bullet separator instead of | to avoid escaping issues - if len(caption_lines) > 1: - caption = caption_lines[0] + "\n" + " • ".join(caption_lines[1:]) - else: - caption = caption_lines[0] + # Transactions line + txns = pool_data.get("transactions_24h", {}) + if txns: + buys = txns.get("buys", 0) + sells = txns.get("sells", 0) + if buys or sells: + caption_lines.append(f"🔄 24h Txns: {buys} buys / {sells} sells") - # Build keyboard with timeframe selection and action buttons - pool_index = context.user_data.get("gecko_selected_pool_index", 0) - dex_id = pool_data.get("dex_id", "") - network = pool_data.get("network", "") + caption = "\n".join(caption_lines) + + # Build keyboard - reorganized layout supports_liquidity = can_fetch_liquidity(dex_id, network) keyboard = [ - # Timeframe row + # Row 1: Timeframe [ - InlineKeyboardButton("1h" if timeframe != "1m" else "• 1h •", callback_data="dex:gecko_ohlcv:1m"), - InlineKeyboardButton("1d" if timeframe != "1h" else "• 1d •", callback_data="dex:gecko_ohlcv:1h"), - InlineKeyboardButton("7d" if timeframe != "1d" else "• 7d •", callback_data="dex:gecko_ohlcv:1d"), + InlineKeyboardButton("1h" if timeframe != "1m" else "• 1h •", callback_data="dex:gecko_pool_tf:1m"), + InlineKeyboardButton("1d" if timeframe != "1h" else "• 1d •", callback_data="dex:gecko_pool_tf:1h"), + InlineKeyboardButton("7d" if timeframe != "1d" else "• 7d •", callback_data="dex:gecko_pool_tf:1d"), ], - # Action row: Swap + Trades + Info + # Row 2: Swap | LP [ InlineKeyboardButton("💱 Swap", callback_data="dex:gecko_swap"), + InlineKeyboardButton("➕ Add LP", callback_data="dex:gecko_add_liquidity"), + ], + # Row 3: Trades | Liquidity Distribution + [ InlineKeyboardButton("📜 Trades", callback_data="dex:gecko_trades"), - InlineKeyboardButton("ℹ️ Info", callback_data="dex:gecko_info"), + InlineKeyboardButton("📊 Liquidity", callback_data="dex:gecko_liquidity"), + ], + # Row 4: Add to Gateway + [ + InlineKeyboardButton("🔗 Add to Gateway", callback_data="dex:gecko_add_tokens"), + ], + # Row 5: Refresh | Back + [ + InlineKeyboardButton("🔄 Refresh", callback_data=f"dex:gecko_pool:{pool_index}"), + InlineKeyboardButton("« Back", callback_data="dex:gecko_back_to_list"), ], ] - # Add liquidity button for supported DEXes - if supports_liquidity: - keyboard.append([ - InlineKeyboardButton("📊 Liquidity", callback_data="dex:gecko_liquidity"), - InlineKeyboardButton("➕ Add LP", callback_data="dex:gecko_add_liquidity"), - ]) - - keyboard.append([ - InlineKeyboardButton("🔄", callback_data=f"dex:gecko_pool:{pool_index}"), - InlineKeyboardButton("« Back", callback_data="dex:gecko_back_to_list"), - ]) - # Delete loading message and send photo - await loading_msg.delete() - await loading_msg.chat.send_photo( + try: + await loading_msg.delete() + except Exception: + pass # Message may already be deleted + + await chat.send_photo( photo=chart_buffer, caption=caption, parse_mode="MarkdownV2", @@ -1458,7 +1588,17 @@ async def _show_pool_chart(update: Update, context: ContextTypes.DEFAULT_TYPE, p except Exception as e: logger.error(f"Error generating pool chart: {e}", exc_info=True) # Fall back to text view on error - await _show_pool_text_detail(loading_msg, context, pool_data) + try: + await _show_pool_text_detail(loading_msg, context, pool_data) + except Exception: + # If loading_msg was deleted, send new error message + await chat.send_message( + f"❌ Error loading chart: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("« Back", callback_data="dex:gecko_back_to_list")] + ]) + ) async def _show_pool_text_detail(message, context: ContextTypes.DEFAULT_TYPE, pool_data: dict) -> None: @@ -2122,8 +2262,15 @@ async def show_recent_trades(update: Update, context: ContextTypes.DEFAULT_TYPE) await query.answer("Loading trades...") - # Show loading - await query.message.edit_text( + chat = query.message.chat + + # Show loading - delete current message (may be photo) and send text + try: + await query.message.delete() + except Exception: + pass + + loading_msg = await chat.send_message( r"📜 *Recent Trades*" + "\n\n" + r"_Loading\.\.\._", parse_mode="MarkdownV2" ) @@ -2154,7 +2301,7 @@ async def show_recent_trades(update: Update, context: ContextTypes.DEFAULT_TYPE) trades = _extract_pools_from_response(result, 20) if not trades: - await query.message.edit_text( + await loading_msg.edit_text( r"📜 *Recent Trades*" + "\n\n" + "_No recent trades found_", parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup([ @@ -2257,7 +2404,7 @@ async def show_recent_trades(update: Update, context: ContextTypes.DEFAULT_TYPE) ], ] - await query.message.edit_text( + await loading_msg.edit_text( "\n".join(lines), parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) @@ -2265,13 +2412,22 @@ async def show_recent_trades(update: Update, context: ContextTypes.DEFAULT_TYPE) except Exception as e: logger.error(f"Error fetching trades: {e}", exc_info=True) - await query.message.edit_text( - f"❌ Error loading trades: {escape_markdown_v2(str(e))}", - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup([ - [InlineKeyboardButton("« Back", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")] - ]) - ) + try: + await loading_msg.edit_text( + f"❌ Error loading trades: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("« Back", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")] + ]) + ) + except Exception: + await chat.send_message( + f"❌ Error loading trades: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("« Back", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")] + ]) + ) # ============================================ @@ -2591,9 +2747,9 @@ async def handle_gecko_token_add(update: Update, context: ContextTypes.DEFAULT_T await query.answer("Adding token to Gateway...") try: - from servers import server_manager + from config_manager import get_config_manager - client = await server_manager.get_default_client() + client = await get_config_manager().get_default_client() await client.gateway.add_token( network_id=gateway_network, address=address, @@ -2783,17 +2939,14 @@ async def handle_gecko_swap(update: Update, context: ContextTypes.DEFAULT_TYPE) context.user_data["swap_from_gecko"] = True context.user_data["swap_gecko_pool_index"] = context.user_data.get("gecko_selected_pool_index", 0) - # Delete current message (might be a photo) and show swap menu + # Delete current message (might be a photo) try: - if getattr(query.message, 'photo', None): - await query.message.delete() - else: - await query.message.delete() + await query.message.delete() except Exception: pass - # Show swap menu - await show_swap_menu(update, context) + # Show swap menu - send_new=True since we deleted the message + await show_swap_menu(update, context, send_new=True) async def show_gecko_info(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -2956,6 +3109,233 @@ async def show_gecko_info(update: Update, context: ContextTypes.DEFAULT_TYPE) -> ) +# ============================================ +# POOL DETAIL TIMEFRAME SWITCHING +# ============================================ + +async def handle_gecko_pool_tf(update: Update, context: ContextTypes.DEFAULT_TYPE, timeframe: str) -> None: + """Handle timeframe switching in pool detail view - maintains action buttons""" + pool_data = context.user_data.get("gecko_selected_pool") + if not pool_data: + await update.callback_query.answer("No pool selected") + return + + await _show_pool_chart(update, context, pool_data, timeframe) + + +# ============================================ +# ADD TOKENS TO GATEWAY +# ============================================ + +async def handle_gecko_add_tokens(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Add pool tokens to Gateway configuration""" + import httpx + from geckoterminal_py import GeckoTerminalAsyncClient + from config_manager import get_config_manager + + query = update.callback_query + + pool_data = context.user_data.get("gecko_selected_pool") + if not pool_data: + await query.answer("No pool selected") + return + + # Get token addresses + base_addr = pool_data.get("base_token_address", "") + quote_addr = pool_data.get("quote_token_address", "") + gecko_network = pool_data.get("network", "solana") + pool_address = pool_data.get("address", "") + + # If no addresses in pool_data, fetch from GeckoTerminal API directly + if not base_addr and not quote_addr and pool_address: + await query.answer("Fetching token info...") + try: + # Use direct API call since geckoterminal_py doesn't have get_pool method + url = f"https://api.geckoterminal.com/api/v2/networks/{gecko_network}/pools/{pool_address}" + async with httpx.AsyncClient() as client: + response = await client.get(url, params={"include": "base_token,quote_token"}) + if response.status_code == 200: + result = response.json() + data = result.get('data', {}) + relationships = data.get('relationships', {}) + base_token_id = relationships.get('base_token', {}).get('data', {}).get('id', '') + quote_token_id = relationships.get('quote_token', {}).get('data', {}).get('id', '') + base_addr = _parse_token_address_from_id(base_token_id) + quote_addr = _parse_token_address_from_id(quote_token_id) + except Exception as e: + logger.warning(f"Failed to fetch pool info for tokens: {e}") + + if not base_addr and not quote_addr: + await query.answer("No token addresses available", show_alert=True) + return + + gecko_client = GeckoTerminalAsyncClient() + + # Map GeckoTerminal network to Gateway network + network_mapping = { + "solana": "solana-mainnet-beta", + "eth": "ethereum-mainnet", + "base": "base-mainnet", + "arbitrum": "arbitrum-one", + "bsc": "bsc-mainnet", + "polygon_pos": "polygon-mainnet", + } + gateway_network = network_mapping.get(gecko_network, gecko_network) + + await query.answer("Adding tokens to Gateway...") + + chat = query.message.chat + + # Show loading - delete current message and send loading text + try: + await query.message.delete() + except Exception: + pass + + loading_msg = await chat.send_message( + r"🔄 *Adding tokens to Gateway\.\.\.*", + parse_mode="MarkdownV2" + ) + + added_tokens = [] + errors = [] + + async def add_token_to_gateway(token_address: str) -> str: + """Fetch token info and add to gateway. Returns symbol or error indicator.""" + try: + # Fetch from GeckoTerminal + result = await gecko_client.get_specific_token_on_network(gecko_network, token_address) + + token_data = {} + if isinstance(result, dict): + data = result.get('data', result) if 'data' in result else result + token_data = data.get('attributes', data) if isinstance(data, dict) else {} + else: + # Try pandas DataFrame + try: + import pandas as pd + if isinstance(result, pd.DataFrame) and not result.empty: + token_data = result.to_dict('records')[0] + except ImportError: + pass + + if not token_data: + return None + + symbol = token_data.get('symbol', '???') + decimals = token_data.get('decimals', 9 if gecko_network == "solana" else 18) + name = token_data.get('name') + + # Add to gateway + client = await get_config_manager().get_default_client() + await client.gateway.add_token( + network_id=gateway_network, + address=token_address, + symbol=symbol, + decimals=decimals, + name=name + ) + return symbol + + except Exception as e: + error_str = str(e).lower() + if "already exists" in error_str or "duplicate" in error_str: + return "exists" + logger.warning(f"Failed to add token {token_address[:12]}...: {e}") + return None + + # Add both tokens + if base_addr: + result = await add_token_to_gateway(base_addr) + if result and result != "exists": + added_tokens.append(result) + elif result is None: + errors.append(f"base ({base_addr[:8]}...)") + + if quote_addr: + result = await add_token_to_gateway(quote_addr) + if result and result != "exists": + added_tokens.append(result) + elif result is None: + errors.append(f"quote ({quote_addr[:8]}...)") + + # Build result message + if added_tokens: + result_msg = f"✅ *Added:* {escape_markdown_v2(', '.join(added_tokens))}\n\n" + else: + result_msg = "ℹ️ _Tokens already in Gateway_\n\n" + + if errors: + result_msg += f"⚠️ Failed: {escape_markdown_v2(', '.join(errors))}\n\n" + + result_msg += r"⚠️ _Restart Gateway for changes to take effect_" + + # Add restart button + keyboard = [ + [InlineKeyboardButton("🔄 Restart Gateway", callback_data="dex:gecko_restart_gateway")], + [InlineKeyboardButton("« Back to Pool", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")], + ] + + try: + await loading_msg.edit_text( + result_msg, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception: + await chat.send_message( + result_msg, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + +async def handle_gecko_restart_gateway(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Restart Gateway after adding tokens""" + from config_manager import get_config_manager + + query = update.callback_query + await query.answer("Restarting Gateway...") + + chat = query.message.chat + + try: + await query.message.delete() + except Exception: + pass + + loading_msg = await chat.send_message( + r"🔄 *Restarting Gateway\.\.\.*", + parse_mode="MarkdownV2" + ) + + try: + client = await get_config_manager().get_default_client() + await client.gateway.restart() + + # Wait a moment for restart + import asyncio + await asyncio.sleep(2) + + await loading_msg.edit_text( + r"✅ *Gateway restarted*" + "\n\n_New tokens are now available_", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("« Back to Pool", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")] + ]) + ) + + except Exception as e: + logger.error(f"Failed to restart gateway: {e}", exc_info=True) + await loading_msg.edit_text( + f"❌ Failed to restart Gateway: {escape_markdown_v2(str(e))}", + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup([ + [InlineKeyboardButton("« Back to Pool", callback_data=f"dex:gecko_pool:{context.user_data.get('gecko_selected_pool_index', 0)}")] + ]) + ) + + # ============================================ # EXPORTS # ============================================ @@ -2993,4 +3373,7 @@ async def show_gecko_info(update: Update, context: ContextTypes.DEFAULT_TYPE) -> 'handle_gecko_add_liquidity', 'handle_gecko_swap', 'show_gecko_info', + 'handle_gecko_pool_tf', + 'handle_gecko_add_tokens', + 'handle_gecko_restart_gateway', ] diff --git a/handlers/dex/liquidity.py b/handlers/dex/liquidity.py index 13c9d1a..eeadb05 100644 --- a/handlers/dex/liquidity.py +++ b/handlers/dex/liquidity.py @@ -12,15 +12,12 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes -from utils.telegram_formatters import escape_markdown_v2, format_error_message, resolve_token_symbol, format_amount, KNOWN_TOKENS +from utils.telegram_formatters import escape_markdown_v2, format_error_message, resolve_token_symbol, KNOWN_TOKENS from utils.auth import gateway_required -from servers import get_client +from config_manager import get_client from ._shared import ( - get_cached, - set_cached, cached_call, invalidate_cache, - get_explorer_url, format_relative_time, get_history_filters, set_history_filters, @@ -479,7 +476,7 @@ async def show_liquidity_menu(update: Update, context: ContextTypes.DEFAULT_TYPE help_text = r"💧 *Liquidity Pools*" + "\n\n" try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Fetch balances (cached) gateway_data = await cached_call( @@ -938,7 +935,7 @@ async def handle_lp_history(update: Update, context: ContextTypes.DEFAULT_TYPE, else: filters = get_history_filters(context.user_data, "position") - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): error_message = format_error_message("Gateway CLMM not available") diff --git a/handlers/dex/lp_monitor_handlers.py b/handlers/dex/lp_monitor_handlers.py new file mode 100644 index 0000000..4ed6af4 --- /dev/null +++ b/handlers/dex/lp_monitor_handlers.py @@ -0,0 +1,622 @@ +""" +LP Monitor Alert Handlers + +Handles user interactions with LP monitor out-of-range alerts: +- Navigation between positions +- Position detail views +- Fee collection +- Position rebalancing (close + reopen) +- Out-of-range position filtering +""" + +import logging +from decimal import Decimal + +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update +from telegram.ext import ContextTypes + +from utils.telegram_formatters import escape_markdown_v2, resolve_token_symbol + +logger = logging.getLogger(__name__) + + +# ============================================ +# POSITION FORMATTING HELPERS +# ============================================ + +def _format_price(value: float | str, decimals: int | None = None) -> str: + """Format a price value with appropriate decimal places.""" + try: + float_val = float(value) + if decimals is None: + decimals = 2 if float_val >= 1 else (6 if float_val >= 0.001 else 8) + return f"{float_val:.{decimals}f}" + except (ValueError, TypeError): + return str(value) + + +def _get_position_tokens(pos: dict, token_cache: dict) -> tuple[str, str, str]: + """Extract and resolve token symbols from position data.""" + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + base_symbol = resolve_token_symbol(base_token, token_cache) + quote_symbol = resolve_token_symbol(quote_token, token_cache) + pair = f"{base_symbol}-{quote_symbol}" + return base_symbol, quote_symbol, pair + + +def _get_positions_for_instance(positions_cache: dict, instance_id: str) -> list[dict]: + """Get all cached positions for a given LP monitor instance.""" + positions = [] + i = 0 + while True: + cache_key = f"lpm_{instance_id}_{i}" + if cache_key in positions_cache: + positions.append(positions_cache[cache_key]) + i += 1 + else: + break + return positions + + +# ============================================ +# NAVIGATION HANDLERS +# ============================================ + +async def handle_lpm_navigation( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + instance_id: str, + new_index: int +) -> None: + """Handle navigation in LP monitor alert message.""" + query = update.callback_query + positions_cache = context.user_data.get("positions_cache", {}) + token_cache = context.user_data.get("token_cache", {}) + + positions = _get_positions_for_instance(positions_cache, instance_id) + if not positions: + await query.answer("Positions not found") + return + + # Clamp index to valid range + new_index = max(0, min(new_index, len(positions) - 1)) + pos = positions[new_index] + + # Get token info + base_symbol, quote_symbol, pair = _get_position_tokens(pos, token_cache) + connector = pos.get('connector', 'unknown') + + # Price info + lower = pos.get('lower_price', pos.get('price_lower', '')) + upper = pos.get('upper_price', pos.get('price_upper', '')) + current = pos.get('current_price', '') + + # Format range + range_str = "" + if lower and upper: + try: + lower_f = float(lower) + upper_f = float(upper) + decimals = 2 if lower_f >= 1 else (6 if lower_f >= 0.001 else 8) + range_str = f"Range: {lower_f:.{decimals}f} - {upper_f:.{decimals}f}" + except (ValueError, TypeError): + range_str = f"Range: {lower} - {upper}" + + # Format current price and direction + current_str = "" + direction = "" + if current: + try: + current_f = float(current) + lower_f = float(lower) if lower else 0 + upper_f = float(upper) if upper else 0 + decimals = 2 if current_f >= 1 else (6 if current_f >= 0.001 else 8) + current_str = f"Current: {current_f:.{decimals}f}" + if current_f < lower_f: + direction = "▼ Below range" + elif current_f > upper_f: + direction = "▲ Above range" + except (ValueError, TypeError): + current_str = f"Current: {current}" + + # Format value + pnl_summary = pos.get('pnl_summary', {}) + value = pnl_summary.get('current_lp_value_quote', 0) + value_str = "" + if value: + try: + value_str = f"Value: {float(value):.2f} {quote_symbol}" + except (ValueError, TypeError): + pass + + # Build message + total = len(positions) + header = f"🚨 *Out of Range* \\({new_index + 1}/{total}\\)" if total > 1 else "🚨 *Position Out of Range*" + lines = [header, "", f"*{escape_markdown_v2(pair)}* \\({escape_markdown_v2(connector)}\\)"] + + if direction: + lines.append(f"_{escape_markdown_v2(direction)}_") + if range_str: + lines.append(escape_markdown_v2(range_str)) + if current_str: + lines.append(escape_markdown_v2(current_str)) + if value_str: + lines.append(escape_markdown_v2(value_str)) + + text = "\n".join(lines) + + # Build keyboard + cache_key = f"lpm_{instance_id}_{new_index}" + keyboard = [] + + if total > 1: + nav_row = [] + if new_index > 0: + nav_row.append(InlineKeyboardButton("◀️ Prev", callback_data=f"dex:lpm_nav:{instance_id}:{new_index - 1}")) + nav_row.append(InlineKeyboardButton(f"{new_index + 1}/{total}", callback_data="dex:lpm_noop")) + if new_index < total - 1: + nav_row.append(InlineKeyboardButton("Next ▶️", callback_data=f"dex:lpm_nav:{instance_id}:{new_index + 1}")) + keyboard.append(nav_row) + + keyboard.append([ + InlineKeyboardButton("❌ Close", callback_data=f"dex:pos_close:{cache_key}"), + InlineKeyboardButton("⏭ Skip", callback_data=f"dex:lpm_skip:{cache_key}"), + InlineKeyboardButton("✅ Dismiss", callback_data=f"dex:lpm_dismiss:{instance_id}"), + ]) + + try: + await query.message.edit_text(text, parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard)) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.warning(f"Failed to update LPM navigation: {e}") + + +async def handle_lpm_oor_navigation( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + instance_id: str, + index: int +) -> None: + """Navigate only out-of-range positions.""" + from routines.lp_monitor import format_position_detail_view + + query = update.callback_query + positions_cache = context.user_data.get("positions_cache", {}) + token_cache = context.user_data.get("token_cache", {}) + token_prices = context.user_data.get("token_prices", {}) + + # Find all positions and filter to out-of-range only + all_positions = [] + i = 0 + while True: + cache_key = f"lpm_{instance_id}_{i}" + if cache_key in positions_cache: + all_positions.append((i, positions_cache[cache_key])) + i += 1 + else: + break + + oor_positions = [(orig_idx, pos) for orig_idx, pos in all_positions if pos.get('in_range') == 'OUT_OF_RANGE'] + + if not oor_positions: + await query.answer("No out-of-range positions") + return + + # Clamp index + index = max(0, min(index, len(oor_positions) - 1)) + orig_idx, pos = oor_positions[index] + + text, _ = format_position_detail_view( + pos, token_cache, token_prices, index, len(oor_positions), instance_id + ) + + # Custom keyboard for OOR navigation + cache_key = f"lpm_{instance_id}_{orig_idx}" + keyboard = [] + + nav_row = [] + if index > 0: + nav_row.append(InlineKeyboardButton("◀️ Prev", callback_data=f"dex:lpm_oor:{instance_id}:{index - 1}")) + nav_row.append(InlineKeyboardButton(f"⚠️ {index + 1}/{len(oor_positions)}", callback_data="dex:noop")) + if index < len(oor_positions) - 1: + nav_row.append(InlineKeyboardButton("Next ▶️", callback_data=f"dex:lpm_oor:{instance_id}:{index + 1}")) + if nav_row: + keyboard.append(nav_row) + + keyboard.append([ + InlineKeyboardButton("💰 Collect Fees", callback_data=f"dex:lpm_collect:{cache_key}"), + InlineKeyboardButton("❌ Close", callback_data=f"dex:pos_close:{cache_key}"), + ]) + keyboard.append([ + InlineKeyboardButton("🔄 Rebalance", callback_data=f"dex:lpm_rebalance:{cache_key}"), + ]) + keyboard.append([ + InlineKeyboardButton("« Back to List", callback_data=f"dex:lpm_dismiss:{instance_id}"), + ]) + + try: + await query.message.edit_text(text, parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard)) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.warning(f"Failed to update OOR navigation: {e}") + + +# ============================================ +# DETAIL VIEW HANDLER +# ============================================ + +async def handle_lpm_detail( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + instance_id: str, + index: int +) -> None: + """Handle position detail view with actions.""" + from routines.lp_monitor import format_position_detail_view + + query = update.callback_query + positions_cache = context.user_data.get("positions_cache", {}) + token_cache = context.user_data.get("token_cache", {}) + token_prices = context.user_data.get("token_prices", {}) + + positions = _get_positions_for_instance(positions_cache, instance_id) + if not positions: + await query.answer("Positions not found") + return + + # Clamp index + index = max(0, min(index, len(positions) - 1)) + pos = positions[index] + + text, reply_markup = format_position_detail_view( + pos, token_cache, token_prices, index, len(positions), instance_id + ) + + try: + await query.message.edit_text(text, parse_mode="MarkdownV2", reply_markup=reply_markup) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.warning(f"Failed to update position detail: {e}") + + +# ============================================ +# FEE COLLECTION HANDLER +# ============================================ + +async def handle_lpm_collect_fees( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + cache_key: str +) -> None: + """Collect fees for a position.""" + query = update.callback_query + chat_id = update.effective_chat.id + positions_cache = context.user_data.get("positions_cache", {}) + + pos = positions_cache.get(cache_key) + if not pos: + await query.answer("Position not found") + return + + await query.answer("Collecting fees...") + + try: + from config_manager import get_client + client = await get_client(chat_id, context=context) + if not client or not hasattr(client, 'gateway_clmm'): + await query.message.reply_text("❌ Gateway not available") + return + + # Get position details + position_address = pos.get('position_address', pos.get('nft_id', pos.get('address', ''))) + connector = pos.get('connector', 'meteora') + network = pos.get('network', 'solana-mainnet-beta') + + result = await client.gateway_clmm.collect_fees( + connector=connector, + network=network, + position_address=position_address + ) + + if result: + tx_hash = (result.get('tx_hash', '') or 'N/A')[:16] + await query.message.reply_text( + f"✅ *Fees collected*\nTx: `{escape_markdown_v2(tx_hash)}...`", + parse_mode="MarkdownV2" + ) + else: + await query.message.reply_text("❌ Failed: No response from gateway") + + except Exception as e: + logger.error(f"Failed to collect fees: {e}") + await query.message.reply_text(f"❌ Error: {str(e)[:100]}") + + +# ============================================ +# REBALANCE HANDLERS +# ============================================ + +async def handle_lpm_rebalance( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + cache_key: str +) -> None: + """Start rebalance flow: show confirmation before close + reopen.""" + query = update.callback_query + positions_cache = context.user_data.get("positions_cache", {}) + token_cache = context.user_data.get("token_cache", {}) + + pos = positions_cache.get(cache_key) + if not pos: + await query.answer("Position not found") + return + + # Store position info for rebalance flow + context.user_data["rebalance_position"] = pos + context.user_data["rebalance_cache_key"] = cache_key + + # Get position details for confirmation + _, _, pair = _get_position_tokens(pos, token_cache) + lower = pos.get('lower_price', pos.get('price_lower', 0)) + upper = pos.get('upper_price', pos.get('price_upper', 0)) + + text = ( + f"🔄 *Rebalance Position*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n" + f"*{escape_markdown_v2(pair)}*\n\n" + f"This will:\n" + f"1\\. Close the current position\n" + f"2\\. Open a new position with the same range\n" + f" \\({escape_markdown_v2(str(lower))} \\- {escape_markdown_v2(str(upper))}\\)\n" + f"3\\. Use Bid\\-Ask strategy \\(type 2\\)\n\n" + f"⚠️ *Are you sure?*" + ) + + keyboard = [ + [ + InlineKeyboardButton("✅ Confirm Rebalance", callback_data=f"dex:lpm_rebalance_confirm:{cache_key}"), + InlineKeyboardButton("❌ Cancel", callback_data=f"dex:lpm_dismiss:{cache_key.split('_')[1]}"), + ] + ] + + try: + await query.message.edit_text(text, parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard)) + except Exception as e: + logger.warning(f"Failed to show rebalance confirmation: {e}") + + +async def handle_lpm_rebalance_execute( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + cache_key: str +) -> None: + """Execute the rebalance: close position and open new one with same range.""" + query = update.callback_query + chat_id = update.effective_chat.id + positions_cache = context.user_data.get("positions_cache", {}) + token_cache = context.user_data.get("token_cache", {}) + + pos = positions_cache.get(cache_key) + if not pos: + await query.answer("Position not found") + return + + await query.answer("Rebalancing position...") + + # Update message to show progress + await query.message.edit_text( + "🔄 *Rebalancing\\.\\.\\.*\n\nStep 1/3: Closing position\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + try: + from config_manager import get_client + client = await get_client(chat_id, context=context) + if not client or not hasattr(client, 'gateway_clmm'): + await query.message.edit_text("❌ Gateway not available") + return + + # Get position details + position_address = pos.get('position_address', pos.get('nft_id', pos.get('address', ''))) + connector = pos.get('connector', 'meteora') + network = pos.get('network', 'solana-mainnet-beta') + pool_address = pos.get('pool_id', pos.get('pool_address', '')) + lower_price = pos.get('lower_price', pos.get('price_lower', 0)) + upper_price = pos.get('upper_price', pos.get('price_upper', 0)) + + # Step 1: Close the position + close_result = await client.gateway_clmm.close_position( + connector=connector, + network=network, + position_address=position_address + ) + + if not close_result: + await query.message.edit_text("❌ Failed to close position: No response from gateway") + return + + logger.info(f"Close position result: {close_result}") + + # Extract tx hash from various possible field names + close_tx = None + if isinstance(close_result, dict): + close_tx = close_result.get('tx_hash') or close_result.get('txHash') or \ + close_result.get('signature') or close_result.get('txSignature') + close_tx_display = f"`{escape_markdown_v2(close_tx[:20])}...`" if close_tx else "_pending_" + + # Update progress + await query.message.edit_text( + f"🔄 *Rebalancing\\.\\.\\.*\n\n" + f"✅ Step 1/3: Position closed\n" + f" Tx: {close_tx_display}\n\n" + f"Step 2/3: Getting withdrawn amounts\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + # Get the withdrawn amounts from the close result + base_withdrawn = close_result.get('base_amount', close_result.get('amount_base', 0)) + quote_withdrawn = close_result.get('quote_amount', close_result.get('amount_quote', 0)) + + # Fallback to original position amounts if not in close result + if not base_withdrawn: + base_withdrawn = pos.get('base_token_amount', pos.get('amount_a', 0)) + if not quote_withdrawn: + quote_withdrawn = pos.get('quote_token_amount', pos.get('amount_b', 0)) + + # Update progress + await query.message.edit_text( + f"🔄 *Rebalancing\\.\\.\\.*\n\n" + f"✅ Step 1/3: Position closed\n" + f"✅ Step 2/3: Amounts ready\n\n" + f"Step 3/3: Opening new position\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + # Step 3: Open new position with same range using bid-ask strategy (type 2) + extra_params = {"strategyType": 2} # Bid-Ask strategy + + open_result = await client.gateway_clmm.open_position( + connector=connector, + network=network, + pool_address=pool_address, + lower_price=Decimal(str(lower_price)), + upper_price=Decimal(str(upper_price)), + base_token_amount=float(base_withdrawn) if base_withdrawn else 0, + quote_token_amount=float(quote_withdrawn) if quote_withdrawn else 0, + extra_params=extra_params + ) + + if not open_result: + await query.message.edit_text( + f"⚠️ *Partial Rebalance*\n\n" + f"✅ Position closed\n" + f"❌ Failed to open new position: No response from gateway\n\n" + f"Your funds are in your wallet\\.", + parse_mode="MarkdownV2" + ) + return + + logger.info(f"Open position result: {open_result}") + + # Extract tx hash + open_tx = None + if isinstance(open_result, dict): + open_tx = open_result.get('tx_hash') or open_result.get('txHash') or \ + open_result.get('signature') or open_result.get('txSignature') + open_tx_display = f"`{escape_markdown_v2(open_tx[:20])}...`" if open_tx else "_pending_" + + # Get token symbols for display + _, quote_symbol, pair = _get_position_tokens(pos, token_cache) + + # Format price range for display + try: + lower_f = float(lower_price) + upper_f = float(upper_price) + decimals = 2 if lower_f >= 1 else 6 if lower_f >= 0.001 else 8 + lower_esc = escape_markdown_v2(f"{lower_f:.{decimals}f}") + upper_esc = escape_markdown_v2(f"{upper_f:.{decimals}f}") + range_display = f"{lower_esc} \\- {upper_esc}" + except (ValueError, TypeError): + range_display = f"{escape_markdown_v2(str(lower_price))} \\- {escape_markdown_v2(str(upper_price))}" + + # Build success message + lines = [ + f"✅ *Rebalance Complete*", + f"━━━━━━━━━━━━━━━━━━━━━", + f"*{escape_markdown_v2(pair)}*", + "", + f"✅ Old position closed", + f"✅ New position opened", + "", + f"Range: {range_display}", + f"Strategy: Bid\\-Ask", + ] + + if close_tx or open_tx: + lines.append("") + if close_tx: + lines.append(f"Close Tx: {close_tx_display}") + if open_tx: + lines.append(f"Open Tx: {open_tx_display}") + + await query.message.edit_text( + "\n".join(lines), + parse_mode="MarkdownV2" + ) + + except Exception as e: + logger.error(f"Failed to rebalance position: {e}", exc_info=True) + await query.message.edit_text(f"❌ Error: {str(e)[:200]}") + + +# ============================================ +# SKIP AND DISMISS HANDLERS +# ============================================ + +async def handle_lpm_skip( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + cache_key: str +) -> None: + """Skip a position alert (remove from cache and dismiss).""" + query = update.callback_query + await query.answer("Skipped") + + # Remove position from cache + positions_cache = context.user_data.get("positions_cache", {}) + if cache_key in positions_cache: + del positions_cache[cache_key] + + try: + await query.message.edit_text( + "⏭ _Position skipped_", + parse_mode="MarkdownV2" + ) + except Exception: + pass + + +async def handle_lpm_dismiss(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Dismiss the LP monitor alert message.""" + query = update.callback_query + await query.answer("Dismissed") + + try: + await query.message.delete() + except Exception: + try: + await query.message.edit_text( + "✅ _Alert dismissed_", + parse_mode="MarkdownV2" + ) + except Exception: + pass + + +# ============================================ +# COUNTDOWN HANDLERS +# ============================================ + +async def handle_lpm_cancel_countdown( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + instance_id: str, + pos_id: str +) -> None: + """Cancel an active auto-close countdown.""" + query = update.callback_query + await query.answer("Countdown cancelled") + + # Signal cancellation via user_data + # The countdown task will check this flag and abort + cancel_key = f"lpm_countdown_{instance_id}_{pos_id}" + context.user_data[cancel_key] = "cancelled" + + try: + await query.message.edit_text( + "⏹ *Auto\\-close cancelled*\n\nPosition will remain open\\.", + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Could not update countdown message: {e}") diff --git a/handlers/dex/menu.py b/handlers/dex/menu.py index 97c2e6b..c4cbcd5 100644 --- a/handlers/dex/menu.py +++ b/handlers/dex/menu.py @@ -13,7 +13,7 @@ from utils.telegram_formatters import escape_markdown_v2, resolve_token_symbol, KNOWN_TOKENS from handlers.config.user_preferences import get_dex_last_swap, get_all_enabled_networks -from servers import get_client +from config_manager import get_client from ._shared import cached_call, invalidate_cache logger = logging.getLogger(__name__) @@ -455,7 +455,7 @@ async def _load_menu_data_background( gateway_data = {"balances_by_network": {}, "lp_positions": [], "total_value": 0, "token_cache": {}} try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Step 2: Fetch balances first (usually fast) and update UI immediately # When refresh=True, bypass local cache and tell API to refresh from exchanges @@ -546,13 +546,13 @@ async def show_dex_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, refr Args: refresh: If True, force refresh balances from exchanges (bypasses API cache) """ - from servers import server_manager + from config_manager import get_config_manager # Cancel any existing loading task first cancel_dex_loading_task(context) # Get server name for display - server_name = server_manager.default_server or "unknown" + server_name = get_config_manager().default_server or "unknown" reply_markup = _build_menu_keyboard() diff --git a/handlers/dex/pool_data.py b/handlers/dex/pool_data.py index af5b35d..fa46dd0 100644 --- a/handlers/dex/pool_data.py +++ b/handlers/dex/pool_data.py @@ -12,7 +12,7 @@ from geckoterminal_py import GeckoTerminalAsyncClient -from servers import get_client +from config_manager import get_client from ._shared import get_cached, set_cached logger = logging.getLogger(__name__) @@ -238,7 +238,7 @@ async def fetch_liquidity_bins( if cached is not None: return cached.get('bins'), cached, None - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not client: return None, None, "Gateway client not available" diff --git a/handlers/dex/pools.py b/handlers/dex/pools.py index 9a13915..b0da627 100644 --- a/handlers/dex/pools.py +++ b/handlers/dex/pools.py @@ -15,8 +15,8 @@ from utils.telegram_formatters import escape_markdown_v2, format_error_message, resolve_token_symbol, format_amount, KNOWN_TOKENS from handlers.config.user_preferences import set_dex_last_pool, get_dex_last_pool -from servers import get_client -from ._shared import get_cached, set_cached, cached_call, DEFAULT_CACHE_TTL, invalidate_cache +from config_manager import get_client +from ._shared import get_cached, set_cached, DEFAULT_CACHE_TTL, invalidate_cache from .visualizations import generate_liquidity_chart, generate_ohlcv_chart, generate_combined_chart, generate_aggregated_liquidity_chart from .pool_data import fetch_ohlcv, fetch_liquidity_bins, get_gecko_network @@ -41,7 +41,7 @@ async def get_token_cache_from_gateway(network: str = "solana-mainnet-beta", cha token_cache = dict(KNOWN_TOKENS) # Start with known tokens try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Try to get tokens from Gateway if hasattr(client, 'gateway'): @@ -213,7 +213,7 @@ async def process_pool_info( raise ValueError(f"Unsupported connector '{connector}'. Use: {', '.join(supported_connectors)}") chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -628,7 +628,7 @@ async def process_pool_list( if not chat_id: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -769,7 +769,7 @@ async def handle_plot_liquidity( try: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Fetch all pool infos in parallel with individual timeouts POOL_FETCH_TIMEOUT = 10 # seconds per pool @@ -988,7 +988,7 @@ async def _show_pool_detail( async def fetch_pool_info_task(): if pool_info is not None: return pool_info - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) return await _fetch_pool_info(client, pool_address, connector) async def fetch_token_cache_task(): @@ -1175,7 +1175,7 @@ async def fetch_ohlcv_task(): balance_cache_key = f"token_balances_{network}_{base_symbol}_{quote_symbol}" balances = get_cached(context.user_data, balance_cache_key, ttl=DEFAULT_CACHE_TTL) if balances is None: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) balances = await _fetch_token_balances(client, network, base_symbol, quote_symbol) set_cached(context.user_data, balance_cache_key, balances) context.user_data["token_balances"] = balances @@ -1538,6 +1538,10 @@ async def handle_pool_detail_refresh(update: Update, context: ContextTypes.DEFAU await query.answer("Refreshing...") timeframe = "1h" # Default timeframe else: + # Store current timeframe in add_position_params to persist across param changes + if "add_position_params" not in context.user_data: + context.user_data["add_position_params"] = {} + context.user_data["add_position_params"]["timeframe"] = timeframe # Timeframe switch - show loading transition await query.answer(f"Loading {timeframe} candles...") @@ -1575,7 +1579,7 @@ async def handle_add_to_gateway(update: Update, context: ContextTypes.DEFAULT_TY to the Gateway configuration for the network. """ from geckoterminal_py import GeckoTerminalAsyncClient - from servers import server_manager + from config_manager import get_config_manager query = update.callback_query @@ -1630,7 +1634,7 @@ async def add_token_to_gateway(token_address: str) -> bool: name = attrs.get('name') # Add to gateway - client = await server_manager.get_default_client() + client = await get_config_manager().get_default_client() await client.gateway.add_token( network_id=network_id, address=token_address, @@ -1750,7 +1754,6 @@ async def handle_pool_ohlcv(update: Update, context: ContextTypes.DEFAULT_TYPE, timeframe: OHLCV timeframe (1m, 5m, 15m, 1h, 4h, 1d) currency: Price currency - "usd" or "token" (quote token) """ - from io import BytesIO from telegram import InputMediaPhoto query = update.callback_query @@ -1899,7 +1902,6 @@ async def handle_pool_combined_chart(update: Update, context: ContextTypes.DEFAU timeframe: OHLCV timeframe currency: Price currency - "usd" or "token" (quote token) """ - from io import BytesIO query = update.callback_query await query.answer("Loading combined chart...") @@ -2322,7 +2324,7 @@ async def handle_manage_positions(update: Update, context: ContextTypes.DEFAULT_ """Display manage positions menu with all active LP positions""" try: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -2535,7 +2537,7 @@ async def fetch_pool_info_task(): cached = get_cached(context.user_data, cache_key, ttl=DEFAULT_CACHE_TTL) if cached: return cached - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) info = await _fetch_pool_info(client, pool_address, connector) if info: set_cached(context.user_data, cache_key, info) @@ -2788,7 +2790,7 @@ async def handle_pos_collect_fees(update: Update, context: ContextTypes.DEFAULT_ ) chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -2932,7 +2934,7 @@ async def handle_pos_close_execute(update: Update, context: ContextTypes.DEFAULT ) chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -2950,9 +2952,13 @@ async def handle_pos_close_execute(update: Update, context: ContextTypes.DEFAULT ) if result: - # Clear the positions cache to force fresh fetch - context.user_data.pop("positions_cache", None) + # Remove this specific position from cache, but keep others (for LP monitor alerts) + positions_cache = context.user_data.get("positions_cache", {}) + if pos_index in positions_cache: + del positions_cache[pos_index] + # Clear the full position list cache to force refresh on next view context.user_data.pop("all_positions", None) + context.user_data.pop("lp_positions_cache", None) # Also clear LP list cache pair = pos.get('trading_pair', 'Unknown') success_msg = escape_markdown_v2(f"✅ Position closed: {pair}") @@ -3019,7 +3025,7 @@ async def process_position_list( pool_address = parts[2] chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -3609,7 +3615,7 @@ async def show_add_position_menu( balances = get_cached(context.user_data, balance_cache_key, ttl=DEFAULT_CACHE_TTL) if balances is None: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) balances = await _fetch_token_balances(client, network, base_symbol, quote_symbol) set_cached(context.user_data, balance_cache_key, balances) @@ -3993,7 +3999,7 @@ async def handle_pos_refresh(update: Update, context: ContextTypes.DEFAULT_TYPE, # Refetch pool info if pool_address: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) pool_info = await _fetch_pool_info(client, pool_address, connector) set_cached(context.user_data, pool_cache_key, pool_info) context.user_data["selected_pool_info"] = pool_info @@ -4324,7 +4330,7 @@ async def handle_pos_add_confirm(update: Update, context: ContextTypes.DEFAULT_T pass chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -4599,7 +4605,7 @@ async def process_add_position( network = params.get("network", "solana-mainnet-beta") chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_clmm'): raise ValueError("Gateway CLMM not available") @@ -4704,13 +4710,26 @@ async def process_pos_set_lower( params["lower_price"] = user_input.strip() context.user_data["dex_state"] = "add_position" - success_msg = escape_markdown_v2(f"✅ Lower price set to: {user_input}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - # Refresh pool detail to show updated chart with range lines selected_pool = context.user_data.get("selected_pool", {}) if selected_pool: - await _show_pool_detail(update, context, selected_pool, from_callback=False) + # Show loading message + loading_msg = await update.message.reply_text( + escape_markdown_v2(f"✅ Lower set to: {user_input} - Updating chart..."), + parse_mode="MarkdownV2" + ) + timeframe = params.get("timeframe", "1h") + await _show_pool_detail(update, context, selected_pool, from_callback=False, timeframe=timeframe) + # Delete loading message after chart is shown + try: + await loading_msg.delete() + except Exception: + pass + else: + await update.message.reply_text( + escape_markdown_v2(f"✅ Lower price set to: {user_input}"), + parse_mode="MarkdownV2" + ) async def process_pos_set_upper( @@ -4723,13 +4742,26 @@ async def process_pos_set_upper( params["upper_price"] = user_input.strip() context.user_data["dex_state"] = "add_position" - success_msg = escape_markdown_v2(f"✅ Upper price set to: {user_input}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - # Refresh pool detail to show updated chart with range lines selected_pool = context.user_data.get("selected_pool", {}) if selected_pool: - await _show_pool_detail(update, context, selected_pool, from_callback=False) + # Show loading message + loading_msg = await update.message.reply_text( + escape_markdown_v2(f"✅ Upper set to: {user_input} - Updating chart..."), + parse_mode="MarkdownV2" + ) + timeframe = params.get("timeframe", "1h") + await _show_pool_detail(update, context, selected_pool, from_callback=False, timeframe=timeframe) + # Delete loading message after chart is shown + try: + await loading_msg.delete() + except Exception: + pass + else: + await update.message.reply_text( + escape_markdown_v2(f"✅ Upper price set to: {user_input}"), + parse_mode="MarkdownV2" + ) async def process_pos_set_base( @@ -4742,13 +4774,26 @@ async def process_pos_set_base( params["amount_base"] = user_input.strip() context.user_data["dex_state"] = "add_position" - success_msg = escape_markdown_v2(f"✅ Base amount set to: {user_input}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - # Refresh pool detail view selected_pool = context.user_data.get("selected_pool", {}) if selected_pool: - await _show_pool_detail(update, context, selected_pool, from_callback=False) + # Show loading message + loading_msg = await update.message.reply_text( + escape_markdown_v2(f"✅ Base set to: {user_input} - Updating chart..."), + parse_mode="MarkdownV2" + ) + timeframe = params.get("timeframe", "1h") + await _show_pool_detail(update, context, selected_pool, from_callback=False, timeframe=timeframe) + # Delete loading message after chart is shown + try: + await loading_msg.delete() + except Exception: + pass + else: + await update.message.reply_text( + escape_markdown_v2(f"✅ Base amount set to: {user_input}"), + parse_mode="MarkdownV2" + ) async def process_pos_set_quote( @@ -4761,10 +4806,23 @@ async def process_pos_set_quote( params["amount_quote"] = user_input.strip() context.user_data["dex_state"] = "add_position" - success_msg = escape_markdown_v2(f"✅ Quote amount set to: {user_input}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - # Refresh pool detail view selected_pool = context.user_data.get("selected_pool", {}) if selected_pool: - await _show_pool_detail(update, context, selected_pool, from_callback=False) + # Show loading message + loading_msg = await update.message.reply_text( + escape_markdown_v2(f"✅ Quote set to: {user_input} - Updating chart..."), + parse_mode="MarkdownV2" + ) + timeframe = params.get("timeframe", "1h") + await _show_pool_detail(update, context, selected_pool, from_callback=False, timeframe=timeframe) + # Delete loading message after chart is shown + try: + await loading_msg.delete() + except Exception: + pass + else: + await update.message.reply_text( + escape_markdown_v2(f"✅ Quote amount set to: {user_input}"), + parse_mode="MarkdownV2" + ) diff --git a/handlers/dex/router.py b/handlers/dex/router.py new file mode 100644 index 0000000..f3f9754 --- /dev/null +++ b/handlers/dex/router.py @@ -0,0 +1,741 @@ +""" +DEX Callback and Message Router + +Central routing for all DEX-related callback queries and text messages. +Dispatches to appropriate handlers based on action patterns. +""" + +import logging +from typing import Callable, Awaitable + +from telegram import Update +from telegram.ext import ContextTypes, CallbackQueryHandler, MessageHandler, filters + +from utils.auth import restricted +from utils.telegram_formatters import format_error_message + +from .menu import cancel_dex_loading_task + +# Import all handler modules +from .swap import ( + handle_swap, + handle_swap_refresh, + handle_swap_toggle_side, + handle_swap_set_connector, + handle_swap_connector_select, + handle_swap_set_network, + handle_swap_network_select, + handle_swap_set_pair, + handle_swap_set_amount, + handle_swap_set_slippage, + handle_swap_get_quote, + handle_swap_execute_confirm, + handle_swap_history, + handle_swap_status, + handle_swap_hist_filter_pair, + handle_swap_hist_filter_connector, + handle_swap_hist_filter_status, + handle_swap_hist_set_filter, + handle_swap_hist_page, + handle_swap_hist_clear, + process_swap, + process_swap_set_pair, + process_swap_set_amount, + process_swap_set_slippage, + process_swap_status, +) +from .pools import ( + handle_pool_info, + handle_pool_list, + handle_pool_select, + handle_pool_list_back, + handle_pool_detail_refresh, + handle_add_to_gateway, + handle_plot_liquidity, + handle_pool_ohlcv, + handle_pool_combined_chart, + handle_manage_positions, + handle_pos_view, + handle_pos_view_pool, + handle_pos_collect_fees, + handle_pos_close_confirm, + handle_pos_close_execute, + handle_position_list, + handle_add_position, + show_add_position_menu, + handle_pos_set_connector, + handle_pos_set_network, + handle_pos_set_pool, + handle_pos_set_lower, + handle_pos_set_upper, + handle_pos_set_base, + handle_pos_set_quote, + handle_pos_add_confirm, + handle_pos_use_max_range, + handle_pos_help, + handle_pos_toggle_strategy, + handle_pos_refresh, + process_pool_info, + process_pool_list, + process_position_list, + process_add_position, + process_pos_set_connector, + process_pos_set_network, + process_pos_set_pool, + process_pos_set_lower, + process_pos_set_upper, + process_pos_set_base, + process_pos_set_quote, +) +from .geckoterminal import ( + show_gecko_explore_menu, + handle_gecko_toggle_view, + handle_gecko_select_network, + handle_gecko_set_network, + handle_gecko_show_pools, + handle_gecko_refresh, + handle_gecko_trending, + show_trending_pools, + handle_gecko_top, + show_top_pools, + handle_gecko_new, + show_new_pools, + handle_gecko_networks, + show_network_menu, + handle_gecko_search, + handle_gecko_search_network, + handle_gecko_search_set_network, + process_gecko_search, + show_pool_detail, + show_gecko_charts_menu, + show_ohlcv_chart, + show_recent_trades, + show_gecko_liquidity, + show_gecko_combined, + handle_copy_address, + handle_gecko_token_info, + handle_gecko_token_search, + handle_gecko_token_add, + handle_back_to_list, + handle_gecko_add_liquidity, + handle_gecko_swap, + show_gecko_info, + handle_gecko_pool_tf, + handle_gecko_add_tokens, + handle_gecko_restart_gateway, +) +from .liquidity import ( + handle_liquidity, + handle_lp_refresh, + handle_lp_pos_view, + handle_lp_collect_all, + handle_lp_history, + handle_lp_hist_filter_pair, + handle_lp_hist_filter_connector, + handle_lp_hist_filter_status, + handle_lp_hist_set_filter, + handle_lp_hist_page, + handle_lp_hist_clear, +) +from .lp_monitor_handlers import ( + handle_lpm_navigation, + handle_lpm_detail, + handle_lpm_collect_fees, + handle_lpm_rebalance, + handle_lpm_rebalance_execute, + handle_lpm_oor_navigation, + handle_lpm_skip, + handle_lpm_dismiss, + handle_lpm_cancel_countdown, +) +from .menu import handle_close, handle_refresh + +logger = logging.getLogger(__name__) + +# Type alias for handler functions +HandlerFunc = Callable[[Update, ContextTypes.DEFAULT_TYPE], Awaitable[None]] + + +# ============================================ +# SLOW ACTIONS (require typing indicator) +# ============================================ + +SLOW_ACTIONS = frozenset({ + "main_menu", "swap", "swap_refresh", "swap_get_quote", "swap_execute_confirm", "swap_history", + "swap_hist_clear", "swap_hist_filter_pair", "swap_hist_filter_connector", "swap_hist_filter_status", + "swap_hist_page_prev", "swap_hist_page_next", + "liquidity", "lp_refresh", "lp_history", "lp_collect_all", + "lp_hist_clear", "lp_hist_filter_pair", "lp_hist_filter_connector", "lp_hist_filter_status", + "lp_hist_page_prev", "lp_hist_page_next", + "pool_info", "pool_list", "manage_positions", "pos_add_confirm", "pos_close_exec", + "add_to_gateway", "pool_detail_refresh", + "gecko_networks", "gecko_trades", "gecko_show_pools", "gecko_refresh", "gecko_token_search", "gecko_token_add", + "gecko_explore", "gecko_swap", "gecko_info", "gecko_add_tokens", "gecko_restart_gateway" +}) + +SLOW_PREFIXES = ( + "gecko_trending_", "gecko_top_", "gecko_new_", "gecko_pool:", "gecko_ohlcv:", + "gecko_pool_tf:", "gecko_token:", "swap_hist_set_", "lp_hist_set_", + "lpm_nav:", "lpm_cancel_countdown:", "pos_close:" +) + + +def _is_slow_action(action: str) -> bool: + """Check if an action requires a typing indicator.""" + return action in SLOW_ACTIONS or action.startswith(SLOW_PREFIXES) + + +# ============================================ +# CALLBACK ACTION DISPATCH TABLE +# ============================================ + +# Simple action -> handler mapping (no parameters) +SIMPLE_ACTIONS: dict[str, HandlerFunc] = { + # Menu + "main_menu": handle_swap, + "close": handle_close, + "refresh": handle_refresh, + "noop": lambda u, c: None, # No-op for page indicators + "lpm_noop": lambda u, c: None, + + # Swap + "swap": handle_swap, + "swap_refresh": handle_swap_refresh, + "swap_toggle_side": handle_swap_toggle_side, + "swap_set_connector": handle_swap_set_connector, + "swap_set_network": handle_swap_set_network, + "swap_set_pair": handle_swap_set_pair, + "swap_set_amount": handle_swap_set_amount, + "swap_set_slippage": handle_swap_set_slippage, + "swap_get_quote": handle_swap_get_quote, + "swap_execute_confirm": handle_swap_execute_confirm, + "swap_history": handle_swap_history, + "swap_status": handle_swap_status, + "swap_hist_filter_pair": handle_swap_hist_filter_pair, + "swap_hist_filter_connector": handle_swap_hist_filter_connector, + "swap_hist_filter_status": handle_swap_hist_filter_status, + "swap_hist_clear": handle_swap_hist_clear, + + # Legacy swap redirects + "swap_quote": handle_swap, + "swap_execute": handle_swap, + "swap_search": handle_swap_history, + + # Liquidity + "liquidity": handle_liquidity, + "lp_refresh": handle_lp_refresh, + "lp_collect_all": handle_lp_collect_all, + "lp_history": handle_lp_history, + "lp_hist_filter_pair": handle_lp_hist_filter_pair, + "lp_hist_filter_connector": handle_lp_hist_filter_connector, + "lp_hist_filter_status": handle_lp_hist_filter_status, + "lp_hist_clear": handle_lp_hist_clear, + + # Legacy liquidity redirects + "explore_pools": handle_lp_refresh, + + # Pool + "pool_info": handle_pool_info, + "pool_list": handle_pool_list, + "pool_list_back": handle_pool_list_back, + "pool_detail_refresh": handle_pool_detail_refresh, + "add_to_gateway": handle_add_to_gateway, + + # Position management + "manage_positions": handle_manage_positions, + "position_list": handle_position_list, + "add_position": handle_add_position, + "pos_set_connector": handle_pos_set_connector, + "pos_set_network": handle_pos_set_network, + "pos_set_pool": handle_pos_set_pool, + "pos_set_lower": handle_pos_set_lower, + "pos_set_upper": handle_pos_set_upper, + "pos_set_base": handle_pos_set_base, + "pos_set_quote": handle_pos_set_quote, + "pos_add_confirm": handle_pos_add_confirm, + "pos_use_max_range": handle_pos_use_max_range, + "pos_help": handle_pos_help, + "pos_toggle_strategy": handle_pos_toggle_strategy, + "pos_refresh": handle_pos_refresh, + + # GeckoTerminal + "gecko_explore": show_gecko_explore_menu, + "gecko_toggle_view": handle_gecko_toggle_view, + "gecko_select_network": handle_gecko_select_network, + "gecko_show_pools": handle_gecko_show_pools, + "gecko_refresh": handle_gecko_refresh, + "gecko_trending": handle_gecko_trending, + "gecko_top": handle_gecko_top, + "gecko_new": handle_gecko_new, + "gecko_networks": handle_gecko_networks, + "gecko_search": handle_gecko_search, + "gecko_search_network": handle_gecko_search_network, + "gecko_charts": show_gecko_charts_menu, + "gecko_add_liquidity": handle_gecko_add_liquidity, + "gecko_token_search": handle_gecko_token_search, + "gecko_token_add": handle_gecko_token_add, + "gecko_swap": handle_gecko_swap, + "gecko_info": show_gecko_info, + "gecko_liquidity": show_gecko_liquidity, + "gecko_trades": show_recent_trades, + "gecko_copy_addr": handle_copy_address, + "gecko_back_to_list": handle_back_to_list, + "gecko_add_tokens": handle_gecko_add_tokens, + "gecko_restart_gateway": handle_gecko_restart_gateway, +} + + +# ============================================ +# CEX SWITCH HANDLER +# ============================================ + +async def _handle_switch_to_cex( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + connector_name: str +) -> None: + """Switch from DEX to CEX trading""" + from handlers.config.user_preferences import ( + set_last_trade_connector, + get_clob_order_defaults, + ) + from handlers.cex.trade import handle_trade as cex_handle_trade + + # Clear DEX state + context.user_data.pop("dex_state", None) + context.user_data.pop("swap_params", None) + + # Save preference and set up CEX trade + set_last_trade_connector(context.user_data, "cex", connector_name) + defaults = get_clob_order_defaults(context.user_data) + defaults["connector"] = connector_name + context.user_data["trade_params"] = defaults + + # Route to CEX trade menu + await cex_handle_trade(update, context) + + +# ============================================ +# PARAMETERIZED ACTION HANDLERS +# ============================================ + +async def _handle_parameterized_action( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + action: str +) -> bool: + """ + Handle actions that require parameter extraction. + Returns True if action was handled, False otherwise. + """ + query = update.callback_query + + # Swap connector selection: swap_connector_{name} + if action.startswith("swap_connector_"): + connector_name = action.replace("swap_connector_", "") + await handle_swap_connector_select(update, context, connector_name) + return True + + # Swap network selection: swap_network_{id} + if action.startswith("swap_network_"): + network_id = action.replace("swap_network_", "") + await handle_swap_network_select(update, context, network_id) + return True + + # Switch to CEX: switch_cex_{connector} + if action.startswith("switch_cex_"): + connector_name = action.replace("switch_cex_", "") + await _handle_switch_to_cex(update, context, connector_name) + return True + + # Swap history filters + if action.startswith("swap_hist_set_pair_"): + value = action.replace("swap_hist_set_pair_", "") + await handle_swap_hist_set_filter(update, context, "pair", value) + return True + if action.startswith("swap_hist_set_connector_"): + value = action.replace("swap_hist_set_connector_", "") + await handle_swap_hist_set_filter(update, context, "connector", value) + return True + if action.startswith("swap_hist_set_status_"): + value = action.replace("swap_hist_set_status_", "") + await handle_swap_hist_set_filter(update, context, "status", value) + return True + + # Swap history pagination + if action == "swap_hist_page_prev": + await handle_swap_hist_page(update, context, "prev") + return True + if action == "swap_hist_page_next": + await handle_swap_hist_page(update, context, "next") + return True + + # LP history filters + if action.startswith("lp_hist_set_pair_"): + value = action.replace("lp_hist_set_pair_", "") + await handle_lp_hist_set_filter(update, context, "pair", value) + return True + if action.startswith("lp_hist_set_connector_"): + value = action.replace("lp_hist_set_connector_", "") + await handle_lp_hist_set_filter(update, context, "connector", value) + return True + if action.startswith("lp_hist_set_status_"): + value = action.replace("lp_hist_set_status_", "") + await handle_lp_hist_set_filter(update, context, "status", value) + return True + + # LP history pagination + if action == "lp_hist_page_prev": + await handle_lp_hist_page(update, context, "prev") + return True + if action == "lp_hist_page_next": + await handle_lp_hist_page(update, context, "next") + return True + + # LP position view: lp_pos_view:{index} + if action.startswith("lp_pos_view:"): + pos_index = int(action.split(":")[1]) + await handle_lp_pos_view(update, context, pos_index) + return True + + # Pool selection: pool_select:{index} + if action.startswith("pool_select:"): + pool_index = int(action.split(":")[1]) + await handle_pool_select(update, context, pool_index) + return True + + # Pool timeframe: pool_tf:{timeframe} + if action.startswith("pool_tf:"): + timeframe = action.split(":")[1] + await handle_pool_detail_refresh(update, context, timeframe=timeframe) + return True + + # Plot liquidity: plot_liquidity:{percentile} + if action.startswith("plot_liquidity:"): + percentile = int(action.split(":")[1]) + await handle_plot_liquidity(update, context, percentile) + return True + + # Position view: pos_view:{index} + if action.startswith("pos_view:"): + pos_index = action.split(":")[1] + await handle_pos_view(update, context, pos_index) + return True + + # Position view with timeframe: pos_view_tf:{index}:{timeframe} + if action.startswith("pos_view_tf:"): + parts = action.split(":") + pos_index = parts[1] + timeframe = parts[2] if len(parts) > 2 else "1h" + await handle_pos_view(update, context, pos_index, timeframe=timeframe) + return True + + # Position view pool: pos_view_pool:{index} + if action.startswith("pos_view_pool:"): + pos_index = action.split(":")[1] + await handle_pos_view_pool(update, context, pos_index) + return True + + # Position actions + if action.startswith("pos_collect:"): + pos_index = action.split(":")[1] + await handle_pos_collect_fees(update, context, pos_index) + return True + if action.startswith("pos_close:"): + pos_index = action.split(":")[1] + await handle_pos_close_confirm(update, context, pos_index) + return True + if action.startswith("pos_close_exec:"): + pos_index = action.split(":")[1] + await handle_pos_close_execute(update, context, pos_index) + return True + + # Position timeframe in add menu: pos_tf:{timeframe} + if action.startswith("pos_tf:"): + timeframe = action.split(":")[1] + await handle_pos_refresh(update, context, timeframe=timeframe) + return True + + # Copy pool address + if action.startswith("copy_pool:"): + selected_pool = context.user_data.get("selected_pool", {}) + pool_address = selected_pool.get('pool_address', selected_pool.get('address', 'N/A')) + await query.answer("Address sent below ⬇️") + await query.message.reply_text(f"`{pool_address}`", parse_mode="Markdown") + return True + + # Add position from pool + if action == "add_position_from_pool": + await query.answer("Loading position form...") + selected_pool = context.user_data.get("selected_pool", {}) + if selected_pool: + pool_address = selected_pool.get('pool_address', selected_pool.get('address', '')) + context.user_data["add_position_params"] = { + "connector": selected_pool.get('connector', 'meteora'), + "network": "solana-mainnet-beta", + "pool_address": pool_address, + "lower_price": "", + "upper_price": "", + "amount_base": "10%", + "amount_quote": "10%", + "strategy_type": "0", + } + await show_add_position_menu(update, context) + return True + + # GeckoTerminal handlers + if action.startswith("gecko_set_network:"): + network = action.split(":")[1] + await handle_gecko_set_network(update, context, network) + return True + if action.startswith("gecko_trending_"): + network = action.replace("gecko_trending_", "") + network = None if network == "all" else network + await show_trending_pools(update, context, network) + return True + if action.startswith("gecko_top_"): + network = action.replace("gecko_top_", "") + await show_top_pools(update, context, network) + return True + if action.startswith("gecko_new_"): + network = action.replace("gecko_new_", "") + network = None if network == "all" else network + await show_new_pools(update, context, network) + return True + if action.startswith("gecko_net_"): + network = action.replace("gecko_net_", "") + await show_network_menu(update, context, network) + return True + if action.startswith("gecko_search_set_net:"): + network = action.split(":")[1] + await handle_gecko_search_set_network(update, context, network) + return True + if action.startswith("gecko_pool:"): + pool_index = int(action.split(":")[1]) + await show_pool_detail(update, context, pool_index) + return True + if action.startswith("gecko_token:"): + token_type = action.split(":")[1] + await handle_gecko_token_info(update, context, token_type) + return True + if action.startswith("gecko_ohlcv:"): + timeframe = action.split(":")[1] + await show_ohlcv_chart(update, context, timeframe) + return True + if action.startswith("gecko_combined:"): + timeframe = action.split(":")[1] + await show_gecko_combined(update, context, timeframe) + return True + if action.startswith("gecko_pool_tf:"): + timeframe = action.split(":")[1] + await handle_gecko_pool_tf(update, context, timeframe) + return True + + # Pool OHLCV: pool_ohlcv:{timeframe}:{currency} + if action.startswith("pool_ohlcv:"): + parts = action.split(":") + timeframe = parts[1] + currency = parts[2] if len(parts) > 2 else "usd" + await handle_pool_ohlcv(update, context, timeframe, currency) + return True + + # Pool combined chart: pool_combined:{timeframe}:{currency} + if action.startswith("pool_combined:"): + parts = action.split(":") + timeframe = parts[1] + currency = parts[2] if len(parts) > 2 else "usd" + await handle_pool_combined_chart(update, context, timeframe, currency) + return True + + # LP Monitor handlers + if action.startswith("lpm_skip:"): + cache_key = action.split(":")[1] + await handle_lpm_skip(update, context, cache_key) + return True + if action.startswith("lpm_nav:"): + parts = action.split(":") + if len(parts) >= 3: + instance_id = parts[1] + new_index = int(parts[2]) + await handle_lpm_navigation(update, context, instance_id, new_index) + return True + if action.startswith("lpm_dismiss:"): + await handle_lpm_dismiss(update, context) + return True + if action.startswith("lpm_detail:"): + parts = action.split(":") + if len(parts) >= 3: + instance_id = parts[1] + index = int(parts[2]) + await handle_lpm_detail(update, context, instance_id, index) + return True + if action.startswith("lpm_collect:"): + cache_key = action.replace("lpm_collect:", "") + await handle_lpm_collect_fees(update, context, cache_key) + return True + if action.startswith("lpm_rebalance_confirm:"): + cache_key = action.replace("lpm_rebalance_confirm:", "") + await handle_lpm_rebalance_execute(update, context, cache_key) + return True + if action.startswith("lpm_rebalance:"): + cache_key = action.replace("lpm_rebalance:", "") + await handle_lpm_rebalance(update, context, cache_key) + return True + if action.startswith("lpm_oor:"): + parts = action.split(":") + if len(parts) >= 3: + instance_id = parts[1] + index = int(parts[2]) + await handle_lpm_oor_navigation(update, context, instance_id, index) + return True + if action.startswith("lpm_cancel_countdown:"): + parts = action.split(":") + if len(parts) >= 3: + instance_id = parts[1] + pos_id = parts[2] + await handle_lpm_cancel_countdown(update, context, instance_id, pos_id) + return True + + return False + + +# ============================================ +# MESSAGE STATE HANDLERS +# ============================================ + +# State -> processor function mapping +MESSAGE_STATE_HANDLERS: dict[str, Callable] = { + # Swap states + "swap": process_swap, + "swap_set_pair": process_swap_set_pair, + "swap_set_amount": process_swap_set_amount, + "swap_set_slippage": process_swap_set_slippage, + "swap_status": process_swap_status, + + # Pool states + "pool_info": process_pool_info, + "pool_list": process_pool_list, + "position_list": process_position_list, + + # Add position states + "add_position": process_add_position, + "pos_set_connector": process_pos_set_connector, + "pos_set_network": process_pos_set_network, + "pos_set_pool": process_pos_set_pool, + "pos_set_lower": process_pos_set_lower, + "pos_set_upper": process_pos_set_upper, + "pos_set_base": process_pos_set_base, + "pos_set_quote": process_pos_set_quote, + + # GeckoTerminal + "gecko_search": process_gecko_search, +} + +# States that should be cleared after processing +CLEAR_STATE_AFTER = frozenset({ + "swap", "swap_status", "pool_info", "pool_list", "position_list", "add_position" +}) + + +# ============================================ +# MAIN CALLBACK HANDLER +# ============================================ + +@restricted +async def dex_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle inline button callbacks - Routes to appropriate sub-module.""" + query = update.callback_query + await query.answer() + + try: + # Parse action from callback data (format: dex:{action}) + callback_parts = query.data.split(":", 1) + action = callback_parts[1] if len(callback_parts) > 1 else query.data + + # Cancel any pending menu loading task when navigating away + if action != "main_menu": + cancel_dex_loading_task(context) + + # Show typing indicator for slow operations + if _is_slow_action(action): + await query.message.reply_chat_action("typing") + + # Try simple action dispatch first + if action in SIMPLE_ACTIONS: + handler = SIMPLE_ACTIONS[action] + if handler is not None: + await handler(update, context) + return + + # Try parameterized action handlers + if await _handle_parameterized_action(update, context, action): + return + + # Unknown action + await query.message.reply_text(f"Unknown action: {action}") + + except Exception as e: + # Ignore "message is not modified" errors - they're harmless + if "not modified" in str(e).lower(): + logger.debug(f"Message not modified (ignored): {e}") + return + + logger.error(f"Error in DEX callback handler: {e}", exc_info=True) + error_message = format_error_message(f"Operation failed: {str(e)}") + try: + await query.message.reply_text(error_message, parse_mode="MarkdownV2") + except Exception as reply_error: + logger.warning(f"Failed to send error message: {reply_error}") + + +# ============================================ +# MAIN MESSAGE HANDLER +# ============================================ + +@restricted +async def dex_message_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle user text input - Routes to appropriate processor.""" + dex_state = context.user_data.get("dex_state") + + if not dex_state: + return + + user_input = update.message.text.strip() + logger.info(f"DEX message handler - state: {dex_state}, input: {user_input}") + + try: + # Clear state for operations that complete + if dex_state in CLEAR_STATE_AFTER: + context.user_data.pop("dex_state", None) + + # Dispatch to appropriate handler + handler = MESSAGE_STATE_HANDLERS.get(dex_state) + if handler: + await handler(update, context, user_input) + else: + await update.message.reply_text(f"Unknown state: {dex_state}") + + except Exception as e: + logger.error(f"Error processing DEX input: {e}", exc_info=True) + error_message = format_error_message(f"Failed to process input: {str(e)}") + await update.message.reply_text(error_message, parse_mode="MarkdownV2") + + +# ============================================ +# HANDLER FACTORIES +# ============================================ + +def get_dex_callback_handler() -> CallbackQueryHandler: + """Get the callback query handler for DEX menu.""" + return CallbackQueryHandler( + dex_callback_handler, + pattern="^dex:" + ) + + +def get_dex_message_handler() -> MessageHandler: + """Returns the message handler for DEX text input.""" + return MessageHandler( + filters.TEXT & ~filters.COMMAND, + dex_message_handler + ) diff --git a/handlers/dex/swap.py b/handlers/dex/swap.py index 1120da8..2f054aa 100644 --- a/handlers/dex/swap.py +++ b/handlers/dex/swap.py @@ -17,19 +17,18 @@ from handlers.config.user_preferences import ( get_dex_swap_defaults, get_dex_connector, - get_dex_last_swap, set_dex_last_swap, get_all_enabled_networks, DEFAULT_DEX_NETWORK, + set_last_trade_connector, ) -from servers import get_client +from config_manager import get_client from ._shared import ( get_cached, set_cached, invalidate_cache, get_explorer_url, format_relative_time, - _format_amount, get_history_filters, set_history_filters, HistoryFilters, @@ -91,7 +90,14 @@ def _format_number(value, decimals: int = 2) -> str: return f"{num:.{decimals}f}" if abs(num) >= 0.01: return f"{num:.4f}" - return f"{num:.6f}" + if abs(num) >= 0.0001: + return f"{num:.6f}" + if abs(num) >= 0.000001: + return f"{num:.8f}" + if abs(num) >= 0.00000001: + return f"{num:.10f}" + # For extremely small numbers, use scientific notation + return f"{num:.2e}" except (ValueError, TypeError): return "—" @@ -266,12 +272,14 @@ async def _fetch_quotes_background( if not all([connector, network, trading_pair, amount]): return - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Fetch balances in parallel with quotes async def fetch_balances_safe(): try: - balances = await _fetch_balances(client) + # Check if force refresh is needed (e.g., after swap execution) + force_refresh = context.user_data.pop("_force_balance_refresh", False) + balances = await _fetch_balances(client, refresh=force_refresh) if balances: set_cached(context.user_data, "gateway_balances", balances) return balances @@ -524,7 +532,7 @@ async def show_swap_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, sen swaps = get_cached(context.user_data, "recent_swaps", ttl=60) if swaps is None: try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) swaps = await _fetch_recent_swaps(client, limit=5) set_cached(context.user_data, "recent_swaps", swaps) except Exception as e: @@ -541,25 +549,50 @@ async def show_swap_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, sen # Send or edit message message = None - if send_new or not update.callback_query: - message = await update.message.reply_text( + if send_new: + # Send new message - determine chat from callback query or message + if update.callback_query: + chat = update.callback_query.message.chat + else: + chat = update.message.chat + message = await chat.send_message( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup, disable_web_page_preview=True ) + elif update.callback_query: + try: + await update.callback_query.message.edit_text( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup, + disable_web_page_preview=True + ) + except Exception as e: + # Ignore "Message is not modified" error + if "not modified" not in str(e).lower(): + raise + message = update.callback_query.message else: - await update.callback_query.message.edit_text( + message = await update.message.reply_text( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup, disable_web_page_preview=True ) - message = update.callback_query.message - # Launch background quote fetch if no quote yet and auto_quote is enabled - if auto_quote and quote_result is None and message: - asyncio.create_task(_fetch_quotes_background(context, message, params, chat_id)) + # Store message for later editing (for text input processing) + if message: + context.user_data["swap_menu_message_id"] = message.message_id + context.user_data["swap_menu_chat_id"] = message.chat_id + + # Launch background fetch if needed (balances missing or quotes needed) + if message: + needs_balance_fetch = get_cached(context.user_data, "gateway_balances", ttl=120) is None + needs_quote_fetch = auto_quote and quote_result is None + if needs_balance_fetch or needs_quote_fetch: + asyncio.create_task(_fetch_quotes_background(context, message, params, chat_id)) # ============================================ @@ -573,6 +606,43 @@ def _invalidate_swap_quote(user_data: dict) -> None: del cache["swap_quote"] +async def _update_swap_menu_after_input(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Update the swap menu message after text input (like trade.py pattern)""" + # Delete user's input message + try: + await update.message.delete() + except Exception: + pass + + # Edit the stored swap menu message + msg_id = context.user_data.get("swap_menu_message_id") + chat_id = context.user_data.get("swap_menu_chat_id") + + if msg_id and chat_id: + params = context.user_data.get("swap_params", {}) + quote_result = get_cached(context.user_data, "swap_quote", ttl=30) + + help_text = _build_swap_menu_text(context.user_data, params, quote_result) + keyboard = _build_swap_keyboard(params) + reply_markup = InlineKeyboardMarkup(keyboard) + + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=msg_id, + text=help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup, + disable_web_page_preview=True + ) + except Exception as e: + logger.debug(f"Could not update swap menu: {e}") + # Fallback: send new message + await show_swap_menu(update, context, send_new=True) + else: + await show_swap_menu(update, context, send_new=True) + + async def handle_swap_toggle_side(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Toggle between BUY and SELL""" params = context.user_data.get("swap_params", {}) @@ -589,7 +659,7 @@ async def handle_swap_set_connector(update: Update, context: ContextTypes.DEFAUL network = params.get("network", "solana-mainnet-beta") try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) cache_key = "router_connectors" connectors = get_cached(context.user_data, cache_key, ttl=300) @@ -653,60 +723,84 @@ async def handle_swap_connector_select(update: Update, context: ContextTypes.DEF params["network"] = f"ethereum-{networks[0]}" break + # Save unified preference for /trade command (DEX stores network ID, not connector) + network = params.get("network", DEFAULT_DEX_NETWORK) + set_last_trade_connector(context.user_data, "dex", network) + _invalidate_swap_quote(context.user_data) context.user_data["dex_state"] = "swap" await show_swap_menu(update, context) async def handle_swap_set_network(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Show available networks for selection""" + """Show available networks (DEX) and connectors (CEX) for selection""" chat_id = update.effective_chat.id try: - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) - networks_cache_key = "gateway_networks" - networks = get_cached(context.user_data, networks_cache_key, ttl=300) - if networks is None: - networks = await _fetch_networks(client) - set_cached(context.user_data, networks_cache_key, networks) + # Fetch DEX networks and CEX connectors in parallel + from handlers import is_gateway_network - connectors_cache_key = "router_connectors" - connectors = get_cached(context.user_data, connectors_cache_key, ttl=300) - if connectors is None: - connectors = await _fetch_router_connectors(client) - set_cached(context.user_data, connectors_cache_key, connectors) + async def get_dex_networks(): + try: + return await _fetch_networks(client) + except Exception as e: + logger.warning(f"Could not fetch DEX networks: {e}") + return [] - # Filter to networks with routers - router_networks = set() - for c in connectors: - chain = c.get('chain', '') - for net in c.get('networks', []): - router_networks.add((chain, net)) + async def get_cex_connectors(): + try: + state = await client.portfolio.get_state() + cex = set() + for account_data in state.values(): + if isinstance(account_data, dict): + for connector_name in account_data.keys(): + if not is_gateway_network(connector_name): + cex.add(connector_name) + return sorted(cex) + except Exception as e: + logger.warning(f"Could not fetch CEX connectors: {e}") + return [] - available = [ - n for n in networks - if (n.get('chain', ''), n.get('network', '')) in router_networks - ] + import asyncio + networks, cex_connectors = await asyncio.gather(get_dex_networks(), get_cex_connectors()) - if not available: - help_text = r"🌐 *Select Network*" + "\n\n" + r"_No networks available\._" - keyboard = [[InlineKeyboardButton("« Back", callback_data="dex:swap")]] - else: - help_text = r"🌐 *Select Network*" + keyboard = [] - network_buttons = [] + # CEX section + if cex_connectors: + keyboard.append([InlineKeyboardButton("━━ CEX ━━", callback_data="dex:noop")]) row = [] - for n in available: - network_id = n.get('network_id', '') + for connector in cex_connectors: + row.append(InlineKeyboardButton(connector, callback_data=f"dex:switch_cex_{connector}")) + if len(row) == 2: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + # DEX section + if networks: + keyboard.append([InlineKeyboardButton("━━ DEX ━━", callback_data="dex:noop")]) + row = [] + for network_item in networks: + if isinstance(network_item, dict): + network_id = network_item.get('network_id') or network_item.get('id') or str(network_item) + else: + network_id = str(network_item) row.append(InlineKeyboardButton(network_id, callback_data=f"dex:swap_network_{network_id}")) - if len(row) == 3: - network_buttons.append(row) + if len(row) == 2: + keyboard.append(row) row = [] if row: - network_buttons.append(row) + keyboard.append(row) - keyboard = network_buttons + [[InlineKeyboardButton("« Back", callback_data="dex:swap")]] + if not networks and not cex_connectors: + help_text = r"🔄 *Select Connector*" + "\n\n" + r"_No connectors available\._" + else: + help_text = r"🔄 *Select Connector*" + keyboard.append([InlineKeyboardButton("« Back", callback_data="dex:swap")]) reply_markup = InlineKeyboardMarkup(keyboard) await update.callback_query.message.edit_text( @@ -716,28 +810,43 @@ async def handle_swap_set_network(update: Update, context: ContextTypes.DEFAULT_ ) except Exception as e: - logger.error(f"Error showing networks: {e}", exc_info=True) - error_text = format_error_message(f"Error loading networks: {str(e)}") + logger.error(f"Error showing connectors: {e}", exc_info=True) + error_text = format_error_message(f"Error loading connectors: {str(e)}") await update.callback_query.message.edit_text(error_text, parse_mode="MarkdownV2") async def handle_swap_network_select(update: Update, context: ContextTypes.DEFAULT_TYPE, network_id: str) -> None: """Handle network selection from button""" + chat_id = update.effective_chat.id params = context.user_data.get("swap_params", {}) params["network"] = network_id - # Auto-update connector + # Auto-update connector based on network chain = network_id.split("-")[0] if network_id else "" - network = network_id.split("-", 1)[1] if "-" in network_id else "" + network_part = network_id.split("-", 1)[1] if "-" in network_id else "" + # Fetch connectors if not cached cache_key = "router_connectors" connectors = get_cached(context.user_data, cache_key, ttl=300) + if connectors is None: + try: + client = await get_client(chat_id, context=context) + connectors = await _fetch_router_connectors(client) + set_cached(context.user_data, cache_key, connectors) + except Exception as e: + logger.warning(f"Could not fetch connectors: {e}") + connectors = [] + + # Find matching connector for the network if connectors: for c in connectors: - if c.get('chain') == chain and network in c.get('networks', []): + if c.get('chain') == chain and network_part in c.get('networks', []): params["connector"] = c.get('name') break + # Save unified preference for /trade command (DEX stores network ID) + set_last_trade_connector(context.user_data, "dex", network_id) + _invalidate_swap_quote(context.user_data) context.user_data["dex_state"] = "swap" await show_swap_menu(update, context) @@ -759,7 +868,7 @@ async def handle_swap_set_pair(update: Update, context: ContextTypes.DEFAULT_TYP context.user_data["dex_state"] = "swap_set_pair" context.user_data["dex_previous_state"] = "swap" - await update.callback_query.message.reply_text( + await update.callback_query.message.edit_text( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -782,7 +891,7 @@ async def handle_swap_set_amount(update: Update, context: ContextTypes.DEFAULT_T context.user_data["dex_state"] = "swap_set_amount" context.user_data["dex_previous_state"] = "swap" - await update.callback_query.message.reply_text( + await update.callback_query.message.edit_text( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -805,7 +914,7 @@ async def handle_swap_set_slippage(update: Update, context: ContextTypes.DEFAULT context.user_data["dex_state"] = "swap_set_slippage" context.user_data["dex_previous_state"] = "swap" - await update.callback_query.message.reply_text( + await update.callback_query.message.edit_text( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -831,12 +940,14 @@ async def handle_swap_get_quote(update: Update, context: ContextTypes.DEFAULT_TY if not all([connector, network, trading_pair, amount]): raise ValueError("Missing required parameters") - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) # Fetch balances in parallel with quotes async def fetch_balances_safe(): try: - balances = await _fetch_balances(client) + # Check if force refresh is needed (e.g., after swap execution) + force_refresh = context.user_data.pop("_force_balance_refresh", False) + balances = await _fetch_balances(client, refresh=force_refresh) if balances: set_cached(context.user_data, "gateway_balances", balances) return balances @@ -967,7 +1078,7 @@ async def handle_swap_execute_confirm(update: Update, context: ContextTypes.DEFA ) chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_swap'): raise ValueError("Gateway swap not available") @@ -1003,8 +1114,11 @@ async def handle_swap_execute_confirm(update: Update, context: ContextTypes.DEFA if result is None: raise ValueError("Swap execution failed") - # Invalidate caches + # Invalidate caches - including swap_quote so background fetch runs when going back invalidate_cache(context.user_data, "balances", "swaps") + _invalidate_swap_quote(context.user_data) + # Flag to force refresh on next balance fetch (swap changed balances) + context.user_data["_force_balance_refresh"] = True # Save params set_dex_last_swap(context.user_data, { @@ -1070,7 +1184,7 @@ async def handle_swap_status(update: Update, context: ContextTypes.DEFAULT_TYPE) context.user_data["dex_state"] = "swap_status" - await update.callback_query.message.reply_text( + await update.callback_query.message.edit_text( help_text, parse_mode="MarkdownV2", reply_markup=reply_markup @@ -1085,7 +1199,7 @@ async def process_swap_status( """Process swap status check""" try: chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_swap'): raise ValueError("Gateway swap not available") @@ -1186,7 +1300,7 @@ async def handle_swap_history(update: Update, context: ContextTypes.DEFAULT_TYPE filters = get_history_filters(context.user_data, "swap") chat_id = update.effective_chat.id - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not hasattr(client, 'gateway_swap'): error_message = format_error_message("Gateway swap not available") @@ -1457,14 +1571,12 @@ async def process_swap_set_pair( """Process trading pair input""" try: params = context.user_data.get("swap_params", {}) - params["trading_pair"] = user_input.strip() + params["trading_pair"] = user_input.strip().upper() _invalidate_swap_quote(context.user_data) context.user_data["dex_state"] = "swap" - success_msg = escape_markdown_v2(f"✅ Pair: {user_input}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - await show_swap_menu(update, context, send_new=True) + await _update_swap_menu_after_input(update, context) except Exception as e: logger.error(f"Error setting pair: {e}", exc_info=True) @@ -1497,12 +1609,7 @@ async def process_swap_set_amount( _invalidate_swap_quote(context.user_data) context.user_data["dex_state"] = "swap" - if is_quote_amount: - success_msg = escape_markdown_v2(f"✅ Amount: {amount_str} (quote)") - else: - success_msg = escape_markdown_v2(f"✅ Amount: {amount_str}") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - await show_swap_menu(update, context, send_new=True) + await _update_swap_menu_after_input(update, context) except Exception as e: logger.error(f"Error setting amount: {e}", exc_info=True) @@ -1529,9 +1636,7 @@ async def process_swap_set_slippage( context.user_data["dex_state"] = "swap" - success_msg = escape_markdown_v2(f"✅ Slippage: {slippage_str}%") - await update.message.reply_text(success_msg, parse_mode="MarkdownV2") - await show_swap_menu(update, context, send_new=True) + await _update_swap_menu_after_input(update, context) except Exception as e: logger.error(f"Error setting slippage: {e}", exc_info=True) diff --git a/handlers/dex/visualizations.py b/handlers/dex/visualizations.py index 48c8464..974b96b 100644 --- a/handlers/dex/visualizations.py +++ b/handlers/dex/visualizations.py @@ -592,7 +592,6 @@ def generate_combined_chart( try: import plotly.graph_objects as go from plotly.subplots import make_subplots - import numpy as np if not ohlcv_data: logger.warning("No OHLCV data for combined chart") @@ -977,7 +976,6 @@ def generate_aggregated_liquidity_chart( try: import plotly.graph_objects as go from collections import defaultdict - import numpy as np if not pools_data: logger.warning("No pools_data provided to aggregated chart") diff --git a/handlers/portfolio.py b/handlers/portfolio.py index bc09cc1..cd73d0a 100644 --- a/handlers/portfolio.py +++ b/handlers/portfolio.py @@ -8,11 +8,10 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes, CallbackQueryHandler -from utils.auth import restricted +from utils.auth import restricted, hummingbot_api_required from utils.telegram_formatters import ( - format_portfolio_summary, - format_portfolio_state, format_portfolio_overview, + format_connector_detail, format_error_message, escape_markdown_v2 ) @@ -605,7 +604,83 @@ async def _fetch_dashboard_data(client, days: int, refresh: bool = False): return overview_data, history, token_distribution, accounts_distribution, pnl_history, graph_interval +def _get_connector_keys(balances: dict) -> list: + """ + Extract connector keys from balances for keyboard buttons. + + Args: + balances: Portfolio state {account: {connector: [holdings]}} + + Returns: + List of "account:connector" keys, sorted by total value descending + """ + if not balances: + return [] + + connector_values = [] + for account_name, account_data in balances.items(): + for connector_name, connector_balances in account_data.items(): + if connector_balances: + total = sum(b.get("value", 0) for b in connector_balances if b.get("value", 0) > 0) + if total > 0: + connector_values.append({ + "key": f"{account_name}:{connector_name}", + "value": total + }) + + # Sort by value descending + connector_values.sort(key=lambda x: x["value"], reverse=True) + return [c["key"] for c in connector_values] + + +def build_portfolio_keyboard(connector_keys: list, days: int) -> InlineKeyboardMarkup: + """ + Build keyboard with connector buttons and controls. + + Args: + connector_keys: List of "account:connector" keys + days: Current days setting for Settings button + + Returns: + InlineKeyboardMarkup with connector buttons + """ + keyboard = [] + + # Row(s) of connector buttons (max 2 per row for longer names) + if connector_keys: + connector_row = [] + for conn_key in connector_keys: + # Extract connector name for display (after the colon) + display_name = conn_key.split(":")[-1] + + connector_row.append( + InlineKeyboardButton(display_name, callback_data=f"portfolio:connector:{conn_key}") + ) + # Use max 2 per row to fit longer names like "solana-mainnet-beta" + if len(connector_row) == 2: + keyboard.append(connector_row) + connector_row = [] + if connector_row: + keyboard.append(connector_row) + + # Bottom row: Refresh + Settings + keyboard.append([ + InlineKeyboardButton("🔄 Refresh", callback_data="portfolio:refresh"), + InlineKeyboardButton(f"⚙️ Settings ({days}d)", callback_data="portfolio:settings") + ]) + + return InlineKeyboardMarkup(keyboard) + + +def build_connector_detail_keyboard() -> InlineKeyboardMarkup: + """Build keyboard for connector detail view with Back button.""" + return InlineKeyboardMarkup([[ + InlineKeyboardButton("⬅️ Back to Overview", callback_data="portfolio:back_overview") + ]]) + + @restricted +@hummingbot_api_required async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """ Handle /portfolio command - Display comprehensive portfolio dashboard @@ -628,11 +703,11 @@ async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) return try: - from servers import server_manager + from config_manager import get_config_manager from utils.trading_data import get_lp_positions, get_perpetual_positions, get_active_orders, get_tokens_for_networks # Get first enabled server - servers = server_manager.list_servers() + servers = get_config_manager().list_servers() enabled_servers = [name for name, cfg in servers.items() if cfg.get("enabled", True)] if not enabled_servers: @@ -640,12 +715,10 @@ async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) await message.reply_text(error_message, parse_mode="MarkdownV2") return - # Use per-chat default server, falling back to global default - default_server = server_manager.get_default_server_for_chat(chat_id) - if default_server and default_server in enabled_servers: - server_name = default_server - else: - server_name = enabled_servers[0] + # Use user's preferred server + from handlers.config.user_preferences import get_active_server + preferred = get_active_server(context.user_data) + server_name = preferred if preferred and preferred in enabled_servers else enabled_servers[0] # Send initial loading message immediately text_msg = await message.reply_text( @@ -654,10 +727,10 @@ async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) parse_mode="MarkdownV2" ) - client = await server_manager.get_client(server_name) + client = await get_config_manager().get_client(server_name) # Check server status - server_status_info = await server_manager.check_server_status(server_name) + server_status_info = await get_config_manager().check_server_status(server_name) server_status = server_status_info.get("status", "online") # Get portfolio config @@ -667,8 +740,9 @@ async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) pnl_start_time = _calculate_start_time(30) graph_interval = _get_optimal_interval(days) - # Check if this is a refresh request (from callback) - refresh = context.user_data.pop("_portfolio_refresh", False) + # Always refresh balances for /portfolio command (CEX balances need real-time data) + refresh = True + context.user_data.pop("_portfolio_refresh", None) # Clear any stale flag # ======================================== # START ALL FETCHES IN PARALLEL @@ -692,6 +766,9 @@ async def portfolio_command(update: Update, context: ContextTypes.DEFAULT_TYPE) current_value = 0.0 token_cache = {} # Will be populated with Gateway tokens + # Track accounts_distribution for UI updates + accounts_distribution = None + # Helper to update UI async def update_ui(loading_text: str = None): nonlocal current_value @@ -712,18 +789,19 @@ async def update_ui(loading_text: str = None): 'lp_positions': lp_positions, 'active_orders': active_orders, } - message = format_portfolio_overview( + formatted_message = format_portfolio_overview( overview_data, server_name=server_name, server_status=server_status, pnl_indicators=pnl_indicators, changes_24h=changes_24h, - token_cache=token_cache + token_cache=token_cache, + accounts_distribution=accounts_distribution ) if loading_text: - message += f"\n_{escape_markdown_v2(loading_text)}_" + formatted_message += f"\n_{escape_markdown_v2(loading_text)}_" try: - await text_msg.edit_text(message, parse_mode="MarkdownV2") + await text_msg.edit_text(formatted_message, parse_mode="MarkdownV2") except Exception: pass @@ -784,7 +862,6 @@ async def update_ui(loading_text: str = None): # ======================================== history = None token_distribution = None - accounts_distribution = None try: graph_results = await asyncio.gather( @@ -810,12 +887,9 @@ async def update_ui(loading_text: str = None): accounts_distribution_data=accounts_distribution ) - # Create buttons row with Refresh and Settings - keyboard = [[ - InlineKeyboardButton("🔄 Refresh", callback_data="portfolio:refresh"), - InlineKeyboardButton(f"⚙️ Settings ({days}d)", callback_data="portfolio:settings") - ]] - reply_markup = InlineKeyboardMarkup(keyboard) + # Build keyboard with connector buttons + connector_keys = _get_connector_keys(balances) + reply_markup = build_portfolio_keyboard(connector_keys, days) # Send the dashboard image with buttons photo_msg = await message.reply_photo( @@ -824,7 +898,7 @@ async def update_ui(loading_text: str = None): reply_markup=reply_markup ) - # Store message IDs and data for later updates + # Store message IDs and data for later updates (including data for connector detail view) context.user_data["portfolio_text_message_id"] = text_msg.message_id context.user_data["portfolio_photo_message_id"] = photo_msg.message_id context.user_data["portfolio_chat_id"] = message.chat_id @@ -832,6 +906,12 @@ async def update_ui(loading_text: str = None): context.user_data["portfolio_server_name"] = server_name context.user_data["portfolio_server_status"] = server_status context.user_data["portfolio_current_value"] = current_value + # Cache data for connector detail callbacks + context.user_data["portfolio_balances"] = balances + context.user_data["portfolio_accounts_distribution"] = accounts_distribution + context.user_data["portfolio_changes_24h"] = changes_24h + context.user_data["portfolio_pnl_indicators"] = pnl_indicators + context.user_data["portfolio_connector_keys"] = connector_keys except Exception as e: logger.error(f"Error fetching portfolio: {e}", exc_info=True) @@ -874,6 +954,13 @@ async def portfolio_callback_handler(update: Update, context: ContextTypes.DEFAU except Exception: pass await refresh_portfolio_dashboard(update, context) + elif action.startswith("connector:"): + # Show detailed view for a specific connector + connector_key = action.replace("connector:", "") + await handle_connector_detail(update, context, connector_key) + elif action == "back_overview": + # Return to main portfolio overview + await handle_back_to_overview(update, context) else: logger.warning(f"Unknown portfolio action: {action}") @@ -898,11 +985,141 @@ async def handle_portfolio_refresh(update: Update, context: ContextTypes.DEFAULT await refresh_portfolio_dashboard(update, context, refresh=True) -async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFAULT_TYPE, refresh: bool = False) -> None: +async def handle_connector_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, connector_key: str) -> None: + """ + Handle connector inspection - show tokens for specific connector. + + Args: + connector_key: "account:connector" format (e.g., "main:binance") + """ + query = update.callback_query + await query.answer() + + # Get cached data + balances = context.user_data.get("portfolio_balances") + changes_24h = context.user_data.get("portfolio_changes_24h") + total_value = context.user_data.get("portfolio_current_value", 0.0) + text_message_id = context.user_data.get("portfolio_text_message_id") + photo_message_id = context.user_data.get("portfolio_photo_message_id") + chat_id = context.user_data.get("portfolio_chat_id") + server_name = context.user_data.get("portfolio_server_name", "") + + if not balances or not text_message_id or not chat_id: + logger.warning("Missing cached data for connector detail view") + return + + try: + bot = query.get_bot() + + # Format connector detail message + detail_message = format_connector_detail( + balances=balances, + connector_key=connector_key, + changes_24h=changes_24h, + total_value=total_value + ) + + # Update text message with connector detail (no keyboard on text) + await bot.edit_message_text( + chat_id=chat_id, + message_id=text_message_id, + text=detail_message, + parse_mode="MarkdownV2" + ) + + # Update photo message keyboard to show "Back to Overview" button + if photo_message_id: + try: + await bot.edit_message_reply_markup( + chat_id=chat_id, + message_id=photo_message_id, + reply_markup=build_connector_detail_keyboard() + ) + except Exception as e: + logger.warning(f"Failed to update photo keyboard: {e}") + + # Store current view mode + context.user_data["portfolio_view_mode"] = f"connector:{connector_key}" + + except Exception as e: + logger.error(f"Error showing connector detail: {e}", exc_info=True) + + +async def handle_back_to_overview(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle 'Back' button - return to main portfolio overview.""" + query = update.callback_query + await query.answer() + + # Get cached data + balances = context.user_data.get("portfolio_balances") + accounts_distribution = context.user_data.get("portfolio_accounts_distribution") + changes_24h = context.user_data.get("portfolio_changes_24h") + pnl_indicators = context.user_data.get("portfolio_pnl_indicators") + server_name = context.user_data.get("portfolio_server_name") + server_status = context.user_data.get("portfolio_server_status") + connector_keys = context.user_data.get("portfolio_connector_keys", []) + text_message_id = context.user_data.get("portfolio_text_message_id") + photo_message_id = context.user_data.get("portfolio_photo_message_id") + chat_id = context.user_data.get("portfolio_chat_id") + + if not text_message_id or not chat_id: + logger.warning("Missing message IDs for back to overview") + return + + try: + bot = query.get_bot() + config = get_portfolio_prefs(context.user_data) + days = config.get("days", 3) + + # Build overview data + overview_data = { + 'balances': balances, + 'perp_positions': {"positions": [], "total": 0}, + 'lp_positions': {"positions": [], "total": 0}, + 'active_orders': {"orders": [], "total": 0}, + } + + # Format overview message + overview_message = format_portfolio_overview( + overview_data, + server_name=server_name, + server_status=server_status, + pnl_indicators=pnl_indicators, + changes_24h=changes_24h, + accounts_distribution=accounts_distribution + ) + + # Update text message with overview (no keyboard on text) + await bot.edit_message_text( + chat_id=chat_id, + message_id=text_message_id, + text=overview_message, + parse_mode="MarkdownV2" + ) + + # Update photo message keyboard to show connector buttons + if photo_message_id: + try: + await bot.edit_message_reply_markup( + chat_id=chat_id, + message_id=photo_message_id, + reply_markup=build_portfolio_keyboard(connector_keys, days) + ) + except Exception as e: + logger.warning(f"Failed to update photo keyboard: {e}") + + # Clear view mode + context.user_data["portfolio_view_mode"] = "overview" + + except Exception as e: + logger.error(f"Error returning to overview: {e}", exc_info=True) + + +async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFAULT_TYPE, refresh: bool = True) -> None: """Refresh both the text message and photo with new settings Args: - refresh: If True, force refresh balances from exchanges (bypasses API cache) + refresh: If True, force refresh balances from exchanges (bypasses API cache). Defaults to True. """ query = update.callback_query bot = query.get_bot() @@ -916,21 +1133,19 @@ async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFA return try: - from servers import server_manager + from config_manager import get_config_manager from utils.trading_data import get_tokens_for_networks - # Use per-chat default server from server_manager - servers = server_manager.list_servers() + # Use user's preferred server + servers = get_config_manager().list_servers() enabled_servers = [name for name, cfg in servers.items() if cfg.get("enabled", True)] if not enabled_servers: return - default_server = server_manager.get_default_server_for_chat(chat_id) - if default_server and default_server in enabled_servers: - server_name = default_server - else: - server_name = enabled_servers[0] + from handlers.config.user_preferences import get_active_server + preferred = get_active_server(context.user_data) + server_name = preferred if preferred and preferred in enabled_servers else enabled_servers[0] # Update caption to show "Updating..." status try: @@ -942,8 +1157,8 @@ async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFA except Exception as e: logger.warning(f"Failed to update caption to 'Updating': {e}") - client = await server_manager.get_client(server_name) - server_status_info = await server_manager.check_server_status(server_name) + client = await get_config_manager().get_client(server_name) + server_status_info = await get_config_manager().check_server_status(server_name) server_status = server_status_info.get("status", "online") # Get current config (only days, interval is auto-calculated) @@ -991,21 +1206,26 @@ async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFA except Exception as e: logger.debug(f"Failed to fetch tokens for LP networks: {e}") + # Get balances for connector keys + balances = overview_data.get('balances') if overview_data else None + connector_keys = _get_connector_keys(balances) + # Update text message if we have it if text_message_id: - message = format_portfolio_overview( + formatted_message = format_portfolio_overview( overview_data, server_name=server_name, server_status=server_status, pnl_indicators=pnl_indicators, changes_24h=changes_24h, - token_cache=token_cache + token_cache=token_cache, + accounts_distribution=accounts_distribution ) try: await bot.edit_message_text( chat_id=chat_id, message_id=text_message_id, - text=message, + text=formatted_message, parse_mode="MarkdownV2" ) except Exception as e: @@ -1018,12 +1238,8 @@ async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFA accounts_distribution_data=accounts_distribution ) - # Create buttons row with Refresh and Settings - keyboard = [[ - InlineKeyboardButton("🔄 Refresh", callback_data="portfolio:refresh"), - InlineKeyboardButton(f"⚙️ Settings ({days}d)", callback_data="portfolio:settings") - ]] - reply_markup = InlineKeyboardMarkup(keyboard) + # Build keyboard with connector buttons + reply_markup = build_portfolio_keyboard(connector_keys, days) # Update photo with new image from telegram import InputMediaPhoto @@ -1042,6 +1258,12 @@ async def refresh_portfolio_dashboard(update: Update, context: ContextTypes.DEFA context.user_data["portfolio_server_name"] = server_name context.user_data["portfolio_server_status"] = server_status context.user_data["portfolio_current_value"] = current_value + # Cache data for connector detail callbacks + context.user_data["portfolio_balances"] = balances + context.user_data["portfolio_accounts_distribution"] = accounts_distribution + context.user_data["portfolio_changes_24h"] = changes_24h + context.user_data["portfolio_pnl_indicators"] = pnl_indicators + context.user_data["portfolio_connector_keys"] = connector_keys except Exception as e: logger.error(f"Failed to refresh portfolio dashboard: {e}", exc_info=True) diff --git a/handlers/routines/__init__.py b/handlers/routines/__init__.py index e3146c8..19b4f43 100644 --- a/handlers/routines/__init__.py +++ b/handlers/routines/__init__.py @@ -4,47 +4,98 @@ Features: - Auto-discovery of routines from routines/ folder - Text-based config editing (key=value) -- Interval routines: run repeatedly at configurable interval -- One-shot routines: run once (foreground or background) -- Multi-instance support for different configs +- Instance-based execution (each run has frozen config) +- One-shot routines: Run once, can be scheduled (interval/daily) +- Continuous routines: Run forever with internal loop until stopped """ import asyncio import hashlib import logging import time +from datetime import time as dt_time from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes, CallbackContext from handlers import clear_all_input_states -from routines.base import discover_routines, get_routine +from routines.base import discover_routines, get_routine, get_routine_by_state from utils.auth import restricted from utils.telegram_formatters import escape_markdown_v2 logger = logging.getLogger(__name__) -# Job metadata: {job_name: {start_time, config, routine_name}} -_job_info: dict[str, dict] = {} -# Last results: {chat_id: {key: {result, duration, end_time}}} -_last_results: dict[int, dict[str, dict]] = {} +# ============================================================================= +# Constants +# ============================================================================= + +SCHEDULE_PRESETS = [ + ("30s", 30), + ("1m", 60), + ("5m", 300), + ("15m", 900), + ("30m", 1800), + ("1h", 3600), +] + +DAILY_PRESETS = ["06:00", "09:00", "12:00", "18:00", "21:00"] + +# Global storage for continuous routine tasks (not persisted) +_continuous_tasks: dict[str, asyncio.Task] = {} # instance_id -> Task # ============================================================================= -# Utility Functions +# Storage Helpers # ============================================================================= -def _generate_instance_id(routine_name: str, config_dict: dict) -> str: - """Generate unique instance ID from routine name and config.""" - data = f"{routine_name}:{sorted(config_dict.items())}" - return hashlib.md5(data.encode()).hexdigest()[:8] +def _get_drafts(context: ContextTypes.DEFAULT_TYPE) -> dict: + """Get draft configs dict.""" + if "routine_drafts" not in context.user_data: + context.user_data["routine_drafts"] = {} + return context.user_data["routine_drafts"] -def _job_name(chat_id: int, routine_name: str, instance_id: str) -> str: - """Build job name for JobQueue.""" - return f"routine_{chat_id}_{routine_name}_{instance_id}" +def _get_instances(context: ContextTypes.DEFAULT_TYPE) -> dict: + """Get all instances dict.""" + if "routine_instances" not in context.user_data: + context.user_data["routine_instances"] = {} + return context.user_data["routine_instances"] + + +def _get_draft(context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> dict: + """Get draft config for a routine, initializing from defaults if needed.""" + drafts = _get_drafts(context) + if routine_name not in drafts: + routine = get_routine(routine_name) + if routine: + drafts[routine_name] = routine.get_default_config().model_dump() + else: + drafts[routine_name] = {} + return drafts[routine_name] + + +def _set_draft(context: ContextTypes.DEFAULT_TYPE, routine_name: str, config: dict) -> None: + """Update draft config for a routine.""" + drafts = _get_drafts(context) + drafts[routine_name] = config + + +def _get_routine_instances(context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> list[tuple[str, dict]]: + """Get all instances for a specific routine.""" + instances = _get_instances(context) + return [(iid, inst) for iid, inst in instances.items() if inst.get("routine_name") == routine_name] + + +def _generate_instance_id() -> str: + """Generate a short unique instance ID.""" + return hashlib.md5(f"{time.time()}{id(object())}".encode()).hexdigest()[:6] + + +# ============================================================================= +# Formatting Helpers +# ============================================================================= def _display_name(name: str) -> str: @@ -63,270 +114,396 @@ def _format_duration(seconds: float) -> str: return f"{int(seconds // 3600)}h {int((seconds % 3600) // 60)}m" -# ============================================================================= -# Instance Management -# ============================================================================= +def _format_schedule(schedule: dict) -> str: + """Format schedule as human-readable string.""" + stype = schedule.get("type", "once") + if stype == "once": + return "One-time" + elif stype == "interval": + secs = schedule.get("interval_sec", 60) + if secs < 60: + return f"Every {secs}s" + elif secs < 3600: + return f"Every {secs // 60}m" + else: + return f"Every {secs // 3600}h" + elif stype == "continuous": + return "Running" + elif stype == "daily": + return f"Daily @ {schedule.get('daily_time', '09:00')}" + return "Unknown" -def _get_instances( - context: ContextTypes.DEFAULT_TYPE, - chat_id: int, - routine_name: str | None = None, -) -> list[dict]: - """Get running instances for a chat, optionally filtered by routine.""" - prefix = f"routine_{chat_id}_" - instances = [] - - for job in context.job_queue.jobs(): - if not job.name or not job.name.startswith(prefix): - continue +def _format_ago(timestamp: float) -> str: + """Format timestamp as 'X ago' string.""" + diff = time.time() - timestamp + if diff < 60: + return f"{int(diff)}s ago" + elif diff < 3600: + return f"{int(diff // 60)}m ago" + elif diff < 86400: + return f"{int(diff // 3600)}h ago" + return f"{int(diff // 86400)}d ago" - parts = job.name.split("_") - if len(parts) < 4: - continue - rname = "_".join(parts[2:-1]) - inst_id = parts[-1] +def _config_preview(config: dict, max_items: int = 2) -> str: + """Get short config preview.""" + items = list(config.items())[:max_items] + return ", ".join(f"{k}={v}" for k, v in items) - if routine_name and rname != routine_name: - continue - info = _job_info.get(job.name, {}) - instances.append({ - "job_name": job.name, - "routine_name": rname, - "instance_id": inst_id, - "config": info.get("config", {}), - "start_time": info.get("start_time", time.time()), - }) +# ============================================================================= +# Job Management +# ============================================================================= - return instances +def _job_name(chat_id: int, instance_id: str) -> str: + """Build job name for JobQueue.""" + return f"routine_{chat_id}_{instance_id}" -def _stop_instance(context: ContextTypes.DEFAULT_TYPE, chat_id: int, job_name: str) -> bool: - """Stop a running instance. Returns True if stopped.""" - jobs = context.job_queue.get_jobs_by_name(job_name) - if not jobs: - return False - info = _job_info.pop(job_name, {}) - start_time = info.get("start_time", time.time()) - routine_name = info.get("routine_name", "unknown") - instance_id = job_name.split("_")[-1] +def _find_job(context: ContextTypes.DEFAULT_TYPE, chat_id: int, instance_id: str): + """Find a job by instance ID.""" + name = _job_name(chat_id, instance_id) + jobs = context.job_queue.get_jobs_by_name(name) + return jobs[0] if jobs else None - jobs[0].schedule_removal() - # Store result - _store_result(chat_id, routine_name, "(stopped)", time.time() - start_time, instance_id) +def _stop_instance(context: ContextTypes.DEFAULT_TYPE, chat_id: int, instance_id: str) -> bool: + """Stop a job/task and remove instance. Returns True if found.""" + # Try to stop JobQueue job (for scheduled one-shots) + job = _find_job(context, chat_id, instance_id) + if job: + job.schedule_removal() + logger.info(f"Removed scheduled job for instance {instance_id}") - # Clean up state - try: - user_data = context.application.user_data.get(chat_id, {}) - user_data.pop(f"{routine_name}_state_{chat_id}_{instance_id}", None) - except Exception: - pass + # Try to cancel asyncio task (for continuous routines) + task = _continuous_tasks.pop(instance_id, None) + if task and not task.done(): + task.cancel() + logger.info(f"Cancelled continuous task for instance {instance_id}") - return True + # Remove from instances + instances = _get_instances(context) + if instance_id in instances: + routine_name = instances[instance_id].get("routine_name", "unknown") + # Call routine cleanup if available + routine = get_routine(routine_name) + if routine and routine.cleanup_fn: + asyncio.create_task(routine.cleanup_fn(context, instance_id, chat_id)) -def _stop_all( - context: ContextTypes.DEFAULT_TYPE, - chat_id: int, - routine_name: str | None = None, -) -> int: - """Stop all instances, optionally filtered by routine. Returns count.""" - instances = _get_instances(context, chat_id, routine_name) - return sum(1 for i in instances if _stop_instance(context, chat_id, i["job_name"])) + del instances[instance_id] + logger.info(f"Stopped instance {instance_id} ({routine_name})") + return True + return False -# ============================================================================= -# Result Storage -# ============================================================================= +def _stop_all_routine(context: ContextTypes.DEFAULT_TYPE, chat_id: int, routine_name: str) -> int: + """Stop all instances of a routine. Returns count.""" + instances = _get_routine_instances(context, routine_name) + count = 0 + for iid, _ in instances: + if _stop_instance(context, chat_id, iid): + count += 1 + return count -def _store_result( - chat_id: int, +async def _execute_routine( + context: CallbackContext, + instance_id: str, routine_name: str, - result: str, - duration: float, - instance_id: str | None = None, -) -> None: - """Store execution result.""" - if chat_id not in _last_results: - _last_results[chat_id] = {} - - # Always store under routine_name for easy retrieval - _last_results[chat_id][routine_name] = { - "result": result, - "duration": duration, - "end_time": time.time(), - "instance_id": instance_id, - } - + config_dict: dict, + chat_id: int, +) -> tuple[str, float]: + """Execute a routine and return (result, duration).""" + routine = get_routine(routine_name) + if not routine: + return "Routine not found", 0 -def _get_result(chat_id: int, routine_name: str) -> dict | None: - """Get last result for a routine.""" - return _last_results.get(chat_id, {}).get(routine_name) + start = time.time() + # Prepare context for routine + context._chat_id = chat_id + context._instance_id = instance_id + context._user_data = context.application.user_data.get(chat_id, {}) -# ============================================================================= -# Job Callbacks -# ============================================================================= + try: + config = routine.config_class(**config_dict) + result = await routine.run_fn(config, context) + result_text = str(result)[:500] if result else "Completed" + except Exception as e: + result_text = f"Error: {e}" + logger.error(f"Routine {routine_name}[{instance_id}] failed: {e}") + duration = time.time() - start + return result_text, duration -async def _interval_callback(context: CallbackContext) -> None: - """Execute one iteration of an interval routine.""" - data = context.job.data or {} - routine_name = data["routine_name"] - chat_id = data["chat_id"] - config_dict = data["config_dict"] - instance_id = data["instance_id"] +async def _run_continuous_routine( + application, + instance_id: str, + routine_name: str, + config_dict: dict, + chat_id: int, +) -> None: + """Run a continuous routine as an asyncio task.""" routine = get_routine(routine_name) if not routine: + logger.error(f"Routine {routine_name} not found") return - # Prepare context for routine - user_data = context.application.user_data.get(chat_id, {}) - context._chat_id = chat_id - context._instance_id = instance_id - context._user_data = user_data - - job_name = context.job.name - info = _job_info.get(job_name, {}) - start_time = info.get("start_time", time.time()) + # Create a mock context for the routine + class MockContext: + def __init__(self): + self._chat_id = chat_id + self._instance_id = instance_id + # Ensure user_data dict exists in application + if chat_id not in application.user_data: + application.user_data[chat_id] = {} + self._user_data = application.user_data[chat_id] + self.bot = application.bot + self.application = application + + @property + def user_data(self): + """Provide user_data property for compatibility.""" + return self._user_data + + context = MockContext() try: config = routine.config_class(**config_dict) + logger.info(f"Starting continuous routine {routine_name}[{instance_id}]") result = await routine.run_fn(config, context) - result_text = str(result)[:500] if result else "Running..." - logger.debug(f"{routine_name}[{instance_id}]: {result_text[:50]}") + logger.info(f"Continuous routine {routine_name}[{instance_id}] ended: {result}") + except asyncio.CancelledError: + logger.info(f"Continuous routine {routine_name}[{instance_id}] cancelled") except Exception as e: - result_text = f"Error: {e}" - logger.error(f"{routine_name}[{instance_id}] error: {e}") + logger.error(f"Continuous routine {routine_name}[{instance_id}] error: {e}") - # Store result for display in detail view - _store_result(chat_id, routine_name, result_text, time.time() - start_time) + # Clean up instance when task ends + instances = application.user_data.get(chat_id, {}).get("routine_instances", {}) + if instance_id in instances: + del instances[instance_id] -async def _oneshot_callback(context: CallbackContext) -> None: - """Execute a one-shot routine and update UI or send message.""" +async def _interval_job_callback(context: CallbackContext) -> None: + """Job callback for interval-scheduled one-shot routines. Sends message each run.""" data = context.job.data or {} + instance_id = data["instance_id"] routine_name = data["routine_name"] - chat_id = data["chat_id"] config_dict = data["config_dict"] - instance_id = data["instance_id"] - msg_id = data.get("msg_id") - background = data.get("background", False) + chat_id = data["chat_id"] - job_name = context.job.name - info = _job_info.get(job_name, {}) - start_time = info.get("start_time", time.time()) + # Check if instance still exists (may have been stopped) + instances = context.application.user_data.get(chat_id, {}).get("routine_instances", {}) + if instance_id not in instances: + logger.warning(f"Instance {instance_id} no longer exists, skipping execution") + return - routine = get_routine(routine_name) - if not routine: + result, duration = await _execute_routine(context, instance_id, routine_name, config_dict, chat_id) + + # Re-check instance exists after execution (may have been stopped during run) + instances = context.application.user_data.get(chat_id, {}).get("routine_instances", {}) + if instance_id not in instances: + logger.warning(f"Instance {instance_id} was removed during execution") return - # Prepare context - user_data = context.application.user_data.get(chat_id, {}) - context._chat_id = chat_id - context._instance_id = instance_id - context._user_data = user_data + instances[instance_id]["last_run_at"] = time.time() + instances[instance_id]["last_result"] = result + instances[instance_id]["last_duration"] = duration + run_count = instances[instance_id].get("run_count", 0) + 1 + instances[instance_id]["run_count"] = run_count + # Send result message for scheduled one-shot routines + schedule = instances[instance_id].get("schedule", {}) + interval_str = _format_schedule(schedule) + icon = "✅" if not result.startswith("Error") else "❌" + text = ( + f"{icon} *{escape_markdown_v2(_display_name(routine_name))}* `{instance_id}`\n" + f"⏱️ {escape_markdown_v2(interval_str)} \\| Run \\#{run_count} \\| {escape_markdown_v2(_format_duration(duration))}\n\n" + f"```\n{result[:400]}\n```" + ) try: - config = routine.config_class(**config_dict) - result = await routine.run_fn(config, context) - result_text = str(result)[:500] if result else "Completed" - status = "completed" + await context.bot.send_message(chat_id=chat_id, text=text, parse_mode="MarkdownV2") except Exception as e: - result_text = f"Error: {e}" - status = "error" - logger.error(f"{routine_name}[{instance_id}] failed: {e}") + logger.error(f"Failed to send interval result: {e}") - duration = time.time() - start_time - _job_info.pop(job_name, None) - _store_result(chat_id, routine_name, result_text, duration, instance_id) + logger.info(f"Routine {routine_name}[{instance_id}] run #{run_count}: {result[:50]}...") + + +async def _oneshot_job_callback(context: CallbackContext) -> None: + """Job callback for one-time runs.""" + data = context.job.data or {} + instance_id = data["instance_id"] + routine_name = data["routine_name"] + config_dict = data["config_dict"] + chat_id = data["chat_id"] + msg_id = data.get("msg_id") + background = data.get("background", False) + + result, duration = await _execute_routine(context, instance_id, routine_name, config_dict, chat_id) + + # Remove one-shot instance after completion + instances = context.application.user_data.get(chat_id, {}).get("routine_instances", {}) + if instance_id in instances: + del instances[instance_id] if background: # Send result as new message - icon = "✅" if status == "completed" else "❌" + icon = "✅" if not result.startswith("Error") else "❌" text = ( f"{icon} *{escape_markdown_v2(_display_name(routine_name))}*\n" f"Duration: {escape_markdown_v2(_format_duration(duration))}\n\n" - f"```\n{result_text[:400]}\n```" + f"```\n{result[:400]}\n```" ) try: - await context.bot.send_message( - chat_id=chat_id, - text=text, - parse_mode="MarkdownV2", - ) + await context.bot.send_message(chat_id=chat_id, text=text, parse_mode="MarkdownV2") except Exception as e: - logger.error(f"Failed to send background result: {e}") - else: - # Update existing message - await _update_after_run(context, routine_name, chat_id, msg_id, config_dict, result_text, status) - + logger.error(f"Failed to send result: {e}") + elif msg_id: + # Update the detail view + await _refresh_detail_msg(context, chat_id, msg_id, routine_name, result, duration) -async def _update_after_run( - context: CallbackContext, - routine_name: str, - chat_id: int, - msg_id: int | None, - config_dict: dict, - result_text: str, - status: str, -) -> None: - """Update the routine detail message after execution.""" - if not msg_id: - return - routine = get_routine(routine_name) - if not routine: - return +async def _daily_job_callback(context: CallbackContext) -> None: + """Job callback for daily-scheduled routines.""" + data = context.job.data or {} + instance_id = data["instance_id"] + routine_name = data["routine_name"] + config_dict = data["config_dict"] + chat_id = data["chat_id"] - fields = routine.get_fields() - config_lines = [f"{k}={config_dict.get(k, v['default'])}" for k, v in fields.items()] + result, duration = await _execute_routine(context, instance_id, routine_name, config_dict, chat_id) - icon = "✅" if status == "completed" else "❌" - result_info = _get_result(chat_id, routine_name) - duration_str = _format_duration(result_info["duration"]) if result_info else "" + # Update instance state + instances = context.application.user_data.get(chat_id, {}).get("routine_instances", {}) + if instance_id in instances: + instances[instance_id]["last_run_at"] = time.time() + instances[instance_id]["last_result"] = result + instances[instance_id]["last_duration"] = duration + instances[instance_id]["run_count"] = instances[instance_id].get("run_count", 0) + 1 + # Send notification + icon = "✅" if not result.startswith("Error") else "❌" text = ( - f"⚡ *{escape_markdown_v2(_display_name(routine_name).upper())}*\n" - f"━━━━━━━━━━━━━━━━━━━━━\n" - f"_{escape_markdown_v2(routine.description)}_\n\n" - f"Status: ⚪ Ready\n\n" - f"┌─ Config ─────────────────\n" - f"```\n{chr(10).join(config_lines)}\n```\n" - f"└─ _✏️ send key\\=value to edit_\n\n" - f"┌─ {icon} Result ─ {escape_markdown_v2(duration_str)} ────\n" - f"```\n{result_text[:300]}\n```\n" - f"└─────────────────────────" + f"{icon} *Daily: {escape_markdown_v2(_display_name(routine_name))}*\n" + f"```\n{result[:400]}\n```" ) + try: + await context.bot.send_message(chat_id=chat_id, text=text, parse_mode="MarkdownV2") + except Exception as e: + logger.error(f"Failed to send daily result: {e}") - keyboard = [ - [ - InlineKeyboardButton("▶️ Run", callback_data=f"routines:run:{routine_name}"), - InlineKeyboardButton("🔄 Background", callback_data=f"routines:bg:{routine_name}"), - ], - [ - InlineKeyboardButton("❓ Help", callback_data=f"routines:help:{routine_name}"), - InlineKeyboardButton("« Back", callback_data="routines:menu"), - ], - ] - try: - await context.bot.edit_message_text( +def _create_scheduled_instance( + context: ContextTypes.DEFAULT_TYPE, + chat_id: int, + routine_name: str, + config_dict: dict, + schedule: dict, + msg_id: int | None = None, + background: bool = False, +) -> str: + """Create a scheduled one-shot instance. Returns instance_id.""" + instance_id = _generate_instance_id() + job_name_str = _job_name(chat_id, instance_id) + + # Store instance + instances = _get_instances(context) + instances[instance_id] = { + "routine_name": routine_name, + "config": config_dict.copy(), + "schedule": schedule.copy(), + "status": "running", + "created_at": time.time(), + "last_run_at": None, + "last_result": None, + "last_duration": None, + "run_count": 0, + } + + job_data = { + "instance_id": instance_id, + "routine_name": routine_name, + "config_dict": config_dict.copy(), # Copy to isolate from draft changes + "chat_id": chat_id, + "msg_id": msg_id, + "background": background, + } + + stype = schedule.get("type", "once") + + if stype == "once": + context.job_queue.run_once( + _oneshot_job_callback, + when=0.1, + data=job_data, + name=job_name_str, chat_id=chat_id, - message_id=msg_id, - text=text, - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard), ) - except Exception as e: - if "not modified" not in str(e).lower(): - logger.debug(f"Could not update message: {e}") + elif stype == "interval": + interval = schedule.get("interval_sec", 60) + context.job_queue.run_repeating( + _interval_job_callback, + interval=interval, + first=0.5, + data=job_data, + name=job_name_str, + chat_id=chat_id, + ) + elif stype == "daily": + time_str = schedule.get("daily_time", "09:00") + hour, minute = map(int, time_str.split(":")) + context.job_queue.run_daily( + _daily_job_callback, + time=dt_time(hour=hour, minute=minute), + data=job_data, + name=job_name_str, + chat_id=chat_id, + ) + + return instance_id + + +def _create_continuous_instance( + context: ContextTypes.DEFAULT_TYPE, + chat_id: int, + routine_name: str, + config_dict: dict, +) -> str: + """Create a continuous routine instance. Returns instance_id.""" + instance_id = _generate_instance_id() + + # Store instance + instances = _get_instances(context) + instances[instance_id] = { + "routine_name": routine_name, + "config": config_dict.copy(), + "schedule": {"type": "continuous"}, + "status": "running", + "created_at": time.time(), + "last_run_at": None, + "last_result": None, + "last_duration": None, + "run_count": 0, + } + + # Create and store asyncio task (copy config to isolate from draft changes) + frozen_config = config_dict.copy() + task = asyncio.create_task( + _run_continuous_routine( + context.application, + instance_id, + routine_name, + frozen_config, + chat_id, + ) + ) + _continuous_tasks[instance_id] = task + + return instance_id # ============================================================================= @@ -334,11 +511,35 @@ async def _update_after_run( # ============================================================================= +async def _edit_or_send(update: Update, text: str, reply_markup: InlineKeyboardMarkup) -> None: + """Edit message if callback, otherwise send new.""" + if update.callback_query: + try: + await update.callback_query.message.edit_text( + text, parse_mode="MarkdownV2", reply_markup=reply_markup + ) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.warning(f"Edit failed: {e}") + else: + msg = update.message or update.callback_query.message + await msg.reply_text(text, parse_mode="MarkdownV2", reply_markup=reply_markup) + + async def _show_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Show main routines menu.""" chat_id = update.effective_chat.id routines = discover_routines(force_reload=True) - all_instances = _get_instances(context, chat_id) + all_instances = _get_instances(context) + + # Count running instances per routine + running_counts = {} + for inst in all_instances.values(): + rname = inst.get("routine_name") + if inst.get("status") == "running": + running_counts[rname] = running_counts.get(rname, 0) + 1 + + total_running = sum(running_counts.values()) if not routines: text = ( @@ -351,28 +552,26 @@ async def _show_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None else: keyboard = [] - if all_instances: + if total_running > 0: keyboard.append([ - InlineKeyboardButton(f"📋 Running ({len(all_instances)})", callback_data="routines:tasks") + InlineKeyboardButton(f"📋 Running ({total_running})", callback_data="routines:tasks") ]) for name in sorted(routines.keys()): routine = routines[name] - count = len(_get_instances(context, chat_id, name)) + count = running_counts.get(name, 0) if count > 0: label = f"🟢 {_display_name(name)} ({count})" else: - icon = "🔄" if routine.is_interval else "⚡" + icon = "♾️" if routine.is_continuous else "⚡" label = f"{icon} {_display_name(name)}" keyboard.append([InlineKeyboardButton(label, callback_data=f"routines:select:{name}")]) keyboard.append([InlineKeyboardButton("🔄 Reload", callback_data="routines:reload")]) - running = len(all_instances) - status = f"🟢 {running} running" if running else "All idle" - + status = f"🟢 {total_running} running" if total_running else "All idle" text = ( "⚡ *ROUTINES*\n" "━━━━━━━━━━━━━━━━━━━━━\n\n" @@ -386,36 +585,40 @@ async def _show_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None async def _show_tasks(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Show all running tasks.""" chat_id = update.effective_chat.id - instances = _get_instances(context, chat_id) + instances = _get_instances(context) + + running = [(iid, inst) for iid, inst in instances.items() if inst.get("status") == "running"] - if not instances: + if not running: text = ( - "⚡ *RUNNING TASKS*\n" + "📋 *RUNNING TASKS*\n" "━━━━━━━━━━━━━━━━━━━━━\n\n" "No tasks running\\." ) keyboard = [[InlineKeyboardButton("« Back", callback_data="routines:menu")]] else: - lines = ["⚡ *RUNNING TASKS*", "━━━━━━━━━━━━━━━━━━━━━\n"] + lines = ["📋 *RUNNING TASKS*", "━━━━━━━━━━━━━━━━━━━━━\n"] keyboard = [] - for inst in instances: + for iid, inst in running: name = inst["routine_name"] - inst_id = inst["instance_id"] - duration = _format_duration(time.time() - inst["start_time"]) - config = inst["config"] - - lines.append(f"🟢 *{escape_markdown_v2(_display_name(name))}* `{inst_id}`") - lines.append(f" {escape_markdown_v2(duration)}") - - if config: - preview = ", ".join(f"{k}\\={v}" for k, v in list(config.items())[:2]) - lines.append(f" `{preview}`") + schedule = inst.get("schedule", {}) + created = inst.get("created_at", time.time()) + config = inst.get("config", {}) + run_count = inst.get("run_count", 0) + + lines.append(f"🟢 *{escape_markdown_v2(_display_name(name))}* `{iid}`") + lines.append(f" {escape_markdown_v2(_format_schedule(schedule))} \\| {escape_markdown_v2(_format_ago(created))}") + lines.append(f" Runs: {run_count} \\| `{escape_markdown_v2(_config_preview(config))}`") + + if inst.get("last_result"): + result_preview = inst["last_result"][:40].replace("\n", " ") + lines.append(f" └ {escape_markdown_v2(result_preview)}\\.\\.\\.") lines.append("") keyboard.append([ - InlineKeyboardButton(f"⏹ {_display_name(name)[:10]}[{inst_id}]", - callback_data=f"routines:stop:{inst['job_name']}") + InlineKeyboardButton(f"⏹ {_display_name(name)[:12]}[{iid}]", + callback_data=f"routines:stop:{iid}") ]) keyboard.append([InlineKeyboardButton("⏹ Stop All", callback_data="routines:stopall")]) @@ -426,60 +629,51 @@ async def _show_tasks(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non async def _show_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: - """Show routine configuration and controls.""" + """Show routine detail with draft config and running instances.""" chat_id = update.effective_chat.id routine = get_routine(routine_name) if not routine: - await update.callback_query.answer("Routine not found") + if update.callback_query: + await update.callback_query.answer("Routine not found") return - # Get or initialize config - config_key = f"routine_config_{routine_name}" - if config_key not in context.user_data: - context.user_data[config_key] = routine.get_default_config().model_dump() - - config = context.user_data[config_key] + # Get draft config + draft = _get_draft(context, routine_name) fields = routine.get_fields() - instances = _get_instances(context, chat_id, routine_name) + + # Get running instances for this routine + instances = _get_routine_instances(context, routine_name) + running = [(iid, inst) for iid, inst in instances if inst.get("status") == "running"] # Build config display - config_lines = [f"{k}={config.get(k, v['default'])}" for k, v in fields.items()] + config_lines = [f"{k}={draft.get(k, v['default'])}" for k, v in fields.items()] - # Status - if instances: - status = f"🟢 {len(instances)} running" + # Status line + if running: + status = f"🟢 {len(running)} running" else: status = "⚪ Ready" - # Instances section + # Type indicator + type_str = "♾️ Continuous" if routine.is_continuous else "⚡ One\\-shot" + + # Build instances section inst_section = "" - if instances: - inst_lines = ["\n┌─ Running ─────────────────"] - for inst in instances[:5]: - dur = _format_duration(time.time() - inst["start_time"]) - cfg = ", ".join(f"{k}={v}" for k, v in list(inst["config"].items())[:2]) - inst_lines.append(f"│ `{inst['instance_id']}` {escape_markdown_v2(cfg)} \\({escape_markdown_v2(dur)}\\)") - if len(instances) > 5: - inst_lines.append(f"│ _\\+{len(instances) - 5} more_") - inst_lines.append("└───────────────────────────") + if running: + inst_lines = ["\n┌─ Running Instances ────────"] + for iid, inst in running[:5]: + sched = _format_schedule(inst.get("schedule", {})) + ago = _format_ago(inst.get("created_at", time.time())) + cfg_prev = _config_preview(inst.get("config", {}), 1) + runs = inst.get("run_count", 0) + inst_lines.append(f"│ `{iid}` {escape_markdown_v2(cfg_prev)}") + inst_lines.append(f"│ {escape_markdown_v2(sched)} \\| {runs} runs \\| {escape_markdown_v2(ago)}") + if len(running) > 5: + inst_lines.append(f"│ _\\+{len(running) - 5} more_") + inst_lines.append("└────────────────────────────") inst_section = "\n".join(inst_lines) - # Result section - result_section = "" - last = _get_result(chat_id, routine_name) - if last: - icon = "❌" if last["result"].startswith("Error") else "✅" - dur = _format_duration(last["duration"]) - result_section = ( - f"\n\n┌─ {icon} Last ─ {escape_markdown_v2(dur)} ────\n" - f"```\n{last['result'][:250]}\n```\n" - f"└───────────────────────────" - ) - - # Type indicator - type_str = "🔄 Interval" if routine.is_interval else "⚡ One\\-shot" - text = ( f"⚡ *{escape_markdown_v2(_display_name(routine_name).upper())}*\n" f"━━━━━━━━━━━━━━━━━━━━━\n" @@ -490,37 +684,43 @@ async def _show_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, routi f"```\n{chr(10).join(config_lines)}\n```\n" f"└─ _✏️ send key\\=value to edit_" f"{inst_section}" - f"{result_section}" ) - # Build keyboard - if routine.is_interval: + # Build keyboard based on routine type + if routine.is_continuous: + # Continuous routine - just start/stop keyboard = [ - [InlineKeyboardButton("▶️ Start", callback_data=f"routines:start:{routine_name}")], + [ + InlineKeyboardButton("▶️ Start", callback_data=f"routines:start:{routine_name}"), + ], ] else: + # One-shot routine - can be scheduled keyboard = [ [ InlineKeyboardButton("▶️ Run", callback_data=f"routines:run:{routine_name}"), InlineKeyboardButton("🔄 Background", callback_data=f"routines:bg:{routine_name}"), ], + [ + InlineKeyboardButton("⏱️ Schedule", callback_data=f"routines:sched:{routine_name}"), + ], ] - if instances: - keyboard.append([InlineKeyboardButton(f"⏹ Stop All ({len(instances)})", - callback_data=f"routines:stopall:{routine_name}")]) + if running: + keyboard.append([ + InlineKeyboardButton(f"⏹ Stop All ({len(running)})", callback_data=f"routines:stopall:{routine_name}") + ]) keyboard.append([ InlineKeyboardButton("❓ Help", callback_data=f"routines:help:{routine_name}"), InlineKeyboardButton("« Back", callback_data="routines:menu"), ]) - # Store state for config editing + # Store editing state context.user_data["routines_state"] = "editing" context.user_data["routines_editing"] = { "routine": routine_name, "fields": fields, - "config_key": config_key, } msg = update.callback_query.message if update.callback_query else None @@ -531,6 +731,72 @@ async def _show_detail(update: Update, context: ContextTypes.DEFAULT_TYPE, routi await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) +async def _show_schedule_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: + """Show schedule options menu.""" + routine = get_routine(routine_name) + if not routine: + return + + text = ( + f"⏱️ *Schedule: {escape_markdown_v2(_display_name(routine_name))}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n\n" + f"Choose how often to run this routine\\.\n" + f"Config will be frozen at schedule time\\.\n" + f"Results will be sent as messages\\." + ) + + # Interval buttons - 3 per row + row1 = [ + InlineKeyboardButton(label, callback_data=f"routines:interval:{routine_name}:{secs}") + for label, secs in SCHEDULE_PRESETS[:3] + ] + row2 = [ + InlineKeyboardButton(label, callback_data=f"routines:interval:{routine_name}:{secs}") + for label, secs in SCHEDULE_PRESETS[3:6] + ] + + keyboard = [ + row1, + row2, + [InlineKeyboardButton("📅 Daily...", callback_data=f"routines:daily:{routine_name}")], + [InlineKeyboardButton("« Cancel", callback_data=f"routines:select:{routine_name}")], + ] + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_daily_menu(update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: + """Show daily schedule time options.""" + text = ( + f"📅 *Daily Schedule: {escape_markdown_v2(_display_name(routine_name))}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n\n" + f"Choose a time \\(server timezone\\)\\.\n" + f"Or send custom time as `HH:MM`\\." + ) + + # Time preset buttons + row1 = [ + InlineKeyboardButton(t, callback_data=f"routines:dailyat:{routine_name}:{t}") + for t in DAILY_PRESETS[:3] + ] + row2 = [ + InlineKeyboardButton(t, callback_data=f"routines:dailyat:{routine_name}:{t}") + for t in DAILY_PRESETS[3:] + ] + + keyboard = [ + row1, + row2, + [InlineKeyboardButton("« Back", callback_data=f"routines:sched:{routine_name}")], + ] + + # Store state for custom time input + context.user_data["routines_state"] = "daily_time" + context.user_data["routines_editing"] = {"routine": routine_name} + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + async def _show_help(update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: """Show field descriptions.""" routine = get_routine(routine_name) @@ -551,18 +817,113 @@ async def _show_help(update: Update, context: ContextTypes.DEFAULT_TYPE, routine await _edit_or_send(update, "\n".join(lines), InlineKeyboardMarkup(keyboard)) +async def _refresh_detail_msg( + context: CallbackContext, + chat_id: int, + msg_id: int, + routine_name: str, + result: str | None = None, + duration: float | None = None, +) -> None: + """Refresh the routine detail message after execution.""" + routine = get_routine(routine_name) + if not routine: + return + + user_data = context.application.user_data.get(chat_id, {}) + drafts = user_data.get("routine_drafts", {}) + draft = drafts.get(routine_name, {}) + + if not draft: + draft = routine.get_default_config().model_dump() + + fields = routine.get_fields() + config_lines = [f"{k}={draft.get(k, v['default'])}" for k, v in fields.items()] + + # Get instances + instances = user_data.get("routine_instances", {}) + running = [(iid, inst) for iid, inst in instances.items() + if inst.get("routine_name") == routine_name and inst.get("status") == "running"] + + status = f"🟢 {len(running)} running" if running else "⚪ Ready" + type_str = "♾️ Continuous" if routine.is_continuous else "⚡ One\\-shot" + + # Result section + result_section = "" + if result is not None: + icon = "❌" if result.startswith("Error") else "✅" + dur_str = _format_duration(duration) if duration else "" + result_section = ( + f"\n\n┌─ {icon} Result ─ {escape_markdown_v2(dur_str)} ────\n" + f"```\n{result[:250]}\n```\n" + f"└────────────────────────────" + ) + + text = ( + f"⚡ *{escape_markdown_v2(_display_name(routine_name).upper())}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n" + f"_{escape_markdown_v2(routine.description)}_\n" + f"{type_str}\n\n" + f"Status: {escape_markdown_v2(status)}\n\n" + f"┌─ Config ─────────────────\n" + f"```\n{chr(10).join(config_lines)}\n```\n" + f"└─ _✏️ send key\\=value to edit_" + f"{result_section}" + ) + + # Build keyboard based on routine type + if routine.is_continuous: + keyboard = [ + [ + InlineKeyboardButton("▶️ Start", callback_data=f"routines:start:{routine_name}"), + ], + ] + else: + keyboard = [ + [ + InlineKeyboardButton("▶️ Run", callback_data=f"routines:run:{routine_name}"), + InlineKeyboardButton("🔄 Background", callback_data=f"routines:bg:{routine_name}"), + ], + [ + InlineKeyboardButton("⏱️ Schedule", callback_data=f"routines:sched:{routine_name}"), + ], + ] + + if running: + keyboard.append([ + InlineKeyboardButton(f"⏹ Stop All ({len(running)})", callback_data=f"routines:stopall:{routine_name}") + ]) + + keyboard.append([ + InlineKeyboardButton("❓ Help", callback_data=f"routines:help:{routine_name}"), + InlineKeyboardButton("« Back", callback_data="routines:menu"), + ]) + + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=msg_id, + text=text, + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard), + ) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.debug(f"Could not refresh: {e}") + + # ============================================================================= # Actions # ============================================================================= -async def _run_oneshot( +async def _run_once( update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str, background: bool = False, ) -> None: - """Run a one-shot routine.""" + """Run one-shot routine once with current draft config.""" chat_id = update.effective_chat.id routine = get_routine(routine_name) @@ -570,50 +931,70 @@ async def _run_oneshot( await update.callback_query.answer("Routine not found") return - config_key = f"routine_config_{routine_name}" - config_dict = context.user_data.get(config_key, {}) + if routine.is_continuous: + await update.callback_query.answer("Use Start for continuous routines") + return + + draft = _get_draft(context, routine_name) try: - routine.config_class(**config_dict) + routine.config_class(**draft) except Exception as e: await update.callback_query.answer(f"Config error: {e}") return - instance_id = _generate_instance_id(routine_name, config_dict) - job = _job_name(chat_id, routine_name, instance_id) msg_id = context.user_data.get("routines_msg_id") + schedule = {"type": "once"} - _job_info[job] = { - "start_time": time.time(), - "config": config_dict, - "routine_name": routine_name, - } - - context.job_queue.run_once( - _oneshot_callback, - when=0.1, - data={ - "routine_name": routine_name, - "chat_id": chat_id, - "config_dict": config_dict, - "instance_id": instance_id, - "msg_id": msg_id, - "background": background, - }, - name=job, - chat_id=chat_id, - ) + _create_scheduled_instance(context, chat_id, routine_name, draft, schedule, msg_id, background) if background: await update.callback_query.answer("🔄 Running in background...") + await _show_detail(update, context, routine_name) else: + # Don't refresh detail view - job callback will update it with result await update.callback_query.answer("▶️ Running...") + +async def _start_continuous( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + routine_name: str, +) -> None: + """Start a continuous routine.""" + chat_id = update.effective_chat.id + routine = get_routine(routine_name) + + if not routine: + await update.callback_query.answer("Routine not found") + return + + if not routine.is_continuous: + await update.callback_query.answer("Not a continuous routine") + return + + draft = _get_draft(context, routine_name) + + try: + routine.config_class(**draft) + except Exception as e: + await update.callback_query.answer(f"Config error: {e}") + return + + instance_id = _create_continuous_instance(context, chat_id, routine_name, draft) + logger.info(f"Started continuous routine {instance_id}: {routine_name}") + + await update.callback_query.answer("▶️ Started") await _show_detail(update, context, routine_name) -async def _start_interval(update: Update, context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: - """Start an interval routine.""" +async def _start_interval( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + routine_name: str, + interval_sec: int, +) -> None: + """Start one-shot routine with interval schedule.""" chat_id = update.effective_chat.id routine = get_routine(routine_name) @@ -621,49 +1002,58 @@ async def _start_interval(update: Update, context: ContextTypes.DEFAULT_TYPE, ro await update.callback_query.answer("Routine not found") return - config_key = f"routine_config_{routine_name}" - config_dict = context.user_data.get(config_key, {}) + draft = _get_draft(context, routine_name) try: - config = routine.config_class(**config_dict) + routine.config_class(**draft) except Exception as e: await update.callback_query.answer(f"Config error: {e}") return - instance_id = _generate_instance_id(routine_name, config_dict) - job = _job_name(chat_id, routine_name, instance_id) + schedule = {"type": "interval", "interval_sec": interval_sec} + instance_id = _create_scheduled_instance(context, chat_id, routine_name, draft, schedule) + logger.info(f"Created interval schedule {instance_id} for {routine_name}: every {interval_sec}s") - # Check duplicate - if context.job_queue.get_jobs_by_name(job): - await update.callback_query.answer("⚠️ Already running with this config") - await _show_detail(update, context, routine_name) + await update.callback_query.answer(f"⏱️ Scheduled every {_format_schedule(schedule)}") + await _show_detail(update, context, routine_name) + + +async def _start_daily( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + routine_name: str, + time_str: str, +) -> None: + """Start routine with daily schedule.""" + chat_id = update.effective_chat.id + routine = get_routine(routine_name) + + if not routine: + await update.callback_query.answer("Routine not found") return - interval = getattr(config, "interval_sec", 5) - msg_id = context.user_data.get("routines_msg_id") + draft = _get_draft(context, routine_name) - _job_info[job] = { - "start_time": time.time(), - "config": config_dict, - "routine_name": routine_name, - } + try: + routine.config_class(**draft) + except Exception as e: + await update.callback_query.answer(f"Config error: {e}") + return - context.job_queue.run_repeating( - _interval_callback, - interval=interval, - first=0.1, - data={ - "routine_name": routine_name, - "chat_id": chat_id, - "config_dict": config_dict, - "instance_id": instance_id, - "msg_id": msg_id, - }, - name=job, - chat_id=chat_id, - ) + # Validate time format + try: + hour, minute = map(int, time_str.split(":")) + if not (0 <= hour < 24 and 0 <= minute < 60): + raise ValueError() + except (ValueError, AttributeError): + await update.callback_query.answer(f"Invalid time: {time_str}") + return + + schedule = {"type": "daily", "daily_time": time_str} + instance_id = _create_scheduled_instance(context, chat_id, routine_name, draft, schedule) + logger.info(f"Created daily schedule {instance_id} for {routine_name}: at {time_str}") - await update.callback_query.answer(f"🔄 Started (every {interval}s)") + await update.callback_query.answer(f"📅 Scheduled daily at {time_str}") await _show_detail(update, context, routine_name) @@ -677,9 +1067,8 @@ async def _process_config(update: Update, context: ContextTypes.DEFAULT_TYPE, te editing = context.user_data.get("routines_editing", {}) routine_name = editing.get("routine") fields = editing.get("fields", {}) - config_key = editing.get("config_key") - if not routine_name or not config_key: + if not routine_name: return # Delete user message @@ -688,6 +1077,14 @@ async def _process_config(update: Update, context: ContextTypes.DEFAULT_TYPE, te except Exception: pass + routine = get_routine(routine_name) + if not routine: + return + + if not fields: + fields = routine.get_fields() + + draft = _get_draft(context, routine_name) updates = {} errors = [] @@ -726,11 +1123,8 @@ async def _process_config(update: Update, context: ContextTypes.DEFAULT_TYPE, te asyncio.create_task(_delete_after(msg, 3)) return - if config_key not in context.user_data: - routine = get_routine(routine_name) - context.user_data[config_key] = routine.get_default_config().model_dump() - - context.user_data[config_key].update(updates) + draft.update(updates) + _set_draft(context, routine_name, draft) msg = await update.message.reply_text( f"✅ {', '.join(f'`{k}={v}`' for k, v in updates.items())}", @@ -738,87 +1132,66 @@ async def _process_config(update: Update, context: ContextTypes.DEFAULT_TYPE, te ) asyncio.create_task(_delete_after(msg, 2)) - await _refresh_detail(context, routine_name) - - -async def _refresh_detail(context: ContextTypes.DEFAULT_TYPE, routine_name: str) -> None: - """Refresh routine detail after config update.""" + # Refresh detail view msg_id = context.user_data.get("routines_msg_id") chat_id = context.user_data.get("routines_chat_id") + if msg_id and chat_id: + await _refresh_detail_msg(context, chat_id, msg_id, routine_name) - if not msg_id or not chat_id: - return - routine = get_routine(routine_name) - if not routine: +async def _process_daily_time(update: Update, context: ContextTypes.DEFAULT_TYPE, text: str) -> None: + """Process custom daily time input (HH:MM).""" + editing = context.user_data.get("routines_editing", {}) + routine_name = editing.get("routine") + + if not routine_name: return - config_key = f"routine_config_{routine_name}" - config = context.user_data.get(config_key, {}) - fields = routine.get_fields() - instances = _get_instances(context, chat_id, routine_name) + # Delete user message + try: + await update.message.delete() + except Exception: + pass - config_lines = [f"{k}={config.get(k, v['default'])}" for k, v in fields.items()] + # Validate time format + text = text.strip() + try: + hour, minute = map(int, text.split(":")) + if not (0 <= hour < 24 and 0 <= minute < 60): + raise ValueError() + time_str = f"{hour:02d}:{minute:02d}" + except (ValueError, AttributeError): + msg = await update.message.reply_text(f"❌ Invalid time. Use `HH:MM` format.", parse_mode="Markdown") + asyncio.create_task(_delete_after(msg, 3)) + return - status = f"🟢 {len(instances)} running" if instances else "⚪ Ready" - type_str = "🔄 Interval" if routine.is_interval else "⚡ One\\-shot" + # Create daily schedule + chat_id = update.effective_chat.id + routine = get_routine(routine_name) + if not routine: + return - # Result section - result_section = "" - last = _get_result(chat_id, routine_name) - if last: - icon = "❌" if last["result"].startswith("Error") else "✅" - dur = _format_duration(last["duration"]) - result_section = ( - f"\n\n┌─ {icon} Last ─ {escape_markdown_v2(dur)} ────\n" - f"```\n{last['result'][:250]}\n```\n" - f"└───────────────────────────" - ) + draft = _get_draft(context, routine_name) - text = ( - f"⚡ *{escape_markdown_v2(_display_name(routine_name).upper())}*\n" - f"━━━━━━━━━━━━━━━━━━━━━\n" - f"_{escape_markdown_v2(routine.description)}_\n" - f"{type_str}\n\n" - f"Status: {escape_markdown_v2(status)}\n\n" - f"┌─ Config ─────────────────\n" - f"```\n{chr(10).join(config_lines)}\n```\n" - f"└─ _✏️ send key\\=value to edit_" - f"{result_section}" - ) + try: + routine.config_class(**draft) + except Exception as e: + msg = await update.message.reply_text(f"❌ Config error: {e}") + asyncio.create_task(_delete_after(msg, 3)) + return - if routine.is_interval: - keyboard = [ - [InlineKeyboardButton("▶️ Start", callback_data=f"routines:start:{routine_name}")], - ] - else: - keyboard = [ - [ - InlineKeyboardButton("▶️ Run", callback_data=f"routines:run:{routine_name}"), - InlineKeyboardButton("🔄 Background", callback_data=f"routines:bg:{routine_name}"), - ], - ] + schedule = {"type": "daily", "daily_time": time_str} + _create_scheduled_instance(context, chat_id, routine_name, draft, schedule) - if instances: - keyboard.append([InlineKeyboardButton(f"⏹ Stop All ({len(instances)})", - callback_data=f"routines:stopall:{routine_name}")]) + msg = await update.message.reply_text(f"📅 Scheduled daily at {time_str}") + asyncio.create_task(_delete_after(msg, 2)) - keyboard.append([ - InlineKeyboardButton("❓ Help", callback_data=f"routines:help:{routine_name}"), - InlineKeyboardButton("« Back", callback_data="routines:menu"), - ]) + context.user_data["routines_state"] = "editing" - try: - await context.bot.edit_message_text( - chat_id=chat_id, - message_id=msg_id, - text=text, - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard), - ) - except Exception as e: - if "not modified" not in str(e).lower(): - logger.debug(f"Could not refresh: {e}") + # Refresh detail + msg_id = context.user_data.get("routines_msg_id") + if msg_id: + await _refresh_detail_msg(context, chat_id, msg_id, routine_name) # ============================================================================= @@ -826,21 +1199,6 @@ async def _refresh_detail(context: ContextTypes.DEFAULT_TYPE, routine_name: str) # ============================================================================= -async def _edit_or_send(update: Update, text: str, reply_markup: InlineKeyboardMarkup) -> None: - """Edit message if callback, otherwise send new.""" - if update.callback_query: - try: - await update.callback_query.message.edit_text( - text, parse_mode="MarkdownV2", reply_markup=reply_markup - ) - except Exception as e: - if "not modified" not in str(e).lower(): - logger.warning(f"Edit failed: {e}") - else: - msg = update.message or update.callback_query.message - await msg.reply_text(text, parse_mode="MarkdownV2", reply_markup=reply_markup) - - async def _delete_after(message, seconds: float) -> None: """Delete message after delay.""" await asyncio.sleep(seconds) @@ -895,29 +1253,56 @@ async def routines_callback_handler(update: Update, context: ContextTypes.DEFAUL await _show_detail(update, context, parts[2]) elif action == "run" and len(parts) >= 3: - await _run_oneshot(update, context, parts[2], background=False) + await _run_once(update, context, parts[2], background=False) elif action == "bg" and len(parts) >= 3: - await _run_oneshot(update, context, parts[2], background=True) + await _run_once(update, context, parts[2], background=True) elif action == "start" and len(parts) >= 3: - await _start_interval(update, context, parts[2]) + # Start continuous routine + await _start_continuous(update, context, parts[2]) + + elif action == "sched" and len(parts) >= 3: + await query.answer() + await _show_schedule_menu(update, context, parts[2]) + + elif action == "interval" and len(parts) >= 4: + routine_name = parts[2] + interval_sec = int(parts[3]) + await _start_interval(update, context, routine_name, interval_sec) + + elif action == "daily" and len(parts) >= 3: + await query.answer() + await _show_daily_menu(update, context, parts[2]) + + elif action == "dailyat" and len(parts) >= 4: + routine_name = parts[2] + time_str = parts[3] + await _start_daily(update, context, routine_name, time_str) elif action == "stop" and len(parts) >= 3: - job_name = ":".join(parts[2:]) - if _stop_instance(context, chat_id, job_name): + instance_id = parts[2] + logger.info(f"User {chat_id} stopping instance {instance_id}") + if _stop_instance(context, chat_id, instance_id): await query.answer("⏹ Stopped") else: await query.answer("Not found") await _show_tasks(update, context) elif action == "stopall" and len(parts) >= 3: - count = _stop_all(context, chat_id, parts[2]) + routine_name = parts[2] + logger.info(f"User {chat_id} stopping all instances of {routine_name}") + count = _stop_all_routine(context, chat_id, routine_name) await query.answer(f"⏹ Stopped {count}") - await _show_detail(update, context, parts[2]) + await _show_detail(update, context, routine_name) elif action == "stopall": - count = _stop_all(context, chat_id) + logger.info(f"User {chat_id} stopping ALL instances") + instances = _get_instances(context) + count = 0 + for iid in list(instances.keys()): + if _stop_instance(context, chat_id, iid): + count += 1 await query.answer(f"⏹ Stopped {count}") await _show_tasks(update, context) @@ -926,15 +1311,145 @@ async def routines_callback_handler(update: Update, context: ContextTypes.DEFAUL await _show_help(update, context, parts[2]) else: - await query.answer() + # Check for routine-specific callbacks + # Pattern: routines:{routine_name}:{action}:{params...} + routine = get_routine(action) + if routine and routine.callback_handler and len(parts) >= 3: + routine_action = parts[2] + routine_params = parts[3:] if len(parts) > 3 else [] + await routine.callback_handler(update, context, routine_action, routine_params) + else: + await query.answer() async def routines_message_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool: - """Handle text input for config editing.""" - if context.user_data.get("routines_state") == "editing": + """Handle text input for config editing, daily time, or routine-specific messages.""" + state = context.user_data.get("routines_state") + + if state == "editing": await _process_config(update, context, update.message.text.strip()) return True + elif state == "daily_time": + await _process_daily_time(update, context, update.message.text.strip()) + return True + + # Check for routine-specific message handlers + routine = get_routine_by_state(state) + if routine and routine.message_handler: + return await routine.message_handler(update, context) + return False -__all__ = ["routines_command", "routines_callback_handler", "routines_message_handler"] +async def restore_scheduled_jobs(application) -> int: + """ + Restore scheduled jobs from persisted instances after bot restart. + Call this during application startup (post_init). + Returns count of restored jobs. + """ + restored = 0 + removed = 0 + + for chat_id, user_data in application.user_data.items(): + instances = user_data.get("routine_instances", {}) + if not instances: + continue + + to_remove = [] + + for instance_id, inst in instances.items(): + if inst.get("status") != "running": + continue + + routine_name = inst.get("routine_name") + config_dict = inst.get("config", {}) + schedule = inst.get("schedule", {}) + stype = schedule.get("type", "once") + + # Check if routine still exists + routine = get_routine(routine_name) + if not routine: + logger.warning(f"Routine {routine_name} no longer exists, removing instance {instance_id}") + to_remove.append(instance_id) + continue + + # One-time jobs that didn't complete - remove them + if stype == "once": + to_remove.append(instance_id) + continue + + # Continuous routines need to be restarted as asyncio tasks + if stype == "continuous": + try: + task = asyncio.create_task( + _run_continuous_routine( + application, + instance_id, + routine_name, + config_dict, + chat_id, + ) + ) + _continuous_tasks[instance_id] = task + restored += 1 + logger.info(f"Restored continuous routine {instance_id}: {routine_name}") + except Exception as e: + logger.error(f"Failed to restore continuous routine {instance_id}: {e}") + to_remove.append(instance_id) + continue + + # Re-create scheduled jobs + job_name_str = _job_name(chat_id, instance_id) + job_data = { + "instance_id": instance_id, + "routine_name": routine_name, + "config_dict": config_dict, + "chat_id": chat_id, + } + + try: + if stype == "interval": + interval = schedule.get("interval_sec", 60) + application.job_queue.run_repeating( + _interval_job_callback, + interval=interval, + first=min(interval, 10), + data=job_data, + name=job_name_str, + chat_id=chat_id, + ) + restored += 1 + logger.info(f"Restored interval job {instance_id} for {routine_name} (every {interval}s)") + + elif stype == "daily": + time_str = schedule.get("daily_time", "09:00") + hour, minute = map(int, time_str.split(":")) + application.job_queue.run_daily( + _daily_job_callback, + time=dt_time(hour=hour, minute=minute), + data=job_data, + name=job_name_str, + chat_id=chat_id, + ) + restored += 1 + logger.info(f"Restored daily job {instance_id} for {routine_name} (at {time_str})") + + else: + to_remove.append(instance_id) + + except Exception as e: + logger.error(f"Failed to restore job {instance_id}: {e}") + to_remove.append(instance_id) + + # Clean up stale instances + for instance_id in to_remove: + del instances[instance_id] + removed += 1 + + if restored > 0 or removed > 0: + logger.info(f"Routine jobs: restored {restored}, removed {removed} stale") + + return restored + + +__all__ = ["routines_command", "routines_callback_handler", "routines_message_handler", "restore_scheduled_jobs"] diff --git a/handlers/signals/__init__.py b/handlers/signals/__init__.py new file mode 100644 index 0000000..6ee0713 --- /dev/null +++ b/handlers/signals/__init__.py @@ -0,0 +1,1216 @@ +""" +Signals Handler - ML prediction pipelines via Telegram. + +Features: +- Auto-discovery of signals from signals/ folder +- Training and prediction pipelines +- Text-based config editing (key=value) +- Background execution for training +- Scheduling for predictions +- SQLite storage for prediction history +""" + +import asyncio +import hashlib +import logging +import time +from datetime import time as dt_time + +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import ContextTypes, CallbackContext + +from handlers import clear_all_input_states +from signals.base import discover_signals, get_signal, get_latest_model_path +from signals.db import get_signals_db +from utils.auth import restricted +from utils.telegram_formatters import escape_markdown_v2 + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +SCHEDULE_PRESETS = [ + ("30s", 30), + ("1m", 60), + ("5m", 300), + ("15m", 900), + ("30m", 1800), + ("1h", 3600), +] + +DAILY_PRESETS = ["06:00", "09:00", "12:00", "18:00", "21:00"] + +# Global storage for running tasks (not persisted) +_running_tasks: dict[str, asyncio.Task] = {} # instance_id -> Task + + +# ============================================================================= +# Storage Helpers +# ============================================================================= + + +def _get_drafts(context: ContextTypes.DEFAULT_TYPE) -> dict: + """Get draft configs dict.""" + if "signals_drafts" not in context.user_data: + context.user_data["signals_drafts"] = {} + return context.user_data["signals_drafts"] + + +def _get_instances(context: ContextTypes.DEFAULT_TYPE) -> dict: + """Get all instances dict.""" + if "signals_instances" not in context.user_data: + context.user_data["signals_instances"] = {} + return context.user_data["signals_instances"] + + +def _get_draft( + context: ContextTypes.DEFAULT_TYPE, signal_name: str, pipeline: str +) -> dict: + """Get draft config for a signal pipeline, initializing from defaults if needed.""" + drafts = _get_drafts(context) + key = f"{signal_name}:{pipeline}" + if key not in drafts: + signal = get_signal(signal_name) + if signal: + pipe = signal.train_pipeline if pipeline == "train" else signal.predict_pipeline + if pipe: + drafts[key] = pipe.get_default_config().model_dump() + else: + drafts[key] = {} + else: + drafts[key] = {} + return drafts[key] + + +def _set_draft( + context: ContextTypes.DEFAULT_TYPE, signal_name: str, pipeline: str, config: dict +) -> None: + """Update draft config for a signal pipeline.""" + drafts = _get_drafts(context) + key = f"{signal_name}:{pipeline}" + drafts[key] = config + + +def _get_signal_instances( + context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> list[tuple[str, dict]]: + """Get all instances for a specific signal.""" + instances = _get_instances(context) + return [ + (iid, inst) + for iid, inst in instances.items() + if inst.get("signal_name") == signal_name + ] + + +def _generate_instance_id() -> str: + """Generate a short unique instance ID.""" + return hashlib.md5(f"{time.time()}{id(object())}".encode()).hexdigest()[:6] + + +# ============================================================================= +# Formatting Helpers +# ============================================================================= + + +def _display_name(name: str) -> str: + """Convert snake_case to Title Case.""" + return name.replace("_", " ").title() + + +def _format_duration(seconds: float) -> str: + """Format seconds as human-readable duration.""" + if seconds < 1: + return f"{seconds * 1000:.0f}ms" + if seconds < 60: + return f"{seconds:.1f}s" + if seconds < 3600: + return f"{int(seconds // 60)}m {int(seconds % 60)}s" + return f"{int(seconds // 3600)}h {int((seconds % 3600) // 60)}m" + + +def _format_schedule(schedule: dict) -> str: + """Format schedule as human-readable string.""" + stype = schedule.get("type", "once") + if stype == "once": + return "One-time" + elif stype == "interval": + secs = schedule.get("interval_sec", 60) + if secs < 60: + return f"Every {secs}s" + elif secs < 3600: + return f"Every {secs // 60}m" + else: + return f"Every {secs // 3600}h" + elif stype == "daily": + return f"Daily @ {schedule.get('daily_time', '09:00')}" + return "Unknown" + + +def _format_ago(timestamp: float) -> str: + """Format timestamp as 'X ago' string.""" + diff = time.time() - timestamp + if diff < 60: + return f"{int(diff)}s ago" + elif diff < 3600: + return f"{int(diff // 60)}m ago" + elif diff < 86400: + return f"{int(diff // 3600)}h ago" + return f"{int(diff // 86400)}d ago" + + +def _config_preview(config: dict, max_items: int = 2) -> str: + """Get short config preview.""" + items = list(config.items())[:max_items] + return ", ".join(f"{k}={v}" for k, v in items) + + +# ============================================================================= +# Job Management +# ============================================================================= + + +def _job_name(chat_id: int, instance_id: str) -> str: + """Build job name for JobQueue.""" + return f"signal_{chat_id}_{instance_id}" + + +def _find_job(context: ContextTypes.DEFAULT_TYPE, chat_id: int, instance_id: str): + """Find a job by instance ID.""" + name = _job_name(chat_id, instance_id) + jobs = context.job_queue.get_jobs_by_name(name) + return jobs[0] if jobs else None + + +def _stop_instance( + context: ContextTypes.DEFAULT_TYPE, chat_id: int, instance_id: str +) -> bool: + """Stop a job/task and remove instance. Returns True if found.""" + # Try to stop JobQueue job + job = _find_job(context, chat_id, instance_id) + if job: + job.schedule_removal() + logger.info(f"Removed scheduled job for instance {instance_id}") + + # Try to cancel asyncio task + task = _running_tasks.pop(instance_id, None) + if task and not task.done(): + task.cancel() + logger.info(f"Cancelled task for instance {instance_id}") + + # Remove from instances + instances = _get_instances(context) + if instance_id in instances: + signal_name = instances[instance_id].get("signal_name", "unknown") + del instances[instance_id] + logger.info(f"Stopped instance {instance_id} ({signal_name})") + return True + return False + + +# ============================================================================= +# Pipeline Execution +# ============================================================================= + + +async def _execute_pipeline( + context: ContextTypes.DEFAULT_TYPE, + instance_id: str, + signal_name: str, + pipeline: str, + config_dict: dict, + chat_id: int, +) -> tuple[str, float]: + """Execute a pipeline and return (result, duration).""" + signal = get_signal(signal_name) + if not signal: + return f"Signal {signal_name} not found", 0 + + pipe = signal.train_pipeline if pipeline == "train" else signal.predict_pipeline + if not pipe: + return f"Pipeline {pipeline} not found for {signal_name}", 0 + + # Prepare context + context._chat_id = chat_id + context._instance_id = instance_id + context._user_data = context.user_data if hasattr(context, 'user_data') else {} + + start = time.time() + try: + config = pipe.config_class(**config_dict) + result = await pipe.run_fn(config, context) + result_text = str(result)[:2000] # Truncate long results + except Exception as e: + logger.error(f"Pipeline {pipeline} failed: {e}", exc_info=True) + result_text = f"Error: {e}" + + duration = time.time() - start + return result_text, duration + + +async def _run_pipeline_background( + application, + instance_id: str, + signal_name: str, + pipeline: str, + config_dict: dict, + chat_id: int, +) -> None: + """Run a pipeline as a background task.""" + # Create a mock context + class MockContext: + def __init__(self): + self._chat_id = chat_id + self._instance_id = instance_id + self._user_data = application.user_data.get(chat_id, {}) + self.bot = application.bot + self.application = application + self.user_data = self._user_data + + context = MockContext() + + try: + result, duration = await _execute_pipeline( + context, instance_id, signal_name, pipeline, config_dict, chat_id + ) + + # Update instance + instances = application.user_data.get(chat_id, {}).get("signals_instances", {}) + if instance_id in instances: + instances[instance_id]["last_result"] = result + instances[instance_id]["last_duration"] = duration + instances[instance_id]["last_run_at"] = time.time() + instances[instance_id]["run_count"] = instances[instance_id].get("run_count", 0) + 1 + + # Send result message + result_preview = result[:500] if len(result) > 500 else result + await application.bot.send_message( + chat_id, + f"*{escape_markdown_v2(signal_name.upper())} \\- {pipeline.upper()}*\n\n" + f"```\n{escape_markdown_v2(result_preview)}\n```\n\n" + f"Duration: {escape_markdown_v2(_format_duration(duration))}", + parse_mode="MarkdownV2", + ) + + except asyncio.CancelledError: + logger.info(f"Pipeline {signal_name}:{pipeline} cancelled") + except Exception as e: + logger.error(f"Background pipeline failed: {e}", exc_info=True) + try: + await application.bot.send_message( + chat_id, f"Pipeline {signal_name}:{pipeline} failed: {e}" + ) + except Exception: + pass + + +def _create_background_instance( + context: ContextTypes.DEFAULT_TYPE, + chat_id: int, + signal_name: str, + pipeline: str, + config_dict: dict, +) -> str: + """Create a background execution instance. Returns instance_id.""" + instance_id = _generate_instance_id() + instances = _get_instances(context) + + instances[instance_id] = { + "signal_name": signal_name, + "pipeline": pipeline, + "config": config_dict.copy(), + "schedule": {"type": "once"}, + "status": "running", + "created_at": time.time(), + "run_count": 0, + } + + # Create task + task = asyncio.create_task( + _run_pipeline_background( + context.application, instance_id, signal_name, pipeline, config_dict, chat_id + ) + ) + _running_tasks[instance_id] = task + + return instance_id + + +# ============================================================================= +# Scheduled Execution +# ============================================================================= + + +async def _interval_job_callback(context: CallbackContext) -> None: + """Callback for interval-scheduled pipelines.""" + job_data = context.job.data + chat_id = job_data["chat_id"] + instance_id = job_data["instance_id"] + signal_name = job_data["signal_name"] + pipeline = job_data["pipeline"] + config_dict = job_data["config"] + + instances = context.application.user_data.get(chat_id, {}).get("signals_instances", {}) + if instance_id not in instances: + context.job.schedule_removal() + return + + # Run pipeline + result, duration = await _execute_pipeline( + context, instance_id, signal_name, pipeline, config_dict, chat_id + ) + + # Update instance + instances[instance_id]["last_result"] = result + instances[instance_id]["last_duration"] = duration + instances[instance_id]["last_run_at"] = time.time() + instances[instance_id]["run_count"] = instances[instance_id].get("run_count", 0) + 1 + + # Send result + result_preview = result[:300] if len(result) > 300 else result + try: + await context.bot.send_message( + chat_id, + f"*{escape_markdown_v2(signal_name.upper())} \\- {pipeline.upper()}*\n\n" + f"```\n{escape_markdown_v2(result_preview)}\n```", + parse_mode="MarkdownV2", + ) + except Exception as e: + logger.error(f"Failed to send result: {e}") + + +def _create_scheduled_instance( + context: ContextTypes.DEFAULT_TYPE, + chat_id: int, + signal_name: str, + pipeline: str, + config_dict: dict, + schedule: dict, +) -> str: + """Create a scheduled instance. Returns instance_id.""" + instance_id = _generate_instance_id() + instances = _get_instances(context) + + instances[instance_id] = { + "signal_name": signal_name, + "pipeline": pipeline, + "config": config_dict.copy(), + "schedule": schedule, + "status": "running", + "created_at": time.time(), + "run_count": 0, + } + + job_data = { + "chat_id": chat_id, + "instance_id": instance_id, + "signal_name": signal_name, + "pipeline": pipeline, + "config": config_dict.copy(), + } + + job_name = _job_name(chat_id, instance_id) + stype = schedule.get("type") + + if stype == "interval": + interval = schedule.get("interval_sec", 60) + context.job_queue.run_repeating( + _interval_job_callback, + interval=interval, + first=interval, + data=job_data, + name=job_name, + chat_id=chat_id, + ) + elif stype == "daily": + time_str = schedule.get("daily_time", "09:00") + hour, minute = map(int, time_str.split(":")) + context.job_queue.run_daily( + _interval_job_callback, + time=dt_time(hour=hour, minute=minute), + data=job_data, + name=job_name, + chat_id=chat_id, + ) + + return instance_id + + +# ============================================================================= +# UI Display +# ============================================================================= + + +async def _edit_or_send( + update: Update, text: str, reply_markup: InlineKeyboardMarkup +) -> None: + """Edit message if callback, otherwise send new.""" + if update.callback_query: + try: + await update.callback_query.message.edit_text( + text, parse_mode="MarkdownV2", reply_markup=reply_markup + ) + except Exception as e: + if "not modified" not in str(e).lower(): + logger.warning(f"Edit failed: {e}") + else: + msg = update.message or update.callback_query.message + await msg.reply_text(text, parse_mode="MarkdownV2", reply_markup=reply_markup) + + +async def _show_menu(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show main signals menu.""" + signals = discover_signals(force_reload=True) + instances = _get_instances(context) + db = get_signals_db() + + # Count running instances + running_count = sum(1 for inst in instances.values() if inst.get("status") == "running") + + if not signals: + text = ( + "📊 *SIGNALS*\n" + "━━━━━━━━━━━━━━━━━━━━━\n\n" + "No signals found\\.\n\n" + "Add folders to `signals/` with pipelines\\." + ) + keyboard = [[InlineKeyboardButton("🔄 Reload", callback_data="signals:reload")]] + else: + keyboard = [] + + for name in sorted(signals.keys()): + signal = signals[name] + pred_count = db.get_count(name) + model_path = get_latest_model_path(name) + + # Build label + icons = [] + if signal.has_train: + icons.append("🔧") + if signal.has_predict: + icons.append("📊") + + model_info = "✓" if model_path else "" + label = f"{''.join(icons)} {_display_name(name)} {model_info}" + + keyboard.append( + [InlineKeyboardButton(label, callback_data=f"signals:select:{name}")] + ) + + # Footer buttons + footer = [] + if running_count > 0: + footer.append( + InlineKeyboardButton(f"Running ({running_count})", callback_data="signals:tasks") + ) + footer.append(InlineKeyboardButton("🔄 Reload", callback_data="signals:reload")) + keyboard.append(footer) + + text = ( + "📊 *SIGNALS*\n" + "━━━━━━━━━━━━━━━━━━━━━\n\n" + "Select a signal to train or predict\\.\n\n" + f"🔧 Train \\| 📊 Predict \\| ✓ Has model" + ) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_detail( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> None: + """Show signal detail view.""" + signal = get_signal(signal_name) + if not signal: + if update.callback_query: + await update.callback_query.answer("Signal not found") + return + + db = get_signals_db() + pred_count = db.get_count(signal_name) + model_path = get_latest_model_path(signal_name) + instances = _get_signal_instances(context, signal_name) + running = [(iid, inst) for iid, inst in instances if inst.get("status") == "running"] + + # Model info + if model_path: + model_info = f"Latest model: `{escape_markdown_v2(model_path.name)}`" + else: + model_info = "_No trained model yet_" + + text = ( + f"📊 *{escape_markdown_v2(_display_name(signal_name).upper())}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n" + f"_{escape_markdown_v2(signal.description)}_\n\n" + f"{model_info}\n" + f"Predictions: {pred_count} total\n" + ) + + if running: + text += f"\n🟢 {len(running)} running\n" + + keyboard = [] + + # Pipeline buttons + row = [] + if signal.has_train: + row.append(InlineKeyboardButton("🔧 Train/Eval", callback_data=f"signals:train:{signal_name}")) + if signal.has_predict: + row.append(InlineKeyboardButton("📊 Predict", callback_data=f"signals:predict:{signal_name}")) + if row: + keyboard.append(row) + + # History and stop buttons + row2 = [] + if pred_count > 0: + row2.append(InlineKeyboardButton("📜 History", callback_data=f"signals:history:{signal_name}")) + if running: + row2.append(InlineKeyboardButton(f"⏹ Stop ({len(running)})", callback_data=f"signals:stopall:{signal_name}")) + if row2: + keyboard.append(row2) + + keyboard.append([InlineKeyboardButton("« Back", callback_data="signals:menu")]) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_pipeline( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + signal_name: str, + pipeline: str, +) -> None: + """Show pipeline config editor.""" + signal = get_signal(signal_name) + if not signal: + return + + pipe = signal.train_pipeline if pipeline == "train" else signal.predict_pipeline + if not pipe: + return + + # Set editing state + context.user_data["signals_state"] = "editing" + context.user_data["signals_editing"] = {"signal": signal_name, "pipeline": pipeline} + + draft = _get_draft(context, signal_name, pipeline) + fields = pipe.get_fields() + + # Build config display + config_lines = [f"{k}={draft.get(k, v['default'])}" for k, v in fields.items()] + + pipeline_label = "TRAIN" if pipeline == "train" else "PREDICT" + text = ( + f"📊 *{escape_markdown_v2(_display_name(signal_name).upper())} \\- {pipeline_label}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n" + f"_{escape_markdown_v2(pipe.description)}_\n\n" + f"```\n{chr(10).join(config_lines)}\n```\n\n" + f"_✏️ Send key\\=value to edit_" + ) + + keyboard = [] + + if pipeline == "train": + # Training always runs in background + keyboard.append([ + InlineKeyboardButton("🚀 Start Training", callback_data=f"signals:run_train:{signal_name}") + ]) + else: + # Predict can run immediately or be scheduled + keyboard.append([ + InlineKeyboardButton("▶️ Run", callback_data=f"signals:run_predict:{signal_name}"), + InlineKeyboardButton("🔄 Background", callback_data=f"signals:bg_predict:{signal_name}"), + ]) + keyboard.append([ + InlineKeyboardButton("⏱️ Schedule", callback_data=f"signals:sched:{signal_name}"), + ]) + + keyboard.append([ + InlineKeyboardButton("❓ Help", callback_data=f"signals:help:{signal_name}:{pipeline}"), + InlineKeyboardButton("« Back", callback_data=f"signals:select:{signal_name}"), + ]) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_schedule_menu( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> None: + """Show schedule options for prediction.""" + text = ( + f"⏱️ *SCHEDULE: {escape_markdown_v2(_display_name(signal_name).upper())}*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n\n" + "Choose interval or daily time\\." + ) + + keyboard = [] + + # Interval presets + row = [] + for label, secs in SCHEDULE_PRESETS: + row.append( + InlineKeyboardButton(label, callback_data=f"signals:interval:{signal_name}:{secs}") + ) + if len(row) == 3: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + # Daily option + keyboard.append([ + InlineKeyboardButton("📅 Daily...", callback_data=f"signals:daily:{signal_name}") + ]) + + keyboard.append([InlineKeyboardButton("« Cancel", callback_data=f"signals:predict:{signal_name}")]) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_daily_menu( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> None: + """Show daily time selection.""" + text = ( + f"📅 *DAILY SCHEDULE*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n\n" + "Select time or send custom \\(HH:MM\\)\\." + ) + + context.user_data["signals_state"] = "daily_time" + context.user_data["signals_editing"] = {"signal": signal_name, "pipeline": "predict"} + + keyboard = [] + row = [] + for time_str in DAILY_PRESETS: + row.append( + InlineKeyboardButton(time_str, callback_data=f"signals:dailyat:{signal_name}:{time_str}") + ) + if len(row) == 3: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + keyboard.append([InlineKeyboardButton("« Cancel", callback_data=f"signals:sched:{signal_name}")]) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_history( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> None: + """Show prediction history.""" + db = get_signals_db() + predictions = db.get_predictions(signal_name, limit=10) + total = db.get_count(signal_name) + + if not predictions: + text = ( + f"📜 *{escape_markdown_v2(_display_name(signal_name).upper())} \\- HISTORY*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n\n" + "No predictions yet\\." + ) + else: + lines = [ + f"📜 *{escape_markdown_v2(_display_name(signal_name).upper())} \\- HISTORY*", + "━━━━━━━━━━━━━━━━━━━━━", + f"_{total} predictions total_\n", + ] + + for pred in predictions: + time_str = pred.created_at.strftime("%m/%d %H:%M") + result_preview = pred.result[:50].replace("\n", " ") + lines.append(f"• `{escape_markdown_v2(time_str)}` {escape_markdown_v2(result_preview)}") + + text = "\n".join(lines) + + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"signals:select:{signal_name}")]] + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_help( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + signal_name: str, + pipeline: str, +) -> None: + """Show field descriptions for a pipeline.""" + signal = get_signal(signal_name) + if not signal: + return + + pipe = signal.train_pipeline if pipeline == "train" else signal.predict_pipeline + if not pipe: + return + + fields = pipe.get_fields() + lines = [ + f"❓ *{escape_markdown_v2(_display_name(signal_name).upper())} \\- HELP*", + "━━━━━━━━━━━━━━━━━━━━━\n", + ] + + for name, info in fields.items(): + lines.append(f"• `{name}` _{escape_markdown_v2(info['type'])}_") + lines.append(f" {escape_markdown_v2(info['description'])}") + lines.append(f" Default: `{escape_markdown_v2(str(info['default']))}`\n") + + text = "\n".join(lines) + keyboard = [[InlineKeyboardButton("« Back", callback_data=f"signals:{pipeline}:{signal_name}")]] + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +async def _show_tasks(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show all running tasks.""" + instances = _get_instances(context) + running = [(iid, inst) for iid, inst in instances.items() if inst.get("status") == "running"] + + if not running: + text = ( + "📋 *RUNNING TASKS*\n" + "━━━━━━━━━━━━━━━━━━━━━\n\n" + "No tasks running\\." + ) + keyboard = [[InlineKeyboardButton("« Back", callback_data="signals:menu")]] + else: + lines = ["📋 *RUNNING TASKS*", "━━━━━━━━━━━━━━━━━━━━━\n"] + keyboard = [] + + for iid, inst in running: + name = inst["signal_name"] + pipeline = inst.get("pipeline", "predict") + schedule = inst.get("schedule", {}) + created = inst.get("created_at", time.time()) + run_count = inst.get("run_count", 0) + + lines.append(f"🟢 *{escape_markdown_v2(_display_name(name))}* `{iid}`") + lines.append(f" {pipeline} \\| {escape_markdown_v2(_format_schedule(schedule))} \\| {run_count} runs") + lines.append("") + + keyboard.append([ + InlineKeyboardButton(f"⏹ {_display_name(name)[:12]}[{iid}]", callback_data=f"signals:stop:{iid}") + ]) + + keyboard.append([InlineKeyboardButton("⏹ Stop All", callback_data="signals:stopall")]) + keyboard.append([InlineKeyboardButton("« Back", callback_data="signals:menu")]) + text = "\n".join(lines) + + await _edit_or_send(update, text, InlineKeyboardMarkup(keyboard)) + + +# ============================================================================= +# Action Handlers +# ============================================================================= + + +async def _run_train( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str +) -> None: + """Start training in background.""" + query = update.callback_query + chat_id = update.effective_chat.id + + draft = _get_draft(context, signal_name, "train") + + await query.answer("🚀 Training started...") + + instance_id = _create_background_instance( + context, chat_id, signal_name, "train", draft + ) + + await query.message.reply_text( + f"🚀 *Training Started*\n\n" + f"Signal: {escape_markdown_v2(signal_name)}\n" + f"Instance: `{instance_id}`\n\n" + f"_This may take a while\\. You'll receive a message when complete\\._", + parse_mode="MarkdownV2", + ) + + +async def _run_predict( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str, background: bool +) -> None: + """Run prediction.""" + query = update.callback_query + chat_id = update.effective_chat.id + + draft = _get_draft(context, signal_name, "predict") + + if background: + await query.answer("🔄 Running in background...") + instance_id = _create_background_instance( + context, chat_id, signal_name, "predict", draft + ) + await query.message.reply_text( + f"🔄 *Prediction Running*\n\n" + f"Instance: `{instance_id}`", + parse_mode="MarkdownV2", + ) + else: + await query.answer("▶️ Running...") + + # Run directly + result, duration = await _execute_pipeline( + context, "direct", signal_name, "predict", draft, chat_id + ) + + result_preview = result[:1000] if len(result) > 1000 else result + await query.message.reply_text( + f"📊 *{escape_markdown_v2(signal_name.upper())} \\- PREDICTION*\n\n" + f"```\n{escape_markdown_v2(result_preview)}\n```\n\n" + f"Duration: {escape_markdown_v2(_format_duration(duration))}", + parse_mode="MarkdownV2", + ) + + +async def _start_interval( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str, interval_sec: int +) -> None: + """Start interval-scheduled prediction.""" + query = update.callback_query + chat_id = update.effective_chat.id + + draft = _get_draft(context, signal_name, "predict") + schedule = {"type": "interval", "interval_sec": interval_sec} + + instance_id = _create_scheduled_instance( + context, chat_id, signal_name, "predict", draft, schedule + ) + + await query.answer(f"⏱️ Scheduled every {_format_schedule(schedule)}") + await _show_detail(update, context, signal_name) + + +async def _start_daily( + update: Update, context: ContextTypes.DEFAULT_TYPE, signal_name: str, time_str: str +) -> None: + """Start daily-scheduled prediction.""" + query = update.callback_query + chat_id = update.effective_chat.id + + draft = _get_draft(context, signal_name, "predict") + schedule = {"type": "daily", "daily_time": time_str} + + instance_id = _create_scheduled_instance( + context, chat_id, signal_name, "predict", draft, schedule + ) + + await query.answer(f"📅 Scheduled daily at {time_str}") + await _show_detail(update, context, signal_name) + + +async def _process_config( + update: Update, context: ContextTypes.DEFAULT_TYPE, text: str +) -> None: + """Process key=value config input.""" + editing = context.user_data.get("signals_editing", {}) + signal_name = editing.get("signal") + pipeline = editing.get("pipeline") + + if not signal_name or not pipeline: + return + + signal = get_signal(signal_name) + if not signal: + return + + pipe = signal.train_pipeline if pipeline == "train" else signal.predict_pipeline + if not pipe: + return + + draft = _get_draft(context, signal_name, pipeline) + fields = pipe.get_fields() + + # Parse input lines + for line in text.strip().split("\n"): + line = line.strip() + if "=" not in line: + continue + + key, value = line.split("=", 1) + key = key.strip() + value = value.strip() + + if key not in fields: + continue + + # Type conversion + field_type = fields[key]["type"] + try: + if field_type == "int": + value = int(value) + elif field_type == "float": + value = float(value) + elif field_type == "bool": + value = value.lower() in ("true", "1", "yes") + draft[key] = value + except ValueError: + pass + + _set_draft(context, signal_name, pipeline, draft) + + # Delete user message and refresh view + try: + await update.message.delete() + except Exception: + pass + + await _show_pipeline(update, context, signal_name, pipeline) + + +async def _process_daily_time( + update: Update, context: ContextTypes.DEFAULT_TYPE, text: str +) -> None: + """Process custom daily time input.""" + editing = context.user_data.get("signals_editing", {}) + signal_name = editing.get("signal") + + if not signal_name: + return + + # Validate time format + try: + hour, minute = map(int, text.strip().split(":")) + if not (0 <= hour <= 23 and 0 <= minute <= 59): + raise ValueError + time_str = f"{hour:02d}:{minute:02d}" + except ValueError: + await update.message.reply_text("Invalid time format. Use HH:MM (e.g., 09:30)") + return + + context.user_data.pop("signals_state", None) + context.user_data.pop("signals_editing", None) + + # Delete user message + try: + await update.message.delete() + except Exception: + pass + + await _start_daily(update, context, signal_name, time_str) + + +# ============================================================================= +# Command and Callback Handlers +# ============================================================================= + + +@restricted +async def signals_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /signals command.""" + clear_all_input_states(context) + await _show_menu(update, context) + + +@restricted +async def signals_callback_handler( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: + """Handle callback queries.""" + query = update.callback_query + chat_id = update.effective_chat.id + parts = query.data.split(":") + + if len(parts) < 2: + await query.answer() + return + + action = parts[1] + + if action == "menu": + await query.answer() + context.user_data.pop("signals_state", None) + context.user_data.pop("signals_editing", None) + await _show_menu(update, context) + + elif action == "reload": + await query.answer("Reloading...") + discover_signals(force_reload=True) + await _show_menu(update, context) + + elif action == "tasks": + await query.answer() + await _show_tasks(update, context) + + elif action == "select" and len(parts) >= 3: + await query.answer() + context.user_data.pop("signals_state", None) + context.user_data.pop("signals_editing", None) + await _show_detail(update, context, parts[2]) + + elif action == "train" and len(parts) >= 3: + await query.answer() + await _show_pipeline(update, context, parts[2], "train") + + elif action == "predict" and len(parts) >= 3: + await query.answer() + await _show_pipeline(update, context, parts[2], "predict") + + elif action == "run_train" and len(parts) >= 3: + await _run_train(update, context, parts[2]) + + elif action == "run_predict" and len(parts) >= 3: + await _run_predict(update, context, parts[2], background=False) + + elif action == "bg_predict" and len(parts) >= 3: + await _run_predict(update, context, parts[2], background=True) + + elif action == "sched" and len(parts) >= 3: + await query.answer() + await _show_schedule_menu(update, context, parts[2]) + + elif action == "interval" and len(parts) >= 4: + await _start_interval(update, context, parts[2], int(parts[3])) + + elif action == "daily" and len(parts) >= 3: + await query.answer() + await _show_daily_menu(update, context, parts[2]) + + elif action == "dailyat" and len(parts) >= 4: + await _start_daily(update, context, parts[2], parts[3]) + + elif action == "history" and len(parts) >= 3: + await query.answer() + await _show_history(update, context, parts[2]) + + elif action == "help" and len(parts) >= 4: + await query.answer() + await _show_help(update, context, parts[2], parts[3]) + + elif action == "stop" and len(parts) >= 3: + instance_id = parts[2] + if _stop_instance(context, chat_id, instance_id): + await query.answer("⏹ Stopped") + else: + await query.answer("Not found") + await _show_tasks(update, context) + + elif action == "stopall" and len(parts) >= 3: + signal_name = parts[2] + instances = _get_signal_instances(context, signal_name) + count = 0 + for iid, _ in instances: + if _stop_instance(context, chat_id, iid): + count += 1 + await query.answer(f"⏹ Stopped {count}") + await _show_detail(update, context, signal_name) + + elif action == "stopall": + instances = _get_instances(context) + count = 0 + for iid in list(instances.keys()): + if _stop_instance(context, chat_id, iid): + count += 1 + await query.answer(f"⏹ Stopped {count}") + await _show_tasks(update, context) + + else: + await query.answer() + + +async def signals_message_handler( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> bool: + """Handle text input for config editing or daily time.""" + state = context.user_data.get("signals_state") + + if state == "editing": + await _process_config(update, context, update.message.text.strip()) + return True + elif state == "daily_time": + await _process_daily_time(update, context, update.message.text.strip()) + return True + + return False + + +async def restore_signal_jobs(application) -> int: + """ + Restore scheduled jobs from persisted instances after bot restart. + Call this during application startup (post_init). + Returns count of restored jobs. + """ + restored = 0 + + for chat_id, user_data in application.user_data.items(): + instances = user_data.get("signals_instances", {}) + if not instances: + continue + + to_remove = [] + + for instance_id, inst in instances.items(): + if inst.get("status") != "running": + continue + + signal_name = inst.get("signal_name") + pipeline = inst.get("pipeline", "predict") + config_dict = inst.get("config", {}) + schedule = inst.get("schedule", {}) + stype = schedule.get("type", "once") + + # Check if signal still exists + signal = get_signal(signal_name) + if not signal: + logger.warning(f"Signal {signal_name} no longer exists, removing instance {instance_id}") + to_remove.append(instance_id) + continue + + # Only restore scheduled jobs (not one-time) + if stype == "once": + to_remove.append(instance_id) + continue + + # Create mock context for job creation + class MockContext: + def __init__(self): + self.job_queue = application.job_queue + self.user_data = user_data + + mock_ctx = MockContext() + + job_data = { + "chat_id": chat_id, + "instance_id": instance_id, + "signal_name": signal_name, + "pipeline": pipeline, + "config": config_dict, + } + + job_name = _job_name(chat_id, instance_id) + + if stype == "interval": + interval = schedule.get("interval_sec", 60) + application.job_queue.run_repeating( + _interval_job_callback, + interval=interval, + first=interval, + data=job_data, + name=job_name, + chat_id=chat_id, + ) + restored += 1 + logger.info(f"Restored interval job for {signal_name}:{pipeline} [{instance_id}]") + + elif stype == "daily": + time_str = schedule.get("daily_time", "09:00") + hour, minute = map(int, time_str.split(":")) + application.job_queue.run_daily( + _interval_job_callback, + time=dt_time(hour=hour, minute=minute), + data=job_data, + name=job_name, + chat_id=chat_id, + ) + restored += 1 + logger.info(f"Restored daily job for {signal_name}:{pipeline} [{instance_id}]") + + # Clean up old instances + for iid in to_remove: + del instances[iid] + + logger.info(f"Restored {restored} signal jobs") + return restored diff --git a/handlers/trading/__init__.py b/handlers/trading/__init__.py new file mode 100644 index 0000000..965ef27 --- /dev/null +++ b/handlers/trading/__init__.py @@ -0,0 +1,221 @@ +""" +Unified Trading Entry Point + +Routes to DEX (swap.py) or CEX (trade.py) based on connector type. +Provides a single /trade command that works with both CEX and DEX connectors. +Uses portfolio connectors (ones with API keys configured). +""" + +import logging +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import ContextTypes + +from utils.auth import restricted, hummingbot_api_required +from utils.telegram_formatters import escape_markdown_v2 +from handlers import clear_all_input_states, is_gateway_network +from handlers.config.user_preferences import ( + get_last_trade_connector, + set_last_trade_connector, + get_clob_order_defaults, + get_dex_swap_defaults, +) +from handlers.cex.trade import handle_trade as cex_handle_trade +from handlers.dex.swap import handle_swap as dex_handle_swap + +logger = logging.getLogger(__name__) + + +def _format_network_display(network_id: str) -> str: + """Format network ID for button display. + + Examples: + solana-mainnet-beta -> Solana + ethereum-mainnet -> Ethereum + solana-devnet -> Solana Dev + """ + if not network_id: + return "Network" + + parts = network_id.split("-") + chain = parts[0].capitalize() + + if len(parts) > 1: + net = parts[1] + if net in ("mainnet", "mainnet-beta"): + return chain + elif net == "devnet": + return f"{chain} Dev" + elif net == "testnet": + return f"{chain} Test" + else: + return f"{chain} {net[:4]}" + + return chain + + +async def _get_portfolio_connectors(client) -> tuple: + """Get connectors from portfolio state, split by type. + + Returns: + (cex_connectors, gateway_networks) - both from portfolio.get_state() + """ + try: + state = await client.portfolio.get_state() + # state = {account_name: {connector_name: [balances]}} + + cex = set() + gateway = set() + + for account_data in state.values(): + if isinstance(account_data, dict): + for connector_name in account_data.keys(): + if is_gateway_network(connector_name): + gateway.add(connector_name) + else: + cex.add(connector_name) + + return sorted(cex), sorted(gateway) + except Exception as e: + logger.warning(f"Error fetching portfolio connectors: {e}") + return [], [] + + +@restricted +@hummingbot_api_required +async def trade_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Unified /trade command - routes to CEX or DEX based on last-used connector""" + clear_all_input_states(context) + + # Get last-used connector type and name + connector_type, connector_name = get_last_trade_connector(context.user_data) + + if connector_type == "cex" and connector_name: + # Route to CEX trade + defaults = get_clob_order_defaults(context.user_data) + defaults["connector"] = connector_name + context.user_data["trade_params"] = defaults + await cex_handle_trade(update, context) + elif connector_type == "dex" and connector_name: + # Route to DEX swap with network pre-set + # For DEX, connector_name is actually the network (e.g., solana-mainnet-beta) + defaults = get_dex_swap_defaults(context.user_data) + defaults["network"] = connector_name + context.user_data["swap_params"] = defaults + await dex_handle_swap(update, context) + else: + # First time or no preference - show connector selector + await handle_unified_connector_select(update, context) + + +async def handle_unified_connector_select(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Show all connectors (CEX + DEX networks) from portfolio for selection""" + chat_id = update.effective_chat.id + + try: + from config_manager import get_client + client = await get_client(chat_id, context=context) + + # Fetch connectors from portfolio (ones with API keys/wallets configured) + cex_connectors, gateway_networks = await _get_portfolio_connectors(client) + + # Build keyboard with groups + keyboard = [] + + # CEX section - connector names (binance, bybit_perpetual, etc.) + if cex_connectors: + keyboard.append([InlineKeyboardButton("━━ CEX ━━", callback_data="trade:noop")]) + row = [] + for connector in cex_connectors: + row.append(InlineKeyboardButton( + connector, + callback_data=f"trade:select_cex:{connector}" + )) + if len(row) == 2: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + # DEX section - network names (solana-mainnet-beta, ethereum-mainnet, etc.) + if gateway_networks: + keyboard.append([InlineKeyboardButton("━━ DEX ━━", callback_data="trade:noop")]) + row = [] + for network in gateway_networks: + display = _format_network_display(network) + row.append(InlineKeyboardButton( + display, + callback_data=f"trade:select_dex:{network}" + )) + if len(row) == 2: + keyboard.append(row) + row = [] + if row: + keyboard.append(row) + + if not cex_connectors and not gateway_networks: + help_text = escape_markdown_v2( + "🔄 Select Connector\n\n" + "No connectors found in portfolio.\n" + "Add API keys via /config to get started." + ) + else: + help_text = r"🔄 *Select Connector*" + "\n\n" + r"Choose a trading connector:" + + reply_markup = InlineKeyboardMarkup(keyboard) + + if update.callback_query: + await update.callback_query.message.edit_text( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + else: + await update.message.reply_text( + help_text, + parse_mode="MarkdownV2", + reply_markup=reply_markup + ) + + except Exception as e: + logger.error(f"Error showing connector selector: {e}", exc_info=True) + error_text = escape_markdown_v2(f"Error loading connectors: {str(e)}") + if update.callback_query: + await update.callback_query.message.edit_text(error_text, parse_mode="MarkdownV2") + else: + await update.message.reply_text(error_text, parse_mode="MarkdownV2") + + +async def handle_select_cex_connector(update: Update, context: ContextTypes.DEFAULT_TYPE, connector_name: str) -> None: + """Handle CEX connector selection from unified selector""" + # Save preference + set_last_trade_connector(context.user_data, "cex", connector_name) + + # Pre-set connector and delegate to CEX trade + defaults = get_clob_order_defaults(context.user_data) + defaults["connector"] = connector_name + context.user_data["trade_params"] = defaults + + await cex_handle_trade(update, context) + + +async def handle_select_dex_network(update: Update, context: ContextTypes.DEFAULT_TYPE, network: str) -> None: + """Handle DEX network selection from unified selector""" + # Save preference - for DEX we store the network (e.g., solana-mainnet-beta) + set_last_trade_connector(context.user_data, "dex", network) + + # Pre-set network and delegate to DEX swap + defaults = get_dex_swap_defaults(context.user_data) + defaults["network"] = network + context.user_data["swap_params"] = defaults + + await dex_handle_swap(update, context) + + +async def handle_back(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle back from connector selector - go to last used connector""" + connector_type, connector_name = get_last_trade_connector(context.user_data) + + if connector_type == "cex": + await cex_handle_trade(update, context) + else: + await dex_handle_swap(update, context) diff --git a/handlers/trading/router.py b/handlers/trading/router.py new file mode 100644 index 0000000..1205f5e --- /dev/null +++ b/handlers/trading/router.py @@ -0,0 +1,54 @@ +""" +Unified Trade Callback Router + +Handles trade:* callbacks for connector switching between CEX and DEX. +""" + +import logging +from telegram import Update +from telegram.ext import ContextTypes + +from utils.auth import restricted +from . import ( + handle_unified_connector_select, + handle_select_cex_connector, + handle_select_dex_network, + handle_back, +) + +logger = logging.getLogger(__name__) + + +@restricted +async def unified_trade_callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle trade:* callbacks for connector switching""" + query = update.callback_query + await query.answer() + + # Parse action from callback data + callback_parts = query.data.split(":", 1) + action = callback_parts[1] if len(callback_parts) > 1 else query.data + + logger.debug(f"Unified trade callback: {action}") + + # Route based on action + if action == "select_connector": + await handle_unified_connector_select(update, context) + + elif action.startswith("select_cex:"): + connector = action.replace("select_cex:", "") + await handle_select_cex_connector(update, context, connector) + + elif action.startswith("select_dex:"): + network = action.replace("select_dex:", "") + await handle_select_dex_network(update, context, network) + + elif action == "back": + await handle_back(update, context) + + elif action == "noop": + # No-op for separator buttons + pass + + else: + logger.warning(f"Unknown unified trade action: {action}") diff --git a/main.py b/main.py index 079aac2..da4a631 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from pathlib import Path from telegram import Update, BotCommand, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.error import NetworkError from telegram.ext import ( Application, CommandHandler, @@ -14,12 +15,6 @@ PicklePersistence, ) -from handlers.portfolio import portfolio_command, get_portfolio_callback_handler -from handlers.bots import bots_command, bots_callback_handler -from handlers.cex import trade_command, cex_callback_handler -from handlers.dex import swap_command, lp_command, dex_callback_handler -from handlers.config import config_command, get_config_callback_handler, get_modify_value_handler -from handlers.routines import routines_command, routines_callback_handler from handlers import clear_all_input_states from utils.auth import restricted from utils.config import TELEGRAM_TOKEN @@ -31,260 +26,86 @@ logger = logging.getLogger(__name__) -def _get_start_menu_keyboard() -> InlineKeyboardMarkup: +def _get_start_menu_keyboard(is_admin: bool = False) -> InlineKeyboardMarkup: """Build the start menu inline keyboard.""" keyboard = [ - [ - InlineKeyboardButton("📊 Portfolio", callback_data="start:portfolio"), - InlineKeyboardButton("🤖 Bots", callback_data="start:bots"), - ], - [ - InlineKeyboardButton("💱 Swap", callback_data="start:swap"), - InlineKeyboardButton("📊 Trade", callback_data="start:trade"), - InlineKeyboardButton("💧 LP", callback_data="start:lp"), - ], [ InlineKeyboardButton("🔌 Servers", callback_data="start:config_servers"), InlineKeyboardButton("🔑 Keys", callback_data="start:config_keys"), InlineKeyboardButton("🌐 Gateway", callback_data="start:config_gateway"), ], - [ - InlineKeyboardButton("❓ Help", callback_data="start:help"), - ], - ] - return InlineKeyboardMarkup(keyboard) - - -def _get_help_keyboard() -> InlineKeyboardMarkup: - """Build the help menu inline keyboard.""" - keyboard = [ - [ - InlineKeyboardButton("📊 Portfolio", callback_data="help:portfolio"), - InlineKeyboardButton("🤖 Bots", callback_data="help:bots"), - ], - [ - InlineKeyboardButton("💱 Swap", callback_data="help:swap"), - InlineKeyboardButton("📊 Trade", callback_data="help:trade"), - InlineKeyboardButton("💧 LP", callback_data="help:lp"), - ], - [ - InlineKeyboardButton("⚙️ Config", callback_data="help:config"), - ], - [ - InlineKeyboardButton("🔙 Back to Menu", callback_data="help:back"), - ], ] + if is_admin: + keyboard.append([InlineKeyboardButton("👑 Admin", callback_data="start:admin")]) + keyboard.append([InlineKeyboardButton("❌ Cancel", callback_data="start:cancel")]) return InlineKeyboardMarkup(keyboard) -HELP_TEXTS = { - "main": r""" -❓ *Help \- Command Guide* - -Select a command below to learn more about its features and usage: - -📊 *Portfolio* \- View holdings and performance -🤖 *Bots* \- Monitor trading bot status -💱 *Swap* \- Quick token swaps via DEX -📊 *Trade* \- Order book trading \(CEX/CLOB\) -💧 *LP* \- Liquidity pool management -⚙️ *Config* \- System configuration -""", - "portfolio": r""" -📊 *Portfolio Command* - -View your complete portfolio summary across all connected accounts\. - -*Features:* -• Real\-time balance overview by account -• PnL tracking with historical charts -• Holdings breakdown by asset -• Multi\-connector aggregation - -*Usage:* -• Tap the button or type `/portfolio` -• Use ⚙️ Settings to adjust the time period \(1d, 3d, 7d, 30d\) -• View performance graphs and detailed breakdowns - -*Tips:* -• Connect multiple accounts via Config to see aggregated portfolio -• PnL is calculated based on your configured time window -""", - "bots": r""" -🤖 *Bots Command* - -Monitor the status of all your active trading bots\. - -*Features:* -• View all running bot instances -• Check bot health and uptime -• See active strategies per bot -• Monitor trading activity - -*Usage:* -• Tap the button or type `/bots` -• View the status of each connected bot -• Check which strategies are currently active - -*Tips:* -• Ensure your API servers are properly configured in Config -• Bots must be running on connected Hummingbot instances -""", - "trade": r""" -📊 *Trade Command* - -Trade on Central Limit Order Book exchanges \(Spot \& Perpetual\)\. - -*Features:* -• Place market and limit orders -• Set leverage for perpetual trading -• View and manage open orders -• Monitor positions with PnL -• Quick account switching - -*Usage:* -• Tap the button or type `/trade` -• Select an account and connector -• Use the menu to place orders or view positions - -*Order Types:* -• 📝 *Place Order* \- Submit new orders -• ⚙️ *Set Leverage* \- Adjust perpetual leverage -• 🔍 *Orders Details* \- View/cancel open orders -• 📊 *Positions Details* \- Monitor active positions - -*Tips:* -• Always verify the selected account before trading -• Use limit orders for better price control -""", - "swap": r""" -💱 *Swap Command* - -Quick token swaps on Decentralized Exchanges via Gateway\. - -*Features:* -• Token swaps with real\-time quotes -• Multiple DEX router support -• Slippage configuration -• Swap history with status tracking - -*Usage:* -• Tap the button or type `/swap` -• Select network and token pair -• Get quote and execute - -*Operations:* -• 💰 *Quote* \- Get swap price estimates -• ✅ *Execute* \- Execute token swaps -• 🔍 *History* \- View past swaps - -*Tips:* -• Always check quotes before executing swaps -• Gateway must be running for DEX operations -""", - "lp": r""" -💧 *LP Command* - -Manage liquidity positions on CLMM pools\. - -*Features:* -• View LP positions with PnL -• Collect fees from positions -• Add/close positions -• Pool explorer with GeckoTerminal -• OHLCV charts and pool analytics - -*Usage:* -• Tap the button or type `/lp` -• View your positions or explore pools -• Manage fees and positions - -*Operations:* -• 📍 *Positions* \- View and manage LP positions -• 📋 *Pools* \- Browse available pools -• 🦎 *Explorer* \- GeckoTerminal pool discovery -• 📊 *Charts* \- View pool OHLCV data - -*Tips:* -• Monitor V/TVL ratio for pool activity -• Check APR and fee tiers before adding liquidity -""", - "config": r""" -⚙️ *Config Command* - -Configure your trading infrastructure and credentials\. - -*Sections:* - -🔌 *API Servers* -• Add/remove Hummingbot instances -• Configure connection endpoints -• Test server connectivity - -🔑 *API Keys* -• Manage exchange credentials -• Add new exchange API keys -• Securely store credentials - -🌐 *Gateway* -• Configure Gateway container -• Set up DEX chain connections -• Manage wallet credentials - -*Usage:* -• Tap the button or type `/config` -• Select the section you want to configure -• Follow the prompts to add or modify settings - -*Tips:* -• Keep your API keys secure -• Test connections after adding new servers -• Gateway is required for DEX trading -""", -} async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Start the conversation and display the main menu.""" - from utils.config import AUTHORIZED_USERS + """Start the conversation and display available commands (BotFather style).""" + from config_manager import get_config_manager, UserRole + from utils.auth import _notify_admin_new_user - chat_id = update.effective_chat.id user_id = update.effective_user.id username = update.effective_user.username or "No username" - # Check if user is authorized - if user_id not in AUTHORIZED_USERS: - reply_text = rf""" -🔒 *Access Restricted* + cm = get_config_manager() + role = cm.get_user_role(user_id) -You are not authorized to use this bot\. + # Handle blocked users + if role == UserRole.BLOCKED: + await update.message.reply_text("Access denied.") + return + + # Handle pending users + if role == UserRole.PENDING: + reply_text = f"""Access Pending -🆔 *Your Chat Info*: -📱 Chat ID: `{chat_id}` -👤 User ID: `{user_id}` +Your access request is awaiting admin approval. -Share this information with the bot administrator to request access\. -""" - await update.message.reply_text(reply_text, parse_mode="MarkdownV2") +Your Info: +User ID: {user_id} +Username: @{username} + +You will be notified when approved.""" + await update.message.reply_text(reply_text) return - # Clear all pending input states to prevent interference + # Handle new users - register as pending + if role is None: + is_new = cm.register_pending(user_id, username) + if is_new: + await _notify_admin_new_user(context, user_id, username) + + reply_text = f"""Access Request Submitted + +Your request has been sent to the admin for approval. + +Your Info: +User ID: {user_id} +Username: @{username} + +You will be notified when approved.""" + await update.message.reply_text(reply_text) + return + + # User is approved (USER or ADMIN role) clear_all_input_states(context) - reply_text = rf""" -🚀 *Welcome to Condor\!* 🦅 + reply_text = """I can help you create and manage trading bots on any CEX or DEX using Hummingbot API servers\\. + +See [this manual](https://hummingbot.org/condor/) if you're new to Condor\\. -Manage your trading bots efficiently and monitor their performance\. +You can control me by sending these commands: -🆔 *Your Chat Info*: -📱 Chat ID: `{chat_id}` -👤 User ID: `{user_id}` -🏷️ Username: `@{username}` +/keys \\- add exchange API keys +/portfolio \\- view balances across exchanges +/bots \\- deploy and manage trading bots +/trade \\- place CEX and DEX orders""" -Select a command below to get started: -""" - keyboard = _get_start_menu_keyboard() - await update.message.reply_text(reply_text, parse_mode="MarkdownV2", reply_markup=keyboard) + await update.message.reply_text(reply_text, parse_mode="MarkdownV2", disable_web_page_preview=True) @restricted @@ -296,19 +117,14 @@ async def start_callback_handler(update: Update, context: ContextTypes.DEFAULT_T data = query.data action = data.split(":")[1] if ":" in data else data - # Handle navigation to commands + # Handle cancel - delete the message + if action == "cancel": + await query.message.delete() + return + + # Handle navigation to config options if data.startswith("start:"): - if action == "portfolio": - await portfolio_command(update, context) - elif action == "bots": - await bots_command(update, context) - elif action == "trade": - await trade_command(update, context) - elif action == "swap": - await swap_command(update, context) - elif action == "lp": - await lp_command(update, context) - elif action == "config_servers": + if action == "config_servers": from handlers.config.servers import show_api_servers from handlers import clear_all_input_states clear_all_input_states(context) @@ -325,46 +141,11 @@ async def start_callback_handler(update: Update, context: ContextTypes.DEFAULT_T context.user_data.pop("dex_state", None) context.user_data.pop("cex_state", None) await show_gateway_menu(query, context) - elif action == "help": - await query.edit_message_text( - HELP_TEXTS["main"], - parse_mode="MarkdownV2", - reply_markup=_get_help_keyboard() - ) - - # Handle help submenu - elif data.startswith("help:"): - if action == "back": - # Go back to main start menu - chat_id = update.effective_chat.id - user_id = update.effective_user.id - username = update.effective_user.username or "No username" - - reply_text = rf""" -🚀 *Welcome to Condor\!* 🦅 - -Manage your trading bots efficiently and monitor their performance\. - -🆔 *Your Chat Info*: -📱 Chat ID: `{chat_id}` -👤 User ID: `{user_id}` -🏷️ Username: `@{username}` - -Select a command below to get started: -""" - await query.edit_message_text( - reply_text, - parse_mode="MarkdownV2", - reply_markup=_get_start_menu_keyboard() - ) - elif action in HELP_TEXTS: - # Show specific help with back button - keyboard = [[InlineKeyboardButton("🔙 Back to Help", callback_data="start:help")]] - await query.edit_message_text( - HELP_TEXTS[action], - parse_mode="MarkdownV2", - reply_markup=InlineKeyboardMarkup(keyboard) - ) + elif action == "admin": + from handlers.admin import _show_admin_menu + from handlers import clear_all_input_states + clear_all_input_states(context) + await _show_admin_menu(query, context) def reload_handlers(): @@ -375,6 +156,8 @@ def reload_handlers(): 'handlers.bots.menu', 'handlers.bots.controllers', 'handlers.bots._shared', + 'handlers.trading', + 'handlers.trading.router', 'handlers.cex', 'handlers.cex.menu', 'handlers.cex.trade', @@ -394,9 +177,11 @@ def reload_handlers(): 'handlers.config.gateway', 'handlers.config.user_preferences', 'handlers.routines', + 'handlers.admin', 'routines.base', 'utils.auth', 'utils.telegram_formatters', + 'config_manager', ] for module_name in modules_to_reload: @@ -409,10 +194,16 @@ def register_handlers(application: Application) -> None: """Register all command handlers.""" # Import fresh versions after reload from handlers.portfolio import portfolio_command, get_portfolio_callback_handler - from handlers.bots import bots_command, bots_callback_handler - from handlers.cex import trade_command, cex_callback_handler - from handlers.dex import swap_command, lp_command, dex_callback_handler - from handlers.config import config_command, get_config_callback_handler, get_modify_value_handler + from handlers.bots import bots_command, bots_callback_handler, get_bots_document_handler + from handlers.trading import trade_command as unified_trade_command + from handlers.trading.router import unified_trade_callback_handler + from handlers.cex import cex_callback_handler + from handlers.dex import lp_command, dex_callback_handler + from handlers.config import get_config_callback_handler, get_modify_value_handler + from handlers.config.servers import servers_command + from handlers.config.api_keys import keys_command + from handlers.config.gateway import gateway_command + from handlers.admin import admin_command from handlers.routines import routines_command, routines_callback_handler # Clear existing handlers @@ -422,14 +213,22 @@ def register_handlers(application: Application) -> None: application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("portfolio", portfolio_command)) application.add_handler(CommandHandler("bots", bots_command)) - application.add_handler(CommandHandler("swap", swap_command)) - application.add_handler(CommandHandler("trade", trade_command)) + application.add_handler(CommandHandler("trade", unified_trade_command)) # Unified trade (CEX + DEX) + application.add_handler(CommandHandler("swap", unified_trade_command)) # Alias for /trade application.add_handler(CommandHandler("lp", lp_command)) - application.add_handler(CommandHandler("config", config_command)) application.add_handler(CommandHandler("routines", routines_command)) + # Add configuration commands (direct access) + application.add_handler(CommandHandler("servers", servers_command)) + application.add_handler(CommandHandler("keys", keys_command)) + application.add_handler(CommandHandler("gateway", gateway_command)) + application.add_handler(CommandHandler("admin", admin_command)) + # Add callback query handler for start menu navigation - application.add_handler(CallbackQueryHandler(start_callback_handler, pattern="^(start:|help:)")) + application.add_handler(CallbackQueryHandler(start_callback_handler, pattern="^start:")) + + # Add unified trade callback handler BEFORE cex/dex handlers (for connector switching) + application.add_handler(CallbackQueryHandler(unified_trade_callback_handler, pattern="^trade:")) # Add callback query handlers for trading operations application.add_handler(CallbackQueryHandler(cex_callback_handler, pattern="^cex:")) @@ -437,6 +236,10 @@ def register_handlers(application: Application) -> None: application.add_handler(CallbackQueryHandler(bots_callback_handler, pattern="^bots:")) application.add_handler(CallbackQueryHandler(routines_callback_handler, pattern="^routines:")) + # Add admin callback handler + from handlers.admin import admin_callback_handler + application.add_handler(CallbackQueryHandler(admin_callback_handler, pattern="^admin:")) + # Add callback query handler for portfolio settings application.add_handler(get_portfolio_callback_handler()) @@ -449,23 +252,65 @@ def register_handlers(application: Application) -> None: # competing for the same filter. application.add_handler(get_modify_value_handler()) + # Add document handler for file uploads (e.g., config files in /bots) + application.add_handler(get_bots_document_handler()) + logger.info("Handlers registered successfully") +async def sync_server_permissions() -> None: + """ + Ensure all servers in config have permission entries. + Registers any unregistered servers with admin as owner. + """ + from config_manager import get_config_manager + + cm = get_config_manager() + for server_name in cm.list_servers(): + cm.ensure_server_registered(server_name) + + logger.info("Synced server permissions") + + async def post_init(application: Application) -> None: """Register bot commands after initialization.""" + from telegram import BotCommandScopeChat + from utils.config import ADMIN_USER_ID + + # Sync server permissions (ensures all servers have ownership entries) + await sync_server_permissions() + + # Public commands (all users) commands = [ - BotCommand("start", "Welcome message and quick commands overview"), - BotCommand("portfolio", "View detailed portfolio breakdown by account and connector"), - BotCommand("bots", "Check status of all active trading bots"), - BotCommand("swap", "Quick token swaps via DEX routers"), - BotCommand("trade", "Order book trading (CEX/CLOB) with limit orders"), - BotCommand("lp", "Liquidity pool management and explorer"), - BotCommand("config", "Configure API servers and credentials"), + BotCommand("start", "Welcome message and server status"), + BotCommand("portfolio", "View detailed portfolio breakdown"), + BotCommand("bots", "Check status of all trading bots"), + BotCommand("trade", "Unified trading - CEX orders and DEX swaps"), + BotCommand("lp", "Liquidity pool management"), BotCommand("routines", "Run configurable Python scripts"), + BotCommand("servers", "Manage Hummingbot API servers"), + BotCommand("keys", "Configure exchange API credentials"), + BotCommand("gateway", "Deploy Gateway for DEX trading"), ] await application.bot.set_my_commands(commands) + # Admin-only commands (visible only to admin user in their command menu) + if ADMIN_USER_ID: + admin_commands = commands + [ + BotCommand("admin", "Admin panel - manage users and access"), + ] + try: + await application.bot.set_my_commands( + admin_commands, + scope=BotCommandScopeChat(chat_id=int(ADMIN_USER_ID)) + ) + except Exception as e: + logger.warning(f"Failed to set admin-specific commands: {e}") + + # Restore scheduled routine jobs from persistence + from handlers.routines import restore_scheduled_jobs + await restore_scheduled_jobs(application) + # Start file watcher asyncio.create_task(watch_and_reload(application)) @@ -508,6 +353,15 @@ def get_persistence() -> PicklePersistence: return PicklePersistence(filepath=persistence_path) +async def error_handler(update: object, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle errors gracefully.""" + if isinstance(context.error, NetworkError): + logger.warning(f"Network error (will retry): {context.error}") + return + + logger.exception("Exception while handling an update:", exc_info=context.error) + + def main() -> None: """Run the bot.""" # Setup persistence to save user data, chat data, and bot data @@ -526,6 +380,9 @@ def main() -> None: # Register all handlers register_handlers(application) + # Register error handler + application.add_error_handler(error_handler) + # Run the bot application.run_polling(allowed_updates=Update.ALL_TYPES) diff --git a/requirements.txt b/requirements.txt index 2f98647..e8de60c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ python-telegram-bot[job-queue] -hummingbot-api-client==1.2.5 +hummingbot-api-client==1.2.6 python-dotenv pytest pre-commit diff --git a/routines/arb_check.py b/routines/arb_check.py index 6405ea1..aadb1b8 100644 --- a/routines/arb_check.py +++ b/routines/arb_check.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from telegram.ext import ContextTypes -from servers import get_client +from config_manager import get_client class Config(BaseModel): @@ -20,7 +20,7 @@ class Config(BaseModel): async def run(config: Config, context: ContextTypes.DEFAULT_TYPE) -> str: """Check arbitrage between CEX and DEX.""" chat_id = context._chat_id if hasattr(context, '_chat_id') else None - client = await get_client(chat_id) + client = await get_client(chat_id, context=context) if not client: return "No server available. Configure servers in /config." diff --git a/routines/base.py b/routines/base.py index 99c6d60..6ba0fb5 100644 --- a/routines/base.py +++ b/routines/base.py @@ -2,8 +2,9 @@ Base classes and discovery for routines. Routine Types: -- Interval: Has `interval_sec` field in Config → runs repeatedly at interval -- One-shot: No `interval_sec` field → runs once and returns result +- One-shot: Runs once and returns result. Can be scheduled externally. +- Continuous: Has CONTINUOUS = True. Contains internal loop (while True). + Runs forever until cancelled. Handles its own timing. """ import importlib @@ -26,27 +27,29 @@ def __init__( name: str, config_class: type[BaseModel], run_fn: Callable[[BaseModel, Any], Awaitable[str]], + is_continuous: bool = False, + callback_handler: Callable | None = None, + message_handler: Callable | None = None, + message_states: list[str] | None = None, + cleanup_fn: Callable | None = None, ): self.name = name self.config_class = config_class self.run_fn = run_fn + self._is_continuous = is_continuous + self.callback_handler = callback_handler + self.message_handler = message_handler + self.message_states = message_states or [] + self.cleanup_fn = cleanup_fn # Extract description from Config docstring doc = config_class.__doc__ or name self.description = doc.strip().split("\n")[0] @property - def is_interval(self) -> bool: - """Check if this is an interval routine (has interval_sec field).""" - return "interval_sec" in self.config_class.model_fields - - @property - def default_interval(self) -> int: - """Get default interval in seconds (only for interval routines).""" - if not self.is_interval: - return 0 - field = self.config_class.model_fields["interval_sec"] - return field.default if field.default is not None else 5 + def is_continuous(self) -> bool: + """Check if this is a continuous routine (has CONTINUOUS = True in module).""" + return self._is_continuous def get_default_config(self) -> BaseModel: """Create config instance with default values.""" @@ -73,6 +76,7 @@ def discover_routines(force_reload: bool = False) -> dict[str, RoutineInfo]: Each routine module needs: - Config: Pydantic BaseModel with optional docstring description - run(config, context) -> str: Async function that executes the routine + - CONTINUOUS = True (optional): Mark as continuous routine with internal loop Args: force_reload: Force reimport of all modules @@ -106,12 +110,26 @@ def discover_routines(force_reload: bool = False) -> dict[str, RoutineInfo]: logger.warning(f"Routine {file_path.stem}: missing Config or run") continue + # Check for CONTINUOUS flag + is_continuous = getattr(module, "CONTINUOUS", False) + + # Detect optional handlers + callback_handler = getattr(module, "handle_callback", None) + message_handler = getattr(module, "handle_message", None) + message_states = getattr(module, "MESSAGE_STATES", None) + cleanup_fn = getattr(module, "cleanup", None) + routines[file_path.stem] = RoutineInfo( name=file_path.stem, config_class=module.Config, run_fn=module.run, + is_continuous=is_continuous, + callback_handler=callback_handler, + message_handler=message_handler, + message_states=message_states, + cleanup_fn=cleanup_fn, ) - logger.debug(f"Discovered routine: {file_path.stem}") + logger.debug(f"Discovered routine: {file_path.stem} (continuous={is_continuous})") except Exception as e: logger.error(f"Failed to load routine {file_path.stem}: {e}") @@ -123,3 +141,11 @@ def discover_routines(force_reload: bool = False) -> dict[str, RoutineInfo]: def get_routine(name: str) -> RoutineInfo | None: """Get a specific routine by name.""" return discover_routines().get(name) + + +def get_routine_by_state(state: str) -> RoutineInfo | None: + """Find a routine that handles the given message state.""" + for routine in discover_routines().values(): + if state in routine.message_states: + return routine + return None diff --git a/routines/lp_monitor.py b/routines/lp_monitor.py new file mode 100644 index 0000000..2d49362 --- /dev/null +++ b/routines/lp_monitor.py @@ -0,0 +1,782 @@ +""" +LP Monitor - Monitor LP positions for out-of-range and rebalance opportunities. + +Features: +- Alerts when positions go out of range (with Close button) +- Alerts when positions return to range +- Rebalance suggestions when base asset % drops below threshold +- Periodic status reports +- Optional auto-close or auto-rebalance with 30s countdown +""" + +import asyncio +import logging +import time +from pydantic import BaseModel, Field +from telegram import InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import ContextTypes + +from config_manager import get_client +from utils.telegram_formatters import escape_markdown_v2, resolve_token_symbol, KNOWN_TOKENS + +logger = logging.getLogger(__name__) + +CONTINUOUS = True + + +class Config(BaseModel): + """Monitor LP positions for out-of-range and rebalance opportunities.""" + + check_interval_sec: int = Field( + default=60, + description="Check interval in seconds", + ) + status_report_hours: float = Field( + default=4.0, + description="Send status report every N hours (0=off)", + ) + rebalance_base_pct: float = Field( + default=0.2, + description="Suggest rebalance when base asset <= N (0.2=20%)", + ) + auto_close_oor: bool = Field( + default=False, + description="Auto-close out-of-range positions (30s delay)", + ) + auto_rebalance: bool = Field( + default=False, + description="Auto-rebalance when triggered (30s delay)", + ) + + +# ============================================================================= +# Helpers +# ============================================================================= + +def _get_pos_id(pos: dict) -> str: + return pos.get('id') or pos.get('position_id') or pos.get('address', '') or pos.get('position_address', '') + + +def _format_price(price: float) -> str: + if price >= 1: + return f"{price:.2f}" + elif price >= 0.001: + return f"{price:.4f}" + else: + return f"{price:.6f}" + + +async def _fetch_token_prices(client) -> dict: + """Fetch token prices from portfolio.""" + prices = {} + try: + if hasattr(client, 'portfolio'): + result = await client.portfolio.get_state() + if result: + for account_data in result.values(): + for balances in account_data.values(): + if balances: + for b in balances: + if b.get("token") and b.get("price"): + prices[b["token"]] = b["price"] + except Exception: + pass + return prices + + +def _get_price(symbol: str, prices: dict, default: float = 0) -> float: + if symbol in prices: + return prices[symbol] + for k, v in prices.items(): + if k.lower() == symbol.lower(): + return v + # Wrapped variants + variants = {"sol": "wsol", "wsol": "sol", "eth": "weth", "weth": "eth"} + alt = variants.get(symbol.lower()) + if alt: + for k, v in prices.items(): + if k.lower() == alt: + return v + return default + + +def _calc_base_pct(pos: dict, token_prices: dict, base_symbol: str, quote_symbol: str) -> float: + """Calculate what percentage of position value is in base asset.""" + base_amt = float(pos.get('base_token_amount', pos.get('amount_a', 0)) or 0) + quote_amt = float(pos.get('quote_token_amount', pos.get('amount_b', 0)) or 0) + + base_price = _get_price(base_symbol, token_prices, 0) + quote_price = _get_price(quote_symbol, token_prices, 1.0) + + base_value = base_amt * base_price + quote_value = quote_amt * quote_price + total = base_value + quote_value + + if total <= 0: + return 0.5 # Default to 50% if can't calculate + return base_value / total + + +def _draw_liquidity_bar(base_pct: float, width: int = 20) -> str: + """Draw a simple liquidity distribution bar.""" + base_blocks = int(base_pct * width) + quote_blocks = width - base_blocks + return f"[{'█' * base_blocks}{'░' * quote_blocks}]" + + +# ============================================================================= +# Position Formatting +# ============================================================================= + +def _format_position( + pos: dict, + token_cache: dict, + token_prices: dict, + index: int = None, +) -> str: + """Format a position for display.""" + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + base_sym = resolve_token_symbol(base_token, token_cache) + quote_sym = resolve_token_symbol(quote_token, token_cache) + pair = f"{base_sym}-{quote_sym}" + + connector = pos.get('connector', 'unknown') + fee = pos.get('fee_tier', pos.get('fee', '')) + fee_str = f" {fee}%" if fee else "" + + # Range status + in_range = pos.get('in_range', '') + status = "🟢" if in_range == "IN_RANGE" else "🔴" if in_range == "OUT_OF_RANGE" else "⚪" + + # Prices + lower = float(pos.get('lower_price', pos.get('price_lower', 0)) or 0) + upper = float(pos.get('upper_price', pos.get('price_upper', 0)) or 0) + current = float(pos.get('current_price', 0) or 0) + + width_pct = ((upper - lower) / lower * 100) if lower > 0 else 0 + + # Value & PnL + quote_price = _get_price(quote_sym, token_prices, 1.0) + base_price = _get_price(base_sym, token_prices, 0) + + pnl_summary = pos.get('pnl_summary', {}) + pnl_usd = float(pnl_summary.get('total_pnl_quote', 0) or 0) * quote_price + value_usd = float(pnl_summary.get('current_lp_value_quote', 0) or 0) * quote_price + + base_fee = float(pos.get('base_fee_pending', 0) or 0) + quote_fee = float(pos.get('quote_fee_pending', 0) or 0) + fees_usd = (base_fee * base_price) + (quote_fee * quote_price) + + # Base asset percentage + base_pct = _calc_base_pct(pos, token_prices, base_sym, quote_sym) + liq_bar = _draw_liquidity_bar(base_pct) + + pnl_sign = "\\+" if pnl_usd >= 0 else "\\-" + prefix = f"{index}\\. " if index else "" + + lines = [ + f"{prefix}*{escape_markdown_v2(pair)}{escape_markdown_v2(fee_str)}* \\({escape_markdown_v2(connector.capitalize())}\\)", + f" {status} \\[{escape_markdown_v2(_format_price(lower))} \\- {escape_markdown_v2(_format_price(upper))}\\] \\({escape_markdown_v2(f'{width_pct:.0f}')}% width\\)", + f" Price: {escape_markdown_v2(_format_price(current))}", + f" {liq_bar} {escape_markdown_v2(f'{base_pct*100:.0f}')}% {escape_markdown_v2(base_sym)}", + f" 💰 ${escape_markdown_v2(f'{value_usd:.2f}')} \\| PnL: {pnl_sign}${escape_markdown_v2(f'{abs(pnl_usd):.2f}')} \\| 🎁 ${escape_markdown_v2(f'{fees_usd:.2f}')}", + ] + return "\n".join(lines) + + +# ============================================================================= +# Notifications +# ============================================================================= + +async def _notify_out_of_range( + context, chat_id: int, pos: dict, token_cache: dict, token_prices: dict, + instance_id: str, user_data: dict, auto_close: bool +) -> None: + """Notify when position goes out of range.""" + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + current = float(pos.get('current_price', 0) or 0) + lower = float(pos.get('lower_price', 0) or 0) + upper = float(pos.get('upper_price', 0) or 0) + + direction = "▼ Below" if current < lower else "▲ Above" + + # Store position for callbacks + pos_id = _get_pos_id(pos) + cache_key = f"lpm_{instance_id}_{pos_id[:8]}" + user_data.setdefault("positions_cache", {})[cache_key] = pos + + text = ( + f"🔴 *Position Out of Range*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"_{escape_markdown_v2(direction)} range_\n" + f"Current: {escape_markdown_v2(_format_price(current))}\n" + f"Range: {escape_markdown_v2(_format_price(lower))} \\- {escape_markdown_v2(_format_price(upper))}" + ) + + keyboard = [[ + InlineKeyboardButton("❌ Close Position", callback_data=f"dex:pos_close:{cache_key}"), + InlineKeyboardButton("✅ Dismiss", callback_data=f"dex:lpm_dismiss:{instance_id}"), + ]] + + try: + msg = await context.bot.send_message( + chat_id=chat_id, text=text, parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + if auto_close: + asyncio.create_task(_countdown_action( + context, chat_id, msg.message_id, pos, "close", + cache_key, instance_id, user_data, token_cache + )) + except Exception as e: + logger.error(f"Failed to send OOR notification: {e}") + + +async def _notify_back_in_range(context, chat_id: int, pos: dict, token_cache: dict) -> None: + """Notify when position returns to range.""" + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + current = float(pos.get('current_price', 0) or 0) + + text = ( + f"🟢 *Position Back in Range*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Current: {escape_markdown_v2(_format_price(current))}" + ) + + try: + await context.bot.send_message(chat_id=chat_id, text=text, parse_mode="MarkdownV2") + except Exception as e: + logger.error(f"Failed to send back-in-range notification: {e}") + + +async def _notify_rebalance( + context, chat_id: int, pos: dict, token_cache: dict, token_prices: dict, + base_pct: float, instance_id: str, user_data: dict, auto_rebalance: bool +) -> None: + """Notify when position should be rebalanced (base asset too low).""" + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + # Current distribution + current_bar = _draw_liquidity_bar(base_pct) + # Proposed bid-ask distribution (roughly 50/50 at current price) + proposed_bar = _draw_liquidity_bar(0.5) + + lower = float(pos.get('lower_price', 0) or 0) + upper = float(pos.get('upper_price', 0) or 0) + current = float(pos.get('current_price', 0) or 0) + + pos_id = _get_pos_id(pos) + cache_key = f"lpm_{instance_id}_{pos_id[:8]}" + user_data.setdefault("positions_cache", {})[cache_key] = pos + + text = ( + f"⚖️ *Rebalance Suggestion*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Base asset dropped to {escape_markdown_v2(f'{base_pct*100:.0f}')}%\n\n" + f"*Current Distribution:*\n" + f"`{current_bar}` {escape_markdown_v2(f'{base_pct*100:.0f}')}% {escape_markdown_v2(base_sym)}\n\n" + f"*Proposed \\(Bid\\-Ask\\):*\n" + f"`{proposed_bar}` 50% {escape_markdown_v2(base_sym)}\n\n" + f"Range: {escape_markdown_v2(_format_price(lower))} \\- {escape_markdown_v2(_format_price(upper))}\n" + f"Price: {escape_markdown_v2(_format_price(current))}" + ) + + keyboard = [[ + InlineKeyboardButton("🔄 Rebalance", callback_data=f"dex:lpm_rebal:{cache_key}"), + InlineKeyboardButton("✅ Dismiss", callback_data=f"dex:lpm_dismiss:{instance_id}"), + ]] + + try: + msg = await context.bot.send_message( + chat_id=chat_id, text=text, parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + + if auto_rebalance: + asyncio.create_task(_countdown_action( + context, chat_id, msg.message_id, pos, "rebalance", + cache_key, instance_id, user_data, token_cache + )) + except Exception as e: + logger.error(f"Failed to send rebalance notification: {e}") + + +# ============================================================================= +# Auto Actions with Countdown +# ============================================================================= + +async def _countdown_action( + context, chat_id: int, msg_id: int, pos: dict, action: str, + cache_key: str, instance_id: str, user_data: dict, token_cache: dict +) -> None: + """30s countdown before auto action. Can be cancelled.""" + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + cancel_key = f"cancel_{cache_key}" + action_name = "Close" if action == "close" else "Rebalance" + + for remaining in [30, 20, 10, 5]: + if user_data.get(cancel_key): + user_data.pop(cancel_key, None) + return + + text = ( + f"⏱️ *Auto\\-{action_name} in {remaining}s*\n\n" + f"*{escape_markdown_v2(pair)}*\n\n" + f"_Press Cancel to stop_" + ) + + keyboard = [[ + InlineKeyboardButton("❌ Cancel", callback_data=f"dex:lpm_cancel:{cache_key}"), + ]] + + try: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, text=text, + parse_mode="MarkdownV2", reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception: + pass + + wait = 10 if remaining > 10 else 5 + await asyncio.sleep(wait) + + # Final check + if user_data.get(cancel_key): + user_data.pop(cancel_key, None) + return + + # Execute action + client = await get_client(chat_id, context=context) + if not client: + return + + if action == "close": + await _execute_close(context, chat_id, msg_id, pos, client, token_cache) + else: + await _execute_rebalance(context, chat_id, msg_id, pos, client, token_cache) + + +async def _execute_close(context, chat_id: int, msg_id: int, pos: dict, client, token_cache: dict) -> None: + """Execute position close.""" + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + connector = pos.get('connector', 'meteora') + network = pos.get('network', 'solana-mainnet-beta') + pos_addr = pos.get('position_address', pos.get('nft_id', '')) + + try: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"⏳ *Closing\\.\\.\\.*\n\n*{escape_markdown_v2(pair)}*", + parse_mode="MarkdownV2" + ) + + result = await client.gateway_clmm.close_position( + connector=connector, network=network, position_address=pos_addr + ) + + if result: + tx = result.get('tx_hash', 'N/A')[:16] + base_amt = float(result.get('base_amount', 0) or 0) + quote_amt = float(result.get('quote_amount', 0) or 0) + + text = ( + f"✅ *Position Closed*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Received: {escape_markdown_v2(f'{base_amt:.6f}')} {escape_markdown_v2(base_sym)}\n" + f"Received: {escape_markdown_v2(f'{quote_amt:.6f}')} {escape_markdown_v2(quote_sym)}\n\n" + f"Tx: `{escape_markdown_v2(tx)}...`" + ) + else: + text = f"❌ *Close failed*\n\nNo response" + + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, text=text, parse_mode="MarkdownV2" + ) + except Exception as e: + logger.error(f"Close failed: {e}") + try: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"❌ *Close failed*\n\n{escape_markdown_v2(str(e)[:100])}", + parse_mode="MarkdownV2" + ) + except Exception: + pass + + +async def _execute_rebalance(context, chat_id: int, msg_id: int, pos: dict, client, token_cache: dict) -> None: + """Execute rebalance: close and reopen with bid-ask strategy.""" + from decimal import Decimal + + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + + connector = pos.get('connector', 'meteora') + network = pos.get('network', 'solana-mainnet-beta') + pos_addr = pos.get('position_address', pos.get('nft_id', '')) + pool_addr = pos.get('pool_id', pos.get('pool_address', '')) + lower = pos.get('lower_price', pos.get('price_lower', 0)) + upper = pos.get('upper_price', pos.get('price_upper', 0)) + + try: + # Step 1: Close + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"🔄 *Rebalancing\\.\\.\\.*\n\n*{escape_markdown_v2(pair)}*\n\n1/2 Closing position\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + close_result = await client.gateway_clmm.close_position( + connector=connector, network=network, position_address=pos_addr + ) + + if not close_result: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"❌ *Rebalance failed*\n\nCould not close position", + parse_mode="MarkdownV2" + ) + return + + base_amt = float(close_result.get('base_amount', close_result.get('amount_base', 0)) or 0) + quote_amt = float(close_result.get('quote_amount', close_result.get('amount_quote', 0)) or 0) + + # Fallback to position amounts + if not base_amt: + base_amt = float(pos.get('base_token_amount', pos.get('amount_a', 0)) or 0) + if not quote_amt: + quote_amt = float(pos.get('quote_token_amount', pos.get('amount_b', 0)) or 0) + + # Step 2: Reopen with bid-ask strategy + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"🔄 *Rebalancing\\.\\.\\.*\n\n*{escape_markdown_v2(pair)}*\n\n✅ Closed\n2/2 Opening with bid\\-ask\\.\\.\\.", + parse_mode="MarkdownV2" + ) + + open_result = await client.gateway_clmm.open_position( + connector=connector, + network=network, + pool_address=pool_addr, + lower_price=Decimal(str(lower)), + upper_price=Decimal(str(upper)), + base_token_amount=base_amt, + quote_token_amount=quote_amt, + extra_params={"strategyType": 2} # Bid-Ask + ) + + if open_result: + tx = (open_result.get('tx_hash') or open_result.get('signature', 'N/A'))[:16] + text = ( + f"✅ *Rebalance Complete*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Strategy: Bid\\-Ask\n" + f"Range: {escape_markdown_v2(_format_price(float(lower)))} \\- {escape_markdown_v2(_format_price(float(upper)))}\n\n" + f"Tx: `{escape_markdown_v2(tx)}...`" + ) + else: + text = ( + f"⚠️ *Partial Rebalance*\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Position closed but failed to reopen\\.\n" + f"Funds are in your wallet\\." + ) + + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, text=text, parse_mode="MarkdownV2" + ) + except Exception as e: + logger.error(f"Rebalance failed: {e}") + try: + await context.bot.edit_message_text( + chat_id=chat_id, message_id=msg_id, + text=f"❌ *Rebalance failed*\n\n{escape_markdown_v2(str(e)[:100])}", + parse_mode="MarkdownV2" + ) + except Exception: + pass + + +# ============================================================================= +# Status Report +# ============================================================================= + +async def _send_status_report( + context, chat_id: int, positions: list, token_cache: dict, token_prices: dict, + instance_id: str, user_data: dict +) -> None: + """Send periodic status report.""" + if not positions: + await context.bot.send_message( + chat_id=chat_id, + text="📊 *LP Monitor*\n\nNo active positions\\.", + parse_mode="MarkdownV2" + ) + return + + # Calculate totals + total_value = 0 + total_pnl = 0 + total_fees = 0 + in_range_count = 0 + + for pos in positions: + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + quote_price = _get_price(quote_sym, token_prices, 1.0) + base_price = _get_price(base_sym, token_prices, 0) + + pnl_summary = pos.get('pnl_summary', {}) + total_pnl += float(pnl_summary.get('total_pnl_quote', 0) or 0) * quote_price + total_value += float(pnl_summary.get('current_lp_value_quote', 0) or 0) * quote_price + + base_fee = float(pos.get('base_fee_pending', 0) or 0) + quote_fee = float(pos.get('quote_fee_pending', 0) or 0) + total_fees += (base_fee * base_price) + (quote_fee * quote_price) + + if pos.get('in_range') == "IN_RANGE": + in_range_count += 1 + + oor_count = len(positions) - in_range_count + pnl_sign = "\\+" if total_pnl >= 0 else "\\-" + + lines = [ + f"📊 *LP Monitor Status*", + f"━━━━━━━━━━━━━━━━━━━━━", + f"💰 ${escape_markdown_v2(f'{total_value:.2f}')} \\| PnL: {pnl_sign}${escape_markdown_v2(f'{abs(total_pnl):.2f}')} \\| 🎁 ${escape_markdown_v2(f'{total_fees:.2f}')}", + f"✅ {in_range_count} in range \\| 🔴 {oor_count} out of range", + "", + ] + + # Format each position + for i, pos in enumerate(positions, 1): + lines.append(_format_position(pos, token_cache, token_prices, i)) + lines.append("") + + text = "\n".join(lines) + + # Store positions for callbacks and build action buttons + keyboard = [] + for i, pos in enumerate(positions): + pos_id = _get_pos_id(pos) + cache_key = f"lpm_{instance_id}_{pos_id[:8]}" + user_data.setdefault("positions_cache", {})[cache_key] = pos + + # Get pair name for button + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + pair = f"{base_sym}-{quote_sym}" + in_range = pos.get('in_range', '') + status = "🟢" if in_range == "IN_RANGE" else "🔴" + + # Add row with close and rebalance for each position + keyboard.append([ + InlineKeyboardButton(f"{status} {pair}", callback_data="noop"), + InlineKeyboardButton("❌ Close", callback_data=f"dex:pos_close:{cache_key}"), + InlineKeyboardButton("🔄 Rebalance", callback_data=f"dex:lpm_rebal:{cache_key}"), + ]) + + keyboard.append([InlineKeyboardButton("💰 Collect All Fees", callback_data=f"dex:lpm_collect_all:{instance_id}")]) + keyboard.append([InlineKeyboardButton("⏹ Stop Monitor", callback_data=f"routines:stop:{instance_id}")]) + + try: + await context.bot.send_message( + chat_id=chat_id, text=text, parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) + except Exception as e: + logger.error(f"Failed to send status report: {e}") + + +# ============================================================================= +# Main Loop +# ============================================================================= + +async def run(config: Config, context: ContextTypes.DEFAULT_TYPE) -> str: + """Monitor LP positions.""" + logger.info(f"LP Monitor starting: interval={config.check_interval_sec}s, report={config.status_report_hours}h") + + chat_id = context._chat_id if hasattr(context, '_chat_id') else None + instance_id = getattr(context, '_instance_id', 'default') + + if not chat_id: + return "No chat_id" + + client = await get_client(chat_id, context=context) + if not client or not hasattr(client, 'gateway_clmm'): + return "Gateway not available" + + user_data = context.user_data if hasattr(context, 'user_data') else {} + + # State + state = { + "out_of_range": set(), # pos_ids currently OOR + "rebalance_notified": set(), # pos_ids we've notified for rebalance + "checks": 0, + "alerts": 0, + "start_time": time.time(), + "last_report": time.time(), + } + + # Start message + features = [] + if config.auto_close_oor: + features.append("Auto\\-close OOR") + if config.auto_rebalance: + features.append("Auto\\-rebalance") + if config.status_report_hours > 0: + features.append(f"Report every {config.status_report_hours}h") + + features_str = " \\| ".join(features) if features else "Manual mode" + + try: + await context.bot.send_message( + chat_id=chat_id, + text=( + f"🟢 *LP Monitor Started*\n" + f"━━━━━━━━━━━━━━━━━━━━━\n" + f"Checking every {config.check_interval_sec}s\n" + f"Rebalance trigger: ≤{escape_markdown_v2(f'{config.rebalance_base_pct*100:.0f}')}% base\n" + f"{features_str}" + ), + parse_mode="MarkdownV2" + ) + except Exception: + pass + + try: + while True: + try: + state["checks"] += 1 + + # Get client + client = await get_client(chat_id, context=context) + if not client or not hasattr(client, 'gateway_clmm'): + await asyncio.sleep(config.check_interval_sec) + continue + + # Fetch positions + result = await client.gateway_clmm.search_positions( + limit=100, offset=0, status="OPEN", refresh=True + ) + + if not result: + await asyncio.sleep(config.check_interval_sec) + continue + + positions = [p for p in result.get("data", []) + if p.get('status') != 'CLOSED' and + float(p.get('liquidity', p.get('current_liquidity', 1)) or 1) > 0] + + if not positions: + await asyncio.sleep(config.check_interval_sec) + continue + + # Build caches + token_cache = dict(KNOWN_TOKENS) + networks = set(p.get('network', 'solana-mainnet-beta') for p in positions) + if hasattr(client, 'gateway'): + for net in networks: + try: + resp = await client.gateway.get_network_tokens(net) + for t in (resp.get('tokens', []) if resp else []): + if t.get('address') and t.get('symbol'): + token_cache[t['address']] = t['symbol'] + except Exception: + pass + + token_prices = await _fetch_token_prices(client) + user_data["token_cache"] = token_cache + user_data["token_prices"] = token_prices + + # Check each position + current_oor = set() + + for pos in positions: + pos_id = _get_pos_id(pos) + in_range = pos.get('in_range', '') + + # Out-of-range detection + if in_range == "OUT_OF_RANGE": + current_oor.add(pos_id) + + if pos_id not in state["out_of_range"]: + # Newly out of range + state["alerts"] += 1 + await _notify_out_of_range( + context, chat_id, pos, token_cache, token_prices, + instance_id, user_data, config.auto_close_oor + ) + + elif pos_id in state["out_of_range"]: + # Back in range + await _notify_back_in_range(context, chat_id, pos, token_cache) + + # Rebalance check (only for in-range positions not already notified) + if in_range == "IN_RANGE" and pos_id not in state["rebalance_notified"]: + base_sym = resolve_token_symbol(pos.get('base_token', ''), token_cache) + quote_sym = resolve_token_symbol(pos.get('quote_token', ''), token_cache) + base_pct = _calc_base_pct(pos, token_prices, base_sym, quote_sym) + + if base_pct <= config.rebalance_base_pct: + state["rebalance_notified"].add(pos_id) + state["alerts"] += 1 + await _notify_rebalance( + context, chat_id, pos, token_cache, token_prices, + base_pct, instance_id, user_data, config.auto_rebalance + ) + + state["out_of_range"] = current_oor + + # Periodic status report + if config.status_report_hours > 0: + elapsed = time.time() - state["last_report"] + if elapsed >= config.status_report_hours * 3600: + state["last_report"] = time.time() + await _send_status_report( + context, chat_id, positions, token_cache, token_prices, + instance_id, user_data + ) + + # Log progress + if state["checks"] % 10 == 0: + logger.info(f"LP Monitor #{state['checks']}: {len(positions)} positions, {state['alerts']} alerts") + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"LP Monitor error: {e}", exc_info=True) + + await asyncio.sleep(config.check_interval_sec) + + except asyncio.CancelledError: + elapsed = int(time.time() - state["start_time"]) + mins, secs = divmod(elapsed, 60) + + try: + await context.bot.send_message( + chat_id=chat_id, + text=f"🔴 *LP Monitor Stopped*\n{mins}m {secs}s \\| {state['checks']} checks \\| {state['alerts']} alerts", + parse_mode="MarkdownV2" + ) + except Exception: + pass + + return f"Stopped: {mins}m {secs}s, {state['checks']} checks, {state['alerts']} alerts" diff --git a/routines/lp_tpsl.py b/routines/lp_tpsl.py new file mode 100644 index 0000000..2f3c468 --- /dev/null +++ b/routines/lp_tpsl.py @@ -0,0 +1,1001 @@ +"""LP Position TP/SL Monitor - Multi-position Take Profit and Stop Loss with runtime editing.""" + +import asyncio +import logging +import re +import time +from typing import NamedTuple, Literal +from pydantic import BaseModel, Field +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import ContextTypes + +from config_manager import get_client +from utils.telegram_formatters import escape_markdown_v2, resolve_token_symbol, KNOWN_TOKENS + +logger = logging.getLogger(__name__) + +# Mark as continuous routine - has internal loop +CONTINUOUS = True + +# Message states this routine handles (for generic routing) +MESSAGE_STATES = ["tpsl_interactive"] + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + +async def _fetch_token_prices(client) -> dict: + """Fetch token prices from portfolio state.""" + token_prices = {} + try: + if hasattr(client, 'portfolio'): + result = await client.portfolio.get_state() + if result: + for account_data in result.values(): + for balances in account_data.values(): + if balances: + for balance in balances: + token = balance.get("token", "") + price = balance.get("price", 0) + if token and price: + token_prices[token] = price + except Exception as e: + logger.debug(f"Could not fetch token prices: {e}") + return token_prices + + +def _get_price(symbol: str, token_prices: dict, default: float = 0) -> float: + """Get token price with fallbacks for wrapped variants.""" + if symbol in token_prices: + return token_prices[symbol] + symbol_lower = symbol.lower() + for key, price in token_prices.items(): + if key.lower() == symbol_lower: + return price + # Wrapped variants + variants = { + "sol": ["wsol"], "wsol": ["sol"], + "eth": ["weth"], "weth": ["eth"], + } + for variant in variants.get(symbol_lower, []): + for key, price in token_prices.items(): + if key.lower() == variant: + return price + return default + + +def _calculate_position_value_usd(pos: dict, token_cache: dict, token_prices: dict) -> float: + """Calculate position value in USD.""" + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + quote_symbol = resolve_token_symbol(quote_token, token_cache) + + quote_price = _get_price(quote_symbol, token_prices, 1.0) + + pnl_summary = pos.get('pnl_summary', {}) + current_lp_value_quote = float(pnl_summary.get('current_lp_value_quote', 0) or 0) + + return current_lp_value_quote * quote_price + + +def _get_user_data(context) -> dict: + """Get user_data from MockContext or regular context.""" + chat_id = context._chat_id if hasattr(context, '_chat_id') else None + if hasattr(context, 'application') and context.application: + app_user_data = context.application.user_data + if chat_id not in app_user_data: + app_user_data[chat_id] = {} + return app_user_data[chat_id] + return getattr(context, '_user_data', {}) + + +def _get_state(context, instance_id: str) -> dict: + """Get or initialize the TP/SL state for this instance.""" + user_data = _get_user_data(context) + key = f"lp_tpsl_{instance_id}" + if key not in user_data: + user_data[key] = { + "mode": "setup", + "available_positions": [], + "tracked_positions": {}, + "global_defaults": {"tp_pct": 10.0, "sl_pct": 10.0}, + "token_cache": {}, + "checks": 0, + "start_time": time.time(), + "last_check": None, + "status_msg_id": None, + } + return user_data[key] + + +async def _delete_after(msg, seconds: int): + """Delete a message after a delay.""" + await asyncio.sleep(seconds) + try: + await msg.delete() + except Exception: + pass + + +# ============================================================================= +# COMMAND PARSING +# ============================================================================= + +class ParsedCommand(NamedTuple): + type: Literal["select", "tp", "sl", "remove", "add", "status", "help", "unknown"] + position_num: int | None + value: float | None + value_type: str | None # "pct" or "usd" + + +def parse_tpsl_command(text: str) -> ParsedCommand: + """Parse runtime TP/SL command from user message.""" + text = text.strip().lower() + + # Position selection: just a number like "1" or "2" + if text.isdigit(): + return ParsedCommand("select", int(text), None, None) + + # Status command + if text in ("status", "s", "list", "ls"): + return ParsedCommand("status", None, None, None) + + # Add command + if text in ("add", "a", "+"): + return ParsedCommand("add", None, None, None) + + # Help command + if text in ("help", "h", "?"): + return ParsedCommand("help", None, None, None) + + # Remove command: "remove 2" or "rm 2" or "del 2" or "-2" + remove_match = re.match(r"(?:remove|rm|del|-)\s*(\d+)", text) + if remove_match: + return ParsedCommand("remove", int(remove_match.group(1)), None, None) + + # TP/SL with optional position: "1 tp=25%" or "tp=15%" or "sl=$50" + tpsl_match = re.match( + r"(?:(\d+)\s+)?(tp|sl)\s*=\s*(\$?)(\d+(?:\.\d+)?)(%?)", + text + ) + if tpsl_match: + pos_num = int(tpsl_match.group(1)) if tpsl_match.group(1) else None + cmd_type = tpsl_match.group(2) # "tp" or "sl" + is_usd = tpsl_match.group(3) == "$" + value = float(tpsl_match.group(4)) + is_pct = tpsl_match.group(5) == "%" + + value_type = "usd" if is_usd else "pct" + return ParsedCommand(cmd_type, pos_num, value, value_type) + + return ParsedCommand("unknown", None, None, None) + + +# ============================================================================= +# MESSAGE HANDLER (called from handlers/routines/__init__.py) +# ============================================================================= + +async def handle_tpsl_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Process interactive commands while routine is running.""" + text = update.message.text.strip() + instance_id = context.user_data.get("tpsl_active_instance") + + if not instance_id: + return + + state = _get_state(context, instance_id) + cmd = parse_tpsl_command(text) + + # Delete user message for cleaner interface + try: + await update.message.delete() + except Exception: + pass + + chat_id = update.effective_chat.id + + if cmd.type == "select": + await _handle_position_select(context, chat_id, state, cmd.position_num, instance_id) + elif cmd.type == "tp": + await _handle_set_tp(context, chat_id, state, cmd, instance_id) + elif cmd.type == "sl": + await _handle_set_sl(context, chat_id, state, cmd, instance_id) + elif cmd.type == "remove": + await _handle_remove_position(context, chat_id, state, cmd.position_num, instance_id) + elif cmd.type == "add": + await _show_available_positions(context, chat_id, state, instance_id) + elif cmd.type == "status": + await _show_status(context, chat_id, state, instance_id) + elif cmd.type == "help": + await _show_help(context, chat_id) + else: + msg = await context.bot.send_message( + chat_id=chat_id, + text="Unknown command\\. Send `help` for available commands\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + + +# ============================================================================= +# EXPORTED HANDLERS (for generic routing) +# ============================================================================= + + +async def handle_callback( + update: Update, + context: ContextTypes.DEFAULT_TYPE, + action: str, + params: list[str], +) -> None: + """ + Handle routine-specific callbacks. + + Callback patterns: + - routines:lp_tpsl:continue:{instance_id}:{pos_id_short} + - routines:lp_tpsl:remove:{instance_id}:{pos_id_short} + """ + query = update.callback_query + + if action == "continue" and len(params) >= 2: + instance_id, pos_id_short = params[0], params[1] + await query.answer("Continuing to monitor...") + state_key = f"lp_tpsl_{instance_id}" + if state_key in context.user_data: + state = context.user_data[state_key] + for pid, pdata in state.get("tracked_positions", {}).items(): + if pid.startswith(pos_id_short) or pid[:8] == pos_id_short: + pdata["triggered"] = None + break + try: + await query.message.delete() + except Exception: + pass + + elif action == "remove" and len(params) >= 2: + instance_id, pos_id_short = params[0], params[1] + await query.answer("Removed from monitor") + state_key = f"lp_tpsl_{instance_id}" + if state_key in context.user_data: + state = context.user_data[state_key] + tracked = state.get("tracked_positions", {}) + for pid in list(tracked.keys()): + if pid.startswith(pos_id_short) or pid[:8] == pos_id_short: + del tracked[pid] + break + try: + await query.message.delete() + except Exception: + pass + + else: + await query.answer() + + +async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> bool: + """Handle messages when routines_state is 'tpsl_interactive'.""" + instance_id = context.user_data.get("tpsl_active_instance") + if not instance_id: + return False + + await handle_tpsl_message(update, context) + return True + + +async def cleanup(context: ContextTypes.DEFAULT_TYPE, instance_id: str, chat_id: int) -> None: + """Clean up routine state when instance stops.""" + user_data = _get_user_data(context) + + # Clear interactive state if this instance was active + if user_data.get("tpsl_active_instance") == instance_id: + user_data.pop("routines_state", None) + user_data.pop("tpsl_active_instance", None) + + # Clear instance-specific state + user_data.pop(f"lp_tpsl_{instance_id}", None) + + +async def _handle_position_select(context, chat_id: int, state: dict, pos_num: int, instance_id: str): + """Handle position selection by number.""" + available = state.get("available_positions", []) + + if pos_num < 1 or pos_num > len(available): + msg = await context.bot.send_message( + chat_id=chat_id, + text=f"Invalid position number\\. Choose 1\\-{len(available)}\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + pos = available[pos_num - 1] + pos_id = pos.get('id') or pos.get('position_id') or pos.get('address', '') + + # Check if already tracked + if pos_id in state["tracked_positions"]: + msg = await context.bot.send_message( + chat_id=chat_id, + text="Position already being tracked\\!", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 3)) + return + + # Get position info + token_cache = state.get("token_cache", {}) + token_prices = state.get("token_prices", {}) + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + base_symbol = resolve_token_symbol(base_token, token_cache) + quote_symbol = resolve_token_symbol(quote_token, token_cache) + pair = f"{base_symbol}-{quote_symbol}" + connector = pos.get('connector', 'unknown') + network = pos.get('network', 'solana-mainnet-beta') + + entry_value = _calculate_position_value_usd(pos, token_cache, token_prices) + + # Get defaults + defaults = state.get("global_defaults", {}) + tp_pct = defaults.get("tp_pct", 10.0) + sl_pct = defaults.get("sl_pct", 10.0) + + # Add to tracked positions + state["tracked_positions"][pos_id] = { + "position_id": pos_id, + "pair": pair, + "connector": connector, + "network": network, + "base_token": base_token, + "quote_token": quote_token, + "entry_value_usd": entry_value, + "tp_pct": tp_pct, + "sl_pct": sl_pct, + "triggered": None, + "current_value_usd": entry_value, + "high_value_usd": entry_value, + "low_value_usd": entry_value, + "added_at": time.time(), + "display_num": len(state["tracked_positions"]) + 1, + } + + # Calculate TP/SL values + tp_value = entry_value * (1 + tp_pct / 100) + sl_value = entry_value * (1 - sl_pct / 100) + + # Store position for close button + user_data = _get_user_data(context) + if "positions_cache" not in user_data: + user_data["positions_cache"] = {} + cache_key = f"tpsl_{instance_id}_{pos_id[:8]}" + user_data["positions_cache"][cache_key] = pos + + await context.bot.send_message( + chat_id=chat_id, + text=( + f"\\#️⃣ *Added Position \\#{len(state['tracked_positions'])}*\n\n" + f"*{escape_markdown_v2(pair)}* \\({escape_markdown_v2(connector)}\\)\n" + f"Entry: ${escape_markdown_v2(f'{entry_value:.2f}')}\n\n" + f"📈 TP: \\+{escape_markdown_v2(f'{tp_pct:.0f}')}% \\(${escape_markdown_v2(f'{tp_value:.2f}')}\\)\n" + f"📉 SL: \\-{escape_markdown_v2(f'{sl_pct:.0f}')}% \\(${escape_markdown_v2(f'{sl_value:.2f}')}\\)\n\n" + f"_Send another number to add more, or `status` to view all_" + ), + parse_mode="MarkdownV2" + ) + + +async def _handle_set_tp(context, chat_id: int, state: dict, cmd: ParsedCommand, instance_id: str): + """Handle setting take profit.""" + tracked = state.get("tracked_positions", {}) + + if not tracked: + msg = await context.bot.send_message( + chat_id=chat_id, + text="No positions tracked yet\\. Add one first\\!", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + value = cmd.value + value_type = cmd.value_type + + if cmd.position_num: + # Update specific position + pos_list = list(tracked.values()) + if cmd.position_num < 1 or cmd.position_num > len(pos_list): + msg = await context.bot.send_message( + chat_id=chat_id, + text=f"Invalid position number\\. Choose 1\\-{len(pos_list)}\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + pos_data = pos_list[cmd.position_num - 1] + if value_type == "usd": + # Convert USD to percentage + entry = pos_data["entry_value_usd"] + if entry > 0: + pos_data["tp_pct"] = ((value - entry) / entry) * 100 + else: + pos_data["tp_pct"] = value + + tp_value = pos_data["entry_value_usd"] * (1 + pos_data["tp_pct"] / 100) + await context.bot.send_message( + chat_id=chat_id, + text=f"\\#️⃣{cmd.position_num} *{escape_markdown_v2(pos_data['pair'])}* TP updated to \\+{escape_markdown_v2(f'{pos_data['tp_pct']:.1f}')}% \\(${escape_markdown_v2(f'{tp_value:.2f}')}\\)", + parse_mode="MarkdownV2" + ) + else: + # Update all positions + for pos_data in tracked.values(): + if value_type == "usd": + entry = pos_data["entry_value_usd"] + if entry > 0: + pos_data["tp_pct"] = ((value - entry) / entry) * 100 + else: + pos_data["tp_pct"] = value + + # Also update default + if value_type == "pct": + state["global_defaults"]["tp_pct"] = value + + await context.bot.send_message( + chat_id=chat_id, + text=f"All positions TP updated to \\+{escape_markdown_v2(f'{value:.1f}')}{'%' if value_type == 'pct' else ''}", + parse_mode="MarkdownV2" + ) + + +async def _handle_set_sl(context, chat_id: int, state: dict, cmd: ParsedCommand, instance_id: str): + """Handle setting stop loss.""" + tracked = state.get("tracked_positions", {}) + + if not tracked: + msg = await context.bot.send_message( + chat_id=chat_id, + text="No positions tracked yet\\. Add one first\\!", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + value = cmd.value + value_type = cmd.value_type + + if cmd.position_num: + # Update specific position + pos_list = list(tracked.values()) + if cmd.position_num < 1 or cmd.position_num > len(pos_list): + msg = await context.bot.send_message( + chat_id=chat_id, + text=f"Invalid position number\\. Choose 1\\-{len(pos_list)}\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + pos_data = pos_list[cmd.position_num - 1] + if value_type == "usd": + # Convert USD to percentage (SL is a loss, so negative) + entry = pos_data["entry_value_usd"] + if entry > 0: + pos_data["sl_pct"] = ((entry - value) / entry) * 100 + else: + pos_data["sl_pct"] = value + + sl_value = pos_data["entry_value_usd"] * (1 - pos_data["sl_pct"] / 100) + await context.bot.send_message( + chat_id=chat_id, + text=f"\\#️⃣{cmd.position_num} *{escape_markdown_v2(pos_data['pair'])}* SL updated to \\-{escape_markdown_v2(f'{pos_data['sl_pct']:.1f}')}% \\(${escape_markdown_v2(f'{sl_value:.2f}')}\\)", + parse_mode="MarkdownV2" + ) + else: + # Update all positions + for pos_data in tracked.values(): + if value_type == "usd": + entry = pos_data["entry_value_usd"] + if entry > 0: + pos_data["sl_pct"] = ((entry - value) / entry) * 100 + else: + pos_data["sl_pct"] = value + + # Also update default + if value_type == "pct": + state["global_defaults"]["sl_pct"] = value + + await context.bot.send_message( + chat_id=chat_id, + text=f"All positions SL updated to \\-{escape_markdown_v2(f'{value:.1f}')}{'%' if value_type == 'pct' else ''}", + parse_mode="MarkdownV2" + ) + + +async def _handle_remove_position(context, chat_id: int, state: dict, pos_num: int, instance_id: str): + """Handle removing a position from tracking.""" + tracked = state.get("tracked_positions", {}) + + if not tracked: + msg = await context.bot.send_message( + chat_id=chat_id, + text="No positions tracked\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 3)) + return + + pos_list = list(tracked.items()) + if pos_num < 1 or pos_num > len(pos_list): + msg = await context.bot.send_message( + chat_id=chat_id, + text=f"Invalid position number\\. Choose 1\\-{len(pos_list)}\\.", + parse_mode="MarkdownV2" + ) + asyncio.create_task(_delete_after(msg, 5)) + return + + pos_id, pos_data = pos_list[pos_num - 1] + pair = pos_data.get("pair", "Unknown") + + del state["tracked_positions"][pos_id] + + await context.bot.send_message( + chat_id=chat_id, + text=f"Removed *{escape_markdown_v2(pair)}* from tracking\\.", + parse_mode="MarkdownV2" + ) + + +async def _show_available_positions(context, chat_id: int, state: dict, instance_id: str): + """Show available positions for adding.""" + available = state.get("available_positions", []) + tracked = state.get("tracked_positions", {}) + token_cache = state.get("token_cache", {}) + token_prices = state.get("token_prices", {}) + + if not available: + await context.bot.send_message( + chat_id=chat_id, + text="No LP positions available\\.", + parse_mode="MarkdownV2" + ) + return + + lines = ["📋 *Available Positions*\n"] + + for i, pos in enumerate(available, 1): + pos_id = pos.get('id') or pos.get('position_id') or pos.get('address', '') + is_tracked = pos_id in tracked + + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + base_symbol = resolve_token_symbol(base_token, token_cache) + quote_symbol = resolve_token_symbol(quote_token, token_cache) + pair = f"{base_symbol}-{quote_symbol}" + + connector = pos.get('connector', 'unknown')[:3] + in_range = pos.get('in_range', '') + status_emoji = "🟢" if in_range == "IN_RANGE" else "🔴" if in_range == "OUT_OF_RANGE" else "⚪" + + value = _calculate_position_value_usd(pos, token_cache, token_prices) + + tracked_mark = " ✓" if is_tracked else "" + line = f"{i}\\. {escape_markdown_v2(pair)} \\({escape_markdown_v2(connector)}\\) {status_emoji} ${escape_markdown_v2(f'{value:.2f}')}{tracked_mark}" + lines.append(line) + + lines.append("") + lines.append("_Send position number to add \\(e\\.g\\. '1'\\)_") + + await context.bot.send_message( + chat_id=chat_id, + text="\n".join(lines), + parse_mode="MarkdownV2" + ) + + +async def _show_status(context, chat_id: int, state: dict, instance_id: str): + """Show current status of tracked positions.""" + tracked = state.get("tracked_positions", {}) + + if not tracked: + await context.bot.send_message( + chat_id=chat_id, + text="No positions being tracked\\. Send a position number to add one\\.", + parse_mode="MarkdownV2" + ) + return + + lines = [f"🎯 *Tracking {len(tracked)} position{'s' if len(tracked) > 1 else ''}*\n"] + + for i, (pos_id, pos_data) in enumerate(tracked.items(), 1): + pair = pos_data.get("pair", "Unknown") + connector = pos_data.get("connector", "?")[:3] + entry = pos_data.get("entry_value_usd", 0) + current = pos_data.get("current_value_usd", entry) + tp_pct = pos_data.get("tp_pct", 0) + sl_pct = pos_data.get("sl_pct", 0) + triggered = pos_data.get("triggered") + + change_pct = ((current - entry) / entry * 100) if entry > 0 else 0 + change_str = f"+{change_pct:.1f}%" if change_pct >= 0 else f"{change_pct:.1f}%" + + tp_value = entry * (1 + tp_pct / 100) + sl_value = entry * (1 - sl_pct / 100) + + status = "" + if triggered == "TP": + status = " 🎯 TP HIT" + elif triggered == "SL": + status = " 🚨 SL HIT" + + lines.append(f"*{i}\\. {escape_markdown_v2(pair)}* \\({escape_markdown_v2(connector)}\\){status}") + lines.append(f" Entry: ${escape_markdown_v2(f'{entry:.2f}')} → ${escape_markdown_v2(f'{current:.2f}')} \\({escape_markdown_v2(change_str)}\\)") + lines.append(f" 📈 TP: \\+{escape_markdown_v2(f'{tp_pct:.0f}')}% \\(${escape_markdown_v2(f'{tp_value:.2f}')}\\)") + lines.append(f" 📉 SL: \\-{escape_markdown_v2(f'{sl_pct:.0f}')}% \\(${escape_markdown_v2(f'{sl_value:.2f}')}\\)") + lines.append("") + + # Show commands reminder + lines.append("_Commands: `tp=15%`, `1 sl=$50`, `remove 1`, `add`_") + + await context.bot.send_message( + chat_id=chat_id, + text="\n".join(lines), + parse_mode="MarkdownV2" + ) + + +async def _show_help(context, chat_id: int): + """Show help message.""" + await context.bot.send_message( + chat_id=chat_id, + text=( + "🎯 *LP TP/SL Commands*\n\n" + "`1`, `2`, `3` \\- Add position by number\n" + "`tp=15%` \\- Set TP for all positions\n" + "`sl=10%` \\- Set SL for all positions\n" + "`sl=$50` \\- Set SL to $50 for all\n" + "`1 tp=25%` \\- Set TP for position \\#1\n" + "`2 sl=$100` \\- Set SL for position \\#2\n" + "`remove 1` \\- Stop tracking position \\#1\n" + "`add` \\- Show available positions\n" + "`status` \\- Show current status\n" + ), + parse_mode="MarkdownV2" + ) + + +# ============================================================================= +# MAIN ROUTINE +# ============================================================================= + +class Config(BaseModel): + """Multi-position LP TP/SL Monitor with interactive control.""" + + default_tp_pct: float = Field(default=10.0, description="Default take profit % (can override per position)") + default_sl_pct: float = Field(default=10.0, description="Default stop loss % (can override per position)") + interval_sec: int = Field(default=30, description="Check interval in seconds") + + +async def run(config: Config, context: ContextTypes.DEFAULT_TYPE) -> str: + """ + Multi-position LP TP/SL Monitor with interactive control. + + - Start with no positions tracked + - Send position numbers to add positions + - Send 'tp=15%' or 'sl=$50' to modify TP/SL + - Alerts when TP or SL is hit + """ + logger.info(f"LP TP/SL starting: TP={config.default_tp_pct}%, SL={config.default_sl_pct}%") + + chat_id = context._chat_id if hasattr(context, '_chat_id') else None + instance_id = getattr(context, '_instance_id', 'default') + + if not chat_id: + return "No chat_id available" + + # Get user_data and initialize state + user_data = _get_user_data(context) + state = _get_state(context, instance_id) + state["global_defaults"]["tp_pct"] = config.default_tp_pct + state["global_defaults"]["sl_pct"] = config.default_sl_pct + + # Enable interactive mode - this routes messages to our handler + user_data["routines_state"] = "tpsl_interactive" + user_data["tpsl_active_instance"] = instance_id + + # Get client + client = await get_client(chat_id, context=context) + if not client: + return "No server available" + + if not hasattr(client, 'gateway_clmm'): + return "Gateway CLMM not available" + + # Fetch available positions + await _refresh_available_positions(context, chat_id, state, client) + + if not state.get("available_positions"): + user_data.pop("routines_state", None) + user_data.pop("tpsl_active_instance", None) + await context.bot.send_message( + chat_id=chat_id, + text="No active LP positions found\\.", + parse_mode="MarkdownV2" + ) + return "No active positions" + + # Show initial setup message + await _show_setup_message(context, chat_id, state, config) + + try: + # Main monitoring loop + while True: + state["checks"] += 1 + state["last_check"] = time.time() + + # Re-read state (may have been modified by message handler) + state = _get_state(context, instance_id) + + # Check positions if any are tracked + if state["tracked_positions"]: + try: + client = await get_client(chat_id, context=context) + if client and hasattr(client, 'gateway_clmm'): + await _check_positions(context, chat_id, state, client, instance_id) + except Exception as e: + logger.error(f"Error checking positions: {e}") + + # Log periodically + if state["checks"] % 20 == 0: + tracked_count = len(state.get("tracked_positions", {})) + logger.info(f"LP TP/SL check #{state['checks']}: tracking {tracked_count} positions") + + await asyncio.sleep(config.interval_sec) + + except asyncio.CancelledError: + # Cleanup interactive state + user_data.pop("routines_state", None) + user_data.pop("tpsl_active_instance", None) + + # Build stop summary + elapsed = int(time.time() - state.get("start_time", time.time())) + mins, secs = divmod(elapsed, 60) + tracked_count = len(state.get("tracked_positions", {})) + triggers = sum(1 for p in state.get("tracked_positions", {}).values() if p.get("triggered")) + + # Clean up state + user_data.pop(f"lp_tpsl_{instance_id}", None) + + try: + await context.bot.send_message( + chat_id=chat_id, + text=( + f"🔴 *LP TP/SL Monitor Stopped*\n\n" + f"Duration: {mins}m {secs}s\n" + f"Positions tracked: {tracked_count}\n" + f"Triggers: {triggers}\n" + f"Checks: {state.get('checks', 0)}" + ), + parse_mode="MarkdownV2" + ) + except Exception: + pass + + return f"Stopped after {mins}m {secs}s, {tracked_count} positions, {triggers} triggers" + + +async def _refresh_available_positions(context, chat_id: int, state: dict, client): + """Fetch available LP positions from Gateway.""" + try: + result = await client.gateway_clmm.search_positions( + limit=100, + offset=0, + status="OPEN", + refresh=True + ) + + if not result: + state["available_positions"] = [] + return + + positions = result.get("data", []) + + # Filter to active positions with liquidity + active_positions = [] + for pos in positions: + if pos.get('status') == 'CLOSED': + continue + liq = pos.get('liquidity') or pos.get('current_liquidity') + if liq is not None: + try: + if float(liq) <= 0: + continue + except (ValueError, TypeError): + pass + active_positions.append(pos) + + state["available_positions"] = active_positions + + # Build token cache + token_cache = dict(KNOWN_TOKENS) + networks = list(set(pos.get('network', 'solana-mainnet-beta') for pos in active_positions)) + if networks and hasattr(client, 'gateway'): + for network in networks: + try: + resp = await client.gateway.get_network_tokens(network) + tokens = resp.get('tokens', []) if resp else [] + for token in tokens: + addr = token.get('address', '') + symbol = token.get('symbol', '') + if addr and symbol: + token_cache[addr] = symbol + except Exception: + pass + + state["token_cache"] = token_cache + + # Fetch token prices + state["token_prices"] = await _fetch_token_prices(client) + + except Exception as e: + logger.error(f"Error fetching positions: {e}") + state["available_positions"] = [] + + +async def _show_setup_message(context, chat_id: int, state: dict, config: Config): + """Show initial setup message with available positions.""" + available = state.get("available_positions", []) + token_cache = state.get("token_cache", {}) + token_prices = state.get("token_prices", {}) + + lines = [ + "🎯 *LP TP/SL Monitor Started*\n", + f"Default: TP \\+{config.default_tp_pct:.0f}% \\| SL \\-{config.default_sl_pct:.0f}%\n", + "📋 *Available Positions:*" + ] + + for i, pos in enumerate(available, 1): + base_token = pos.get('base_token', pos.get('token_a', '')) + quote_token = pos.get('quote_token', pos.get('token_b', '')) + base_symbol = resolve_token_symbol(base_token, token_cache) + quote_symbol = resolve_token_symbol(quote_token, token_cache) + pair = f"{base_symbol}-{quote_symbol}" + + connector = pos.get('connector', 'unknown')[:3] + in_range = pos.get('in_range', '') + status_emoji = "🟢" if in_range == "IN_RANGE" else "🔴" if in_range == "OUT_OF_RANGE" else "⚪" + + value = _calculate_position_value_usd(pos, token_cache, token_prices) + + line = f"{i}\\. {escape_markdown_v2(pair)} \\({escape_markdown_v2(connector)}\\) {status_emoji} ${escape_markdown_v2(f'{value:.2f}')}" + lines.append(line) + + lines.append("") + lines.append("_Send position number to add \\(e\\.g\\. '1'\\)_") + lines.append("_Send `help` for all commands_") + + await context.bot.send_message( + chat_id=chat_id, + text="\n".join(lines), + parse_mode="MarkdownV2" + ) + + +async def _check_positions(context, chat_id: int, state: dict, client, instance_id: str): + """Check all tracked positions for TP/SL triggers.""" + tracked = state.get("tracked_positions", {}) + if not tracked: + return + + # Refresh token prices + token_prices = await _fetch_token_prices(client) + state["token_prices"] = token_prices + + # Fetch current positions + try: + result = await client.gateway_clmm.search_positions( + limit=100, + offset=0, + status="OPEN", + refresh=True + ) + current_positions = { + (pos.get('id') or pos.get('position_id') or pos.get('address', '')): pos + for pos in result.get("data", []) if result + } + except Exception as e: + logger.error(f"Error fetching positions for check: {e}") + return + + token_cache = state.get("token_cache", {}) + user_data = _get_user_data(context) + + for pos_id, pos_data in list(tracked.items()): + # Skip already triggered + if pos_data.get("triggered"): + continue + + # Check if position still exists + current_pos = current_positions.get(pos_id) + if not current_pos: + # Position was closed externally + await context.bot.send_message( + chat_id=chat_id, + text=f"Position *{escape_markdown_v2(pos_data.get('pair', 'Unknown'))}* was closed externally\\.", + parse_mode="MarkdownV2" + ) + del tracked[pos_id] + continue + + # Calculate current value + current_value = _calculate_position_value_usd(current_pos, token_cache, token_prices) + pos_data["current_value_usd"] = current_value + pos_data["high_value_usd"] = max(pos_data.get("high_value_usd", current_value), current_value) + pos_data["low_value_usd"] = min(pos_data.get("low_value_usd", current_value), current_value) + + entry_value = pos_data.get("entry_value_usd", 0) + if entry_value <= 0: + continue + + change_pct = ((current_value - entry_value) / entry_value) * 100 + tp_pct = pos_data.get("tp_pct", 0) + sl_pct = pos_data.get("sl_pct", 0) + + # Check TP + if tp_pct > 0 and change_pct >= tp_pct: + pos_data["triggered"] = "TP" + await _send_trigger_alert(context, chat_id, pos_data, "TP", instance_id, user_data, current_pos) + + # Check SL + elif sl_pct > 0 and change_pct <= -sl_pct: + pos_data["triggered"] = "SL" + await _send_trigger_alert(context, chat_id, pos_data, "SL", instance_id, user_data, current_pos) + + +async def _send_trigger_alert(context, chat_id: int, pos_data: dict, trigger_type: str, instance_id: str, user_data: dict, current_pos: dict): + """Send TP/SL trigger alert with action buttons.""" + pair = pos_data.get("pair", "Unknown") + entry = pos_data.get("entry_value_usd", 0) + current = pos_data.get("current_value_usd", 0) + change_pct = ((current - entry) / entry * 100) if entry > 0 else 0 + target_pct = pos_data.get("tp_pct" if trigger_type == "TP" else "sl_pct", 0) + + # Store position for close button + pos_id = pos_data.get("position_id", "") + cache_key = f"tpsl_{instance_id}_{pos_id[:8]}" + if "positions_cache" not in user_data: + user_data["positions_cache"] = {} + user_data["positions_cache"][cache_key] = current_pos + + if trigger_type == "TP": + header = "🎯 *TAKE PROFIT HIT\\!*" + emoji = "🚀" + target_str = f"\\+{target_pct:.0f}%" + else: + header = "🚨 *STOP LOSS HIT\\!*" + emoji = "⚠️" + target_str = f"\\-{target_pct:.0f}%" + + change_str = f"+{change_pct:.1f}%" if change_pct >= 0 else f"{change_pct:.1f}%" + + keyboard = [[ + InlineKeyboardButton("✅ Close Position", callback_data=f"dex:pos_close:{cache_key}"), + InlineKeyboardButton("🔄 Continue", callback_data=f"routines:lp_tpsl:continue:{instance_id}:{pos_id[:8]}"), + ], [ + InlineKeyboardButton("❌ Remove from Monitor", callback_data=f"routines:lp_tpsl:remove:{instance_id}:{pos_id[:8]}"), + ]] + + await context.bot.send_message( + chat_id=chat_id, + text=( + f"{header}\n\n" + f"*{escape_markdown_v2(pair)}*\n" + f"Entry: ${escape_markdown_v2(f'{entry:.2f}')}\n" + f"Current: ${escape_markdown_v2(f'{current:.2f}')}\n" + f"Change: {escape_markdown_v2(change_str)}\n\n" + f"{emoji} Target {target_str} reached\\!" + ), + parse_mode="MarkdownV2", + reply_markup=InlineKeyboardMarkup(keyboard) + ) diff --git a/routines/price_monitor.py b/routines/price_monitor.py index 4974bfd..95da3d7 100644 --- a/routines/price_monitor.py +++ b/routines/price_monitor.py @@ -1,15 +1,19 @@ """Monitor price and alert on threshold.""" +import asyncio import logging import time from pydantic import BaseModel, Field from telegram.ext import ContextTypes -from servers import get_client +from config_manager import get_client from utils.telegram_formatters import escape_markdown_v2 logger = logging.getLogger(__name__) +# Mark as continuous routine - has internal loop +CONTINUOUS = True + class Config(BaseModel): """Live price monitor with configurable alerts.""" @@ -17,111 +21,123 @@ class Config(BaseModel): connector: str = Field(default="binance", description="CEX connector name") trading_pair: str = Field(default="BTC-USDT", description="Trading pair to monitor") threshold_pct: float = Field(default=1.0, description="Alert threshold in %") - interval_sec: int = Field(default=10, description="Refresh interval in seconds") + interval_sec: int = Field(default=10, description="Check interval in seconds") async def run(config: Config, context: ContextTypes.DEFAULT_TYPE) -> str: """ - Monitor price - single iteration. + Monitor price continuously. - Runs silently in background. Sends alert messages when threshold is crossed. - Returns status string for the routine handler to display. + This is a continuous routine - runs forever until cancelled. + Sends alert messages when threshold is crossed. """ chat_id = context._chat_id if hasattr(context, '_chat_id') else None - client = await get_client(chat_id) + instance_id = getattr(context, '_instance_id', 'default') + client = await get_client(chat_id, context=context) if not client: return "No server available" - # Get user_data and instance_id - user_data = getattr(context, '_user_data', None) or getattr(context, 'user_data', {}) - instance_id = getattr(context, '_instance_id', 'default') - - # State key for this routine instance - state_key = f"price_monitor_state_{chat_id}_{instance_id}" - - # Get or initialize state - state = user_data.get(state_key, {}) - - # Get current price + # State for tracking + state = { + "initial_price": None, + "last_price": None, + "high_price": None, + "low_price": None, + "alerts_sent": 0, + "updates": 0, + "start_time": time.time(), + } + + # Send start notification try: - prices = await client.market_data.get_prices( - connector_name=config.connector, - trading_pairs=config.trading_pair + pair_esc = escape_markdown_v2(config.trading_pair) + await context.bot.send_message( + chat_id=chat_id, + text=f"🟢 *Price Monitor Started*\n{pair_esc} @ {escape_markdown_v2(config.connector)}", + parse_mode="MarkdownV2" ) - current_price = prices["prices"].get(config.trading_pair) - if not current_price: - return f"No price for {config.trading_pair}" except Exception as e: - return f"Error: {e}" - - # Initialize state on first run - if not state: - state = { - "initial_price": current_price, - "last_price": current_price, - "high_price": current_price, - "low_price": current_price, - "alerts_sent": 0, - "updates": 0, - "start_time": time.time(), - } - user_data[state_key] = state - - # Send start notification - try: - pair_esc = escape_markdown_v2(config.trading_pair) - price_esc = escape_markdown_v2(f"${current_price:,.2f}") - await context.bot.send_message( - chat_id=chat_id, - text=f"🟢 *Price Monitor Started*\n{pair_esc}: `{price_esc}`", - parse_mode="MarkdownV2" - ) - except Exception: - pass - - # Update tracking - state["high_price"] = max(state["high_price"], current_price) - state["low_price"] = min(state["low_price"], current_price) - - # Calculate changes - change_from_last = ((current_price - state["last_price"]) / state["last_price"]) * 100 - change_from_start = ((current_price - state["initial_price"]) / state["initial_price"]) * 100 + logger.error(f"Failed to send start message: {e}") - # Check threshold for alert - if abs(change_from_last) >= config.threshold_pct: - direction = "📈" if change_from_last > 0 else "📉" - pair_esc = escape_markdown_v2(config.trading_pair) - price_esc = escape_markdown_v2(f"${current_price:,.2f}") - change_esc = escape_markdown_v2(f"{change_from_last:+.2f}%") + try: + # Main monitoring loop + while True: + try: + # Get current price + prices = await client.market_data.get_prices( + connector_name=config.connector, + trading_pairs=config.trading_pair + ) + current_price = prices["prices"].get(config.trading_pair) + + if not current_price: + await asyncio.sleep(config.interval_sec) + continue + + # Initialize on first price + if state["initial_price"] is None: + state["initial_price"] = current_price + state["last_price"] = current_price + state["high_price"] = current_price + state["low_price"] = current_price + + # Update tracking + state["high_price"] = max(state["high_price"], current_price) + state["low_price"] = min(state["low_price"], current_price) + state["updates"] += 1 + + # Calculate changes + change_from_last = ((current_price - state["last_price"]) / state["last_price"]) * 100 + + # Check threshold for alert + if abs(change_from_last) >= config.threshold_pct: + direction = "📈" if change_from_last > 0 else "📉" + pair_esc = escape_markdown_v2(config.trading_pair) + price_esc = escape_markdown_v2(f"${current_price:,.2f}") + change_esc = escape_markdown_v2(f"{change_from_last:+.2f}%") + + try: + await context.bot.send_message( + chat_id=chat_id, + text=( + f"{direction} *{pair_esc} Alert*\n" + f"Price: `{price_esc}`\n" + f"Change: `{change_esc}`" + ), + parse_mode="MarkdownV2" + ) + state["alerts_sent"] += 1 + except Exception: + pass + + # Update last price + state["last_price"] = current_price + + except asyncio.CancelledError: + raise # Re-raise to exit the loop + except Exception as e: + logger.error(f"Price monitor error: {e}") + + # Wait for next check + await asyncio.sleep(config.interval_sec) + + except asyncio.CancelledError: + # Send stop notification + elapsed = int(time.time() - state["start_time"]) + mins, secs = divmod(elapsed, 60) try: await context.bot.send_message( chat_id=chat_id, text=( - f"{direction} *{pair_esc} Alert*\n" - f"Price: `{price_esc}`\n" - f"Change: `{change_esc}`" + f"🔴 *Price Monitor Stopped*\n" + f"{escape_markdown_v2(config.trading_pair)}\n" + f"Duration: {mins}m {secs}s \\| Updates: {state['updates']} \\| Alerts: {state['alerts_sent']}" ), parse_mode="MarkdownV2" ) - state["alerts_sent"] += 1 except Exception: pass - # Update state - state["last_price"] = current_price - state["updates"] += 1 - user_data[state_key] = state - - # Build status string for handler display - elapsed = int(time.time() - state["start_time"]) - mins, secs = divmod(elapsed, 60) - - trend = "📈" if change_from_start > 0.01 else "📉" if change_from_start < -0.01 else "➡️" - - return ( - f"{trend} ${current_price:,.2f} ({change_from_start:+.2f}%)\n" - f"High: ${state['high_price']:,.2f} | Low: ${state['low_price']:,.2f}\n" - f"Updates: {state['updates']} | Alerts: {state['alerts_sent']} | {mins}m {secs}s" - ) + return f"Stopped after {mins}m {secs}s, {state['updates']} updates, {state['alerts_sent']} alerts" diff --git a/servers.py b/servers.py deleted file mode 100644 index 6dc854d..0000000 --- a/servers.py +++ /dev/null @@ -1,426 +0,0 @@ -""" -Simple API Server Manager with YAML configuration -Manages multiple Hummingbot API servers from servers.yml -""" - -import asyncio -import logging -import os -from pathlib import Path -from typing import Dict, Optional - -import yaml -from aiohttp import ClientTimeout -from hummingbot_api_client import HummingbotAPIClient - -logger = logging.getLogger(__name__) - - -class ServerManager: - """Manages multiple API servers from servers.yml configuration""" - - def __init__(self, config_path: str = "servers.yml"): - self.config_path = Path(config_path) - self.servers: Dict[str, dict] = {} - self.clients: Dict[str, HummingbotAPIClient] = {} - self.default_server: Optional[str] = None - self.per_chat_servers: Dict[int, str] = {} # chat_id -> server_name - self._load_config() - - def _load_config(self): - """Load servers configuration from YAML file""" - if not self.config_path.exists(): - logger.warning(f"Config file not found: {self.config_path}") - self.servers = {} - self.default_server = None - return - - try: - with open(self.config_path, 'r') as f: - config = yaml.safe_load(f) - self.servers = config.get('servers', {}) - self.default_server = config.get('default_server', None) - - # Load per-chat server defaults - per_chat_raw = config.get('per_chat_defaults', {}) - self.per_chat_servers = { - int(chat_id): server_name - for chat_id, server_name in per_chat_raw.items() - if server_name in self.servers - } - - # Validate default server exists - if self.default_server and self.default_server not in self.servers: - logger.warning(f"Default server '{self.default_server}' not found in servers list") - self.default_server = None - - logger.info(f"Loaded {len(self.servers)} servers from {self.config_path}") - if self.default_server: - logger.info(f"Default server: {self.default_server}") - if self.per_chat_servers: - logger.info(f"Loaded {len(self.per_chat_servers)} per-chat server defaults") - except Exception as e: - logger.error(f"Failed to load config: {e}") - self.servers = {} - self.default_server = None - self.per_chat_servers = {} - - def _save_config(self): - """Save servers configuration to YAML file""" - try: - config = {'servers': self.servers} - if self.default_server: - config['default_server'] = self.default_server - if self.per_chat_servers: - config['per_chat_defaults'] = self.per_chat_servers - with open(self.config_path, 'w') as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) - logger.info(f"Saved configuration to {self.config_path}") - except Exception as e: - logger.error(f"Failed to save config: {e}") - raise - - def add_server(self, name: str, host: str, port: int, username: str, - password: str) -> bool: - """Add a new server to configuration""" - if name in self.servers: - logger.error(f"Server '{name}' already exists") - return False - - self.servers[name] = { - 'host': host, - 'port': port, - 'username': username, - 'password': password - } - self._save_config() - logger.info(f"Added server '{name}'") - return True - - def modify_server(self, name: str, host: Optional[str] = None, - port: Optional[int] = None, username: Optional[str] = None, - password: Optional[str] = None) -> bool: - """Modify an existing server configuration""" - if name not in self.servers: - logger.error(f"Server '{name}' not found") - return False - - # Close existing client if configuration is changing - if name in self.clients: - asyncio.create_task(self._close_client(name)) - - # Update only provided fields - if host is not None: - self.servers[name]['host'] = host - if port is not None: - self.servers[name]['port'] = port - if username is not None: - self.servers[name]['username'] = username - if password is not None: - self.servers[name]['password'] = password - - self._save_config() - logger.info(f"Modified server '{name}'") - return True - - def delete_server(self, name: str) -> bool: - """Delete a server from configuration and runtime""" - if name not in self.servers: - logger.error(f"Server '{name}' not found") - return False - - # Close and remove client if exists - if name in self.clients: - asyncio.create_task(self._close_client(name)) - - del self.servers[name] - self._save_config() - logger.info(f"Deleted server '{name}'") - return True - - def list_servers(self) -> Dict[str, dict]: - """List all configured servers""" - return self.servers.copy() - - def get_server(self, name: str) -> Optional[dict]: - """Get a specific server configuration""" - return self.servers.get(name) - - def set_default_server(self, name: str) -> bool: - """Set the default server""" - if name not in self.servers: - logger.error(f"Server '{name}' not found") - return False - - self.default_server = name - self._save_config() - logger.info(f"Set default server to '{name}'") - return True - - def get_default_server(self) -> Optional[str]: - """Get the default server name""" - return self.default_server - - def get_default_server_for_chat(self, chat_id: int) -> Optional[str]: - """Get the default server for a specific chat, falling back to global default""" - server = self.per_chat_servers.get(chat_id) - if server and server in self.servers: - return server - # Fallback to global default server - if self.default_server and self.default_server in self.servers: - return self.default_server - # Last resort: first available server - if self.servers: - return list(self.servers.keys())[0] - return None - - def set_default_server_for_chat(self, chat_id: int, server_name: str) -> bool: - """Set the default server for a specific chat""" - if server_name not in self.servers: - logger.error(f"Server '{server_name}' not found") - return False - - self.per_chat_servers[chat_id] = server_name - self._save_config() - logger.info(f"Set default server for chat {chat_id} to '{server_name}'") - return True - - def clear_default_server_for_chat(self, chat_id: int) -> bool: - """Clear the per-chat default server, reverting to global default""" - if chat_id in self.per_chat_servers: - del self.per_chat_servers[chat_id] - self._save_config() - logger.info(f"Cleared default server for chat {chat_id}") - return True - return False - - def get_chat_server_info(self, chat_id: int) -> dict: - """Get server info for a chat including whether it's using per-chat or global default""" - per_chat = self.per_chat_servers.get(chat_id) - if per_chat and per_chat in self.servers: - return { - "server": per_chat, - "is_per_chat": True, - "global_default": self.default_server - } - return { - "server": self.default_server, - "is_per_chat": False, - "global_default": self.default_server - } - - async def check_server_status(self, name: str) -> dict: - """ - Check if a server is online and responding using protected endpoint - Returns detailed status including authentication errors - """ - if name not in self.servers: - return {"status": "error", "message": "Server not found"} - - server = self.servers[name] - base_url = f"http://{server['host']}:{server['port']}" - - logger.debug(f"Checking status for '{name}' at {base_url} with username '{server['username']}'") - - # Create a temporary client for testing (don't cache it) - # Important: Do NOT use cached clients to ensure we test current credentials - # Use 3 second timeout for quick status checks - client = HummingbotAPIClient( - base_url=base_url, - username=server['username'], - password=server['password'], - timeout=ClientTimeout(total=3, connect=2) # Quick timeout for status check - ) - - try: - await client.init() - # Use protected endpoint to verify both connectivity and authentication - # This will raise 401 error if credentials are wrong - await client.accounts.list_accounts() - logger.info(f"Status check succeeded for '{name}' - server is online") - return {"status": "online", "message": "Connected and authenticated"} - except Exception as e: - error_msg = str(e) - logger.warning(f"Status check failed for '{name}': {error_msg}") - - # Categorize the error with clearer messages - if "401" in error_msg or "Incorrect username or password" in error_msg: - return {"status": "auth_error", "message": "Invalid credentials"} - elif "timeout" in error_msg.lower() or "TimeoutError" in error_msg: - return {"status": "offline", "message": "Connection timeout - server unreachable"} - elif "Connection" in error_msg or "Cannot connect" in error_msg or "ConnectionRefused" in error_msg: - return {"status": "offline", "message": "Cannot reach server"} - elif "ClientConnectorError" in error_msg or "getaddrinfo" in error_msg: - return {"status": "offline", "message": "Server unreachable or invalid host"} - else: - # Show first 80 chars of error for debugging - return {"status": "error", "message": f"Error: {error_msg[:80]}"} - finally: - # Always close the client - try: - await client.close() - except: - pass - - async def get_default_client(self) -> HummingbotAPIClient: - """Get the API client for the default server""" - if not self.default_server: - # If no default server, try to use the first available server - if not self.servers: - raise ValueError("No servers configured") - self.default_server = list(self.servers.keys())[0] - logger.info(f"No default server set, using '{self.default_server}'") - - return await self.get_client(self.default_server) - - async def get_client_for_chat(self, chat_id: int) -> HummingbotAPIClient: - """Get the API client for a specific chat's default server""" - server_name = self.get_default_server_for_chat(chat_id) - if not server_name: - # Fallback to first available server - if not self.servers: - raise ValueError("No servers configured") - server_name = list(self.servers.keys())[0] - logger.info(f"No default server for chat {chat_id}, using '{server_name}'") - - return await self.get_client(server_name) - - async def get_client(self, name: Optional[str] = None) -> HummingbotAPIClient: - """Get or create API client for a server. If name is None, uses default server.""" - if name is None: - return await self.get_default_client() - - if name not in self.servers: - raise ValueError(f"Server '{name}' not found") - - server = self.servers[name] - - # Return existing client if available - if name in self.clients: - return self.clients[name] - - # Create new client with longer timeout to handle slow operations - # (credential verification can take time as it connects to external exchanges) - base_url = f"http://{server['host']}:{server['port']}" - client = HummingbotAPIClient( - base_url=base_url, - username=server['username'], - password=server['password'], - timeout=ClientTimeout(total=60, connect=10) - ) - - try: - await client.init() - # Test connection - await client.accounts.list_accounts() - self.clients[name] = client - logger.info(f"Connected to server '{name}' at {base_url}") - return client - except Exception as e: - await client.close() - logger.error(f"Failed to connect to '{name}': {e}") - raise - - async def _close_client(self, name: str): - """Close a specific client connection""" - if name in self.clients: - try: - await self.clients[name].close() - logger.info(f"Closed connection to '{name}'") - except Exception as e: - logger.error(f"Error closing client '{name}': {e}") - finally: - del self.clients[name] - - async def close_all(self): - """Close all client connections""" - for name in list(self.clients.keys()): - await self._close_client(name) - - async def initialize_all(self): - """Initialize all servers""" - for name in self.servers.keys(): - try: - await self.get_client(name) - except Exception as e: - logger.warning(f"Failed to initialize '{name}': {e}") - - async def reload_config(self): - """Reload configuration from file and clear cached clients""" - # Close all existing clients since config may have changed - await self.close_all() - # Reload the configuration - self._load_config() - logger.info("Configuration reloaded from file") - - -# Global server manager instance -server_manager = ServerManager() - - -async def get_client(chat_id: int = None): - """Get the API client for the appropriate server. - - Args: - chat_id: Optional chat ID to get per-chat server. If None, uses 'local' as fallback. - - Returns: - HummingbotAPIClient instance - - Raises: - ValueError: If no servers are configured - """ - if chat_id is not None: - return await server_manager.get_client_for_chat(chat_id) - return await server_manager.get_default_client() - - -# Example usage -async def main(): - """Example usage of ServerManager""" - - # List all servers - print("\nConfigured servers:") - for name, config in server_manager.list_servers().items(): - print(f" {name}: {config['host']}:{config['port']}") - - # Add a new server - print("\nAdding new server...") - server_manager.add_server( - name="test", - host="localhost", - port=8081, - username="test_user", - password="test_pass" - ) - - # Modify a server - print("\nModifying server...") - server_manager.modify_server( - name="local", - port=8080 - ) - - # Initialize all servers - print("\nInitializing servers...") - await server_manager.initialize_all() - - # Get a specific client - try: - client = await server_manager.get_client("local") - accounts = await client.accounts.list_accounts() - print(f"\nConnected to 'local' server, accounts: {len(accounts)}") - except Exception as e: - print(f"\nFailed to connect: {e}") - - # Delete a server - print("\nDeleting test server...") - server_manager.delete_server("test") - - # Clean up - await server_manager.close_all() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - asyncio.run(main()) diff --git a/servers.yml b/servers.yml deleted file mode 100644 index 0b21593..0000000 --- a/servers.yml +++ /dev/null @@ -1,13 +0,0 @@ -servers: - remote: - host: 212.85.34.32 - port: 8000 - username: admin - password: admin - local: - host: localhost - port: 8000 - username: admin - password: admin -default_server: local - diff --git a/setup-environment.sh b/setup-environment.sh old mode 100644 new mode 100755 index 9772b10..5d2a885 --- a/setup-environment.sh +++ b/setup-environment.sh @@ -1,55 +1,71 @@ #!/bin/bash +# Configuration +ENV_FILE=".env" +DATA_DIR="data" + echo "===================================" -echo " Condor Bot Setup" +echo " Condor Bot Setup" echo "===================================" -echo "" -# Prompt for Telegram Bot Token -read -p "Enter your Telegram Bot Token: " telegram_token +# 1. Check if .env already exists +if [ -f "$ENV_FILE" ]; then + echo "" + echo ">> Found existing $ENV_FILE file." + echo ">> Credentials already exist. Skipping setup params." + echo "" +else + # 2. Prompt for Telegram Bot Token + echo "" + read -p "Enter your Telegram Bot Token: " telegram_token -# Prompt for Authorized User IDs -echo "" -echo "Enter the User IDs that are allowed to talk with the bot." -echo "Separate multiple User IDs with a comma (e.g., 12345,67890,23456)." -echo "(Tip: Run /start in the bot to see your User ID)" -read -p "User IDs: " user_ids + # 3. Prompt for Admin User ID + echo "" + echo "Enter your Telegram User ID (you will be the admin)." + echo "(Tip: Message @userinfobot on Telegram to get your ID)" + read -p "Admin User ID: " admin_id -# Prompt for Pydantic Gateway Key (optional) -echo "" -echo "Enter your Pydantic Gateway Key (optional, for AI features)." -echo "Press Enter to skip if not using AI features." -read -p "Pydantic Gateway Key: " pydantic_key - -# Remove spaces from user IDs -user_ids=$(echo $user_ids | tr -d '[:space:]') - -# Create or update .env file -echo "TELEGRAM_TOKEN=$telegram_token" > .env -echo "AUTHORIZED_USERS=$user_ids" >> .env -if [ -n "$pydantic_key" ]; then - echo "PYDANTIC_GATEWAY_KEY=$pydantic_key" >> .env -fi + # 4. Prompt for OpenAI API Key (optional) + echo "" + echo "Enter your OpenAI API Key (optional, for AI features)." + echo "Press Enter to skip if not using AI features." + read -p "OpenAI API Key: " openai_key -echo "" -echo ".env file created successfully!" + # Clean whitespaces from inputs + telegram_token=$(echo "$telegram_token" | tr -d '[:space:]') + admin_id=$(echo "$admin_id" | tr -d '[:space:]') + openai_key=$(echo "$openai_key" | tr -d '[:space:]') -echo "" -echo "Installing Chrome for Plotly image generation..." -plotly_get_chrome || kaleido_get_chrome || python -c "import kaleido; kaleido.get_chrome_sync()" 2>/dev/null || echo "Chrome installation skipped (not required for basic usage)" -echo "" -echo "Ensuring data directory exists for persistence..." -mkdir -p data + # 5. Create .env file + { + echo "TELEGRAM_TOKEN=$telegram_token" + echo "ADMIN_USER_ID=$admin_id" + if [ -n "$openai_key" ]; then + echo "OPENAI_API_KEY=$openai_key" + fi + } > "$ENV_FILE" + + echo "" + echo "✅ $ENV_FILE file created successfully!" +fi +# 6. Ensure data directory exists +if [ ! -d "$DATA_DIR" ]; then + echo "Ensuring $DATA_DIR directory exists for persistence..." + mkdir -p "$DATA_DIR" +fi + +# 7. Display Run Instructions echo "===================================" -echo " How to Run Condor" +echo " How to Run Condor" echo "===================================" echo "" echo "Option 1: Docker (Recommended)" -echo " docker compose up -d" +echo " make deploy" echo "" echo "Option 2: Local Python" -echo " make install" -echo " conda activate condor" -echo " python main.py" +echo " make run" echo "" +echo "On first run, config.yml will be auto-created." +echo "Use /config in the bot to add servers and manage access." +echo "===================================" diff --git a/utils/auth.py b/utils/auth.py index dce365b..9a3c11c 100644 --- a/utils/auth.py +++ b/utils/auth.py @@ -4,26 +4,171 @@ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ContextTypes -from utils.config import AUTHORIZED_USERS +from config_manager import ( + UserRole, + ServerPermission, + get_config_manager, +) logger = logging.getLogger(__name__) +async def _notify_admin_new_user(context: ContextTypes.DEFAULT_TYPE, user_id: int, username: str) -> None: + """Send notification to admin about new user request.""" + cm = get_config_manager() + admin_id = cm.admin_id + + if not admin_id: + return + + try: + message = ( + f"👤 *New Access Request*\n\n" + f"User ID: `{user_id}`\n" + f"Username: @{username or 'N/A'}\n\n" + f"Use /start \\> Admin Panel to approve or reject\\." + ) + await context.bot.send_message( + chat_id=admin_id, + text=message, + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Failed to notify admin about new user: {e}") + + def restricted(func): + """ + Decorator that checks if user is approved. + New users are auto-registered as pending and admin is notified. + """ @wraps(func) async def wrapped( update: Update, context: ContextTypes.DEFAULT_TYPE, *args, **kwargs ): user_id = update.effective_user.id - if user_id not in AUTHORIZED_USERS: - print(f"Unauthorized access denied for {user_id}.") - await update.message.reply_text("You are not authorized to use this bot.") + username = update.effective_user.username + + cm = get_config_manager() + role = cm.get_user_role(user_id) + + # Handle blocked users + if role == UserRole.BLOCKED: + logger.info(f"Blocked user {user_id} attempted access") + if update.message: + await update.message.reply_text("🚫 Access denied.") + elif update.callback_query: + await update.callback_query.answer("Access denied", show_alert=True) + return + + # Handle pending users + if role == UserRole.PENDING: + if update.message: + await update.message.reply_text( + "⏳ Your access request is pending admin approval.\n" + "You will be notified when approved." + ) + elif update.callback_query: + await update.callback_query.answer("Access pending approval", show_alert=True) + return + + # Handle new users - register as pending + if role is None: + is_new = cm.register_pending(user_id, username) + if is_new: + # Notify admin + await _notify_admin_new_user(context, user_id, username) + + if update.message: + await update.message.reply_text( + "🔒 *Access Request Submitted*\n\n" + f"Your User ID: `{user_id}`\n\n" + "An admin will review your request\\. " + "You will be notified when approved\\.", + parse_mode="MarkdownV2" + ) + elif update.callback_query: + await update.callback_query.answer("Access request submitted", show_alert=True) return + + # User is approved (USER or ADMIN role) + # Store user_id in context for access control in subsequent calls + context.user_data['_user_id'] = user_id return await func(update, context, *args, **kwargs) return wrapped +def admin_required(func): + """Decorator that requires admin role.""" + @wraps(func) + async def wrapped( + update: Update, context: ContextTypes.DEFAULT_TYPE, *args, **kwargs + ): + user_id = update.effective_user.id + cm = get_config_manager() + + if not cm.is_admin(user_id): + if update.message: + await update.message.reply_text("🔐 Admin access required.") + elif update.callback_query: + await update.callback_query.answer("Admin access required", show_alert=True) + return + + return await func(update, context, *args, **kwargs) + + return wrapped + + +def server_access_required(min_permission: ServerPermission = ServerPermission.VIEWER): + """ + Decorator factory that checks server permission. + Server name is determined from context.user_data or per-chat default. + """ + def decorator(func): + @wraps(func) + async def wrapped( + update: Update, context: ContextTypes.DEFAULT_TYPE, *args, **kwargs + ): + from config_manager import get_config_manager + from handlers.config.user_preferences import get_active_server + + user_id = update.effective_user.id + cm = get_config_manager() + + # Get user's preferred server, fallback to first accessible + server_name = get_active_server(context.user_data) + if not server_name: + accessible = cm.get_accessible_servers(user_id) + server_name = accessible[0] if accessible else None + + if not server_name: + if update.message: + await update.message.reply_text("⚠️ No server configured.") + elif update.callback_query: + await update.callback_query.answer("No server configured", show_alert=True) + return + + # Check permission + if not cm.has_server_access(user_id, server_name, min_permission): + perm_name = min_permission.value.title() + if update.message: + await update.message.reply_text( + f"🚫 You don't have {perm_name} access to this server." + ) + elif update.callback_query: + await update.callback_query.answer( + f"No {perm_name} access to this server", + show_alert=True + ) + return + + return await func(update, context, *args, **kwargs) + + return wrapped + return decorator + + async def _send_service_unavailable_message( update: Update, title: str, @@ -74,26 +219,26 @@ async def wrapped( chat_id = update.effective_chat.id if update.effective_chat else None # Check server status first - server_header, server_online = await get_server_context_header(chat_id) + server_header, server_online = await get_server_context_header(context.user_data) if not server_online: await _send_service_unavailable_message( update, title="Server Offline", status_line="🔴 The API server is not reachable\\.", - instruction="Check your server configuration in /config \\> API Servers\\." + instruction="Check your server configuration in /start \\> API Servers\\." ) return # Check gateway status - _, gateway_running = await get_gateway_status_info(chat_id) + _, gateway_running = await get_gateway_status_info(chat_id, context.user_data) if not gateway_running: await _send_service_unavailable_message( update, title="Gateway Not Running", status_line="🔴 The Gateway is not deployed or not running on this server\\.", - instruction="Deploy the Gateway in /config \\> Gateway to use this feature\\." + instruction="Deploy the Gateway in /start \\> Gateway to use this feature\\." ) return @@ -105,7 +250,7 @@ async def wrapped( update, title="Service Unavailable", status_line="⚠️ Could not verify service status\\.", - instruction="Please try again or check /config for server status\\." + instruction="Please try again or check /start for server status\\." ) return @@ -129,18 +274,15 @@ async def wrapped( try: from handlers.config.server_context import get_server_context_header - # Get chat_id to use per-chat server default - chat_id = update.effective_chat.id if update.effective_chat else None - # Check server status - server_header, server_online = await get_server_context_header(chat_id) + server_header, server_online = await get_server_context_header(context.user_data) if not server_online: await _send_service_unavailable_message( update, title="API Server Offline", status_line="🔴 The Hummingbot API server is not reachable\\.", - instruction="Check your server configuration in /config \\> API Servers\\." + instruction="Check your server configuration in /start \\> API Servers\\." ) return @@ -152,7 +294,7 @@ async def wrapped( update, title="Service Unavailable", status_line="⚠️ Could not verify service status\\.", - instruction="Please try again or check /config for server status\\." + instruction="Please try again or check /start for server status\\." ) return diff --git a/utils/config.py b/utils/config.py index e95b49a..549a10b 100644 --- a/utils/config.py +++ b/utils/config.py @@ -4,9 +4,15 @@ load_dotenv() -AUTHORIZED_USERS = [ - int(user_id) for user_id in os.environ.get("AUTHORIZED_USERS", "").split(",") -] +# Primary admin user ID - this user has full control over the bot +# Set via ADMIN_USER_ID environment variable +ADMIN_USER_ID = None +_admin_id_str = os.environ.get("ADMIN_USER_ID", "").strip() +if _admin_id_str: + try: + ADMIN_USER_ID = int(_admin_id_str) + except ValueError: + pass TELEGRAM_TOKEN = os.environ.get("TELEGRAM_TOKEN") diff --git a/utils/portfolio_graphs.py b/utils/portfolio_graphs.py index ffeba3c..65301fd 100644 --- a/utils/portfolio_graphs.py +++ b/utils/portfolio_graphs.py @@ -4,7 +4,7 @@ import io import logging -from typing import Dict, Any, List, Optional +from typing import Dict, Any import plotly.graph_objects as go from datetime import datetime diff --git a/utils/telegram_formatters.py b/utils/telegram_formatters.py index fce9d45..fe1dfd1 100644 --- a/utils/telegram_formatters.py +++ b/utils/telegram_formatters.py @@ -3,9 +3,54 @@ """ from typing import Dict, Any, List, Optional +from datetime import datetime, timezone import re +def format_uptime(deployed_at: str) -> str: + """ + Format time elapsed since deployment as a compact uptime string. + + Args: + deployed_at: ISO format datetime string (e.g., "2025-12-24T22:22:50.879680+00:00") + + Returns: + Formatted uptime string (e.g., "2h 15m", "1d 5h", "3d") + """ + try: + # Parse the deployed_at timestamp + if deployed_at.endswith('Z'): + deployed_at = deployed_at[:-1] + '+00:00' + deploy_time = datetime.fromisoformat(deployed_at) + + # Get current time in UTC + now = datetime.now(timezone.utc) + + # Calculate the difference + delta = now - deploy_time + + total_seconds = int(delta.total_seconds()) + if total_seconds < 0: + return "0m" + + days = total_seconds // 86400 + hours = (total_seconds % 86400) // 3600 + minutes = (total_seconds % 3600) // 60 + + if days > 0: + if hours > 0: + return f"{days}d {hours}h" + return f"{days}d" + elif hours > 0: + if minutes > 0: + return f"{hours}h {minutes}m" + return f"{hours}h" + else: + return f"{minutes}m" + except Exception: + return "" + + def escape_markdown_v2(text: str) -> str: """ Escape special characters for Telegram MarkdownV2 @@ -330,7 +375,8 @@ def _shorten_controller_for_table(name: str, max_len: int = 28) -> str: def format_active_bots( bots_data: Dict[str, Any], server_name: Optional[str] = None, - server_status: Optional[str] = None + server_status: Optional[str] = None, + bot_runs: Optional[Dict[str, str]] = None ) -> str: """ Format active bots status for Telegram with clean table layout. @@ -339,11 +385,13 @@ def format_active_bots( bots_data: Active bots data from client.bot_orchestration.get_active_bots_status() server_name: Name of the server (optional) server_status: Status of the server (optional) + bot_runs: Dict mapping bot_name -> deployed_at ISO timestamp (optional) Returns: Formatted Telegram message """ message = "🤖 *Active Bots*\n\n" + bot_runs = bot_runs or {} # Handle different response formats # New format: {"status": "success", "data": {"bot_name": {...}}} @@ -378,7 +426,15 @@ def format_active_bots( # Truncate long bot names for display display_name = bot_name[:45] + "..." if len(bot_name) > 45 else bot_name - message += f"{status_emoji} `{escape_markdown_v2(display_name)}`\n" + + # Add uptime if available + uptime_str = "" + if bot_name in bot_runs: + uptime = format_uptime(bot_runs[bot_name]) + if uptime: + uptime_str = f" ⏱️ {uptime}" + + message += f"{status_emoji} `{escape_markdown_v2(display_name)}`{uptime_str}\n" # Performance is a dict of controller_name -> controller_info performance = bot_info.get("performance", {}) @@ -809,7 +865,9 @@ def _format_pnl_value(value: float) -> str: def format_lp_positions(positions_data: Dict[str, Any], token_cache: Optional[Dict[str, str]] = None) -> str: """ - Format LP (CLMM) positions for Telegram display, grouped by connector. + Format LP (CLMM) positions for Telegram display - compact summary with value and PNL. + + Only shows active positions with their value and PNL in a scannable format. Args: positions_data: Dictionary with 'positions' list and 'total' count @@ -823,83 +881,64 @@ def format_lp_positions(positions_data: Dict[str, Any], token_cache: Optional[Di token_cache = token_cache or {} if not positions or total == 0: - return "🏊 *LP Positions \\(CLMM\\)*\n_No active LP positions_\n" + return "" # Don't show section if no positions - message = f"🏊 *LP Positions \\(CLMM\\)* \\({escape_markdown_v2(str(total))}\\)\n" + # Calculate totals and filter active positions + total_value_usd = 0.0 + total_pnl_usd = 0.0 + active_count = 0 + out_of_range_count = 0 - # Group positions by connector - from collections import defaultdict - by_connector = defaultdict(list) for pos in positions: - connector = pos.get('connector', 'unknown') - by_connector[connector].append(pos) + in_range = pos.get('in_range', '') + if in_range == 'IN_RANGE': + active_count += 1 + elif in_range == 'OUT_OF_RANGE': + out_of_range_count += 1 - # Format each connector's positions - for connector, conn_positions in by_connector.items(): - # Get network/chain from first position - first_pos = conn_positions[0] - chain = _get_chain_from_network(first_pos.get('network', '')) + # Get value from pnl_summary (in quote token) + pnl_summary = pos.get('pnl_summary', {}) + current_value = pnl_summary.get('current_lp_value_quote', 0) + total_pnl = pnl_summary.get('total_pnl_quote', 0) - message += f"\n*{escape_markdown_v2(connector)}* \\({escape_markdown_v2(chain)}\\)\n" - - # Create table with PNL info - # Pair(10) Range(9) Status(3) BasePNL(8) QuotePNL(8) - table_content = "```\n" - table_content += f"{'Pair':<10} {'Range':<9} {'St':>2} {'B.PNL':>7} {'Q.PNL':>7}\n" - table_content += f"{'─'*10} {'─'*9} {'─'*2} {'─'*7} {'─'*7}\n" - - for pos in conn_positions[:8]: # Max 8 per connector - # Resolve token pair using cache - base_token = pos.get('base_token', '') - quote_token = pos.get('quote_token', '') - base_symbol = resolve_token_symbol(base_token, token_cache) - quote_symbol = resolve_token_symbol(quote_token, token_cache) - pair_str = f"{base_symbol[:4]}-{quote_symbol[:4]}" - - # Format range - lower_price = pos.get('lower_price', 0) - upper_price = pos.get('upper_price', 0) - if lower_price and upper_price: - try: - low = float(lower_price) - high = float(upper_price) - if low >= 1000: - range_str = f"{low/1000:.1f}k-{high/1000:.1f}k" - elif low >= 1: - range_str = f"{low:.2f}-{high:.2f}" - else: - range_str = f"{low:.3f}-{high:.3f}" - range_str = range_str[:8] - except: - range_str = "N/A" - else: - range_str = "N/A" - - # Status (IN_RANGE -> "IN", OUT_OF_RANGE -> "OUT") - in_range = pos.get('in_range', '') - if in_range == 'IN_RANGE': - status = "✓" - elif in_range == 'OUT_OF_RANGE': - status = "✗" - else: - status = "?" - - # PNL from pnl_summary - pnl_summary = pos.get('pnl_summary', {}) - base_pnl = pnl_summary.get('base_pnl') - quote_pnl = pnl_summary.get('quote_pnl') - - base_pnl_str = _format_pnl_value(base_pnl)[:7] if base_pnl is not None else "—" - quote_pnl_str = _format_pnl_value(quote_pnl)[:7] if quote_pnl is not None else "—" - - table_content += f"{pair_str:<10} {range_str:<9} {status:>2} {base_pnl_str:>7} {quote_pnl_str:>7}\n" - - if len(conn_positions) > 8: - table_content += f"... +{len(conn_positions) - 8} more\n" + try: + # For now, assume quote token is a stablecoin (value ~= $1) + # A more accurate approach would use token_prices + value_f = float(current_value) if current_value else 0 + pnl_f = float(total_pnl) if total_pnl else 0 + total_value_usd += value_f + total_pnl_usd += pnl_f + except (ValueError, TypeError): + pass + + # Build compact message + message = f"🏊 *LP Positions* \\({escape_markdown_v2(str(total))}\\)\n" + + # Summary line: 3 active 🟢 | 1 out 🔴 | Value: $1,234 | PnL: +$56 + parts = [] + if active_count > 0: + parts.append(f"{active_count} 🟢") + if out_of_range_count > 0: + parts.append(f"{out_of_range_count} 🔴") + + if total_value_usd > 0: + value_str = format_number(total_value_usd) + parts.append(f"Value: {value_str}") + + if total_pnl_usd != 0: + pnl_str = format_number(abs(total_pnl_usd)) + if total_pnl_usd >= 0: + parts.append(f"PnL: +{pnl_str}") + else: + parts.append(f"PnL: -{pnl_str}") - table_content += "```\n" - message += table_content + if parts: + summary = " \\| ".join([escape_markdown_v2(p) for p in parts]) + message += summary + else: + message += "_No value data available_" + message += "\n_Use /lp for details_\n\n" return message @@ -988,13 +1027,336 @@ def format_change_compact(value: Optional[float]) -> str: return f"{sign}{value:.1f}%" +def format_exchange_distribution( + accounts_distribution: Dict[str, Any], + changes_24h: Optional[Dict[str, Any]] = None, + total_value: float = 0.0 +) -> str: + """ + Format exchange/connector distribution as a compact table. + + Args: + accounts_distribution: From client.portfolio.get_accounts_distribution() + changes_24h: 24h change data with connector changes + total_value: Total portfolio value for percentage calculation + + Returns: + MarkdownV2 formatted string with exchange distribution table + """ + if not accounts_distribution: + return "" + + # Parse distribution data (supports both list and dict formats) + accounts_list = accounts_distribution.get("distribution", []) + accounts_dict = accounts_distribution.get("accounts", {}) + + # Build connector -> value mapping (aggregated across accounts) + connector_totals = {} # {connector: {"value": float, "account": str}} + connector_changes = changes_24h.get("connectors", {}) if changes_24h else {} + + if accounts_list: + for account_info in accounts_list: + account_name = account_info.get("account", account_info.get("name", "Unknown")) + connectors = account_info.get("connectors", {}) + + for connector_name, connector_value in connectors.items(): + if isinstance(connector_value, dict): + connector_value = connector_value.get("value", 0) + if isinstance(connector_value, str): + try: + connector_value = float(connector_value) + except (ValueError, TypeError): + connector_value = 0 + + key = f"{account_name}:{connector_name}" + connector_totals[key] = { + "value": float(connector_value), + "account": account_name, + "connector": connector_name + } + + elif accounts_dict: + for account_name, account_info in accounts_dict.items(): + connectors = account_info.get("connectors", {}) + for connector_name, connector_info in connectors.items(): + if isinstance(connector_info, dict): + value = connector_info.get("value", 0) + else: + value = connector_info + + if isinstance(value, str): + try: + value = float(value) + except (ValueError, TypeError): + value = 0 + + key = f"{account_name}:{connector_name}" + connector_totals[key] = { + "value": float(value), + "account": account_name, + "connector": connector_name + } + + if not connector_totals: + return "" + + # Sort by value descending + sorted_connectors = sorted(connector_totals.items(), key=lambda x: x[1]["value"], reverse=True) + + # Build table + message = "*Exchanges:*\n" + message += "```\n" + message += f"{'Exchange':<20} {'Value':<10} {'%':>6} {'24h':>8}\n" + message += f"{'─'*20} {'─'*10} {'─'*6} {'─'*8}\n" + + for key, data in sorted_connectors: + connector = data["connector"] + account = data["account"] + value = data["value"] + + if value < 1: # Skip tiny values + continue + + # Show full connector name (up to 19 chars) + display_name = connector[:19] if len(connector) > 19 else connector + + # Format value + value_str = format_number(value) + + # Calculate percentage + pct = (value / total_value * 100) if total_value > 0 else 0 + pct_str = f"{pct:.1f}%" if pct < 100 else f"{pct:.0f}%" + + # Get 24h change + conn_change = connector_changes.get(account, {}).get(connector, {}) + conn_pct = conn_change.get("pct_change") + if conn_pct is not None: + change_str = format_change_compact(conn_pct) + else: + change_str = "—" + + message += f"{display_name:<20} {value_str:<10} {pct_str:>6} {change_str:>8}\n" + + message += "```\n\n" + return message + + +def format_aggregated_tokens( + balances: Dict[str, Any], + changes_24h: Optional[Dict[str, Any]] = None, + total_value: float = 0.0, + max_tokens: int = 10 +) -> str: + """ + Format aggregated token holdings across all exchanges. + + Args: + balances: Portfolio state from get_state() {account: {connector: [holdings]}} + changes_24h: 24h change data with token price changes + total_value: Total portfolio value for percentage calculation + max_tokens: Maximum number of tokens to display + + Returns: + MarkdownV2 formatted string with token holdings table + """ + if not balances: + return "" + + # Aggregate tokens across all accounts/connectors + token_totals = {} # {token: {"units": float, "value": float}} + + for account_name, account_data in balances.items(): + for connector_name, connector_balances in account_data.items(): + if not connector_balances: + continue + for balance in connector_balances: + token = balance.get("token", "???") + units = balance.get("units", 0) + value = balance.get("value", 0) + + if isinstance(units, str): + try: + units = float(units) + except (ValueError, TypeError): + units = 0 + if isinstance(value, str): + try: + value = float(value) + except (ValueError, TypeError): + value = 0 + + if token not in token_totals: + token_totals[token] = {"units": 0.0, "value": 0.0} + + token_totals[token]["units"] += float(units) + token_totals[token]["value"] += float(value) + + if not token_totals: + return "" + + # Sort by value descending + sorted_tokens = sorted(token_totals.items(), key=lambda x: x[1]["value"], reverse=True) + + # Filter out tiny values + sorted_tokens = [(t, d) for t, d in sorted_tokens if d["value"] >= 1] + + if not sorted_tokens: + return "" + + # Get 24h changes + token_changes = changes_24h.get("tokens", {}) if changes_24h else {} + + # Build table + message = "*Token Holdings:*\n" + message += "```\n" + message += f"{'Token':<6} {'Price':<9} {'Value':<8} {'%':>5} {'24h':>7}\n" + message += f"{'─'*6} {'─'*9} {'─'*8} {'─'*5} {'─'*7}\n" + + for token, data in sorted_tokens[:max_tokens]: + units = data["units"] + value = data["value"] + + # Truncate token name + token_display = token[:5] if len(token) > 5 else token + + # Calculate price + price = value / units if units > 0 else 0 + price_str = format_price(price)[:9] + + # Format value + value_str = format_number(value)[:8] + + # Calculate percentage + pct = (value / total_value * 100) if total_value > 0 else 0 + pct_str = f"{pct:.1f}%" if pct < 100 else f"{pct:.0f}%" + + # Get 24h price change + token_change = token_changes.get(token, {}) + price_change = token_change.get("price_change") + if price_change is not None: + change_str = format_change_compact(price_change)[:7] + else: + change_str = "—" + + message += f"{token_display:<6} {price_str:<9} {value_str:<8} {pct_str:>5} {change_str:>7}\n" + + # Show count if more tokens exist + if len(sorted_tokens) > max_tokens: + message += f"\n... +{len(sorted_tokens) - max_tokens} more tokens\n" + + message += "```\n\n" + return message + + +def format_connector_detail( + balances: Dict[str, Any], + connector_key: str, + changes_24h: Optional[Dict[str, Any]] = None, + total_value: float = 0.0 +) -> str: + """ + Format detailed token holdings for a specific connector. + + Args: + balances: Portfolio state from get_state() + connector_key: "account:connector" identifier + changes_24h: 24h change data + total_value: Total portfolio value for percentage calculation + + Returns: + MarkdownV2 formatted string with connector-specific token table + """ + if not balances or not connector_key: + return "_No data available_" + + # Parse connector key + parts = connector_key.split(":", 1) + if len(parts) != 2: + return "_Invalid connector_" + + account_name, connector_name = parts + + # Get connector balances + account_data = balances.get(account_name, {}) + connector_balances = account_data.get(connector_name, []) + + if not connector_balances: + return f"_No holdings found for {escape_markdown_v2(connector_name)}_" + + # Calculate connector total + connector_total = sum(b.get("value", 0) for b in connector_balances if b.get("value", 0) > 0) + + # Get changes + token_changes = changes_24h.get("tokens", {}) if changes_24h else {} + connector_changes = changes_24h.get("connectors", {}) if changes_24h else {} + conn_change = connector_changes.get(account_name, {}).get(connector_name, {}) + conn_pct = conn_change.get("pct_change") + + # Build header + message = f"🏦 *{escape_markdown_v2(connector_name)}* " + message += f"\\| `{escape_markdown_v2(format_number(connector_total))}`" + if conn_pct is not None: + message += f" \\({escape_markdown_v2(format_change_compact(conn_pct))}\\)" + message += "\n" + message += f"_Account: {escape_markdown_v2(account_name)}_\n\n" + + # Sort balances by value + sorted_balances = sorted( + [b for b in connector_balances if b.get("value", 0) >= 1], + key=lambda x: x.get("value", 0), + reverse=True + ) + + if not sorted_balances: + message += "_No significant holdings_" + return message + + # Build table + message += "```\n" + message += f"{'Token':<6} {'Price':<9} {'Value':<8} {'%':>5} {'24h':>7}\n" + message += f"{'─'*6} {'─'*9} {'─'*8} {'─'*5} {'─'*7}\n" + + for balance in sorted_balances: + token = balance.get("token", "???") + units = balance.get("units", 0) + value = balance.get("value", 0) + + # Truncate token name + token_display = token[:5] if len(token) > 5 else token + + # Calculate price + price = value / units if units > 0 else 0 + price_str = format_price(price)[:9] + + # Format value + value_str = format_number(value)[:8] + + # Calculate percentage of portfolio + pct = (value / total_value * 100) if total_value > 0 else 0 + pct_str = f"{pct:.1f}%" if pct < 100 else f"{pct:.0f}%" + + # Get 24h price change + token_change = token_changes.get(token, {}) + price_change = token_change.get("price_change") + if price_change is not None: + change_str = format_change_compact(price_change)[:7] + else: + change_str = "—" + + message += f"{token_display:<6} {price_str:<9} {value_str:<8} {pct_str:>5} {change_str:>7}\n" + + message += "```\n" + return message + + def format_portfolio_overview( overview_data: Dict[str, Any], server_name: Optional[str] = None, server_status: Optional[str] = None, pnl_indicators: Optional[Dict[str, Optional[float]]] = None, changes_24h: Optional[Dict[str, Any]] = None, - token_cache: Optional[Dict[str, str]] = None + token_cache: Optional[Dict[str, str]] = None, + accounts_distribution: Optional[Dict[str, Any]] = None ) -> str: """ Format complete portfolio overview with all sections @@ -1010,6 +1372,7 @@ def format_portfolio_overview( pnl_indicators: Dict with pnl_24h, pnl_7d, pnl_30d percentages (optional) changes_24h: Dict with token and connector 24h changes (optional) token_cache: Dict mapping token addresses to symbols for LP position resolution (optional) + accounts_distribution: From get_accounts_distribution() for exchange breakdown (optional) Returns: Formatted Telegram message with all portfolio sections @@ -1028,146 +1391,67 @@ def format_portfolio_overview( else: message = "💼 *Portfolio Details*\n\n" - # Add PNL indicators bar if available - if pnl_indicators: - pnl_24h = pnl_indicators.get("pnl_24h") - pnl_7d = pnl_indicators.get("pnl_7d") - pnl_30d = pnl_indicators.get("pnl_30d") - detected_movements = pnl_indicators.get("detected_movements", []) - - # Only show if we have at least one value - if any(v is not None for v in [pnl_24h, pnl_7d, pnl_30d]): - pnl_parts = [] - if pnl_24h is not None: - pnl_parts.append(f"24h: `{escape_markdown_v2(format_pnl_indicator(pnl_24h))}`") - if pnl_7d is not None: - pnl_parts.append(f"7d: `{escape_markdown_v2(format_pnl_indicator(pnl_7d))}`") - if pnl_30d is not None: - pnl_parts.append(f"30d: `{escape_markdown_v2(format_pnl_indicator(pnl_30d))}`") - - if pnl_parts: - message += "📈 *PNL:* " + " \\| ".join(pnl_parts) + "\n" - - # Show detected movements if any (max 5 most recent) - if detected_movements: - message += f"_\\({len(detected_movements)} detected movement\\(s\\) adjusted\\)_\n" - message += "\n" - # ============================================ - # SECTION 1: BALANCES - Detailed tables by account and connector + # SECTION 1: TOTAL VALUE AND PNL # ============================================ - balances = overview_data.get('balances') - if balances: - total_value = 0.0 - all_balances = [] + balances = overview_data.get('balances') if overview_data else None - # Collect all balances with metadata - for account_name, account_data in balances.items(): - for connector_name, connector_balances in account_data.items(): + # Calculate total portfolio value + total_value = 0.0 + if balances: + for account_data in balances.values(): + for connector_balances in account_data.values(): if connector_balances: for balance in connector_balances: - token = balance.get("token", "???") - units = balance.get("units", 0) value = balance.get("value", 0) - - if value > 1: # Only show balances > $1 - all_balances.append({ - "account": account_name, - "connector": connector_name, - "token": token, - "units": units, - "value": value - }) + if value > 0: total_value += value - # Calculate percentages - for balance in all_balances: - balance["percentage"] = (balance["value"] / total_value * 100) if total_value > 0 else 0 + # Show total value with all PNL indicators on one line + pnl_24h = pnl_indicators.get("pnl_24h") if pnl_indicators else None + pnl_7d = pnl_indicators.get("pnl_7d") if pnl_indicators else None + pnl_30d = pnl_indicators.get("pnl_30d") if pnl_indicators else None + detected_movements = pnl_indicators.get("detected_movements", []) if pnl_indicators else [] - # Group by account and connector - from collections import defaultdict - grouped = defaultdict(lambda: defaultdict(list)) - - for balance in all_balances: - account = balance["account"] - connector = balance["connector"] - grouped[account][connector].append(balance) - - # Sort each group by value - for account in grouped: - for connector in grouped[account]: - grouped[account][connector].sort(key=lambda x: x["value"], reverse=True) - - # Get 24h changes data - token_changes = changes_24h.get("tokens", {}) if changes_24h else {} - connector_changes = changes_24h.get("connectors", {}) if changes_24h else {} - - # Build the balances section by iterating through accounts and connectors - for account, connectors in grouped.items(): - message += f"*Account:* {escape_markdown_v2(account)}\n" + if total_value > 0: + total_str = format_number(total_value) + line = f"💵 *Total:* `{escape_markdown_v2(total_str)}`" + if pnl_24h is not None: + line += f" \\({escape_markdown_v2(format_pnl_indicator(pnl_24h))} 24h\\)" + else: + line = f"💵 *Total:* `{escape_markdown_v2('$0.00')}`" - for connector, balances_list in connectors.items(): - # Calculate total value for this connector - connector_total = sum(balance["value"] for balance in balances_list) - connector_total_str = format_number(connector_total) + # Add 7d/30d PNL on the same line + pnl_parts = [] + if pnl_7d is not None: + pnl_parts.append(f"7d: {escape_markdown_v2(format_pnl_indicator(pnl_7d))}") + if pnl_30d is not None: + pnl_parts.append(f"30d: {escape_markdown_v2(format_pnl_indicator(pnl_30d))}") - # Get connector 24h change - conn_change = connector_changes.get(account, {}).get(connector, {}) - conn_pct = conn_change.get("pct_change") + if pnl_parts: + line += " 📈 " + " \\| ".join(pnl_parts) - if conn_pct is not None: - change_str = format_change_compact(conn_pct) - message += f" 🏦 *{escape_markdown_v2(connector)}* \\- `{escape_markdown_v2(connector_total_str)}` \\({escape_markdown_v2(change_str)}\\)\n\n" - else: - message += f" 🏦 *{escape_markdown_v2(connector)}* \\- `{escape_markdown_v2(connector_total_str)}`\n\n" - - # Start table - Token, Price, Value, %Total, 24h - table_content = "```\n" - table_content += f"{'Token':<6} {'Price':<8} {'Value':<7} {'%Tot':>5} {'24h':>6}\n" - table_content += f"{'─'*6} {'─'*8} {'─'*7} {'─'*5} {'─'*6}\n" - - for balance in balances_list: - token = balance["token"] - units = balance["units"] - value = balance["value"] - - # Calculate price per token - price = value / units if units > 0 else 0 - price_str = format_price(price)[:8] - - # Calculate percentage of total portfolio - pct = (value / total_value * 100) if total_value > 0 else 0 - pct_str = f"{pct:.0f}%" if pct >= 10 else f"{pct:.1f}%" - - value_str = format_number(value).replace('$', '')[:7] - - # Get 24h price change for this token - token_change = token_changes.get(token, {}) - price_change = token_change.get("price_change") - if price_change is not None: - sign = "+" if price_change >= 0 else "" - change_24h_str = f"{sign}{price_change:.1f}%"[:6] - else: - change_24h_str = "—" + message += line + "\n" - # Truncate long token names - token_display = token[:5] if len(token) > 5 else token + if detected_movements: + message += f"_\\({len(detected_movements)} movement\\(s\\) adjusted\\)_\n" - # Add row to table - table_content += f"{token_display:<6} {price_str:<8} {value_str:<7} {pct_str:>5} {change_24h_str:>6}\n" + message += "\n" - # Close table - table_content += "```\n\n" - message += table_content + # ============================================ + # SECTION 2: EXCHANGE DISTRIBUTION (compact) + # ============================================ + if accounts_distribution: + message += format_exchange_distribution(accounts_distribution, changes_24h, total_value) - # Show total - if total_value > 0: - message += f"💵 *Total Portfolio Value:* `{escape_markdown_v2(format_number(total_value))}`\n\n" - else: - message += f"💵 *Total Portfolio Value:* `{escape_markdown_v2('$0.00')}`\n\n" + # ============================================ + # SECTION 3: AGGREGATED TOKEN HOLDINGS (compact) + # ============================================ + if balances: + message += format_aggregated_tokens(balances, changes_24h, total_value, max_tokens=10) # ============================================ - # SECTION 2: PERPETUAL POSITIONS + # SECTION 4: PERPETUAL POSITIONS # ============================================ perp_positions = overview_data.get('perp_positions', {"positions": [], "total": 0}) message += format_perpetual_positions(perp_positions) diff --git a/utils/telegram_helpers.py b/utils/telegram_helpers.py new file mode 100644 index 0000000..01dfe26 --- /dev/null +++ b/utils/telegram_helpers.py @@ -0,0 +1,41 @@ +"""Telegram helper utilities for reusing callback-based handlers from commands.""" + + +class MockMessage: + """Wrapper around a real message to provide edit_text interface.""" + + def __init__(self, msg): + self._msg = msg + self.message_id = msg.message_id + self.chat_id = msg.chat_id + + async def edit_text(self, text, parse_mode=None, reply_markup=None): + return await self._msg.edit_text( + text, + parse_mode=parse_mode, + reply_markup=reply_markup + ) + + async def delete(self): + return await self._msg.delete() + + def get_bot(self): + return self._msg.get_bot() + + +class MockQuery: + """Mock query object to reuse callback-based functions from commands.""" + + def __init__(self, message, from_user): + self.message = message + self.from_user = from_user + self.data = "" + + async def answer(self, text="", show_alert=False): + pass + + +async def create_mock_query_from_message(update, initial_text="Loading..."): + """Create a mock query from a message update for reusing callback handlers.""" + msg = await update.message.reply_text(initial_text) + return MockQuery(MockMessage(msg), update.effective_user) diff --git a/utils/trading_data.py b/utils/trading_data.py index b93aec9..5de8de8 100644 --- a/utils/trading_data.py +++ b/utils/trading_data.py @@ -190,28 +190,39 @@ def _is_position_active(pos: Dict[str, Any]) -> bool: Returns: True if position appears to be active (has liquidity) """ + # Must not have CLOSED status + if pos.get('status') == 'CLOSED': + return False + # Check liquidity field (exact field name may vary) liquidity = pos.get('liquidity') or pos.get('current_liquidity') or pos.get('liq') if liquidity is not None: try: - if float(liquidity) <= 0: - return False + return float(liquidity) > 0 except (ValueError, TypeError): pass # Check if position has any token amounts remaining - base_amount = pos.get('base_amount') or pos.get('amount_base') or pos.get('token_a_amount') - quote_amount = pos.get('quote_amount') or pos.get('amount_quote') or pos.get('token_b_amount') + # Use the same field names as the /lp command (liquidity.py) + base_amount = ( + pos.get('base_token_amount') or pos.get('base_amount') or + pos.get('amount_base') or pos.get('token_a_amount') + ) + quote_amount = ( + pos.get('quote_token_amount') or pos.get('quote_amount') or + pos.get('amount_quote') or pos.get('token_b_amount') + ) - if base_amount is not None and quote_amount is not None: + # If we have token amount data, check if at least one is > 0 + if base_amount is not None or quote_amount is not None: try: - if float(base_amount) <= 0 and float(quote_amount) <= 0: - return False + return float(base_amount or 0) > 0 or float(quote_amount or 0) > 0 except (ValueError, TypeError): pass - # If we can't determine, assume it's active - return True + # If we can't determine from liquidity or token amounts, assume closed + # (positions should have at least one of these fields if active) + return False async def get_lp_positions(client, account_names: Optional[List[str]] = None) -> Dict[str, Any]: