Skip to content
Merged

Dev #37

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions aws_manage_parameter_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import sys
from typing import List, Optional

import boto3
import botocore


def list_parameters(client, prefix: Optional[str], recursive: bool) -> List[str]:
names: List[str] = []
if prefix:
paginator = client.get_paginator('get_parameters_by_path')
for page in paginator.paginate(Path=prefix, Recursive=recursive, WithDecryption=False):
for p in page.get('Parameters', []):
names.append(p['Name'])
else:
paginator = client.get_paginator('describe_parameters')
for page in paginator.paginate():
for meta in page.get('Parameters', []):
# describe_parameters returns metadata including Name
names.append(meta['Name'])
return names


essm_delete_batch_size = 10 # delete_parameters API supports up to 10 at a time

def delete_parameters(client, names: List[str]) -> None:
for i in range(0, len(names), essm_delete_batch_size):
batch = names[i:i + essm_delete_batch_size]
resp = client.delete_parameters(Names=batch)
deleted = resp.get('DeletedParameters', [])
invalid = resp.get('InvalidParameters', [])
if deleted:
print(f"Deleted: {', '.join(deleted)}")
if invalid:
print(f"Invalid (not found or no access): {', '.join(invalid)}")


def manage_parameter_store(region: Optional[str], prefix: Optional[str], recursive: bool, force: bool, dry_run: bool) -> None:
try:
client = boto3.client('ssm', region_name=region)
resolved_region = region or client.meta.region_name
scope_desc = f"prefix '{prefix}' (recursive={recursive})" if prefix else "all parameters"
print(f"Fetching {scope_desc} from AWS SSM Parameter Store in region: {resolved_region}...")

param_names = list_parameters(client, prefix=prefix, recursive=recursive)
if not param_names:
print("No parameters found.")
return

print(f"\nFound {len(param_names)} parameters.")
for n in param_names:
print(f" - {n}")

if dry_run:
print("\nDry run: no deletions will be performed.")
return

if not force:
ans = input("\nProceed to delete ALL listed parameters? Type 'yes' to confirm: ").strip().lower()
if ans != 'yes':
print("Aborting. No parameters deleted.")
return

print("\nDeleting parameters...")
delete_parameters(client, param_names)
print("\nParameter Store cleanup finished.")

except botocore.exceptions.NoCredentialsError:
print("AWS credentials not found. Please configure your credentials.")
sys.exit(1)
except botocore.exceptions.ClientError as e:
print(f"AWS client error: {e}")
sys.exit(1)
except KeyboardInterrupt:
print("Interrupted.")
sys.exit(130)
except Exception as e:
print(f"Unexpected error: {e}")
sys.exit(1)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='List and delete AWS SSM Parameter Store parameters.')
parser.add_argument('--region', type=str, help='AWS region to use. Defaults to your environment configuration if omitted.')
parser.add_argument('--prefix', type=str, help='Optional path prefix to filter parameters, e.g. /my/app. If omitted, operates on all parameters.')
parser.add_argument('--recursive', action='store_true', help='When used with --prefix, include all child paths recursively.')
parser.add_argument('--force', action='store_true', help='Do not prompt for confirmation; delete immediately.')
parser.add_argument('--dry-run', action='store_true', help='Only list parameters; do not delete.')
args = parser.parse_args()

manage_parameter_store(
region=args.region,
prefix=args.prefix,
recursive=args.recursive,
force=args.force,
dry_run=args.dry_run,
)
92 changes: 92 additions & 0 deletions azure_manage_key_vault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import argparse
import sys
from typing import Iterable, List, Optional

from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from azure.core.exceptions import HttpResponseError


def list_secret_names(client: SecretClient, name_prefix: Optional[str]) -> List[str]:
names: List[str] = []
props_iter: Iterable = client.list_properties_of_secrets()
for props in props_iter:
name = props.name
if name_prefix and not name.startswith(name_prefix):
continue
names.append(name)
return names


def delete_secrets(client: SecretClient, names: List[str], purge: bool) -> None:
for name in names:
try:
print(f"Deleting secret: {name}")
poller = client.begin_delete_secret(name)
poller.wait()
print(f"Deleted (soft-delete) secret: {name}")
if purge:
try:
client.purge_deleted_secret(name)
print(f"Purged secret: {name}")
except HttpResponseError as e:
# Purge may fail if soft-delete not enabled or insufficient permissions
print(f"Could not purge {name}: {e}")
except HttpResponseError as e:
print(f"Error deleting {name}: {e}")


def manage_key_vault(vault_url: str, name_prefix: Optional[str], force: bool, dry_run: bool, purge: bool) -> None:
try:
credential = DefaultAzureCredential()
client = SecretClient(vault_url=vault_url, credential=credential)
scope_desc = f"with name prefix '{name_prefix}'" if name_prefix else "(all secrets)"
print(f"Fetching secrets from Key Vault: {vault_url} {scope_desc} ...")

names = list_secret_names(client, name_prefix=name_prefix)
if not names:
print("No secrets found.")
return

print(f"\nFound {len(names)} secrets:")
for n in names:
print(f" - {n}")

if dry_run:
print("\nDry run: no deletions will be performed.")
return

if not force:
ans = input("\nProceed to delete ALL listed secrets? Type 'yes' to confirm: ").strip().lower()
if ans != 'yes':
print("Aborting. No secrets deleted.")
return

print("\nDeleting secrets...")
delete_secrets(client, names, purge=purge)
print("\nKey Vault cleanup finished.")

except KeyboardInterrupt:
print("Interrupted.")
sys.exit(130)
except Exception as e:
print(f"Unexpected error: {e}")
sys.exit(1)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="List and delete Azure Key Vault secrets.")
parser.add_argument("--vault-url", required=True, help="Key Vault URL, e.g., https://myvault.vault.azure.net/")
parser.add_argument("--name-prefix", help="Optional name prefix filter for secrets.")
parser.add_argument("--force", action="store_true", help="Do not prompt for confirmation; delete immediately.")
parser.add_argument("--dry-run", action="store_true", help="Only list secrets; do not delete.")
parser.add_argument("--purge", action="store_true", help="After deletion, purge secrets (if soft-delete enabled).")
args = parser.parse_args()

manage_key_vault(
vault_url=args.vault_url,
name_prefix=args.name_prefix,
force=args.force,
dry_run=args.dry_run,
purge=args.purge,
)
2 changes: 1 addition & 1 deletion examples/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def run() :
# Call a tool
result = await mcp_client.call_tool("add", {"a": 69, "b": 420})
print(f'The result of 69 + 420 is: {result["content"][-1]["text"]}')

# Call a premium tool
result = await mcp_client.call_tool("multiply", {"a": 69, "b": 420})
print(f'The result of 69 * 420 is: {result["content"][-1]["text"]}')
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ cli = [
"azure-identity>=1.23.0",
"PyYAML>=6.0",
"nostr-relay>=1.14",
"azure-identity>=1.24.0",
"azure-keyvault-secrets>=4.10.0",
]
all = [
"langchain>=0.3.25",
Expand All @@ -135,6 +137,7 @@ all = [
"google-cloud-run>=0.10.18",
"azure-mgmt-containerinstance>=10.1.0",
"azure-identity>=1.23.0",
"azure-keyvault-secrets>=4.10.0",
"PyYAML>=6.0",
"nostr-relay>=1.14",
"langgraph-checkpoint-postgres>=2.0.21",
Expand Down
14 changes: 12 additions & 2 deletions src/agentstr/agents/agentstr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from agentstr.commands.base import Commands
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langchain_core.tools import BaseTool
from agentstr.logger import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -50,7 +51,9 @@ def __init__(self,
llm_model_name: str | None = None,
llm_base_url: str | None = None,
llm_api_key: str | None = None,
agent_callable: Callable[[ChatInput], ChatOutput | str] | None = None):
agent_callable: Callable[[ChatInput], ChatOutput | str] | None = None,
tools: list[BaseTool] | None = None,
recipient_pubkey: str | None = None):
"""Initializes the AgentstrAgent.

Args:
Expand All @@ -70,6 +73,8 @@ def __init__(self,
llm_base_url: The base URL for the language model (or use environment variable LLM_BASE_URL).
llm_api_key: The API key for the language model (or use environment variable LLM_API_KEY).
agent_callable: A callable for non-streaming responses (overrides default LLM response).
tools: A list of Langgraph tools for the agent.
recipient_pubkey: The public key to listen for direct messages from.
"""
self.nostr_client = nostr_client or NostrClient()
self.nostr_mcp_clients = nostr_mcp_clients.copy() if nostr_mcp_clients else []
Expand All @@ -89,6 +94,9 @@ def __init__(self,
self.llm_base_url = llm_base_url or os.getenv("LLM_BASE_URL")
self.llm_api_key = llm_api_key or os.getenv("LLM_API_KEY")
self.agent_callable = agent_callable
self.tools = tools or []
self.recipient_pubkey = recipient_pubkey or os.getenv('RECIPIENT_PUBKEY')

if self.agent_callable is None:
# Require LLM
self._check_llm_vars()
Expand Down Expand Up @@ -145,6 +153,7 @@ async def _create_agent_server(self, checkpointer: AsyncPostgresSaver | AsyncSql
all_tools = []
for nostr_mcp_client in self.nostr_mcp_clients:
all_tools.extend(await to_langgraph_tools(nostr_mcp_client))
all_tools.extend(self.tools)

all_skills = [skill for skills in [await nostr_mcp_client.get_skills() for nostr_mcp_client in self.nostr_mcp_clients] for skill in skills]

Expand Down Expand Up @@ -182,7 +191,8 @@ async def _create_agent_server(self, checkpointer: AsyncPostgresSaver | AsyncSql
server = NostrAgentServer(nostr_client=self.nostr_client,
nostr_agent=nostr_agent,
db=self.database,
commands=self.commands)
commands=self.commands,
recipient_pubkey=self.recipient_pubkey)

return server

Expand Down
24 changes: 18 additions & 6 deletions src/agentstr/agents/nostr_agent_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import asyncio
from collections.abc import Callable
from typing import Any, Literal
import uuid
import json
import os
import time

from pynostr.event import Event
from datetime import datetime, timezone, timedelta

from agentstr.agents.nostr_agent import NostrAgent
from agentstr.database import Database, BaseDatabase
from agentstr.models import AgentCard, ChatInput, ChatOutput, Message, User, NoteFilters
from agentstr.models import ChatInput, ChatOutput, Message, User, NoteFilters
from agentstr.commands.base import Commands
from agentstr.commands.commands import DefaultCommands
from agentstr.logger import get_logger
Expand Down Expand Up @@ -42,7 +41,8 @@ def __init__(self,
nwc_str: str | None = None,
db: BaseDatabase | None = None,
note_filters: NoteFilters | None = None,
commands: Commands | None = None):
commands: Commands | None = None,
recipient_pubkey: str | None = None):
"""
Initialize a NostrAgentServer.

Expand All @@ -56,6 +56,7 @@ def __init__(self,
db (BaseDatabase, optional): Database for persisting messages and user state.
note_filters (NoteFilters, optional): Filters for subscribing to specific Nostr notes/events.
commands (Commands, optional): Custom command handler. If not provided, uses DefaultCommands.
recipient_pubkey (str, optional): The public key to listen for direct messages from.
"""
self.client = nostr_client or (nostr_mcp_client.client if nostr_mcp_client else NostrClient(relays=relays, private_key=private_key, nwc_str=nwc_str))
self.nostr_agent = nostr_agent
Expand All @@ -67,6 +68,7 @@ def __init__(self,
if self.nostr_agent.agent_card.nostr_relays is None:
self.nostr_agent.agent_card.nostr_relays = self.client.relays
self.commands = commands or DefaultCommands(db=self.db, nostr_client=self.client, agent_card=nostr_agent.agent_card)
self.recipient_pubkey = recipient_pubkey

async def _save_input(self, chat_input: ChatInput):
"""
Expand Down Expand Up @@ -345,6 +347,16 @@ async def _direct_message_callback(self, event: Event, message: str):
history = await self.db.get_messages(thread_id=thread_id, user_id=user_id)
logger.debug(f"Message history: {history}")

# Check for latest thread_id
if len(history) > 0:
latest_thread_id = history[-1].thread_id
latest_created_at = history[-1].created_at
new_thread_refresh_seconds = os.getenv("NEW_THREAD_REFRESH_SECONDS", 3600) # default 1 hour
if latest_created_at < datetime.now(timezone.utc) - timedelta(seconds=new_thread_refresh_seconds):
logger.info(f"New thread detected: {latest_thread_id} != {thread_id} or {latest_created_at} < {datetime.now(timezone.utc) - timedelta(seconds=new_thread_refresh_seconds)}")
thread_id = uuid.uuid4().hex
await self.db.set_current_thread_id(user_id=user_id, thread_id=thread_id)

# Create chat input
chat_input = ChatInput(
message=message,
Expand Down Expand Up @@ -385,5 +397,5 @@ async def start(self):
# Start direct message listener
tasks = []
logger.info(f"Starting message listener for {self.client.public_key.bech32()}")
tasks.append(self.client.direct_message_listener(callback=self._direct_message_callback))
tasks.append(self.client.direct_message_listener(callback=self._direct_message_callback, recipient_pubkey=self.recipient_pubkey))
await asyncio.gather(*tasks)
19 changes: 17 additions & 2 deletions src/agentstr/relays/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import time
import uuid
import random
from collections.abc import Callable
import traceback

Expand Down Expand Up @@ -202,9 +203,15 @@ async def direct_message_listener(self, filters: Filters, callback: Callable[[Ev
subscription = create_subscription(filters)
logger.debug(f"Sending DM subscription: {json.dumps(subscription)}")
latest_timestamp = filters.since or get_timestamp()
# Exponential backoff settings for reconnect attempts
initial_backoff = 0.5
max_backoff = 30.0
backoff = initial_backoff
while True:
try:
async with connect(self.relay) as ws:
# Reset backoff on successful (re)connection
backoff = initial_backoff
await ws.send(json.dumps(subscription))
while True:
response = await ws.recv()
Expand All @@ -225,11 +232,19 @@ async def direct_message_listener(self, filters: Filters, callback: Callable[[Ev
await callback(dm.event, dm.message)
except Exception as e:
logger.error(f"Error in direct_message_listener callback: {e}")
logger.error(traceback.format_exc())
await asyncio.sleep(0)
except asyncio.CancelledError:
# Allow cooperative cancellation
logger.debug("direct_message_listener task cancelled")
raise
except Exception as e:
logger.warning(f"Connection closed in direct_message_listener at {int(time.time())} trying again: {e}")
# Move the window forward to avoid re-processing
filters.since = latest_timestamp + 1
subscription = create_subscription(filters)
logger.debug(f"Sending DM subscription: {json.dumps(subscription)}")
await asyncio.sleep(0)
# Exponential backoff with jitter
jitter = random.uniform(0, backoff * 0.1)
sleep_for = min(max_backoff, backoff) + jitter
await asyncio.sleep(sleep_for)
backoff = min(max_backoff, backoff * 2)
Loading