Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions docs/usage/interface-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ writer = EmbeddingWriterInterface(
backend=backend,
metric_type=MetricType.COSINE,
embedding_client=embedding_client,
omop_cdm_engine=cdm_engine, # optional; required for embed_and_upsert_concepts
omop_cdm_engine=cdm_engine, # optional; used to enrich search results
)
```

Expand All @@ -71,13 +71,16 @@ is safe and returns the existing record.
### Generate and store embeddings

```python
# Generate embeddings from CDM concepts and upsert in one step.
# omop_cdm_engine is used to fetch domain_id, vocabulary_id, standard_concept,
# and invalid_reason from the CDM and store them as filter metadata.
writer.embed_and_upsert_concepts(
# Fetch candidate concepts from the CDM, then pass the returned rows back as
# concept_meta so filter columns can be stored alongside the embeddings.
missing = writer.get_concepts_without_embedding(
omop_cdm_engine=cdm_engine,
concept_ids=(1, 2, 3),
concept_texts=("Hypertension", "Diabetes mellitus", "Aspirin"),
)

writer.embed_and_upsert_concepts(
concept_ids=tuple(missing.keys()),
concept_texts=tuple(row.concept_name for row in missing.values()),
concept_meta=missing,
)
```

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ classifiers = [

dependencies = [
"numpy>=1.26",
"omop-alchemy>=0.5.7",
"orm-loader>=0.3.15",
"omop-alchemy>=0.6.3",
Comment thread
nicoloesch marked this conversation as resolved.
"sqlalchemy>=2.0.45",
"typing-extensions>=4.15.0",
"sqlite-vec>=0.1.9",
Expand Down
80 changes: 39 additions & 41 deletions src/omop_emb/cli/cli_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from dotenv import load_dotenv
from tqdm import tqdm

from orm_loader.helpers import create_db

from .utils import configure_logging_level, resolve_omop_cdm_engine
from omop_emb.utils.cdm import check_concept_cdm
from omop_emb.backends.index_config import index_config_from_index_type
from omop_emb.backends import resolve_backend
from omop_emb.config import IndexType, MetricType
Expand Down Expand Up @@ -104,11 +103,6 @@ def add_embeddings(
help="Limit the number of concepts to embed. Useful for testing.",
rich_help_panel="Concept Filters",
)] = None,
cdm_batch_size: Annotated[int, typer.Option(
"--cdm-batch-size",
help="Batch size for fetching concept metadata from the CDM during ingestion. Adjust if you encounter performance issues or database limits during ingestion.",
rich_help_panel="CDM Fetch Options",
)] = 50_000,
verbosity: Annotated[int, typer.Option(
"--verbose", "-v", count=True,
help="Increase verbosity (up to two levels)",
Expand All @@ -124,7 +118,6 @@ def add_embeddings(

backend = resolve_backend()
omop_cdm_engine = resolve_omop_cdm_engine()
create_db(omop_cdm_engine)

embedding_client = EmbeddingClient(
model=model,
Expand All @@ -139,38 +132,49 @@ def add_embeddings(
metric_type=MetricType.COSINE,
embedding_client=embedding_client,
)
embedding_writer.register_model()

concept_filter = EmbeddingConceptFilter(
require_standard=standard_only,
domains=tuple(domains) if domains else None,
vocabularies=tuple(vocabularies) if vocabularies else None,
)
check_concept_cdm(omop_cdm_engine)

missing = embedding_writer.get_concepts_without_embedding(
omop_cdm_engine=omop_cdm_engine,
concept_filter=concept_filter,
)
if num_embeddings is not None:
missing = dict(list(missing.items())[:num_embeddings])
try:
embedding_writer.register_model()

total_concepts = len(missing)
typer.echo(f"Total concepts to process: {total_concepts:,}")
# Filter concepts
concept_filter = EmbeddingConceptFilter(
require_standard=standard_only,
domains=tuple(domains) if domains else None,
vocabularies=tuple(vocabularies) if vocabularies else None,
)
n_missing = embedding_writer.count_concepts_without_embedding(
omop_cdm_engine=omop_cdm_engine,
concept_filter=concept_filter,
)
n_total = min(n_missing, num_embeddings) if num_embeddings is not None else n_missing
typer.echo(f"Total concepts to process: {n_total:,}")

from itertools import batched as _batched
with tqdm(total=total_concepts, desc="Processing", unit="concept") as pbar:
for batch in _batched(missing.items(), batch_size):
batch_dict = dict(batch)
embedding_writer.embed_and_upsert_concepts(
n_processed = 0
with tqdm(total=n_total, desc="Processing", unit="concept") as pbar:
for batch_dict in embedding_writer.get_concepts_without_embedding_batched(
omop_cdm_engine=omop_cdm_engine,
concept_ids=tuple(batch_dict.keys()),
concept_texts=tuple(batch_dict.values()),
concept_filter=concept_filter,
batch_size=batch_size,
cdm_batch_size=cdm_batch_size,
)
pbar.update(len(batch_dict))

logger.info("Completed embedding generation and storage.")
limit=num_embeddings,
):
embedding_writer.embed_and_upsert_concepts(
concept_ids=tuple(batch_dict.keys()),
concept_texts=tuple(row.concept_name for row in batch_dict.values()),
concept_meta=batch_dict,
batch_size=batch_size,
)
n_processed += len(batch_dict)
pbar.update(len(batch_dict))

typer.echo(f"Processed {n_processed:,} concepts.")
logger.info("Completed embedding generation and storage.")
except Exception as e:
logger.exception(f"Error during embedding generation and storage.\n{e}")
if not embedding_writer.has_any_embeddings():
logger.info("No embeddings were stored. Cleaning up model registration.")
embedding_writer.delete_model()
raise typer.Exit(code=1)


@app.command()
Expand Down Expand Up @@ -305,11 +309,6 @@ def add_embeddings_with_index(
help="Limit the number of concepts to embed. Useful for testing.",
rich_help_panel="Concept Filters",
)] = None,
cdm_batch_size: Annotated[int, typer.Option(
"--cdm-batch-size",
help="Batch size for fetching concept metadata from the CDM during ingestion. Adjust if you encounter performance issues or database limits during ingestion.",
rich_help_panel="CDM Fetch Options",
)] = 50_000,
index_hnsw_num_neighbors: Annotated[Optional[int], typer.Option(
"--index-hnsw-num-neighbors",
help="HNSW: number of neighbors per graph node.",
Expand Down Expand Up @@ -344,7 +343,6 @@ def add_embeddings_with_index(
vocabularies=vocabularies,
domains=domains,
num_embeddings=num_embeddings,
cdm_batch_size=cdm_batch_size,
verbosity=verbosity,
)

Expand Down
10 changes: 4 additions & 6 deletions src/omop_emb/cli/cli_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from tqdm import tqdm

import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker

from .utils import configure_logging_level
from omop_emb.backends import resolve_backend
from omop_emb.backends.index_config import FlatIndexConfig
from omop_emb.config import MetricType, ProviderType
from omop_emb.interface import _fetch_cdm_concepts_for_ingestion
from omop_emb.utils.cdm import fetch_cdm_concepts_for_ingestion
from omop_emb.utils.embedding_utils import ConceptEmbeddingRecord

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,7 +130,6 @@ def add_embeddings_from_h5(
typer.echo(f"Registered model '{model}' ({dimensions}d, metric={metric_type.value}).")

cdm_engine = sa.create_engine(omop_cdm_db_url, future=True, echo=False)
cdm_factory = sessionmaker(cdm_engine)

n_batches = (total + batch_size - 1) // batch_size
typer.echo(f"Ingesting {total:,} embeddings in {n_batches} batch(es) of {batch_size:,}...")
Expand All @@ -142,8 +139,9 @@ def add_embeddings_from_h5(
end = min(start + batch_size, total)
batch_cids: np.ndarray = np.asarray(cid_ds[start:end])
batch_emb = np.asarray(emb_ds[start:end], dtype=np.float32)
meta = _fetch_cdm_concepts_for_ingestion(
{int(cid) for cid in batch_cids}, cdm_factory,
meta = fetch_cdm_concepts_for_ingestion(
{int(cid) for cid in batch_cids},
cdm_engine,
batch_size=cdm_batch_size,
)
records = []
Expand Down
Loading
Loading