Skip to content

Commit df23062

Browse files
authored
improve endpoint support (#1577)
small PR to add a new interface endpoint behind a feature
1 parent d19c768 commit df23062

File tree

4 files changed

+170
-7
lines changed

4 files changed

+170
-7
lines changed

router/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
5252
[features]
5353
default = ["ngrok"]
5454
ngrok = ["dep:ngrok"]
55+
google = []

router/src/lib.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = (
2020
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
2121
);
2222

23+
#[derive(Clone, Deserialize, ToSchema)]
24+
pub(crate) struct VertexInstance {
25+
#[schema(example = "What is Deep Learning?")]
26+
pub inputs: String,
27+
#[schema(nullable = true, default = "null", example = "null")]
28+
pub parameters: Option<GenerateParameters>,
29+
}
30+
31+
#[derive(Deserialize, ToSchema)]
32+
pub(crate) struct VertexRequest {
33+
#[serde(rename = "instances")]
34+
pub instances: Vec<VertexInstance>,
35+
}
36+
37+
#[derive(Clone, Deserialize, ToSchema, Serialize)]
38+
pub(crate) struct VertexResponse {
39+
pub predictions: Vec<String>,
40+
}
41+
2342
/// Hub type
2443
#[derive(Clone, Debug, Deserialize)]
2544
pub struct HubModelInfo {
@@ -70,7 +89,7 @@ mod json_object_or_string_to_string {
7089
}
7190
}
7291

73-
#[derive(Clone, Debug, Deserialize)]
92+
#[derive(Clone, Debug, Deserialize, ToSchema)]
7493
#[serde(tag = "type", content = "value")]
7594
pub(crate) enum GrammarType {
7695
#[serde(
@@ -153,7 +172,7 @@ pub struct Info {
153172
pub docker_label: Option<&'static str>,
154173
}
155174

156-
#[derive(Clone, Debug, Deserialize, ToSchema)]
175+
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
157176
pub(crate) struct GenerateParameters {
158177
#[serde(default)]
159178
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]

router/src/main.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
328328
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
329329
tracing::info!("Connected");
330330

331+
// Determine the server port based on the feature and environment variable.
332+
let port = if cfg!(feature = "google") {
333+
std::env::var("AIP_HTTP_PORT")
334+
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
335+
.unwrap_or(port)
336+
} else {
337+
port
338+
};
339+
331340
let addr = match hostname.parse() {
332341
Ok(ip) => SocketAddr::new(ip, port),
333342
Err(_) => {

router/src/server.rs

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use crate::validation::ValidationError;
55
use crate::{
66
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
77
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
8-
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
8+
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
99
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
10-
StreamResponse, Token, TokenizeResponse, Validation,
10+
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
1111
};
1212
use axum::extract::Extension;
1313
use axum::http::{HeaderMap, Method, StatusCode};
@@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
1616
use axum::routing::{get, post};
1717
use axum::{http, Json, Router};
1818
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
19+
use futures::stream::FuturesUnordered;
1920
use futures::stream::StreamExt;
2021
use futures::Stream;
22+
use futures::TryStreamExt;
2123
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
2224
use std::convert::Infallible;
2325
use std::net::SocketAddr;
@@ -693,6 +695,97 @@ async fn chat_completions(
693695
}
694696
}
695697

698+
/// Generate tokens from Vertex request
699+
#[utoipa::path(
700+
post,
701+
tag = "Text Generation Inference",
702+
path = "/vertex",
703+
request_body = VertexRequest,
704+
responses(
705+
(status = 200, description = "Generated Text", body = VertexResponse),
706+
(status = 424, description = "Generation Error", body = ErrorResponse,
707+
example = json ! ({"error": "Request failed during generation"})),
708+
(status = 429, description = "Model is overloaded", body = ErrorResponse,
709+
example = json ! ({"error": "Model is overloaded"})),
710+
(status = 422, description = "Input validation error", body = ErrorResponse,
711+
example = json ! ({"error": "Input validation error"})),
712+
(status = 500, description = "Incomplete generation", body = ErrorResponse,
713+
example = json ! ({"error": "Incomplete generation"})),
714+
)
715+
)]
716+
#[instrument(
717+
skip_all,
718+
fields(
719+
total_time,
720+
validation_time,
721+
queue_time,
722+
inference_time,
723+
time_per_token,
724+
seed,
725+
)
726+
)]
727+
async fn vertex_compatibility(
728+
Extension(infer): Extension<Infer>,
729+
Extension(compute_type): Extension<ComputeType>,
730+
Json(req): Json<VertexRequest>,
731+
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
732+
metrics::increment_counter!("tgi_request_count");
733+
734+
// check that theres at least one instance
735+
if req.instances.is_empty() {
736+
return Err((
737+
StatusCode::UNPROCESSABLE_ENTITY,
738+
Json(ErrorResponse {
739+
error: "Input validation error".to_string(),
740+
error_type: "Input validation error".to_string(),
741+
}),
742+
));
743+
}
744+
745+
// Process all instances
746+
let predictions = req
747+
.instances
748+
.iter()
749+
.map(|instance| {
750+
let generate_request = GenerateRequest {
751+
inputs: instance.inputs.clone(),
752+
parameters: GenerateParameters {
753+
do_sample: true,
754+
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
755+
seed: instance.parameters.as_ref().and_then(|p| p.seed),
756+
details: true,
757+
decoder_input_details: true,
758+
..Default::default()
759+
},
760+
};
761+
762+
async {
763+
generate(
764+
Extension(infer.clone()),
765+
Extension(compute_type.clone()),
766+
Json(generate_request),
767+
)
768+
.await
769+
.map(|(_, Json(generation))| generation.generated_text)
770+
.map_err(|_| {
771+
(
772+
StatusCode::INTERNAL_SERVER_ERROR,
773+
Json(ErrorResponse {
774+
error: "Incomplete generation".into(),
775+
error_type: "Incomplete generation".into(),
776+
}),
777+
)
778+
})
779+
}
780+
})
781+
.collect::<FuturesUnordered<_>>()
782+
.try_collect::<Vec<_>>()
783+
.await?;
784+
785+
let response = VertexResponse { predictions };
786+
Ok((HeaderMap::new(), Json(response)).into_response())
787+
}
788+
696789
/// Tokenize inputs
697790
#[utoipa::path(
698791
post,
@@ -818,6 +911,7 @@ pub async fn run(
818911
StreamResponse,
819912
StreamDetails,
820913
ErrorResponse,
914+
GrammarType,
821915
)
822916
),
823917
tags(
@@ -942,8 +1036,30 @@ pub async fn run(
9421036
docker_label: option_env!("DOCKER_LABEL"),
9431037
};
9441038

1039+
// Define VertextApiDoc conditionally only if the "google" feature is enabled
1040+
#[cfg(feature = "google")]
1041+
#[derive(OpenApi)]
1042+
#[openapi(
1043+
paths(vertex_compatibility),
1044+
components(schemas(VertexInstance, VertexRequest, VertexResponse))
1045+
)]
1046+
struct VertextApiDoc;
1047+
1048+
let doc = {
1049+
// avoid `mut` if possible
1050+
#[cfg(feature = "google")]
1051+
{
1052+
// limiting mutability to the smallest scope necessary
1053+
let mut doc = doc;
1054+
doc.merge(VertextApiDoc::openapi());
1055+
doc
1056+
}
1057+
#[cfg(not(feature = "google"))]
1058+
ApiDoc::openapi()
1059+
};
1060+
9451061
// Configure Swagger UI
946-
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi());
1062+
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
9471063

9481064
// Define base and health routes
9491065
let base_routes = Router::new()
@@ -953,6 +1069,7 @@ pub async fn run(
9531069
.route("/generate", post(generate))
9541070
.route("/generate_stream", post(generate_stream))
9551071
.route("/v1/chat/completions", post(chat_completions))
1072+
.route("/vertex", post(vertex_compatibility))
9561073
.route("/tokenize", post(tokenize))
9571074
.route("/health", get(health))
9581075
.route("/ping", get(health))
@@ -969,10 +1086,27 @@ pub async fn run(
9691086
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
9701087

9711088
// Combine routes and layers
972-
let app = Router::new()
1089+
let mut app = Router::new()
9731090
.merge(swagger_ui)
9741091
.merge(base_routes)
975-
.merge(aws_sagemaker_route)
1092+
.merge(aws_sagemaker_route);
1093+
1094+
#[cfg(feature = "google")]
1095+
{
1096+
tracing::info!("Built with `google` feature");
1097+
tracing::info!(
1098+
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1099+
);
1100+
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
1101+
app = app.route(&env_predict_route, post(vertex_compatibility));
1102+
}
1103+
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
1104+
app = app.route(&env_health_route, get(health));
1105+
}
1106+
}
1107+
1108+
// add layers after routes
1109+
app = app
9761110
.layer(Extension(info))
9771111
.layer(Extension(health_ext.clone()))
9781112
.layer(Extension(compat_return_full_text))

0 commit comments

Comments
 (0)