Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions baselines/Genie/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import OpenAIModel
from graphgen.models import OpenAIClient
from graphgen.utils import compute_content_hash, create_event_loop

PROMPT_TEMPLATE = """Instruction: Given the next [document], create a [question] and [answer] pair that are grounded \
Expand Down Expand Up @@ -59,7 +59,7 @@ def _post_process(content: str) -> tuple:

@dataclass
class Genie:
llm_client: OpenAIModel = None
llm_client: OpenAIClient = None
max_concurrent: int = 1000

def generate(self, docs: List[List[dict]]) -> List[dict]:
Expand Down Expand Up @@ -121,7 +121,7 @@ async def process_chunk(content: str):

load_dotenv()

llm_client = OpenAIModel(
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
Expand Down
6 changes: 3 additions & 3 deletions baselines/LongForm/longform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import OpenAIModel
from graphgen.models import OpenAIClient
from graphgen.utils import compute_content_hash, create_event_loop

PROMPT_TEMPLATE = """Instruction: X
Expand All @@ -23,7 +23,7 @@

@dataclass
class LongForm:
llm_client: OpenAIModel = None
llm_client: OpenAIClient = None
max_concurrent: int = 1000

def generate(self, docs: List[List[dict]]) -> List[dict]:
Expand Down Expand Up @@ -88,7 +88,7 @@ async def process_chunk(content: str):

load_dotenv()

llm_client = OpenAIModel(
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
Expand Down
6 changes: 3 additions & 3 deletions baselines/SELF-QA/self-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import OpenAIModel
from graphgen.models import OpenAIClient
from graphgen.utils import compute_content_hash, create_event_loop

INSTRUCTION_GENERATION_PROMPT = """The background knowledge is:
Expand Down Expand Up @@ -58,7 +58,7 @@ def _post_process_answers(content: str) -> tuple:

@dataclass
class SelfQA:
llm_client: OpenAIModel = None
llm_client: OpenAIClient = None
max_concurrent: int = 100

def generate(self, docs: List[List[dict]]) -> List[dict]:
Expand Down Expand Up @@ -155,7 +155,7 @@ async def process_chunk(content: str):

load_dotenv()

llm_client = OpenAIModel(
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
Expand Down
6 changes: 3 additions & 3 deletions baselines/Wrap/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import OpenAIModel
from graphgen.models import OpenAIClient
from graphgen.utils import compute_content_hash, create_event_loop

PROMPT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant.
Expand Down Expand Up @@ -46,7 +46,7 @@ def _post_process(content: str) -> list:

@dataclass
class Wrap:
llm_client: OpenAIModel = None
llm_client: OpenAIClient = None
max_concurrent: int = 1000

def generate(self, docs: List[List[dict]]) -> List[dict]:
Expand Down Expand Up @@ -108,7 +108,7 @@ async def process_chunk(content: str):

load_dotenv()

llm_client = OpenAIModel(
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
Expand Down
12 changes: 12 additions & 0 deletions graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .base_kg_builder import BaseKGBuilder
from .base_llm_client import BaseLLMClient
from .base_reader import BaseReader
from .base_splitter import BaseSplitter
from .base_storage import (
BaseGraphStorage,
BaseKVStorage,
BaseListStorage,
StorageNameSpace,
)
from .base_tokenizer import BaseTokenizer
from .datatypes import Chunk, QAPair, Token
41 changes: 41 additions & 0 deletions graphgen/bases/base_kg_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

from graphgen.bases.base_llm_client import BaseLLMClient
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk


@dataclass
class BaseKGBuilder(ABC):
kg_instance: BaseGraphStorage
llm_client: BaseLLMClient

_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
_edges: Dict[Tuple[str, str], List[dict]] = field(
default_factory=lambda: defaultdict(list)
)

def build(self, chunks: List[Chunk]) -> None:
pass

@abstractmethod
async def extract_all(self, chunks: List[Chunk]) -> None:
"""Extract nodes and edges from all chunks."""
raise NotImplementedError

@abstractmethod
async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""Extract nodes and edges from a single chunk."""
raise NotImplementedError

@abstractmethod
async def merge_nodes(
self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
) -> None:
"""Merge extracted nodes into the knowledge graph."""
raise NotImplementedError
74 changes: 74 additions & 0 deletions graphgen/bases/base_llm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

import abc
import re
from typing import Any, List, Optional

from graphgen.bases.base_tokenizer import BaseTokenizer
from graphgen.bases.datatypes import Token


class BaseLLMClient(abc.ABC):
"""
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
"""

def __init__(
self,
*,
system_prompt: str = "",
temperature: float = 0.0,
max_tokens: int = 4096,
repetition_penalty: float = 1.05,
top_p: float = 0.95,
top_k: int = 50,
tokenizer: Optional[BaseTokenizer] = None,
**kwargs: Any,
):
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.top_p = top_p
self.top_k = top_k
self.tokenizer = tokenizer

for k, v in kwargs.items():
setattr(self, k, v)

@abc.abstractmethod
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
"""Generate answer from the model."""
raise NotImplementedError

@abc.abstractmethod
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate top-k tokens for the next token prediction."""
raise NotImplementedError

@abc.abstractmethod
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate probabilities for each token in the input."""
raise NotImplementedError

def count_tokens(self, text: str) -> int:
"""Count the number of tokens in the text."""
if self.tokenizer is None:
raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
return len(self.tokenizer.encode(text))

@staticmethod
def filter_think_tags(text: str, think_tag: str = "think") -> str:
"""
Remove <think> tags from the text.
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
"""
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
filtered_text = think_pattern.sub("", text).strip()
return filtered_text if filtered_text else text.strip()
44 changes: 44 additions & 0 deletions graphgen/bases/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List


@dataclass
class BaseTokenizer(ABC):
model_name: str = "cl100k_base"

@abstractmethod
def encode(self, text: str) -> List[int]:
"""Encode text -> token ids."""
raise NotImplementedError

@abstractmethod
def decode(self, token_ids: List[int]) -> str:
"""Decode token ids -> text."""
raise NotImplementedError

def count_tokens(self, text: str) -> int:
return len(self.encode(text))

def chunk_by_token_size(
self,
content: str,
*,
overlap_token_size: int = 128,
max_token_size: int = 1024,
) -> List[dict]:
tokens = self.encode(content)
results = []
step = max_token_size - overlap_token_size
for index, start in enumerate(range(0, len(tokens), step)):
chunk_ids = tokens[start : start + max_token_size]
results.append(
{
"tokens": len(chunk_ids),
"content": self.decode(chunk_ids).strip(),
"chunk_order_index": index,
}
)
return results
14 changes: 14 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import math
from dataclasses import dataclass, field
from typing import List, Union


@dataclass
Expand All @@ -16,3 +18,15 @@ class QAPair:

question: str
answer: str


@dataclass
class Token:
text: str
prob: float
top_candidates: List = field(default_factory=list)
ppl: Union[float, None] = field(default=None)

@property
def logprob(self) -> float:
return math.log(self.prob)
Loading