Skip to content

Commit

Permalink
Merge pull request #6 from lmangani/auth
Browse files Browse the repository at this point in the history
API Authentication (basic or x-header)
  • Loading branch information
lmangani authored Oct 15, 2024
2 parents 29a7049 + bc6afe2 commit ec47a4c
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
39 changes: 31 additions & 8 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,39 @@ LOAD httpserver;
```

### 🔌 Usage
Start the HTTP server providing the `host` and `port` parameters
Start the HTTP server providing the `host`, `port` and `auth` parameters.<br>
> If you want no authhentication, just pass an empty string.
#### Basic Auth
```sql
D SELECT httpserve_start('0.0.0.0',9999);
┌─────────────────────────────────────┐
│ httpserve_start('0.0.0.0', 9999) │
varchar
├─────────────────────────────────────┤
│ HTTP server started on 0.0.0.0:9999
└─────────────────────────────────────┘
D SELECT httpserve_start('localhost', 9999, 'user:pass');

┌───────────────────────────────────────────────┐
│ httpserve_start('0.0.0.0', 9999, 'user:pass') │
varchar
├───────────────────────────────────────────────┤
│ HTTP server started on 0.0.0.0:9999
└───────────────────────────────────────────────┘
```
```bash
curl -X POST -d "SELECT 'hello', version()" "http://user:pass@localhost:9999/"
```

#### Token Auth
```sql
SELECT httpserve_start('localhost', 9999, 'supersecretkey');

┌───────────────────────────────────────────────┐
│ httpserve_start('0.0.0.0', 9999, 'secretkey') │
varchar
├───────────────────────────────────────────────┤
│ HTTP server started on 0.0.0.0:9999
└───────────────────────────────────────────────┘
```
```
curl -X POST --header "X-API-Key: supersecretkey" -d "SELECT 'hello', version()" "http://localhost:9999/"
```


#### 👉 QUERY UI
Browse to your endpoint and use the built-in quackplay interface _(experimental)_
Expand Down
62 changes: 59 additions & 3 deletions src/httpserver_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct HttpServerState {
std::atomic<bool> is_running;
DatabaseInstance* db_instance;
unique_ptr<Allocator> allocator;
std::string auth_token;

HttpServerState() : is_running(false), db_instance(nullptr) {}
};
Expand Down Expand Up @@ -129,6 +130,51 @@ static std::string ConvertResultToJSON(MaterializedQueryResult &result, ReqStats
return json_output;
}

// New: Base64 decoding function
std::string base64_decode(const std::string &in) {
std::string out;
std::vector<int> T(256, -1);
for (int i = 0; i < 64; i++)
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;

int val = 0, valb = -8;
for (unsigned char c : in) {
if (T[c] == -1) break;
val = (val << 6) + T[c];
valb += 6;
if (valb >= 0) {
out.push_back(char((val >> valb) & 0xFF));
valb -= 8;
}
}
return out;
}

// Auth Check
bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) {
if (global_state.auth_token.empty()) {
return true; // No authentication required if no token is set
}

// Check for X-API-Key header
auto api_key = req.get_header_value("X-API-Key");
if (!api_key.empty() && api_key == global_state.auth_token) {
return true;
}

// Check for Basic Auth
auto auth = req.get_header_value("Authorization");
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
std::string decoded_auth = base64_decode(auth.substr(6));
if (decoded_auth == global_state.auth_token) {
return true;
}
}

return false;
}


// Convert the query result to NDJSON (JSONEachRow) format
static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
std::string ndjson_output;
Expand Down Expand Up @@ -208,6 +254,13 @@ static void HandleQuery(const string& query, duckdb_httplib_openssl::Response& r
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
std::string query;

// Check authentication
if (!IsAuthenticated(req)) {
res.status = 401;
res.set_content("Unauthorized", "text/plain");
return;
}

// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
Expand Down Expand Up @@ -295,14 +348,15 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
}
}

void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) {
void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t auth = string_t()) {
if (global_state.is_running) {
throw IOException("HTTP server is already running");
}

global_state.db_instance = &db;
global_state.server = make_uniq<duckdb_httplib_openssl::Server>();
global_state.is_running = true;
global_state.auth_token = auth.GetString();

// CORS Preflight
global_state.server->Options("/",
Expand Down Expand Up @@ -359,17 +413,19 @@ static void HttpServerCleanup() {

static void LoadInternal(DatabaseInstance &instance) {
auto httpserve_start = ScalarFunction("httpserve_start",
{LogicalType::VARCHAR, LogicalType::INTEGER},
{LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR},
LogicalType::VARCHAR,
[&](DataChunk &args, ExpressionState &state, Vector &result) {
auto &host_vector = args.data[0];
auto &port_vector = args.data[1];
auto &auth_vector = args.data[2];

UnaryExecutor::Execute<string_t, string_t>(
host_vector, result, args.size(),
[&](string_t host) {
auto port = ((int32_t*)port_vector.GetData())[0];
HttpServerStart(instance, host, port);
auto auth = ((string_t*)auth_vector.GetData())[0];
HttpServerStart(instance, host, port, auth);
return StringVector::AddString(result, "HTTP server started on " + host.GetString() + ":" + std::to_string(port));
});
});
Expand Down

0 comments on commit ec47a4c

Please sign in to comment.