Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI prototype with new bindings #543

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions worker-sys/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod ai;
mod context;
#[cfg(feature = "d1")]
mod d1;
Expand All @@ -14,6 +15,7 @@ mod socket;
mod tls_client_auth;
mod websocket_pair;

pub use ai::*;
pub use context::*;
#[cfg(feature = "d1")]
pub use d1::*;
Expand Down
12 changes: 12 additions & 0 deletions worker-sys/src/types/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use js_sys::Promise;
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(extends=::js_sys::Object, js_name=Ai)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub type Ai;

#[wasm_bindgen(structural, method, js_class=Ai, js_name=run)]
pub fn run(this: &Ai, model: &str, input: JsValue) -> Promise;
}
121 changes: 121 additions & 0 deletions worker/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use crate::{env::EnvBinding, send::SendFuture};
use crate::{Error, Result};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use worker_sys::{console_log, Ai as AiSys};

pub struct Ai(AiSys);

#[derive(Deserialize)]
struct TextEmbeddingOutput {
data: Vec<Vec<f64>>,
}

#[derive(Serialize)]
pub struct TextEmbeddingInput<'a> {
text: Vec<&'a str>,
}

impl Ai {
pub async fn run<T: Serialize, U: DeserializeOwned>(&self, model: &str, input: T) -> Result<U> {
let fut = SendFuture::new(JsFuture::from(
self.0.run(model, serde_wasm_bindgen::to_value(&input)?),
));
match fut.await {
Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
Err(err) => Err(Error::from(err)),
}
}

pub async fn embed<'a, S: AsRef<str> + 'a, T: IntoIterator<Item = S>>(
&self,
model: &str,
input: T,
) -> Result<Vec<Vec<f64>>> {
let iter = input.into_iter();
let items: Vec<S> = iter.collect();
let text = items.iter().map(|s| s.as_ref()).collect();
let arg = TextEmbeddingInput { text };
self.run(model, arg)
.await
.map(|out: TextEmbeddingOutput| out.data)
}
}

unsafe impl Sync for Ai {}
unsafe impl Send for Ai {}

impl From<AiSys> for Ai {
fn from(inner: AiSys) -> Self {
Self(inner)
}
}

impl AsRef<JsValue> for Ai {
fn as_ref(&self) -> &JsValue {
&self.0
}
}

impl From<Ai> for JsValue {
fn from(database: Ai) -> Self {
JsValue::from(database.0)
}
}

impl JsCast for Ai {
fn instanceof(val: &JsValue) -> bool {
val.is_instance_of::<AiSys>()
}

fn unchecked_from_js(val: JsValue) -> Self {
Self(val.into())
}

fn unchecked_from_js_ref(val: &JsValue) -> &Self {
unsafe { &*(val as *const JsValue as *const Self) }
}
}

impl EnvBinding for Ai {
const TYPE_NAME: &'static str = "Ai";

// Workaround for Miniflare D1 Beta
fn get(val: JsValue) -> Result<Self> {
let obj = js_sys::Object::from(val);
console_log!("{}", obj.constructor().name());
if obj.constructor().name() == Self::TYPE_NAME {
Ok(obj.unchecked_into())
} else {
Err(format!(
"Binding cannot be cast to the type {} from {}",
Self::TYPE_NAME,
obj.constructor().name()
)
.into())
}
}
}

#[cfg(test)]
mod test {
use super::Ai;
use wasm_bindgen::JsCast;

#[test]
#[allow(unused_must_use)]
fn text_embedding_input_from() {
let ai: Ai = js_sys::Object::new().unchecked_into();

let s: &str = "foo";

ai.embed("foo-model", [s]);
ai.embed("foo-model", [s.to_owned()]);
ai.embed("foo-model", [&(s.to_owned())]);
ai.embed("foo-model", &[s]);
ai.embed("foo-model", vec![s]);
ai.embed("foo-model", &vec![s]);
}
}
6 changes: 6 additions & 0 deletions worker/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::error::Error;
use crate::Queue;
use crate::{durable::ObjectNamespace, Bucket, DynamicDispatcher, Fetcher, Result};

use crate::Ai;

use js_sys::Object;
use wasm_bindgen::{prelude::*, JsCast, JsValue};
use worker_kv::KvStore;
Expand Down Expand Up @@ -34,6 +36,10 @@ impl Env {
}
}

pub fn ai(&self, binding: &str) -> Result<Ai> {
self.get_binding::<Ai>(binding)
}

/// Access Secret value bindings added to your Worker via the UI or `wrangler`:
/// <https://developers.cloudflare.com/workers/cli-wrangler/commands#secret>
pub fn secret(&self, binding: &str) -> Result<Secret> {
Expand Down
2 changes: 2 additions & 0 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub use worker_sys;
pub use worker_sys::{console_debug, console_error, console_log, console_warn};

pub use crate::abort::*;
pub use crate::ai::*;
pub use crate::cache::{Cache, CacheDeletionOutcome, CacheKey};
pub use crate::context::Context;
pub use crate::cors::Cors;
Expand Down Expand Up @@ -152,6 +153,7 @@ pub use crate::streams::*;
pub use crate::websocket::*;

mod abort;
mod ai;
mod cache;
mod cf;
mod context;
Expand Down
Loading