Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions bin/gateway/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod http_utils;
mod logger;
mod pipeline;
mod planner_service;
mod shared_state;

use crate::{
Expand All @@ -10,12 +11,15 @@ use crate::{
request_id::{RequestIdGenerator, REQUEST_ID_HEADER_NAME},
},
logger::{configure_logging, LoggingFormat},
planner_service::{
planner_service_handler, supergraph_schema_handler, supergraph_version_handler,
},
shared_state::GatewaySharedState,
};
use axum::{
body::Body,
http::Method,
routing::{any_service, get},
routing::{any_service, get, post},
Router,
};
use http::Request;
Expand Down Expand Up @@ -78,7 +82,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
});
let parsed_schema = parse_schema(&supergraph_sdl);
let gateway_shared_state = GatewaySharedState::new(parsed_schema);
let supergraph_version = env::var("SUPERGRAPH_VERSION").unwrap_or_default();
let gateway_shared_state = GatewaySharedState::new(parsed_schema, supergraph_version);

let pipeline = ServiceBuilder::new()
.layer(Extension(gateway_shared_state.clone()))
Expand Down Expand Up @@ -115,7 +120,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let app = Router::new()
.route("/graphql", any_service(pipeline))
.route("/supergraph/version", get(supergraph_version_handler))
.route("/supergraph/schema", get(supergraph_schema_handler))
.route("/build-query-plan", post(planner_service_handler))
.route("/health", get(health_check_handler))
.layer(Extension(gateway_shared_state.clone()))
.layer(
CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
Expand Down
102 changes: 102 additions & 0 deletions bin/gateway/src/planner_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::{collections::HashSet, sync::Arc};

use axum::{response::IntoResponse, Extension, Json};
use graphql_tools::validation::validate::validate;
use http::{header::CONTENT_TYPE, StatusCode};
use query_planner::{
ast::{document::NormalizedDocument, normalization::normalize_operation},
graph::{PlannerOverrideContext, PERCENTAGE_SCALE_FACTOR},
planner::plan_nodes::QueryPlan,
utils::parsing::safe_parse_operation,
};
use rand::Rng;
use serde::Deserialize;
use sonic_rs::json;

use crate::{pipeline::error::PipelineErrorVariant, shared_state::GatewaySharedState};

pub async fn supergraph_version_handler(
state: Extension<Arc<GatewaySharedState>>,
) -> impl IntoResponse {
json!({
"version": state.supergraph_version
})
.to_string()
}

pub async fn supergraph_schema_handler(
state: Extension<Arc<GatewaySharedState>>,
) -> impl IntoResponse {
state.sdl.clone()
}

#[derive(Deserialize)]
pub struct PlannerServiceJsonInput {
#[serde(rename = "operationName")]
pub operation_name: Option<String>,
pub query: String,
}

pub async fn planner_service_handler(
state: Extension<Arc<GatewaySharedState>>,
body: Json<PlannerServiceJsonInput>,
) -> impl IntoResponse {
match plan(&body.0, &state).await {
Ok((plan, normalized_document)) => (
StatusCode::OK,
[(CONTENT_TYPE, "application/json")],
json!({
"plan": plan,
"normalizedOperation": normalized_document.operation.to_string()
})
.to_string(),
),
Err(err) => (
err.default_status_code(false),
[(CONTENT_TYPE, "application/json")],
json!({
"error": err.graphql_error_message()
})
.to_string(),
),
}
}

async fn plan(
input: &PlannerServiceJsonInput,
state: &GatewaySharedState,
) -> Result<(QueryPlan, NormalizedDocument), PipelineErrorVariant> {
let parsed_operation =
safe_parse_operation(&input.query).map_err(PipelineErrorVariant::FailedToParseOperation)?;
let consumer_schema_ast = &state.planner.consumer_schema.document;
let validation_errors = validate(
consumer_schema_ast,
&parsed_operation,
&state.validation_plan,
);

if !validation_errors.is_empty() {
return Err(PipelineErrorVariant::ValidationErrors(Arc::new(
validation_errors,
)));
}

let normalized_operation = normalize_operation(
&state.planner.supergraph,
&parsed_operation,
input.operation_name.as_deref(),
)
.map_err(PipelineErrorVariant::NormalizationError)?;

let request_override_context = PlannerOverrideContext::new(
HashSet::new(),
rand::rng().random_range(0..=(100 * PERCENTAGE_SCALE_FACTOR)),
);

let plan = state
.planner
.plan_from_normalized_operation(&normalized_operation.operation, request_override_context)
.map_err(PipelineErrorVariant::PlannerError)?;

Ok((plan, normalized_operation))
}
9 changes: 8 additions & 1 deletion bin/gateway/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ pub struct GatewaySharedState {
pub validate_cache: Cache<u64, Arc<Vec<ValidationError>>>,
pub parse_cache: Cache<u64, Arc<graphql_parser::query::Document<'static, String>>>,
pub normalize_cache: Cache<u64, Arc<GraphQLNormalizationPayload>>,
pub supergraph_version: String,
pub sdl: String,
}

impl GatewaySharedState {
pub fn new(parsed_supergraph_sdl: Document<'static, String>) -> Arc<Self> {
pub fn new(
parsed_supergraph_sdl: Document<'static, String>,
supergraph_version: String,
) -> Arc<Self> {
let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl);
let planner =
Planner::new_from_supergraph(&parsed_supergraph_sdl).expect("failed to create planner");
Expand All @@ -52,6 +57,8 @@ impl GatewaySharedState {
validate_cache: moka::future::Cache::new(1000),
parse_cache: moka::future::Cache::new(1000),
normalize_cache: moka::future::Cache::new(1000),
supergraph_version,
sdl: parsed_supergraph_sdl.to_string(),
})
}
}
Loading