Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ak-py/src/agentkernel/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .service import AgentService
from .hooks import PreHook, PostHook
from .util.key_value_cache import KeyValueCache
from .multimodal import MultimodalContextHook, MultimodalMemoryHook, MultimodalModuleMixin
20 changes: 20 additions & 0 deletions ak-py/src/agentkernel/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ class _GmailConfig(BaseModel):
label_filter: str = Field(default="INBOX", description="Gmail label to monitor (e.g., INBOX, UNREAD)")


class _MultimodalConfig(BaseModel):
"""Configuration for multimodal memory."""

enabled: bool = Field(
default=True,
description="Enable multimodal memory. When enabled, images and files are persisted "
"in session cache and automatically injected into follow-up requests.",
)
max_attachments: int = Field(
default=5, description="Maximum number of recent attachments to keep in context for follow-up questions"
)
attachment_ttl: int = Field(
default=604800,
description="Time-to-live for attachments in seconds (default: 604800 = 1 week, matches session TTL)",
)


class _TraceConfig(BaseModel):
enabled: bool = Field(default=False, description="Enable tracing")
type: str = Field(default="langfuse", pattern="^(langfuse|openllmetry)$")
Expand Down Expand Up @@ -161,6 +178,9 @@ class AKConfig(YamlBaseSettingsModified):
description="Telegram Bot related configurations", default_factory=_TelegramConfig
)
gmail: _GmailConfig = Field(description="Gmail related configurations", default_factory=_GmailConfig)
multimodal: _MultimodalConfig = Field(
description="Multimodal memory related configurations", default_factory=_MultimodalConfig
)

trace: _TraceConfig = Field(description="Tracing related configurations", default_factory=_TraceConfig)
test: _TestConfig = Field(description="Test related configurations", default_factory=_TestConfig)
Expand Down
296 changes: 296 additions & 0 deletions ak-py/src/agentkernel/core/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
"""
Shared multimodal hooks for all Agent Kernel frameworks.

This module provides framework-agnostic hooks for multimodal conversation memory,

Architecture:
- MultimodalContextHook (Pre-hook): Loads previous attachments from cache and
injects them into the current request so the model has context.
- MultimodalMemoryHook (Post-hook): Saves new attachments to non-volatile cache
(Redis/DynamoDB) for persistence across requests.
- MultimodalModuleMixin: Mixin class for framework modules to auto-register hooks.

Example flow:
User: [sends image] "What's in this photo?"
→ Post-hook saves image to cache
Bot: "I see a red car."

User: "What color is the house in the background?"
→ Pre-hook loads previous image and injects it
→ Model receives: [previous image] + "What color is the house..."
Bot: "The house appears to be white."

Configuration via environment variables:
AK_OPENAI__MULTIMODAL_MEMORY=true # Enable/disable (default: true)
AK_OPENAI__MAX_ATTACHMENTS=5 # Max attachments in context
AK_OPENAI__ATTACHMENT_TTL=604800 # Expiry in seconds (1 week)
"""

import logging
import time
from typing import TYPE_CHECKING

from .hooks import PostHook, PreHook
from .model import AgentRequestFile, AgentRequestImage

if TYPE_CHECKING:
from .base import Agent
from .model import AgentReply, AgentRequest
from .session import Session

# Default: keep last 5 attachments in context (configurable)
DEFAULT_MAX_ATTACHMENTS = 5
# Default: attachments expire after 1 week (604800 seconds) - matches session TTL
DEFAULT_ATTACHMENT_TTL = 604800


class MultimodalContextHook(PreHook):
"""
Pre-hook that loads previous images/files from non-volatile cache and injects

"""

def __init__(self, max_attachments: int = DEFAULT_MAX_ATTACHMENTS, ttl_seconds: int = DEFAULT_ATTACHMENT_TTL):
"""
Initialize the context hook.

:param max_attachments: Maximum number of recent attachments to inject (default: 5)
:param ttl_seconds: Time-to-live for attachments in seconds (default: 604800 = 1 week)
"""
self._log = logging.getLogger("ak.hooks.multimodal_context")
self._max_attachments = max_attachments
self._ttl_seconds = ttl_seconds

async def on_run(self, session: "Session", agent: "Agent", requests: list["AgentRequest"]) -> list["AgentRequest"]:
"""
Load previous attachments from cache and inject into the request.

:param session: The current session
:param agent: The agent instance
:param requests: List of current agent requests
:return: Modified requests with previous attachments injected
"""
if not session:
return requests

nv_cache = session.get_non_volatile_cache()
current_time = time.time()

# Collect all stored attachments
attachments = []
keys_to_check = nv_cache.keys() if hasattr(nv_cache, "keys") else []

if not keys_to_check:
# Try to retrieve using the index if available
index_data = nv_cache.get("_attachment_index")
if index_data:
keys_to_check = index_data.get("keys", [])

for key in keys_to_check:
if not (key.startswith("attachment_image_") or key.startswith("attachment_file_")):
continue

data = nv_cache.get(key)
if not data:
continue

# Check TTL
timestamp = data.get("timestamp", 0)
if current_time - timestamp > self._ttl_seconds:
# Expired, skip (could also delete here)
self._log.debug(f"Skipping expired attachment: {key}")
continue

attachments.append({"key": key, "data": data, "timestamp": timestamp})

if not attachments:
return requests

# Sort by timestamp and limit
# Use secondary sort by key for consistent ordering when timestamps match
attachments.sort(key=lambda x: (x["timestamp"], x["key"]))
attachments = attachments[-self._max_attachments :]

# Inject attachments at the beginning of requests
injected_requests = []
for att in attachments:
data = att["data"]
att_type = data.get("type")

if att_type == "image":
img_req = AgentRequestImage(
image_data=data.get("data"),
mime_type=data.get("mime_type"),
name=data.get("name", "previous_image"),
)
# Mark as injected so post-hook doesn't re-save it
img_req._injected = True
injected_requests.append(img_req)
elif att_type == "file":
file_req = AgentRequestFile(
file_data=data.get("data"), mime_type=data.get("mime_type"), name=data.get("name", "previous_file")
)
# Mark as injected so post-hook doesn't re-save it
file_req._injected = True
injected_requests.append(file_req)

if injected_requests:
self._log.info(f"Injected {len(injected_requests)} previous attachment(s) into context")
# Prepend previous attachments, then add current requests
return injected_requests + list(requests)

return requests

def name(self) -> str:
"""Return hook name"""
return "MultimodalContextHook"


class MultimodalMemoryHook(PostHook):
"""
Post-hook that stores images and files in non-volatile cache after agent processes them.

This hook ensures multimodal attachments are persisted in the session cache
(Redis or DynamoDB) so they're available in subsequent requests via MultimodalContextHook.

Works with all frameworks: OpenAI, ADK, LangGraph, CrewAI
"""

def __init__(self, max_attachments: int = DEFAULT_MAX_ATTACHMENTS):
"""
Initialize the memory hook.

:param max_attachments: Maximum number of attachments to keep in cache (default: 5)
"""
self._log = logging.getLogger("ak.hooks.multimodal_memory")
self._max_attachments = max_attachments

async def on_run(
self, session: "Session", requests: list["AgentRequest"], agent: "Agent", agent_reply: "AgentReply"
) -> "AgentReply":
"""
Save multimodal attachments to non-volatile cache after agent processes them.

:param session: The current session
:param requests: List of agent requests (may contain AgentRequestImage/AgentRequestFile)
:param agent: The agent instance
:param agent_reply: The agent's reply
:return: The agent_reply unchanged
"""
if not session:
return agent_reply

nv_cache = session.get_non_volatile_cache()
timestamp = time.time()
attachment_count = 0
new_keys = []

# Store each image/file with timestamp
for i, req in enumerate(requests):
if isinstance(req, AgentRequestImage):
# Skip if this is an injected previous attachment
if getattr(req, "_injected", False):
continue

cache_key = f"attachment_image_{timestamp}_{i}"
nv_cache.set(
cache_key,
{
"type": "image",
"data": req.image_data,
"mime_type": req.mime_type,
"name": getattr(req, "name", None),
"timestamp": timestamp,
},
)
new_keys.append(cache_key)
self._log.debug(f"Stored image to cache: {cache_key}")
attachment_count += 1

elif isinstance(req, AgentRequestFile):
# Skip if this is an injected previous attachment
if getattr(req, "_injected", False):
continue

cache_key = f"attachment_file_{timestamp}_{i}"
nv_cache.set(
cache_key,
{
"type": "file",
"data": req.file_data,
"mime_type": req.mime_type,
"name": req.name,
"timestamp": timestamp,
},
)
new_keys.append(cache_key)
self._log.debug(f"Stored file to cache: {cache_key}")
attachment_count += 1

# Update attachment index for cache implementations without keys() method
if new_keys:
index_data = nv_cache.get("_attachment_index") or {"keys": []}
index_data["keys"].extend(new_keys)
# Keep only recent keys
index_data["keys"] = index_data["keys"][-self._max_attachments * 2 :]
nv_cache.set("_attachment_index", index_data)

if attachment_count > 0:
self._log.info(f"Stored {attachment_count} multimodal attachment(s) to non-volatile cache")

return agent_reply

def name(self) -> str:
"""Return hook name"""
return "MultimodalMemoryHook"


class MultimodalModuleMixin:
"""
Mixin class that provides multimodal memory hook registration for framework modules.

Usage:
class OpenAIModule(MultimodalModuleMixin, Module):
def __init__(self, agents, runner=None):
super().__init__()
# ... setup ...
self._register_multimodal_hooks(agents)

Configuration is read from AKConfig.openai settings:
- multimodal_memory: bool (default: True)
- max_attachments: int (default: 5)
- attachment_ttl: int (default: 604800 = 1 week)
"""

def _register_multimodal_hooks(self, agents: list) -> None:
"""
Register multimodal memory hooks for all agents if enabled in config.

:param agents: List of agents (framework-specific agent objects)
"""
from .config import AKConfig
from .runtime import GlobalRuntime

config = AKConfig.get().multimodal
if not config.enabled:
return

runtime = GlobalRuntime.instance()

# Create hooks with config values
context_hook = MultimodalContextHook(max_attachments=config.max_attachments, ttl_seconds=config.attachment_ttl)
memory_hook = MultimodalMemoryHook(max_attachments=config.max_attachments)

# Register hooks for each agent by attaching to the wrapped agent
for agent in agents:
# Get agent name from raw agent (different frameworks use different attributes)
agent_name = getattr(agent, "name", None) or getattr(agent, "role", None)
if not agent_name:
continue

# Get wrapped agent from module
wrapped_agent = self.get_agent(agent_name)
if wrapped_agent:
# Attach hooks to wrapped agent (which has the methods)
wrapped_agent.attach_pre_hooks([context_hook])
wrapped_agent.attach_post_hooks([memory_hook])
9 changes: 8 additions & 1 deletion ak-py/src/agentkernel/framework/adk/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...core import Runner as BaseRunner
from ...core import Session
from ...core.config import AKConfig
from ...core.multimodal import MultimodalModuleMixin
from ...trace import Trace

FRAMEWORK = "adk"
Expand Down Expand Up @@ -204,9 +205,12 @@ def get_a2a_card(self):
pass


class GoogleADKModule(Module):
class GoogleADKModule(MultimodalModuleMixin, Module):
"""
GoogleADKModule class provides a module for Google ADK-based agents.

When multimodal_memory is enabled (default), this module automatically registers
hooks to provide ChatGPT-like conversation memory for images and files.
"""

def __init__(self, agents: list[BaseAgent], runner: GoogleADKRunner = None):
Expand All @@ -224,6 +228,9 @@ def __init__(self, agents: list[BaseAgent], runner: GoogleADKRunner = None):
self.runner = GoogleADKRunner()
self.load(agents)

# Auto-register multimodal memory hooks (from MultimodalModuleMixin)
self._register_multimodal_hooks(agents)

def _wrap(self, agent: BaseAgent, agents: List[BaseAgent]) -> AKBaseAgent:
"""
Wraps the provided agent in a GoogleADKAgent instance.
Expand Down
Loading
Loading