Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions topicer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,25 @@ async def discover_topics_dense(self, texts: Sequence[TextChunk], n: int | None
...

@abstractmethod
async def discover_topics_in_db_sparse(self, db_request: DBRequest, n: int | None = None) -> DiscoveredTopicsSparse:
async def discover_topics_in_db_sparse(self, db_request: DBRequest, n: int | None = None, db_embeddings: bool | None = None) -> DiscoveredTopicsSparse:
"""
Discover topics based on a database request and return a sparse representation.

:param db_request: Database request to fetch texts for topic discovery.
:param n: Optional number of topics to propose, if None uses the default value.
:param db_embeddings: Obtain text representations from database, if None uses the default.
:return: DiscoveredTopicsSparse
"""
...

@abstractmethod
async def discover_topics_in_db_dense(self, db_request: DBRequest, n: int | None = None) -> DiscoveredTopics:
async def discover_topics_in_db_dense(self, db_request: DBRequest, n: int | None = None, db_embeddings: bool | None = None) -> DiscoveredTopics:
"""
Discover topics based on a database request and return a dense representation.

:param db_request: Database request to fetch texts for topic discovery.
:param n: Optional number of topics to propose, if None uses the default value.
:param db_embeddings: Obtain text representations from database, if None uses the default.
:return: DiscoveredTopics
"""
...
Expand Down
76 changes: 62 additions & 14 deletions topicer/topic_discovery/fast_topic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import copy
import random
from typing import Sequence
from typing import Sequence, Any

import numpy as np
from classconfig import ConfigurableMixin, CreatableMixin, ConfigurableValue, ConfigurableSubclassFactory
Expand All @@ -11,12 +11,14 @@
from ruamel.yaml.scalarstring import LiteralScalarString
from topmost.preprocess.preprocess import Tokenizer, Preprocess

from tests.test_embedding.test_local import embedder
from topicer.base import BaseTopicer, MissingServiceError, BaseEmbeddingService
from topicer.schemas import DBRequest, DiscoveredTopics, TextChunk, DiscoveredTopicsSparse, Topic, Tag, \
TextChunkWithTagSpanProposals
from topicer.topic_discovery.time_text_detector import TimeTextDetector, CzechTimeTextDetector
from topicer.utils.template import TemplateTransformer, Template
from topicer.utils.tokenizers import CzechLemmatizedTokenizer
from sklearn.preprocessing import normalize


class EmbeddingServiceWrapper:
Expand Down Expand Up @@ -44,6 +46,36 @@ def encode(self, docs: list[str], normalize_embeddings: bool | None = None, show
return embeddings


class PrecomputedEmbeddings:
"""
A wrapper to pass pre-computed embeddings to FASTopic.
"""

def __init__(self, embeddings):
"""
:param embeddings: Pre-computed embeddings matrix.
"""
self.embeddings = embeddings

def encode(self, docs: list[str], normalize_embeddings: bool | None = None,
show_progress_bar: bool = False) -> NDArray:
"""
Encodes given documents into embeddings.

:param docs: list of documents to encode
:param normalize_embeddings: whether to normalize the embeddings using L2 normalization
if not provided, uses the default setting
:param show_progress_bar: this parameter is ignored, present for compatibility
:return: list of embeddings
"""
assert len(docs) == len(self.embeddings), "The number of documents and embeddings must match."

if normalize_embeddings:
return normalize(self.embeddings, norm='l2', axis=1)

return self.embeddings


class GenerateTopicNameResponse(BaseModel):
"""
Schema for the response from the topic name generation LLM.
Expand Down Expand Up @@ -225,6 +257,11 @@ class FastTopicDiscovery(BaseTopicer, ConfigurableMixin, CreatableMixin):
user_default=42,
voluntary=True
)
use_db_embeddings: bool = ConfigurableValue(
desc="Whether to use database embeddings if available when fetching texts from the database.",
user_default=False,
voluntary=True
)

def __post_init__(self):
if self.random_seed is not None:
Expand Down Expand Up @@ -264,38 +301,50 @@ async def discover_topics_dense(self, texts: Sequence[TextChunk], n: int | None
sparse=False
)

async def discover_topics_in_db_sparse(self, db_request: DBRequest, n: int | None = None) -> DiscoveredTopicsSparse:
if self.db_connection is None:
raise MissingServiceError("DB connection has to be set for DB topic discovery.")
async def discover_topics_in_db_sparse(self, db_request: DBRequest, n: int | None = None, db_embeddings: bool | None = None) -> DiscoveredTopicsSparse:
return await self.discover_topics_in_db(db_request=db_request, n=n, sparse=True, db_embeddings=db_embeddings)

texts = self.db_connection.get_text_chunks(db_request)
return await self.discover_topics_sparse(texts=texts, n=n)
async def discover_topics_in_db_dense(self, db_request: DBRequest, n: int | None = None, db_embeddings: bool | None = None) -> DiscoveredTopics:
return await self.discover_topics_in_db(db_request=db_request, n=n, sparse=False, db_embeddings=db_embeddings)

async def discover_topics_in_db_dense(self, db_request: DBRequest, n: int | None = None) -> DiscoveredTopics:
async def discover_topics_in_db(self, db_request: DBRequest, n: int | None = None, sparse: bool = False, db_embeddings: bool | None = None) -> DiscoveredTopics | DiscoveredTopicsSparse:
if self.db_connection is None:
raise MissingServiceError("DB connection has to be set for DB topic discovery.")

texts = self.db_connection.get_text_chunks(db_request)
return await self.discover_topics_dense(texts=texts, n=n)
embedder = None
if db_embeddings or (db_embeddings is None and self.use_db_embeddings):
embeddings = self.db_connection.get_embeddings(texts)
embedder = PrecomputedEmbeddings(embeddings=embeddings)

texts = self.truncate_texts(texts, self.max_char_length)
top_words, doc_topic_dist = await self._get_topics(texts=texts, n=n, embedder=embedder)
return await self._process_topics(
top_words=top_words,
doc_topic_dist=doc_topic_dist,
texts=texts,
sparse=sparse
)

async def propose_tags(self, text_chunk: TextChunk, tags: list[Tag]) -> TextChunkWithTagSpanProposals:
raise NotImplementedError()

async def propose_tags_in_db(self, tag: Tag, db_request: DBRequest) -> list[TextChunkWithTagSpanProposals]:
raise NotImplementedError()

async def _get_topics(self, texts: Sequence[TextChunk], n: int | None = None) -> tuple[list[str], NDArray]:
async def _get_topics(self, texts: Sequence[TextChunk], n: int | None = None, embedder: str | Any = None) -> tuple[list[str], NDArray]:
"""
Discovers topics using the FASTopic model.

:param texts: Sequence of TextChunk objects.
:param n: Optional number of topics to discover. If None, uses the default from
:param embedder: Optional embedder to use. If None, uses the embedding service wrapper.
:return: Tuple containing:
- List of top words for each topic.
- Document-topic distribution matrix.
"""

model = await asyncio.to_thread(self._create_fastopic_model, n=n)
model = await asyncio.to_thread(self._create_fastopic_model, n=n, embedder=embedder)

top_words, doc_topic_dist = await asyncio.to_thread(
model.fit_transform,
Expand All @@ -320,19 +369,20 @@ def truncate_texts(texts: Sequence[TextChunk], max_char_length: int) -> Sequence

return texts

def _create_fastopic_model(self, n: int | None = None) -> FASTopic:
def _create_fastopic_model(self, n: int | None = None, embedder: str | Any = None) -> FASTopic:
"""
Creates and configures the FASTopic model.

:param n: Optional number of topics to discover. If None, uses the default from configuration.
:param embedder: Optional embedder to use. If None, uses the embedding service wrapper.
:return: Configured FASTopic model.
"""
preprocessing = Preprocess(tokenizer=self.tokenizer, vocab_size=self.vocab_size)
model = FASTopic(
num_topics=self.n_topics if n is None else n,
preprocess=preprocessing,
num_top_words=self.topic_rep_size,
doc_embed_model=self.embedding_service_wrapper,
doc_embed_model=self.embedding_service_wrapper if embedder is None else embedder,
verbose=self.verbose,
normalize_embeddings=self.embedding_service.normalize_embeddings if hasattr(self.embedding_service, 'normalize_embeddings') else False,
)
Expand Down Expand Up @@ -408,7 +458,6 @@ async def _generate_topic_names(self, top_words: list[list[str]], top_docs_per_t
explanation of the topic name
"""


text_chunks = []
for words, docs in zip(top_words, top_docs_per_topic):
prompt = self.generate_topic_name_prompt.render({
Expand All @@ -417,7 +466,6 @@ async def _generate_topic_names(self, top_words: list[list[str]], top_docs_per_t
})
text_chunks.append(prompt)


api_output = await self.llm_service.process_text_chunks_structured(
text_chunks=text_chunks,
instruction=self.generate_topic_name_system_prompt.render({}),
Expand Down