-
Notifications
You must be signed in to change notification settings - Fork 618
Expand file tree
/
Copy pathgrpc_client.rs
More file actions
357 lines (307 loc) · 12.2 KB
/
grpc_client.rs
File metadata and controls
357 lines (307 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! gRPC client for fetching sandbox policy, provider environment, and inference
//! route bundles from OpenShell server.
use std::collections::HashMap;
use std::time::Duration;
use miette::{IntoDiagnostic, Result, WrapErr};
use openshell_core::proto::{
DenialSummary, GetInferenceBundleRequest, GetInferenceBundleResponse, GetSandboxPolicyRequest,
GetSandboxProviderEnvironmentRequest, PolicyStatus, ReportPolicyStatusRequest,
SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, UpdateSandboxPolicyRequest,
inference_client::InferenceClient, open_shell_client::OpenShellClient,
};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity};
use tracing::debug;
/// Create a channel to the OpenShell server.
///
/// When the endpoint uses `https://`, mTLS is configured using these env vars:
/// - `OPENSHELL_TLS_CA` -- path to the CA certificate
/// - `OPENSHELL_TLS_CERT` -- path to the client certificate
/// - `OPENSHELL_TLS_KEY` -- path to the client private key
///
/// When the endpoint uses `http://`, a plaintext connection is used (for
/// deployments where TLS is disabled, e.g. behind a Cloudflare Tunnel).
async fn connect_channel(endpoint: &str) -> Result<Channel> {
let mut ep = Endpoint::from_shared(endpoint.to_string())
.into_diagnostic()
.wrap_err("invalid gRPC endpoint")?
.connect_timeout(Duration::from_secs(10))
.http2_keep_alive_interval(Duration::from_secs(10))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(20));
let tls_enabled = endpoint.starts_with("https://");
if tls_enabled {
let ca_path = std::env::var("OPENSHELL_TLS_CA")
.into_diagnostic()
.wrap_err("OPENSHELL_TLS_CA is required")?;
let cert_path = std::env::var("OPENSHELL_TLS_CERT")
.into_diagnostic()
.wrap_err("OPENSHELL_TLS_CERT is required")?;
let key_path = std::env::var("OPENSHELL_TLS_KEY")
.into_diagnostic()
.wrap_err("OPENSHELL_TLS_KEY is required")?;
let ca_pem = std::fs::read(&ca_path)
.into_diagnostic()
.wrap_err_with(|| format!("failed to read CA cert from {ca_path}"))?;
let cert_pem = std::fs::read(&cert_path)
.into_diagnostic()
.wrap_err_with(|| format!("failed to read client cert from {cert_path}"))?;
let key_pem = std::fs::read(&key_path)
.into_diagnostic()
.wrap_err_with(|| format!("failed to read client key from {key_path}"))?;
let tls_config = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(ca_pem))
.identity(Identity::from_pem(cert_pem, key_pem));
ep = ep
.tls_config(tls_config)
.into_diagnostic()
.wrap_err("failed to configure TLS")?;
}
ep.connect()
.await
.into_diagnostic()
.wrap_err("failed to connect to OpenShell server")
}
/// Connect to the OpenShell server (mTLS or plaintext based on endpoint scheme).
async fn connect(endpoint: &str) -> Result<OpenShellClient<Channel>> {
let channel = connect_channel(endpoint).await?;
Ok(OpenShellClient::new(channel))
}
/// Fetch sandbox policy from OpenShell server via gRPC.
///
/// Returns `Ok(Some(policy))` when the server has a policy configured,
/// or `Ok(None)` when the sandbox was created without a policy (the sandbox
/// should discover one from disk or use the restrictive default).
pub async fn fetch_policy(endpoint: &str, sandbox_id: &str) -> Result<Option<ProtoSandboxPolicy>> {
debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Connecting to OpenShell server");
let mut client = connect(endpoint).await?;
debug!("Connected, fetching sandbox policy");
fetch_policy_with_client(&mut client, sandbox_id).await
}
/// Fetch sandbox policy using an existing client connection.
async fn fetch_policy_with_client(
client: &mut OpenShellClient<Channel>,
sandbox_id: &str,
) -> Result<Option<ProtoSandboxPolicy>> {
let response = client
.get_sandbox_policy(GetSandboxPolicyRequest {
sandbox_id: sandbox_id.to_string(),
})
.await
.into_diagnostic()?;
let inner = response.into_inner();
// version 0 with no policy means the sandbox was created without one.
if inner.version == 0 && inner.policy.is_none() {
return Ok(None);
}
Ok(Some(inner.policy.ok_or_else(|| {
miette::miette!("Server returned non-zero version but empty policy")
})?))
}
/// Sync a locally-discovered policy using an existing client connection.
async fn sync_policy_with_client(
client: &mut OpenShellClient<Channel>,
sandbox: &str,
policy: &ProtoSandboxPolicy,
) -> Result<()> {
client
.update_sandbox_policy(UpdateSandboxPolicyRequest {
name: sandbox.to_string(),
policy: Some(policy.clone()),
})
.await
.into_diagnostic()
.wrap_err("failed to sync policy to server")?;
Ok(())
}
/// Discover and sync policy using a single gRPC connection.
///
/// Performs the full discovery flow (fetch → sync → re-fetch) over one
/// channel instead of establishing three separate connections.
pub async fn discover_and_sync_policy(
endpoint: &str,
sandbox_id: &str,
sandbox: &str,
discovered_policy: &ProtoSandboxPolicy,
) -> Result<ProtoSandboxPolicy> {
debug!(
endpoint = %endpoint,
sandbox_id = %sandbox_id,
sandbox = %sandbox,
"Syncing discovered policy and re-fetching canonical version"
);
let mut client = connect(endpoint).await?;
// Sync the discovered policy to the gateway.
sync_policy_with_client(&mut client, sandbox, discovered_policy).await?;
// Re-fetch from the gateway to get the canonical version/hash.
fetch_policy_with_client(&mut client, sandbox_id)
.await?
.ok_or_else(|| {
miette::miette!("Server still returned no policy after sync — this is a bug")
})
}
/// Sync an enriched policy back to the gateway.
///
/// Used by the supervisor to push baseline-path-enriched policies so the
/// gateway stores the effective policy users see via `openshell sandbox get`.
pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPolicy) -> Result<()> {
debug!(endpoint = %endpoint, sandbox = %sandbox, "Syncing enriched policy to gateway");
let mut client = connect(endpoint).await?;
sync_policy_with_client(&mut client, sandbox, policy).await
}
/// Provider environment fetched from the server, indexed by provider type.
pub struct ProviderEnvironment {
/// Env vars indexed by provider type (e.g. `"anthropic"` -> `{"ANTHROPIC_API_KEY": "sk-..."}`).
pub by_type: HashMap<String, HashMap<String, String>>,
}
impl ProviderEnvironment {
/// Flatten all provider env vars into a single map for injection into the
/// child process. When two different provider types set the same env var,
/// one value wins arbitrarily (iteration order over `HashMap` keys is
/// nondeterministic).
pub fn flatten(self) -> HashMap<String, String> {
let mut flat = HashMap::new();
for (_provider_type, env) in self.by_type {
for (key, value) in env {
flat.entry(key).or_insert(value);
}
}
flat
}
/// Returns the set of provider types present.
pub fn provider_types(&self) -> Vec<String> {
self.by_type.keys().cloned().collect()
}
/// Check if a specific provider type is present.
pub fn has_provider_type(&self, provider_type: &str) -> bool {
self.by_type.contains_key(provider_type)
}
}
/// Fetch provider environment variables for a sandbox from OpenShell server via gRPC.
///
/// Returns provider credentials indexed by provider type. Use
/// [`ProviderEnvironment::flatten`] to merge into a single env var map for
/// injection into the child process.
pub async fn fetch_provider_environment(
endpoint: &str,
sandbox_id: &str,
) -> Result<ProviderEnvironment> {
debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment");
let mut client = connect(endpoint).await?;
let response = client
.get_sandbox_provider_environment(GetSandboxProviderEnvironmentRequest {
sandbox_id: sandbox_id.to_string(),
})
.await
.into_diagnostic()?;
let inner = response.into_inner();
let by_type = inner
.providers
.into_iter()
.map(|(provider_type, entry)| (provider_type, entry.environment))
.collect();
Ok(ProviderEnvironment { by_type })
}
/// A reusable gRPC client for the OpenShell service.
///
/// Wraps a tonic channel connected once and reused for policy polling
/// and status reporting, avoiding per-request TLS handshake overhead.
#[derive(Clone)]
pub struct CachedOpenShellClient {
client: OpenShellClient<Channel>,
}
/// Policy poll result returned by [`CachedOpenShellClient::poll_policy`].
pub struct PolicyPollResult {
pub policy: ProtoSandboxPolicy,
pub version: u32,
pub policy_hash: String,
}
impl CachedOpenShellClient {
pub async fn connect(endpoint: &str) -> Result<Self> {
debug!(endpoint = %endpoint, "Connecting openshell gRPC client for policy polling");
let channel = connect_channel(endpoint).await?;
let client = OpenShellClient::new(channel);
Ok(Self { client })
}
/// Get a clone of the underlying tonic client for direct RPC calls.
pub fn raw_client(&self) -> OpenShellClient<Channel> {
self.client.clone()
}
/// Poll for the current sandbox policy version.
pub async fn poll_policy(&self, sandbox_id: &str) -> Result<PolicyPollResult> {
let response = self
.client
.clone()
.get_sandbox_policy(GetSandboxPolicyRequest {
sandbox_id: sandbox_id.to_string(),
})
.await
.into_diagnostic()?;
let inner = response.into_inner();
let policy = inner
.policy
.ok_or_else(|| miette::miette!("Server returned empty policy"))?;
Ok(PolicyPollResult {
policy,
version: inner.version,
policy_hash: inner.policy_hash,
})
}
/// Submit denial summaries for policy analysis.
pub async fn submit_policy_analysis(
&self,
sandbox_name: &str,
summaries: Vec<DenialSummary>,
proposed_chunks: Vec<openshell_core::proto::PolicyChunk>,
analysis_mode: &str,
) -> Result<()> {
self.client
.clone()
.submit_policy_analysis(SubmitPolicyAnalysisRequest {
name: sandbox_name.to_string(),
summaries,
proposed_chunks,
analysis_mode: analysis_mode.to_string(),
})
.await
.into_diagnostic()?;
Ok(())
}
/// Report policy load status back to the server.
pub async fn report_policy_status(
&self,
sandbox_id: &str,
version: u32,
loaded: bool,
error_msg: &str,
) -> Result<()> {
let status = if loaded {
PolicyStatus::Loaded
} else {
PolicyStatus::Failed
};
self.client
.clone()
.report_policy_status(ReportPolicyStatusRequest {
sandbox_id: sandbox_id.to_string(),
version,
status: status.into(),
load_error: error_msg.to_string(),
})
.await
.into_diagnostic()?;
Ok(())
}
}
/// Fetch the resolved inference route bundle from the server.
pub async fn fetch_inference_bundle(endpoint: &str) -> Result<GetInferenceBundleResponse> {
debug!(endpoint = %endpoint, "Fetching inference route bundle");
let channel = connect_channel(endpoint).await?;
let mut client = InferenceClient::new(channel);
let response = client
.get_inference_bundle(GetInferenceBundleRequest {})
.await
.into_diagnostic()?;
Ok(response.into_inner())
}