diff --git a/lib/req_llm/embedding.ex b/lib/req_llm/embedding.ex index 80dd1c86..3860d658 100644 --- a/lib/req_llm/embedding.ex +++ b/lib/req_llm/embedding.ex @@ -174,10 +174,13 @@ defmodule ReqLLM.Embedding do def embed(model_spec, input, opts \\ []) def embed(model_spec, text, opts) when is_binary(text) do + {plugins, opts} = Keyword.pop(opts, :req_plugins, []) + with {:ok, model} <- validate_model(model_spec), :ok <- validate_input(text), {:ok, provider_module} <- ReqLLM.provider(model.provider), {:ok, request} <- provider_module.prepare_request(:embedding, model, text, opts), + request = apply_plugins(request, plugins), {:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <- Req.request(request) do extract_single_embedding(decoded_response) @@ -196,10 +199,13 @@ defmodule ReqLLM.Embedding do end def embed(model_spec, texts, opts) when is_list(texts) do + {plugins, opts} = Keyword.pop(opts, :req_plugins, []) + with {:ok, model} <- validate_model(model_spec), :ok <- validate_input(texts), {:ok, provider_module} <- ReqLLM.provider(model.provider), {:ok, request} <- provider_module.prepare_request(:embedding, model, texts, opts), + request = apply_plugins(request, plugins), {:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <- Req.request(request) do extract_multiple_embeddings(decoded_response) @@ -217,6 +223,10 @@ defmodule ReqLLM.Embedding do end end + defp apply_plugins(request, plugins) do + Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end) + end + defp validate_input("") do {:error, ReqLLM.Error.Invalid.Parameter.exception(parameter: "text: cannot be empty")} end diff --git a/lib/req_llm/generation.ex b/lib/req_llm/generation.ex index a12b0f61..02708ea0 100644 --- a/lib/req_llm/generation.ex +++ b/lib/req_llm/generation.ex @@ -51,6 +51,8 @@ defmodule ReqLLM.Generation do * `:tool_choice` - Tool choice strategy * `:system_prompt` - System prompt to prepend * `:provider_options` - Provider-specific options + * `:req_plugins` - List of functions to transform the `Req.Request` before execution. + Each function receives a `Req.Request` and returns a modified `Req.Request`. ## Examples @@ -70,9 +72,12 @@ defmodule ReqLLM.Generation do keyword() ) :: {:ok, Response.t()} | {:error, term()} def generate_text(model_spec, messages, opts \\ []) do + {plugins, opts} = Keyword.pop(opts, :req_plugins, []) + with {:ok, model} <- ReqLLM.model(model_spec), {:ok, provider_module} <- ReqLLM.provider(model.provider), {:ok, request} <- provider_module.prepare_request(:chat, model, messages, opts), + request = apply_plugins(request, plugins), {:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <- Req.request(request) do {:ok, decoded_response} @@ -213,6 +218,8 @@ defmodule ReqLLM.Generation do * `:frequency_penalty` - Penalize new tokens based on frequency * `:system_prompt` - System prompt to prepend * `:provider_options` - Provider-specific options + * `:req_plugins` - List of functions to transform the `Req.Request` before execution. + Each function receives a `Req.Request` and returns a modified `Req.Request`. ## Examples @@ -232,12 +239,15 @@ defmodule ReqLLM.Generation do keyword() ) :: {:ok, Response.t()} | {:error, term()} def generate_object(model_spec, messages, object_schema, opts \\ []) do + {plugins, opts} = Keyword.pop(opts, :req_plugins, []) + with {:ok, model} <- ReqLLM.model(model_spec), {:ok, provider_module} <- ReqLLM.provider(model.provider), {:ok, compiled_schema} <- ReqLLM.Schema.compile(object_schema), opts_with_schema = Keyword.put(opts, :compiled_schema, compiled_schema), {:ok, request} <- provider_module.prepare_request(:object, model, messages, opts_with_schema), + request = apply_plugins(request, plugins), {:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <- Req.request(request) do # For models with json.strict = false, coerce response types to match schema @@ -393,6 +403,10 @@ defmodule ReqLLM.Generation do defp coerce_value(value, _type), do: value + defp apply_plugins(request, plugins) do + Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end) + end + @doc """ Streams structured data generation using an AI model with schema validation. diff --git a/lib/req_llm/images.ex b/lib/req_llm/images.ex index f0abb8a0..ac323b8c 100644 --- a/lib/req_llm/images.ex +++ b/lib/req_llm/images.ex @@ -108,10 +108,13 @@ defmodule ReqLLM.Images do keyword() ) :: {:ok, Response.t()} | {:error, term()} def generate_image(model_spec, prompt_or_messages, opts \\ []) do + {plugins, opts} = Keyword.pop(opts, :req_plugins, []) + with {:ok, model} <- ReqLLM.model(model_spec), {:ok, provider_module} <- ReqLLM.provider(model.provider), {:ok, request} <- provider_module.prepare_request(:image, model, prompt_or_messages, opts), + request = apply_plugins(request, plugins), {:ok, %Req.Response{status: status, body: response}} when status in 200..299 <- Req.request(request) do {:ok, response} @@ -129,6 +132,10 @@ defmodule ReqLLM.Images do end end + defp apply_plugins(request, plugins) do + Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end) + end + @doc """ Returns a list of model specs that likely support image generation. diff --git a/test/req_llm/generation_test.exs b/test/req_llm/generation_test.exs index c5c9a205..8565f49b 100644 --- a/test/req_llm/generation_test.exs +++ b/test/req_llm/generation_test.exs @@ -321,4 +321,124 @@ defmodule ReqLLM.GenerationTest do assert %StreamResponse{} = response end end + + describe "req_plugins option" do + test "generate_text/3 applies plugins to the request" do + test_pid = self() + + Req.Test.stub(ReqLLM.GenerationTestPlugins, fn conn -> + custom_header = Plug.Conn.get_req_header(conn, "x-custom-header") + send(test_pid, {:custom_header, custom_header}) + + Req.Test.json(conn, %{ + "id" => "cmpl_test_123", + "model" => "gpt-4o-mini-2024-07-18", + "choices" => [ + %{ + "message" => %{"role" => "assistant", "content" => "Hello!"} + } + ], + "usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15} + }) + end) + + add_header_plugin = fn req -> + Req.Request.put_header(req, "x-custom-header", "custom-value") + end + + {:ok, response} = + Generation.generate_text( + "openai:gpt-4o-mini", + "Hello", + req_plugins: [add_header_plugin], + req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestPlugins}] + ) + + assert %Response{} = response + assert_receive {:custom_header, ["custom-value"]} + end + + test "generate_text/3 applies multiple plugins in order" do + test_pid = self() + + Req.Test.stub(ReqLLM.GenerationTestPluginsOrder, fn conn -> + header1 = Plug.Conn.get_req_header(conn, "x-first-header") + header2 = Plug.Conn.get_req_header(conn, "x-second-header") + send(test_pid, {:headers, header1, header2}) + + Req.Test.json(conn, %{ + "id" => "cmpl_test_123", + "model" => "gpt-4o-mini-2024-07-18", + "choices" => [ + %{ + "message" => %{"role" => "assistant", "content" => "Hello!"} + } + ], + "usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15} + }) + end) + + first_plugin = fn req -> + Req.Request.put_header(req, "x-first-header", "first-value") + end + + second_plugin = fn req -> + Req.Request.put_header(req, "x-second-header", "second-value") + end + + {:ok, _response} = + Generation.generate_text( + "openai:gpt-4o-mini", + "Hello", + req_plugins: [first_plugin, second_plugin], + req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestPluginsOrder}] + ) + + assert_receive {:headers, ["first-value"], ["second-value"]} + end + + test "generate_object/4 applies plugins to the request" do + test_pid = self() + + Req.Test.stub(ReqLLM.GenerationTestObjectPlugins, fn conn -> + custom_header = Plug.Conn.get_req_header(conn, "x-object-header") + send(test_pid, {:object_header, custom_header}) + + Req.Test.json(conn, %{ + "id" => "cmpl_test_123", + "model" => "gpt-4o-mini-2024-07-18", + "choices" => [ + %{ + "message" => %{ + "role" => "assistant", + "content" => Jason.encode!(%{"name" => "Alice", "age" => 30}) + } + } + ], + "usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15} + }) + end) + + add_header_plugin = fn req -> + Req.Request.put_header(req, "x-object-header", "object-value") + end + + schema = [ + name: [type: :string, required: true], + age: [type: :integer, required: true] + ] + + {:ok, response} = + Generation.generate_object( + "openai:gpt-4o-mini", + "Generate a person", + schema, + req_plugins: [add_header_plugin], + req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestObjectPlugins}] + ) + + assert %Response{} = response + assert_receive {:object_header, ["object-value"]} + end + end end