Skip to content

Commit

Permalink
Add deployments (#18)
Browse files Browse the repository at this point in the history
* deployment behavior

* fix deployments

* version not necessary

* fix behavior for deployments

* remove unused import

* update tests

* add readme explanation

* update version to 1.1.1
  • Loading branch information
cbh123 authored Sep 16, 2023
1 parent 1cea0e9 commit 162e6b1
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 4 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Install by adding `replicate` to your list of dependencies in `mix.exs`:
```elixir
def deps do
[
{:replicate, "~> 1.1.0"}
{:replicate, "~> 1.1.1"}
]
end
```
Expand Down Expand Up @@ -219,4 +219,15 @@ iex> {{_, 200, 'OK'}, _headers, body} = resp
iex> File.write!("babadook_watercolor.jpg", body)
```

## Create prediction from deployment

Deployments allow you to control the configuration of a model with a private, fixed API endpoint. You can control the version of the model, the hardware it runs on, and how it scales.

Once you create a deployment on Replicate, you can make predictions like this:

```elixir
iex> {:ok, deployment} = Replicate.Deployments.get("test/model")
iex> {:ok, prediction} = Replicate.Deployments.create_prediction(deployment, %{prompt: "a 19th century portrait of a wombat gentleman"})
```

# replicate-elixir
85 changes: 85 additions & 0 deletions lib/deployments.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
defmodule Replicate.Deployments do
@moduledoc """
Documentation for `Predictions`.
"""
@behaviour Replicate.Deployments.Behaviour
@replicate_client Application.compile_env(:replicate, :replicate_client, Replicate.Client)

alias Replicate.Deployments.Deployment
alias Replicate.Predictions.Prediction

@doc """
Gets a deployment by name, in the format `owner/model-name`.
## Examples
```
iex> {:ok, deployment} = Replicate.Deployments.get("test/model")
iex> deployment.username
"test"
iex> Replicate.Predictions.get("not_a_real_id")
{:error, "Not found"}
```
"""
def get(name) do
[owner, model_name] = String.split(name, "/")
{:ok, %Deployment{username: owner, name: model_name}}
end

@doc """
Create a new prediction with the deployment. The input parameter should be a map of the model inputs.
## Examples
```
iex> {:ok, deployment} = Replicate.Deployments.get("test/model")
iex> {:ok, prediction} = Replicate.Deployments.create_prediction(deployment, %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> prediction.status
"starting"
```
"""
def create_prediction(
%Deployment{username: username, name: name},
input,
webhook \\ nil,
webhook_completed \\ nil,
webhook_event_filter \\ nil,
stream \\ nil
) do
webhook_parameters =
%{
"webhook" => webhook,
"webhook_completed" => webhook_completed,
"webhook_event_filter" => webhook_event_filter,
"stream" => stream
}
|> Enum.filter(fn {_key, value} -> !is_nil(value) end)
|> Enum.into(%{})

body =
%{
"input" => input |> Enum.into(%{})
}
|> Map.merge(webhook_parameters)
|> Jason.encode!()

@replicate_client.request(:post, "/v1/deployments/#{username}/#{name}/predictions", body)
|> parse_response()
end

defp parse_response({:ok, json_body}) do
body =
json_body
|> Jason.decode!()
|> string_to_atom()

{:ok, struct(Prediction, body)}
end

defp parse_response({:error, message}), do: {:error, message}

defp string_to_atom(body) do
for {k, v} <- body, into: %{}, do: {String.to_atom(k), v}
end
end
17 changes: 17 additions & 0 deletions lib/deployments/behaviour.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defmodule Replicate.Deployments.Behaviour do
@moduledoc """
Documentation for the Deployment Behaviour.
"""
alias Replicate.Deployments.Deployment

@callback get(String.t()) :: {:ok, Deployment.t()} | {:error, String.t()}
@callback create_prediction(
Deployment.t(),
input :: %{string: any},
webhook :: list(String.t()),
webhook_completed :: list(String.t()),
webook_event_filter :: list(String.t()),
stream :: boolean()
) ::
{:ok, Replicate.Predictions.Prediction.t()} | {:error, String.t()}
end
9 changes: 9 additions & 0 deletions lib/deployments/deployment.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defmodule Replicate.Deployments.Deployment do
@moduledoc """
`Deployment` struct.
"""
defstruct [
:username,
:name
]
end
4 changes: 2 additions & 2 deletions lib/mock_client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ defmodule Replicate.MockClient do
],
urls: %{
"get" => "https://api.replicate.com/v1/predictions/1234",
"cancel" => "https://api.replicate.com/v1/predictions/1234/cancel",
"cancel" => "https://api.replicate.com/v1/predictions/1234/cancel"
}
}
@stub_prediction2 %{
Expand All @@ -27,7 +27,7 @@ defmodule Replicate.MockClient do
],
urls: %{
"get" => "https://api.replicate.com/v1/predictions/1235",
"cancel" => "https://api.replicate.com/v1/predictions/1235/cancel",
"cancel" => "https://api.replicate.com/v1/predictions/1235/cancel"
}
}

Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Replicate.MixProject do
def project do
[
app: :replicate,
version: "1.1.0",
version: "1.1.1",
elixir: "~> 1.14",
start_permanent: Mix.env() == :prod,
start_permanent: Mix.env() == :prod,
Expand Down
12 changes: 12 additions & 0 deletions test/replicate_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ defmodule ReplicateTest do
doctest Replicate
doctest Replicate.Predictions
doctest Replicate.Models
doctest Replicate.Deployments

# Make sure mocks are verified when the test exits
setup :verify_on_exit!
Expand Down Expand Up @@ -79,4 +80,15 @@ defmodule ReplicateTest do
assert first_version.id == "v1"
assert first_version.cog_version == "0.3.0"
end

test "create a deployment prediction" do
{:ok, deployment} = Replicate.Deployments.get("test/model")

{:ok, prediction} =
Replicate.Deployments.create_prediction(deployment, %{
prompt: "a 19th century portrait of a wombat gentleman"
})

assert prediction.status == "starting"
end
end

0 comments on commit 162e6b1

Please sign in to comment.