@@ -5,9 +5,9 @@ use crate::validation::ValidationError;
55use crate :: {
66 BestOfSequence , ChatCompletion , ChatCompletionChoice , ChatCompletionChunk , ChatCompletionDelta ,
77 ChatCompletionLogprobs , ChatRequest , CompatGenerateRequest , Details , ErrorResponse ,
8- FinishReason , GenerateParameters , GenerateRequest , GenerateResponse , HubModelInfo ,
8+ FinishReason , GenerateParameters , GenerateRequest , GenerateResponse , GrammarType , HubModelInfo ,
99 HubTokenizerConfig , Infer , Info , Message , PrefillToken , SimpleToken , StreamDetails ,
10- StreamResponse , Token , TokenizeResponse , Validation ,
10+ StreamResponse , Token , TokenizeResponse , Validation , VertexRequest , VertexResponse ,
1111} ;
1212use axum:: extract:: Extension ;
1313use axum:: http:: { HeaderMap , Method , StatusCode } ;
@@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
1616use axum:: routing:: { get, post} ;
1717use axum:: { http, Json , Router } ;
1818use axum_tracing_opentelemetry:: middleware:: OtelAxumLayer ;
19+ use futures:: stream:: FuturesUnordered ;
1920use futures:: stream:: StreamExt ;
2021use futures:: Stream ;
22+ use futures:: TryStreamExt ;
2123use metrics_exporter_prometheus:: { Matcher , PrometheusBuilder , PrometheusHandle } ;
2224use std:: convert:: Infallible ;
2325use std:: net:: SocketAddr ;
@@ -693,6 +695,97 @@ async fn chat_completions(
693695 }
694696}
695697
698+ /// Generate tokens from Vertex request
699+ #[ utoipa:: path(
700+ post,
701+ tag = "Text Generation Inference" ,
702+ path = "/vertex" ,
703+ request_body = VertexRequest ,
704+ responses(
705+ ( status = 200 , description = "Generated Text" , body = VertexResponse ) ,
706+ ( status = 424 , description = "Generation Error" , body = ErrorResponse ,
707+ example = json ! ( { "error" : "Request failed during generation" } ) ) ,
708+ ( status = 429 , description = "Model is overloaded" , body = ErrorResponse ,
709+ example = json ! ( { "error" : "Model is overloaded" } ) ) ,
710+ ( status = 422 , description = "Input validation error" , body = ErrorResponse ,
711+ example = json ! ( { "error" : "Input validation error" } ) ) ,
712+ ( status = 500 , description = "Incomplete generation" , body = ErrorResponse ,
713+ example = json ! ( { "error" : "Incomplete generation" } ) ) ,
714+ )
715+ ) ]
716+ #[ instrument(
717+ skip_all,
718+ fields(
719+ total_time,
720+ validation_time,
721+ queue_time,
722+ inference_time,
723+ time_per_token,
724+ seed,
725+ )
726+ ) ]
727+ async fn vertex_compatibility (
728+ Extension ( infer) : Extension < Infer > ,
729+ Extension ( compute_type) : Extension < ComputeType > ,
730+ Json ( req) : Json < VertexRequest > ,
731+ ) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
732+ metrics:: increment_counter!( "tgi_request_count" ) ;
733+
734+ // check that theres at least one instance
735+ if req. instances . is_empty ( ) {
736+ return Err ( (
737+ StatusCode :: UNPROCESSABLE_ENTITY ,
738+ Json ( ErrorResponse {
739+ error : "Input validation error" . to_string ( ) ,
740+ error_type : "Input validation error" . to_string ( ) ,
741+ } ) ,
742+ ) ) ;
743+ }
744+
745+ // Process all instances
746+ let predictions = req
747+ . instances
748+ . iter ( )
749+ . map ( |instance| {
750+ let generate_request = GenerateRequest {
751+ inputs : instance. inputs . clone ( ) ,
752+ parameters : GenerateParameters {
753+ do_sample : true ,
754+ max_new_tokens : instance. parameters . as_ref ( ) . and_then ( |p| p. max_new_tokens ) ,
755+ seed : instance. parameters . as_ref ( ) . and_then ( |p| p. seed ) ,
756+ details : true ,
757+ decoder_input_details : true ,
758+ ..Default :: default ( )
759+ } ,
760+ } ;
761+
762+ async {
763+ generate (
764+ Extension ( infer. clone ( ) ) ,
765+ Extension ( compute_type. clone ( ) ) ,
766+ Json ( generate_request) ,
767+ )
768+ . await
769+ . map ( |( _, Json ( generation) ) | generation. generated_text )
770+ . map_err ( |_| {
771+ (
772+ StatusCode :: INTERNAL_SERVER_ERROR ,
773+ Json ( ErrorResponse {
774+ error : "Incomplete generation" . into ( ) ,
775+ error_type : "Incomplete generation" . into ( ) ,
776+ } ) ,
777+ )
778+ } )
779+ }
780+ } )
781+ . collect :: < FuturesUnordered < _ > > ( )
782+ . try_collect :: < Vec < _ > > ( )
783+ . await ?;
784+
785+ let response = VertexResponse { predictions } ;
786+ Ok ( ( HeaderMap :: new ( ) , Json ( response) ) . into_response ( ) )
787+ }
788+
696789/// Tokenize inputs
697790#[ utoipa:: path(
698791 post,
@@ -818,6 +911,7 @@ pub async fn run(
818911 StreamResponse ,
819912 StreamDetails ,
820913 ErrorResponse ,
914+ GrammarType ,
821915 )
822916 ) ,
823917 tags(
@@ -942,8 +1036,30 @@ pub async fn run(
9421036 docker_label : option_env ! ( "DOCKER_LABEL" ) ,
9431037 } ;
9441038
1039+ // Define VertextApiDoc conditionally only if the "google" feature is enabled
1040+ #[ cfg( feature = "google" ) ]
1041+ #[ derive( OpenApi ) ]
1042+ #[ openapi(
1043+ paths( vertex_compatibility) ,
1044+ components( schemas( VertexInstance , VertexRequest , VertexResponse ) )
1045+ ) ]
1046+ struct VertextApiDoc ;
1047+
1048+ let doc = {
1049+ // avoid `mut` if possible
1050+ #[ cfg( feature = "google" ) ]
1051+ {
1052+ // limiting mutability to the smallest scope necessary
1053+ let mut doc = doc;
1054+ doc. merge ( VertextApiDoc :: openapi ( ) ) ;
1055+ doc
1056+ }
1057+ #[ cfg( not( feature = "google" ) ) ]
1058+ ApiDoc :: openapi ( )
1059+ } ;
1060+
9451061 // Configure Swagger UI
946- let swagger_ui = SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , ApiDoc :: openapi ( ) ) ;
1062+ let swagger_ui = SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc ) ;
9471063
9481064 // Define base and health routes
9491065 let base_routes = Router :: new ( )
@@ -953,6 +1069,7 @@ pub async fn run(
9531069 . route ( "/generate" , post ( generate) )
9541070 . route ( "/generate_stream" , post ( generate_stream) )
9551071 . route ( "/v1/chat/completions" , post ( chat_completions) )
1072+ . route ( "/vertex" , post ( vertex_compatibility) )
9561073 . route ( "/tokenize" , post ( tokenize) )
9571074 . route ( "/health" , get ( health) )
9581075 . route ( "/ping" , get ( health) )
@@ -969,10 +1086,27 @@ pub async fn run(
9691086 ComputeType ( std:: env:: var ( "COMPUTE_TYPE" ) . unwrap_or ( "gpu+optimized" . to_string ( ) ) ) ;
9701087
9711088 // Combine routes and layers
972- let app = Router :: new ( )
1089+ let mut app = Router :: new ( )
9731090 . merge ( swagger_ui)
9741091 . merge ( base_routes)
975- . merge ( aws_sagemaker_route)
1092+ . merge ( aws_sagemaker_route) ;
1093+
1094+ #[ cfg( feature = "google" ) ]
1095+ {
1096+ tracing:: info!( "Built with `google` feature" ) ;
1097+ tracing:: info!(
1098+ "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1099+ ) ;
1100+ if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
1101+ app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1102+ }
1103+ if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
1104+ app = app. route ( & env_health_route, get ( health) ) ;
1105+ }
1106+ }
1107+
1108+ // add layers after routes
1109+ app = app
9761110 . layer ( Extension ( info) )
9771111 . layer ( Extension ( health_ext. clone ( ) ) )
9781112 . layer ( Extension ( compat_return_full_text) )
0 commit comments