Skip to content

Commit

Permalink
Merge pull request #185 from stanford-oval/costorm-integration
Browse files Browse the repository at this point in the history
Costorm integration
Yucheng-Jiang authored Sep 25, 2024
2 parents 33a03a3 + efac123 commit 564a507
Showing 45 changed files with 5,191 additions and 270 deletions.
207 changes: 166 additions & 41 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/co-storm-workflow.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
241 changes: 241 additions & 0 deletions examples/costorm_examples/run_costorm_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""
Co-STORM pipeline powered by GPT-4o/4o-mini and Bing search engine.
You need to set up the following environment variables to run this script:
- OPENAI_API_KEY: OpenAI API key
- OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
- AZURE_API_BASE: Azure API base URL if using Azure API
- AZURE_API_VERSION: Azure API version if using Azure API
- BING_SEARCH_API_KEY: Biang search API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
Output will be structured as below
args.output_dir/
log.json # Log of information-seeking conversation
report.txt # Final article generated
"""

import os
import json
from argparse import ArgumentParser
from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner
from knowledge_storm.collaborative_storm.modules.callback import LocalConsolePrintCallBackHandler
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.logging_wrapper import LoggingWrapper
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.utils import load_api_key


def main(args):
load_api_key(toml_file_path='secrets.toml')
lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs()
openai_kwargs = {
"api_key": os.getenv("OPENAI_API_KEY"),
"api_provider": "openai",
"temperature": 1.0,
"top_p": 0.9,
"api_base": None,
} if os.getenv('OPENAI_API_TYPE') == 'openai' else {
"api_key": os.getenv("AZURE_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
}

ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel
# If you are using Azure service, make sure the model name matches your own deployed model name.
# The default name here is only used for demonstration and may not match your case.
gpt_4o_mini_model_name = 'gpt-4o-mini'
gpt_4o_model_name = 'gpt-4o'
if os.getenv('OPENAI_API_TYPE') == 'azure':
openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE')
openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION')

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
question_answering_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)
discourse_manage_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
utterance_polishing_lm = ModelClass(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs)
warmstart_outline_gen_lm = ModelClass(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
question_asking_lm = ModelClass(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs)
knowledge_base_lm = ModelClass(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)

lm_config.set_question_answering_lm(question_answering_lm)
lm_config.set_discourse_manage_lm(discourse_manage_lm)
lm_config.set_utterance_polishing_lm(utterance_polishing_lm)
lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm)
lm_config.set_question_asking_lm(question_asking_lm)
lm_config.set_knowledge_base_lm(knowledge_base_lm)

topic = input('Topic: ')
runner_argument = RunnerArgument(
topic=topic,
retrieve_top_k=args.retrieve_top_k,
max_search_queries=args.max_search_queries,
total_conv_turn=args.total_conv_turn,
max_search_thread=args.max_search_thread,
max_search_queries_per_turn=args.max_search_queries_per_turn,
warmstart_max_num_experts=args.warmstart_max_num_experts,
warmstart_max_turn_per_experts=args.warmstart_max_turn_per_experts,
warmstart_max_thread=args.warmstart_max_thread,
max_thread_num=args.max_thread_num,
max_num_round_table_experts=args.max_num_round_table_experts,
moderator_override_N_consecutive_answering_turn=args.moderator_override_N_consecutive_answering_turn,
node_expansion_trigger_count=args.node_expansion_trigger_count)
logging_wrapper = LoggingWrapper(lm_config)
callback_handler = LocalConsolePrintCallBackHandler() if args.enable_log_print else None

# Co-STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
match args.retriever:
case 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=runner_argument.retrieve_top_k)
case 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=runner_argument.retrieve_top_k)
case 'brave':
rm = BraveRM(brave_search_api_key=os.getenv('BRAVE_API_KEY'), k=runner_argument.retrieve_top_k)
case 'duckduckgo':
rm = DuckDuckGoSearchRM(k=runner_argument.retrieve_top_k, safe_search='On', region='us-en')
case 'serper':
rm = SerperRM(serper_search_api_key=os.getenv('SERPER_API_KEY'), query_params={'autocorrect': True, 'num': 10, 'page': 1})
case 'tavily':
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=runner_argument.retrieve_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=runner_argument.retrieve_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')

costorm_runner = CoStormRunner(lm_config=lm_config,
runner_argument=runner_argument,
logging_wrapper=logging_wrapper,
rm=rm,
callback_handler=callback_handler)

# warm start the system
costorm_runner.warm_start()

# Below is an example of how users may interact with Co-STORM to seek information together
# In actual deployment, we suggest allowing the user to decide whether to observe the agent utterance or inject a turn

# observing Co-STORM LLM agent utterance for 5 turns
for _ in range(1):
conv_turn = costorm_runner.step()
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")

# active engaging by injecting your utterance
your_utterance = input('Your utterance: ')
costorm_runner.step(user_utterance=your_utterance)

# continue observing
conv_turn = costorm_runner.step()
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")

# generate report
costorm_runner.knowledge_base.reogranize()
article = costorm_runner.generate_report()

# save results
os.makedirs(args.output_dir, exist_ok=True)

# Save article
with open(os.path.join(args.output_dir, "report.md"), "w") as f:
f.write(article)

# Save logging
log_dump = costorm_runner.dump_logging_and_reset()
with open(os.path.join(args.output_dir, "log.json"), "w") as f:
json.dump(log_dump, f, indent=2)


if __name__ == '__main__':
parser = ArgumentParser()
# global arguments
parser.add_argument('--output-dir', type=str, default='./results/co-storm',
help='Directory to store the outputs.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
help='The search engine API to use for retrieving information.')
# hyperparameters for co-storm
parser.add_argument(
'--retrieve_top_k',
type=int,
default=10,
help='Retrieve top k results for each query in retriever.'
)
parser.add_argument(
'--max_search_queries',
type=int,
default=2,
help='Maximum number of search queries to consider for each question.'
)
parser.add_argument(
'--total_conv_turn',
type=int,
default=20,
help='Maximum number of turns in conversation.'
)
parser.add_argument(
'--max_search_thread',
type=int,
default=5,
help='Maximum number of parallel threads for retriever.'
)
parser.add_argument(
'--max_search_queries_per_turn',
type=int,
default=3,
help='Maximum number of search queries to consider in each turn.'
)
parser.add_argument(
'--warmstart_max_num_experts',
type=int,
default=3,
help='Max number of experts in perspective-guided QA during warm start.'
)
parser.add_argument(
'--warmstart_max_turn_per_experts',
type=int,
default=2,
help='Max number of turns per perspective during warm start.'
)
parser.add_argument(
'--warmstart_max_thread',
type=int,
default=3,
help='Max number of threads for parallel perspective-guided QA during warm start.'
)
parser.add_argument(
'--max_thread_num',
type=int,
default=10,
help=("Maximum number of threads to use. "
"Consider reducing it if you keep getting 'Exceed rate limit' errors when calling the LM API.")
)
parser.add_argument(
'--max_num_round_table_experts',
type=int,
default=2,
help='Max number of active experts in round table discussion.'
)
parser.add_argument(
'--moderator_override_N_consecutive_answering_turn',
type=int,
default=3,
help=('Number of consecutive expert answering turns before the moderator overrides the conversation.')
)
parser.add_argument(
'--node_expansion_trigger_count',
type=int,
default=10,
help='Trigger node expansion for nodes that contain more than N snippets.'
)

# Boolean flags
parser.add_argument(
'--enable_log_print',
action='store_true',
help='If set, enable console log print.'
)

main(parser.parse_args())
12 changes: 6 additions & 6 deletions examples/README.md → examples/storm_examples/README.md
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ We host a number of example scripts for various customization of STORM (e.g., us
2. Run the following command under the root directory of the repository:

```
python examples/run_storm_wiki_mistral.py \
python examples/storm_examples/run_storm_wiki_mistral.py \
--url $URL \
--port $PORT \
--output-dir $OUTPUT_DIR \
@@ -50,7 +50,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
To create the vector store offline, run
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
@@ -65,7 +65,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
To create the vector store online on a Qdrant server, run
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode online \
--online-vector-db-url $ONLINE_VECTOR_DB_URL \
@@ -83,12 +83,12 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
- Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.
```
python examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
python examples/storm_examples/helper/process_kaggle_arxiv_abstract_dataset.py --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
```
- Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
@@ -102,7 +102,7 @@ By default, STORM is grounded on the Internet using the search engine, but it ca
- For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link).
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
python examples/storm_examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -27,10 +27,8 @@
"""

import os
import sys
from argparse import ArgumentParser

sys.path.append('./')
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.rm import VectorRM
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
Original file line number Diff line number Diff line change
@@ -18,17 +18,10 @@
"""

import os
import sys
import re
import logging
from argparse import ArgumentParser

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
# Get the absolute path to the directory containing lm.py
lm_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'knowledge_storm'))

# Add this path to sys.path
sys.path.insert(0, lm_path)

# Now import lm directly
import lm
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@

from dspy import Example

sys.path.append('./src')
from knowledge_storm.lm import OllamaClient
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
File renamed without changes.
15 changes: 9 additions & 6 deletions knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .storm_wiki.engine import (
STORMWikiLMConfigs,
STORMWikiRunnerArguments,
STORMWikiRunner,
)
from .storm_wiki import *
from .collaborative_storm import *
from .encoder import *
from .interface import *
from .lm import *
from .rm import *
from .utils import *
from .dataclass import *

__version__ = "0.2.8"
__version__ = "1.0.0"
2 changes: 2 additions & 0 deletions knowledge_storm/collaborative_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modules import *
from .engine import *
745 changes: 745 additions & 0 deletions knowledge_storm/collaborative_storm/engine.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions knowledge_storm/collaborative_storm/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .article_generation import *
from .grounded_question_answering import *
from .grounded_question_generation import *
from .information_insertion_module import *
from .simulate_user import *
from .warmstart_hierarchical_chat import *
from .knowledge_base_summary import *
from .costorm_expert_utterance_generator import *
123 changes: 123 additions & 0 deletions knowledge_storm/collaborative_storm/modules/article_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import dspy
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Set, Union

from .collaborative_storm_utils import clean_up_section
from ...dataclass import KnowledgeBase, KnowledgeNode


class ArticleGenerationModule(dspy.Module):
"""Use the information collected from the information-seeking conversation to write a section."""

def __init__(
self,
engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
):
super().__init__()
self.write_section = dspy.Predict(WriteSection)
self.engine = engine

def _get_cited_information_string(
self,
all_citation_index: Set[int],
knowledge_base: KnowledgeBase,
max_words: int = 1500,
):
information = []
cur_word_count = 0
for index in sorted(list(all_citation_index)):
info = knowledge_base.info_uuid_to_info_dict[index]
snippet = info.snippets[0]
info_text = f"[{index}]: {snippet} (Question: {info.meta['question']}. Query: {info.meta['query']})"
cur_snippet_length = len(info_text.split())
if cur_snippet_length + cur_word_count > max_words:
break
cur_word_count += cur_snippet_length
information.append(info_text)
return "\n".join(information)

def gen_section(
self, topic: str, node: KnowledgeNode, knowledge_base: KnowledgeBase
):
if node is None or len(node.content) == 0:
return ""
if (
node.synthesize_output is not None
and node.synthesize_output
and not node.need_regenerate_synthesize_output
):
return node.synthesize_output
all_citation_index = node.collect_all_content()
information = self._get_cited_information_string(
all_citation_index=all_citation_index, knowledge_base=knowledge_base
)
with dspy.settings.context(lm=self.engine):
synthesize_output = clean_up_section(
self.write_section(
topic=topic, info=information, section=node.name
).output
)
node.synthesize_output = synthesize_output
node.need_regenerate_synthesize_output = False
return node.synthesize_output

def forward(self, knowledge_base: KnowledgeBase):
all_nodes = knowledge_base.collect_all_nodes()
node_to_paragraph = {}

# Define a function to generate paragraphs for nodes
def _node_generate_paragraph(node):
node_gen_paragraph = self.gen_section(
topic=knowledge_base.topic, node=node, knowledge_base=knowledge_base
)
lines = node_gen_paragraph.split("\n")
if lines[0].strip().replace("*", "").replace("#", "") == node.name:
lines = lines[1:]
node_gen_paragraph = "\n".join(lines)
path = " -> ".join(node.get_path_from_root())
return path, node_gen_paragraph

with ThreadPoolExecutor(max_workers=5) as executor:
# Submit all tasks
future_to_node = {
executor.submit(_node_generate_paragraph, node): node
for node in all_nodes
}

# Collect the results as they complete
for future in as_completed(future_to_node):
path, node_gen_paragraph = future.result()
node_to_paragraph[path] = node_gen_paragraph

def helper(cur_root, level):
to_return = []
if cur_root is not None:
hash_tag = "#" * level + " "
cur_path = " -> ".join(cur_root.get_path_from_root())
node_gen_paragraph = node_to_paragraph[cur_path]
to_return.append(f"{hash_tag}{cur_root.name}\n{node_gen_paragraph}")
for child in cur_root.children:
to_return.extend(helper(child, level + 1))
return to_return

to_return = []
for child in knowledge_base.root.children:
to_return.extend(helper(child, level=1))

return "\n".join(to_return)


class WriteSection(dspy.Signature):
"""Write a Wikipedia section based on the collected information. You will be given the topic, the section you are writing and relevant information.
Each information will be provided with the raw content along with question and query lead to that information.
Here is the format of your writing:
Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
"""

info = dspy.InputField(prefix="The collected information:\n", format=str)
topic = dspy.InputField(prefix="The topic of the page: ", format=str)
section = dspy.InputField(prefix="The section you need to write: ", format=str)
output = dspy.OutputField(
prefix="Write the section with proper inline citations (Start your writing. Don't include the page title, section name, or try to write other sections. Do not start the section with topic name.):\n",
format=str,
)
110 changes: 110 additions & 0 deletions knowledge_storm/collaborative_storm/modules/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import List
from ...interface import Information


class BaseCallbackHandler:
"""Base callback handler to manage callbacks from the Co-STORM pipeline."""

def on_turn_policy_planning_start(self, **kwargs):
"""Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn."""
pass

def on_expert_action_planning_start(self, **kwargs):
"""Run when the expert action planning begins, preparing to determine the actions that each expert should take."""
pass

def on_expert_action_planning_end(self, **kwargs):
"""Run when the expert action planning ends, after deciding the actions that each expert should take."""
pass

def on_expert_information_collection_start(self, **kwargs):
"""Run when the expert information collection starts, start gathering all necessary data from selected sources."""
pass

def on_expert_information_collection_end(self, info: List[Information], **kwargs):
"""Run when the expert information collection ends, after gathering all necessary data from selected sources."""
pass

def on_expert_utterance_generation_end(self, **kwargs):
"""Run when the expert utterance generation ends, before creating responses or statements from each expert."""
pass

def on_expert_utterance_polishing_start(self, **kwargs):
"""Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content."""
pass

def on_mindmap_insert_start(self, **kwargs):
"""Run when the process of inserting new information into the mindmap starts."""
pass

def on_mindmap_insert_end(self, **kwargs):
"""Run when the process of inserting new information into the mindmap ends."""
pass

def on_mindmap_reorg_start(self, **kwargs):
"""Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information."""
pass

def on_expert_list_update_start(self, **kwargs):
"""Run when the expert list update starts, to modify or refresh the list of active experts."""
pass

def on_article_generation_start(self, **kwargs):
"""Run when the article generation process begins, to compile and format the final article content."""
pass

def on_warmstart_update(self, message, **kwargs):
"""Run when the warm start process has update."""
pass


class LocalConsolePrintCallBackHandler(BaseCallbackHandler):
def __init__(self):
pass

def on_turn_policy_planning_start(self, **kwargs):
"""Run when the turn policy planning begins, before deciding the direction or goal for the next conversation turn."""
print("Start planning next expert; inspect mind map; inspect system state.")

def on_expert_action_planning_start(self, **kwargs):
"""Run when the expert action planning begins, preparing to determine the actions that each expert should take."""
print("Reviewing discourse history; Deciding utterance intent.")

def on_expert_information_collection_start(self, **kwargs):
"""Run when the expert information collection ends, after gathering all necessary data from selected sources."""
print("Start searching with the search engine; browsing collected information.")

def on_expert_information_collection_end(self, info: List[Information], **kwargs):
"""Run when the expert information collection ends, after gathering all necessary data from selected sources."""
if info:
urls = [i.url for i in info]
information_string = "\n".join([f"Finish browsing {url}" for url in urls])
print(information_string)

def on_expert_utterance_generation_end(self, **kwargs):
"""Run when the expert utterance generation ends, before creating responses or statements from each expert."""
print("Finish generating utterance from collected information.")

def on_expert_utterance_polishing_start(self, **kwargs):
"""Run when the expert utterance polishing begins, to refine and improve the clarity and coherence of generated content."""
print("Start polishing utterance.")

def on_mindmap_insert_start(self, **kwargs):
"""Run when the process of inserting new information into the mindmap starts."""
print("Start inserting information into mind map.")

def on_mindmap_insert_end(self, **kwargs):
"""Run when the process of inserting new information into the mindmap ends."""
print("Finish inserting information into mind map.")

def on_mindmap_reorg_start(self, **kwargs):
"""Run when the reorganization of the mindmap begins, to restructure and optimize the flow of information."""
print("Start re-organizing mind map.")

def on_expert_list_update_start(self, **kwargs):
"""Run when the expert list update starts, to modify or refresh the list of active experts."""
print("Start updating expert candidates.")

def on_warmstart_update(self, message, **kwargs):
"""Run when the warm start process has update."""
print(f"Warm start update: {message}")
381 changes: 381 additions & 0 deletions knowledge_storm/collaborative_storm/modules/co_storm_agents.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import dspy
import os
import re
import sys
import toml
from typing import List, Tuple, Dict, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ..engine import RunnerArgument
from ...interface import Information, Retriever, LMConfigs
from ...logging_wrapper import LoggingWrapper
from ...rm import BingSearch


def extract_storm_info_snippet(info: Information, snippet_index: int) -> Information:
"""
Constructs a new Information instance with only the specified snippet index.
Args:
storm_info (Information): The original Information instance.
snippet_index (int): The index of the snippet to retain.
Returns:
Information: A new Information instance with only the specified snippet.
"""
if snippet_index < 0 or snippet_index >= len(info.snippets):
raise ValueError("Snippet index out of range")

new_snippets = [info.snippets[snippet_index]]
new_storm_info = Information(
info.url, info.description, new_snippets, info.title, info.meta
)
return new_storm_info


def format_search_results(
searched_results: List[Information],
info_max_num_words: int = 1000,
mode: str = "brief",
) -> Tuple[str, Dict[int, Information]]:
"""
Constructs a string from a list of search results with a specified word limit and returns a mapping of indices to Information.
Args:
searched_results (List[Information]): List of Information objects to process.
info_max_num_words (int, optional): Maximum number of words allowed in the output string. Defaults to 1000.
mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.
'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.
Returns:
Tuple[str, Dict[int, Information]]:
- Formatted string with search results, constrained by the word limit.
- Dictionary mapping indices to the corresponding Information objects.
"""
total_length = 0

extracted_snippet_queue = []
max_snippets = (
max(len(info.snippets) for info in searched_results) if searched_results else 0
)
max_snippets = 1 if mode == "brief" else max_snippets
abort = False
included_snippets = set()
for i in range(max_snippets):
for info in searched_results:
if i < len(info.snippets) and not abort:
cur_snippet = info.snippets[i]
cur_snippet_len = len(info.snippets[i].split())
if total_length + cur_snippet_len > info_max_num_words:
abort = True
break
if cur_snippet not in included_snippets:
included_snippets.add(cur_snippet)
info = extract_storm_info_snippet(info, snippet_index=i)
extracted_snippet_queue.append(info)
total_length += cur_snippet_len
output = []
index_mapping = {}
for idx, info in enumerate(extracted_snippet_queue):
output.append(f"[{idx + 1}]: {info.snippets[0]}")
index_mapping[idx + 1] = info
assert -1 not in index_mapping
return "\n".join(output), index_mapping


def extract_cited_storm_info(
response: str, index_to_storm_info: Dict[int, Information]
) -> Dict[int, Information]:
"""
Extracts a sub-dictionary of Information instances that are cited in the response.
Args:
response (str): The response string containing inline citations like [1], [2], etc.
index_to_storm_info (Dict[int, Information]): A dictionary mapping indices to Information instances.
Returns:
Dict[int, Information]: A sub-dictionary with only the indices that appear in the response.
"""
cited_indices = set(map(int, re.findall(r"\[(\d+)\]", response)))
cited_storm_info = {
index: info
for index, info in index_to_storm_info.items()
if index in cited_indices
}
return cited_storm_info


def trim_output_after_hint(response: str, hint: str) -> str:
"""
Trims the output string to only keep the substring after the given hint (not including the hint).
Args:
response (str): The original output string.
hint (str): The hint string after which the substring should be kept.
Returns:
str: The trimmed output string, or the original string if the hint is not found.
"""
if hint in response:
start_index = response.find(hint) + len(hint)
return response[start_index:].strip()
return response.strip("\n")


def separate_citations(text: str) -> str:
"""
Separates multiple citations within square brackets into individual citations.
Args:
text (str): The input string containing citations.
Returns:
str: The string with separated citations.
"""

# Define a function to process each match
def replace_citations(match):
citations = match.group(1).split(",")
return "".join(f"[{citation.strip()}]" for citation in citations)

# Use regular expressions to find and replace citations
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
return pattern.sub(replace_citations, text)


def extract_and_remove_citations(text: str) -> Tuple[str, List[int]]:
"""
Removes single inline citations from the input string and returns the modified string and a list of citation integers.
Args:
text (str): The input string containing citations.
Returns:
Tuple[str, List[int]]: The string after removal of citations and a list of citation integers.
"""
citations = []

# Define a function to process each match
def extract_citation(match):
citation = int(match.group(1))
citations.append(citation)
return ""

# Use regular expressions to find and replace citations
pattern = re.compile(r"\[(\d+)\]")
modified_text = pattern.sub(extract_citation, text)

return modified_text, citations


def keep_first_and_last_paragraph(text: str) -> str:
"""
Processes the input text to keep the first and last paragraphs and replace
the middle paragraphs with '[content omitted due to space limit]'.
Args:
text (str): The input text containing paragraphs separated by '\n\n'.
Returns:
str: The processed text.
"""
paragraphs = text.split("\n\n")

if len(paragraphs) <= 3:
return text

first_paragraph = paragraphs[0]
last_paragraph = "\n\n".join(paragraphs[-2:])
return (
f"{first_paragraph}\n\n[content omitted due to space limit]\n\n{last_paragraph}"
)


def clean_up_section(text):
"""Clean up a section:
1. Remove uncompleted sentences (usually due to output token limitation).
2. Deduplicate individual groups of citations.
3. Remove unnecessary summary."""

paragraphs = text.split("\n")
output_paragraphs = []
summary_sec_flag = False
for p in paragraphs:
p = p.strip()
if len(p) == 0:
continue
if not p.startswith("#"):
p = separate_citations(p)
if summary_sec_flag:
if p.startswith("#"):
summary_sec_flag = False
else:
continue
if (
p.startswith("Overall")
or p.startswith("In summary")
or p.startswith("In conclusion")
):
continue
if "# Summary" in p or "# Conclusion" in p:
summary_sec_flag = True
continue
output_paragraphs.append(p)

return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format.


def load_api_key(toml_file_path):
try:
with open(toml_file_path, "r") as file:
data = toml.load(file)
except FileNotFoundError:
print(f"File not found: {toml_file_path}", file=sys.stderr)
return
except toml.TomlDecodeError:
print(f"Error decoding TOML file: {toml_file_path}", file=sys.stderr)
return
# Set environment variables
for key, value in data.items():
os.environ[key] = str(value)


def _get_answer_question_module_instance(
lm_config: LMConfigs,
runner_argument: "RunnerArgument",
logging_wrapper: LoggingWrapper,
rm: Optional[dspy.Retrieve] = None,
):
from .grounded_question_answering import AnswerQuestionModule

# configure retriever
if rm is None:
rm = BingSearch(k=runner_argument.retrieve_top_k)
retriever = Retriever(rm=rm, max_thread=runner_argument.max_search_thread)
# return AnswerQuestionModule instance
return AnswerQuestionModule(
retriever=retriever,
max_search_queries=runner_argument.max_search_queries,
question_answering_lm=lm_config.question_answering_lm,
logging_wrapper=logging_wrapper,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import dspy
from typing import Union

from .callback import BaseCallbackHandler
from .collaborative_storm_utils import (
trim_output_after_hint,
extract_and_remove_citations,
keep_first_and_last_paragraph,
)

from .grounded_question_answering import AnswerQuestionModule
from .grounded_question_generation import ConvertUtteranceStyle
from ...dataclass import ConversationTurn
from ...logging_wrapper import LoggingWrapper


class GenExpertActionPlanning(dspy.Signature):
"""
You are an invited speaker in the round table conversation. Your task is to make a very short note to your assistant to help you prepare for your turn in the conversation.
You will be given the topic we are discussing, your expertise, and the conversation history.
Take a look at conversation history, especially last few turns, then let your assistant prepare the material for you with one of following ways.
1. Original Question: Initiates a new question to other speakers.
2. Further Details: Provides additional information.
3. Information Request: Requests information from other speakers.
4. Potential Answer: Offers a possible solution or answer.
Strictly follow this format: [type of contribution]: [one sentence description]. For example, Original Question: [description]
"""

topic = dspy.InputField(prefix="topic of discussion: ", format=str)
expert = dspy.InputField(prefix="You are inivited as: ", format=str)
summary = dspy.InputField(prefix="Discussion history: \n", format=str)
last_utterance = dspy.InputField(
prefix="Last utterance in the conversation: \n", format=str
)
resposne = dspy.OutputField(
prefix="Now give your note. Start with one of [Original Question, Further Details, Information Request, Potential Answer] with one sentence description\n",
format=str,
)


class CoStormExpertUtteranceGenerationModule(dspy.Module):
def __init__(
self,
action_planning_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
utterance_polishing_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
answer_question_module: AnswerQuestionModule,
logging_wrapper: LoggingWrapper,
callback_handler: BaseCallbackHandler = None,
):
self.action_planning_lm = action_planning_lm
self.utterance_polishing_lm = utterance_polishing_lm
self.expert_action = dspy.Predict(GenExpertActionPlanning)
self.change_style = dspy.Predict(ConvertUtteranceStyle)
self.answer_question_module = answer_question_module
self.logging_wrapper = logging_wrapper
self.callback_handler = callback_handler

def parse_action(self, action):
action_types = [
"Original Question",
"Further Details",
"Information Request",
"Potential Answer",
]
for action_type in action_types:
if f"{action_type}:" in action:
return action_type, trim_output_after_hint(action, f"{action_type}:")
elif f"[{action_type}]:" in action:
return action_type, trim_output_after_hint(action, f"[{action_type}]:")
return "Undefined", ""

def polish_utterance(
self, conversation_turn: ConversationTurn, last_conv_turn: ConversationTurn
):
# change utterance style
action_type = conversation_turn.utterance_type
with self.logging_wrapper.log_event(
"RoundTableConversationModule.ConvertUtteranceStyle"
):
with dspy.settings.context(
lm=self.utterance_polishing_lm, show_guidelines=False
):
action_string = (
f"{action_type} about: {conversation_turn.claim_to_make}"
)
if action_type in ["Original Question", "Information Request"]:
action_string = f"{action_type}"
last_expert_utterance_wo_citation, _ = extract_and_remove_citations(
last_conv_turn.utterance
)
trimmed_last_expert_utterance = keep_first_and_last_paragraph(
last_expert_utterance_wo_citation
)
utterance = self.change_style(
expert=conversation_turn.role,
action=action_string,
prev=trimmed_last_expert_utterance,
content=conversation_turn.raw_utterance,
).utterance
conversation_turn.utterance = utterance

def forward(
self,
topic: str,
current_expert: str,
conversation_summary: str,
last_conv_turn: ConversationTurn,
):
last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)
if last_conv_turn.utterance_type in [
"Original Question",
"Information Request",
]:
action_type = "Potential Answer"
action_content = last_utterance
else:
with self.logging_wrapper.log_event(
"CoStormExpertUtteranceGenerationModule: GenExpertActionPlanning"
):
with dspy.settings.context(
lm=self.action_planning_lm, show_guidelines=False
):
action = self.expert_action(
topic=topic,
expert=current_expert,
summary=conversation_summary,
last_utterance=last_utterance,
).resposne
action_type, action_content = self.parse_action(action)

if self.callback_handler is not None:
self.callback_handler.on_expert_action_planning_end()
# get response
conversation_turn = ConversationTurn(
role=current_expert, raw_utterance="", utterance_type=action_type
)

if action_type == "Undefined":
raise Exception(f"unexpected output: {action}")
elif action_type in ["Further Details", "Potential Answer"]:
with self.logging_wrapper.log_event(
"RoundTableConversationModule: QuestionAnswering"
):
grounded_answer = self.answer_question_module(
topic=topic,
question=action_content,
mode="brief",
style="conversational and concise",
callback_handler=self.callback_handler,
)
conversation_turn.claim_to_make = action_content
conversation_turn.raw_utterance = grounded_answer.response
conversation_turn.queries = grounded_answer.queries
conversation_turn.raw_retrieved_info = grounded_answer.raw_retrieved_info
conversation_turn.cited_info = grounded_answer.cited_info
elif action_type in ["Original Question", "Information Request"]:
conversation_turn.raw_utterance = action_content

return dspy.Prediction(conversation_turn=conversation_turn)
83 changes: 83 additions & 0 deletions knowledge_storm/collaborative_storm/modules/expert_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import dspy
import re
from typing import Union


class GenerateExpertGeneral(dspy.Signature):
"""You need to select a group of diverse experts who will be suitable to be invited to a roundtable discussion on the given topic.
Each expert should represent a different perspective, role, or affiliation related to this topic.
You can use the background information provided about the topic for inspiration. For each expert, add a description of their expertise and what they will focus on during the discussion.
No need to include speakers name in the output.
Strictly follow format below:
1. [speaker 1 role]: [speaker 1 short description]
2. [speaker 2 role]: [speaker 2 short description]
"""

topic = dspy.InputField(prefix="Topic of interest:", format=str)
background_info = dspy.InputField(
prefix="Background information about the topic:\n", format=str
)
topN = dspy.InputField(prefix="Number of speakers needed: ", format=str)
experts = dspy.OutputField(format=str)


class GenerateExpertWithFocus(dspy.Signature):
"""
You need to select a group of speakers who will be suitable to have roundtable discussion on the [topic] of specific [focus].
You may consider inviting speakers having opposite stands on the topic; speakers representing different interest parties; Ensure that the selected speakers are directly connected to the specific context and scenario provided.
For example, if the discussion focus is about a recent event at a specific university, consider inviting students, faculty members, journalists covering the event, university officials, and local community members.
Use the background information provided about the topic for inspiration. For each speaker, add a description of their interests and what they will focus on during the discussion.
No need to include speakers name in the output.
Strictly follow format below:
1. [speaker 1 role]: [speaker 1 short description]
2. [speaker 2 role]: [speaker 2 short description]
"""

topic = dspy.InputField(prefix="Topic of interest:", format=str)
background_info = dspy.InputField(prefix="Background information:\n", format=str)
focus = dspy.InputField(prefix="Discussion focus: ", format=str)
topN = dspy.InputField(prefix="Number of speakers needed: ", format=str)
experts = dspy.OutputField(format=str)


class GenerateExpertModule(dspy.Module):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.engine = engine
self.generate_expert_general = dspy.Predict(GenerateExpertGeneral)
self.generate_expert_w_focus = dspy.ChainOfThought(GenerateExpertWithFocus)

def trim_background(self, background: str, max_words: int = 100):
words = background.split()
cur_len = len(words)
if cur_len <= max_words:
return background
trimmed_words = words[: min(cur_len, max_words)]
trimmed_background = " ".join(trimmed_words)
return f"{trimmed_background} [rest content omitted]."

def forward(
self, topic: str, num_experts: int, background_info: str = "", focus: str = ""
):
with dspy.settings.context(lm=self.engine, show_guidelines=False):
if not focus:
output = self.generate_expert_general(
topic=topic, background_info=background_info, topN=num_experts
).experts
else:
background_info = self.trim_background(
background=background_info, max_words=100
)
output = self.generate_expert_w_focus(
topic=topic,
background_info=background_info,
focus=focus,
topN=num_experts,
).experts
output = output.replace("*", "").replace("[", "").replace("]", "")
expert_list = []
for s in output.split("\n"):
match = re.search(r"\d+\.\s*(.*)", s)
if match:
expert_list.append(match.group(1))
expert_list = [expert.strip() for expert in expert_list if expert.strip()]
return dspy.Prediction(experts=expert_list, raw_output=output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import dspy
from typing import Union, List

from .callback import BaseCallbackHandler
from .collaborative_storm_utils import (
trim_output_after_hint,
format_search_results,
extract_cited_storm_info,
separate_citations,
)
from ...logging_wrapper import LoggingWrapper
from ...utils import ArticleTextProcessing
from ...interface import Information


class QuestionToQuery(dspy.Signature):
"""You want to answer the question or support a claim using Google search. What do you type in the search box?
The question is raised in a round table discussion on a topic. The question may or may not focus on the topic itself.
Write the queries you will use in the following format:
- query 1
- query 2
...
- query n"""

topic = dspy.InputField(prefix="Topic context:", format=str)
question = dspy.InputField(
prefix="I want to collect information about: ", format=str
)
queries = dspy.OutputField(prefix="Queries: \n", format=str)


class AnswerQuestion(dspy.Signature):
"""You are an expert who can use information effectively. You have gathered the related information and will now use the information to form a response.
Make your response as informative as possible and make sure every sentence is supported by the gathered information.
If [Gathered information] is not directly related to the [Topic] and [Question], provide the most relevant answer you can based on the available information, and explain any limitations or gaps.
Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3].").
You DO NOT need to include a References or Sources section to list the sources at the end. The style of writing should be formal.
"""

topic = dspy.InputField(prefix="Topic you are discussing about:", format=str)
question = dspy.InputField(prefix="You want to provide insight on: ", format=str)
info = dspy.InputField(prefix="Gathered information:\n", format=str)
style = dspy.InputField(prefix="Style of your response should be:", format=str)
answer = dspy.OutputField(
prefix="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)",
format=str,
)


class AnswerQuestionModule(dspy.Module):
def __init__(
self,
retriever: dspy.Retrieve,
max_search_queries: int,
question_answering_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
logging_wrapper: LoggingWrapper,
):
super().__init__()
self.question_answering_lm = question_answering_lm
self.question_to_query = dspy.Predict(QuestionToQuery)
self.answer_question = dspy.Predict(AnswerQuestion)
self.retriever = retriever
self.max_search_queries = max_search_queries
self.logging_wrapper = logging_wrapper

def retrieve_information(self, topic, question):
# decompose question to queries
with self.logging_wrapper.log_event(
f"AnswerQuestionModule.question_to_query ({hash(question)})"
):
with dspy.settings.context(lm=self.question_answering_lm):
queries = self.question_to_query(topic=topic, question=question).queries
queries = trim_output_after_hint(queries, hint="Queries:")
queries = [
q.replace("-", "").strip().strip('"').strip('"').strip()
for q in queries.split("\n")
]
queries = queries[: self.max_search_queries]
self.logging_wrapper.add_query_count(count=len(queries))
with self.logging_wrapper.log_event(
f"AnswerQuestionModule.retriever.retrieve ({hash(question)})"
):
# retrieve information using retriever
searched_results: List[Information] = self.retriever.retrieve(
list(set(queries)), exclude_urls=[]
)
# update storm information meta to include the question
for storm_info in searched_results:
storm_info.meta["question"] = question
return queries, searched_results

def forward(
self,
topic: str,
question: str,
mode: str = "brief",
style: str = "conversational",
callback_handler: BaseCallbackHandler = None,
):
"""
Processes a topic and question to generate a response with relevant information and citations.
Args:
topic (str): The topic of interest.
question (str): The specific question related to the topic.
mode (str, optional): Mode of summarization. 'brief' takes only the first snippet of each Information.
'extensive' adds snippets iteratively until the word limit is reached. Defaults to 'brief'.
Returns:
dspy.Prediction: An object containing the following:
- question (str): the question to answer
- queries (List[str]): List of query strings used for information retrieval.
- raw_retrieved_info (List[Information]): List of Information instances retrieved.
- cited_info (Dict[int, Information]): Dictionary of cited Information instances, indexed by their citation number.
- response (str): The generated response string with inline citations.
"""
# retrieve information
if callback_handler is not None:
callback_handler.on_expert_information_collection_start()
queries, searched_results = self.retrieve_information(
topic=topic, question=question
)
if callback_handler is not None:
callback_handler.on_expert_information_collection_end(searched_results)
# format information string for answer generation
info_text, index_to_information_mapping = format_search_results(
searched_results, mode=mode
)
answer = "Sorry, there is insufficient information to answer the question."
# generate answer to the question
if info_text:
with self.logging_wrapper.log_event(
f"AnswerQuestionModule.answer_question ({hash(question)})"
):
with dspy.settings.context(
lm=self.question_answering_lm, show_guidelines=False
):
answer = self.answer_question(
topic=topic, question=question, info=info_text, style=style
).answer
answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
answer
)
answer = trim_output_after_hint(
answer,
hint="Now give your response. (Try to use as many different sources as possible and do not hallucinate.)",
)
# enforce single citation index bracket. [1, 2] -> [1][2]
answer = separate_citations(answer)
if callback_handler is not None:
callback_handler.on_expert_utterance_generation_end()
# construct cited search result
cited_searched_results = extract_cited_storm_info(
response=answer, index_to_storm_info=index_to_information_mapping
)

return dspy.Prediction(
question=question,
queries=queries,
raw_retrieved_info=searched_results,
cited_info=cited_searched_results,
response=answer,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
This module handles question generation within the Co-STORM framework, specifically designed to support the Moderator role.
The Moderator generates insightful, thought-provoking questions that introduce new directions into the conversation.
By leveraging uncited or unused snippets of information retrieved during the discussion, the Moderator ensures the conversation remains dynamic and avoids repetitive or overly niche topics.
For more detailed information, refer to Section 3.5 of the Co-STORM paper: https://www.arxiv.org/pdf/2408.15232.
"""

import dspy
from typing import List, Union

from .collaborative_storm_utils import (
format_search_results,
extract_and_remove_citations,
keep_first_and_last_paragraph,
extract_cited_storm_info,
)
from ...dataclass import ConversationTurn, KnowledgeBase
from ...interface import Information


class KnowledgeBaseSummmary(dspy.Signature):
"""Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.
You will be presented with these sections where "#" denotes level of section.
"""

topic = dspy.InputField(prefix="topic: ", format=str)
structure = dspy.InputField(prefix="Tree structure: \n", format=str)
output = dspy.OutputField(prefix="Now give brief summary:\n", format=str)


class ConvertUtteranceStyle(dspy.Signature):
"""
You are an invited speaker in the round table conversation.
Your task is to make the question or the response more conversational and engaging to facilicate the flow of conversation.
Note that this is ongoing conversation so no need to have welcoming and concluding words. Previous speaker utterance is provided only for making the conversation more natural.
Note that do not hallucinate and keep the citation index like [1] as it is. Also,
"""

expert = dspy.InputField(prefix="You are inivited as: ", format=str)
action = dspy.InputField(
prefix="You want to contribute to conversation by: ", format=str
)
prev = dspy.InputField(prefix="Previous speaker said: ", format=str)
content = dspy.InputField(
prefix="Question or response you want to say: ", format=str
)
utterance = dspy.OutputField(
prefix="Your utterance (keep the information as much as you can with citations, prefer shorter answers without loss of information): ",
format=str,
)


class GroundedQuestionGeneration(dspy.Signature):
"""Your job is to find next discussion focus in a roundtable conversation. You will be given previous conversation summary and some information that might assist you discover new discussion focus.
Note that the new discussion focus should bring new angle and perspective to the discussion and avoid repetition. The new discussion focus should be grounded on the available information and push the boundaries of the current discussion for broader exploration.
The new discussion focus should have natural flow from last utterance in the conversation.
Use [1][2] in line to ground your question.
"""

topic = dspy.InputField(prefix="topic: ", format=str)
summary = dspy.InputField(prefix="Discussion history: \n", format=str)
information = dspy.InputField(prefix="Available information: \n", format=str)
last_utterance = dspy.InputField(
prefix="Last utterance in the conversation: \n", format=str
)
output = dspy.OutputField(
prefix="Now give next discussion focus in the format of one sentence question:\n",
format=str,
)


class GroundedQuestionGenerationModule(dspy.Module):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.engine = engine
self.gen_focus = dspy.Predict(GroundedQuestionGeneration)
self.polish_style = dspy.Predict(ConvertUtteranceStyle)
self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)

def forward(
self,
topic: str,
knowledge_base: KnowledgeBase,
last_conv_turn: ConversationTurn,
unused_snippets: List[Information],
):
information, index_to_information_mapping = format_search_results(
unused_snippets, info_max_num_words=1000
)
summary = knowledge_base.get_knowledge_base_summary()
last_utterance, _ = extract_and_remove_citations(last_conv_turn.utterance)
with dspy.settings.context(lm=self.engine, show_guidelines=False):
raw_utterance = self.gen_focus(
topic=topic,
summary=summary,
information=information,
last_utterance=last_utterance,
).output
utterance = self.polish_style(
expert="Roundtable conversation moderator",
action="Raising a new question by natural transit from previous utterance.",
prev=keep_first_and_last_paragraph(last_utterance),
content=raw_utterance,
).utterance
cited_searched_results = extract_cited_storm_info(
response=utterance, index_to_storm_info=index_to_information_mapping
)
return dspy.Prediction(
raw_utterance=raw_utterance,
utterance=utterance,
cited_info=cited_searched_results,
)

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import dspy
from typing import Union
from ...dataclass import KnowledgeBase


class KnowledgeBaseSummmary(dspy.Signature):
"""Your job is to give brief summary of what's been discussed in a roundtable conversation. Contents are themantically organized into hierarchical sections.
You will be presented with these sections where "#" denotes level of section.
"""

topic = dspy.InputField(prefix="topic: ", format=str)
structure = dspy.InputField(prefix="Tree structure: \n", format=str)
output = dspy.OutputField(prefix="Now give brief summary:\n", format=str)


class KnowledgeBaseSummaryModule(dspy.Module):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.engine = engine
self.gen_summary = dspy.Predict(KnowledgeBaseSummmary)

def forward(self, knowledge_base: KnowledgeBase):
structure = knowledge_base.get_node_hierarchy_string(
include_indent=False,
include_full_path=False,
include_hash_tag=True,
include_node_content_count=False,
)
with dspy.settings.context(lm=self.engine, show_guidelines=False):
summary = self.gen_summary(
topic=knowledge_base.topic, structure=structure
).output
return summary
37 changes: 37 additions & 0 deletions knowledge_storm/collaborative_storm/modules/simulate_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dspy
from typing import List, Union

from .collaborative_storm_utils import extract_and_remove_citations
from ...dataclass import ConversationTurn
from ...storm_wiki.modules.knowledge_curation import AskQuestionWithPersona


class GenSimulatedUserUtterance(dspy.Module):
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.engine = engine
self.ask_qeustion = dspy.Predict(AskQuestionWithPersona)

def gen_conv_history_string(self, conversation_turns: List[ConversationTurn]):
conv_history = []
total_turns = len(conversation_turns)

for i, turn in enumerate(conversation_turns):
utterance, _ = extract_and_remove_citations(turn.utterance)
if i >= total_turns - 4:
conv_history.append(f"{turn.role}: {utterance}")
else:
if turn.claim_to_make:
conv_history.append(f"{turn.role}: {turn.claim_to_make}")
else:
conv_history.append(f"{turn.role}: {utterance}")

return "\n".join(conv_history)

def forward(self, topic: str, intent: str, conv_history: List[ConversationTurn]):
conv_history_string = self.gen_conv_history_string(conv_history)
with dspy.settings.context(lm=self.engine, show_guidelines=False):
return self.ask_qeustion(
topic=topic,
persona=f"researcher with interest in {intent}",
conv=conv_history_string,
).question

Large diffs are not rendered by default.

849 changes: 849 additions & 0 deletions knowledge_storm/dataclass.py

Large diffs are not rendered by default.

169 changes: 169 additions & 0 deletions knowledge_storm/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import requests
import os
from typing import List, Tuple, Union, Optional, Dict, Literal
import numpy as np

from concurrent.futures import ThreadPoolExecutor, as_completed


class EmbeddingModel:
def __init__():
pass

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
raise Exception("Not implemented")


class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(self, model: str = "text-embedding-3-small", api_key: str = None):
if not api_key:
self.api_key = os.getenv("OPENAI_API_KEY")

self.url = "https://api.openai.com/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self.model = model

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
data = {"input": text, "model": "text-embedding-3-small"}

response = requests.post(self.url, headers=self.headers, json=data)
if response.status_code == 200:
data = response.json()
embedding = np.array(data["data"][0]["embedding"])
token = data["usage"]["prompt_tokens"]
return embedding, token
else:
response.raise_for_status()


class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(self, model: str = "text-embedding-3-small", api_key: str = None):
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")

self.url = "https://api.openai.com/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
self.model = model

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
data = {"input": text, "model": self.model}

response = requests.post(self.url, headers=self.headers, json=data)
if response.status_code == 200:
data = response.json()
embedding = np.array(data["data"][0]["embedding"])
token = data["usage"]["prompt_tokens"]
return embedding, token
else:
response.raise_for_status()


class TogetherEmbeddingModel:
def __init__(self, model: str = "BAAI/bge-large-en-v1.5", api_key: str = None):
import together

self.model = model
if not api_key:
api_key = os.getenv("TOGETHER_API_KEY")
self.together_client = together.Together(api_key=api_key)

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
response = self.together_client.embeddings.create(input=text, model=self.model)
return response.data[0].embedding, -1


class AzureOpenAIEmbeddingModel:
def __init__(self, model: str = "text-embedding-3-small", api_key: str = None):
from openai import AzureOpenAI

self.model = model
if not api_key:
api_key = os.getenv("AZURE_API_KEY")

self.client = AzureOpenAI(
api_key=api_key,
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.getenv("AZURE_API_BASE"),
)

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
response = self.client.embeddings.create(input=text, model=self.model)

embedding = np.array(response.data[0].embedding)
token = response.usage.prompt_tokens
return embedding, token


def get_text_embeddings(
texts: Union[str, List[str]],
max_workers: int = 5,
embedding_cache: Optional[Dict[str, np.ndarray]] = None,
) -> Tuple[np.ndarray, int]:
"""
Get text embeddings using OpenAI's text-embedding-3-small model.
Args:
texts (Union[str, List[str]]): A single text string or a list of text strings to embed.
max_workers (int): The maximum number of workers for parallel processing.
api_key (str): The API key for accessing OpenAI's services.
embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings.
Returns:
Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage.
"""
embedding_model = None
encoder_type = os.getenv("ENCODER_API_TYPE")
if encoder_type and encoder_type == "openai":
embedding_model = OpenAIEmbeddingModel()
elif encoder_type and encoder_type == "azure":
embedding_model = AzureOpenAIEmbeddingModel()
elif encoder_type == encoder_type == "together":
embedding_model = TogetherEmbeddingModel()
else:
raise Exception(
"No valid encoder type is provided. Check <repo root>/secrets.toml for the field ENCODER_API_TYPE"
)

def fetch_embedding(text: str) -> Tuple[str, np.ndarray, int]:
if embedding_cache is not None and text in embedding_cache:
return (
text,
embedding_cache[text],
0,
) # Returning 0 tokens since no API call is made
embedding, token_usage = embedding_model.get_embedding(text)
return text, embedding, token_usage

if isinstance(texts, str):
_, embedding, tokens = fetch_embedding(texts)
return np.array(embedding), tokens

embeddings = []
total_tokens = 0

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(fetch_embedding, text): text for text in texts}

for future in as_completed(futures):
try:
text, embedding, tokens = future.result()
embeddings.append((text, embedding, tokens))
total_tokens += tokens
except Exception as e:
print(f"An error occurred for text: {futures[future]}")
print(e)

# Sort results to match the order of the input texts
embeddings.sort(key=lambda x: texts.index(x[0]))
if embedding_cache is not None:
for text, embedding, _ in embeddings:
embedding_cache[text] = embedding
embeddings = [result[1] for result in embeddings]

return np.array(embeddings), total_tokens
227 changes: 188 additions & 39 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import concurrent.futures
import dspy
import functools
import hashlib
import json
import logging
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, TYPE_CHECKING

from .utils import ArticleTextProcessing

logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logger = logging.getLogger(__name__)


class Information(ABC):
"""Abstract base class to represent basic information.
Attributes:
uuid (str): The unique identifier for the information.
meta (dict): The meta information associated with the information.
"""

def __init__(self, uuid, meta={}):
self.uuid = uuid
self.meta = meta
if TYPE_CHECKING:
from .logging_wrapper import LoggingWrapper


class InformationTable(ABC):
@@ -42,6 +38,101 @@ def retrieve_information(**kwargs):
pass


class Information:
"""Class to represent detailed information.
Inherits from Information to include a unique identifier (URL), and extends
it with a description, snippets, and title of the storm information.
Attributes:
description (str): Brief description.
snippets (list): List of brief excerpts or snippets.
title (str): The title or headline of the information.
url (str): The unique URL (serving as UUID) of the information.
"""

def __init__(self, url, description, snippets, title, meta=None):
"""Initialize the Information object with detailed attributes.
Args:
url (str): The unique URL serving as the identifier for the information.
description (str): Detailed description.
snippets (list): List of brief excerpts or snippet.
title (str): The title or headline of the information.
"""
self.description = description
self.snippets = snippets
self.title = title
self.url = url
self.meta = meta if meta is not None else {}
self.citation_uuid = -1

def __hash__(self):
return hash(
(
self.url,
tuple(sorted(self.snippets)),
)
)

def __eq__(self, other):
if not isinstance(other, Information):
return False
return (
self.url == other.url
and set(self.snippets) == set(other.snippets)
and self._meta_str() == other._meta_str()
)

def __hash__(self):
return int(
self._md5_hash((self.url, tuple(sorted(self.snippets)), self._meta_str())),
16,
)

def _meta_str(self):
"""Generate a string representation of relevant meta information."""
return f"Question: {self.meta.get('question', '')}, Query: {self.meta.get('query', '')}"

def _md5_hash(self, value):
"""Generate an MD5 hash for a given value."""
if isinstance(value, (dict, list, tuple)):
value = json.dumps(value, sort_keys=True)
return hashlib.md5(str(value).encode("utf-8")).hexdigest()

@classmethod
def from_dict(cls, info_dict):
"""Create a Information object from a dictionary.
Usage: info = Information.from_dict(storm_info_dict)
Args:
info_dict (dict): A dictionary containing keys 'url', 'description',
'snippets', and 'title' corresponding to the object's attributes.
Returns:
Information: An instance of Information.
"""
info = cls(
url=info_dict["url"],
description=info_dict["description"],
snippets=info_dict["snippets"],
title=info_dict["title"],
meta=info_dict.get("meta", None),
)
info.citation_uuid = int(info_dict.get("citation_uuid", -1))
return info

def to_dict(self):
return {
"url": self.url,
"description": self.description,
"snippets": self.snippets,
"title": self.title,
"meta": self.meta,
"citation_uuid": self.citation_uuid,
}


class ArticleSectionNode:
"""
The ArticleSectionNode is the dataclass for handling the section of the article.
@@ -166,7 +257,7 @@ def prune_empty_nodes(self, node=None):
return node


class Retriever(ABC):
class Retriever:
"""
An abstract base class for retriever modules. It provides a template for retrieving information based on a query.
@@ -175,19 +266,14 @@ class Retriever(ABC):
The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name.
"""

def __init__(self, search_top_k):
self.search_top_k = search_top_k

def update_search_top_k(self, k):
self.search_top_k = k
def __init__(self, rm: dspy.Retrieve, max_thread: int = 1):
self.max_thread = max_thread
self.rm = rm

def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if "_rm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())
if hasattr(getattr(self, "rm"), "get_usage_and_reset"):
combined_usage.append(getattr(self, "rm").get_usage_and_reset())

name_to_usage = {}
for usage in combined_usage:
@@ -199,21 +285,38 @@ def collect_and_reset_rm_usage(self):

return name_to_usage

@abstractmethod
def retrieve(self, query: Union[str, List[str]], **kwargs) -> List[Information]:
"""
Retrieves information based on a query.
This method must be implemented by subclasses to specify how information is retrieved.
Args:
query (Union[str, List[str]]): The query or list of queries to retrieve information for.
**kwargs: Additional keyword arguments that might be necessary for the retrieval process.
Returns:
List[Information]: A list of Information objects retrieved based on the query.
"""
pass
def retrieve(
self, query: Union[str, List[str]], exclude_urls: List[str] = []
) -> List[Information]:
queries = query if isinstance(query, list) else [query]
to_return = []

def process_query(q):
retrieved_data_list = self.rm(
query_or_queries=[q], exclude_urls=exclude_urls
)
local_to_return = []
for data in retrieved_data_list:
for i in range(len(data["snippets"])):
# STORM generate the article with citations. We do not consider multi-hop citations.
# Remove citations in the source to avoid confusion.
data["snippets"][i] = ArticleTextProcessing.remove_citations(
data["snippets"][i]
)
storm_info = Information.from_dict(data)
storm_info.meta["query"] = q
local_to_return.append(storm_info)
return local_to_return

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_thread
) as executor:
results = list(executor.map(process_query, queries))

for result in results:
to_return.extend(result)

return to_return


class KnowledgeCurationModule(ABC):
@@ -458,3 +561,49 @@ def reset(self):
self.time = {}
self.lm_cost = {}
self.rm_cost = {}


class Agent(ABC):
"""
Interface for STORM and Co-STORM LLM agent
This class must be implemented by any subclass of `Agent` to define how the agent generates an utterance.
The generated utterance can be influenced by the conversation history, knowledge base, and any additional parameters passed via `kwargs`.
The implementation should align with the specific role and perspective of the agent, as defined by the agent's topic, role name, and role description.
Args:
knowledge_base (KnowledgeBase): The current knowledge base (e.g., mind map in Co-STORM) that contains the accumulated information relevant to the conversation.
conversation_history (List[ConversationTurn]): A list of past conversation turns, providing context for generating the next utterance.
The agent can refer to this history to maintain continuity and relevance in the conversation.
logging_wrapper (LoggingWrapper): A wrapper used for logging important events during the utterance generation process.
**kwargs: Additional arguments that can be passed to the method for more specialized utterance generation behavior depending on the agent's specific implementation.
Returns:
ConversationTurn: A new conversation turn generated by the agent, containing the agent's response, including the role, utterance type, and relevant information from the knowledge base.
Notes:
- Subclasses of `Agent` should define the exact strategy for generating the utterance, which could involve interacting with a language model, retrieving relevant knowledge, or following specific conversational policies.
- The agent's role, perspective, and the knowledge base content will influence how the utterance is formulated.
"""

from .dataclass import KnowledgeBase, ConversationTurn

def __init__(self, topic: str, role_name: str, role_description: str):
self.topic = topic
self.role_name = role_name
self.role_description = role_description

def get_role_description(self):
if self.role_description:
return f"{self.role_name}: {self.role_description}"
return self.role_name

@abstractmethod
def generate_utterance(
self,
knowledge_base: KnowledgeBase,
conversation_history: List[ConversationTurn],
logging_wrapper: "LoggingWrapper",
**kwargs,
):
pass
44 changes: 24 additions & 20 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ class OpenAIModel(dspy.OpenAI):

def __init__(
self,
model: str = "gpt-3.5-turbo-instruct",
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
model_type: Literal["chat", "text"] = None,
**kwargs,
@@ -211,7 +211,7 @@ def __init__(
self,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model: str = "gpt-3.5-turbo-instruct",
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
model_type: Literal["chat", "text"] = "chat",
**kwargs,
@@ -674,21 +674,28 @@ class TogetherClient(dspy.HFModel):
def __init__(
self,
model,
api_key: Optional[str] = None,
apply_tokenizer_chat_template=False,
hf_tokenizer_name=None,
model_type: Literal["chat", "text"] = "chat",
**kwargs,
):
"""Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template."""

super().__init__(model=model, is_client=True)
self.session = requests.Session()
self.api_base = (
"https://api.together.xyz/v1/completions"
if os.getenv("TOGETHER_API_BASE") is None
else os.getenv("TOGETHER_API_BASE")
self.api_key = api_key = (
os.environ.get("TOGETHER_API_KEY") if api_key is None else api_key
)
self.token = os.getenv("TOGETHER_API_KEY")
self.model = model
self.model_type = model_type
if os.getenv("TOGETHER_API_BASE") is None:
if self.model_type == "chat":
self.api_base = "https://api.together.xyz/v1/chat/completions"
else:
self.api_base = "https://api.together.xyz/v1/completions"
else:
self.api_base = os.getenv("TOGETHER_API_BASE")

# self.use_inst_template = False
# if any(keyword in self.model.lower() for keyword in ["inst", "instruct"]):
@@ -705,12 +712,12 @@ def __init__(
stop_default = "\n\n---"

self.kwargs = {
"temperature": 0.0,
"max_tokens": 512,
"top_p": 1,
"top_k": 20,
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
"top_p": kwargs.get("top_p", 1.0),
"top_k": kwargs.get("top_k", 1),
"repetition_penalty": 1,
"n": 1,
"n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
"stop": stop_default if "stop" not in kwargs else kwargs["stop"],
**kwargs,
}
@@ -745,9 +752,7 @@ def get_usage_and_reset(self):
max_time=1000,
on_backoff=backoff_hdlr,
)
def _generate(self, prompt, use_chat_api=False, **kwargs):
url = f"{self.api_base}"

def _generate(self, prompt, **kwargs):
kwargs = {**self.kwargs, **kwargs}

stop = kwargs.get("stop")
@@ -762,8 +767,7 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
)
# prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt

if use_chat_api:
url = f"{self.api_base}/chat/completions"
if self.model_type == "chat":
messages = [
{
"role": "system",
@@ -793,13 +797,13 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
"stop": stop,
}

headers = {"Authorization": f"Bearer {self.token}"}
headers = {"Authorization": f"Bearer {self.api_key}"}

with self.session.post(url, headers=headers, json=body) as resp:
with self.session.post(self.api_base, headers=headers, json=body) as resp:
resp_json = resp.json()
# Log the token usage from the Together API response.
self.log_usage(resp_json)
if use_chat_api:
if self.model_type == "chat":
# completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")]
completions = [
resp_json.get("choices", [])[0]
212 changes: 212 additions & 0 deletions knowledge_storm/logging_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from contextlib import contextmanager
import time
import pytz
from datetime import datetime

# Define California timezone
CALIFORNIA_TZ = pytz.timezone("America/Los_Angeles")


class EventLog:
def __init__(self, event_name):
self.event_name = event_name
self.start_time = None
self.end_time = None
self.child_events = {}

def record_start_time(self):
self.start_time = datetime.now(
pytz.utc
) # Store in UTC for consistent timezone conversion

def record_end_time(self):
self.end_time = datetime.now(
pytz.utc
) # Store in UTC for consistent timezone conversion

def get_total_time(self):
if self.start_time and self.end_time:
return (self.end_time - self.start_time).total_seconds()
return 0

def get_start_time(self):
if self.start_time:
# Format to milliseconds
return self.start_time.astimezone(CALIFORNIA_TZ).strftime(
"%Y-%m-%d %H:%M:%S.%f"
)[:-3]
return None

def get_end_time(self):
if self.end_time:
# Format to milliseconds
return self.end_time.astimezone(CALIFORNIA_TZ).strftime(
"%Y-%m-%d %H:%M:%S.%f"
)[:-3]
return None

def add_child_event(self, child_event):
self.child_events[child_event.event_name] = child_event

def get_child_events(self):
return self.child_events


class LoggingWrapper:
def __init__(self, lm_config):
self.logging_dict = {}
self.lm_config = lm_config
self.current_pipeline_stage = None
self.event_stack = []
self.pipeline_stage_active = False

def _pipeline_stage_start(self, pipeline_stage: str):
if self.pipeline_stage_active:
raise RuntimeError(
"A pipeline stage is already active. End the current stage before starting a new one."
)

self.current_pipeline_stage = pipeline_stage
self.logging_dict[pipeline_stage] = {
"time_usage": {},
"lm_usage": {},
"lm_history": [],
"query_count": 0,
}
self.pipeline_stage_active = True

def _event_start(self, event_name: str):
if not self.pipeline_stage_active:
raise RuntimeError("No pipeline stage is currently active.")

if not self.event_stack and self.current_pipeline_stage:
# Top-level event (directly under the pipeline stage)
if (
event_name
not in self.logging_dict[self.current_pipeline_stage]["time_usage"]
):
event = EventLog(event_name=event_name)
event.record_start_time()
self.logging_dict[self.current_pipeline_stage]["time_usage"][
event_name
] = event
self.event_stack.append(event)
else:
self.logging_dict[self.current_pipeline_stage]["time_usage"][
event_name
].record_start_time()
elif self.event_stack:
# Nested event (under another event)
parent_event = self.event_stack[-1]
if event_name not in parent_event.get_child_events():
event = EventLog(event_name=event_name)
event.record_start_time()
parent_event.add_child_event(event)
self.logging_dict[self.current_pipeline_stage]["time_usage"][
event_name
] = event
self.event_stack.append(event)
else:
parent_event.get_child_events()[event_name].record_start_time()
else:
raise RuntimeError(
"Cannot start an event without an active pipeline stage or parent event."
)

def _event_end(self, event_name: str):
if not self.pipeline_stage_active:
raise RuntimeError("No pipeline stage is currently active.")

if not self.event_stack:
raise RuntimeError("No parent event is currently active.")

if self.event_stack:
current_event_log = self.event_stack[-1]
if event_name in current_event_log.get_child_events():
current_event_log.get_child_events()[event_name].record_end_time()
elif (
event_name
in self.logging_dict[self.current_pipeline_stage]["time_usage"]
):
self.logging_dict[self.current_pipeline_stage]["time_usage"][
event_name
].record_end_time()
else:
raise AssertionError(
f"Failure to record end time for event {event_name}. Start time is not recorded."
)
if current_event_log.event_name == event_name:
self.event_stack.pop()
else:
raise RuntimeError("Cannot end an event without an active parent event.")

def _pipeline_stage_end(self):
if not self.pipeline_stage_active:
raise RuntimeError("No pipeline stage is currently active to end.")

self.logging_dict[self.current_pipeline_stage][
"lm_usage"
] = self.lm_config.collect_and_reset_lm_usage()
self.logging_dict[self.current_pipeline_stage][
"lm_history"
] = self.lm_config.collect_and_reset_lm_history()
self.pipeline_stage_active = False

def add_query_count(self, count):
if not self.pipeline_stage_active:
raise RuntimeError(
"No pipeline stage is currently active to add query count."
)

self.logging_dict[self.current_pipeline_stage]["query_count"] += count

@contextmanager
def log_event(self, event_name):
if not self.pipeline_stage_active:
raise RuntimeError("No pipeline stage is currently active.")

self._event_start(event_name)
yield
self._event_end(event_name)

@contextmanager
def log_pipeline_stage(self, pipeline_stage):
if self.pipeline_stage_active:
print(
"A pipeline stage is already active, ending the current stage safely."
)
self._pipeline_stage_end()

start_time = time.time()
try:
self._pipeline_stage_start(pipeline_stage)
yield
except Exception as e:
print(f"Error occurred during pipeline stage '{pipeline_stage}': {e}")
finally:
self.logging_dict[self.current_pipeline_stage]["total_wall_time"] = (
time.time() - start_time
)
self._pipeline_stage_end()

def dump_logging_and_reset(self, reset_logging=True):
log_dump = {}
for pipeline_stage, pipeline_log in self.logging_dict.items():
time_stamp_log = {
event_name: {
"total_time_seconds": event.get_total_time(),
"start_time": event.get_start_time(),
"end_time": event.get_end_time(),
}
for event_name, event in pipeline_log["time_usage"].items()
}
log_dump[pipeline_stage] = {
"time_usage": time_stamp_log,
"lm_usage": pipeline_log["lm_usage"],
"lm_history": pipeline_log["lm_history"],
"query_count": pipeline_log["query_count"],
"total_wall_time": pipeline_log["total_wall_time"],
}
if reset_logging:
self.logging_dict.clear()
return log_dump
136 changes: 112 additions & 24 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
@@ -333,10 +333,79 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
return collected_results


class StanfordOvalArxivRM(dspy.Retrieve):
"""[Alpha] This retrieval class is for internal use only, not intended for the public."""

def __init__(self, endpoint, k=3):
super().__init__(k=k)
self.endpoint = endpoint
self.usage = 0

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"CS224vArxivRM": usage}

def _retrieve(self, query: str):
payload = {"query": query, "num_blocks": self.k}

response = requests.post(
self.endpoint, json=payload, headers={"Content-Type": "application/json"}
)

# Check if the request was successful
if response.status_code == 200:
data = response.json()[0]
results = []
for i in range(len(data["title"])):
result = {
"title": data["title"][i],
"url": data["title"][i],
"snippets": [data["text"][i]],
"description": "N/A",
"meta": {"section_title": data["full_section_title"][i]},
}
results.append(result)

return results
else:
raise Exception(
f"Error: Unable to retrieve results. Status code: {response.status_code}"
)

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
collected_results = []
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)

for query in queries:
try:
results = self._retrieve(query)
collected_results.extend(results)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")
return collected_results


class SerperRM(dspy.Retrieve):
"""Retrieve information from custom queries using Serper.dev."""

def __init__(self, serper_search_api_key=None, query_params=None):
def __init__(
self,
serper_search_api_key=None,
k=3,
query_params=None,
ENABLE_EXTRA_SNIPPET_EXTRACTION=False,
min_char_count: int = 150,
snippet_chunk_size: int = 1000,
webpage_helper_max_threads=10,
):
"""Args:
serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/
query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query.
@@ -355,9 +424,21 @@ def __init__(self, serper_search_api_key=None, query_params=None):
qdr:m str: Date time range for past month.
qdr:y str: Date time range for past year.
"""
super().__init__()
super().__init__(k=k)
self.usage = 0
self.query_params = query_params
self.query_params = None
self.ENABLE_EXTRA_SNIPPET_EXTRACTION = ENABLE_EXTRA_SNIPPET_EXTRACTION
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads,
)

if query_params is None:
self.query_params = {"num": k, "autocorrect": True, "page": 1}
else:
self.query_params = query_params
self.query_params.update({"num": k})
self.serper_search_api_key = serper_search_api_key
if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"):
raise RuntimeError(
@@ -435,34 +516,41 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
# Array of dictionaries that will be used by Storm to create the jsons
collected_results = []

if self.ENABLE_EXTRA_SNIPPET_EXTRACTION:
urls = []
for result in self.results:
organic_results = result.get("organic", [])
for organic in organic_results:
url = organic.get("link")
if url:
urls.append(url)
valid_url_to_snippets = self.webpage_helper.urls_to_snippets(urls)
else:
valid_url_to_snippets = {}

for result in self.results:
try:
# An array of dictionaries that contains the snippets, title of the document and url that will be used.
organic_results = result.get("organic")

knowledge_graph = result.get("knowledgeGraph")
for organic in organic_results:
snippets = []
snippets.append(organic.get("snippet"))
if knowledge_graph != None:
collected_results.append(
{
"snippets": snippets,
"title": organic.get("title"),
"url": organic.get("link"),
"description": knowledge_graph.get("description"),
}
)
else:
# Common for knowledge graph to be None, set description to empty string
collected_results.append(
{
"snippets": snippets,
"title": organic.get("title"),
"url": organic.get("link"),
"description": "",
}
snippets = [organic.get("snippet")]
if self.ENABLE_EXTRA_SNIPPET_EXTRACTION:
snippets.extend(
valid_url_to_snippets.get(url, {}).get("snippets", [])
)
collected_results.append(
{
"snippets": snippets,
"title": organic.get("title"),
"url": organic.get("link"),
"description": (
knowledge_graph.get("description")
if knowledge_graph is not None
else ""
),
}
)
except:
continue

48 changes: 41 additions & 7 deletions knowledge_storm/storm_wiki/engine.py
Original file line number Diff line number Diff line change
@@ -12,10 +12,9 @@
from .modules.knowledge_curation import StormKnowledgeCurationModule
from .modules.outline_generation import StormOutlineGenerationModule
from .modules.persona_generator import StormPersonaGenerator
from .modules.retriever import StormRetriever
from .modules.storm_dataclass import StormInformationTable, StormArticle
from ..interface import Engine, LMConfigs
from ..lm import OpenAIModel
from ..interface import Engine, LMConfigs, Retriever
from ..lm import OpenAIModel, AzureOpenAIModel
from ..utils import FileIOHelper, makeStringRed, truncate_filename


@@ -39,26 +38,35 @@ def __init__(self):
def init_openai_model(
self,
openai_api_key: str,
azure_api_key: str,
openai_type: Literal["openai", "azure"],
api_base: Optional[str] = None,
api_version: Optional[str] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 0.9,
):
"""Legacy: Corresponding to the original setup in the NAACL'24 paper."""
azure_kwargs = {
"api_key": azure_api_key,
"temperature": temperature,
"top_p": top_p,
"api_base": api_base,
"api_version": api_version,
}

openai_kwargs = {
"api_key": openai_api_key,
"api_provider": openai_type,
"api_provider": "openai",
"temperature": temperature,
"top_p": top_p,
"api_base": None,
}
if openai_type and openai_type == "openai":
self.conv_simulator_lm = OpenAIModel(
model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs
model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs
)
self.question_asker_lm = OpenAIModel(
model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs
model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs
)
# 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.)
self.outline_gen_lm = OpenAIModel(
@@ -70,6 +78,32 @@ def init_openai_model(
self.article_polish_lm = OpenAIModel(
model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs
)
elif openai_type and openai_type == "azure":
self.conv_simulator_lm = OpenAIModel(
model="gpt-4o-mini-2024-07-18", max_tokens=500, **openai_kwargs
)
self.question_asker_lm = AzureOpenAIModel(
model="gpt-4o-mini-2024-07-18",
max_tokens=500,
**azure_kwargs,
model_type="chat",
)
# use combination of openai and azure-openai as azure-openai does not support gpt-4 in standard deployment
self.outline_gen_lm = AzureOpenAIModel(
model="gpt-4o", max_tokens=400, **azure_kwargs, model_type="chat"
)
self.article_gen_lm = AzureOpenAIModel(
model="gpt-4o-mini-2024-07-18",
max_tokens=700,
**azure_kwargs,
model_type="chat",
)
self.article_polish_lm = AzureOpenAIModel(
model="gpt-4o-mini-2024-07-18",
max_tokens=4000,
**azure_kwargs,
model_type="chat",
)
else:
logging.warning(
"No valid OpenAI API provider is provided. Cannot use default LLM configurations."
@@ -145,7 +179,7 @@ def __init__(
self.args = args
self.lm_configs = lm_configs

self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k)
self.retriever = Retriever(rm=rm, max_thread=self.args.max_thread_num)
storm_persona_generator = StormPersonaGenerator(
self.lm_configs.question_asker_lm
)
12 changes: 4 additions & 8 deletions knowledge_storm/storm_wiki/modules/article_generation.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,8 @@
import dspy

from .callback import BaseCallbackHandler
from .storm_dataclass import StormInformationTable, StormArticle, StormInformation
from ...interface import ArticleGenerationModule
from .storm_dataclass import StormInformationTable, StormArticle
from ...interface import ArticleGenerationModule, Information
from ...utils import ArticleTextProcessing


@@ -33,7 +33,7 @@ def __init__(
def generate_section(
self, topic, section_name, information_table, section_outline, section_query
):
collected_info: List[StormInformation] = []
collected_info: List[Information] = []
if information_table is not None:
collected_info = information_table.retrieve_information(
queries=section_query, search_top_k=self.retrieve_top_k
@@ -143,11 +143,7 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.engine = engine

def forward(
self,
topic: str,
outline: str,
section: str,
collected_info: List[StormInformation],
self, topic: str, outline: str, section: str, collected_info: List[Information]
):
info = ""
for idx, storm_info in enumerate(collected_info):
6 changes: 4 additions & 2 deletions knowledge_storm/storm_wiki/modules/article_polish.py
Original file line number Diff line number Diff line change
@@ -85,14 +85,16 @@ def __init__(
self.polish_page = dspy.Predict(PolishPage)

def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True):
with dspy.settings.context(lm=self.write_lead_engine):
# NOTE: Change show_guidelines to false to make the generation more robust to different LM families.
with dspy.settings.context(lm=self.write_lead_engine, show_guidelines=False):
lead_section = self.write_lead(
topic=topic, draft_page=draft_page
).lead_section
if "The lead section:" in lead_section:
lead_section = lead_section.split("The lead section:")[1].strip()
if polish_whole_page:
with dspy.settings.context(lm=self.polish_engine):
# NOTE: Change show_guidelines to false to make the generation more robust to different LM families.
with dspy.settings.context(lm=self.polish_engine, show_guidelines=False):
page = self.polish_page(draft_page=draft_page).page
else:
page = draft_page
11 changes: 5 additions & 6 deletions knowledge_storm/storm_wiki/modules/knowledge_curation.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@

from .callback import BaseCallbackHandler
from .persona_generator import StormPersonaGenerator
from .storm_dataclass import DialogueTurn, StormInformationTable, StormInformation
from ...interface import KnowledgeCurationModule, Retriever
from .storm_dataclass import DialogueTurn, StormInformationTable
from ...interface import KnowledgeCurationModule, Retriever, Information
from ...utils import ArticleTextProcessing

try:
@@ -166,7 +166,7 @@ class QuestionToQuery(dspy.Signature):

class AnswerQuestion(dspy.Signature):
"""You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response.
Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.".
Make your response as informative as possible, ensuring that every sentence is supported by the gathered information. If the [gathered information] is not directly related to the [topic] or [question], provide the most relevant answer based on the available information. If no appropriate answer can be formulated, respond with, “I cannot answer this question based on the available information,” and explain any limitations or gaps.
"""

topic = dspy.InputField(prefix="Topic you are discussing about:", format=str)
@@ -196,14 +196,13 @@ def __init__(
super().__init__()
self.generate_queries = dspy.Predict(QuestionToQuery)
self.retriever = retriever
self.retriever.update_search_top_k(search_top_k)
self.answer_question = dspy.Predict(AnswerQuestion)
self.engine = engine
self.max_search_queries = max_search_queries
self.search_top_k = search_top_k

def forward(self, topic: str, question: str, ground_truth_url: str):
with dspy.settings.context(lm=self.engine):
with dspy.settings.context(lm=self.engine, show_guidelines=False):
# Identify: Break down question into queries.
queries = self.generate_queries(topic=topic, question=question).queries
queries = [
@@ -212,7 +211,7 @@ def forward(self, topic: str, question: str, ground_truth_url: str):
]
queries = queries[: self.max_search_queries]
# Search
searched_results: List[StormInformation] = self.retriever.retrieve(
searched_results: List[Information] = self.retriever.retrieve(
list(set(queries)), exclude_urls=[ground_truth_url]
)
if len(searched_results) > 0:
24 changes: 0 additions & 24 deletions knowledge_storm/storm_wiki/modules/retriever.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@

import dspy

from .storm_dataclass import StormInformation
from ...interface import Retriever, Information
from ...utils import ArticleTextProcessing

@@ -232,26 +231,3 @@ def is_valid_wikipedia_source(url):
return False

return True


class StormRetriever(Retriever):
def __init__(self, rm: dspy.Retrieve, k=3):
super().__init__(search_top_k=k)
self._rm = rm
if hasattr(rm, "is_valid_source"):
rm.is_valid_source = is_valid_wikipedia_source

def retrieve(
self, query: Union[str, List[str]], exclude_urls: List[str] = []
) -> List[Information]:
retrieved_data_list = self._rm(
query_or_queries=query, exclude_urls=exclude_urls
)
for data in retrieved_data_list:
for i in range(len(data["snippets"])):
# STORM generate the article with citations. We do not consider multi-hop citations.
# Remove citations in the source to avoid confusion.
data["snippets"][i] = ArticleTextProcessing.remove_citations(
data["snippets"][i]
)
return [StormInformation.from_dict(data) for data in retrieved_data_list]
75 changes: 9 additions & 66 deletions knowledge_storm/storm_wiki/modules/storm_dataclass.py
Original file line number Diff line number Diff line change
@@ -11,69 +11,13 @@
from ...utils import ArticleTextProcessing, FileIOHelper


class StormInformation(Information):
"""Class to represent detailed information.
Inherits from Information to include a unique identifier (URL), and extends
it with a description, snippets, and title of the storm information.
Attributes:
description (str): Brief description.
snippets (list): List of brief excerpts or snippets.
title (str): The title or headline of the information.
url (str): The unique URL (serving as UUID) of the information.
"""

def __init__(self, uuid, description, snippets, title):
"""Initialize the StormInformation object with detailed attributes.
Args:
uuid (str): The unique URL serving as the identifier for the information.
description (str): Detailed description.
snippets (list): List of brief excerpts or snippet.
title (str): The title or headline of the information.
"""
super().__init__(uuid=uuid, meta={})
self.description = description
self.snippets = snippets
self.title = title
self.url = self.uuid

@classmethod
def from_dict(cls, info_dict):
"""Create a StormInformation object from a dictionary.
Usage: storm_info = StormInformation.from_dict(storm_info_dict)
Args:
info_dict (dict): A dictionary containing keys 'uuid', 'description',
'snippets', and 'title' corresponding to the object's attributes.
Returns:
StormInformation: An instance of StormInformation.
"""
return cls(
info_dict["url"],
info_dict["description"],
info_dict["snippets"],
info_dict["title"],
)

def to_dict(self):
return {
"url": self.uuid,
"description": self.description,
"snippets": self.snippets,
"title": self.title,
}


class DialogueTurn:
def __init__(
self,
agent_utterance: str = None,
user_utterance: str = None,
search_queries: Optional[List[str]] = None,
search_results: Optional[List[Union[StormInformation, Dict]]] = None,
search_results: Optional[List[Union[Information, Dict]]] = None,
):
self.agent_utterance = agent_utterance
self.user_utterance = user_utterance
@@ -83,15 +27,14 @@ def __init__(
if self.search_results:
for idx in range(len(self.search_results)):
if type(self.search_results[idx]) == dict:
self.search_results[idx] = StormInformation.from_dict(
self.search_results[idx] = Information.from_dict(
self.search_results[idx]
)

def log(self):
"""
Returns a json object that contains all information inside `self`
"""

return OrderedDict(
{
"agent_utterance": self.agent_utterance,
@@ -115,14 +58,14 @@ class StormInformationTable(InformationTable):
def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]):
super().__init__()
self.conversations = conversations
self.url_to_info: Dict[str, StormInformation] = (
self.url_to_info: Dict[str, Information] = (
StormInformationTable.construct_url_to_info(self.conversations)
)

@staticmethod
def construct_url_to_info(
conversations: List[Tuple[str, List[DialogueTurn]]]
) -> Dict[str, StormInformation]:
) -> Dict[str, Information]:
url_to_info = {}

for persona, conv in conversations:
@@ -177,7 +120,7 @@ def prepare_table_for_retrieval(self):

def retrieve_information(
self, queries: Union[List[str], str], search_top_k
) -> List[StormInformation]:
) -> List[Information]:
selected_urls = []
selected_snippets = []
if type(queries) is str:
@@ -231,13 +174,13 @@ def find_section(
return None

def _merge_new_info_to_references(
self, new_info_list: List[StormInformation], index_to_keep=None
self, new_info_list: List[Information], index_to_keep=None
) -> Dict[int, int]:
"""
Merges new storm information into existing references and updates the citation index mapping.
Args:
new_info_list (List[StormInformation]): A list of dictionaries representing new storm information.
new_info_list (List[Information]): A list of dictionaries representing new storm information.
index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all.
Returns:
@@ -308,7 +251,7 @@ def insert_or_create_section(
def update_section(
self,
current_section_content: str,
current_section_info_list: List[StormInformation],
current_section_info_list: List[Information],
parent_section_name: Optional[str] = None,
) -> Optional[ArticleSectionNode]:
"""
@@ -552,7 +495,7 @@ def from_string(cls, topic_name: str, article_text: str, references: dict):
article = cls(topic_name=topic_name)
article.insert_or_create_section(article_dict=article_dict)
for url in list(references["url_to_info"]):
references["url_to_info"][url] = StormInformation.from_dict(
references["url_to_info"][url] = Information.from_dict(
references["url_to_info"][url]
)
article.reference = references
Loading

0 comments on commit 564a507

Please sign in to comment.