diff --git a/.github/workflows/go_test.yaml b/.github/workflows/go_test.yaml index f3ed44e..2bd17a6 100644 --- a/.github/workflows/go_test.yaml +++ b/.github/workflows/go_test.yaml @@ -118,6 +118,8 @@ jobs: echo "Service failed to start" echo "Logs:" docker compose logs spark4-thrift-server + docker compose logs tls-proxy + docker inspect --format "{{json .State.Health }}" go-tls-proxy-1 exit 1 fi fi @@ -230,6 +232,8 @@ jobs: echo "Service failed to start" echo "Logs:" docker compose logs ${{ matrix.service_name }} + docker compose logs tls-proxy + docker inspect --format "{{json .State.Health }}" go-tls-proxy-1 exit 1 fi fi diff --git a/go/ci/docker/caddy/Caddyfile b/go/ci/docker/caddy/Caddyfile new file mode 100644 index 0000000..5bdd7ff --- /dev/null +++ b/go/ci/docker/caddy/Caddyfile @@ -0,0 +1,35 @@ +# Copyright (c) 2026 ADBC Drivers Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + auto_https off +} + +https://:15003 { + reverse_proxy h2c://spark4-connect-server:15002 + + tls /etc/caddy/certs/server.crt /etc/caddy/certs/server.key +} + +https://:8999 { + reverse_proxy spark35-livy:8998 + + tls /etc/caddy/certs/server.crt /etc/caddy/certs/server.key +} + +https://:10002 { + reverse_proxy spark4-thrifthttp-server:10001 + + tls /etc/caddy/certs/server.crt /etc/caddy/certs/server.key +} diff --git a/go/client.go b/go/client.go index ad91d43..a9923fd 100644 --- a/go/client.go +++ b/go/client.go @@ -52,6 +52,9 @@ func parseOptionsFromUri(uri *url.URL, options map[string]string) error { } for key, values := range queryValues { + if key == "validateservercertificate" { + key = "validate_server_certificate" + } fullKey := fmt.Sprintf("spark.%s", key) if len(values) != 1 { return adbc.Error{ @@ -127,6 +130,23 @@ func parseIntegerOption(key string, options map[string]string, defaultValue uint return uint16(intOpt), nil } +func parseBoolOption(key string, options map[string]string, defaultValue bool) (bool, error) { + opt, ok := options[key] + if !ok { + return defaultValue, nil + } + delete(options, key) + opt = strings.ToLower(opt) + switch opt { + case "true", "1", "yes", "on": + return true, nil + case "false", "0", "no", "off": + return false, nil + default: + return false, sparkbase.InvalidOptionErr(key, opt) + } +} + // initializeAws sets up AWS configuration for SigV4 authentication func awsConfigFromOptions(ctx context.Context, options map[string]string) (aws.Config, error) { // Check if explicit credentials are provided @@ -176,10 +196,23 @@ func livyOptsFromOptions(ctx context.Context, options map[string]string) (livyim return livyOpts, err } - // TODO: come up with a better way to do this - // Allow explicit http:// + tls, err := parseBoolOption(OptionUseTls, options, false) + if err != nil { + return livyOpts, err + } + + validateServerCertificate, err := parseBoolOption(OptionValidateServerCertificate, options, true) + if err != nil { + return livyOpts, err + } + livyOpts.ValidateServerCertificate = validateServerCertificate + if !strings.Contains(host, "://") { - host = fmt.Sprintf("http://%s", host) + if tls { + host = fmt.Sprintf("https://%s", host) + } else { + host = fmt.Sprintf("http://%s", host) + } } livyOpts.BaseURL = host @@ -258,6 +291,16 @@ func connectOptsFromOptions(options map[string]string) (connectimpl.ConnectionOp connectOpts.Username = username } + // XXX: ignored, because spark-connect-go doesn't let you configure this + _, err = parseBoolOption(OptionUseTls, options, false) + if err != nil { + return connectOpts, err + } + _, err = parseBoolOption(OptionValidateServerCertificate, options, true) + if err != nil { + return connectOpts, err + } + authType, ok := options[OptionAuthType] if !ok { return connectOpts, sparkbase.MissingRequiredOptionErr(OptionAuthType) @@ -296,6 +339,18 @@ func thriftOptsFromOptions(options map[string]string) (thriftimpl.ConnectionOpts } thriftOpts.Host = host + tls, err := parseBoolOption(OptionUseTls, options, false) + if err != nil { + return thriftOpts, err + } + thriftOpts.Tls = tls + + validateServerCertificate, err := parseBoolOption(OptionValidateServerCertificate, options, true) + if err != nil { + return thriftOpts, err + } + thriftOpts.ValidateServerCertificate = validateServerCertificate + switch authType { case OptionValueAuthTypeNoSasl: thriftOpts.Auth = thriftimpl.NoSasl diff --git a/go/compose.yaml b/go/compose.yaml index 1af0cd1..667ebff 100644 --- a/go/compose.yaml +++ b/go/compose.yaml @@ -207,6 +207,8 @@ services: condition: service_healthy minio-init: condition: service_completed_successfully + tls-proxy: + condition: service_started environment: - SPARK_MASTER=spark://spark35-master:7077 ports: @@ -309,6 +311,44 @@ services: retries: 10 start_period: 45s + tls-proxy-init: + image: alpine:3.21 + command: + - sh + - -c + - | + apk add --no-cache openssl && + mkdir -p /certs && + if [ ! -f /certs/server.crt ]; then + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout /certs/server.key \ + -out /certs/server.crt \ + -days 3650 \ + -subj "/CN=tls-proxy" \ + -addext "subjectAltName=DNS:tls-proxy,DNS:localhost" + fi + volumes: + - ./.data/tls:/certs + + tls-proxy: + image: caddy:2-alpine + depends_on: + tls-proxy-init: + condition: service_completed_successfully + ports: + - "15003:15003" # Spark Connect gRPC over TLS + - "8999:8999" # Livy REST API over TLS + - "10002:10002" # Thrift HTTP over TLS + volumes: + - ./ci/docker/caddy/Caddyfile:/etc/caddy/Caddyfile:ro + - ./.data/tls:/etc/caddy/certs:ro + healthcheck: + test: ["CMD-SHELL", "wget -q -O /dev/null http://127.0.0.1:2019/config/ || exit 1"] + interval: 5s + timeout: 5s + retries: 10 + start_period: 5s + spark4-connect-server: build: context: ./ci/docker/ @@ -323,6 +363,8 @@ services: condition: service_healthy minio-init: condition: service_completed_successfully + tls-proxy: + condition: service_started environment: SPARK_MASTER: spark://spark4-master:7077 ports: @@ -353,6 +395,8 @@ services: condition: service_healthy minio-init: condition: service_completed_successfully + tls-proxy: + condition: service_started environment: SPARK_MASTER: spark://spark4-master:7077 SPARK_SERVER_TYPE: thrifthttp diff --git a/go/docs/spark.md b/go/docs/spark.md index 6d58610..5666b80 100644 --- a/go/docs/spark.md +++ b/go/docs/spark.md @@ -104,6 +104,16 @@ These parameters can be specified in the URI as query parameters, or as connecti Currently only `sql` is tested/supported. ::: +`spark.tls` (query parameter: `tls`) +: **Type** boolean. **Default**: false. + + Whether to use TLS for connecting. Only applies to `connect`, `livy`, and `thrift+http`. + +`spark.validate_server_certificate` (query parameter: `validateservercertificate`) +: **Type** boolean. **Default**: true. + + Whether to validate the server's TLS certificate. Should only be disabled for development/testing. + ## Limitations Different backends have limitations; some limitations related to data type support are also noted further below. diff --git a/go/internal/livyimpl/client.go b/go/internal/livyimpl/client.go index 2665737..60354e6 100644 --- a/go/internal/livyimpl/client.go +++ b/go/internal/livyimpl/client.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "crypto/sha256" + "crypto/tls" "encoding/hex" "encoding/json" "fmt" @@ -56,13 +57,14 @@ type ConnectionOpts struct { SessionKind SessionKind AuthType AuthType - BaseURL string - HttpTimeoutSeconds uint - HeartbeatTimeoutSeconds uint - QueryTimeoutSeconds uint - Username string - Password string - SessionTtl string + BaseURL string + HttpTimeoutSeconds uint + HeartbeatTimeoutSeconds uint + QueryTimeoutSeconds uint + Username string + Password string + SessionTtl string + ValidateServerCertificate bool AwsConfig aws.Config } @@ -88,10 +90,18 @@ type livyClient struct { // NewClient creates a new SparkClient over Livy client func NewClient(ctx context.Context, opts ConnectionOpts, sessionConfig map[string]string) (sparkbase.SparkClient, error) { + httpClient := &http.Client{ + Timeout: time.Duration(float64(opts.HttpTimeoutSeconds) * float64(time.Second)), + } + if !opts.ValidateServerCertificate { + httpClient.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } client := &livyClient{ sessionID: -1, baseURL: opts.BaseURL, - httpClient: &http.Client{Timeout: time.Duration(float64(opts.HttpTimeoutSeconds) * float64(time.Second))}, + httpClient: httpClient, queryTimeout: time.Duration(float64(opts.QueryTimeoutSeconds) * float64(time.Second)), heartbeatTimeout: time.Duration(float64(opts.HeartbeatTimeoutSeconds) * float64(time.Second)), authType: opts.AuthType, diff --git a/go/internal/thriftimpl/client.go b/go/internal/thriftimpl/client.go index f8dfd87..2416d04 100644 --- a/go/internal/thriftimpl/client.go +++ b/go/internal/thriftimpl/client.go @@ -16,9 +16,11 @@ package thriftimpl import ( "context" + "crypto/tls" "encoding/base64" "errors" "fmt" + "net/http" "strings" "github.com/adbc-drivers/apache/go/internal/hiveserver2" @@ -51,8 +53,10 @@ type ConnectionOpts struct { Catalog string - Username string - Password string + Username string + Password string + Tls bool + ValidateServerCertificate bool Host string } @@ -104,9 +108,21 @@ func NewClient(ctx context.Context, opts ConnectionOpts) (sparkbase.SparkClient, switch opts.Transport { case Http: transportName = "HTTP" - // TODO(lidavidm): TLS, configurable HTTP path - uri := "http://" + opts.Host + "/cliservice" - transport, err = thrift.NewTHttpClient(uri) + var uri string + if opts.Tls { + uri = "https://" + opts.Host + "/cliservice" + } else { + uri = "http://" + opts.Host + "/cliservice" + } + httpClientOptions := thrift.THttpClientOptions{} + if !opts.ValidateServerCertificate { + httpClientOptions.Client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + } + transport, err = thrift.NewTHttpClientWithOptions(uri, httpClientOptions) if err != nil { return nil, sparkbase.ErrToAdbcErr(adbc.StatusIO, err, "could not open HTTP thrift client") } diff --git a/go/options.go b/go/options.go index 6230515..92cbe3e 100644 --- a/go/options.go +++ b/go/options.go @@ -36,7 +36,9 @@ const ( // OptionAuthType specifies the authentication method used by the driver OptionAuthType = "spark.auth_type" // OptionSchema specifies the default schema to connect to - OptionSchema = "spark.schema" + OptionSchema = "spark.schema" + OptionUseTls = "spark.tls" + OptionValidateServerCertificate = "spark.validate_server_certificate" // Spark Configuration Prefix // Options starting with this prefix are passed to the Spark session configuration diff --git a/go/validation/tests/test_auth.py b/go/validation/tests/test_auth.py index d581e39..77bf842 100644 --- a/go/validation/tests/test_auth.py +++ b/go/validation/tests/test_auth.py @@ -90,7 +90,7 @@ def test_auth(subtests, driver, driver_path): for replacement, error_message in cases: new_uri = uri.replace(orig, replacement) - if replacement == "auth_type=nosasl": + if replacement in ("auth_type=none", "auth_type=nosasl"): kwargs = {} else: kwargs = { @@ -108,3 +108,86 @@ def test_auth(subtests, driver, driver_path): ) as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1") + + +def test_tls(subtests, driver, driver_path): + if driver.short_version.endswith("-connect"): + # Spark Connect is "special" and forces plaintext if you don't have a + # token and TLS if you do + uri = os.environ["SPARK_CONNECT_URI"].replace("15002", "15003") + uri = uri.replace("auth_type=none", "auth_type=token") + uri += "&tls=true&validateservercertificate=false" + # XXX: there is no way to skip certificate checking for spark-connect-go + + with pytest.raises( + adbc_driver_manager.Error, match="failed to verify certificate" + ): + with adbc_driver_manager.dbapi.connect( + driver=driver_path, + uri=uri, + autocommit=True, + db_kwargs={ + "username": "spark", + "password": "spark", + }, + ) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + + return + + elif driver.short_version.endswith("-thrift"): + return + elif driver.short_version.endswith("-thrifthttp"): + uri = os.environ["SPARK_THRIFTHTTP_URI"].replace("10001", "10002") + uri += "&tls=true&validateservercertificate=false" + elif driver.short_version.endswith("-livy"): + uri = os.environ["SPARK_LIVY_URI"].replace("8998", "8999") + uri += "&tls=true&validateservercertificate=false" + else: + raise NotImplementedError(driver.short_version) + + with adbc_driver_manager.dbapi.connect( + driver=driver_path, + uri=uri, + autocommit=True, + db_kwargs={ + "username": "spark", + "password": "spark", + }, + ) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + assert cursor.fetchall() == [(1,)] + + +def test_tls_verify(subtests, driver, driver_path): + if driver.short_version.endswith("-connect"): + # Spark Connect is "special" and forces plaintext if you don't have a + # token and TLS if you do + uri = os.environ["SPARK_CONNECT_URI"].replace("15002", "15003") + uri = uri.replace("auth_type=none", "auth_type=token") + uri += "&tls=true" + elif driver.short_version.endswith("-thrift"): + return + elif driver.short_version.endswith("-thrifthttp"): + uri = os.environ["SPARK_THRIFTHTTP_URI"].replace("10001", "10002") + uri += "&tls=true" + elif driver.short_version.endswith("-livy"): + uri = os.environ["SPARK_LIVY_URI"].replace("8998", "8999") + uri += "&tls=true" + else: + raise NotImplementedError(driver.short_version) + + with pytest.raises(adbc_driver_manager.Error, match="failed to verify certificate"): + with adbc_driver_manager.dbapi.connect( + driver=driver_path, + uri=uri, + autocommit=True, + db_kwargs={ + "username": "spark", + "password": "spark", + }, + ) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1")