diff --git a/baselines/Genie/genie.py b/baselines/Genie/genie.py index 75e713e..1dfd8e9 100644 --- a/baselines/Genie/genie.py +++ b/baselines/Genie/genie.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIModel +from graphgen.models import OpenAIClient from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """Instruction: Given the next [document], create a [question] and [answer] pair that are grounded \ @@ -59,7 +59,7 @@ def _post_process(content: str) -> tuple: @dataclass class Genie: - llm_client: OpenAIModel = None + llm_client: OpenAIClient = None max_concurrent: int = 1000 def generate(self, docs: List[List[dict]]) -> List[dict]: @@ -121,7 +121,7 @@ async def process_chunk(content: str): load_dotenv() - llm_client = OpenAIModel( + llm_client = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), diff --git a/baselines/LongForm/longform.py b/baselines/LongForm/longform.py index 31feb01..8467556 100644 --- a/baselines/LongForm/longform.py +++ b/baselines/LongForm/longform.py @@ -11,7 +11,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIModel +from graphgen.models import OpenAIClient from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """Instruction: X @@ -23,7 +23,7 @@ @dataclass class LongForm: - llm_client: OpenAIModel = None + llm_client: OpenAIClient = None max_concurrent: int = 1000 def generate(self, docs: List[List[dict]]) -> List[dict]: @@ -88,7 +88,7 @@ async def process_chunk(content: str): load_dotenv() - llm_client = OpenAIModel( + llm_client = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), diff --git a/baselines/SELF-QA/self-qa.py b/baselines/SELF-QA/self-qa.py index 8ee0307..1f96cff 100644 --- a/baselines/SELF-QA/self-qa.py +++ b/baselines/SELF-QA/self-qa.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIModel +from graphgen.models import OpenAIClient from graphgen.utils import compute_content_hash, create_event_loop INSTRUCTION_GENERATION_PROMPT = """The background knowledge is: @@ -58,7 +58,7 @@ def _post_process_answers(content: str) -> tuple: @dataclass class SelfQA: - llm_client: OpenAIModel = None + llm_client: OpenAIClient = None max_concurrent: int = 100 def generate(self, docs: List[List[dict]]) -> List[dict]: @@ -155,7 +155,7 @@ async def process_chunk(content: str): load_dotenv() - llm_client = OpenAIModel( + llm_client = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), diff --git a/baselines/Wrap/wrap.py b/baselines/Wrap/wrap.py index 3f71b2f..cecbaad 100644 --- a/baselines/Wrap/wrap.py +++ b/baselines/Wrap/wrap.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIModel +from graphgen.models import OpenAIClient from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. @@ -46,7 +46,7 @@ def _post_process(content: str) -> list: @dataclass class Wrap: - llm_client: OpenAIModel = None + llm_client: OpenAIClient = None max_concurrent: int = 1000 def generate(self, docs: List[List[dict]]) -> List[dict]: @@ -108,7 +108,7 @@ async def process_chunk(content: str): load_dotenv() - llm_client = OpenAIModel( + llm_client = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index e69de29..30b0014 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -0,0 +1,12 @@ +from .base_kg_builder import BaseKGBuilder +from .base_llm_client import BaseLLMClient +from .base_reader import BaseReader +from .base_splitter import BaseSplitter +from .base_storage import ( + BaseGraphStorage, + BaseKVStorage, + BaseListStorage, + StorageNameSpace, +) +from .base_tokenizer import BaseTokenizer +from .datatypes import Chunk, QAPair, Token diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py new file mode 100644 index 0000000..91c3df6 --- /dev/null +++ b/graphgen/bases/base_kg_builder.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Chunk + + +@dataclass +class BaseKGBuilder(ABC): + kg_instance: BaseGraphStorage + llm_client: BaseLLMClient + + _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list)) + _edges: Dict[Tuple[str, str], List[dict]] = field( + default_factory=lambda: defaultdict(list) + ) + + def build(self, chunks: List[Chunk]) -> None: + pass + + @abstractmethod + async def extract_all(self, chunks: List[Chunk]) -> None: + """Extract nodes and edges from all chunks.""" + raise NotImplementedError + + @abstractmethod + async def extract( + self, chunk: Chunk + ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: + """Extract nodes and edges from a single chunk.""" + raise NotImplementedError + + @abstractmethod + async def merge_nodes( + self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm + ) -> None: + """Merge extracted nodes into the knowledge graph.""" + raise NotImplementedError diff --git a/graphgen/bases/base_llm_client.py b/graphgen/bases/base_llm_client.py new file mode 100644 index 0000000..fdb8f8f --- /dev/null +++ b/graphgen/bases/base_llm_client.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import abc +import re +from typing import Any, List, Optional + +from graphgen.bases.base_tokenizer import BaseTokenizer +from graphgen.bases.datatypes import Token + + +class BaseLLMClient(abc.ABC): + """ + LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...). + """ + + def __init__( + self, + *, + system_prompt: str = "", + temperature: float = 0.0, + max_tokens: int = 4096, + repetition_penalty: float = 1.05, + top_p: float = 0.95, + top_k: int = 50, + tokenizer: Optional[BaseTokenizer] = None, + **kwargs: Any, + ): + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + self.repetition_penalty = repetition_penalty + self.top_p = top_p + self.top_k = top_k + self.tokenizer = tokenizer + + for k, v in kwargs.items(): + setattr(self, k, v) + + @abc.abstractmethod + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + """Generate answer from the model.""" + raise NotImplementedError + + @abc.abstractmethod + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + """Generate top-k tokens for the next token prediction.""" + raise NotImplementedError + + @abc.abstractmethod + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + """Generate probabilities for each token in the input.""" + raise NotImplementedError + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in the text.""" + if self.tokenizer is None: + raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.") + return len(self.tokenizer.encode(text)) + + @staticmethod + def filter_think_tags(text: str, think_tag: str = "think") -> str: + """ + Remove tags from the text. + If the text contains and , it removes everything between them and the tags themselves. + """ + think_pattern = re.compile(rf"<{think_tag}>.*?", re.DOTALL) + filtered_text = think_pattern.sub("", text).strip() + return filtered_text if filtered_text else text.strip() diff --git a/graphgen/bases/base_tokenizer.py b/graphgen/bases/base_tokenizer.py new file mode 100644 index 0000000..958b142 --- /dev/null +++ b/graphgen/bases/base_tokenizer.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List + + +@dataclass +class BaseTokenizer(ABC): + model_name: str = "cl100k_base" + + @abstractmethod + def encode(self, text: str) -> List[int]: + """Encode text -> token ids.""" + raise NotImplementedError + + @abstractmethod + def decode(self, token_ids: List[int]) -> str: + """Decode token ids -> text.""" + raise NotImplementedError + + def count_tokens(self, text: str) -> int: + return len(self.encode(text)) + + def chunk_by_token_size( + self, + content: str, + *, + overlap_token_size: int = 128, + max_token_size: int = 1024, + ) -> List[dict]: + tokens = self.encode(content) + results = [] + step = max_token_size - overlap_token_size + for index, start in enumerate(range(0, len(tokens), step)): + chunk_ids = tokens[start : start + max_token_size] + results.append( + { + "tokens": len(chunk_ids), + "content": self.decode(chunk_ids).strip(), + "chunk_order_index": index, + } + ) + return results diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 4cdc9d2..5a32126 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -1,4 +1,6 @@ +import math from dataclasses import dataclass, field +from typing import List, Union @dataclass @@ -16,3 +18,15 @@ class QAPair: question: str answer: str + + +@dataclass +class Token: + text: str + prob: float + top_candidates: List = field(default_factory=list) + ppl: Union[float, None] = field(default=None) + + @property + def logprob(self) -> float: + return math.log(self.prob) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index ab2be0d..8abb0b4 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -2,10 +2,9 @@ import os import time from dataclasses import dataclass, field -from typing import Dict, List, Union, cast +from typing import Dict, cast import gradio as gr -from tqdm.asyncio import tqdm as tqdm_async from graphgen.bases.base_storage import StorageNameSpace from graphgen.bases.datatypes import Chunk @@ -13,27 +12,25 @@ JsonKVStorage, JsonListStorage, NetworkXStorage, - OpenAIModel, + OpenAIClient, Tokenizer, TraverseStrategy, - read_file, - split_chunks, ) - -from .operators import ( +from graphgen.operators import ( + chunk_documents, extract_kg, generate_cot, judge_statement, quiz, + read_files, search_all, traverse_graph_for_aggregated, traverse_graph_for_atomic, traverse_graph_for_multi_hop, ) -from .utils import ( +from graphgen.utils import ( + async_to_sync_method, compute_content_hash, - create_event_loop, - detect_main_language, format_generation_results, logger, ) @@ -49,8 +46,8 @@ class GraphGen: # llm tokenizer_instance: Tokenizer = None - synthesizer_llm_client: OpenAIModel = None - trainee_llm_client: OpenAIModel = None + synthesizer_llm_client: OpenAIClient = None + trainee_llm_client: OpenAIClient = None # search search_config: dict = field( @@ -67,17 +64,17 @@ def __post_init__(self): self.tokenizer_instance: Tokenizer = Tokenizer( model_name=self.config["tokenizer"] ) - self.synthesizer_llm_client: OpenAIModel = OpenAIModel( + self.synthesizer_llm_client: OpenAIClient = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), - tokenizer_instance=self.tokenizer_instance, + tokenizer=self.tokenizer_instance, ) - self.trainee_llm_client: OpenAIModel = OpenAIModel( + self.trainee_llm_client: OpenAIClient = OpenAIClient( model_name=os.getenv("TRAINEE_MODEL"), api_key=os.getenv("TRAINEE_API_KEY"), base_url=os.getenv("TRAINEE_BASE_URL"), - tokenizer_instance=self.tokenizer_instance, + tokenizer=self.tokenizer_instance, ) self.search_config = self.config["search"] @@ -111,15 +108,23 @@ def __post_init__(self): namespace="qa", ) - async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: - # TODO: configurable whether to use coreference resolution + @async_to_sync_method + async def insert(self): + """ + insert chunks into the graph + """ + input_file = self.config["read"]["input_file"] + + # Step 1: Read files + data = read_files(input_file) if len(data) == 0: - return {} + logger.warning("No data to process") + return - inserting_chunks = {} - assert isinstance(data, list) and isinstance(data[0], dict) + # TODO: configurable whether to use coreference resolution - # compute hash for each document + # Step 2: Split chunks and filter existing ones + assert isinstance(data, list) and isinstance(data[0], dict) new_docs = { compute_content_hash(doc["content"], prefix="doc-"): { "content": doc["content"] @@ -128,38 +133,19 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: } _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + if len(new_docs) == 0: logger.warning("All docs are already in the storage") - return {} + return logger.info("[New Docs] inserting %d docs", len(new_docs)) - cur_index = 1 - doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): - doc_language = detect_main_language(doc["content"]) - text_chunks = split_chunks( - doc["content"], - language=doc_language, - chunk_size=self.config["split"]["chunk_size"], - chunk_overlap=self.config["split"]["chunk_overlap"], - ) - - chunks = { - compute_content_hash(txt, prefix="chunk-"): { - "content": txt, - "full_doc_id": doc_key, - "length": len(self.tokenizer_instance.encode_string(txt)), - "language": doc_language, - } - for txt in text_chunks - } - inserting_chunks.update(chunks) - - if self.progress_bar is not None: - self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}") - cur_index += 1 + inserting_chunks = await chunk_documents( + new_docs, + self.config["split"]["chunk_size"], + self.config["split"]["chunk_overlap"], + self.tokenizer_instance, + self.progress_bar, + ) _add_chunk_keys = await self.text_chunks_storage.filter_keys( list(inserting_chunks.keys()) @@ -167,29 +153,16 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } - await self.full_docs_storage.upsert(new_docs) - await self.text_chunks_storage.upsert(inserting_chunks) - - return inserting_chunks - - def insert(self): - loop = create_event_loop() - loop.run_until_complete(self.async_insert()) - - async def async_insert(self): - """ - insert chunks into the graph - """ - - input_file = self.config["read"]["input_file"] - data = read_file(input_file) - inserting_chunks = await self.async_split_chunks(data) if len(inserting_chunks) == 0: logger.warning("All chunks are already in the storage") return + logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) + await self.full_docs_storage.upsert(new_docs) + await self.text_chunks_storage.upsert(inserting_chunks) + # Step 3: Extract entities and relations from chunks logger.info("[Entity and Relation Extraction]...") _add_entities_and_relations = await extract_kg( llm_client=self.synthesizer_llm_client, @@ -219,11 +192,8 @@ async def _insert_done(self): tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) await asyncio.gather(*tasks) - def search(self): - loop = create_event_loop() - loop.run_until_complete(self.async_search()) - - async def async_search(self): + @async_to_sync_method + async def search(self): logger.info( "Search is %s", "enabled" if self.search_config["enabled"] else "disabled" ) @@ -257,13 +227,10 @@ async def async_search(self): ] ) # TODO: fix insert after search - await self.async_insert() - - def quiz(self): - loop = create_event_loop() - loop.run_until_complete(self.async_quiz()) + await self.insert() - async def async_quiz(self): + @async_to_sync_method + async def quiz(self): max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"] await quiz( self.synthesizer_llm_client, @@ -273,11 +240,8 @@ async def async_quiz(self): ) await self.rephrase_storage.index_done_callback() - def judge(self): - loop = create_event_loop() - loop.run_until_complete(self.async_judge()) - - async def async_judge(self): + @async_to_sync_method + async def judge(self): re_judge = self.config["quiz_and_judge_strategy"]["re_judge"] _update_relations = await judge_statement( self.trainee_llm_client, @@ -287,11 +251,8 @@ async def async_judge(self): ) await _update_relations.index_done_callback() - def traverse(self): - loop = create_event_loop() - loop.run_until_complete(self.async_traverse()) - - async def async_traverse(self): + @async_to_sync_method + async def traverse(self): output_data_type = self.config["output_data_type"] if output_data_type == "atomic": @@ -331,11 +292,8 @@ async def async_traverse(self): await self.qa_storage.upsert(results) await self.qa_storage.index_done_callback() - def generate_reasoning(self, method_params): - loop = create_event_loop() - loop.run_until_complete(self.async_generate_reasoning(method_params)) - - async def async_generate_reasoning(self, method_params): + @async_to_sync_method + async def generate_reasoning(self, method_params): results = await generate_cot( self.graph_storage, self.synthesizer_llm_client, @@ -349,11 +307,8 @@ async def async_generate_reasoning(self, method_params): await self.qa_storage.upsert(results) await self.qa_storage.index_done_callback() - def clear(self): - loop = create_event_loop() - loop.run_until_complete(self.async_clear()) - - async def async_clear(self): + @async_to_sync_method + async def clear(self): await self.full_docs_storage.drop() await self.text_chunks_storage.drop() await self.search_storage.drop() diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 9650909..cea2fc4 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -3,15 +3,15 @@ from .evaluate.mtld_evaluator import MTLDEvaluator from .evaluate.reward_evaluator import RewardEvaluator from .evaluate.uni_evaluator import UniEvaluator -from .llm.openai_model import OpenAIModel -from .llm.tokenizer import Tokenizer -from .llm.topk_token_model import Token, TopkTokenModel -from .reader import read_file +from .llm.openai_client import OpenAIClient +from .llm.topk_token_model import TopkTokenModel +from .reader import CsvReader, JsonlReader, JsonReader, TxtReader from .search.db.uniprot_search import UniProtSearch from .search.kg.wiki_search import WikiSearch from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch -from .splitter import split_chunks +from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter from .storage.json_storage import JsonKVStorage, JsonListStorage from .storage.networkx_storage import NetworkXStorage from .strategy.travserse_strategy import TraverseStrategy +from .tokenizer import Tokenizer diff --git a/graphgen/models/evaluate/length_evaluator.py b/graphgen/models/evaluate/length_evaluator.py index bf7cc48..9aa6c7c 100644 --- a/graphgen/models/evaluate/length_evaluator.py +++ b/graphgen/models/evaluate/length_evaluator.py @@ -2,7 +2,7 @@ from graphgen.bases.datatypes import QAPair from graphgen.models.evaluate.base_evaluator import BaseEvaluator -from graphgen.models.llm.tokenizer import Tokenizer +from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop @@ -18,5 +18,5 @@ async def evaluate_single(self, pair: QAPair) -> float: return await loop.run_in_executor(None, self._calculate_length, pair.answer) def _calculate_length(self, text: str) -> float: - tokens = self.tokenizer.encode_string(text) + tokens = self.tokenizer.encode(text) return len(tokens) diff --git a/graphgen/models/kg_builder/NetworkXKGBuilder.py b/graphgen/models/kg_builder/NetworkXKGBuilder.py new file mode 100644 index 0000000..9067363 --- /dev/null +++ b/graphgen/models/kg_builder/NetworkXKGBuilder.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from graphgen.bases import BaseKGBuilder + + +@dataclass +class NetworkXKGBuilder(BaseKGBuilder): + def build(self, chunks): + pass + + async def extract_all(self, chunks): + pass + + async def extract(self, chunk): + pass + + async def merge_nodes(self, nodes_data, kg_instance, llm): + pass diff --git a/graphgen/operators/kg/__init__.py b/graphgen/models/kg_builder/__init__.py similarity index 100% rename from graphgen/operators/kg/__init__.py rename to graphgen/models/kg_builder/__init__.py diff --git a/graphgen/models/llm/ollama_client.py b/graphgen/models/llm/ollama_client.py new file mode 100644 index 0000000..5d6e5d2 --- /dev/null +++ b/graphgen/models/llm/ollama_client.py @@ -0,0 +1,21 @@ +# TODO: implement ollama client +from typing import Any, List, Optional + +from graphgen.bases import BaseLLMClient, Token + + +class OllamaClient(BaseLLMClient): + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + pass + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + pass + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + pass diff --git a/graphgen/models/llm/openai_model.py b/graphgen/models/llm/openai_client.py similarity index 70% rename from graphgen/models/llm/openai_model.py rename to graphgen/models/llm/openai_client.py index 2c04432..9f0d276 100644 --- a/graphgen/models/llm/openai_model.py +++ b/graphgen/models/llm/openai_client.py @@ -1,7 +1,5 @@ import math -import re -from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import openai from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError @@ -12,9 +10,9 @@ wait_exponential, ) +from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.datatypes import Token from graphgen.models.llm.limitter import RPM, TPM -from graphgen.models.llm.tokenizer import Tokenizer -from graphgen.models.llm.topk_token_model import Token, TopkTokenModel def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: @@ -30,32 +28,33 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: return tokens -def filter_think_tags(text: str) -> str: - """ - Remove tags from the text. - If the text contains and , it removes everything between them and the tags themselves. - """ - think_pattern = re.compile(r".*?", re.DOTALL) - filtered_text = think_pattern.sub("", text).strip() - return filtered_text if filtered_text else text.strip() - - -@dataclass -class OpenAIModel(TopkTokenModel): - model_name: str = "gpt-4o-mini" - api_key: str = None - base_url: str = None - - system_prompt: str = "" - json_mode: bool = False - seed: int = None - - token_usage: list = field(default_factory=list) - request_limit: bool = False - rpm: RPM = field(default_factory=lambda: RPM(rpm=1000)) - tpm: TPM = field(default_factory=lambda: TPM(tpm=50000)) - - tokenizer_instance: Tokenizer = field(default_factory=Tokenizer) +class OpenAIClient(BaseLLMClient): + def __init__( + self, + *, + model_name: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, # number of topk tokens to generate for each token + request_limit: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.model_name = model_name + self.api_key = api_key + self.base_url = base_url + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + + self.token_usage: list = [] + self.request_limit = request_limit + self.rpm = RPM(rpm=1000) + self.tpm = TPM(tpm=50000) + + self.__post_init__() def __post_init__(self): assert self.api_key is not None, "Please provide api key to access openai api." @@ -66,7 +65,7 @@ def __post_init__(self): def _pre_generate(self, text: str, history: List[str]) -> Dict: kwargs = { "temperature": self.temperature, - "top_p": self.topp, + "top_p": self.top_p, "max_tokens": self.max_tokens, } if self.seed: @@ -94,7 +93,10 @@ def _pre_generate(self, text: str, history: List[str]) -> Dict: ), ) async def generate_topk_per_token( - self, text: str, history: Optional[List[str]] = None + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> List[Token]: kwargs = self._pre_generate(text, history) if self.topk_per_token > 0: @@ -120,16 +122,16 @@ async def generate_topk_per_token( ), ) async def generate_answer( - self, text: str, history: Optional[List[str]] = None, temperature: int = 0 + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> str: kwargs = self._pre_generate(text, history) - kwargs["temperature"] = temperature prompt_tokens = 0 for message in kwargs["messages"]: - prompt_tokens += len( - self.tokenizer_instance.encode_string(message["content"]) - ) + prompt_tokens += len(self.tokenizer.encode(message["content"])) estimated_tokens = prompt_tokens + kwargs["max_tokens"] if self.request_limit: @@ -147,9 +149,10 @@ async def generate_answer( "total_tokens": completion.usage.total_tokens, } ) - return filter_think_tags(completion.choices[0].message.content) + return self.filter_think_tags(completion.choices[0].message.content) async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None + self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: + """Generate probabilities for each token in the input.""" raise NotImplementedError diff --git a/graphgen/models/llm/tokenizer.py b/graphgen/models/llm/tokenizer.py deleted file mode 100644 index 6a1c4b2..0000000 --- a/graphgen/models/llm/tokenizer.py +++ /dev/null @@ -1,73 +0,0 @@ -from dataclasses import dataclass -from typing import List -import tiktoken - -try: - from transformers import AutoTokenizer - TRANSFORMERS_AVAILABLE = True -except ImportError: - AutoTokenizer = None - TRANSFORMERS_AVAILABLE = False - - -def get_tokenizer(tokenizer_name: str = "cl100k_base"): - """ - Get a tokenizer instance by name. - - :param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name - :return: tokenizer instance - """ - if tokenizer_name in tiktoken.list_encoding_names(): - return tiktoken.get_encoding(tokenizer_name) - if TRANSFORMERS_AVAILABLE: - try: - return AutoTokenizer.from_pretrained(tokenizer_name) - except Exception as e: - raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e - else: - raise ValueError("Hugging Face Transformers is not available, please install it first.") - -@dataclass -class Tokenizer: - model_name: str = "cl100k_base" - - def __post_init__(self): - self.tokenizer = get_tokenizer(self.model_name) - - def encode_string(self, text: str) -> List[int]: - """ - Encode text to tokens - - :param text - :return: tokens - """ - return self.tokenizer.encode(text) - - def decode_tokens(self, tokens: List[int]) -> str: - """ - Decode tokens to text - - :param tokens - :return: text - """ - return self.tokenizer.decode(tokens) - - def chunk_by_token_size( - self, content: str, overlap_token_size=128, max_token_size=1024 - ): - tokens = self.encode_string(content) - results = [] - for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) - ): - chunk_content = self.decode_tokens( - tokens[start : start + max_token_size] - ) - results.append( - { - "tokens": min(max_token_size, len(tokens) - start), - "content": chunk_content.strip(), - "chunk_order_index": index, - } - ) - return results diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py index b7595cb..94719cf 100644 --- a/graphgen/models/llm/topk_token_model.py +++ b/graphgen/models/llm/topk_token_model.py @@ -1,18 +1,7 @@ -import math -from dataclasses import dataclass, field -from typing import List, Union, Optional +from dataclasses import dataclass +from typing import List, Optional - -@dataclass -class Token: - text: str - prob: float - top_candidates: List = field(default_factory=list) - ppl: Union[float, None] = field(default=None) - - @property - def logprob(self) -> float: - return math.log(self.prob) +from graphgen.bases import Token @dataclass @@ -34,14 +23,18 @@ async def generate_topk_per_token(self, text: str) -> List[Token]: """ raise NotImplementedError - async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]: + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None + ) -> List[Token]: """ Generate prob and text for each token of the input text. This function is used to visualize the ppl. """ raise NotImplementedError - async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str: + async def generate_answer( + self, text: str, history: Optional[List[str]] = None + ) -> str: """ Generate answer from the model. """ diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index fde3962..0fca903 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -2,21 +2,3 @@ from .json_reader import JsonReader from .jsonl_reader import JsonlReader from .txt_reader import TxtReader - -_MAPPING = { - "jsonl": JsonlReader, - "json": JsonReader, - "txt": TxtReader, - "csv": CsvReader, -} - - -def read_file(file_path: str): - suffix = file_path.split(".")[-1] - if suffix in _MAPPING: - reader = _MAPPING[suffix]() - else: - raise ValueError( - f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}" - ) - return reader.read(file_path) diff --git a/graphgen/models/splitter/__init__.py b/graphgen/models/splitter/__init__.py index 0743654..4f8a427 100644 --- a/graphgen/models/splitter/__init__.py +++ b/graphgen/models/splitter/__init__.py @@ -1,31 +1,4 @@ -from functools import lru_cache -from typing import Union - from .recursive_character_splitter import ( ChineseRecursiveTextSplitter, RecursiveCharacterSplitter, ) - -_MAPPING = { - "en": RecursiveCharacterSplitter, - "zh": ChineseRecursiveTextSplitter, -} - -SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] - - -@lru_cache(maxsize=None) -def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: - cls = _MAPPING[language] - kwargs = dict(frozen_kwargs) - return cls(**kwargs) - - -def split_chunks(text: str, language: str = "en", **kwargs) -> list: - if language not in _MAPPING: - raise ValueError( - f"Unsupported language: {language}. " - f"Supported languages are: {list(_MAPPING.keys())}" - ) - splitter = _get_splitter(language, frozenset(kwargs.items())) - return splitter.split_text(text) diff --git a/graphgen/models/tokenizer/__init__.py b/graphgen/models/tokenizer/__init__.py new file mode 100644 index 0000000..43c7e25 --- /dev/null +++ b/graphgen/models/tokenizer/__init__.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass, field +from typing import List + +from graphgen.bases import BaseTokenizer + +from .hf_tokenizer import HFTokenizer +from .tiktoken_tokenizer import TiktokenTokenizer + +try: + from transformers import AutoTokenizer + + _HF_AVAILABLE = True +except ImportError: + _HF_AVAILABLE = False + + +def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer: + import tiktoken + + if tokenizer_name in tiktoken.list_encoding_names(): + return TiktokenTokenizer(model_name=tokenizer_name) + + # 2. HuggingFace + if _HF_AVAILABLE: + return HFTokenizer(model_name=tokenizer_name) + + raise ValueError( + f"Unknown tokenizer {tokenizer_name} and HuggingFace not available." + ) + + +@dataclass +class Tokenizer(BaseTokenizer): + """ + Encapsulates different tokenization implementations based on the specified model name. + """ + + model_name: str = "cl100k_base" + _impl: BaseTokenizer = field(init=False, repr=False) + + def __post_init__(self): + self._impl = get_tokenizer_impl(self.model_name) + + def encode(self, text: str) -> List[int]: + return self._impl.encode(text) + + def decode(self, token_ids: List[int]) -> str: + return self._impl.decode(token_ids) + + def count_tokens(self, text: str) -> int: + return self._impl.count_tokens(text) diff --git a/graphgen/models/tokenizer/hf_tokenizer.py b/graphgen/models/tokenizer/hf_tokenizer.py new file mode 100644 index 0000000..e5511a9 --- /dev/null +++ b/graphgen/models/tokenizer/hf_tokenizer.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import List + +from transformers import AutoTokenizer + +from graphgen.bases import BaseTokenizer + + +@dataclass +class HFTokenizer(BaseTokenizer): + def __post_init__(self): + self.enc = AutoTokenizer.from_pretrained(self.model_name) + + def encode(self, text: str) -> List[int]: + return self.enc.encode(text, add_special_tokens=False) + + def decode(self, token_ids: List[int]) -> str: + return self.enc.decode(token_ids, skip_special_tokens=True) diff --git a/graphgen/models/tokenizer/tiktoken_tokenizer.py b/graphgen/models/tokenizer/tiktoken_tokenizer.py new file mode 100644 index 0000000..3c84edd --- /dev/null +++ b/graphgen/models/tokenizer/tiktoken_tokenizer.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import List + +import tiktoken + +from graphgen.bases import BaseTokenizer + + +@dataclass +class TiktokenTokenizer(BaseTokenizer): + def __post_init__(self): + self.enc = tiktoken.get_encoding(self.model_name) + + def encode(self, text: str) -> List[int]: + return self.enc.encode(text) + + def decode(self, token_ids: List[int]) -> str: + return self.enc.decode(token_ids) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index b332970..5c98bc9 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,22 +1,13 @@ +from graphgen.operators.build_kg.extract_kg import extract_kg from graphgen.operators.generate.generate_cot import generate_cot -from graphgen.operators.kg.extract_kg import extract_kg from graphgen.operators.search.search_all import search_all from .judge import judge_statement from .quiz import quiz +from .read import read_files +from .split import chunk_documents from .traverse_graph import ( traverse_graph_for_aggregated, traverse_graph_for_atomic, traverse_graph_for_multi_hop, ) - -__all__ = [ - "extract_kg", - "quiz", - "judge_statement", - "search_all", - "traverse_graph_for_aggregated", - "traverse_graph_for_atomic", - "traverse_graph_for_multi_hop", - "generate_cot", -] diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphgen/operators/build_kg/extract_kg.py b/graphgen/operators/build_kg/extract_kg.py new file mode 100644 index 0000000..4f508f2 --- /dev/null +++ b/graphgen/operators/build_kg/extract_kg.py @@ -0,0 +1,127 @@ +import re +from collections import defaultdict +from typing import List + +import gradio as gr + +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Chunk +from graphgen.models import OpenAIClient, Tokenizer +from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes +from graphgen.templates import KG_EXTRACTION_PROMPT +from graphgen.utils import ( + detect_if_chinese, + handle_single_entity_extraction, + handle_single_relationship_extraction, + logger, + pack_history_conversations, + run_concurrent, + split_string_by_multi_markers, +) + + +# pylint: disable=too-many-statements +async def extract_kg( + llm_client: OpenAIClient, + kg_instance: BaseGraphStorage, + tokenizer_instance: Tokenizer, + chunks: List[Chunk], + progress_bar: gr.Progress = None, +): + """ + :param llm_client: Synthesizer LLM model to extract entities and relationships + :param kg_instance + :param tokenizer_instance + :param chunks + :param progress_bar: Gradio progress bar to show the progress of the extraction + :return: + """ + + async def _process_single_content(chunk: Chunk, max_loop: int = 3): + chunk_id = chunk.id + content = chunk.content + if detect_if_chinese(content): + language = "Chinese" + else: + language = "English" + KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language + + hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( + **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content + ) + + final_result = await llm_client.generate_answer(hint_prompt) + logger.info("First result: %s", final_result) + + history = pack_history_conversations(hint_prompt, final_result) + for loop_index in range(max_loop): + if_loop_result = await llm_client.generate_answer( + text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history + ) + if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() + if if_loop_result != "yes": + break + + glean_result = await llm_client.generate_answer( + text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history + ) + logger.info("Loop %s glean: %s", loop_index, glean_result) + + history += pack_history_conversations( + KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result + ) + final_result += glean_result + if loop_index == max_loop - 1: + break + + records = split_string_by_multi_markers( + final_result, + [ + KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], + KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], + ], + ) + + nodes = defaultdict(list) + edges = defaultdict(list) + + for record in records: + record = re.search(r"\((.*)\)", record) + if record is None: + continue + record = record.group(1) # 提取括号内的内容 + record_attributes = split_string_by_multi_markers( + record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] + ) + + entity = await handle_single_entity_extraction(record_attributes, chunk_id) + if entity is not None: + nodes[entity["entity_name"]].append(entity) + continue + relation = await handle_single_relationship_extraction( + record_attributes, chunk_id + ) + if relation is not None: + edges[(relation["src_id"], relation["tgt_id"])].append(relation) + return dict(nodes), dict(edges) + + results = await run_concurrent( + _process_single_content, + chunks, + desc="[2/4]Extracting entities and relationships from chunks", + unit="chunk", + progress_bar=progress_bar, + ) + + nodes = defaultdict(list) + edges = defaultdict(list) + for n, e in results: + for k, v in n.items(): + nodes[k].extend(v) + for k, v in e.items(): + edges[tuple(sorted(k))].extend(v) + + await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance) + await merge_edges(edges, kg_instance, llm_client, tokenizer_instance) + + return kg_instance diff --git a/graphgen/operators/kg/merge_kg.py b/graphgen/operators/build_kg/merge_kg.py similarity index 94% rename from graphgen/operators/kg/merge_kg.py rename to graphgen/operators/build_kg/merge_kg.py index fca35f3..45249c5 100644 --- a/graphgen/operators/kg/merge_kg.py +++ b/graphgen/operators/build_kg/merge_kg.py @@ -3,8 +3,8 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.models import Tokenizer, TopkTokenModel +from graphgen.bases import BaseGraphStorage, BaseLLMClient +from graphgen.models import Tokenizer from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT from graphgen.utils import detect_main_language, logger from graphgen.utils.format import split_string_by_multi_markers @@ -13,7 +13,7 @@ async def _handle_kg_summary( entity_or_relation_name: str, description: str, - llm_client: TopkTokenModel, + llm_client: BaseLLMClient, tokenizer_instance: Tokenizer, max_summary_tokens: int = 200, ) -> str: @@ -34,11 +34,11 @@ async def _handle_kg_summary( language = "Chinese" KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language - tokens = tokenizer_instance.encode_string(description) + tokens = tokenizer_instance.encode(description) if len(tokens) < max_summary_tokens: return description - use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens]) + use_description = tokenizer_instance.decode(tokens[:max_summary_tokens]) prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format( entity_name=entity_or_relation_name, description_list=use_description.split(""), @@ -54,7 +54,7 @@ async def _handle_kg_summary( async def merge_nodes( nodes_data: dict, kg_instance: BaseGraphStorage, - llm_client: TopkTokenModel, + llm_client: BaseLLMClient, tokenizer_instance: Tokenizer, max_concurrent: int = 1000, ): @@ -131,7 +131,7 @@ async def process_single_node(entity_name: str, node_data: list[dict]): async def merge_edges( edges_data: dict, kg_instance: BaseGraphStorage, - llm_client: TopkTokenModel, + llm_client: BaseLLMClient, tokenizer_instance: Tokenizer, max_concurrent: int = 1000, ): diff --git a/graphgen/operators/kg/split_kg.py b/graphgen/operators/build_kg/split_kg.py similarity index 100% rename from graphgen/operators/kg/split_kg.py rename to graphgen/operators/build_kg/split_kg.py diff --git a/graphgen/operators/generate/generate_cot.py b/graphgen/operators/generate/generate_cot.py index b87bce2..e96635a 100644 --- a/graphgen/operators/generate/generate_cot.py +++ b/graphgen/operators/generate/generate_cot.py @@ -3,14 +3,14 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel +from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT from graphgen.utils import compute_content_hash, detect_main_language async def generate_cot( graph_storage: NetworkXStorage, - synthesizer_llm_client: OpenAIModel, + synthesizer_llm_client: OpenAIClient, method_params: Dict = None, ): method = method_params.get("method", "leiden") diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py index 61e9d33..d1b0e86 100644 --- a/graphgen/operators/judge.py +++ b/graphgen/operators/judge.py @@ -3,13 +3,13 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel +from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT from graphgen.utils import logger, yes_no_loss_entropy async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: OpenAIModel, + trainee_llm_client: OpenAIClient, graph_storage: NetworkXStorage, rephrase_storage: JsonKVStorage, re_judge: bool = False, diff --git a/graphgen/operators/kg/extract_kg.py b/graphgen/operators/kg/extract_kg.py deleted file mode 100644 index ed64f22..0000000 --- a/graphgen/operators/kg/extract_kg.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio -import re -from collections import defaultdict -from typing import List - -import gradio as gr -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.bases.datatypes import Chunk -from graphgen.models import OpenAIModel, Tokenizer -from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes -from graphgen.templates import KG_EXTRACTION_PROMPT -from graphgen.utils import ( - detect_if_chinese, - handle_single_entity_extraction, - handle_single_relationship_extraction, - logger, - pack_history_conversations, - split_string_by_multi_markers, -) - - -# pylint: disable=too-many-statements -async def extract_kg( - llm_client: OpenAIModel, - kg_instance: BaseGraphStorage, - tokenizer_instance: Tokenizer, - chunks: List[Chunk], - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -): - """ - :param llm_client: Synthesizer LLM model to extract entities and relationships - :param kg_instance - :param tokenizer_instance - :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction - :param max_concurrent - :return: - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def _process_single_content(chunk: Chunk, max_loop: int = 3): - async with semaphore: - chunk_id = chunk.id - content = chunk.content - if detect_if_chinese(content): - language = "Chinese" - else: - language = "English" - KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language - - hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( - **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content - ) - - final_result = await llm_client.generate_answer(hint_prompt) - logger.info("First result: %s", final_result) - - history = pack_history_conversations(hint_prompt, final_result) - for loop_index in range(max_loop): - if_loop_result = await llm_client.generate_answer( - text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history - ) - if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() - if if_loop_result != "yes": - break - - glean_result = await llm_client.generate_answer( - text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history - ) - logger.info("Loop %s glean: %s", loop_index, glean_result) - - history += pack_history_conversations( - KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result - ) - final_result += glean_result - if loop_index == max_loop - 1: - break - - records = split_string_by_multi_markers( - final_result, - [ - KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], - KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], - ], - ) - - nodes = defaultdict(list) - edges = defaultdict(list) - - for record in records: - record = re.search(r"\((.*)\)", record) - if record is None: - continue - record = record.group(1) # 提取括号内的内容 - record_attributes = split_string_by_multi_markers( - record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] - ) - - entity = await handle_single_entity_extraction( - record_attributes, chunk_id - ) - if entity is not None: - nodes[entity["entity_name"]].append(entity) - continue - relation = await handle_single_relationship_extraction( - record_attributes, chunk_id - ) - if relation is not None: - edges[(relation["src_id"], relation["tgt_id"])].append(relation) - return dict(nodes), dict(edges) - - results = [] - chunk_number = len(chunks) - async for result in tqdm_async( - asyncio.as_completed([_process_single_content(c) for c in chunks]), - total=len(chunks), - desc="[2/4]Extracting entities and relationships from chunks", - unit="chunk", - ): - try: - if progress_bar is not None: - progress_bar( - len(results) / chunk_number, - desc="[3/4]Extracting entities and relationships from chunks", - ) - results.append(await result) - if progress_bar is not None and len(results) == chunk_number: - progress_bar( - 1, desc="[3/4]Extracting entities and relationships from chunks" - ) - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error occurred while extracting entities and relationships from chunks: %s", - e, - ) - - nodes = defaultdict(list) - edges = defaultdict(list) - for n, e in results: - for k, v in n.items(): - nodes[k].extend(v) - for k, v in e.items(): - edges[tuple(sorted(k))].extend(v) - - await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance) - await merge_edges(edges, kg_instance, llm_client, tokenizer_instance) - - return kg_instance diff --git a/graphgen/operators/preprocess/resolute_coreference.py b/graphgen/operators/preprocess/resolute_coreference.py index e3c498d..a4da6a8 100644 --- a/graphgen/operators/preprocess/resolute_coreference.py +++ b/graphgen/operators/preprocess/resolute_coreference.py @@ -1,13 +1,13 @@ from typing import List from graphgen.bases.datatypes import Chunk -from graphgen.models import OpenAIModel +from graphgen.models import OpenAIClient from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT from graphgen.utils import detect_main_language async def resolute_coreference( - llm_client: OpenAIModel, chunks: List[Chunk] + llm_client: OpenAIClient, chunks: List[Chunk] ) -> List[Chunk]: """ Resolute conference diff --git a/graphgen/operators/quiz.py b/graphgen/operators/quiz.py index 36edddb..a8623bf 100644 --- a/graphgen/operators/quiz.py +++ b/graphgen/operators/quiz.py @@ -2,17 +2,19 @@ from collections import defaultdict from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage -from graphgen.utils import logger, detect_main_language + +from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT +from graphgen.utils import detect_main_language, logger async def quiz( - synth_llm_client: OpenAIModel, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - max_samples: int = 1, - max_concurrent: int = 1000) -> JsonKVStorage: + synth_llm_client: OpenAIClient, + graph_storage: NetworkXStorage, + rephrase_storage: JsonKVStorage, + max_samples: int = 1, + max_concurrent: int = 1000, +) -> JsonKVStorage: """ Get all edges and quiz them @@ -26,11 +28,7 @@ async def quiz( semaphore = asyncio.Semaphore(max_concurrent) - async def _process_single_quiz( - des: str, - prompt: str, - gt: str - ): + async def _process_single_quiz(des: str, prompt: str, gt: str): async with semaphore: try: # 如果在rephrase_storage中已经存在,直接取出 @@ -39,16 +37,14 @@ async def _process_single_quiz( return None new_description = await synth_llm_client.generate_answer( - prompt, - temperature=1 + prompt, temperature=1 ) - return {des: [(new_description, gt)]} + return {des: [(new_description, gt)]} - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except logger.error("Error when quizzing description %s: %s", des, e) return None - edges = await graph_storage.get_all_edges() nodes = await graph_storage.get_all_nodes() @@ -60,41 +56,59 @@ async def _process_single_quiz( description = edge_data["description"] language = "English" if detect_main_language(description) == "en" else "Chinese" - results[description] = [(description, 'yes')] + results[description] = [(description, "yes")] for i in range(max_samples): if i > 0: tasks.append( - _process_single_quiz(description, - DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( - input_sentence=description), 'yes') + _process_single_quiz( + description, + DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format( + input_sentence=description + ), + "yes", + ) ) - tasks.append(_process_single_quiz(description, - DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( - input_sentence=description), 'no')) + tasks.append( + _process_single_quiz( + description, + DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format( + input_sentence=description + ), + "no", + ) + ) for node in nodes: node_data = node[1] description = node_data["description"] language = "English" if detect_main_language(description) == "en" else "Chinese" - results[description] = [(description, 'yes')] + results[description] = [(description, "yes")] for i in range(max_samples): if i > 0: tasks.append( - _process_single_quiz(description, - DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( - input_sentence=description), 'yes') + _process_single_quiz( + description, + DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format( + input_sentence=description + ), + "yes", + ) + ) + tasks.append( + _process_single_quiz( + description, + DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format( + input_sentence=description + ), + "no", ) - tasks.append(_process_single_quiz(description, - DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( - input_sentence=description), 'no')) + ) for result in tqdm_async( - asyncio.as_completed(tasks), - total=len(tasks), - desc="Quizzing descriptions" + asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions" ): new_result = await result if new_result: @@ -105,5 +119,4 @@ async def _process_single_quiz( results[key] = list(set(value)) await rephrase_storage.upsert({key: results[key]}) - return rephrase_storage diff --git a/graphgen/operators/read/__init__.py b/graphgen/operators/read/__init__.py new file mode 100644 index 0000000..075ae93 --- /dev/null +++ b/graphgen/operators/read/__init__.py @@ -0,0 +1 @@ +from .read_files import read_files diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py new file mode 100644 index 0000000..e1d13a2 --- /dev/null +++ b/graphgen/operators/read/read_files.py @@ -0,0 +1,19 @@ +from graphgen.models import CsvReader, JsonlReader, JsonReader, TxtReader + +_MAPPING = { + "jsonl": JsonlReader, + "json": JsonReader, + "txt": TxtReader, + "csv": CsvReader, +} + + +def read_files(file_path: str): + suffix = file_path.split(".")[-1] + if suffix in _MAPPING: + reader = _MAPPING[suffix]() + else: + raise ValueError( + f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}" + ) + return reader.read(file_path) diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py new file mode 100644 index 0000000..2afc738 --- /dev/null +++ b/graphgen/operators/split/__init__.py @@ -0,0 +1 @@ +from .split_chunks import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/split/split_chunks.py new file mode 100644 index 0000000..caba96a --- /dev/null +++ b/graphgen/operators/split/split_chunks.py @@ -0,0 +1,76 @@ +from functools import lru_cache +from typing import Union + +from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.models import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, + Tokenizer, +) +from graphgen.utils import compute_content_hash, detect_main_language + +_MAPPING = { + "en": RecursiveCharacterSplitter, + "zh": ChineseRecursiveTextSplitter, +} + +SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] + + +@lru_cache(maxsize=None) +def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: + cls = _MAPPING[language] + kwargs = dict(frozen_kwargs) + return cls(**kwargs) + + +def split_chunks(text: str, language: str = "en", **kwargs) -> list: + if language not in _MAPPING: + raise ValueError( + f"Unsupported language: {language}. " + f"Supported languages are: {list(_MAPPING.keys())}" + ) + splitter = _get_splitter(language, frozenset(kwargs.items())) + return splitter.split_text(text) + + +async def chunk_documents( + new_docs: dict, + chunk_size: int = 1024, + chunk_overlap: int = 100, + tokenizer_instance: Tokenizer = None, + progress_bar=None, +) -> dict: + inserting_chunks = {} + cur_index = 1 + doc_number = len(new_docs) + async for doc_key, doc in tqdm_async( + new_docs.items(), desc="[1/4]Chunking documents", unit="doc" + ): + doc_language = detect_main_language(doc["content"]) + text_chunks = split_chunks( + doc["content"], + language=doc_language, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + chunks = { + compute_content_hash(txt, prefix="chunk-"): { + "content": txt, + "full_doc_id": doc_key, + "length": len(tokenizer_instance.encode(txt)) + if tokenizer_instance + else len(txt), + "language": doc_language, + } + for txt in text_chunks + } + inserting_chunks.update(chunks) + + if progress_bar is not None: + progress_bar(cur_index / doc_number, f"Chunking {doc_key}") + cur_index += 1 + + return inserting_chunks diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py index 16e2d25..ff3faab 100644 --- a/graphgen/operators/traverse_graph.py +++ b/graphgen/operators/traverse_graph.py @@ -6,11 +6,11 @@ from graphgen.models import ( JsonKVStorage, NetworkXStorage, - OpenAIModel, + OpenAIClient, Tokenizer, TraverseStrategy, ) -from graphgen.operators.kg.split_kg import get_batches_with_strategy +from graphgen.operators.build_kg.split_kg import get_batches_with_strategy from graphgen.templates import ( ANSWER_REPHRASING_PROMPT, MULTI_HOP_GENERATION_PROMPT, @@ -30,7 +30,7 @@ async def handle_edge(edge: tuple) -> tuple: if "length" not in edge[2]: edge[2]["length"] = len( await asyncio.get_event_loop().run_in_executor( - None, tokenizer.encode_string, edge[2]["description"] + None, tokenizer.encode, edge[2]["description"] ) ) return edge @@ -40,7 +40,7 @@ async def handle_node(node: dict) -> dict: if "length" not in node[1]: node[1]["length"] = len( await asyncio.get_event_loop().run_in_executor( - None, tokenizer.encode_string, node[1]["description"] + None, tokenizer.encode, node[1]["description"] ) ) return node @@ -161,7 +161,7 @@ def _post_process_synthetic_data(data): async def traverse_graph_for_aggregated( - llm_client: OpenAIModel, + llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, traverse_strategy: TraverseStrategy, @@ -310,7 +310,7 @@ async def _process_single_batch( # pylint: disable=too-many-branches, too-many-statements async def traverse_graph_for_atomic( - llm_client: OpenAIModel, + llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, traverse_strategy: TraverseStrategy, @@ -426,7 +426,7 @@ async def _generate_question(node_or_edge: tuple): async def traverse_graph_for_multi_hop( - llm_client: OpenAIModel, + llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, traverse_strategy: TraverseStrategy, diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index a3bf496..d56ca73 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -13,3 +13,5 @@ from .help_nltk import NLTKHelper from .log import logger, parse_log, set_logger from .loop import create_event_loop +from .run_concurrent import run_concurrent +from .wrap import async_to_sync_method diff --git a/graphgen/utils/calculate_confidence.py b/graphgen/utils/calculate_confidence.py index 1b596d9..663e2e4 100644 --- a/graphgen/utils/calculate_confidence.py +++ b/graphgen/utils/calculate_confidence.py @@ -1,34 +1,41 @@ import math from typing import List -from graphgen.models.llm.topk_token_model import Token + +from graphgen.bases.datatypes import Token + def preprocess_tokens(tokens: List[Token]) -> List[Token]: """Preprocess tokens for calculating confidence.""" tokens = [x for x in tokens if x.prob > 0] return tokens + def joint_probability(tokens: List[Token]) -> float: """Calculate joint probability of a list of tokens.""" tokens = preprocess_tokens(tokens) logprob_sum = sum(x.logprob for x in tokens) return math.exp(logprob_sum / len(tokens)) + def min_prob(tokens: List[Token]) -> float: """Calculate the minimum probability of a list of tokens.""" tokens = preprocess_tokens(tokens) return min(x.prob for x in tokens) + def average_prob(tokens: List[Token]) -> float: """Calculate the average probability of a list of tokens.""" tokens = preprocess_tokens(tokens) return sum(x.prob for x in tokens) / len(tokens) + def average_confidence(tokens: List[Token]) -> float: """Calculate the average confidence of a list of tokens.""" tokens = preprocess_tokens(tokens) confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens] return sum(confidence) / len(tokens) + def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float: """Calculate the loss for yes/no question.""" losses = [] @@ -41,7 +48,10 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa losses.append(token.prob) return sum(losses) / len(losses) -def yes_no_loss_entropy(tokens_list: List[List[Token]], ground_truth: List[str]) -> float: + +def yes_no_loss_entropy( + tokens_list: List[List[Token]], ground_truth: List[str] +) -> float: """Calculate the loss for yes/no question using entropy.""" losses = [] for i, tokens in enumerate(tokens_list): diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py new file mode 100644 index 0000000..43eaf8c --- /dev/null +++ b/graphgen/utils/run_concurrent.py @@ -0,0 +1,38 @@ +import asyncio +from typing import Awaitable, Callable, List, Optional, TypeVar + +import gradio as gr +from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.utils.log import logger + +T = TypeVar("T") +R = TypeVar("R") + + +async def run_concurrent( + coro_fn: Callable[[T], Awaitable[R]], + items: List[T], + *, + desc: str = "processing", + unit: str = "item", + progress_bar: Optional[gr.Progress] = None, +) -> List[R]: + tasks = [asyncio.create_task(coro_fn(it)) for it in items] + + results = await tqdm_async.gather(*tasks, desc=desc, unit=unit) + + ok_results = [] + for idx, res in enumerate(results): + if isinstance(res, Exception): + logger.exception("Task failed: %s", res) + if progress_bar: + progress_bar((idx + 1) / len(items), desc=desc) + continue + ok_results.append(res) + if progress_bar: + progress_bar((idx + 1) / len(items), desc=desc) + + if progress_bar: + progress_bar(1.0, desc=desc) + return ok_results diff --git a/graphgen/utils/wrap.py b/graphgen/utils/wrap.py new file mode 100644 index 0000000..57776f2 --- /dev/null +++ b/graphgen/utils/wrap.py @@ -0,0 +1,13 @@ +from functools import wraps +from typing import Any, Callable + +from .loop import create_event_loop + + +def async_to_sync_method(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs) -> Any: + loop = create_event_loop() + return loop.run_until_complete(func(self, *args, **kwargs)) + + return wrapper diff --git a/webui/app.py b/webui/app.py index 79f4ce6..8e2ec9f 100644 --- a/webui/app.py +++ b/webui/app.py @@ -9,7 +9,7 @@ from dotenv import load_dotenv from graphgen.graphgen import GraphGen -from graphgen.models import OpenAIModel, Tokenizer +from graphgen.models import OpenAIClient, Tokenizer from graphgen.models.llm.limitter import RPM, TPM from graphgen.utils import set_logger from webui.base import WebuiParams @@ -41,7 +41,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: graph_gen = GraphGen(working_dir=working_dir, config=config) # Set up LLM clients - graph_gen.synthesizer_llm_client = OpenAIModel( + graph_gen.synthesizer_llm_client = OpenAIClient( model_name=env.get("SYNTHESIZER_MODEL", ""), base_url=env.get("SYNTHESIZER_BASE_URL", ""), api_key=env.get("SYNTHESIZER_API_KEY", ""), @@ -50,7 +50,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: tpm=TPM(env.get("TPM", 50000)), ) - graph_gen.trainee_llm_client = OpenAIModel( + graph_gen.trainee_llm_client = OpenAIClient( model_name=env.get("TRAINEE_MODEL", ""), base_url=env.get("TRAINEE_BASE_URL", ""), api_key=env.get("TRAINEE_API_KEY", ""), diff --git a/webui/utils/count_tokens.py b/webui/utils/count_tokens.py index 210bd26..82b5522 100644 --- a/webui/utils/count_tokens.py +++ b/webui/utils/count_tokens.py @@ -45,7 +45,7 @@ def count_tokens(file, tokenizer_name, data_frame): content = item.get("content", "") else: content = item - token_count += len(tokenizer.encode_string(content)) + token_count += len(tokenizer.encode(content)) _update_data = [[str(token_count), str(token_count * 50), "N/A"]]