Skip to content

Commit dde9e87

Browse files
authored
added CLI flag for creating API keys (tensorzero#4336)
* added CLI flag for API keys * use preexisting constants for API key gen
1 parent 90e9640 commit dde9e87

File tree

4 files changed

+116
-1
lines changed

4 files changed

+116
-1
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gateway/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ pin-project = { workspace = true }
2727
futures = { workspace = true }
2828
tokio-stream = { workspace = true }
2929
tensorzero-auth = { path = "../internal/tensorzero-auth" }
30+
sqlx_alpha = { package = "sqlx", version = "0.9.0-alpha.1", features = ["postgres", "runtime-tokio"] }
31+
secrecy = { workspace = true }
3032

3133
[lints]
3234
workspace = true
@@ -43,4 +45,4 @@ tracing-test = "0.2"
4345
serde_json = { workspace = true }
4446
reqwest-eventsource = { workspace = true }
4547
futures = { workspace = true }
46-
secrecy = { workspace = true }
48+
secrecy = { workspace = true }

gateway/src/main.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use clap::{Args, Parser};
22
use futures::{FutureExt, StreamExt};
33
use mimalloc::MiMalloc;
4+
use secrecy::ExposeSecret;
45
use std::fmt::Display;
56
use std::future::{Future, IntoFuture};
67
use std::io::ErrorKind;
@@ -12,6 +13,7 @@ use tokio::signal;
1213
use tokio_stream::wrappers::IntervalStream;
1314
use tower_http::metrics::in_flight_requests::InFlightRequestsCounter;
1415

16+
use tensorzero_auth::constants::{DEFAULT_ORGANIZATION, DEFAULT_WORKSPACE};
1517
use tensorzero_core::config::{Config, ConfigFileGlob};
1618
use tensorzero_core::db::clickhouse::migration_manager::manual_run_clickhouse_migrations;
1719
use tensorzero_core::db::postgres::{manual_run_postgres_migrations, PostgresConnectionInfo};
@@ -57,6 +59,34 @@ struct MigrationCommands {
5759
/// Run PostgreSQL migrations manually then exit.
5860
#[arg(long)]
5961
run_postgres_migrations: bool,
62+
63+
/// Create an API key then exit.
64+
#[arg(long)]
65+
create_api_key: bool,
66+
}
67+
68+
#[expect(clippy::print_stdout)]
69+
fn print_key(key: &secrecy::SecretString) {
70+
println!("{}", key.expose_secret());
71+
}
72+
73+
async fn handle_create_api_key() -> Result<(), Box<dyn std::error::Error>> {
74+
// Read the Postgres URL from the environment
75+
let postgres_url = std::env::var("TENSORZERO_POSTGRES_URL")
76+
.map_err(|_| "TENSORZERO_POSTGRES_URL environment variable not set")?;
77+
78+
// Create connection pool (alpha version for tensorzero-auth)
79+
let pool = sqlx_alpha::PgPool::connect(&postgres_url).await?;
80+
81+
// Create the key with default organization and workspace
82+
let key =
83+
tensorzero_auth::postgres::create_key(DEFAULT_ORGANIZATION, DEFAULT_WORKSPACE, None, &pool)
84+
.await?;
85+
86+
// Print only the API key to stdout for easy machine parsing
87+
print_key(&key);
88+
89+
Ok(())
6090
}
6191

6292
#[tokio::main]
@@ -70,6 +100,14 @@ async fn main() {
70100
.expect_pretty("Failed to set up logs");
71101

72102
let git_sha = tensorzero_core::built_info::GIT_COMMIT_HASH_SHORT.unwrap_or("unknown");
103+
104+
if args.migration_commands.create_api_key {
105+
handle_create_api_key()
106+
.await
107+
.expect_pretty("Failed to create API key");
108+
return;
109+
}
110+
73111
if args.migration_commands.run_clickhouse_migrations {
74112
manual_run_clickhouse_migrations()
75113
.await

gateway/tests/auth.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
#![allow(clippy::print_stdout)]
2+
use std::process::Stdio;
23
use std::str::FromStr;
34

45
use http::{Method, StatusCode};
56
use serde_json::json;
67
use tensorzero::test_helpers::make_embedded_gateway_with_config_and_postgres;
78
use tensorzero_auth::key::TensorZeroApiKey;
89
use tensorzero_core::endpoints::status::TENSORZERO_VERSION;
10+
use tokio::process::Command;
911

1012
use crate::common::start_gateway_on_random_port;
1113
use secrecy::ExposeSecret;
1214

1315
mod common;
1416

17+
const GATEWAY_PATH: &str = env!("CARGO_BIN_EXE_gateway");
18+
1519
#[tokio::test]
1620
async fn test_tensorzero_auth_enabled() {
1721
let child_data = start_gateway_on_random_port(
@@ -302,3 +306,73 @@ async fn test_tensorzero_missing_auth() {
302306
assert_eq!(status, StatusCode::UNAUTHORIZED);
303307
}
304308
}
309+
310+
#[tokio::test]
311+
async fn test_create_api_key_cli() {
312+
// This test verifies that the --create-api-key CLI command works correctly
313+
let output = Command::new(GATEWAY_PATH)
314+
.args(["--create-api-key"])
315+
.stdout(Stdio::piped())
316+
.stderr(Stdio::piped())
317+
.output()
318+
.await
319+
.unwrap();
320+
321+
assert!(
322+
output.status.success(),
323+
"CLI command failed with stderr: {}",
324+
String::from_utf8_lossy(&output.stderr)
325+
);
326+
327+
let stdout = String::from_utf8(output.stdout).unwrap();
328+
let api_key = stdout.trim();
329+
330+
// Verify the key has the correct format
331+
assert!(
332+
api_key.starts_with("sk-t0-"),
333+
"API key should start with 'sk-t0-', got: {api_key}"
334+
);
335+
336+
// Verify the key can be parsed
337+
let parsed_key = TensorZeroApiKey::parse(api_key);
338+
assert!(
339+
parsed_key.is_ok(),
340+
"API key should be valid, got error: {:?}",
341+
parsed_key.err()
342+
);
343+
344+
// Verify the key works for authentication
345+
let child_data = start_gateway_on_random_port(
346+
"
347+
[gateway.auth]
348+
enabled = true
349+
",
350+
None,
351+
)
352+
.await;
353+
354+
let inference_response = reqwest::Client::new()
355+
.post(format!("http://{}/inference", child_data.addr))
356+
.header(http::header::AUTHORIZATION, format!("Bearer {api_key}"))
357+
.json(&json!({
358+
"model_name": "dummy::good",
359+
"input": {
360+
"messages": [
361+
{
362+
"role": "user",
363+
"content": "Hello, world!",
364+
}
365+
]
366+
}
367+
}))
368+
.send()
369+
.await
370+
.unwrap();
371+
372+
let status = inference_response.status();
373+
assert_eq!(
374+
status,
375+
StatusCode::OK,
376+
"Created API key should work for authentication"
377+
);
378+
}

0 commit comments

Comments
 (0)