diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 67536deed3..570f725215 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5491,3 +5491,241 @@ def invoke_sub_agent(state: AgentState): "invoke_sub_agent": {"input": True}, }, ] + + +def test_map_reduce() -> None: + """Test map reduce with Graph API.""" + import operator + from typing import Annotated, NotRequired, TypedDict + + class OverallState(TypedDict): + subjects: list + # Notice here we use the operator.add + # This is because we want combine all the jokes we generate + # from individual nodes back into one list - this is essentially + # the "reduce" part + jokes_with_grades: Annotated[list, operator.add] + best_selected_joke: str + + # This will be the state of the node that we will "map" all + # subjects to in order to generate a joke + class JokeState(TypedDict): + subject: str + joke: NotRequired[str] + joke_grade: NotRequired[int] + + def generate_jokes(state: OverallState): + # Distribute the subjects to the generate_joke node + return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] + + def generate_joke(state: JokeState): + # Come up with a joke and continue to grade the joke + funny_joke = f"Funny joke about {state['subject']}" + return [Send("grade_joke", {"joke": funny_joke, "subject": state["subject"]})] + + def grade_joke(state: JokeState): + # Grade the joke and then send it to the count_grade node which will reduce + if state["subject"].startswith("a"): + grade = 10 + elif state["subject"].startswith("b"): + grade = 9 + else: + grade = 1 + return { + "jokes_with_grades": [ + { + "joke": state["joke"], + "grade": grade, + } + ] + } + + def count_grade(state: OverallState): + max_score = 0 + best_joke = "" + for joke in state["jokes_with_grades"]: + if joke["grade"] > max_score: + max_score = joke["grade"] + best_joke = joke["joke"] + return {"best_selected_joke": best_joke} + + builder = StateGraph(OverallState) + builder.add_conditional_edges("__start__", generate_jokes, ["generate_joke"]) + builder.add_conditional_edges("generate_joke", grade_joke, ["count_grade"]) + builder.add_node(grade_joke) + builder.add_node(count_grade) + builder.add_edge("count_grade", END) + graph = builder.compile() + result = graph.invoke({"subjects": ["apple", "banana", "carrot"]}) + assert result == { + "subjects": ["apple", "banana", "carrot"], + "jokes_with_grades": [ + {"joke": "Funny joke about apple", "grade": 10}, + {"joke": "Funny joke about banana", "grade": 9}, + {"joke": "Funny joke about carrot", "grade": 1}, + ], + "best_selected_joke": "Funny joke about apple", + } + + +def test_map_reduce_v2() -> None: + """Test map reduce with Graph API.""" + import operator + from typing import Annotated, NotRequired, TypedDict + from langgraph.graph import StateGraph + from langgraph.types import Send, Command + + class OverallState(TypedDict): + subjects: list + # Notice here we use the operator.add + # This is because we want combine all the jokes we generate + # from individual nodes back into one list - this is essentially + # the "reduce" part + jokes_with_grades: Annotated[list, operator.add] + best_selected_joke: str + + # This will be the state of the node that we will "map" all + # subjects to in order to generate a joke + class JokeState(TypedDict): + subject: str + joke: NotRequired[str] + joke_grade: NotRequired[int] + + def generate_jokes(state: OverallState): + # Distribute the subjects to the generate_joke node + return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] + + def generate_joke(state: JokeState): + # Come up with a joke and continue to grade the joke + funny_joke = f"Funny joke about {state['subject']}" + return Command( + goto=[Send("grade_joke", {"joke": funny_joke, "subject": state["subject"]})] + ) + + def grade_joke(state: JokeState): + # Grade the joke and then send it to the count_grade node which will reduce + if state["subject"].startswith("a"): + grade = 10 + elif state["subject"].startswith("b"): + grade = 9 + else: + grade = 1 + return { + "jokes_with_grades": [ + { + "joke": state["joke"], + "grade": grade, + } + ] + } + + def count_grade(state: OverallState): + max_score = 0 + best_joke = "" + for joke in state["jokes_with_grades"]: + if joke["grade"] > max_score: + max_score = joke["grade"] + best_joke = joke["joke"] + return {"best_selected_joke": best_joke} + + builder = StateGraph(OverallState) + builder.add_node(generate_joke) + builder.add_node(grade_joke) + builder.add_node(count_grade) + builder.add_conditional_edges("__start__", generate_jokes, ["generate_joke"]) + builder.add_edge("grade_joke", "count_grade") + builder.add_edge("count_grade", END) + graph = builder.compile() + result = graph.invoke({"subjects": ["apple", "banana", "carrot"]}) + assert result == { + "subjects": ["apple", "banana", "carrot"], + "jokes_with_grades": [ + {"joke": "Funny joke about apple", "grade": 10}, + {"joke": "Funny joke about banana", "grade": 9}, + {"joke": "Funny joke about carrot", "grade": 1}, + ], + "best_selected_joke": "Funny joke about apple", + } + + +def test_map_reduce_v3() -> None: + """Test map reduce with Graph API.""" + import operator + from typing import Annotated, NotRequired, TypedDict + from langgraph.graph import StateGraph + from langgraph.types import Send, Command + + class OverallState(TypedDict): + subjects: list + # Notice here we use the operator.add + # This is because we want combine all the jokes we generate + # from individual nodes back into one list - this is essentially + # the "reduce" part + jokes_with_grades: Annotated[list, operator.add] + best_selected_joke: str + + # This will be the state of the node that we will "map" all + # subjects to in order to generate a joke + class JokeState(TypedDict): + subject: str + joke: NotRequired[str] + joke_grade: NotRequired[int] + + def generate_jokes(state: OverallState): + # Distribute the subjects to the generate_joke node + sends = [Send("generate_joke", {"subject": s}) for s in state["subjects"]] + return Command(goto=sends) + + def generate_joke(state: JokeState): + # Come up with a joke and continue to grade the joke + funny_joke = f"Funny joke about {state['subject']}" + return Command( + goto=[Send("grade_joke", {"joke": funny_joke, "subject": state["subject"]})] + ) + + def grade_joke(state: JokeState): + # Grade the joke and then send it to the count_grade node which will reduce + if state["subject"].startswith("a"): + grade = 10 + elif state["subject"].startswith("b"): + grade = 9 + else: + grade = 1 + return { + "jokes_with_grades": [ + { + "joke": state["joke"], + "grade": grade, + } + ] + } + + def count_grade(state: OverallState): + max_score = 0 + best_joke = "" + for joke in state["jokes_with_grades"]: + if joke["grade"] > max_score: + max_score = joke["grade"] + best_joke = joke["joke"] + return {"best_selected_joke": best_joke} + + builder = StateGraph(OverallState) + builder.add_node(generate_jokes) + builder.add_node(generate_joke) + builder.add_node(grade_joke) + builder.add_node(count_grade) + builder.add_edge("__start__", "generate_jokes") + builder.add_edge("generate_jokes", "generate_joke") + builder.add_edge("grade_joke", "count_grade") + builder.add_edge("count_grade", END) + graph = builder.compile() + result = graph.invoke({"subjects": ["apple", "banana", "carrot"]}) + assert result == { + "subjects": ["apple", "banana", "carrot"], + "jokes_with_grades": [ + {"joke": "Funny joke about apple", "grade": 10}, + {"joke": "Funny joke about banana", "grade": 9}, + {"joke": "Funny joke about carrot", "grade": 1}, + ], + "best_selected_joke": "Funny joke about apple", + }