diff --git a/Makefile b/Makefile index 155b7ea1..477f0ffe 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ INTEG_API_INVOKE := RestApiUrl HttpApiUrl INTEG_EXTENSIONS := extension-fn extension-trait logs-trait # Using musl to run extensions on both AL1 and AL2 INTEG_ARCH := x86_64-unknown-linux-musl +RIE_MAX_CONCURRENCY ?= 4 define uppercase $(shell sed -r 's/(^|-)(\w)/\U\2/g' <<< $(1)) @@ -111,4 +112,8 @@ fmt: cargo +nightly fmt --all test-rie: - ./scripts/test-rie.sh $(EXAMPLE) \ No newline at end of file + ./scripts/test-rie.sh $(EXAMPLE) + +# Run RIE in Lambda Managed Instance (LMI) mode with concurrent polling. +test-rie-lmi: + RIE_MAX_CONCURRENCY=$(RIE_MAX_CONCURRENCY) ./scripts/test-rie.sh $(EXAMPLE) diff --git a/examples/basic-lambda/src/main.rs b/examples/basic-lambda/src/main.rs index d3f2a3cd..396c3afd 100644 --- a/examples/basic-lambda/src/main.rs +++ b/examples/basic-lambda/src/main.rs @@ -28,7 +28,10 @@ async fn main() -> Result<(), Error> { tracing::init_default_subscriber(); let func = service_fn(my_handler); - lambda_runtime::run(func).await?; + if let Err(err) = lambda_runtime::run(func).await { + eprintln!("run error: {:?}", err); + return Err(err); + } Ok(()) } diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index 60e279c7..d82ff0d0 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -102,7 +102,7 @@ use std::{ }; mod streaming; -pub use streaming::{run_with_streaming_response, StreamAdapter}; +pub use streaming::{run_with_streaming_response, run_with_streaming_response_concurrent, StreamAdapter}; /// Type alias for `http::Request`s with a fixed [`Body`](enum.Body.html) type pub type Request = http::Request; @@ -151,6 +151,18 @@ pub struct Adapter<'a, R, S> { _phantom_data: PhantomData<&'a R>, } +impl<'a, R, S> Clone for Adapter<'a, R, S> +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + service: self.service.clone(), + _phantom_data: PhantomData, + } + } +} + impl<'a, R, S, E> From for Adapter<'a, R, S> where S: Service, @@ -203,6 +215,24 @@ where lambda_runtime::run(Adapter::from(handler)).await } +/// Starts the Lambda Rust runtime in a mode that is compatible with +/// Lambda Managed Instances (concurrent invocations). +/// +/// When `AWS_LAMBDA_MAX_CONCURRENCY` is set to a value greater than 1, this +/// will use a concurrent `/next` polling loop with a bounded number of +/// in-flight handler tasks. When the environment variable is unset or `<= 1`, +/// it falls back to the same sequential behavior as [`run`], so the same +/// handler can run on both classic Lambda and Lambda Managed Instances. +pub async fn run_concurrent(handler: S) -> Result<(), Error> +where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, + R: IntoResponse + Send + Sync + 'static, + E: std::fmt::Debug + Into + Send + 'static, +{ + lambda_runtime::run_concurrent(Adapter::from(handler)).await +} + #[cfg(test)] mod test_adapter { use std::task::{Context, Poll}; diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs index ed61c773..a729206c 100644 --- a/lambda-http/src/streaming.rs +++ b/lambda-http/src/streaming.rs @@ -10,7 +10,7 @@ pub use http::{self, Response}; use http_body::Body; use lambda_runtime::{ tower::{ - util::{MapRequest, MapResponse}, + util::{BoxCloneService, MapRequest, MapResponse}, ServiceBuilder, ServiceExt, }, Diagnostic, @@ -93,14 +93,33 @@ where B::Error: Into + Send + Debug, { ServiceBuilder::new() - .map_request(|req: LambdaEvent| { - let event: Request = req.payload.into(); - event.with_lambda_context(req.context) - }) + .map_request(event_to_request as fn(LambdaEvent) -> Request) .service(handler) .map_response(into_stream_response) } +/// Builds a streaming-aware Tower service from a `Service` that can be +/// cloned and sent across tasks. This is used by the concurrent HTTP entrypoint. +#[allow(clippy::type_complexity)] +fn into_stream_service_boxed( + handler: S, +) -> BoxCloneService, StreamResponse>, E> +where + S: Service, Error = E> + Clone + Send + 'static, + S::Future: Send + 'static, + E: Debug + Into + Send + 'static, + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + let svc = ServiceBuilder::new() + .map_request(event_to_request as fn(LambdaEvent) -> Request) + .service(handler) + .map_response(into_stream_response); + + BoxCloneService::new(svc) +} + /// Converts an `http::Response` into a streaming Lambda response. fn into_stream_response(res: Response) -> StreamResponse> where @@ -128,6 +147,11 @@ where } } +fn event_to_request(req: LambdaEvent) -> Request { + let event: Request = req.payload.into(); + event.with_lambda_context(req.context) +} + /// Runs the Lambda runtime with a handler that returns **streaming** HTTP /// responses. /// @@ -147,6 +171,24 @@ where lambda_runtime::run(into_stream_service(handler)).await } +/// Runs the Lambda runtime with a handler that returns **streaming** HTTP +/// responses, in a mode that is compatible with Lambda Managed Instances. +/// +/// This uses a cloneable, boxed service internally so it can be driven by the +/// concurrent runtime. When `AWS_LAMBDA_MAX_CONCURRENCY` is not set or `<= 1`, +/// it falls back to the same sequential behavior as [`run_with_streaming_response`]. +pub async fn run_with_streaming_response_concurrent(handler: S) -> Result<(), Error> +where + S: Service, Error = E> + Clone + Send + 'static, + S::Future: Send + 'static, + E: Debug + Into + Send + 'static, + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + lambda_runtime::run_concurrent(into_stream_service_boxed(handler)).await +} + pin_project_lite::pin_project! { #[non_exhaustive] pub struct BodyStream { diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs index 3df616ab..86cc715f 100644 --- a/lambda-runtime-api-client/src/lib.rs +++ b/lambda-runtime-api-client/src/lib.rs @@ -41,6 +41,7 @@ impl Client { ClientBuilder { connector: HttpConnector::new(), uri: None, + pool_size: None, } } } @@ -59,11 +60,16 @@ impl Client { self.client.request(req).map_err(Into::into).boxed() } - /// Create a new client with a given base URI and HTTP connector. - fn with(base: Uri, connector: HttpConnector) -> Self { - let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) - .http1_max_buf_size(1024 * 1024) - .build(connector); + /// Create a new client with a given base URI, HTTP connector, and optional pool size hint. + fn with(base: Uri, connector: HttpConnector, pool_size: Option) -> Self { + let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); + builder.http1_max_buf_size(1024 * 1024); + + if let Some(size) = pool_size { + builder.pool_max_idle_per_host(size); + } + + let client = builder.build(connector); Self { base, client } } @@ -94,6 +100,7 @@ impl Client { pub struct ClientBuilder { connector: HttpConnector, uri: Option, + pool_size: Option, } impl ClientBuilder { @@ -102,6 +109,7 @@ impl ClientBuilder { ClientBuilder { connector, uri: self.uri, + pool_size: self.pool_size, } } @@ -111,6 +119,14 @@ impl ClientBuilder { Self { uri: Some(uri), ..self } } + /// Provide a pool size hint for the underlying Hyper client. + pub fn with_pool_size(self, pool_size: usize) -> Self { + Self { + pool_size: Some(pool_size), + ..self + } + } + /// Create the new client to interact with the Runtime API. pub fn build(self) -> Result { let uri = match self.uri { @@ -120,7 +136,7 @@ impl ClientBuilder { uri.try_into().expect("Unable to convert to URL") } }; - Ok(Client::with(uri, self.connector)) + Ok(Client::with(uri, self.connector, self.pool_size)) } } @@ -182,4 +198,17 @@ mod tests { &req.uri().to_string() ); } + + #[test] + fn builder_accepts_pool_size() { + let base = "http://localhost:9001"; + let expected: Uri = base.parse().unwrap(); + let client = Client::builder() + .with_pool_size(4) + .with_endpoint(base.parse().unwrap()) + .build() + .unwrap(); + + assert_eq!(client.base, expected); + } } diff --git a/lambda-runtime/src/layers/api_client.rs b/lambda-runtime/src/layers/api_client.rs index d44a84f2..7113ee0a 100644 --- a/lambda-runtime/src/layers/api_client.rs +++ b/lambda-runtime/src/layers/api_client.rs @@ -44,6 +44,18 @@ where } } +impl Clone for RuntimeApiClientService +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + client: self.client.clone(), + } + } +} + #[pin_project(project = RuntimeApiClientFutureProj)] pub enum RuntimeApiClientFuture { First(#[pin] F, Arc), diff --git a/lambda-runtime/src/layers/api_response.rs b/lambda-runtime/src/layers/api_response.rs index 453f8b4c..5bb3c96f 100644 --- a/lambda-runtime/src/layers/api_response.rs +++ b/lambda-runtime/src/layers/api_response.rs @@ -51,6 +51,27 @@ impl Clone + for RuntimeApiResponseService< + S, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + > +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + impl Service for RuntimeApiResponseService< S, diff --git a/lambda-runtime/src/layers/trace.rs b/lambda-runtime/src/layers/trace.rs index e93927b1..4a3ad3d9 100644 --- a/lambda-runtime/src/layers/trace.rs +++ b/lambda-runtime/src/layers/trace.rs @@ -25,6 +25,7 @@ impl Layer for TracingLayer { } /// Tower service returned by [TracingLayer]. +#[derive(Clone)] pub struct TracingService { inner: S, } diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index cbcd0a9e..610b608f 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -39,8 +39,12 @@ pub use lambda_runtime_api_client::tracing; /// Types available to a Lambda function. mod types; +#[cfg(all(unix, feature = "graceful-shutdown"))] +use crate::runtime::SHUTDOWN_NOTIFY; use requests::EventErrorRequest; pub use runtime::{LambdaInvocation, Runtime}; +#[cfg(all(unix, feature = "graceful-shutdown"))] +use std::time::Duration; pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse}; /// Error type that lambdas may result in @@ -59,6 +63,9 @@ pub struct Config { pub log_stream: String, /// The name of the Amazon CloudWatch Logs group for the function. pub log_group: String, + /// Maximum concurrent invocations for Lambda managed-concurrency environments. + /// Populated from `AWS_LAMBDA_MAX_CONCURRENCY` when present. + pub max_concurrency: Option, } type RefConfig = Arc; @@ -75,8 +82,17 @@ impl Config { version: env::var("AWS_LAMBDA_FUNCTION_VERSION").expect("Missing AWS_LAMBDA_FUNCTION_VERSION env var"), log_stream: env::var("AWS_LAMBDA_LOG_STREAM_NAME").unwrap_or_default(), log_group: env::var("AWS_LAMBDA_LOG_GROUP_NAME").unwrap_or_default(), + max_concurrency: env::var("AWS_LAMBDA_MAX_CONCURRENCY") + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|&c| c > 0), } } + + /// Returns true if concurrent runtime mode should be enabled. + pub fn is_concurrent(&self) -> bool { + self.max_concurrency.map(|c| c > 1).unwrap_or(false) + } } /// Return a new [`ServiceFn`] with a closure that takes an event and context as separate arguments. @@ -126,6 +142,30 @@ where runtime.run().await } +/// Starts the Lambda Rust runtime in a mode that is compatible with +/// Lambda Managed Instances (concurrent invocations). +/// +/// When `AWS_LAMBDA_MAX_CONCURRENCY` is set to a value greater than 1, this +/// will use a concurrent `/next` polling loop with a bounded number of +/// in-flight handler tasks. When the environment variable is unset or `<= 1`, +/// it falls back to the same sequential behavior as [`run`], so the same +/// handler can run on both classic Lambda and Lambda Managed Instances. +pub async fn run_concurrent(handler: F) -> Result<(), Error> +where + F: Service, Response = R> + Clone + Send + 'static, + F::Future: Future> + Send + 'static, + F::Error: Into + fmt::Debug, + A: for<'de> Deserialize<'de> + Send + 'static, + R: IntoFunctionResponse + Send + 'static, + B: Serialize + Send + 'static, + S: Stream> + Unpin + Send + 'static, + D: Into + Send + 'static, + E: Into + Send + Debug + 'static, +{ + let runtime = Runtime::new(handler).layer(layers::TracingLayer::new()); + runtime.run_concurrent().await +} + /// Spawns a task that will be execute a provided async closure when the process /// receives unix graceful shutdown signals. If the closure takes longer than 500ms /// to execute, an unhandled `SIGKILL` signal might be received. @@ -211,14 +251,26 @@ where eprintln!("[runtime] Graceful shutdown in progress ..."); shutdown_hook().await; eprintln!("[runtime] Graceful shutdown completed"); - std::process::exit(0); + if let Some(tx) = SHUTDOWN_NOTIFY.get() { + let _ = tx.send(true); + tokio::time::sleep(Duration::from_millis(500)).await; + std::process::exit(0); + } else { + std::process::exit(0); + } }, _sigterm = sigterm.recv()=> { eprintln!("[runtime] SIGTERM received"); eprintln!("[runtime] Graceful shutdown in progress ..."); shutdown_hook().await; eprintln!("[runtime] Graceful shutdown completed"); - std::process::exit(0); + if let Some(tx) = SHUTDOWN_NOTIFY.get() { + let _ = tx.send(true); + tokio::time::sleep(Duration::from_millis(500)).await; + std::process::exit(0); + } else { + std::process::exit(0); + } }, } }; diff --git a/lambda-runtime/src/runtime.rs b/lambda-runtime/src/runtime.rs index 517ee64f..e6a8e13a 100644 --- a/lambda-runtime/src/runtime.rs +++ b/lambda-runtime/src/runtime.rs @@ -4,13 +4,20 @@ use crate::{ types::{invoke_request_id, IntoFunctionResponse, LambdaEvent}, Config, Context, Diagnostic, }; +use futures::stream::FuturesUnordered; use http_body_util::BodyExt; use lambda_runtime_api_client::{BoxError, Client as ApiClient}; use serde::{Deserialize, Serialize}; -use std::{env, fmt::Debug, future::Future, sync::Arc}; +use std::{ + env, + fmt::Debug, + future::Future, + sync::{Arc, OnceLock}, +}; +use tokio::sync::{watch, Semaphore}; use tokio_stream::{Stream, StreamExt}; use tower::{Layer, Service, ServiceExt}; -use tracing::trace; +use tracing::{error, trace, warn}; /* ----------------------------------------- INVOCATION ---------------------------------------- */ @@ -55,6 +62,11 @@ pub struct Runtime { client: Arc, } +/// Global shutdown notifier used by concurrent runtime to coordinate graceful termination. +pub(crate) static SHUTDOWN_NOTIFY: OnceLock> = OnceLock::new(); +/// One-time marker to log X-Ray behavior in concurrent mode. +static XRAY_LOGGED: OnceLock<()> = OnceLock::new(); + impl Runtime< RuntimeApiClientService< @@ -92,7 +104,13 @@ where pub fn new(handler: F) -> Self { trace!("Loading config from env"); let config = Arc::new(Config::from_env()); - let client = Arc::new(ApiClient::builder().build().expect("Unable to create a runtime client")); + let pool_size = config.max_concurrency.unwrap_or(1).max(1) as usize; + let client = Arc::new( + ApiClient::builder() + .with_pool_size(pool_size) + .build() + .expect("Unable to create a runtime client"), + ); Self { service: wrap_handler(handler, client.clone()), config, @@ -137,6 +155,152 @@ impl Runtime { } } +impl Runtime +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + /// Start the runtime in concurrent mode when configured for Lambda managed-concurrency. + /// + /// If `AWS_LAMBDA_MAX_CONCURRENCY` is not set or is `<= 1`, this falls back to the + /// sequential `run_with_incoming` loop so that the same handler can run on both + /// classic Lambda and Lambda Managed Instances. + pub async fn run_concurrent(self) -> Result<(), BoxError> { + if self.config.is_concurrent() { + let max_concurrency = self.config.max_concurrency.unwrap_or(1); + Self::run_concurrent_inner(self.service, self.config, self.client, max_concurrency).await + } else { + let incoming = incoming(&self.client); + Self::run_with_incoming(self.service, self.config, incoming).await + } + } + + /// Concurrent processing using windowed long-polls (for Lambda managed-concurrency). + async fn run_concurrent_inner( + service: S, + config: Arc, + client: Arc, + max_concurrency: u32, + ) -> Result<(), BoxError> { + let limit = max_concurrency as usize; + let semaphore = Arc::new(Semaphore::new(limit)); + let mut polls = FuturesUnordered::new(); + let mut handlers = FuturesUnordered::new(); + let mut paused_polls = 0usize; + // Bound total spawned tasks (running + waiting on permits) + let max_spawned_tasks = limit * 2; + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let _ = SHUTDOWN_NOTIFY.set(shutdown_tx); + let mut shutting_down = false; + + for _ in 0..limit { + polls.push(next_event_future(client.clone())); + } + + loop { + tokio::select! { + _ = shutdown_rx.changed(), if !shutting_down => { + if *shutdown_rx.borrow() { + shutting_down = true; + trace!("Shutdown requested; draining handlers and stopping new polls"); + polls.clear(); + } + } + + Some(result) = futures::StreamExt::next(&mut polls) => { + let event = match result { + Ok(event) => event, + Err(e) => { + warn!(error = %e, "Error polling /next, retrying"); + polls.push(next_event_future(client.clone())); + continue; + } + }; + + let (parts, incoming) = event.into_parts(); + + #[cfg(debug_assertions)] + if parts.status == http::StatusCode::NO_CONTENT { + if paused_polls == 0 && !shutting_down { + polls.push(next_event_future(client.clone())); + } else { + paused_polls = paused_polls.saturating_sub(1); + } + continue; + } + + let at_cap = handlers.len() >= max_spawned_tasks; + if !at_cap && !shutting_down { + polls.push(next_event_future(client.clone())); + } else { + paused_polls += 1; + } + + // Collect body before spawning to release the HTTP connection earlier. + let request_id = invoke_request_id(&parts.headers)?.to_owned(); + let body = incoming.collect().await?.to_bytes(); + let mut svc = service.clone(); + let cfg = config.clone(); + let sem = semaphore.clone(); + + handlers.push(tokio::spawn(async move { + // Permit acquired inside task (keeps event loop non-blocking) + let _permit = sem.acquire_owned().await?; + + let context = match Context::new(&request_id, cfg, &parts.headers) { + Ok(ctx) => ctx, + Err(err) => { + error!(request_id = %request_id, error = %err, "Context::new failed"); + return Err(err); + } + }; + + // Inform users that X-Ray is available via context, not env var, in concurrent mode. + XRAY_LOGGED.get_or_init(|| { + trace!("Concurrent mode: _X_AMZN_TRACE_ID is not set; use context.xray_trace_id"); + }); + + let invocation = LambdaInvocation { parts, body, context }; + + trace!(request_id = %request_id, "Processing invocation"); + let ready = match svc.ready().await { + Ok(r) => r, + Err(err) => { + error!(request_id = %request_id, error = %err, "Service not ready"); + return Err(err); + } + }; + if let Err(err) = ready.call(invocation).await { + error!(request_id = %request_id, error = %err, "Handler call failed"); + return Err(err); + } + trace!(request_id = %request_id, "Invocation completed"); + Ok::<(), BoxError>(()) + })); + } + + Some(result) = futures::StreamExt::next(&mut handlers) => { + result??; + + if paused_polls > 0 && handlers.len() < max_spawned_tasks && !shutting_down { + paused_polls -= 1; + polls.push(next_event_future(client.clone())); + } + + if shutting_down && handlers.is_empty() { + trace!("All handlers drained after shutdown"); + break; + } + } + + else => break, + } + } + + Ok(()) + } +} + impl Runtime where S: Service, @@ -233,6 +397,12 @@ fn incoming( } } +/// Creates a future that polls the `/next` endpoint. +async fn next_event_future(client: Arc) -> Result, BoxError> { + let req = NextEventRequest.into_req()?; + client.call(req).await +} + fn amzn_trace_env(ctx: &Context) { match &ctx.xray_trace_id { Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id), @@ -456,6 +626,7 @@ mod endpoint_tests { version: "1".to_string(), log_stream: "test_stream".to_string(), log_group: "test_log".to_string(), + max_concurrency: None, }); let client = Arc::new(client); @@ -485,4 +656,58 @@ mod endpoint_tests { }) .await } + + #[test] + fn config_parses_max_concurrency() { + // Preserve existing env values + let prev_fn = env::var("AWS_LAMBDA_FUNCTION_NAME").ok(); + let prev_mem = env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").ok(); + let prev_ver = env::var("AWS_LAMBDA_FUNCTION_VERSION").ok(); + let prev_log_stream = env::var("AWS_LAMBDA_LOG_STREAM_NAME").ok(); + let prev_log_group = env::var("AWS_LAMBDA_LOG_GROUP_NAME").ok(); + let prev_max = env::var("AWS_LAMBDA_MAX_CONCURRENCY").ok(); + + env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); + env::set_var("AWS_LAMBDA_MAX_CONCURRENCY", "4"); + + let cfg = Config::from_env(); + assert_eq!(cfg.max_concurrency, Some(4)); + assert!(cfg.is_concurrent()); + + // Restore env + if let Some(v) = prev_fn { + env::set_var("AWS_LAMBDA_FUNCTION_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_NAME"); + } + if let Some(v) = prev_mem { + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE"); + } + if let Some(v) = prev_ver { + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_VERSION"); + } + if let Some(v) = prev_log_stream { + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_LOG_STREAM_NAME"); + } + if let Some(v) = prev_log_group { + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_LOG_GROUP_NAME"); + } + if let Some(v) = prev_max { + env::set_var("AWS_LAMBDA_MAX_CONCURRENCY", v); + } else { + env::remove_var("AWS_LAMBDA_MAX_CONCURRENCY"); + } + } } diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 5e5f487a..03cbfad0 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -104,13 +104,23 @@ impl Context { /// and the incoming request data. pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result { let client_context: Option = if let Some(value) = headers.get("lambda-runtime-client-context") { - serde_json::from_str(value.to_str()?)? + let raw = value.to_str()?; + if raw.is_empty() { + None + } else { + Some(serde_json::from_str(raw)?) + } } else { None }; let identity: Option = if let Some(value) = headers.get("lambda-runtime-cognito-identity") { - serde_json::from_str(value.to_str()?)? + let raw = value.to_str()?; + if raw.is_empty() { + None + } else { + Some(serde_json::from_str(raw)?) + } } else { None }; diff --git a/scripts/test-rie.sh b/scripts/test-rie.sh index 911cb390..8d3c9320 100755 --- a/scripts/test-rie.sh +++ b/scripts/test-rie.sh @@ -2,12 +2,19 @@ set -euo pipefail EXAMPLE=${1:-basic-lambda} +# Optional: set RIE_MAX_CONCURRENCY to enable LMI mode (emulates AWS_LAMBDA_MAX_CONCURRENCY) +RIE_MAX_CONCURRENCY=${RIE_MAX_CONCURRENCY:-} echo "Building Docker image with RIE for example: $EXAMPLE..." docker build -f Dockerfile.rie --build-arg EXAMPLE=$EXAMPLE -t rust-lambda-rie-test . echo "Starting RIE container on port 9000..." -docker run -p 9000:8080 rust-lambda-rie-test & +if [ -n "$RIE_MAX_CONCURRENCY" ]; then + echo "Enabling LMI mode with AWS_LAMBDA_MAX_CONCURRENCY=$RIE_MAX_CONCURRENCY" + docker run -p 9000:8080 -e AWS_LAMBDA_MAX_CONCURRENCY="$RIE_MAX_CONCURRENCY" rust-lambda-rie-test & +else + docker run -p 9000:8080 rust-lambda-rie-test & +fi CONTAINER_PID=$! echo "Container started. Test with:" @@ -19,4 +26,4 @@ fi echo "" echo "Press Ctrl+C to stop the container." -wait $CONTAINER_PID \ No newline at end of file +wait $CONTAINER_PID