Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/req_llm/embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ defmodule ReqLLM.Embedding do
def embed(model_spec, input, opts \\ [])

def embed(model_spec, text, opts) when is_binary(text) do
{plugins, opts} = Keyword.pop(opts, :req_plugins, [])

with {:ok, model} <- validate_model(model_spec),
:ok <- validate_input(text),
{:ok, provider_module} <- ReqLLM.provider(model.provider),
{:ok, request} <- provider_module.prepare_request(:embedding, model, text, opts),
request = apply_plugins(request, plugins),
{:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <-
Req.request(request) do
extract_single_embedding(decoded_response)
Expand All @@ -196,10 +199,13 @@ defmodule ReqLLM.Embedding do
end

def embed(model_spec, texts, opts) when is_list(texts) do
{plugins, opts} = Keyword.pop(opts, :req_plugins, [])

with {:ok, model} <- validate_model(model_spec),
:ok <- validate_input(texts),
{:ok, provider_module} <- ReqLLM.provider(model.provider),
{:ok, request} <- provider_module.prepare_request(:embedding, model, texts, opts),
request = apply_plugins(request, plugins),
{:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <-
Req.request(request) do
extract_multiple_embeddings(decoded_response)
Expand All @@ -217,6 +223,10 @@ defmodule ReqLLM.Embedding do
end
end

defp apply_plugins(request, plugins) do
Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end)
end

defp validate_input("") do
{:error, ReqLLM.Error.Invalid.Parameter.exception(parameter: "text: cannot be empty")}
end
Expand Down
14 changes: 14 additions & 0 deletions lib/req_llm/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ defmodule ReqLLM.Generation do
* `:tool_choice` - Tool choice strategy
* `:system_prompt` - System prompt to prepend
* `:provider_options` - Provider-specific options
* `:req_plugins` - List of functions to transform the `Req.Request` before execution.
Each function receives a `Req.Request` and returns a modified `Req.Request`.

## Examples

Expand All @@ -70,9 +72,12 @@ defmodule ReqLLM.Generation do
keyword()
) :: {:ok, Response.t()} | {:error, term()}
def generate_text(model_spec, messages, opts \\ []) do
{plugins, opts} = Keyword.pop(opts, :req_plugins, [])

with {:ok, model} <- ReqLLM.model(model_spec),
{:ok, provider_module} <- ReqLLM.provider(model.provider),
{:ok, request} <- provider_module.prepare_request(:chat, model, messages, opts),
request = apply_plugins(request, plugins),
{:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <-
Req.request(request) do
{:ok, decoded_response}
Expand Down Expand Up @@ -213,6 +218,8 @@ defmodule ReqLLM.Generation do
* `:frequency_penalty` - Penalize new tokens based on frequency
* `:system_prompt` - System prompt to prepend
* `:provider_options` - Provider-specific options
* `:req_plugins` - List of functions to transform the `Req.Request` before execution.
Each function receives a `Req.Request` and returns a modified `Req.Request`.

## Examples

Expand All @@ -232,12 +239,15 @@ defmodule ReqLLM.Generation do
keyword()
) :: {:ok, Response.t()} | {:error, term()}
def generate_object(model_spec, messages, object_schema, opts \\ []) do
{plugins, opts} = Keyword.pop(opts, :req_plugins, [])

with {:ok, model} <- ReqLLM.model(model_spec),
{:ok, provider_module} <- ReqLLM.provider(model.provider),
{:ok, compiled_schema} <- ReqLLM.Schema.compile(object_schema),
opts_with_schema = Keyword.put(opts, :compiled_schema, compiled_schema),
{:ok, request} <-
provider_module.prepare_request(:object, model, messages, opts_with_schema),
request = apply_plugins(request, plugins),
{:ok, %Req.Response{status: status, body: decoded_response}} when status in 200..299 <-
Req.request(request) do
# For models with json.strict = false, coerce response types to match schema
Expand Down Expand Up @@ -393,6 +403,10 @@ defmodule ReqLLM.Generation do

defp coerce_value(value, _type), do: value

defp apply_plugins(request, plugins) do
Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end)
end

@doc """
Streams structured data generation using an AI model with schema validation.

Expand Down
7 changes: 7 additions & 0 deletions lib/req_llm/images.ex
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ defmodule ReqLLM.Images do
keyword()
) :: {:ok, Response.t()} | {:error, term()}
def generate_image(model_spec, prompt_or_messages, opts \\ []) do
{plugins, opts} = Keyword.pop(opts, :req_plugins, [])

with {:ok, model} <- ReqLLM.model(model_spec),
{:ok, provider_module} <- ReqLLM.provider(model.provider),
{:ok, request} <-
provider_module.prepare_request(:image, model, prompt_or_messages, opts),
request = apply_plugins(request, plugins),
{:ok, %Req.Response{status: status, body: response}} when status in 200..299 <-
Req.request(request) do
{:ok, response}
Expand All @@ -129,6 +132,10 @@ defmodule ReqLLM.Images do
end
end

defp apply_plugins(request, plugins) do
Enum.reduce(plugins, request, fn plugin, req -> plugin.(req) end)
end

@doc """
Returns a list of model specs that likely support image generation.

Expand Down
120 changes: 120 additions & 0 deletions test/req_llm/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,124 @@ defmodule ReqLLM.GenerationTest do
assert %StreamResponse{} = response
end
end

describe "req_plugins option" do
test "generate_text/3 applies plugins to the request" do
test_pid = self()

Req.Test.stub(ReqLLM.GenerationTestPlugins, fn conn ->
custom_header = Plug.Conn.get_req_header(conn, "x-custom-header")
send(test_pid, {:custom_header, custom_header})

Req.Test.json(conn, %{
"id" => "cmpl_test_123",
"model" => "gpt-4o-mini-2024-07-18",
"choices" => [
%{
"message" => %{"role" => "assistant", "content" => "Hello!"}
}
],
"usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15}
})
end)

add_header_plugin = fn req ->
Req.Request.put_header(req, "x-custom-header", "custom-value")
end

{:ok, response} =
Generation.generate_text(
"openai:gpt-4o-mini",
"Hello",
req_plugins: [add_header_plugin],
req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestPlugins}]
)

assert %Response{} = response
assert_receive {:custom_header, ["custom-value"]}
end

test "generate_text/3 applies multiple plugins in order" do
test_pid = self()

Req.Test.stub(ReqLLM.GenerationTestPluginsOrder, fn conn ->
header1 = Plug.Conn.get_req_header(conn, "x-first-header")
header2 = Plug.Conn.get_req_header(conn, "x-second-header")
send(test_pid, {:headers, header1, header2})

Req.Test.json(conn, %{
"id" => "cmpl_test_123",
"model" => "gpt-4o-mini-2024-07-18",
"choices" => [
%{
"message" => %{"role" => "assistant", "content" => "Hello!"}
}
],
"usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15}
})
end)

first_plugin = fn req ->
Req.Request.put_header(req, "x-first-header", "first-value")
end

second_plugin = fn req ->
Req.Request.put_header(req, "x-second-header", "second-value")
end

{:ok, _response} =
Generation.generate_text(
"openai:gpt-4o-mini",
"Hello",
req_plugins: [first_plugin, second_plugin],
req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestPluginsOrder}]
)

assert_receive {:headers, ["first-value"], ["second-value"]}
end

test "generate_object/4 applies plugins to the request" do
test_pid = self()

Req.Test.stub(ReqLLM.GenerationTestObjectPlugins, fn conn ->
custom_header = Plug.Conn.get_req_header(conn, "x-object-header")
send(test_pid, {:object_header, custom_header})

Req.Test.json(conn, %{
"id" => "cmpl_test_123",
"model" => "gpt-4o-mini-2024-07-18",
"choices" => [
%{
"message" => %{
"role" => "assistant",
"content" => Jason.encode!(%{"name" => "Alice", "age" => 30})
}
}
],
"usage" => %{"prompt_tokens" => 10, "completion_tokens" => 5, "total_tokens" => 15}
})
end)

add_header_plugin = fn req ->
Req.Request.put_header(req, "x-object-header", "object-value")
end

schema = [
name: [type: :string, required: true],
age: [type: :integer, required: true]
]

{:ok, response} =
Generation.generate_object(
"openai:gpt-4o-mini",
"Generate a person",
schema,
req_plugins: [add_header_plugin],
req_http_options: [plug: {Req.Test, ReqLLM.GenerationTestObjectPlugins}]
)

assert %Response{} = response
assert_receive {:object_header, ["object-value"]}
end
end
end