Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Add bank_id column to memory_links for direct filtering

The stats endpoint JOINs memory_links to memory_units just to filter by
bank_id. With millions of links this takes 18+ seconds. Adding bank_id
directly to memory_links lets Postgres push the filter down before the JOIN.

Revision ID: c5d6e7f8a9b0
Revises: b3c4d5e6f7a8
Create Date: 2026-03-26
"""

from collections.abc import Sequence

from alembic import context, op

revision: str = "c5d6e7f8a9b0"
down_revision: str | Sequence[str] | None = "b3c4d5e6f7a8"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def _get_schema_prefix() -> str:
schema = context.config.get_main_option("target_schema")
return f'"{schema}".' if schema else ""


def upgrade() -> None:
schema = _get_schema_prefix()

# 1. Add nullable column
op.execute(f"ALTER TABLE {schema}memory_links ADD COLUMN IF NOT EXISTS bank_id TEXT")

# 2. Backfill from memory_units
op.execute(f"""
UPDATE {schema}memory_links ml
SET bank_id = mu.bank_id
FROM {schema}memory_units mu
WHERE ml.from_unit_id = mu.id
AND ml.bank_id IS NULL
""")

# 3. Set NOT NULL
op.execute(f"ALTER TABLE {schema}memory_links ALTER COLUMN bank_id SET NOT NULL")



def downgrade() -> None:
schema = _get_schema_prefix()
op.execute(f"ALTER TABLE {schema}memory_links DROP COLUMN IF EXISTS bank_id")
8 changes: 5 additions & 3 deletions hindsight-api-slim/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5826,14 +5826,16 @@ async def get_bank_stats(
bank_id,
)

# Single query for all link stats — avoids triple join on memory_links (can be 21M+ rows).
# link_counts and link_counts_by_fact_type are derived in Python from the breakdown.
# Link stats — filter on ml.bank_id (indexed) instead of joining through mu.bank_id.
# With the idx_memory_links_bank_link_type index this turns a full-table hash join
# into an indexed scan + PK lookups. link_counts and link_counts_by_fact_type are
# derived in Python from the breakdown.
link_breakdown_stats = await conn.fetch(
f"""
SELECT mu.fact_type, ml.link_type, COUNT(*) as count
FROM {fq_table("memory_links")} ml
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
WHERE mu.bank_id = $1
WHERE ml.bank_id = $1
GROUP BY mu.fact_type, ml.link_type
""",
bank_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ async def process_entities_batch(
return entity_links


async def insert_entity_links_batch(conn, entity_links: list[EntityLink]) -> None:
async def insert_entity_links_batch(conn, entity_links: list[EntityLink], bank_id: str) -> None:
"""
Insert entity links in batch.

Args:
conn: Database connection
entity_links: List of EntityLink objects
bank_id: Bank identifier (stored directly on memory_links for fast filtering)
"""
if not entity_links:
return

await link_utils.insert_entity_links_batch(conn, entity_links)
await link_utils.insert_entity_links_batch(conn, entity_links, bank_id)
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def create_semantic_links_batch(conn, bank_id: str, unit_ids: list[str], e
return await link_utils.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings, log_buffer=[])


async def create_causal_links_batch(conn, unit_ids: list[str], facts: list[ProcessedFact]) -> int:
async def create_causal_links_batch(conn, bank_id: str, unit_ids: list[str], facts: list[ProcessedFact]) -> int:
"""
Create causal links between facts.

Expand Down Expand Up @@ -94,6 +94,6 @@ async def create_causal_links_batch(conn, unit_ids: list[str], facts: list[Proce
else:
causal_relations_per_fact.append([])

link_count = await link_utils.create_causal_links_batch(conn, unit_ids, causal_relations_per_fact)
link_count = await link_utils.create_causal_links_batch(conn, bank_id, unit_ids, causal_relations_per_fact)

return link_count
43 changes: 26 additions & 17 deletions hindsight-api-slim/hindsight_api/engine/retain/link_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,16 +495,18 @@ async def create_temporal_links_batch_per_fact(

if links:
insert_start = time_mod.time()
# Add bank_id to each tuple for direct filtering (avoids expensive JOIN in stats)
links_with_bank = [(*link, bank_id) for link in links]
# Batch inserts to avoid timeout on large batches
BATCH_SIZE = 1000
for batch_start in range(0, len(links), BATCH_SIZE):
for batch_start in range(0, len(links_with_bank), BATCH_SIZE):
await conn.executemany(
f"""
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
VALUES ($1, $2, $3, $4, $5)
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id, bank_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
""",
links[batch_start : batch_start + BATCH_SIZE],
links_with_bank[batch_start : batch_start + BATCH_SIZE],
)
_log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")

Expand Down Expand Up @@ -627,16 +629,18 @@ async def create_semantic_links_batch(

if all_links:
insert_start = time_mod.time()
# Add bank_id to each tuple for direct filtering (avoids expensive JOIN in stats)
all_links_with_bank = [(*link, bank_id) for link in all_links]
# Batch inserts to avoid timeout on large batches
BATCH_SIZE = 1000
for batch_start in range(0, len(all_links), BATCH_SIZE):
for batch_start in range(0, len(all_links_with_bank), BATCH_SIZE):
await conn.executemany(
f"""
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
VALUES ($1, $2, $3, $4, $5)
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id, bank_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
""",
all_links[batch_start : batch_start + BATCH_SIZE],
all_links_with_bank[batch_start : batch_start + BATCH_SIZE],
)
_log(
log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s"
Expand All @@ -652,7 +656,7 @@ async def create_semantic_links_batch(
raise


async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: int = 5000):
async def insert_entity_links_batch(conn, links: list[EntityLink], bank_id: str, chunk_size: int = 5000):
"""
Insert all entity links using COPY to temp table + chunked INSERT for reliability.

Expand All @@ -663,6 +667,7 @@ async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: i
Args:
conn: Database connection
links: List of EntityLink objects
bank_id: Bank identifier (stored directly on memory_links for fast filtering)
chunk_size: Number of rows per INSERT chunk (default 5000)
"""
if not links:
Expand All @@ -681,7 +686,8 @@ async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: i
to_unit_id uuid,
link_type text,
weight float,
entity_id uuid
entity_id uuid,
bank_id text
) ON COMMIT DROP
""")
logger.debug(f" [9.1] Create temp table: {time_mod.time() - create_start:.3f}s")
Expand All @@ -693,15 +699,17 @@ async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: i

# Convert EntityLink objects to tuples for COPY
convert_start = time_mod.time()
records = [(link.from_unit_id, link.to_unit_id, link.link_type, link.weight, link.entity_id) for link in links]
records = [
(link.from_unit_id, link.to_unit_id, link.link_type, link.weight, link.entity_id, bank_id) for link in links
]
logger.debug(f" [9.3] Convert {len(records)} records: {time_mod.time() - convert_start:.3f}s")

# Bulk load using COPY (fastest method)
copy_start = time_mod.time()
await conn.copy_records_to_table(
"_temp_entity_links",
records=records,
columns=["from_unit_id", "to_unit_id", "link_type", "weight", "entity_id"],
columns=["from_unit_id", "to_unit_id", "link_type", "weight", "entity_id", "bank_id"],
)
logger.debug(f" [9.4] COPY {len(records)} records to temp table: {time_mod.time() - copy_start:.3f}s")

Expand All @@ -713,8 +721,8 @@ async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: i
chunk_end = chunk_start + chunk_size
await conn.execute(
f"""
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
SELECT from_unit_id, to_unit_id, link_type, weight, entity_id
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id, bank_id)
SELECT from_unit_id, to_unit_id, link_type, weight, entity_id, bank_id
FROM _temp_entity_links
WHERE _row_num > $1 AND _row_num <= $2
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
Expand All @@ -729,6 +737,7 @@ async def insert_entity_links_batch(conn, links: list[EntityLink], chunk_size: i

async def create_causal_links_batch(
conn,
bank_id: str,
unit_ids: list[str],
causal_relations_per_fact: list[list[dict]],
) -> int:
Expand Down Expand Up @@ -795,15 +804,15 @@ async def create_causal_links_batch(
# Add the causal link
# link_type is the relation_type (e.g., "causes", "caused_by")
# weight is the strength of the relationship
links.append((from_unit_id, to_unit_id, relation_type, strength, None))
links.append((from_unit_id, to_unit_id, relation_type, strength, None, bank_id))

if links:
insert_start = time_mod.time()
try:
await conn.executemany(
f"""
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id)
VALUES ($1, $2, $3, $4, $5)
INSERT INTO {fq_table("memory_links")} (from_unit_id, to_unit_id, link_type, weight, entity_id, bank_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
""",
links,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ async def _insert_facts_and_links(
# Insert entity links
step_start = time.time()
if entity_links:
await entity_processing.insert_entity_links_batch(conn, entity_links)
await entity_processing.insert_entity_links_batch(conn, entity_links, bank_id)
log_buffer.append(
f" Entity links: {len(entity_links) if entity_links else 0} links in {time.time() - step_start:.3f}s"
)

# Create causal links
step_start = time.time()
causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, processed_facts)
causal_link_count = await link_creation.create_causal_links_batch(conn, bank_id, unit_ids, processed_facts)
log_buffer.append(f" Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")

# Map results back to original content items
Expand Down
Loading