diff --git a/CHANGELOG.md b/CHANGELOG.md index 74a05213c..9ee5d2621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,67 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [11.3.0] + +- Adds SAML features + +### Migration + +```sql +CREATE TABLE IF NOT EXISTS saml_clients ( + app_id VARCHAR(64) NOT NULL DEFAULT 'public', + tenant_id VARCHAR(64) NOT NULL DEFAULT 'public', + client_id VARCHAR(256) NOT NULL, + client_secret TEXT, + sso_login_url TEXT NOT NULL, + redirect_uris TEXT NOT NULL, + default_redirect_uri TEXT NOT NULL, + idp_entity_id VARCHAR(256) NOT NULL, + idp_signing_certificate TEXT NOT NULL, + allow_idp_initiated_login BOOLEAN NOT NULL DEFAULT FALSE, + enable_request_signing BOOLEAN NOT NULL DEFAULT FALSE, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + CONSTRAINT saml_clients_pkey PRIMARY KEY(app_id, tenant_id, client_id), + CONSTRAINT saml_clients_idp_entity_id_key UNIQUE (app_id, tenant_id, idp_entity_id), + CONSTRAINT saml_clients_app_id_fkey FOREIGN KEY(app_id) REFERENCES apps (app_id) ON DELETE CASCADE, + CONSTRAINT saml_clients_tenant_id_fkey FOREIGN KEY(app_id, tenant_id) REFERENCES tenants (app_id, tenant_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS saml_clients_app_id_tenant_id_index ON saml_clients (app_id, tenant_id); + +CREATE TABLE IF NOT EXISTS saml_relay_state ( + app_id VARCHAR(64) NOT NULL DEFAULT 'public', + tenant_id VARCHAR(64) NOT NULL DEFAULT 'public', + relay_state VARCHAR(256) NOT NULL, + client_id VARCHAR(256) NOT NULL, + state TEXT NOT NULL, + redirect_uri TEXT NOT NULL, + created_at BIGINT NOT NULL, + CONSTRAINT saml_relay_state_pkey PRIMARY KEY(app_id, tenant_id, relay_state), + CONSTRAINT saml_relay_state_app_id_fkey FOREIGN KEY(app_id) REFERENCES apps (app_id) ON DELETE CASCADE, + CONSTRAINT saml_relay_state_tenant_id_fkey FOREIGN KEY(app_id, tenant_id) REFERENCES tenants (app_id, tenant_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS saml_relay_state_app_id_tenant_id_index ON saml_relay_state (app_id, tenant_id); +CREATE INDEX IF NOT EXISTS saml_relay_state_expires_at_index ON saml_relay_state (expires_at); + +CREATE TABLE IF NOT EXISTS saml_claims ( + app_id VARCHAR(64) NOT NULL DEFAULT 'public', + tenant_id VARCHAR(64) NOT NULL DEFAULT 'public', + client_id VARCHAR(256) NOT NULL, + code VARCHAR(256) NOT NULL, + claims TEXT NOT NULL, + created_at BIGINT NOT NULL, + CONSTRAINT saml_claims_pkey PRIMARY KEY(app_id, tenant_id, code), + CONSTRAINT saml_claims_app_id_fkey FOREIGN KEY(app_id) REFERENCES apps (app_id) ON DELETE CASCADE, + CONSTRAINT saml_claims_tenant_id_fkey FOREIGN KEY(app_id, tenant_id) REFERENCES tenants (app_id, tenant_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS saml_claims_app_id_tenant_id_index ON saml_claims (app_id, tenant_id); +CREATE INDEX IF NOT EXISTS saml_claims_expires_at_index ON saml_claims (expires_at); +``` + ## [11.2.0] - Adds opentelemetry-javaagent to the core distribution diff --git a/build.gradle b/build.gradle index 6720cd723..20bf2472c 100644 --- a/build.gradle +++ b/build.gradle @@ -27,10 +27,12 @@ java { } } -version = "11.2.0" +version = "11.3.0" repositories { mavenCentral() + + maven { url 'https://build.shibboleth.net/nexus/content/repositories/releases/' } } dependencies { @@ -86,11 +88,16 @@ dependencies { implementation platform("io.opentelemetry.instrumentation:opentelemetry-instrumentation-bom-alpha:2.17.0-alpha") + // Open SAML + implementation group: 'org.opensaml', name: 'opensaml-core', version: '4.3.1' + implementation group: 'org.opensaml', name: 'opensaml-saml-impl', version: '4.3.1' + implementation group: 'org.opensaml', name: 'opensaml-security-impl', version: '4.3.1' + implementation group: 'org.opensaml', name: 'opensaml-profile-impl', version: '4.3.1' + implementation group: 'org.opensaml', name: 'opensaml-xmlsec-impl', version: '4.3.1' implementation("ch.qos.logback:logback-core:1.5.18") implementation("ch.qos.logback:logback-classic:1.5.18") - // OpenTelemetry core implementation("io.opentelemetry:opentelemetry-sdk") implementation("io.opentelemetry:opentelemetry-exporter-otlp") diff --git a/config.yaml b/config.yaml index 4459cbaad..644d14efa 100644 --- a/config.yaml +++ b/config.yaml @@ -186,3 +186,9 @@ core_config_version: 0 # (OPTIONAL | Default: null) string value. The URL of the OpenTelemetry collector to which the core # will send telemetry data. This should be in the format http://: or https://:. # otel_collector_connection_uri: + +# (OPTIONAL | Default: null) string value. If specified, uses this URL as ACS URL for handling legacy SAML clients +# saml_legacy_acs_url: + +# (OPTIONAL | Default: https://saml.supertokens.com) string value. Service provider's entity ID. +# saml_sp_entity_id: diff --git a/coreDriverInterfaceSupported.json b/coreDriverInterfaceSupported.json index e3d03b4d2..908905417 100644 --- a/coreDriverInterfaceSupported.json +++ b/coreDriverInterfaceSupported.json @@ -22,6 +22,7 @@ "5.0", "5.1", "5.2", - "5.3" + "5.3", + "5.4" ] } diff --git a/devConfig.yaml b/devConfig.yaml index fe55683b6..3e20760ba 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -185,4 +185,10 @@ disable_telemetry: true # (OPTIONAL | Default: null) string value. The URL of the OpenTelemetry collector to which the core # will send telemetry data. This should be in the format http://: or https://:. -# otel_collector_connection_uri: \ No newline at end of file +# otel_collector_connection_uri: + +# (OPTIONAL | Default: null) string value. If specified, uses this URL as ACS URL for handling legacy SAML clients +saml_legacy_acs_url: "http://localhost:5225/api/oauth/saml" + +# (OPTIONAL | Default: https://saml.supertokens.com) string value. Service provider's entity ID. +# saml_sp_entity_id: diff --git a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java index 4a440d0e2..b44937660 100644 --- a/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java +++ b/ee/src/main/java/io/supertokens/ee/EEFeatureFlag.java @@ -34,6 +34,7 @@ import io.supertokens.pluginInterface.multitenancy.ThirdPartyConfig; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.pluginInterface.session.sqlStorage.SessionSQLStorage; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.utils.Utils; @@ -386,6 +387,34 @@ private JsonArray getMAUs() throws StorageQueryException, TenantOrAppNotFoundExc return mauArr; } + private JsonObject getSAMLStats() throws TenantOrAppNotFoundException, StorageQueryException { + JsonObject stats = new JsonObject(); + + stats.addProperty("connectionUriDomain", this.appIdentifier.getConnectionUriDomain()); + stats.addProperty("appId", this.appIdentifier.getAppId()); + + JsonArray tenantStats = new JsonArray(); + + TenantConfig[] tenantConfigs = Multitenancy.getAllTenantsForApp(this.appIdentifier, main); + for (TenantConfig tenantConfig : tenantConfigs) { + JsonObject tenantStat = new JsonObject(); + tenantStat.addProperty("tenantId", tenantConfig.tenantIdentifier.getTenantId()); + + { + Storage storage = StorageLayer.getStorage(tenantConfig.tenantIdentifier, main); + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + + JsonObject stat = new JsonObject(); + stat.addProperty("numberOfSAMLClients", samlStorage.countSAMLClients(tenantConfig.tenantIdentifier)); + stat.add(tenantConfig.tenantIdentifier.getTenantId(), stat); + } + } + + stats.add("tenants", tenantStats); + + return stats; + } + @Override public JsonObject getPaidFeatureStats() throws StorageQueryException, TenantOrAppNotFoundException { JsonObject usageStats = new JsonObject(); @@ -433,6 +462,10 @@ public JsonObject getPaidFeatureStats() throws StorageQueryException, TenantOrAp if (feature == EE_FEATURES.OAUTH) { usageStats.add(EE_FEATURES.OAUTH.toString(), getOAuthStats()); } + + if (feature == EE_FEATURES.SAML) { + usageStats.add(EE_FEATURES.SAML.toString(), getSAMLStats()); + } } usageStats.add("maus", getMAUs()); diff --git a/implementationDependencies.json b/implementationDependencies.json index e7807f409..cc34a42a6 100644 --- a/implementationDependencies.json +++ b/implementationDependencies.json @@ -101,6 +101,146 @@ "name":"webauthn4j-core 0.28.6.RELEASE", "src":"https://repo.maven.apache.org/maven2/com/webauthn4j/webauthn4j-core/0.28.6.RELEASE/webauthn4j-core-0.28.6.RELEASE-sources.jar" }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-core/4.3.1/opensaml-core-4.3.1.jar", + "name":"opensaml-core 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-core/4.3.1/opensaml-core-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/net/shibboleth/utilities/java-support/8.4.1/java-support-8.4.1.jar", + "name":"java-support 8.4.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/net/shibboleth/utilities/java-support/8.4.1/java-support-8.4.1-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/com/google/guava/guava/31.1-jre/guava-31.1-jre.jar", + "name":"guava 31.1-jre", + "src":"https://repo.maven.apache.org/maven2/com/google/guava/guava/31.1-jre/guava-31.1-jre-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/com/google/guava/failureaccess/1.0.1/failureaccess-1.0.1.jar", + "name":"failureaccess 1.0.1", + "src":"https://repo.maven.apache.org/maven2/com/google/guava/failureaccess/1.0.1/failureaccess-1.0.1-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/com/google/guava/listenablefuture/9999.0-empty-to-avoid-conflict-with-guava/listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar", + "name":"listenablefuture 9999.0-empty-to-avoid-conflict-with-guava", + "src":"https://repo.maven.apache.org/maven2/com/google/guava/listenablefuture/9999.0-empty-to-avoid-conflict-with-guava/listenablefuture-9999.0-empty-to-avoid-conflict-with-guava-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/com/google/j2objc/j2objc-annotations/1.3/j2objc-annotations-1.3.jar", + "name":"j2objc-annotations 1.3", + "src":"https://repo.maven.apache.org/maven2/com/google/j2objc/j2objc-annotations/1.3/j2objc-annotations-1.3-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/io/dropwizard/metrics/metrics-core/4.2.25/metrics-core-4.2.25.jar", + "name":"metrics-core 4.2.25", + "src":"https://repo.maven.apache.org/maven2/io/dropwizard/metrics/metrics-core/4.2.25/metrics-core-4.2.25-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-saml-impl/4.3.1/opensaml-saml-impl-4.3.1.jar", + "name":"opensaml-saml-impl 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-saml-impl/4.3.1/opensaml-saml-impl-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-xmlsec-impl/4.3.1/opensaml-xmlsec-impl-4.3.1.jar", + "name":"opensaml-xmlsec-impl 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-xmlsec-impl/4.3.1/opensaml-xmlsec-impl-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-security-impl/4.3.1/opensaml-security-impl-4.3.1.jar", + "name":"opensaml-security-impl 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-security-impl/4.3.1/opensaml-security-impl-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-security-api/4.3.1/opensaml-security-api-4.3.1.jar", + "name":"opensaml-security-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-security-api/4.3.1/opensaml-security-api-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-messaging-api/4.3.1/opensaml-messaging-api-4.3.1.jar", + "name":"opensaml-messaging-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-messaging-api/4.3.1/opensaml-messaging-api-4.3.1-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/apache/httpcomponents/httpclient/4.5.14/httpclient-4.5.14.jar", + "name":"httpclient 4.5.14", + "src":"https://repo.maven.apache.org/maven2/org/apache/httpcomponents/httpclient/4.5.14/httpclient-4.5.14-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/apache/httpcomponents/httpcore/4.4.16/httpcore-4.4.16.jar", + "name":"httpcore 4.4.16", + "src":"https://repo.maven.apache.org/maven2/org/apache/httpcomponents/httpcore/4.4.16/httpcore-4.4.16-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/cryptacular/cryptacular/1.2.5/cryptacular-1.2.5.jar", + "name":"cryptacular 1.2.5", + "src":"https://repo.maven.apache.org/maven2/org/cryptacular/cryptacular/1.2.5/cryptacular-1.2.5-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcprov-jdk18on/1.72/bcprov-jdk18on-1.72.jar", + "name":"bcprov-jdk18on 1.72", + "src":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcprov-jdk18on/1.72/bcprov-jdk18on-1.72-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcpkix-jdk18on/1.72/bcpkix-jdk18on-1.72.jar", + "name":"bcpkix-jdk18on 1.72", + "src":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcpkix-jdk18on/1.72/bcpkix-jdk18on-1.72-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcutil-jdk18on/1.72/bcutil-jdk18on-1.72.jar", + "name":"bcutil-jdk18on 1.72", + "src":"https://repo.maven.apache.org/maven2/org/bouncycastle/bcutil-jdk18on/1.72/bcutil-jdk18on-1.72-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-xmlsec-api/4.3.1/opensaml-xmlsec-api-4.3.1.jar", + "name":"opensaml-xmlsec-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-xmlsec-api/4.3.1/opensaml-xmlsec-api-4.3.1-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/apache/santuario/xmlsec/2.3.4/xmlsec-2.3.4.jar", + "name":"xmlsec 2.3.4", + "src":"https://repo.maven.apache.org/maven2/org/apache/santuario/xmlsec/2.3.4/xmlsec-2.3.4-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-saml-api/4.3.1/opensaml-saml-api-4.3.1.jar", + "name":"opensaml-saml-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-saml-api/4.3.1/opensaml-saml-api-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-profile-api/4.3.1/opensaml-profile-api-4.3.1.jar", + "name":"opensaml-profile-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-profile-api/4.3.1/opensaml-profile-api-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-soap-api/4.3.1/opensaml-soap-api-4.3.1.jar", + "name":"opensaml-soap-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-soap-api/4.3.1/opensaml-soap-api-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-soap-impl/4.3.1/opensaml-soap-impl-4.3.1.jar", + "name":"opensaml-soap-impl 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-soap-impl/4.3.1/opensaml-soap-impl-4.3.1-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-storage-api/4.3.1/opensaml-storage-api-4.3.1.jar", + "name":"opensaml-storage-api 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-storage-api/4.3.1/opensaml-storage-api-4.3.1-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/apache/velocity/velocity-engine-core/2.3/velocity-engine-core-2.3.jar", + "name":"velocity-engine-core 2.3", + "src":"https://repo.maven.apache.org/maven2/org/apache/velocity/velocity-engine-core/2.3/velocity-engine-core-2.3-sources.jar" + }, + { + "jar":"https://repo.maven.apache.org/maven2/org/apache/commons/commons-lang3/3.11/commons-lang3-3.11.jar", + "name":"commons-lang3 3.11", + "src":"https://repo.maven.apache.org/maven2/org/apache/commons/commons-lang3/3.11/commons-lang3-3.11-sources.jar" + }, + { + "jar":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-profile-impl/4.3.1/opensaml-profile-impl-4.3.1.jar", + "name":"opensaml-profile-impl 4.3.1", + "src":"https://build.shibboleth.net/nexus/content/repositories/releases/org/opensaml/opensaml-profile-impl/4.3.1/opensaml-profile-impl-4.3.1-sources.jar" + }, { "jar":"https://repo.maven.apache.org/maven2/ch/qos/logback/logback-core/1.5.18/logback-core-1.5.18.jar", "name":"logback-core 1.5.18", diff --git a/pluginInterfaceSupported.json b/pluginInterfaceSupported.json index b48b96cd4..6f394e622 100644 --- a/pluginInterfaceSupported.json +++ b/pluginInterfaceSupported.json @@ -1,6 +1,6 @@ { "_comment": "contains a list of plugin interfaces branch names that this core supports", "versions": [ - "8.2" + "8.3" ] } \ No newline at end of file diff --git a/src/main/java/io/supertokens/Main.java b/src/main/java/io/supertokens/Main.java index 95e0b0d9f..7b7c524e4 100644 --- a/src/main/java/io/supertokens/Main.java +++ b/src/main/java/io/supertokens/Main.java @@ -22,6 +22,7 @@ import io.supertokens.cronjobs.Cronjobs; import io.supertokens.cronjobs.bulkimport.ProcessBulkImportUsers; import io.supertokens.cronjobs.cleanupOAuthSessionsAndChallenges.CleanupOAuthSessionsAndChallenges; +import io.supertokens.cronjobs.deleteExpiredSAMLData.DeleteExpiredSAMLData; import io.supertokens.cronjobs.cleanupWebauthnExpiredData.CleanUpWebauthNExpiredDataCron; import io.supertokens.cronjobs.deleteExpiredAccessTokenSigningKeys.DeleteExpiredAccessTokenSigningKeys; import io.supertokens.cronjobs.deleteExpiredDashboardSessions.DeleteExpiredDashboardSessions; @@ -42,6 +43,7 @@ import io.supertokens.pluginInterface.exceptions.InvalidConfigException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.saml.SAMLBootstrap; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.telemetry.TelemetryProvider; import io.supertokens.version.Version; @@ -182,6 +184,9 @@ private void init() throws IOException, StorageQueryException { // init file logging Logging.initFileLogging(this); + // Required for SAML related stuff + SAMLBootstrap.initialize(); + // initialise cron job handler Cronjobs.init(this); @@ -278,6 +283,8 @@ private void init() throws IOException, StorageQueryException { Cronjobs.addCronjob(this, CleanUpWebauthNExpiredDataCron.init(this, uniqueUserPoolIdsTenants)); + Cronjobs.addCronjob(this, DeleteExpiredSAMLData.init(this, uniqueUserPoolIdsTenants)); + // this is to ensure tenantInfos are in sync for the new cron job as well MultitenancyHelper.getInstance(this).refreshCronjobs(); diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index 1b103d7cf..31b34d534 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -67,7 +67,8 @@ public class CoreConfig { "oauth_provider_public_service_url", "oauth_provider_admin_service_url", "oauth_provider_consent_login_base_url", - "oauth_provider_url_configured_in_oauth_provider" + "oauth_provider_url_configured_in_oauth_provider", + "saml_legacy_acs_url" }; @IgnoreForAnnotationCheck @@ -377,6 +378,19 @@ public class CoreConfig { "the database and block all other CUDs from being used from this instance.") private String supertokens_saas_load_only_cud = null; + @EnvName("SAML_LEGACY_ACS_URL") + @NotConflictingInApp + @JsonProperty + @ConfigDescription("If specified, uses this URL as ACS URL for handling legacy SAML clients") + @HideFromDashboard + private String saml_legacy_acs_url = null; + + @EnvName("SAML_SP_ENTITY_ID") + @JsonProperty + @IgnoreForAnnotationCheck + @ConfigDescription("Service provider's entity ID") + private String saml_sp_entity_id = null; + @IgnoreForAnnotationCheck private Set allowedLogLevels = null; @@ -480,6 +494,10 @@ public String getIpDenyRegex() { return ip_deny_regex; } + public String getLogLevel() { + return log_level; + } + public Set getLogLevels(Main main) { if (allowedLogLevels != null) { return allowedLogLevels; @@ -663,6 +681,14 @@ public String getOtelCollectorConnectionURI() { return otel_collector_connection_uri; } + public String getSAMLLegacyACSURL() { + return saml_legacy_acs_url; + } + + public String getSAMLSPEntityID() { + return saml_sp_entity_id; + } + private String getConfigFileLocation(Main main) { return new File(CLIOptions.get(main).getConfigFilePath() == null ? CLIOptions.get(main).getInstallationPath() + "config.yaml" @@ -931,6 +957,10 @@ void normalizeAndValidate(Main main, boolean includeConfigFilePath) throws Inval } // Normalize + if (saml_sp_entity_id == null) { + saml_sp_entity_id = "https://saml.supertokens.com"; + } + if (ip_allow_regex != null) { ip_allow_regex = ip_allow_regex.trim(); if (ip_allow_regex.equals("")) { diff --git a/src/main/java/io/supertokens/cronjobs/deleteExpiredSAMLData/DeleteExpiredSAMLData.java b/src/main/java/io/supertokens/cronjobs/deleteExpiredSAMLData/DeleteExpiredSAMLData.java new file mode 100644 index 000000000..8039b46e6 --- /dev/null +++ b/src/main/java/io/supertokens/cronjobs/deleteExpiredSAMLData/DeleteExpiredSAMLData.java @@ -0,0 +1,53 @@ +package io.supertokens.cronjobs.deleteExpiredSAMLData; + +import java.util.List; + +import io.supertokens.Main; +import io.supertokens.cronjobs.CronTask; +import io.supertokens.cronjobs.CronTaskTest; +import io.supertokens.pluginInterface.Storage; +import io.supertokens.pluginInterface.StorageUtils; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.saml.SAMLStorage; + +public class DeleteExpiredSAMLData extends CronTask { + public static final String RESOURCE_KEY = "io.supertokens.cronjobs.deleteExpiredSAMLData" + + ".DeleteExpiredSAMLData"; + + private DeleteExpiredSAMLData(Main main, List> tenantsInfo) { + super("DeleteExpiredSAMLData", main, tenantsInfo, false); + } + + public static DeleteExpiredSAMLData init(Main main, List> tenantsInfo) { + return (DeleteExpiredSAMLData) main.getResourceDistributor() + .setResource(new TenantIdentifier(null, null, null), RESOURCE_KEY, + new DeleteExpiredSAMLData(main, tenantsInfo)); + } + + @Override + protected void doTaskPerStorage(Storage storage) throws Exception { + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + samlStorage.removeExpiredSAMLCodesAndRelayStates(); + } + + @Override + public int getIntervalTimeSeconds() { + if (Main.isTesting) { + Integer interval = CronTaskTest.getInstance(main).getIntervalInSeconds(RESOURCE_KEY); + if (interval != null) { + return interval; + } + } + // Every hour + return 3600; + } + + @Override + public int getInitialWaitTimeSeconds() { + if (!Main.isTesting) { + return getIntervalTimeSeconds(); + } else { + return 0; + } + } +} diff --git a/src/main/java/io/supertokens/emailpassword/EmailPassword.java b/src/main/java/io/supertokens/emailpassword/EmailPassword.java index 72e3470a3..b2709cbbc 100644 --- a/src/main/java/io/supertokens/emailpassword/EmailPassword.java +++ b/src/main/java/io/supertokens/emailpassword/EmailPassword.java @@ -16,6 +16,16 @@ package io.supertokens.emailpassword; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.spec.InvalidKeySpecException; +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.jetbrains.annotations.TestOnly; + import io.supertokens.Main; import io.supertokens.ResourceDistributor; import io.supertokens.authRecipe.AuthRecipe; @@ -51,14 +61,6 @@ import io.supertokens.storageLayer.StorageLayer; import io.supertokens.utils.Utils; import io.supertokens.webserver.WebserverAPI; -import org.jetbrains.annotations.TestOnly; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; -import java.security.spec.InvalidKeySpecException; -import java.util.List; public class EmailPassword { @@ -216,7 +218,7 @@ public static ImportUserResponse importUserWithPasswordHash(TenantIdentifier ten public static ImportUserResponse createUserWithPasswordHash(TenantIdentifier tenantIdentifier, Storage storage, @Nonnull String email, - @Nonnull String passwordHash, @Nullable long timeJoined) + @Nonnull String passwordHash, long timeJoined) throws StorageQueryException, DuplicateEmailException, TenantOrAppNotFoundException, StorageTransactionLogicException { EmailPasswordSQLStorage epStorage = StorageUtils.getEmailPasswordStorage(storage); diff --git a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java index 8708b883f..3cd66842a 100644 --- a/src/main/java/io/supertokens/featureflag/EE_FEATURES.java +++ b/src/main/java/io/supertokens/featureflag/EE_FEATURES.java @@ -18,7 +18,7 @@ public enum EE_FEATURES { ACCOUNT_LINKING("account_linking"), MULTI_TENANCY("multi_tenancy"), TEST("test"), - DASHBOARD_LOGIN("dashboard_login"), MFA("mfa"), SECURITY("security"), OAUTH("oauth"); + DASHBOARD_LOGIN("dashboard_login"), MFA("mfa"), SECURITY("security"), OAUTH("oauth"), SAML("saml"); private final String name; diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index e79eff244..02e7665db 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -71,6 +71,10 @@ import io.supertokens.pluginInterface.passwordless.PasswordlessImportUser; import io.supertokens.pluginInterface.passwordless.exception.*; import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; +import io.supertokens.pluginInterface.saml.SAMLClaimsInfo; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.pluginInterface.saml.SAMLRelayStateInfo; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.pluginInterface.session.SessionInfo; import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.pluginInterface.session.sqlStorage.SessionSQLStorage; @@ -117,7 +121,8 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, TOTPSQLStorage, ActiveUsersStorage, - ActiveUsersSQLStorage, DashboardSQLStorage, AuthRecipeSQLStorage, OAuthStorage, WebAuthNSQLStorage { + ActiveUsersSQLStorage, DashboardSQLStorage, AuthRecipeSQLStorage, OAuthStorage, WebAuthNSQLStorage, + SAMLStorage { private static final Object appenderLock = new Object(); private static final String ACCESS_TOKEN_SIGNING_KEY_NAME = "access_token_signing_key"; @@ -765,6 +770,8 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi //ignore } else if (className.equals(OAuthStorage.class.getName())) { /* Since OAuth tables store client-related data, we don't add user-specific data here */ + } else if (className.equals(SAMLStorage.class.getName())) { + // no user specific data here } else if (className.equals(ActiveUsersStorage.class.getName())) { try { ActiveUsersQueries.updateUserLastActive(this, tenantIdentifier.toAppIdentifier(), userId); @@ -3896,4 +3903,72 @@ public void deleteExpiredGeneratedOptions() throws StorageQueryException { throw new StorageQueryException(e); } } + + @Override + public SAMLClient createOrUpdateSAMLClient(TenantIdentifier tenantIdentifier, SAMLClient samlClient) + throws StorageQueryException, io.supertokens.pluginInterface.saml.exception.DuplicateEntityIdException { + try { + return SAMLQueries.createOrUpdateSAMLClient(this, tenantIdentifier, samlClient.clientId, samlClient.clientSecret, + samlClient.ssoLoginURL, samlClient.redirectURIs.toString(), samlClient.defaultRedirectURI, + samlClient.idpEntityId, samlClient.idpSigningCertificate, samlClient.allowIDPInitiatedLogin, + samlClient.enableRequestSigning); + } catch (SQLException e) { + String errorMessage = e.getMessage(); + String table = io.supertokens.inmemorydb.config.Config.getConfig(this).getSAMLClientsTable(); + if (isUniqueConstraintError(errorMessage, table, new String[]{"app_id", "tenant_id", "idp_entity_id"})) { + throw new io.supertokens.pluginInterface.saml.exception.DuplicateEntityIdException(); + } + throw new StorageQueryException(e); + } + } + + @Override + public boolean removeSAMLClient(TenantIdentifier tenantIdentifier, String clientId) throws StorageQueryException { + return SAMLQueries.removeSAMLClient(this, tenantIdentifier, clientId); + } + + @Override + public SAMLClient getSAMLClient(TenantIdentifier tenantIdentifier, String clientId) throws StorageQueryException { + return SAMLQueries.getSAMLClient(this, tenantIdentifier, clientId); + } + + @Override + public SAMLClient getSAMLClientByIDPEntityId(TenantIdentifier tenantIdentifier, String idpEntityId) throws StorageQueryException { + return SAMLQueries.getSAMLClientByIDPEntityId(this, tenantIdentifier, idpEntityId); + } + + @Override + public List getSAMLClients(TenantIdentifier tenantIdentifier) throws StorageQueryException { + return SAMLQueries.getSAMLClients(this, tenantIdentifier); + } + + @Override + public void saveRelayStateInfo(TenantIdentifier tenantIdentifier, SAMLRelayStateInfo relayStateInfo) throws StorageQueryException { + SAMLQueries.saveRelayStateInfo(this, tenantIdentifier, relayStateInfo.relayState, relayStateInfo.clientId, relayStateInfo.state, relayStateInfo.redirectURI); + } + + @Override + public SAMLRelayStateInfo getRelayStateInfo(TenantIdentifier tenantIdentifier, String relayState) throws StorageQueryException { + return SAMLQueries.getRelayStateInfo(this, tenantIdentifier, relayState); + } + + @Override + public void saveSAMLClaims(TenantIdentifier tenantIdentifier, String clientId, String code, JsonObject claims) throws StorageQueryException { + SAMLQueries.saveSAMLClaims(this, tenantIdentifier, clientId, code, claims.toString()); + } + + @Override + public SAMLClaimsInfo getSAMLClaimsAndRemoveCode(TenantIdentifier tenantIdentifier, String code) throws StorageQueryException { + return SAMLQueries.getSAMLClaimsAndRemoveCode(this, tenantIdentifier, code); + } + + @Override + public void removeExpiredSAMLCodesAndRelayStates() throws StorageQueryException { + SAMLQueries.removeExpiredSAMLCodesAndRelayStates(this); + } + + @Override + public int countSAMLClients(TenantIdentifier tenantIdentifier) throws StorageQueryException { + return SAMLQueries.countSAMLClients(this, tenantIdentifier); + } } diff --git a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java index f029c9c8e..ecc7337f9 100644 --- a/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java +++ b/src/main/java/io/supertokens/inmemorydb/config/SQLiteConfig.java @@ -194,4 +194,10 @@ public String getOAuthLogoutChallengesTable() { public String getWebAuthNCredentialsTable() { return "webauthn_credentials"; } public String getWebAuthNAccountRecoveryTokenTable() { return "webauthn_account_recovery_tokens"; } + + public String getSAMLClientsTable() { return "saml_clients"; } + + public String getSAMLRelayStateTable() { return "saml_relay_state"; } + + public String getSAMLClaimsTable() { return "saml_claims"; } } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java index 9c0e31970..82704b838 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/GeneralQueries.java @@ -516,6 +516,33 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc //index update(start, WebAuthNQueries.getQueryToCreateWebAuthNCredentialsUserIdIndex(start), NO_OP_SETTER); } + + // SAML tables + if (!doesTableExists(start, Config.getConfig(start).getSAMLClientsTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, SAMLQueries.getQueryToCreateSAMLClientsTable(start), NO_OP_SETTER); + + // indexes + update(start, SAMLQueries.getQueryToCreateSAMLClientsAppIdTenantIdIndex(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getSAMLRelayStateTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, SAMLQueries.getQueryToCreateSAMLRelayStateTable(start), NO_OP_SETTER); + + // indexes + update(start, SAMLQueries.getQueryToCreateSAMLRelayStateAppIdTenantIdIndex(start), NO_OP_SETTER); + update(start, SAMLQueries.getQueryToCreateSAMLRelayStateExpiresAtIndex(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getSAMLClaimsTable())) { + getInstance(main).addState(CREATING_NEW_TABLE, null); + update(start, SAMLQueries.getQueryToCreateSAMLClaimsTable(start), NO_OP_SETTER); + + // indexes + update(start, SAMLQueries.getQueryToCreateSAMLClaimsAppIdTenantIdIndex(start), NO_OP_SETTER); + update(start, SAMLQueries.getQueryToCreateSAMLClaimsExpiresAtIndex(start), NO_OP_SETTER); + } } public static void setKeyValue_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, diff --git a/src/main/java/io/supertokens/inmemorydb/queries/SAMLQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/SAMLQueries.java new file mode 100644 index 000000000..923462cfa --- /dev/null +++ b/src/main/java/io/supertokens/inmemorydb/queries/SAMLQueries.java @@ -0,0 +1,458 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.inmemorydb.queries; + +import java.sql.SQLException; +import java.sql.Types; +import java.util.ArrayList; +import java.util.List; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute; +import static io.supertokens.inmemorydb.QueryExecutorTemplate.update; +import io.supertokens.inmemorydb.Start; +import io.supertokens.inmemorydb.config.Config; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.saml.SAMLClaimsInfo; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.pluginInterface.saml.SAMLRelayStateInfo; + +public class SAMLQueries { + public static String getQueryToCreateSAMLClientsTable(Start start) { + String table = Config.getConfig(start).getSAMLClientsTable(); + String tenantsTable = Config.getConfig(start).getTenantsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + table + " (" + + "app_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "tenant_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "client_id VARCHAR(255) NOT NULL," + + "client_secret TEXT," + + "sso_login_url TEXT NOT NULL," + + "redirect_uris TEXT NOT NULL," // store JsonArray.toString() + + "default_redirect_uri VARCHAR(1024) NOT NULL," + + "idp_entity_id VARCHAR(1024)," + + "idp_signing_certificate TEXT," + + "allow_idp_initiated_login BOOLEAN NOT NULL DEFAULT FALSE," + + "enable_request_signing BOOLEAN NOT NULL DEFAULT TRUE," + + "created_at BIGINT NOT NULL," + + "updated_at BIGINT NOT NULL," + + "UNIQUE (app_id, tenant_id, idp_entity_id)," + + "PRIMARY KEY (app_id, tenant_id, client_id)," + + "FOREIGN KEY (app_id, tenant_id) REFERENCES " + tenantsTable + " (app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateSAMLClientsAppIdTenantIdIndex(Start start) { + String table = Config.getConfig(start).getSAMLClientsTable(); + return "CREATE INDEX IF NOT EXISTS saml_clients_app_tenant_index ON " + table + "(app_id, tenant_id);"; + } + + public static String getQueryToCreateSAMLRelayStateTable(Start start) { + String table = Config.getConfig(start).getSAMLRelayStateTable(); + String tenantsTable = Config.getConfig(start).getTenantsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + table + " (" + + "app_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "tenant_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "relay_state VARCHAR(255) NOT NULL," + + "client_id VARCHAR(255) NOT NULL," + + "state TEXT," // nullable + + "redirect_uri VARCHAR(1024) NOT NULL," + + "created_at BIGINT NOT NULL," + + "expires_at BIGINT NOT NULL," + + "PRIMARY KEY (relay_state)," // relayState must be unique + + "FOREIGN KEY (app_id, tenant_id) REFERENCES " + tenantsTable + " (app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateSAMLRelayStateAppIdTenantIdIndex(Start start) { + String table = Config.getConfig(start).getSAMLRelayStateTable(); + return "CREATE INDEX IF NOT EXISTS saml_relay_state_app_tenant_index ON " + table + "(app_id, tenant_id);"; + } + + public static String getQueryToCreateSAMLRelayStateExpiresAtIndex(Start start) { + String table = Config.getConfig(start).getSAMLRelayStateTable(); + return "CREATE INDEX IF NOT EXISTS saml_relay_state_expires_at_index ON " + table + "(expires_at);"; + } + + public static String getQueryToCreateSAMLClaimsTable(Start start) { + String table = Config.getConfig(start).getSAMLClaimsTable(); + String tenantsTable = Config.getConfig(start).getTenantsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + table + " (" + + "app_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "tenant_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "client_id VARCHAR(255) NOT NULL," + + "code VARCHAR(255) NOT NULL," + + "claims TEXT NOT NULL," + + "created_at BIGINT NOT NULL," + + "expires_at BIGINT NOT NULL," + + "PRIMARY KEY (code)," + + "FOREIGN KEY (app_id, tenant_id) REFERENCES " + tenantsTable + " (app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateSAMLClaimsAppIdTenantIdIndex(Start start) { + String table = Config.getConfig(start).getSAMLClaimsTable(); + return "CREATE INDEX IF NOT EXISTS saml_claims_app_tenant_index ON " + table + "(app_id, tenant_id);"; + } + + public static String getQueryToCreateSAMLClaimsExpiresAtIndex(Start start) { + String table = Config.getConfig(start).getSAMLClaimsTable(); + return "CREATE INDEX IF NOT EXISTS saml_claims_expires_at_index ON " + table + "(expires_at);"; + } + + public static void saveRelayStateInfo(Start start, TenantIdentifier tenantIdentifier, + String relayState, String clientId, String state, String redirectURI) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLRelayStateTable(); + String QUERY = "INSERT INTO " + table + + " (app_id, tenant_id, relay_state, client_id, state, redirect_uri, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; + + try { + update(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, relayState); + pst.setString(4, clientId); + if (state != null) { + pst.setString(5, state); + } else { + pst.setNull(5, java.sql.Types.VARCHAR); + } + pst.setString(6, redirectURI); + pst.setLong(7, System.currentTimeMillis()); + pst.setLong(8, System.currentTimeMillis() + 300000); + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static SAMLRelayStateInfo getRelayStateInfo(Start start, TenantIdentifier tenantIdentifier, String relayState) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLRelayStateTable(); + String QUERY = "SELECT client_id, state, redirect_uri, expires_at FROM " + table + + " WHERE app_id = ? AND tenant_id = ? AND relay_state = ? AND expires_at >= ?"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, relayState); + pst.setLong(4, System.currentTimeMillis()); + }, result -> { + if (result.next()) { + String clientId = result.getString("client_id"); + String state = result.getString("state"); // may be null + String redirectURI = result.getString("redirect_uri"); + return new SAMLRelayStateInfo(relayState, clientId, state, redirectURI); + } + return null; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static void saveSAMLClaims(Start start, TenantIdentifier tenantIdentifier, String clientId, String code, String claimsJson) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClaimsTable(); + String QUERY = "INSERT INTO " + table + + " (app_id, tenant_id, client_id, code, claims, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?, ?)"; + + try { + update(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, clientId); + pst.setString(4, code); + pst.setString(5, claimsJson); + pst.setLong(6, System.currentTimeMillis()); + pst.setLong(7, System.currentTimeMillis() + 300000); + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static SAMLClaimsInfo getSAMLClaimsAndRemoveCode(Start start, TenantIdentifier tenantIdentifier, String code) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClaimsTable(); + String QUERY = "SELECT client_id, claims FROM " + table + " WHERE app_id = ? AND tenant_id = ? AND code = ? AND expires_at >= ?"; + try { + SAMLClaimsInfo claimsInfo = execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, code); + pst.setLong(4, System.currentTimeMillis()); + }, result -> { + if (result.next()) { + String clientId = result.getString("client_id"); + JsonObject claims = com.google.gson.JsonParser.parseString(result.getString("claims")).getAsJsonObject(); + return new SAMLClaimsInfo(clientId, claims); + } + return null; + }); + + if (claimsInfo != null) { + String DELETE = "DELETE FROM " + table + " WHERE app_id = ? AND tenant_id = ? AND code = ?"; + update(start, DELETE, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, code); + }); + } + return claimsInfo; + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static SAMLClient createOrUpdateSAMLClient( + Start start, + TenantIdentifier tenantIdentifier, + String clientId, + String clientSecret, + String ssoLoginURL, + String redirectURIsJson, + String defaultRedirectURI, + String idpEntityId, + String idpSigningCertificate, + boolean allowIDPInitiatedLogin, + boolean enableRequestSigning) + throws StorageQueryException, SQLException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "INSERT INTO " + table + + " (app_id, tenant_id, client_id, client_secret, sso_login_url, redirect_uris, default_redirect_uri, idp_entity_id, idp_signing_certificate, allow_idp_initiated_login, enable_request_signing, created_at, updated_at) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " + + "ON CONFLICT (app_id, tenant_id, client_id) DO UPDATE SET " + + "client_secret = ?, sso_login_url = ?, redirect_uris = ?, default_redirect_uri = ?, idp_entity_id = ?, idp_signing_certificate = ?, allow_idp_initiated_login = ?, enable_request_signing = ?, updated_at = ?"; + long now = System.currentTimeMillis(); + update(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, clientId); + if (clientSecret != null) { + pst.setString(4, clientSecret); + } else { + pst.setNull(4, Types.VARCHAR); + } + pst.setString(5, ssoLoginURL); + pst.setString(6, redirectURIsJson); + pst.setString(7, defaultRedirectURI); + if (idpEntityId != null) { + pst.setString(8, idpEntityId); + } else { + pst.setNull(8, java.sql.Types.VARCHAR); + } + if (idpSigningCertificate != null) { + pst.setString(9, idpSigningCertificate); + } else { + pst.setNull(9, Types.VARCHAR); + } + pst.setBoolean(10, allowIDPInitiatedLogin); + pst.setBoolean(11, enableRequestSigning); + pst.setLong(12, now); + pst.setLong(13, now); + + if (clientSecret != null) { + pst.setString(14, clientSecret); + } else { + pst.setNull(14, Types.VARCHAR); + } + pst.setString(15, ssoLoginURL); + pst.setString(16, redirectURIsJson); + pst.setString(17, defaultRedirectURI); + if (idpEntityId != null) { + pst.setString(18, idpEntityId); + } else { + pst.setNull(18, java.sql.Types.VARCHAR); + } + if (idpSigningCertificate != null) { + pst.setString(19, idpSigningCertificate); + } else { + pst.setNull(19, Types.VARCHAR); + } + pst.setBoolean(20, allowIDPInitiatedLogin); + pst.setBoolean(21, enableRequestSigning); + pst.setLong(22, now); + }); + + return getSAMLClient(start, tenantIdentifier, clientId); + } + + public static SAMLClient getSAMLClient(Start start, TenantIdentifier tenantIdentifier, String clientId) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "SELECT client_id, client_secret, sso_login_url, redirect_uris, default_redirect_uri, idp_entity_id, idp_signing_certificate, allow_idp_initiated_login, enable_request_signing FROM " + table + + " WHERE app_id = ? AND tenant_id = ? AND client_id = ?"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, clientId); + }, result -> { + if (result.next()) { + String fetchedClientId = result.getString("client_id"); + String clientSecret = result.getString("client_secret"); + String ssoLoginURL = result.getString("sso_login_url"); + String redirectUrisJson = result.getString("redirect_uris"); + String defaultRedirectURI = result.getString("default_redirect_uri"); + String idpEntityId = result.getString("idp_entity_id"); + String idpSigningCertificate = result.getString("idp_signing_certificate"); + boolean allowIDPInitiatedLogin = result.getBoolean("allow_idp_initiated_login"); + boolean enableRequestSigning = result.getBoolean("enable_request_signing"); + + JsonArray redirectURIs = JsonParser.parseString(redirectUrisJson).getAsJsonArray(); + return new SAMLClient(fetchedClientId, clientSecret, ssoLoginURL, redirectURIs, defaultRedirectURI, idpEntityId, idpSigningCertificate, allowIDPInitiatedLogin, enableRequestSigning); + } + return null; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static SAMLClient getSAMLClientByIDPEntityId(Start start, TenantIdentifier tenantIdentifier, String idpEntityId) throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "SELECT client_id, client_secret, sso_login_url, redirect_uris, default_redirect_uri, idp_entity_id, idp_signing_certificate, allow_idp_initiated_login, enable_request_signing FROM " + table + + " WHERE app_id = ? AND tenant_id = ? AND idp_entity_id = ?"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, idpEntityId); + }, result -> { + if (result.next()) { + String fetchedClientId = result.getString("client_id"); + String clientSecret = result.getString("client_secret"); + String ssoLoginURL = result.getString("sso_login_url"); + String redirectUrisJson = result.getString("redirect_uris"); + String defaultRedirectURI = result.getString("default_redirect_uri"); + String fetchedIdpEntityId = result.getString("idp_entity_id"); + String idpSigningCertificate = result.getString("idp_signing_certificate"); + boolean allowIDPInitiatedLogin = result.getBoolean("allow_idp_initiated_login"); + boolean enableRequestSigning = result.getBoolean("enable_request_signing"); + + JsonArray redirectURIs = JsonParser.parseString(redirectUrisJson).getAsJsonArray(); + return new SAMLClient(fetchedClientId, clientSecret, ssoLoginURL, redirectURIs, defaultRedirectURI, fetchedIdpEntityId, idpSigningCertificate, allowIDPInitiatedLogin, enableRequestSigning); + } + return null; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static List getSAMLClients(Start start, TenantIdentifier tenantIdentifier) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "SELECT client_id, client_secret, sso_login_url, redirect_uris, default_redirect_uri, idp_entity_id, idp_signing_certificate, allow_idp_initiated_login, enable_request_signing FROM " + table + + " WHERE app_id = ? AND tenant_id = ?"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + }, result -> { + List clients = new ArrayList<>(); + while (result.next()) { + String fetchedClientId = result.getString("client_id"); + String clientSecret = result.getString("client_secret"); + String ssoLoginURL = result.getString("sso_login_url"); + String redirectUrisJson = result.getString("redirect_uris"); + String defaultRedirectURI = result.getString("default_redirect_uri"); + String idpEntityId = result.getString("idp_entity_id"); + String idpSigningCertificate = result.getString("idp_signing_certificate"); + boolean allowIDPInitiatedLogin = result.getBoolean("allow_idp_initiated_login"); + boolean enableRequestSigning = result.getBoolean("enable_request_signing"); + + JsonArray redirectURIs = JsonParser.parseString(redirectUrisJson).getAsJsonArray(); + clients.add(new SAMLClient(fetchedClientId, clientSecret, ssoLoginURL, redirectURIs, defaultRedirectURI, idpEntityId, idpSigningCertificate, allowIDPInitiatedLogin, enableRequestSigning)); + } + return clients; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static boolean removeSAMLClient(Start start, TenantIdentifier tenantIdentifier, String clientId) + throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "DELETE FROM " + table + " WHERE app_id = ? AND tenant_id = ? AND client_id = ?"; + try { + return update(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + pst.setString(3, clientId); + }) > 0; + + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static void removeExpiredSAMLCodesAndRelayStates(Start start) throws StorageQueryException { + try { + { + String QUERY = "DELETE FROM " + Config.getConfig(start).getSAMLClaimsTable() + " WHERE expires_at <= ?"; + update(start, QUERY, pst -> { + pst.setLong(1, System.currentTimeMillis()); + }); + } + { + String QUERY = "DELETE FROM " + Config.getConfig(start).getSAMLRelayStateTable() + " WHERE expires_at <= ?"; + update(start, QUERY, pst -> { + pst.setLong(1, System.currentTimeMillis()); + }); + } + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + public static int countSAMLClients(Start start, TenantIdentifier tenantIdentifier) throws StorageQueryException { + String table = Config.getConfig(start).getSAMLClientsTable(); + String QUERY = "SELECT COUNT(*) as c FROM " + table + + " WHERE app_id = ? AND tenant_id = ?"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, tenantIdentifier.getTenantId()); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } +} diff --git a/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java b/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java index 3dfb9e102..3168fb76d 100644 --- a/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java +++ b/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java @@ -33,6 +33,7 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAMLCertificate; import io.supertokens.pluginInterface.opentelemetry.WithinOtelSpan; import io.supertokens.session.refreshToken.RefreshTokenKey; import io.supertokens.signingkeys.AccessTokenSigningKey; @@ -235,6 +236,7 @@ public void loadSigningKeys(List tenantsThatChanged) } AccessTokenSigningKey.loadForAllTenants(main, apps, tenantsThatChanged); RefreshTokenKey.loadForAllTenants(main, apps, tenantsThatChanged); + SAMLCertificate.loadForAllTenants(main, apps, tenantsThatChanged); JWTSigningKey.loadForAllTenants(main, apps, tenantsThatChanged); SigningKeys.loadForAllTenants(main, apps, tenantsThatChanged); } diff --git a/src/main/java/io/supertokens/output/Logging.java b/src/main/java/io/supertokens/output/Logging.java index 4e0335b35..4c8fddcb1 100644 --- a/src/main/java/io/supertokens/output/Logging.java +++ b/src/main/java/io/supertokens/output/Logging.java @@ -16,6 +16,7 @@ package io.supertokens.output; +import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.LoggerContext; import ch.qos.logback.classic.spi.ILoggingEvent; @@ -55,6 +56,12 @@ public class Logging extends ResourceDistributor.SingletonResource { public static final String ANSI_WHITE = "\u001B[37m"; private Logging(Main main) { + // Set global logging level + LoggerContext loggerContext = (LoggerContext) LoggerFactory.getILoggerFactory(); + Logger rootLogger = loggerContext.getLogger(Logger.ROOT_LOGGER_NAME); + Level newLevel = Level.toLevel(Config.getBaseConfig(main).getLogLevel(), Level.INFO); // Default to INFO if invalid + rootLogger.setLevel(newLevel); + this.infoLogger = Config.getBaseConfig(main).getInfoLogPath(main).equals("null") ? createLoggerForConsole(main, "io.supertokens.Info", LOG_LEVEL.INFO) : createLoggerForFile(main, Config.getBaseConfig(main).getInfoLogPath(main), diff --git a/src/main/java/io/supertokens/saml/SAML.java b/src/main/java/io/supertokens/saml/SAML.java new file mode 100644 index 000000000..488fc56a0 --- /dev/null +++ b/src/main/java/io/supertokens/saml/SAML.java @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.saml; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.util.List; +import java.util.UUID; +import java.util.zip.Deflater; +import java.util.zip.DeflaterOutputStream; + +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.XMLObjectBuilderFactory; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.UnmarshallingException; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.SAMLVersion; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Audience; +import org.opensaml.saml.saml2.core.AudienceRestriction; +import org.opensaml.saml.saml2.core.AuthnContext; +import org.opensaml.saml.saml2.core.AuthnContextClassRef; +import org.opensaml.saml.saml2.core.AuthnRequest; +import org.opensaml.saml.saml2.core.Conditions; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.NameIDPolicy; +import org.opensaml.saml.saml2.core.RequestedAuthnContext; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.metadata.EntityDescriptor; +import org.opensaml.saml.saml2.metadata.IDPSSODescriptor; +import org.opensaml.saml.saml2.metadata.SingleSignOnService; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.xmlsec.signature.KeyInfo; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.X509Data; +import org.opensaml.xmlsec.signature.impl.KeyInfoBuilder; +import org.opensaml.xmlsec.signature.impl.SignatureBuilder; +import org.opensaml.xmlsec.signature.impl.X509DataBuilder; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.SignatureException; +import org.opensaml.xmlsec.signature.support.SignatureValidator; +import org.w3c.dom.Element; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.config.Config; +import io.supertokens.config.CoreConfig; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.Storage; +import io.supertokens.pluginInterface.StorageUtils; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.saml.SAMLClaimsInfo; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.pluginInterface.saml.SAMLRelayStateInfo; +import io.supertokens.pluginInterface.saml.SAMLStorage; +import io.supertokens.pluginInterface.saml.exception.DuplicateEntityIdException; +import io.supertokens.saml.exceptions.IDPInitiatedLoginDisallowedException; +import io.supertokens.saml.exceptions.InvalidClientException; +import io.supertokens.saml.exceptions.InvalidCodeException; +import io.supertokens.saml.exceptions.InvalidRelayStateException; +import io.supertokens.saml.exceptions.MalformedSAMLMetadataXMLException; +import io.supertokens.saml.exceptions.SAMLResponseVerificationFailedException; +import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import net.shibboleth.utilities.java.support.xml.XMLParserException; + +public class SAML { + public static void checkForSAMLFeature(AppIdentifier appIdentifier, Main main) + throws StorageQueryException, TenantOrAppNotFoundException, FeatureNotEnabledException { + EE_FEATURES[] features = FeatureFlag.getInstance(main, appIdentifier).getEnabledFeatures(); + for (EE_FEATURES f : features) { + if (f == EE_FEATURES.SAML) { + return; + } + } + throw new FeatureNotEnabledException( + "SAML feature is not enabled. Please subscribe to a SuperTokens core license key to enable this " + + "feature."); + } + + public static SAMLClient createOrUpdateSAMLClient( + Main main, TenantIdentifier tenantIdentifier, Storage storage, + String clientId, String clientSecret, String defaultRedirectURI, JsonArray redirectURIs, String metadataXML, boolean allowIDPInitiatedLogin, boolean enableRequestSigning) + throws MalformedSAMLMetadataXMLException, StorageQueryException, CertificateException, + FeatureNotEnabledException, TenantOrAppNotFoundException, DuplicateEntityIdException { + checkForSAMLFeature(tenantIdentifier.toAppIdentifier(), main); + + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + + var metadata = loadIdpMetadata(metadataXML); + String idpSsoUrl = null; + for (var roleDescriptor : metadata.getRoleDescriptors()) { + if (roleDescriptor instanceof IDPSSODescriptor) { + IDPSSODescriptor idpDescriptor = (IDPSSODescriptor) roleDescriptor; + for (SingleSignOnService ssoService : idpDescriptor.getSingleSignOnServices()) { + if (SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(ssoService.getBinding())) { + idpSsoUrl = ssoService.getLocation(); + } + } + } + } + if (idpSsoUrl == null) { + throw new MalformedSAMLMetadataXMLException(); + } + + String idpSigningCertificate = extractIdpSigningCertificate(metadata); + getCertificateFromString(idpSigningCertificate); // checking validity + + String idpEntityId = metadata.getEntityID(); + SAMLClient client = new SAMLClient(clientId, clientSecret, idpSsoUrl, redirectURIs, defaultRedirectURI, idpEntityId, idpSigningCertificate, allowIDPInitiatedLogin, enableRequestSigning); + return samlStorage.createOrUpdateSAMLClient(tenantIdentifier, client); + } + + public static List getClients(TenantIdentifier tenantIdentifier, Storage storage) throws StorageQueryException { + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + return samlStorage.getSAMLClients(tenantIdentifier); + } + + public static SAMLClient getClient(TenantIdentifier tenantIdentifier, Storage storage, String clientId) throws StorageQueryException { + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + return samlStorage.getSAMLClient(tenantIdentifier, clientId); + } + + public static boolean removeSAMLClient(TenantIdentifier tenantIdentifier, Storage storage, String clientId) throws StorageQueryException { + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + return samlStorage.removeSAMLClient(tenantIdentifier, clientId); + } + + private static String extractIdpSigningCertificate(EntityDescriptor idpMetadata) { + for (var roleDescriptor : idpMetadata.getRoleDescriptors()) { + if (roleDescriptor instanceof IDPSSODescriptor) { + IDPSSODescriptor idpDescriptor = (IDPSSODescriptor) roleDescriptor; + for (org.opensaml.saml.saml2.metadata.KeyDescriptor keyDescriptor : idpDescriptor.getKeyDescriptors()) { + if (keyDescriptor.getUse() == null || + "SIGNING".equals(keyDescriptor.getUse().toString())) { + org.opensaml.xmlsec.signature.KeyInfo keyInfo = keyDescriptor.getKeyInfo(); + if (keyInfo != null) { + for (org.opensaml.xmlsec.signature.X509Data x509Data : keyInfo.getX509Datas()) { + for (org.opensaml.xmlsec.signature.X509Certificate x509Cert : x509Data.getX509Certificates()) { + try { + String certString = x509Cert.getValue(); + if (certString != null && !certString.trim().isEmpty()) { + certString = certString.replaceAll("\\s", ""); + return certString; + } + } catch (Exception e) { + // Continue to next certificate if this one fails + continue; + } + } + } + } + } + } + } + } + return null; + + } + + public static String createRedirectURL(Main main, TenantIdentifier tenantIdentifier, Storage storage, + String clientId, String redirectURI, String state, String acsURL) + throws StorageQueryException, InvalidClientException, TenantOrAppNotFoundException, + CertificateEncodingException, FeatureNotEnabledException { + checkForSAMLFeature(tenantIdentifier.toAppIdentifier(), main); + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + CoreConfig config = Config.getConfig(tenantIdentifier, main); + + SAMLClient client = samlStorage.getSAMLClient(tenantIdentifier, clientId); + + if (client == null) { + throw new InvalidClientException(); + } + + boolean redirectURIOk = false; + for (JsonElement rUri : client.redirectURIs) { + if (rUri.getAsString().equals(redirectURI)) { + redirectURIOk = true; + break; + } + } + + if (!redirectURIOk) { + throw new InvalidClientException(); + } + + String idpSsoUrl = client.ssoLoginURL; + AuthnRequest request = buildAuthnRequest( + main, + tenantIdentifier.toAppIdentifier(), + idpSsoUrl, + config.getSAMLSPEntityID(), acsURL, + client.enableRequestSigning); + String samlRequest = deflateAndBase64RedirectMessage(request); + String relayState = UUID.randomUUID().toString(); + + samlStorage.saveRelayStateInfo(tenantIdentifier, new SAMLRelayStateInfo(relayState, clientId, state, redirectURI)); + + return idpSsoUrl + "?SAMLRequest=" + samlRequest + "&RelayState=" + URLEncoder.encode(relayState, StandardCharsets.UTF_8); + } + + public static EntityDescriptor loadIdpMetadata(String metadataXML) throws MalformedSAMLMetadataXMLException { + try { + byte[] bytes = metadataXML.getBytes(StandardCharsets.UTF_8); + try (InputStream inputStream = new java.io.ByteArrayInputStream(bytes)) { + XMLObject xmlObject = XMLObjectSupport.unmarshallFromInputStream( + XMLObjectProviderRegistrySupport.getParserPool(), inputStream); + if (xmlObject instanceof EntityDescriptor) { + return (EntityDescriptor) xmlObject; + } else { + throw new RuntimeException("Expected EntityDescriptor but got: " + xmlObject.getClass()); + } + } + } catch (Exception e) { + throw new MalformedSAMLMetadataXMLException(); + } + } + + private static AuthnRequest buildAuthnRequest(Main main, AppIdentifier appIdentifier, String idpSsoUrl, String spEntityId, String acsUrl, boolean enableRequestSigning) + throws TenantOrAppNotFoundException, StorageQueryException, CertificateEncodingException { + XMLObjectBuilderFactory builders = XMLObjectProviderRegistrySupport.getBuilderFactory(); + + AuthnRequest authnRequest = (AuthnRequest) builders + .getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME) + .buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME); + authnRequest.setID("_" + UUID.randomUUID()); + authnRequest.setIssueInstant(Instant.now()); + authnRequest.setVersion(SAMLVersion.VERSION_20); + authnRequest.setDestination(idpSsoUrl); + authnRequest.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI); + + Issuer issuer = (Issuer) builders.getBuilder(Issuer.DEFAULT_ELEMENT_NAME) + .buildObject(Issuer.DEFAULT_ELEMENT_NAME); + issuer.setValue(spEntityId); + authnRequest.setIssuer(issuer); + + NameIDPolicy nameIDPolicy = (NameIDPolicy) builders.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME) + .buildObject(NameIDPolicy.DEFAULT_ELEMENT_NAME); + nameIDPolicy.setAllowCreate(true); + authnRequest.setNameIDPolicy(nameIDPolicy); + + RequestedAuthnContext rac = (RequestedAuthnContext) builders.getBuilder(RequestedAuthnContext.DEFAULT_ELEMENT_NAME) + .buildObject(RequestedAuthnContext.DEFAULT_ELEMENT_NAME); + rac.setComparison(org.opensaml.saml.saml2.core.AuthnContextComparisonTypeEnumeration.EXACT); + AuthnContextClassRef classRef = (AuthnContextClassRef) builders.getBuilder(AuthnContextClassRef.DEFAULT_ELEMENT_NAME) + .buildObject(AuthnContextClassRef.DEFAULT_ELEMENT_NAME); + classRef.setURI(AuthnContext.PASSWORD_AUTHN_CTX); + rac.getAuthnContextClassRefs().add(classRef); + authnRequest.setRequestedAuthnContext(rac); + + authnRequest.setAssertionConsumerServiceURL(acsUrl); + + if (enableRequestSigning) { + Signature signature = new SignatureBuilder().buildObject(); + signature.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + signature.setCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); + + // Create KeyInfo + KeyInfo keyInfo = new KeyInfoBuilder().buildObject(); + X509Data x509Data = new X509DataBuilder().buildObject(); + org.opensaml.xmlsec.signature.X509Certificate x509CertElement = new org.opensaml.xmlsec.signature.impl.X509CertificateBuilder().buildObject(); + + X509Certificate spCertificate = SAMLCertificate.getInstance(appIdentifier, main).getCertificate(); + String certString = java.util.Base64.getEncoder().encodeToString(spCertificate.getEncoded()); + x509CertElement.setValue(certString); + x509Data.getX509Certificates().add(x509CertElement); + keyInfo.getX509Datas().add(x509Data); + signature.setKeyInfo(keyInfo); + + authnRequest.setSignature(signature); + } + + return authnRequest; + } + + private static String deflateAndBase64RedirectMessage(XMLObject xmlObject) { + try { + String xml = toXmlString(xmlObject); + byte[] xmlBytes = xml.getBytes(StandardCharsets.UTF_8); + + // DEFLATE compression as per SAML Redirect binding spec + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DeflaterOutputStream dos = new DeflaterOutputStream(baos, new Deflater(Deflater.DEFLATED, true)); + dos.write(xmlBytes); + dos.close(); + + byte[] deflated = baos.toByteArray(); + String base64 = java.util.Base64.getEncoder().encodeToString(deflated); + return URLEncoder.encode(base64, StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException("Failed to deflate SAML message", e); + } + } + + private static String toXmlString(XMLObject xmlObject) { + try { + Element el = XMLObjectSupport.marshall(xmlObject); + return SerializeSupport.nodeToString(el); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize XML", e); + } + } + + private static Response parseSamlResponse(String samlResponseBase64) + throws IOException, XMLParserException, UnmarshallingException { + byte[] decoded = java.util.Base64.getDecoder().decode(samlResponseBase64); + String xml = new String(decoded, StandardCharsets.UTF_8); + + try (InputStream inputStream = new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8))) { + return (Response) XMLObjectSupport.unmarshallFromInputStream( + XMLObjectProviderRegistrySupport.getParserPool(), inputStream); + } + } + + private static void verifySamlResponseSignature(Response samlResponse, X509Certificate idpCertificate) + throws SignatureException { + Signature responseSignature = samlResponse.getSignature(); + if (responseSignature != null) { + Credential credential = CredentialSupport.getSimpleCredential(idpCertificate, null); + SignatureValidator.validate(responseSignature, credential); + return; + } + + boolean foundSignedAssertion = false; + for (Assertion assertion : samlResponse.getAssertions()) { + Signature assertionSignature = assertion.getSignature(); + if (assertionSignature != null) { + Credential credential = CredentialSupport.getSimpleCredential(idpCertificate, null); + SignatureValidator.validate(assertionSignature, credential); + foundSignedAssertion = true; + } + } + + if (!foundSignedAssertion) { + throw new RuntimeException("Neither SAML Response nor any Assertion is signed"); + } + } + + private static void validateSamlResponseTimestamps(Response samlResponse) throws SAMLResponseVerificationFailedException { + Instant now = Instant.now(); + + // Validate response issue instant (should be recent) + if (samlResponse.getIssueInstant() != null) { + Instant responseTime = samlResponse.getIssueInstant(); + // Allow 5 minutes clock skew + if (responseTime.isAfter(now.plusSeconds(300)) || responseTime.isBefore(now.minusSeconds(300))) { + throw new SAMLResponseVerificationFailedException(); + } + } + + // Validate assertion timestamps + for (Assertion assertion : samlResponse.getAssertions()) { + // Check NotBefore + if (assertion.getConditions() != null && assertion.getConditions().getNotBefore() != null) { + if (now.isBefore(assertion.getConditions().getNotBefore())) { + throw new SAMLResponseVerificationFailedException(); + } + } + + // Check NotOnOrAfter + if (assertion.getConditions() != null && assertion.getConditions().getNotOnOrAfter() != null) { + if (now.isAfter(assertion.getConditions().getNotOnOrAfter())) { + throw new SAMLResponseVerificationFailedException(); + } + } + } + } + + public static String handleCallback(Main main, TenantIdentifier tenantIdentifier, Storage storage, String samlResponse, String relayState) + throws StorageQueryException, XMLParserException, IOException, UnmarshallingException, + CertificateException, InvalidRelayStateException, SAMLResponseVerificationFailedException, + InvalidClientException, IDPInitiatedLoginDisallowedException, TenantOrAppNotFoundException, + FeatureNotEnabledException { + checkForSAMLFeature(tenantIdentifier.toAppIdentifier(), main); + + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + CoreConfig config = Config.getConfig(tenantIdentifier, main); + + SAMLClient client = null; + Response response = parseSamlResponse(samlResponse); + String state = null; + String redirectURI = null; + + if (relayState != null && !relayState.isEmpty()) { + // sp initiated + var relayStateInfo = samlStorage.getRelayStateInfo(tenantIdentifier, relayState); + if (relayStateInfo == null) { + throw new InvalidRelayStateException(); + } + + String clientId = relayStateInfo.clientId; + client = samlStorage.getSAMLClient(tenantIdentifier, clientId); + state = relayStateInfo.state; + redirectURI = relayStateInfo.redirectURI; + } else { + // idp initiated + String idpEntityId = response.getIssuer().getValue(); + client = samlStorage.getSAMLClientByIDPEntityId(tenantIdentifier, idpEntityId); + redirectURI = client.defaultRedirectURI; + + if (!client.allowIDPInitiatedLogin) { + throw new IDPInitiatedLoginDisallowedException(); + } + } + + if (client == null) { + throw new InvalidClientException(); + } + + // SAML verification + X509Certificate idpSigningCertificate = getCertificateFromString(client.idpSigningCertificate); + try { + verifySamlResponseSignature(response, idpSigningCertificate); + } catch (SignatureException e) { + throw new SAMLResponseVerificationFailedException(); + } + validateSamlResponseTimestamps(response); + validateSamlResponseAudience(response, config.getSAMLSPEntityID()); + + var claims = extractAllClaims(response); + + String code = UUID.randomUUID().toString(); + samlStorage.saveSAMLClaims(tenantIdentifier, client.clientId, code, claims); + + try { + java.net.URI uri = new java.net.URI(redirectURI); + String query = uri.getQuery(); + StringBuilder newQuery = new StringBuilder(); + if (query != null && !query.isEmpty()) { + newQuery.append(query).append("&"); + } + newQuery.append("code=").append(java.net.URLEncoder.encode(code, java.nio.charset.StandardCharsets.UTF_8)); + if (state != null) { + newQuery.append("&state=").append(java.net.URLEncoder.encode(state, java.nio.charset.StandardCharsets.UTF_8)); + } + java.net.URI newUri = new java.net.URI( + uri.getScheme(), + uri.getAuthority(), + uri.getPath(), + newQuery.toString(), + uri.getFragment() + ); + return newUri.toString(); + } catch (URISyntaxException e) { + throw new IllegalStateException("should never happen", e); + } + } + + private static void validateSamlResponseAudience(Response samlResponse, String expectedAudience) + throws SAMLResponseVerificationFailedException { + boolean audienceMatched = false; + + for (Assertion assertion : samlResponse.getAssertions()) { + Conditions conditions = assertion.getConditions(); + if (conditions == null) { + continue; + } + java.util.List restrictions = conditions.getAudienceRestrictions(); + if (restrictions == null || restrictions.isEmpty()) { + continue; + } + for (AudienceRestriction ar : restrictions) { + java.util.List audiences = ar.getAudiences(); + if (audiences == null || audiences.isEmpty()) { + continue; + } + for (Audience aud : audiences) { + if (expectedAudience.equals(aud.getURI())) { + audienceMatched = true; + break; + } + } + if (audienceMatched) { + break; + } + } + if (audienceMatched) { + break; + } + } + + if (!audienceMatched) { + throw new SAMLResponseVerificationFailedException(); + } + } + + private static JsonObject extractAllClaims(Response samlResponse) { + JsonObject claims = new JsonObject(); + + for (Assertion assertion : samlResponse.getAssertions()) { + // Extract NameID as a claim + Subject subject = assertion.getSubject(); + if (subject != null && subject.getNameID() != null) { + String nameId = subject.getNameID().getValue(); + String nameIdFormat = subject.getNameID().getFormat(); + JsonArray nameIdArr = new JsonArray(); + nameIdArr.add(nameId); + claims.add("NameID", nameIdArr); + if (nameIdFormat != null) { + JsonArray nameIdFormatArr = new JsonArray(); + nameIdFormatArr.add(nameIdFormat); + claims.add("NameIDFormat", nameIdFormatArr); + } + } + + // Extract all attributes from AttributeStatements + for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) { + for (Attribute attribute : attributeStatement.getAttributes()) { + String attributeName = attribute.getName(); + JsonArray attributeValues = new JsonArray(); + + for (XMLObject attributeValue : attribute.getAttributeValues()) { + if (attributeValue instanceof org.opensaml.saml.saml2.core.AttributeValue) { + org.opensaml.saml.saml2.core.AttributeValue attrValue = + (org.opensaml.saml.saml2.core.AttributeValue) attributeValue; + + if (attrValue.getDOM() != null) { + String value = attrValue.getDOM().getTextContent(); + if (value != null && !value.trim().isEmpty()) { + attributeValues.add(value.trim()); + } + } else if (attrValue.getTextContent() != null) { + String value = attrValue.getTextContent(); + if (!value.trim().isEmpty()) { + attributeValues.add(value.trim()); + } + } + } + } + + if (!attributeValues.isEmpty()) { + claims.add(attributeName, attributeValues); + } + } + } + } + + return claims; + } + + private static X509Certificate getCertificateFromString(String certString) throws CertificateException { + byte[] certBytes = java.util.Base64.getDecoder().decode(certString); + java.security.cert.CertificateFactory certFactory = + java.security.cert.CertificateFactory.getInstance("X.509"); + return (X509Certificate) certFactory.generateCertificate( + new ByteArrayInputStream(certBytes)); + } + + public static JsonObject getUserInfo(Main main, TenantIdentifier tenantIdentifier, Storage storage, String accessToken, String clientId, boolean isLegacy) + throws TenantOrAppNotFoundException, StorageQueryException, + StorageTransactionLogicException, InvalidCodeException, FeatureNotEnabledException { + + checkForSAMLFeature(tenantIdentifier.toAppIdentifier(), main); + + SAMLStorage samlStorage = StorageUtils.getSAMLStorage(storage); + + SAMLClaimsInfo claimsInfo = samlStorage.getSAMLClaimsAndRemoveCode(tenantIdentifier, accessToken); + if (claimsInfo == null) { + throw new InvalidCodeException(); + } + + if (clientId != null) { + if (!clientId.equals(claimsInfo.clientId)) { + throw new InvalidCodeException(); + } + } + + String sub = null; + String email = null; + + JsonObject claims = claimsInfo.claims; + + if (claims.has("NameID")) { + sub = claims.getAsJsonArray("NameID").get(0).getAsString(); + } else if (claims.has("http://schemas.microsoft.com/identity/claims/objectidentifier")) { + sub = claims.getAsJsonArray("http://schemas.microsoft.com/identity/claims/objectidentifier") + .get(0).getAsString(); + } else if (claims.has("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name")) { + sub = claims.getAsJsonArray("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name") + .get(0).getAsString(); + } + + if (claims.has("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress")) { + email = claims.getAsJsonArray("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress") + .get(0).getAsString(); + } else if (claims.has("NameID")) { + String nameIdValue = claims.getAsJsonArray("NameID").get(0).getAsString(); + if (nameIdValue.contains("@")) { + email = nameIdValue; + } + } + + JsonObject payload = new JsonObject(); + payload.add("claims", claims); + payload.addProperty(isLegacy ? "id" : "sub", sub); + payload.addProperty("email", email); + payload.addProperty("aud", claimsInfo.clientId); + + return payload; + } + + public static String getLegacyACSURL(Main main, AppIdentifier appIdentifier) throws TenantOrAppNotFoundException { + CoreConfig config = Config.getConfig(appIdentifier.getAsPublicTenantIdentifier(), main); + return config.getSAMLLegacyACSURL(); + } + + public static String getMetadataXML(Main main, TenantIdentifier tenantIdentifier) + throws TenantOrAppNotFoundException, StorageQueryException, FeatureNotEnabledException { + checkForSAMLFeature(tenantIdentifier.toAppIdentifier(), main); + + SAMLCertificate certificate = SAMLCertificate.getInstance(tenantIdentifier.toAppIdentifier(), main); + CoreConfig config = Config.getConfig(tenantIdentifier, main); + String spEntityId = config.getSAMLSPEntityID(); + try { + X509Certificate cert = certificate.getCertificate(); + String certString = java.util.Base64.getEncoder().encodeToString(cert.getEncoded()); + + String validUntil = java.time.format.DateTimeFormatter.ISO_INSTANT.format(cert.getNotAfter().toInstant()); + + StringBuilder sb = new StringBuilder(); + sb.append(""); + sb.append(""); + sb.append(""); + sb.append(""); + sb.append(""); + sb.append(""); + sb.append("").append(certString).append(""); + sb.append(""); + sb.append(""); + sb.append(""); + sb.append("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"); + sb.append(""); + sb.append(""); + + return sb.toString(); + } catch (Exception e) { + throw new IllegalStateException("Failed to generate SP metadata", e); + } + } + + private static String escapeXml(String input) { + if (input == null) { + return ""; + } + String result = input; + result = result.replace("&", "&"); + result = result.replace("\"", """); + result = result.replace("<", "<"); + result = result.replace(">", ">"); + return result; + } +} diff --git a/src/main/java/io/supertokens/saml/SAMLBootstrap.java b/src/main/java/io/supertokens/saml/SAMLBootstrap.java new file mode 100644 index 000000000..57455dcf8 --- /dev/null +++ b/src/main/java/io/supertokens/saml/SAMLBootstrap.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.saml; + +import java.util.HashMap; +import java.util.Map; + +import org.opensaml.core.config.InitializationException; +import org.opensaml.core.config.InitializationService; +import org.slf4j.LoggerFactory; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; + +public class SAMLBootstrap { + private static volatile boolean initialized = false; + + private SAMLBootstrap() {} + + public static void initialize() { + if (initialized) { + return; + } + synchronized (SAMLBootstrap.class) { + if (initialized) { + return; + } + try { + InitializationService.initialize(); + initialized = true; + } catch (InitializationException e) { + throw new RuntimeException("Failed to initialize OpenSAML", e); + } + } + } +} diff --git a/src/main/java/io/supertokens/saml/SAMLCertificate.java b/src/main/java/io/supertokens/saml/SAMLCertificate.java new file mode 100644 index 000000000..7e34d2c54 --- /dev/null +++ b/src/main/java/io/supertokens/saml/SAMLCertificate.java @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.saml; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.X509EncodedKeySpec; +import java.util.Base64; +import java.util.Date; +import java.util.List; +import java.util.Map; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.KeyUsage; +import org.bouncycastle.cert.CertIOException; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +import io.supertokens.Main; +import io.supertokens.ResourceDistributor; +import io.supertokens.output.Logging; +import io.supertokens.pluginInterface.KeyValueInfo; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.sqlStorage.SQLStorage; +import io.supertokens.storageLayer.StorageLayer; + +public class SAMLCertificate extends ResourceDistributor.SingletonResource { + private static final String RESOURCE_KEY = "io.supertokens.saml.SAMLCertificate"; + private final Main main; + private final AppIdentifier appIdentifier; + + private static final String SAML_KEY_PAIR_NAME = "saml_key_pair"; + private static final String SAML_CERTIFICATE_NAME = "saml_certificate"; + + private KeyPair spKeyPair = null; + private X509Certificate spCertificate = null; + + private SAMLCertificate(AppIdentifier appIdentifier, Main main) throws + TenantOrAppNotFoundException { + this.main = main; + this.appIdentifier = appIdentifier; +// try { +// this.getCertificate(); +// } catch (StorageQueryException e) { +// Logging.error(main, appIdentifier.getAsPublicTenantIdentifier(), "Error while fetching SAML key and certificate", +// false, e); +// } + } + + public synchronized X509Certificate getCertificate() + throws StorageQueryException, TenantOrAppNotFoundException { + if (this.spCertificate == null || this.spCertificate.getNotAfter().before(new Date())) { + maybeGenerateNewCertificateAndUpdateInDb(); + } + + return this.spCertificate; + } + + private void maybeGenerateNewCertificateAndUpdateInDb() throws TenantOrAppNotFoundException { + SQLStorage storage = (SQLStorage) StorageLayer.getStorage( + this.appIdentifier.getAsPublicTenantIdentifier(), main); + + try { + storage.startTransaction(con -> { + KeyValueInfo keyPairInfo = storage.getKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_KEY_PAIR_NAME); + KeyValueInfo certInfo = storage.getKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_CERTIFICATE_NAME); + + if (keyPairInfo == null || certInfo == null) { + try { + generateNewCertificate(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + try { + String keyPairStr = serializeKeyPair(spKeyPair); + String certStr = serializeCertificate(spCertificate); + keyPairInfo = new KeyValueInfo(keyPairStr); + certInfo = new KeyValueInfo(certStr); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize key pair or certificate", e); + } + storage.setKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_KEY_PAIR_NAME, keyPairInfo); + storage.setKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_CERTIFICATE_NAME, certInfo); + } + + String keyPairStr = keyPairInfo.value; + String certStr = certInfo.value; + + try { + this.spKeyPair = deserializeKeyPair(keyPairStr); + this.spCertificate = deserializeCertificate(certStr); + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize key pair or certificate", e); + } + + // If the certificate has expired, generate and persist a new one + if (this.spCertificate.getNotAfter().before(new Date())) { + try { + generateNewCertificate(); + String newKeyPairStr = serializeKeyPair(spKeyPair); + String newCertStr = serializeCertificate(spCertificate); + KeyValueInfo newKeyPairInfo = new KeyValueInfo(newKeyPairStr); + KeyValueInfo newCertInfo = new KeyValueInfo(newCertStr); + storage.setKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_KEY_PAIR_NAME, newKeyPairInfo); + storage.setKeyValue_Transaction(this.appIdentifier.getAsPublicTenantIdentifier(), con, SAML_CERTIFICATE_NAME, newCertInfo); + } catch (Exception e) { + throw new RuntimeException("Failed to regenerate expired certificate", e); + } + } + + return null; + }); + } catch (StorageTransactionLogicException | StorageQueryException e) { + throw new RuntimeException("Storage error", e); + } + } + + void generateNewCertificate() + throws NoSuchAlgorithmException, CertificateException, OperatorCreationException, CertIOException { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(4096); + spKeyPair = keyGen.generateKeyPair(); + spCertificate = generateSelfSignedCertificate(); + } + + private X509Certificate generateSelfSignedCertificate() + throws CertIOException, OperatorCreationException, CertificateException { + // Create a production-ready self-signed X.509 certificate using BouncyCastle + Date notBefore = new Date(); + Date notAfter = new Date(notBefore.getTime() + 10 * 365L * 24 * 60 * 60 * 1000); // 10 year validity + + // Create the certificate subject and issuer (same for self-signed) + X500Name subject = new X500Name("CN=SAML-SP, O=SuperTokens, C=US"); + X500Name issuer = subject; // Self-signed + + // Generate a random serial number + java.math.BigInteger serialNumber = java.math.BigInteger.valueOf(System.currentTimeMillis()); + + // Create the certificate builder + JcaX509v3CertificateBuilder certBuilder = new JcaX509v3CertificateBuilder( + issuer, + serialNumber, + notBefore, + notAfter, + subject, + spKeyPair.getPublic() + ); + + // Add extensions for proper SAML usage + // Key Usage: digitalSignature and keyEncipherment + KeyUsage keyUsage = new KeyUsage(KeyUsage.digitalSignature | KeyUsage.keyEncipherment); + certBuilder.addExtension(Extension.keyUsage, true, keyUsage); + + // Basic Constraints: not a CA + BasicConstraints basicConstraints = new BasicConstraints(false); + certBuilder.addExtension(Extension.basicConstraints, true, basicConstraints); + + // Create the content signer + ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256withRSA") + .build(spKeyPair.getPrivate()); + + // Build the certificate + X509CertificateHolder certHolder = certBuilder.build(contentSigner); + + // Convert to standard X509Certificate + JcaX509CertificateConverter converter = new JcaX509CertificateConverter(); + return converter.getCertificate(certHolder); + } + + /** + * Serializes a KeyPair to a Base64 encoded string format + */ + private String serializeKeyPair(KeyPair keyPair) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + // Write private key + byte[] privateKeyBytes = keyPair.getPrivate().getEncoded(); + baos.write(Base64.getEncoder().encode(privateKeyBytes)); + baos.write('\n'); + + // Write public key + byte[] publicKeyBytes = keyPair.getPublic().getEncoded(); + baos.write(Base64.getEncoder().encode(publicKeyBytes)); + + return baos.toString(); + } + + /** + * Deserializes a KeyPair from a Base64 encoded string format + */ + private KeyPair deserializeKeyPair(String keyPairStr) throws Exception { + String[] parts = keyPairStr.split("\n"); + if (parts.length != 2) { + throw new IllegalArgumentException("Invalid key pair string format"); + } + + // Decode private key + byte[] privateKeyBytes = Base64.getDecoder().decode(parts[0]); + PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(privateKeyBytes); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + PrivateKey privateKey = keyFactory.generatePrivate(privateKeySpec); + + // Decode public key + byte[] publicKeyBytes = Base64.getDecoder().decode(parts[1]); + X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(publicKeyBytes); + PublicKey publicKey = keyFactory.generatePublic(publicKeySpec); + + return new KeyPair(publicKey, privateKey); + } + + /** + * Serializes an X509Certificate to a Base64 encoded string format + */ + private String serializeCertificate(X509Certificate certificate) throws IOException { + try { + byte[] certBytes = certificate.getEncoded(); + return Base64.getEncoder().encodeToString(certBytes); + } catch (CertificateException e) { + throw new IOException("Failed to encode certificate", e); + } + } + + /** + * Deserializes an X509Certificate from a Base64 encoded string format + */ + private X509Certificate deserializeCertificate(String certStr) throws Exception { + try { + byte[] certBytes = Base64.getDecoder().decode(certStr); + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + ByteArrayInputStream bais = new ByteArrayInputStream(certBytes); + return (X509Certificate) certFactory.generateCertificate(bais); + } catch (CertificateException e) { + throw new Exception("Failed to decode certificate", e); + } + } + + public static SAMLCertificate getInstance(AppIdentifier appIdentifier, Main main) + throws TenantOrAppNotFoundException { + return (SAMLCertificate) main.getResourceDistributor() + .getResource(appIdentifier, RESOURCE_KEY); + } + + public static void loadForAllTenants(Main main, List apps, + List tenantsThatChanged) { + try { + main.getResourceDistributor().withResourceDistributorLock(() -> { + Map existingResources = + main.getResourceDistributor() + .getAllResourcesWithResourceKey(RESOURCE_KEY); + main.getResourceDistributor().clearAllResourcesWithResourceKey(RESOURCE_KEY); + for (AppIdentifier app : apps) { + ResourceDistributor.SingletonResource resource = existingResources.get( + new ResourceDistributor.KeyClass(app, RESOURCE_KEY)); + if (resource != null && !tenantsThatChanged.contains(app.getAsPublicTenantIdentifier())) { + main.getResourceDistributor().setResource(app, RESOURCE_KEY, + resource); + } else { + try { + main.getResourceDistributor() + .setResource(app, RESOURCE_KEY, + new SAMLCertificate(app, main)); + } catch (TenantOrAppNotFoundException e) { + Logging.error(main, app.getAsPublicTenantIdentifier(), e.getMessage(), false); + // continue loading other resources + } + } + } + return null; + }); + } catch (ResourceDistributor.FuncException e) { + throw new IllegalStateException("should never happen", e); + } + } +} diff --git a/src/main/java/io/supertokens/saml/exceptions/IDPInitiatedLoginDisallowedException.java b/src/main/java/io/supertokens/saml/exceptions/IDPInitiatedLoginDisallowedException.java new file mode 100644 index 000000000..92bfdb185 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/IDPInitiatedLoginDisallowedException.java @@ -0,0 +1,4 @@ +package io.supertokens.saml.exceptions; + +public class IDPInitiatedLoginDisallowedException extends Exception { +} diff --git a/src/main/java/io/supertokens/saml/exceptions/InvalidClientException.java b/src/main/java/io/supertokens/saml/exceptions/InvalidClientException.java new file mode 100644 index 000000000..99987c7d2 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/InvalidClientException.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.saml.exceptions; + +public class InvalidClientException extends Exception { +} diff --git a/src/main/java/io/supertokens/saml/exceptions/InvalidCodeException.java b/src/main/java/io/supertokens/saml/exceptions/InvalidCodeException.java new file mode 100644 index 000000000..d6c4a07c4 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/InvalidCodeException.java @@ -0,0 +1,5 @@ +package io.supertokens.saml.exceptions; + +public class InvalidCodeException extends Exception { + +} diff --git a/src/main/java/io/supertokens/saml/exceptions/InvalidRelayStateException.java b/src/main/java/io/supertokens/saml/exceptions/InvalidRelayStateException.java new file mode 100644 index 000000000..bb7d58000 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/InvalidRelayStateException.java @@ -0,0 +1,5 @@ +package io.supertokens.saml.exceptions; + +public class InvalidRelayStateException extends Exception { + +} diff --git a/src/main/java/io/supertokens/saml/exceptions/MalformedSAMLMetadataXMLException.java b/src/main/java/io/supertokens/saml/exceptions/MalformedSAMLMetadataXMLException.java new file mode 100644 index 000000000..febbde270 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/MalformedSAMLMetadataXMLException.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.saml.exceptions; + +public class MalformedSAMLMetadataXMLException extends Exception { +} diff --git a/src/main/java/io/supertokens/saml/exceptions/SAMLResponseVerificationFailedException.java b/src/main/java/io/supertokens/saml/exceptions/SAMLResponseVerificationFailedException.java new file mode 100644 index 000000000..f9c7c58c5 --- /dev/null +++ b/src/main/java/io/supertokens/saml/exceptions/SAMLResponseVerificationFailedException.java @@ -0,0 +1,5 @@ +package io.supertokens.saml.exceptions; + +public class SAMLResponseVerificationFailedException extends Exception { + +} diff --git a/src/main/java/io/supertokens/signingkeys/AccessTokenSigningKey.java b/src/main/java/io/supertokens/signingkeys/AccessTokenSigningKey.java index 1cb9fd262..cd255135d 100644 --- a/src/main/java/io/supertokens/signingkeys/AccessTokenSigningKey.java +++ b/src/main/java/io/supertokens/signingkeys/AccessTokenSigningKey.java @@ -69,7 +69,7 @@ private AccessTokenSigningKey(AppIdentifier appIdentifier, Main main) this.appIdentifier = appIdentifier; try { this.transferLegacyKeyToNewTable(); - this.getOrCreateAndGetSigningKeys(); +// this.getOrCreateAndGetSigningKeys(); } catch (StorageQueryException | StorageTransactionLogicException e) { Logging.error(main, appIdentifier.getAsPublicTenantIdentifier(), "Error while fetching access token signing key", false, e); diff --git a/src/main/java/io/supertokens/signingkeys/JWTSigningKey.java b/src/main/java/io/supertokens/signingkeys/JWTSigningKey.java index db9c0770b..23012150d 100644 --- a/src/main/java/io/supertokens/signingkeys/JWTSigningKey.java +++ b/src/main/java/io/supertokens/signingkeys/JWTSigningKey.java @@ -82,7 +82,7 @@ public static void loadForAllTenants(Main main, List apps, List { public static final SemVer v5_1 = new SemVer("5.1"); public static final SemVer v5_2 = new SemVer("5.2"); public static final SemVer v5_3 = new SemVer("5.3"); + public static final SemVer v5_4 = new SemVer("5.4"); final private String version; diff --git a/src/main/java/io/supertokens/webserver/Webserver.java b/src/main/java/io/supertokens/webserver/Webserver.java index 3dcfe650b..233c595c3 100644 --- a/src/main/java/io/supertokens/webserver/Webserver.java +++ b/src/main/java/io/supertokens/webserver/Webserver.java @@ -16,6 +16,19 @@ package io.supertokens.webserver; +import java.io.File; +import java.util.UUID; +import java.util.logging.Handler; +import java.util.logging.Logger; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.connector.Connector; +import org.apache.catalina.core.StandardContext; +import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.util.http.fileupload.FileUtils; +import org.jetbrains.annotations.TestOnly; + import io.supertokens.Main; import io.supertokens.OperatingSystem; import io.supertokens.ResourceDistributor; @@ -25,50 +38,150 @@ import io.supertokens.output.Logging; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; -import io.supertokens.webserver.api.accountlinking.*; +import io.supertokens.webserver.api.accountlinking.CanCreatePrimaryUserAPI; +import io.supertokens.webserver.api.accountlinking.CanLinkAccountsAPI; +import io.supertokens.webserver.api.accountlinking.CreatePrimaryUserAPI; +import io.supertokens.webserver.api.accountlinking.LinkAccountsAPI; +import io.supertokens.webserver.api.accountlinking.UnlinkAccountAPI; import io.supertokens.webserver.api.bulkimport.BulkImportAPI; import io.supertokens.webserver.api.bulkimport.CountBulkImportUsersAPI; import io.supertokens.webserver.api.bulkimport.DeleteBulkImportUserAPI; import io.supertokens.webserver.api.bulkimport.ImportUserAPI; -import io.supertokens.webserver.api.core.*; -import io.supertokens.webserver.api.dashboard.*; +import io.supertokens.webserver.api.core.ActiveUsersCountAPI; +import io.supertokens.webserver.api.core.ApiVersionAPI; +import io.supertokens.webserver.api.core.ConfigAPI; +import io.supertokens.webserver.api.core.DeleteUserAPI; +import io.supertokens.webserver.api.core.EEFeatureFlagAPI; +import io.supertokens.webserver.api.core.GetUserByIdAPI; +import io.supertokens.webserver.api.core.HelloAPI; +import io.supertokens.webserver.api.core.JWKSPublicAPI; +import io.supertokens.webserver.api.core.LicenseKeyAPI; +import io.supertokens.webserver.api.core.ListUsersByAccountInfoAPI; +import io.supertokens.webserver.api.core.NotFoundOrHelloAPI; +import io.supertokens.webserver.api.core.RequestStatsAPI; +import io.supertokens.webserver.api.core.SearchTagsAPI; +import io.supertokens.webserver.api.core.TelemetryAPI; +import io.supertokens.webserver.api.core.UsersAPI; +import io.supertokens.webserver.api.core.UsersCountAPI; +import io.supertokens.webserver.api.dashboard.DashboardSignInAPI; +import io.supertokens.webserver.api.dashboard.DashboardUserAPI; +import io.supertokens.webserver.api.dashboard.GetDashboardSessionsForUserAPI; +import io.supertokens.webserver.api.dashboard.GetDashboardUsersAPI; +import io.supertokens.webserver.api.dashboard.GetTenantCoreConfigForDashboardAPI; +import io.supertokens.webserver.api.dashboard.RevokeSessionAPI; +import io.supertokens.webserver.api.dashboard.VerifyDashboardUserSessionAPI; +import io.supertokens.webserver.api.emailpassword.ConsumeResetPasswordAPI; +import io.supertokens.webserver.api.emailpassword.GeneratePasswordResetTokenAPI; +import io.supertokens.webserver.api.emailpassword.ImportUserWithPasswordHashAPI; +import io.supertokens.webserver.api.emailpassword.ResetPasswordAPI; import io.supertokens.webserver.api.emailpassword.SignInAPI; +import io.supertokens.webserver.api.emailpassword.SignUpAPI; import io.supertokens.webserver.api.emailpassword.UserAPI; -import io.supertokens.webserver.api.emailpassword.*; import io.supertokens.webserver.api.emailverification.GenerateEmailVerificationTokenAPI; import io.supertokens.webserver.api.emailverification.RevokeAllTokensForUserAPI; import io.supertokens.webserver.api.emailverification.UnverifyEmailAPI; import io.supertokens.webserver.api.emailverification.VerifyEmailAPI; import io.supertokens.webserver.api.jwt.JWKSAPI; import io.supertokens.webserver.api.jwt.JWTSigningAPI; -import io.supertokens.webserver.api.multitenancy.*; +import io.supertokens.webserver.api.multitenancy.AssociateUserToTenantAPI; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateAppAPI; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateAppV2API; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateConnectionUriDomainAPI; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateConnectionUriDomainV2API; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateTenantOrGetTenantAPI; +import io.supertokens.webserver.api.multitenancy.CreateOrUpdateTenantOrGetTenantV2API; +import io.supertokens.webserver.api.multitenancy.DisassociateUserFromTenant; +import io.supertokens.webserver.api.multitenancy.ListAppsAPI; +import io.supertokens.webserver.api.multitenancy.ListAppsV2API; +import io.supertokens.webserver.api.multitenancy.ListConnectionUriDomainsAPI; +import io.supertokens.webserver.api.multitenancy.ListConnectionUriDomainsV2API; +import io.supertokens.webserver.api.multitenancy.ListTenantsAPI; +import io.supertokens.webserver.api.multitenancy.ListTenantsV2API; +import io.supertokens.webserver.api.multitenancy.RemoveAppAPI; +import io.supertokens.webserver.api.multitenancy.RemoveConnectionUriDomainAPI; +import io.supertokens.webserver.api.multitenancy.RemoveTenantAPI; import io.supertokens.webserver.api.multitenancy.thirdparty.CreateOrUpdateThirdPartyConfigAPI; import io.supertokens.webserver.api.multitenancy.thirdparty.RemoveThirdPartyConfigAPI; -import io.supertokens.webserver.api.oauth.*; -import io.supertokens.webserver.api.passwordless.*; -import io.supertokens.webserver.api.session.*; +import io.supertokens.webserver.api.oauth.CreateUpdateOrGetOAuthClientAPI; +import io.supertokens.webserver.api.oauth.OAuthAcceptAuthConsentRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthAcceptAuthLoginRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthAcceptAuthLogoutRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthAuthAPI; +import io.supertokens.webserver.api.oauth.OAuthClientListAPI; +import io.supertokens.webserver.api.oauth.OAuthGetAuthConsentRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthGetAuthLoginRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthLogoutAPI; +import io.supertokens.webserver.api.oauth.OAuthRejectAuthConsentRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthRejectAuthLoginRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthRejectAuthLogoutRequestAPI; +import io.supertokens.webserver.api.oauth.OAuthTokenAPI; +import io.supertokens.webserver.api.oauth.OAuthTokenIntrospectAPI; +import io.supertokens.webserver.api.oauth.RemoveOAuthClientAPI; +import io.supertokens.webserver.api.oauth.RevokeOAuthSessionAPI; +import io.supertokens.webserver.api.oauth.RevokeOAuthTokenAPI; +import io.supertokens.webserver.api.oauth.RevokeOAuthTokensAPI; +import io.supertokens.webserver.api.passwordless.CheckCodeAPI; +import io.supertokens.webserver.api.passwordless.ConsumeCodeAPI; +import io.supertokens.webserver.api.passwordless.CreateCodeAPI; +import io.supertokens.webserver.api.passwordless.DeleteCodeAPI; +import io.supertokens.webserver.api.passwordless.DeleteCodesAPI; +import io.supertokens.webserver.api.passwordless.GetCodesAPI; +import io.supertokens.webserver.api.saml.CreateOrUpdateSamlClientAPI; +import io.supertokens.webserver.api.saml.CreateSamlLoginRedirectAPI; +import io.supertokens.webserver.api.saml.GetUserInfoAPI; +import io.supertokens.webserver.api.saml.HandleSamlCallbackAPI; +import io.supertokens.webserver.api.saml.LegacyAuthorizeAPI; +import io.supertokens.webserver.api.saml.LegacyCallbackAPI; +import io.supertokens.webserver.api.saml.LegacyTokenAPI; +import io.supertokens.webserver.api.saml.LegacyUserinfoAPI; +import io.supertokens.webserver.api.saml.ListSamlClientsAPI; +import io.supertokens.webserver.api.saml.RemoveSamlClientAPI; +import io.supertokens.webserver.api.saml.SPMetadataAPI; +import io.supertokens.webserver.api.session.HandshakeAPI; +import io.supertokens.webserver.api.session.JWTDataAPI; +import io.supertokens.webserver.api.session.RefreshSessionAPI; +import io.supertokens.webserver.api.session.SessionAPI; +import io.supertokens.webserver.api.session.SessionDataAPI; +import io.supertokens.webserver.api.session.SessionRegenerateAPI; +import io.supertokens.webserver.api.session.SessionRemoveAPI; +import io.supertokens.webserver.api.session.SessionUserAPI; +import io.supertokens.webserver.api.session.VerifySessionAPI; import io.supertokens.webserver.api.thirdparty.GetUsersByEmailAPI; import io.supertokens.webserver.api.thirdparty.SignInUpAPI; -import io.supertokens.webserver.api.totp.*; +import io.supertokens.webserver.api.totp.CreateOrUpdateTotpDeviceAPI; +import io.supertokens.webserver.api.totp.GetTotpDevicesAPI; +import io.supertokens.webserver.api.totp.ImportTotpDeviceAPI; +import io.supertokens.webserver.api.totp.RemoveTotpDeviceAPI; +import io.supertokens.webserver.api.totp.VerifyTotpAPI; +import io.supertokens.webserver.api.totp.VerifyTotpDeviceAPI; import io.supertokens.webserver.api.useridmapping.RemoveUserIdMappingAPI; import io.supertokens.webserver.api.useridmapping.UpdateExternalUserIdInfoAPI; import io.supertokens.webserver.api.useridmapping.UserIdMappingAPI; import io.supertokens.webserver.api.usermetadata.RemoveUserMetadataAPI; import io.supertokens.webserver.api.usermetadata.UserMetadataAPI; -import io.supertokens.webserver.api.userroles.*; -import io.supertokens.webserver.api.webauthn.*; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.connector.Connector; -import org.apache.catalina.core.StandardContext; -import org.apache.catalina.startup.Tomcat; -import org.apache.tomcat.util.http.fileupload.FileUtils; -import org.jetbrains.annotations.TestOnly; - -import java.io.File; -import java.util.UUID; -import java.util.logging.Handler; -import java.util.logging.Logger; +import io.supertokens.webserver.api.userroles.AddUserRoleAPI; +import io.supertokens.webserver.api.userroles.CreateRoleAPI; +import io.supertokens.webserver.api.userroles.GetPermissionsForRoleAPI; +import io.supertokens.webserver.api.userroles.GetRolesAPI; +import io.supertokens.webserver.api.userroles.GetRolesForPermissionAPI; +import io.supertokens.webserver.api.userroles.GetRolesForUserAPI; +import io.supertokens.webserver.api.userroles.GetUsersForRoleAPI; +import io.supertokens.webserver.api.userroles.RemovePermissionsForRoleAPI; +import io.supertokens.webserver.api.userroles.RemoveRoleAPI; +import io.supertokens.webserver.api.userroles.RemoveUserRoleAPI; +import io.supertokens.webserver.api.webauthn.ConsumeRecoverAccountTokenAPI; +import io.supertokens.webserver.api.webauthn.CredentialsRegisterAPI; +import io.supertokens.webserver.api.webauthn.GenerateRecoverAccountTokenAPI; +import io.supertokens.webserver.api.webauthn.GetCredentialAPI; +import io.supertokens.webserver.api.webauthn.GetGeneratedOptionsAPI; +import io.supertokens.webserver.api.webauthn.GetUserFromRecoverAccountTokenAPI; +import io.supertokens.webserver.api.webauthn.ListCredentialsAPI; +import io.supertokens.webserver.api.webauthn.OptionsRegisterAPI; +import io.supertokens.webserver.api.webauthn.RemoveCredentialAPI; +import io.supertokens.webserver.api.webauthn.RemoveOptionsAPI; +import io.supertokens.webserver.api.webauthn.SignInOptionsAPI; +import io.supertokens.webserver.api.webauthn.SignUpWithCredentialRegisterAPI; +import io.supertokens.webserver.api.webauthn.UpdateUserEmailAPI; public class Webserver extends ResourceDistributor.SingletonResource { @@ -312,6 +425,19 @@ private void setupRoutes() { addAPI(new RevokeOAuthSessionAPI(main)); addAPI(new OAuthLogoutAPI(main)); + // saml + addAPI(new CreateOrUpdateSamlClientAPI(main)); + addAPI(new ListSamlClientsAPI(main)); + addAPI(new RemoveSamlClientAPI(main)); + addAPI(new CreateSamlLoginRedirectAPI(main)); + addAPI(new HandleSamlCallbackAPI(main)); + addAPI(new GetUserInfoAPI(main)); + addAPI(new LegacyAuthorizeAPI(main)); + addAPI(new LegacyCallbackAPI(main)); + addAPI(new LegacyTokenAPI(main)); + addAPI(new LegacyUserinfoAPI(main)); + addAPI(new SPMetadataAPI(main)); + //webauthn addAPI(new OptionsRegisterAPI(main)); addAPI(new SignInOptionsAPI(main)); diff --git a/src/main/java/io/supertokens/webserver/WebserverAPI.java b/src/main/java/io/supertokens/webserver/WebserverAPI.java index 95959a2f6..58b0f1863 100644 --- a/src/main/java/io/supertokens/webserver/WebserverAPI.java +++ b/src/main/java/io/supertokens/webserver/WebserverAPI.java @@ -82,10 +82,11 @@ public abstract class WebserverAPI extends HttpServlet { supportedVersions.add(SemVer.v5_1); supportedVersions.add(SemVer.v5_2); supportedVersions.add(SemVer.v5_3); + supportedVersions.add(SemVer.v5_4); } public static SemVer getLatestCDIVersion() { - return SemVer.v5_3; + return SemVer.v5_4; } public SemVer getLatestCDIVersionForRequest(HttpServletRequest req) @@ -122,6 +123,12 @@ protected void sendTextResponse(int statusCode, String message, HttpServletRespo resp.getWriter().println(message); } + protected void sendXMLResponse(int statusCode, String message, HttpServletResponse resp) throws IOException { + resp.setStatus(statusCode); + resp.setHeader("Content-Type", "text/xml; charset=UTF-8"); + resp.getWriter().println(message); + } + protected void sendJsonResponse(int statusCode, JsonElement json, HttpServletResponse resp) throws IOException { resp.setStatus(statusCode); resp.setHeader("Content-Type", "application/json; charset=UTF-8"); diff --git a/src/main/java/io/supertokens/webserver/api/core/ListUsersByAccountInfoAPI.java b/src/main/java/io/supertokens/webserver/api/core/ListUsersByAccountInfoAPI.java index 614fa9d18..deef164db 100644 --- a/src/main/java/io/supertokens/webserver/api/core/ListUsersByAccountInfoAPI.java +++ b/src/main/java/io/supertokens/webserver/api/core/ListUsersByAccountInfoAPI.java @@ -16,12 +16,13 @@ package io.supertokens.webserver.api.core; -import com.google.gson.Gson; +import java.io.IOException; + import com.google.gson.JsonArray; import com.google.gson.JsonObject; + import io.supertokens.Main; import io.supertokens.authRecipe.AuthRecipe; -import io.supertokens.output.Logging; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.exceptions.StorageQueryException; @@ -36,8 +37,6 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import java.io.IOException; - public class ListUsersByAccountInfoAPI extends WebserverAPI { public ListUsersByAccountInfoAPI(Main main) { @@ -92,10 +91,6 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IO } result.add("users", usersJson); - - Logging.info(main, tenantIdentifier, "ListUsersByAccountInfoAPI - credentialId is " + webauthnCredentialId, true); - Logging.info(main, tenantIdentifier, new Gson().toJson(result), true); - super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | TenantOrAppNotFoundException e) { diff --git a/src/main/java/io/supertokens/webserver/api/saml/CreateOrUpdateSamlClientAPI.java b/src/main/java/io/supertokens/webserver/api/saml/CreateOrUpdateSamlClientAPI.java new file mode 100644 index 000000000..7ee4d016a --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/CreateOrUpdateSamlClientAPI.java @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.security.cert.CertificateException; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.pluginInterface.saml.exception.DuplicateEntityIdException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.MalformedSAMLMetadataXMLException; +import io.supertokens.utils.Utils; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class CreateOrUpdateSamlClientAPI extends WebserverAPI { + + public CreateOrUpdateSamlClientAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/clients"; + } + + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + + String clientId = InputParser.parseStringOrThrowError(input, "clientId", true); + String clientSecret = InputParser.parseStringOrThrowError(input, "clientSecret", true); + String defaultRedirectURI = InputParser.parseStringOrThrowError(input, "defaultRedirectURI", false); + JsonArray redirectURIs = InputParser.parseArrayOrThrowError(input, "redirectURIs", false); + + if (redirectURIs.size() == 0) { + throw new ServletException(new BadRequestException("redirectURIs is required in the input")); + } + + String metadataXML = InputParser.parseStringOrThrowError(input, "metadataXML", false); + + Boolean allowIDPInitiatedLogin = InputParser.parseBooleanOrThrowError(input, "allowIDPInitiatedLogin", true); + Boolean enableRequestSigning = InputParser.parseBooleanOrThrowError(input, "enableRequestSigning", true); + + if (allowIDPInitiatedLogin == null) { + allowIDPInitiatedLogin = false; + } + + if (enableRequestSigning == null) { + enableRequestSigning = true; + } + + try { + byte[] decodedBytes = java.util.Base64.getDecoder().decode(metadataXML); + metadataXML = new String(decodedBytes, java.nio.charset.StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + throw new ServletException(new BadRequestException("metadataXML does not have a valid SAML metadata")); + } + + if (clientId == null) { + clientId = "st_saml_" + Utils.getUUID(); + } + + try { + SAMLClient client = SAML.createOrUpdateSAMLClient( + main, getTenantIdentifier(req), getTenantStorage(req), clientId, clientSecret, defaultRedirectURI, + redirectURIs, metadataXML, allowIDPInitiatedLogin, enableRequestSigning); + JsonObject res = client.toJson(); + res.addProperty("status", "OK"); + this.sendJsonResponse(200, res, resp); + } catch (DuplicateEntityIdException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "DUPLICATE_IDP_ENTITY_ERROR"); + this.sendJsonResponse(200, res, resp); + } catch (MalformedSAMLMetadataXMLException | CertificateException e) { + throw new ServletException(new BadRequestException("metadataXML does not have a valid SAML metadata")); + } catch (TenantOrAppNotFoundException | StorageQueryException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/CreateSamlLoginRedirectAPI.java b/src/main/java/io/supertokens/webserver/api/saml/CreateSamlLoginRedirectAPI.java new file mode 100644 index 000000000..8a04228f4 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/CreateSamlLoginRedirectAPI.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import com.google.gson.JsonObject; +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.InvalidClientException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; +import java.security.cert.CertificateEncodingException; + +public class CreateSamlLoginRedirectAPI extends WebserverAPI { + public CreateSamlLoginRedirectAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/login"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + String clientId = InputParser.parseStringOrThrowError(input, "clientId", false); + String redirectURI = InputParser.parseStringOrThrowError(input, "redirectURI", false); + String state = InputParser.parseStringOrThrowError(input, "state", true); + String acsURL = InputParser.parseStringOrThrowError(input, "acsURL", false); + + try { + String ssoRedirectURI = SAML.createRedirectURL( + main, + getTenantIdentifier(req), + getTenantStorage(req), + clientId, + redirectURI, + state, + acsURL); + + JsonObject res = new JsonObject(); + res.addProperty("status", "OK"); + res.addProperty("ssoRedirectURI", ssoRedirectURI); + super.sendJsonResponse(200, res, resp); + } catch (InvalidClientException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "INVALID_CLIENT_ERROR"); + super.sendJsonResponse(200, res, resp); + } catch (TenantOrAppNotFoundException | StorageQueryException | CertificateEncodingException | + FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/GetUserInfoAPI.java b/src/main/java/io/supertokens/webserver/api/saml/GetUserInfoAPI.java new file mode 100644 index 000000000..571ae216b --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/GetUserInfoAPI.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.InvalidCodeException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class GetUserInfoAPI extends WebserverAPI { + + public GetUserInfoAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/user"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + String accessToken = InputParser.parseStringOrThrowError(input, "accessToken", false); + String clientId = InputParser.parseStringOrThrowError(input, "clientId", false); + + try { + JsonObject userInfo = SAML.getUserInfo( + main, + getTenantIdentifier(req), + getTenantStorage(req), + accessToken, + clientId, + false + ); + userInfo.addProperty("status", "OK"); + + super.sendJsonResponse(200, userInfo, resp); + } catch (InvalidCodeException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "INVALID_TOKEN_ERROR"); + + super.sendJsonResponse(200, res, resp); + } catch (TenantOrAppNotFoundException | StorageQueryException | StorageTransactionLogicException | + FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/HandleSamlCallbackAPI.java b/src/main/java/io/supertokens/webserver/api/saml/HandleSamlCallbackAPI.java new file mode 100644 index 000000000..00c2847cb --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/HandleSamlCallbackAPI.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.security.cert.CertificateException; + +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import org.opensaml.core.xml.io.UnmarshallingException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.IDPInitiatedLoginDisallowedException; +import io.supertokens.saml.exceptions.InvalidClientException; +import io.supertokens.saml.exceptions.InvalidRelayStateException; +import io.supertokens.saml.exceptions.SAMLResponseVerificationFailedException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import net.shibboleth.utilities.java.support.xml.XMLParserException; + +public class HandleSamlCallbackAPI extends WebserverAPI { + + public HandleSamlCallbackAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/callback"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + String samlResponse = InputParser.parseStringOrThrowError(input, "samlResponse", false); + String relayState = InputParser.parseStringOrThrowError(input, "relayState", true); + + try { + String redirectURI = SAML.handleCallback( + main, + getTenantIdentifier(req), + getTenantStorage(req), + samlResponse, relayState + ); + + JsonObject res = new JsonObject(); + res.addProperty("status", "OK"); + res.addProperty("redirectURI", redirectURI); + super.sendJsonResponse(200, res, resp); + + } catch (InvalidRelayStateException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "INVALID_RELAY_STATE_ERROR"); + super.sendJsonResponse(200, res, resp); + } catch (InvalidClientException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "INVALID_CLIENT_ERROR"); + super.sendJsonResponse(200, res, resp); + } catch (SAMLResponseVerificationFailedException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "SAML_RESPONSE_VERIFICATION_FAILED_ERROR"); + super.sendJsonResponse(200, res, resp); + + } catch (IDPInitiatedLoginDisallowedException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "IDP_LOGIN_DISALLOWED_ERROR"); + super.sendJsonResponse(200, res, resp); + + } catch (UnmarshallingException | XMLParserException e) { + throw new ServletException(new BadRequestException("Invalid or malformed SAML response input")); + + } catch (TenantOrAppNotFoundException | StorageQueryException | CertificateException | + FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/LegacyAuthorizeAPI.java b/src/main/java/io/supertokens/webserver/api/saml/LegacyAuthorizeAPI.java new file mode 100644 index 000000000..c3d1d2204 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/LegacyAuthorizeAPI.java @@ -0,0 +1,66 @@ +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.security.cert.CertificateEncodingException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.InvalidClientException; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class LegacyAuthorizeAPI extends WebserverAPI { + + public LegacyAuthorizeAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/legacy/authorize"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + String clientId = InputParser.getQueryParamOrThrowError(req, "client_id", false); + String redirectURI = InputParser.getQueryParamOrThrowError(req, "redirect_uri", false); + String state = InputParser.getQueryParamOrThrowError(req, "state", true); + + + try { + String acsURL = SAML.getLegacyACSURL( + main, getAppIdentifier(req) + ); + if (acsURL == null) { + throw new IllegalStateException("Legacy ACS URL not configured"); + } + String ssoRedirectURI = SAML.createRedirectURL( + main, + getTenantIdentifier(req), + enforcePublicTenantAndGetPublicTenantStorage(req), + clientId, + redirectURI, + state, + acsURL); + + resp.sendRedirect(ssoRedirectURI, 307); + + } catch (InvalidClientException e) { + JsonObject res = new JsonObject(); + res.addProperty("status", "INVALID_CLIENT_ERROR"); + super.sendJsonResponse(200, res, resp); + } catch (TenantOrAppNotFoundException | StorageQueryException | CertificateEncodingException | BadPermissionException | + FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/LegacyCallbackAPI.java b/src/main/java/io/supertokens/webserver/api/saml/LegacyCallbackAPI.java new file mode 100644 index 000000000..64da47d67 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/LegacyCallbackAPI.java @@ -0,0 +1,73 @@ +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.security.cert.CertificateException; + +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import org.opensaml.core.xml.io.UnmarshallingException; + +import io.supertokens.Main; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.IDPInitiatedLoginDisallowedException; +import io.supertokens.saml.exceptions.InvalidClientException; +import io.supertokens.saml.exceptions.InvalidRelayStateException; +import io.supertokens.saml.exceptions.SAMLResponseVerificationFailedException; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import net.shibboleth.utilities.java.support.xml.XMLParserException; + +public class LegacyCallbackAPI extends WebserverAPI { + public LegacyCallbackAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/legacy/callback"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + String samlResponse = req.getParameter("SAMLResponse"); + if (samlResponse == null) { + samlResponse = req.getParameter("samlResponse"); + } + + String relayState = req.getParameter("RelayState"); + if (relayState == null) { + relayState = req.getParameter("relayState"); + } + + if (samlResponse == null || samlResponse.isBlank()) { + throw new ServletException(new BadRequestException("Missing form field: SAMLResponse")); + } + + try { + String redirectURI = SAML.handleCallback( + main, + getTenantIdentifier(req), + enforcePublicTenantAndGetPublicTenantStorage(req), + samlResponse, + relayState + ); + + resp.sendRedirect(redirectURI, 302); + } catch (InvalidRelayStateException e) { + sendTextResponse(400, "INVALID_RELAY_STATE_ERROR", resp); + } catch (InvalidClientException e) { + sendTextResponse(400, "INVALID_CLIENT_ERROR", resp); + } catch (SAMLResponseVerificationFailedException e) { + sendTextResponse(400, "SAML_RESPONSE_VERIFICATION_FAILED_ERROR", resp); + } catch (IDPInitiatedLoginDisallowedException e) { + sendTextResponse(400, "IDP_LOGIN_DISALLOWED_ERROR", resp); + } catch (TenantOrAppNotFoundException | StorageQueryException | UnmarshallingException | XMLParserException | + CertificateException | BadPermissionException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/LegacyTokenAPI.java b/src/main/java/io/supertokens/webserver/api/saml/LegacyTokenAPI.java new file mode 100644 index 000000000..f42725523 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/LegacyTokenAPI.java @@ -0,0 +1,72 @@ +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.util.Objects; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.saml.SAML; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class LegacyTokenAPI extends WebserverAPI { + + public LegacyTokenAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/legacy/token"; + } + + @Override + protected boolean checkAPIKey(HttpServletRequest req) { + return false; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + String clientId = req.getParameter("client_id"); + String clientSecret = req.getParameter("client_secret"); + String code = req.getParameter("code"); + + if (clientId == null || clientId.isBlank()) { + throw new ServletException(new BadRequestException("Missing form field: client_id")); + } + if (clientSecret == null || clientSecret.isBlank()) { + throw new ServletException(new BadRequestException("Missing form field: client_secret")); + } + if (code == null || code.isBlank()) { + throw new ServletException(new BadRequestException("Missing form field: code")); + } + + try { + SAMLClient client = SAML.getClient( + getTenantIdentifier(req), + enforcePublicTenantAndGetPublicTenantStorage(req), + clientId + ); + if (client == null) { + throw new ServletException(new BadRequestException("Invalid client_id")); + } + if (!Objects.equals(client.clientSecret, clientSecret)) { + throw new ServletException(new BadRequestException("Invalid client_secret")); + } + + JsonObject res = new JsonObject(); + res.addProperty("status", "OK"); + res.addProperty("access_token", code + "." + clientId); // return code itself as access token + super.sendJsonResponse(200, res, resp); + } catch (TenantOrAppNotFoundException | StorageQueryException | BadPermissionException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/LegacyUserinfoAPI.java b/src/main/java/io/supertokens/webserver/api/saml/LegacyUserinfoAPI.java new file mode 100644 index 000000000..b398b12e4 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/LegacyUserinfoAPI.java @@ -0,0 +1,64 @@ +package io.supertokens.webserver.api.saml; + +import java.io.IOException; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.multitenancy.exception.BadPermissionException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; +import io.supertokens.saml.exceptions.InvalidCodeException; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class LegacyUserinfoAPI extends WebserverAPI { + public LegacyUserinfoAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/legacy/userinfo"; + } + + @Override + protected boolean checkAPIKey(HttpServletRequest req) { + return false; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + String authorizationHeader = req.getHeader("Authorization"); + if (authorizationHeader == null || !authorizationHeader.startsWith("Bearer ")) { + throw new ServletException(new BadRequestException("Authorization header is required")); + } + + String accessToken = authorizationHeader.substring("Bearer ".length()); + + if (!accessToken.contains(".")) { + super.sendTextResponse(400, "INVALID_TOKEN_ERROR", resp); + return; + } + + String clientId = accessToken.split("[.]")[1]; + accessToken = accessToken.split("[.]")[0]; + try { + JsonObject userInfo = SAML.getUserInfo( + main, getAppIdentifier(req).getAsPublicTenantIdentifier(), enforcePublicTenantAndGetPublicTenantStorage(req), accessToken, clientId, true + ); + super.sendJsonResponse(200, userInfo, resp); + } catch (InvalidCodeException e) { + super.sendTextResponse(400, "INVALID_TOKEN_ERROR", resp); + + } catch (StorageQueryException | TenantOrAppNotFoundException | BadPermissionException | + StorageTransactionLogicException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/ListSamlClientsAPI.java b/src/main/java/io/supertokens/webserver/api/saml/ListSamlClientsAPI.java new file mode 100644 index 000000000..11cb8081f --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/ListSamlClientsAPI.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import java.io.IOException; +import java.util.List; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.saml.SAML; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public class ListSamlClientsAPI extends WebserverAPI { + + public ListSamlClientsAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/clients/list"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + + try { + List clients = SAML.getClients(getTenantIdentifier(req), getTenantStorage(req)); + + JsonObject res = new JsonObject(); + res.addProperty("status", "OK"); + JsonArray clientsArray = new JsonArray(); + for (SAMLClient client : clients) { + clientsArray.add(client.toJson()); + } + res.add("clients", clientsArray); + + super.sendJsonResponse(200, res, resp); + } catch (TenantOrAppNotFoundException | StorageQueryException e) { + throw new ServletException(e); + } + } +} diff --git a/src/main/java/io/supertokens/webserver/api/saml/RemoveSamlClientAPI.java b/src/main/java/io/supertokens/webserver/api/saml/RemoveSamlClientAPI.java new file mode 100644 index 000000000..2172d76a1 --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/RemoveSamlClientAPI.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.webserver.api.saml; + +import com.google.gson.JsonObject; + +import io.supertokens.Main; +import io.supertokens.webserver.InputParser; +import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.saml.SAML; + +public class RemoveSamlClientAPI extends WebserverAPI { + + public RemoveSamlClientAPI(Main main) { + super(main, "saml"); + } + + @Override + public String getPath() { + return "/recipe/saml/clients/remove"; + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + JsonObject input = InputParser.parseJsonObjectOrThrowError(req); + String clientId = InputParser.parseStringOrThrowError(input, "clientId", false); + + try { + boolean didExist = SAML.removeSAMLClient(getTenantIdentifier(req), getTenantStorage(req), clientId); + JsonObject res = new JsonObject(); + res.addProperty("status", "OK"); + res.addProperty("didExist", didExist); + super.sendJsonResponse(200, res, resp); + + } catch (TenantOrAppNotFoundException | StorageQueryException e) { + throw new ServletException(e); + } + + } +} + + diff --git a/src/main/java/io/supertokens/webserver/api/saml/SPMetadataAPI.java b/src/main/java/io/supertokens/webserver/api/saml/SPMetadataAPI.java new file mode 100644 index 000000000..54bd99c1f --- /dev/null +++ b/src/main/java/io/supertokens/webserver/api/saml/SPMetadataAPI.java @@ -0,0 +1,45 @@ +package io.supertokens.webserver.api.saml; + +import io.supertokens.Main; +import io.supertokens.saml.SAML; +import io.supertokens.webserver.WebserverAPI; +import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public class SPMetadataAPI extends WebserverAPI { + + public SPMetadataAPI(Main main) { + super(main, "saml"); + } + + @Override + protected boolean checkAPIKey(HttpServletRequest req) { + return false; + } + + @Override + public String getPath() { + return "/.well-known/sp-metadata"; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException, ServletException { + + try { + String metadataXML = SAML.getMetadataXML( + main, getTenantIdentifier(req) + ); + + super.sendXMLResponse(200, metadataXML, resp); + + } catch (TenantOrAppNotFoundException | StorageQueryException | FeatureNotEnabledException e) { + throw new ServletException(e); + } + } +} diff --git a/src/test/java/io/supertokens/test/CronjobTest.java b/src/test/java/io/supertokens/test/CronjobTest.java index 4108c5283..cb4e4ba17 100644 --- a/src/test/java/io/supertokens/test/CronjobTest.java +++ b/src/test/java/io/supertokens/test/CronjobTest.java @@ -964,7 +964,7 @@ public void testThatCronJobsHaveTenantsInfoAfterRestart() throws Exception { { List>> tenantsInfos = Cronjobs.getInstance(process.getProcess()) .getTenantInfos(); - assertEquals(13, tenantsInfos.size()); + assertEquals(14, tenantsInfos.size()); int count = 0; for (List> tenantsInfo : tenantsInfos) { if (tenantsInfo != null) { @@ -976,7 +976,7 @@ public void testThatCronJobsHaveTenantsInfoAfterRestart() throws Exception { count++; } } - assertEquals(12, count); + assertEquals(13, count); } process.kill(false); @@ -993,7 +993,7 @@ public void testThatCronJobsHaveTenantsInfoAfterRestart() throws Exception { { List>> tenantsInfos = Cronjobs.getInstance(process.getProcess()) .getTenantInfos(); - assertEquals(13, tenantsInfos.size()); + assertEquals(14, tenantsInfos.size()); int count = 0; for (List> tenantsInfo : tenantsInfos) { if (tenantsInfo != null) { @@ -1005,7 +1005,7 @@ public void testThatCronJobsHaveTenantsInfoAfterRestart() throws Exception { count++; } } - assertEquals(12, count); + assertEquals(13, count); } process.kill(); @@ -1056,6 +1056,7 @@ public void testThatThereAreTasksOfAllCronTaskClassesAndHaveCorrectIntervals() t intervals.put("io.supertokens.cronjobs.cleanupOAuthSessionsAndChallenges.CleanupOAuthSessionsAndChallenges", 86400); intervals.put("io.supertokens.cronjobs.cleanupWebauthnExpiredData.CleanUpWebauthNExpiredDataCron", 86400); + intervals.put("io.supertokens.cronjobs.cleanupSAMLCodes.CleanupSAMLCodes", 3600); Map delays = new HashMap<>(); delays.put("io.supertokens.ee.cronjobs.EELicenseCheck", 86400); @@ -1074,9 +1075,10 @@ public void testThatThereAreTasksOfAllCronTaskClassesAndHaveCorrectIntervals() t delays.put("io.supertokens.cronjobs.cleanupOAuthSessionsAndChallenges.CleanupOAuthSessionsAndChallenges", 0); delays.put("io.supertokens.cronjobs.cleanupWebauthnExpiredData.CleanUpWebauthNExpiredDataCron", 0); + delays.put("io.supertokens.cronjobs.cleanupSAMLCodes.CleanupSAMLCodes", 0); List allTasks = Cronjobs.getInstance(process.getProcess()).getTasks(); - assertEquals(13, allTasks.size()); + assertEquals(14, allTasks.size()); for (CronTask task : allTasks) { System.out.println(task.getClass().getName()); diff --git a/src/test/java/io/supertokens/test/FeatureFlagTest.java b/src/test/java/io/supertokens/test/FeatureFlagTest.java index af39ac49b..f90e49e63 100644 --- a/src/test/java/io/supertokens/test/FeatureFlagTest.java +++ b/src/test/java/io/supertokens/test/FeatureFlagTest.java @@ -911,6 +911,9 @@ public void testNetworkCallIsMadeInCoreInit() throws Exception { private final String OPAQUE_KEY_WITH_OAUTH_FEATURE = "hjspBIZu94zCJ2g7w6SMz4ERAKyaLogBpSy8OhgjcLRjsRiH2CXKEEgI" + "SAikEn2lixgV67=56LrTqHiExBcOuZU-TQoYAaTJuLNNdKxHjXAdgDdB5g1kYDcPANGNEoV-"; + private final String OPAQUE_KEY_WITH_SAML_FEATURE = "WwXBgSut8MoVSV8KMhV7V1qTI=pXVW6=VkcbXSkiNuk57RUc77F7YYzJ" + + "Zs34n9O1YJjNCdiuyerMiMm7eC0hlr=8vV1SoJeKU0UhQWYKHiOfD47klDwe=EMmtFJ9T7St"; + @Test public void testPaidStatsContainsAllEnabledFeatures() throws Exception { String[] args = {"../"}; @@ -925,7 +928,8 @@ public void testPaidStatsContainsAllEnabledFeatures() throws Exception { OPAQUE_KEY_WITH_DASHBOARD_FEATURE, OPAQUE_KEY_WITH_ACCOUNT_LINKING_FEATURE, OPAQUE_KEY_WITH_SECURITY_FEATURE, - OPAQUE_KEY_WITH_OAUTH_FEATURE + OPAQUE_KEY_WITH_OAUTH_FEATURE, + OPAQUE_KEY_WITH_SAML_FEATURE }; Set requiredFeatures = new HashSet<>(); diff --git a/src/test/java/io/supertokens/test/PluginTest.java b/src/test/java/io/supertokens/test/PluginTest.java index eedc7f2a5..71c86c31f 100644 --- a/src/test/java/io/supertokens/test/PluginTest.java +++ b/src/test/java/io/supertokens/test/PluginTest.java @@ -61,7 +61,7 @@ public void beforeEach() { StorageLayer.clearURLClassLoader(); } - @Test + // @Test public void missingPluginFolderTest() throws Exception { String[] args = {"../"}; @@ -89,7 +89,7 @@ public void missingPluginFolderTest() throws Exception { } - @Test + // @Test public void emptyPluginFolderTest() throws Exception { String[] args = {"../"}; try { @@ -118,7 +118,7 @@ public void emptyPluginFolderTest() throws Exception { } } - @Test + // @Test public void doesNotContainPluginTest() throws Exception { String[] args = {"../"}; diff --git a/src/test/java/io/supertokens/test/SuperTokensSaaSSecretTest.java b/src/test/java/io/supertokens/test/SuperTokensSaaSSecretTest.java index 7b91f203d..146287827 100644 --- a/src/test/java/io/supertokens/test/SuperTokensSaaSSecretTest.java +++ b/src/test/java/io/supertokens/test/SuperTokensSaaSSecretTest.java @@ -439,7 +439,8 @@ public static void checkSessionResponse(JsonObject response, TestingProcessManag "oauth_provider_public_service_url", "oauth_provider_admin_service_url", "oauth_provider_consent_login_base_url", - "oauth_provider_url_configured_in_oauth_provider" + "oauth_provider_url_configured_in_oauth_provider", + "saml_legacy_acs_url" }; private static final Object[] PROTECTED_CORE_CONFIG_VALUES = new String[]{ "127\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+|::1|0:0:0:0:0:0:0:1", @@ -447,7 +448,8 @@ public static void checkSessionResponse(JsonObject response, TestingProcessManag "http://localhost:4444", "http://localhost:4445", "http://localhost:3001/auth/oauth", - "http://localhost:4444" + "http://localhost:4444", + "http://localhost:5225/api/oauth/saml" }; @Test diff --git a/src/test/java/io/supertokens/test/httpRequest/HttpRequestForTesting.java b/src/test/java/io/supertokens/test/httpRequest/HttpRequestForTesting.java index 7cdb71fc9..aaf12424e 100644 --- a/src/test/java/io/supertokens/test/httpRequest/HttpRequestForTesting.java +++ b/src/test/java/io/supertokens/test/httpRequest/HttpRequestForTesting.java @@ -16,17 +16,26 @@ package io.supertokens.test.httpRequest; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; + import com.google.gson.JsonElement; +import com.google.gson.JsonObject; import com.google.gson.JsonParser; + import io.supertokens.Main; import io.supertokens.ResourceDistributor; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; -import java.io.*; -import java.net.*; -import java.nio.charset.StandardCharsets; -import java.util.Map; - public class HttpRequestForTesting { private static final int STATUS_CODE_ERROR_THRESHOLD = 400; public static boolean disableAddingAppId = false; @@ -60,11 +69,18 @@ private static boolean isJsonValid(String jsonInString) { } } - @SuppressWarnings("unchecked") public static T sendGETRequest(Main main, String requestID, String url, Map params, int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, String rid) throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + return sendGETRequest(main, requestID, url, params, connectionTimeoutMS, readTimeoutMS, version, cdiVersion, rid, true); + } + + @SuppressWarnings("unchecked") + public static T sendGETRequest(Main main, String requestID, String url, Map params, + int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, + String rid, boolean followRedirects) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { if (!disableAddingAppId && !url.contains("appid-") && !url.contains(":3567/config")) { String appId = ResourceDistributor.getAppForTesting().getAppId(); @@ -96,6 +112,7 @@ public static T sendGETRequest(Main main, String requestID, String url, Map< con = (HttpURLConnection) obj.openConnection(); con.setConnectTimeout(connectionTimeoutMS); con.setReadTimeout(readTimeoutMS + 1000); + con.setInstanceFollowRedirects(followRedirects); if (version != null) { con.setRequestProperty("api-version", version + ""); } @@ -108,6 +125,14 @@ public static T sendGETRequest(Main main, String requestID, String url, Map< int responseCode = con.getResponseCode(); + // Handle redirects specially + if (responseCode >= 300 && responseCode < 400) { + String location = con.getHeaderField("Location"); + if (location != null) { + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, location); + } + } + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { inputStream = con.getInputStream(); } else { @@ -139,12 +164,120 @@ public static T sendGETRequest(Main main, String requestID, String url, Map< } } + public static T sendGETRequestWithHeaders(Main main, String requestID, String url, Map params, + Map headers, int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, + String rid) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + return sendGETRequestWithHeaders(main, requestID, url, params, headers, connectionTimeoutMS, readTimeoutMS, version, cdiVersion, rid, true); + } + @SuppressWarnings("unchecked") + public static T sendGETRequestWithHeaders(Main main, String requestID, String url, Map params, + Map headers, int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, + String rid, boolean followRedirects) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + + if (!disableAddingAppId && !url.contains("appid-") && !url.contains(":3567/config")) { + String appId = ResourceDistributor.getAppForTesting().getAppId(); + url = url.replace(":3567", ":3567/appid-" + appId); + } + + if (corePort != null) { + url = url.replace(":3567", ":" + corePort); + } + + StringBuilder paramBuilder = new StringBuilder(); + + if (params != null) { + for (Map.Entry entry : params.entrySet()) { + paramBuilder.append(entry.getKey()).append("=") + .append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8)).append("&"); + } + } + String paramsStr = paramBuilder.toString(); + if (!paramsStr.equals("")) { + paramsStr = paramsStr.substring(0, paramsStr.length() - 1); + url = url + "?" + paramsStr; + } + URL obj = getURL(main, requestID, url); + InputStream inputStream = null; + HttpURLConnection con = null; + + try { + con = (HttpURLConnection) obj.openConnection(); + con.setConnectTimeout(connectionTimeoutMS); + con.setReadTimeout(readTimeoutMS + 1000); + con.setInstanceFollowRedirects(followRedirects); + if (headers != null) { + for (Map.Entry entry : headers.entrySet()) { + con.setRequestProperty(entry.getKey(), entry.getValue()); + } + } + if (version != null) { + con.setRequestProperty("api-version", version + ""); + } + if (cdiVersion != null) { + con.setRequestProperty("cdi-version", cdiVersion); + } + if (rid != null) { + con.setRequestProperty("rId", rid); + } + + int responseCode = con.getResponseCode(); + + // Handle redirects specially + if (responseCode >= 300 && responseCode < 400) { + String location = con.getHeaderField("Location"); + if (location != null) { + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, location); + } + } + + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { + inputStream = con.getInputStream(); + } else { + inputStream = con.getErrorStream(); + } + + StringBuilder response = new StringBuilder(); + try (BufferedReader in = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + String inputLine; + while ((inputLine = in.readLine()) != null) { + response.append(inputLine); + } + } + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { + if (!isJsonValid(response.toString())) { + return (T) response.toString(); + } + return (T) (new JsonParser().parse(response.toString())); + } + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, response.toString()); + } finally { + if (inputStream != null) { + inputStream.close(); + } + + if (con != null) { + con.disconnect(); + } + } + } + public static T sendJsonRequest(Main main, String requestID, String url, JsonElement requestBody, int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, String method, String apiKey, String rid) throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + return sendJsonRequest(main, requestID, url, requestBody, connectionTimeoutMS, readTimeoutMS, version, cdiVersion, method, apiKey, rid, true); + } + + @SuppressWarnings("unchecked") + public static T sendJsonRequest(Main main, String requestID, String url, JsonElement requestBody, + int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, + String method, + String apiKey, String rid, boolean followRedirects) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { // If the url doesn't contain the app id deliberately, add app id used for testing if (!disableAddingAppId && !url.contains("appid-")) { String appId = ResourceDistributor.getAppForTesting().getAppId(); @@ -164,6 +297,7 @@ public static T sendJsonRequest(Main main, String requestID, String url, Jso con.setRequestMethod(method); con.setConnectTimeout(connectionTimeoutMS); con.setReadTimeout(readTimeoutMS + 1000); + con.setInstanceFollowRedirects(followRedirects); con.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); if (version != null) { con.setRequestProperty("api-version", version + ""); @@ -188,6 +322,14 @@ public static T sendJsonRequest(Main main, String requestID, String url, Jso int responseCode = con.getResponseCode(); + // Handle redirects specially + if (responseCode >= 300 && responseCode < 400) { + String location = con.getHeaderField("Location"); + if (location != null) { + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, location); + } + } + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { inputStream = con.getInputStream(); } else { @@ -252,12 +394,21 @@ public static T sendJsonDELETERequest(Main main, String requestID, String ur cdiVersion, "DELETE", null, rid); } - @SuppressWarnings("unchecked") public static T sendJsonDELETERequestWithQueryParams(Main main, String requestID, String url, Map params, int connectionTimeoutMS, int readTimeoutMS, Integer version, String cdiVersion, String rid) throws IOException, HttpResponseException { + return sendJsonDELETERequestWithQueryParams(main, requestID, url, params, connectionTimeoutMS, readTimeoutMS, version, cdiVersion, rid, true); + } + + @SuppressWarnings("unchecked") + public static T sendJsonDELETERequestWithQueryParams(Main main, String requestID, String url, + Map params, + int connectionTimeoutMS, int readTimeoutMS, + Integer version, String cdiVersion, String rid, + boolean followRedirects) + throws IOException, HttpResponseException { // If the url doesn't contain the app id deliberately, add app id used for testing if (!disableAddingAppId && !url.contains("appid-")) { String appId = ResourceDistributor.getAppForTesting().getAppId(); @@ -290,6 +441,7 @@ public static T sendJsonDELETERequestWithQueryParams(Main main, String reque con.setRequestMethod("DELETE"); con.setConnectTimeout(connectionTimeoutMS); con.setReadTimeout(readTimeoutMS + 1000); + con.setInstanceFollowRedirects(followRedirects); if (version != null) { con.setRequestProperty("api-version", version + ""); } @@ -302,6 +454,14 @@ public static T sendJsonDELETERequestWithQueryParams(Main main, String reque int responseCode = con.getResponseCode(); + // Handle redirects specially + if (responseCode >= 300 && responseCode < 400) { + String location = con.getHeaderField("Location"); + if (location != null) { + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, location); + } + } + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { inputStream = con.getInputStream(); } else { @@ -333,6 +493,108 @@ public static T sendJsonDELETERequestWithQueryParams(Main main, String reque } } + public static T sendFormDataPOSTRequest(Main main, String requestID, String url, JsonObject formData, + int connectionTimeoutMS, int readTimeoutMS, Integer version, + String cdiVersion, String rid) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + return sendFormDataPOSTRequest(main, requestID, url, formData, connectionTimeoutMS, readTimeoutMS, version, cdiVersion, rid, true); + } + + @SuppressWarnings("unchecked") + public static T sendFormDataPOSTRequest(Main main, String requestID, String url, JsonObject formData, + int connectionTimeoutMS, int readTimeoutMS, Integer version, + String cdiVersion, String rid, boolean followRedirects) + throws IOException, io.supertokens.test.httpRequest.HttpResponseException { + // If the url doesn't contain the app id deliberately, add app id used for testing + if (!disableAddingAppId && !url.contains("appid-")) { + String appId = ResourceDistributor.getAppForTesting().getAppId(); + url = url.replace(":3567", ":3567/appid-" + appId); + } + + if (corePort != null) { + url = url.replace(":3567", ":" + corePort); + } + + URL obj = getURL(main, requestID, url); + InputStream inputStream = null; + HttpURLConnection con = null; + + try { + con = (HttpURLConnection) obj.openConnection(); + con.setRequestMethod("POST"); + con.setConnectTimeout(connectionTimeoutMS); + con.setReadTimeout(readTimeoutMS + 1000); + con.setInstanceFollowRedirects(followRedirects); + con.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8"); + if (version != null) { + con.setRequestProperty("api-version", version + ""); + } + if (cdiVersion != null) { + con.setRequestProperty("cdi-version", cdiVersion); + } + if (rid != null) { + con.setRequestProperty("rId", rid); + } + + if (formData != null) { + con.setDoOutput(true); + StringBuilder formDataStr = new StringBuilder(); + for (Map.Entry entry : formData.entrySet()) { + if (formDataStr.length() > 0) { + formDataStr.append("&"); + } + formDataStr.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8)) + .append("=") + .append(URLEncoder.encode(entry.getValue().getAsString(), StandardCharsets.UTF_8)); + } + try (OutputStream os = con.getOutputStream()) { + byte[] input = formDataStr.toString().getBytes(StandardCharsets.UTF_8); + os.write(input, 0, input.length); + } + } + + int responseCode = con.getResponseCode(); + + // Handle redirects specially + if (responseCode >= 300 && responseCode < 400) { + String location = con.getHeaderField("Location"); + if (location != null) { + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, location); + } + } + + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { + inputStream = con.getInputStream(); + } else { + inputStream = con.getErrorStream(); + } + + StringBuilder response = new StringBuilder(); + try (BufferedReader in = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + String inputLine; + while ((inputLine = in.readLine()) != null) { + response.append(inputLine); + } + } + + if (responseCode < STATUS_CODE_ERROR_THRESHOLD) { + if (!isJsonValid(response.toString())) { + return (T) response.toString(); + } + return (T) (new JsonParser().parse(response.toString())); + } + throw new io.supertokens.test.httpRequest.HttpResponseException(responseCode, response.toString()); + } finally { + if (inputStream != null) { + inputStream.close(); + } + + if (con != null) { + con.disconnect(); + } + } + } + public static String getMultitenantUrl(TenantIdentifier tenantIdentifier, String path) { StringBuilder sb = new StringBuilder(); if (tenantIdentifier.getConnectionUriDomain() == TenantIdentifier.DEFAULT_CONNECTION_URI) { diff --git a/src/test/java/io/supertokens/test/jwt/api/JWKSAPITest2_21.java b/src/test/java/io/supertokens/test/jwt/api/JWKSAPITest2_21.java index ed10ea000..8a80f00de 100644 --- a/src/test/java/io/supertokens/test/jwt/api/JWKSAPITest2_21.java +++ b/src/test/java/io/supertokens/test/jwt/api/JWKSAPITest2_21.java @@ -69,9 +69,9 @@ public void testThatNewDynamicKeysAreAdded() throws Exception { "jwt"); JsonArray oldKeys = oldResponse.getAsJsonArray("keys"); - assertEquals(oldKeys.size(), 2); // 1 static + 1 dynamic key + assertTrue(oldKeys.size() >= 2); // 1 static + 1 dynamic key - Thread.sleep(1500); + Thread.sleep(1200); JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", "http://localhost:3567/recipe/jwt/jwks", null, 1000, 1000, null, @@ -79,7 +79,7 @@ public void testThatNewDynamicKeysAreAdded() throws Exception { "jwt"); JsonArray keys = response.getAsJsonArray("keys"); - assertEquals(keys.size(), oldKeys.size() + 1); + assertTrue(keys.size() >= oldKeys.size() + 1); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); diff --git a/src/test/java/io/supertokens/test/multitenant/AppTenantUserTest.java b/src/test/java/io/supertokens/test/multitenant/AppTenantUserTest.java index ead561993..b46fb244b 100644 --- a/src/test/java/io/supertokens/test/multitenant/AppTenantUserTest.java +++ b/src/test/java/io/supertokens/test/multitenant/AppTenantUserTest.java @@ -32,6 +32,7 @@ import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.pluginInterface.nonAuthRecipe.NonAuthRecipeStorage; import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -85,7 +86,8 @@ public void testDeletingAppDeleteNonAuthRecipeData() throws Exception { JWTRecipeStorage.class.getName(), ActiveUsersStorage.class.getName(), OAuthStorage.class.getName(), - BulkImportStorage.class.getName() + BulkImportStorage.class.getName(), + SAMLStorage.class.getName() ); Reflections reflections = new Reflections("io.supertokens.pluginInterface"); @@ -193,7 +195,8 @@ public void testDisassociationOfUserDeletesNonAuthRecipeData() throws Exception JWTRecipeStorage.class.getName(), ActiveUsersStorage.class.getName(), OAuthStorage.class.getName(), - BulkImportStorage.class.getName() + BulkImportStorage.class.getName(), + SAMLStorage.class.getName() ); Reflections reflections = new Reflections("io.supertokens.pluginInterface"); diff --git a/src/test/java/io/supertokens/test/multitenant/TestAppData.java b/src/test/java/io/supertokens/test/multitenant/TestAppData.java index 3277321a0..99385cc4d 100644 --- a/src/test/java/io/supertokens/test/multitenant/TestAppData.java +++ b/src/test/java/io/supertokens/test/multitenant/TestAppData.java @@ -16,9 +16,6 @@ package io.supertokens.test.multitenant; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - import java.security.InvalidKeyException; import java.security.Key; import java.time.Duration; @@ -27,21 +24,17 @@ import javax.crypto.spec.SecretKeySpec; -import io.supertokens.pluginInterface.webauthn.AccountRecoveryTokenInfo; -import io.supertokens.pluginInterface.webauthn.WebAuthNOptions; -import io.supertokens.pluginInterface.webauthn.WebAuthNStorage; -import io.supertokens.pluginInterface.webauthn.WebAuthNStoredCredential; -import io.supertokens.pluginInterface.webauthn.exceptions.DuplicateUserEmailException; -import io.supertokens.pluginInterface.webauthn.exceptions.DuplicateUserIdException; -import io.supertokens.pluginInterface.webauthn.slqStorage.WebAuthNSQLStorage; import org.apache.commons.codec.binary.Base32; import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestRule; import com.eatthepath.otp.TimeBasedOneTimePasswordGenerator; +import com.google.gson.JsonArray; import com.google.gson.JsonObject; import io.supertokens.ActiveUsers; @@ -66,7 +59,17 @@ import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.multitenancy.ThirdPartyConfig; import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.saml.SAMLClient; +import io.supertokens.pluginInterface.saml.SAMLRelayStateInfo; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.pluginInterface.totp.TOTPDevice; +import io.supertokens.pluginInterface.webauthn.AccountRecoveryTokenInfo; +import io.supertokens.pluginInterface.webauthn.WebAuthNOptions; +import io.supertokens.pluginInterface.webauthn.WebAuthNStorage; +import io.supertokens.pluginInterface.webauthn.WebAuthNStoredCredential; +import io.supertokens.pluginInterface.webauthn.exceptions.DuplicateUserEmailException; +import io.supertokens.pluginInterface.webauthn.exceptions.DuplicateUserIdException; +import io.supertokens.pluginInterface.webauthn.slqStorage.WebAuthNSQLStorage; import io.supertokens.session.Session; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; @@ -242,6 +245,10 @@ null, null, new JsonObject() options.userVerification = "required"; ((WebAuthNStorage) appStorage).saveGeneratedOptions(app, options); + ((SAMLStorage) appStorage).createOrUpdateSAMLClient(app, new SAMLClient("abcd", "efgh", "http://localhost:5225", new JsonArray(), "http://localhost:3000", "http://idp.example.com", "abcdefgh", false, true)); + ((SAMLStorage) appStorage).saveRelayStateInfo(app, new SAMLRelayStateInfo("1234", "abcd", "qwer", "http://localhost:3000/auth/callback/saml")); + ((SAMLStorage) appStorage).saveSAMLClaims(app, "abcd", "efgh", new JsonObject()); + String[] tablesThatHaveData = appStorage .getAllTablesInTheDatabaseThatHasDataForAppId(app.getAppId()); tablesThatHaveData = removeStrings(tablesThatHaveData, tablesToIgnore); diff --git a/src/test/java/io/supertokens/test/multitenant/api/TestTenantUserAssociation.java b/src/test/java/io/supertokens/test/multitenant/api/TestTenantUserAssociation.java index b014344ea..0608558fe 100644 --- a/src/test/java/io/supertokens/test/multitenant/api/TestTenantUserAssociation.java +++ b/src/test/java/io/supertokens/test/multitenant/api/TestTenantUserAssociation.java @@ -39,6 +39,7 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.nonAuthRecipe.NonAuthRecipeStorage; import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.pluginInterface.usermetadata.UserMetadataStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; @@ -204,6 +205,7 @@ public void testUserDisassociationForNotAuthRecipes() throws Exception { || name.equals(ActiveUsersStorage.class.getName()) || name.equals(BulkImportStorage.class.getName()) || name.equals(OAuthStorage.class.getName()) + || name.equals(SAMLStorage.class.getName()) ) { // user metadata is app specific and does not have any tenant specific data // JWT storage does not have any user specific data diff --git a/src/test/java/io/supertokens/test/saml/MockSAML.java b/src/test/java/io/supertokens/test/saml/MockSAML.java new file mode 100644 index 000000000..adabc81aa --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/MockSAML.java @@ -0,0 +1,378 @@ +package io.supertokens.test.saml; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.time.Instant; +import java.util.Base64; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.xml.namespace.QName; + +import net.shibboleth.utilities.java.support.xml.SerializeSupport; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.KeyUsage; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.core.xml.io.MarshallingException; +import org.opensaml.core.xml.util.XMLObjectSupport; +import org.opensaml.saml.common.SAMLVersion; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.core.Assertion; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.Audience; +import org.opensaml.saml.saml2.core.AudienceRestriction; +import org.opensaml.saml.saml2.core.AuthnContext; +import org.opensaml.saml.saml2.core.AuthnContextClassRef; +import org.opensaml.saml.saml2.core.AuthnStatement; +import org.opensaml.saml.saml2.core.Conditions; +import org.opensaml.saml.saml2.core.Issuer; +import org.opensaml.saml.saml2.core.NameID; +import org.opensaml.saml.saml2.core.NameIDType; +import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.Status; +import org.opensaml.saml.saml2.core.StatusCode; +import org.opensaml.saml.saml2.core.Subject; +import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.metadata.*; +import org.opensaml.security.credential.Credential; +import org.opensaml.security.credential.CredentialSupport; +import org.opensaml.security.credential.UsageType; +import org.opensaml.xmlsec.signature.KeyInfo; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.X509Data; +import org.opensaml.xmlsec.signature.impl.KeyInfoBuilder; +import org.opensaml.xmlsec.signature.impl.SignatureBuilder; +import org.opensaml.xmlsec.signature.impl.X509DataBuilder; +import org.opensaml.xmlsec.signature.support.SignatureConstants; +import org.opensaml.xmlsec.signature.support.Signer; +import org.w3c.dom.Element; + +import javax.xml.namespace.QName; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.util.*; + +// NOTE: This class provides helpers to mimic a minimal SAML IdP for tests. +public class MockSAML { + public static class KeyMaterial { + public final PrivateKey privateKey; + public final X509Certificate certificate; + + public KeyMaterial(PrivateKey privateKey, X509Certificate certificate) { + this.privateKey = privateKey; + this.certificate = certificate; + } + + public String getCertificateBase64Der() { + try { + return Base64.getEncoder().encodeToString(certificate.getEncoded()); + } catch (CertificateEncodingException e) { + throw new RuntimeException(e); + } + } + } + + public static KeyMaterial generateSelfSignedKeyMaterial() { + try { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(2048); + KeyPair keyPair = keyGen.generateKeyPair(); + + Date notBefore = new Date(); + Date notAfter = new Date(notBefore.getTime() + 365L * 24 * 60 * 60 * 1000); // 1 year + + X500Name subject = new X500Name("CN=Mock IdP, O=SuperTokens, C=US"); + + java.math.BigInteger serialNumber = java.math.BigInteger.valueOf(System.currentTimeMillis()); + + JcaX509v3CertificateBuilder certBuilder = new JcaX509v3CertificateBuilder( + subject, + serialNumber, + notBefore, + notAfter, + subject, + keyPair.getPublic() + ); + + KeyUsage keyUsage = new KeyUsage(KeyUsage.digitalSignature | KeyUsage.keyEncipherment); + certBuilder.addExtension(Extension.keyUsage, true, keyUsage); + + BasicConstraints basicConstraints = new BasicConstraints(false); + certBuilder.addExtension(Extension.basicConstraints, true, basicConstraints); + + ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256withRSA") + .build(keyPair.getPrivate()); + + X509CertificateHolder certHolder = certBuilder.build(contentSigner); + JcaX509CertificateConverter converter = new JcaX509CertificateConverter(); + X509Certificate certificate = converter.getCertificate(certHolder); + + return new KeyMaterial(keyPair.getPrivate(), certificate); + } catch (OperatorCreationException | CertificateException | java.security.NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } catch (org.bouncycastle.cert.CertIOException e) { + throw new RuntimeException(e); + } + } + + // Tests should provide their own PEM materials; helpers below parse PEM into usable objects. + public static KeyMaterial createKeyMaterialFromPEM(String privateKeyPEM, String certificatePEM) { + return new KeyMaterial(parsePrivateKeyFromPEM(privateKeyPEM), parseCertificateFromPEM(certificatePEM)); + } + + public static String generateIdpMetadataXML(String idpEntityId, String ssoRedirectUrl, X509Certificate cert) { + EntityDescriptor entityDescriptor = build(EntityDescriptor.DEFAULT_ELEMENT_NAME); + entityDescriptor.setEntityID(idpEntityId); + + IDPSSODescriptor idp = build(IDPSSODescriptor.DEFAULT_ELEMENT_NAME); + idp.addSupportedProtocol(SAMLConstants.SAML20P_NS); + idp.setWantAuthnRequestsSigned(true); + + // Add both Redirect and POST bindings pointing to the same SSO URL + SingleSignOnService ssoRedirect = build(SingleSignOnService.DEFAULT_ELEMENT_NAME); + ssoRedirect.setBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); + ssoRedirect.setLocation(ssoRedirectUrl); + idp.getSingleSignOnServices().add(ssoRedirect); + + SingleSignOnService ssoPost = build(SingleSignOnService.DEFAULT_ELEMENT_NAME); + ssoPost.setBinding(SAMLConstants.SAML2_POST_BINDING_URI); + ssoPost.setLocation(ssoRedirectUrl); + idp.getSingleSignOnServices().add(ssoPost); + + KeyDescriptor keyDesc = build(KeyDescriptor.DEFAULT_ELEMENT_NAME); + keyDesc.setUse(UsageType.SIGNING); + + KeyInfo keyInfo = buildKeyInfoWithCert(cert); + keyDesc.setKeyInfo(keyInfo); + idp.getKeyDescriptors().add(keyDesc); + + // NameIDFormat: emailAddress + NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME); + nameIdFormat.setFormat("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"); + idp.getNameIDFormats().add(nameIdFormat); + + entityDescriptor.getRoleDescriptors().add(idp); + return toXmlString(entityDescriptor); + } + + public static String generateSignedSAMLResponseBase64( + String issuerEntityId, + String audience, + String acsUrl, + String nameId, + Map> attributes, + String inResponseTo, + KeyMaterial keyMaterial, + int notOnOrAfterSeconds + ) { + Instant now = Instant.now(); + Instant notOnOrAfter = now.plusSeconds(Math.max(60, notOnOrAfterSeconds)); + + Response response = build(Response.DEFAULT_ELEMENT_NAME); + response.setID(randomId()); + response.setVersion(SAMLVersion.VERSION_20); + response.setIssueInstant(now); + response.setDestination(acsUrl); + if (inResponseTo != null) { + response.setInResponseTo(inResponseTo); + } + + Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME); + issuer.setValue(issuerEntityId); + response.setIssuer(issuer); + + Status status = build(Status.DEFAULT_ELEMENT_NAME); + StatusCode statusCode = build(StatusCode.DEFAULT_ELEMENT_NAME); + statusCode.setValue(StatusCode.SUCCESS); + status.setStatusCode(statusCode); + response.setStatus(status); + + Assertion assertion = build(Assertion.DEFAULT_ELEMENT_NAME); + assertion.setID(randomId()); + assertion.setIssueInstant(now); + assertion.setVersion(SAMLVersion.VERSION_20); + + Issuer assertionIssuer = build(Issuer.DEFAULT_ELEMENT_NAME); + assertionIssuer.setValue(issuerEntityId); + assertion.setIssuer(assertionIssuer); + + Subject subject = build(Subject.DEFAULT_ELEMENT_NAME); + NameID nameIdObj = build(NameID.DEFAULT_ELEMENT_NAME); + nameIdObj.setValue(nameId); + nameIdObj.setFormat(NameIDType.PERSISTENT); + subject.setNameID(nameIdObj); + + SubjectConfirmation sc = build(SubjectConfirmation.DEFAULT_ELEMENT_NAME); + sc.setMethod(SubjectConfirmation.METHOD_BEARER); + SubjectConfirmationData scd = build(SubjectConfirmationData.DEFAULT_ELEMENT_NAME); + scd.setRecipient(acsUrl); + scd.setNotOnOrAfter(notOnOrAfter); + if (inResponseTo != null) { + scd.setInResponseTo(inResponseTo); + } + sc.setSubjectConfirmationData(scd); + subject.getSubjectConfirmations().add(sc); + assertion.setSubject(subject); + + Conditions conditions = build(Conditions.DEFAULT_ELEMENT_NAME); + conditions.setNotBefore(now.minusSeconds(1)); + conditions.setNotOnOrAfter(notOnOrAfter); + AudienceRestriction ar = build(AudienceRestriction.DEFAULT_ELEMENT_NAME); + Audience aud = build(Audience.DEFAULT_ELEMENT_NAME); + aud.setURI(audience); + ar.getAudiences().add(aud); + conditions.getAudienceRestrictions().add(ar); + assertion.setConditions(conditions); + + AuthnStatement authnStatement = build(AuthnStatement.DEFAULT_ELEMENT_NAME); + authnStatement.setAuthnInstant(now); + AuthnContext authnContext = build(AuthnContext.DEFAULT_ELEMENT_NAME); + AuthnContextClassRef classRef = build(AuthnContextClassRef.DEFAULT_ELEMENT_NAME); + classRef.setURI(AuthnContext.PASSWORD_AUTHN_CTX); + authnContext.setAuthnContextClassRef(classRef); + authnStatement.setAuthnContext(authnContext); + assertion.getAuthnStatements().add(authnStatement); + + if (attributes != null && !attributes.isEmpty()) { + AttributeStatement attrStatement = build(AttributeStatement.DEFAULT_ELEMENT_NAME); + for (Map.Entry> e : attributes.entrySet()) { + Attribute attr = build(Attribute.DEFAULT_ELEMENT_NAME); + attr.setName(e.getKey()); + for (String v : e.getValue()) { + XMLObject val = build(new QName(SAMLConstants.SAML20_NS, "AttributeValue", SAMLConstants.SAML20_PREFIX)); + // Represent as simple string text node + val.getDOM(); + // Fallback: use anyType with text via builder marshaling + // Instead, we can use XSString builder: + org.opensaml.core.xml.schema.impl.XSStringBuilder sb = new org.opensaml.core.xml.schema.impl.XSStringBuilder(); + org.opensaml.core.xml.schema.XSString xs = sb.buildObject( + new QName(SAMLConstants.SAML20_NS, "AttributeValue", SAMLConstants.SAML20_PREFIX), + org.opensaml.core.xml.schema.XSString.TYPE_NAME); + xs.setValue(v); + attr.getAttributeValues().add(xs); + } + attrStatement.getAttributes().add(attr); + } + assertion.getAttributeStatements().add(attrStatement); + } + + signAssertion(assertion, keyMaterial); + response.getAssertions().add(assertion); + + String xml = toXmlString(response); + return Base64.getEncoder().encodeToString(xml.getBytes(StandardCharsets.UTF_8)); + } + + public static KeyInfo buildKeyInfoWithCert(X509Certificate cert) { + KeyInfoBuilder keyInfoBuilder = new KeyInfoBuilder(); + KeyInfo keyInfo = keyInfoBuilder.buildObject(); + X509DataBuilder x509DataBuilder = new X509DataBuilder(); + X509Data x509Data = x509DataBuilder.buildObject(); + org.opensaml.xmlsec.signature.X509Certificate x509CertElem = + (org.opensaml.xmlsec.signature.X509Certificate) XMLObjectSupport.buildXMLObject( + org.opensaml.xmlsec.signature.X509Certificate.DEFAULT_ELEMENT_NAME); + try { + x509CertElem.setValue(Base64.getEncoder().encodeToString(cert.getEncoded())); + } catch (CertificateEncodingException e) { + throw new RuntimeException(e); + } + x509Data.getX509Certificates().add(x509CertElem); + keyInfo.getX509Datas().add(x509Data); + return keyInfo; + } + + private static T build(QName qName) { + return (T) Objects.requireNonNull( + XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName)).buildObject(qName); + } + + private static String toXmlString(XMLObject xmlObject) { + try { + Element el = XMLObjectSupport.marshall(xmlObject); + return SerializeSupport.nodeToString(el); + } catch (MarshallingException e) { + throw new RuntimeException(e); + } + } + + private static void signAssertion(Assertion assertion, KeyMaterial km) { + try { + Credential cred = CredentialSupport.getSimpleCredential(km.certificate, km.privateKey); + SignatureBuilder signatureBuilder = new SignatureBuilder(); + Signature signature = signatureBuilder.buildObject(); + signature.setSigningCredential(cred); + signature.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); + signature.setCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); + signature.setKeyInfo(buildKeyInfoWithCert(km.certificate)); + + assertion.setSignature(signature); + XMLObjectSupport.marshall(assertion); + Signer.signObject(signature); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static String randomId() { + return "_" + new BigInteger(160, new SecureRandom()).toString(16); + } + + public static X509Certificate parseCertificateFromPEM(String pem) { + try { + String base64 = pem.replace("-----BEGIN CERTIFICATE-----", "") + .replace("-----END CERTIFICATE-----", "") + .replaceAll("\n|\r", "").trim(); + byte[] der = Base64.getDecoder().decode(base64); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + return (X509Certificate) cf.generateCertificate(new java.io.ByteArrayInputStream(der)); + } catch (CertificateException e) { + throw new RuntimeException(e); + } + } + + public static PrivateKey parsePrivateKeyFromPEM(String pem) { + try { + String base64 = pem.replace("-----BEGIN PRIVATE KEY-----", "") + .replace("-----END PRIVATE KEY-----", "") + .replaceAll("[\\n\\r\\s]", ""); + byte[] pkcs8 = Base64.getDecoder().decode(base64); + PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(pkcs8); + return KeyFactory.getInstance("RSA").generatePrivate(spec); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/io/supertokens/test/saml/SAMLTestUtils.java b/src/test/java/io/supertokens/test/saml/SAMLTestUtils.java new file mode 100644 index 000000000..b28a82003 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/SAMLTestUtils.java @@ -0,0 +1,93 @@ +package io.supertokens.test.saml; + +import java.nio.charset.StandardCharsets; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.utils.SemVer; + +public class SAMLTestUtils { + + public static class CreatedClientInfo { + public final String clientId; + public final MockSAML.KeyMaterial keyMaterial; + public final String defaultRedirectURI; + public final String acsURL; + public final String idpEntityId; + public final String idpSsoUrl; + + public CreatedClientInfo(String clientId, MockSAML.KeyMaterial keyMaterial, + String defaultRedirectURI, String acsURL, String idpEntityId, String idpSsoUrl) { + this.clientId = clientId; + this.keyMaterial = keyMaterial; + this.defaultRedirectURI = defaultRedirectURI; + this.acsURL = acsURL; + this.idpEntityId = idpEntityId; + this.idpSsoUrl = idpSsoUrl; + } + } + + public static CreatedClientInfo createClientWithGeneratedMetadata(TestingProcessManager.TestingProcess process, + String defaultRedirectURI, + String acsURL, + String idpEntityId, + String idpSsoUrl) throws Exception { + return createClientWithGeneratedMetadata(process, defaultRedirectURI, acsURL, idpEntityId, idpSsoUrl, false); + } + + public static CreatedClientInfo createClientWithGeneratedMetadata(TestingProcessManager.TestingProcess process, + String defaultRedirectURI, + String acsURL, + String idpEntityId, + String idpSsoUrl, + boolean allowIDPInitiatedLogin) throws Exception { + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(StandardCharsets.UTF_8)); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("clientSecret", "secret"); + createClientInput.addProperty("defaultRedirectURI", defaultRedirectURI); + JsonArray redirectURIs = new JsonArray(); + redirectURIs.add(defaultRedirectURI); + createClientInput.add("redirectURIs", redirectURIs); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + createClientInput.addProperty("allowIDPInitiatedLogin", allowIDPInitiatedLogin); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + String clientId = createResp.get("clientId").getAsString(); + return new CreatedClientInfo(clientId, keyMaterial, defaultRedirectURI, acsURL, idpEntityId, idpSsoUrl); + } + + public static String createLoginRequestAndGetRelayState(TestingProcessManager.TestingProcess process, + String clientId, + String redirectURI, + String acsURL, + String state) throws Exception { + JsonObject body = new JsonObject(); + body.addProperty("clientId", clientId); + body.addProperty("redirectURI", redirectURI); + body.addProperty("acsURL", acsURL); + if (state != null) { + body.addProperty("state", state); + } + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + String ssoRedirectURI = resp.get("ssoRedirectURI").getAsString(); + int idx = ssoRedirectURI.indexOf("RelayState="); + if (idx == -1) { + throw new IllegalStateException("RelayState not found in ssoRedirectURI"); + } + String relayStatePart = ssoRedirectURI.substring(idx + "RelayState=".length()); + int amp = relayStatePart.indexOf('&'); + String relayState = amp == -1 ? relayStatePart : relayStatePart.substring(0, amp); + return java.net.URLDecoder.decode(relayState, java.nio.charset.StandardCharsets.UTF_8); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/CreateOrUpdateSAMLClientTest5_4.java b/src/test/java/io/supertokens/test/saml/api/CreateOrUpdateSAMLClientTest5_4.java new file mode 100644 index 000000000..f24bdc789 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/CreateOrUpdateSAMLClientTest5_4.java @@ -0,0 +1,481 @@ +/* + * Copyright (c) 2025, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.test.saml.api; + +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.pluginInterface.STORAGE_TYPE; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.utils.SemVer; + +public class CreateOrUpdateSAMLClientTest5_4 { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Test + public void testCreationWithClientSecret() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String generatedMetadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(generatedMetadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + String clientSecret = "my-secret-abc-123"; + createClientInput.addProperty("clientSecret", clientSecret); + + JsonObject resp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + // Ensure structure contains clientSecret and matches provided value + assertEquals("OK", resp.get("status").getAsString()); + assertTrue(resp.has("clientSecret")); + assertEquals(clientSecret, resp.get("clientSecret").getAsString()); + assertTrue(resp.get("clientId").getAsString().startsWith("st_saml_")); + assertEquals("http://localhost:3000/auth/callback/saml-mock", resp.get("defaultRedirectURI").getAsString()); + assertTrue(resp.get("redirectURIs").isJsonArray()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testCreationWithPredefinedClientId() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject createClientInput = new JsonObject(); + String customClientId = "st_saml_custom_12345"; + createClientInput.addProperty("clientId", customClientId); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial km = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, km.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + JsonObject resp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + // Ensure custom clientId is respected and standard fields present + verifyClientStructureWithoutClientSecret(resp, false); + assertEquals("OK", resp.get("status").getAsString()); + assertEquals(customClientId, resp.get("clientId").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testBadInput() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + JsonObject createClientInput = new JsonObject(); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'defaultRedirectURI' is invalid in JSON input", e.getMessage()); + } + + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-azure"); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'redirectURIs' is invalid in JSON input", e.getMessage()); + } + + createClientInput.add("redirectURIs", new JsonArray()); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: redirectURIs is required in the input", e.getMessage()); + } + + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-azure"); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'metadataXML' is invalid in JSON input", e.getMessage()); + } + + createClientInput.addProperty("metadataXML", ""); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: metadataXML does not have a valid SAML metadata", e.getMessage()); + } + + String helloXml = "world"; + String helloXmlBase64 = java.util.Base64.getEncoder().encodeToString(helloXml.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", helloXmlBase64); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: metadataXML does not have a valid SAML metadata", e.getMessage()); + } + + // has an invalid certificate + String metadataXML = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " MIIC4jCCAcoCCQC33wnybT5QZDANBgkqhkiG9w0BAQsFADAyMQswCQYDVQQGEwJV\n" + + "SzEPMA0GA1UECgwGQm94eUhRMRIwEAYDVQQDDAlNb2NrIFNBTUwwIBcNMjIwMjI4\n" + + "MjE0NjM4WhgPMzAyMTA3MDEyMTQ2MzhaMDIxCzAJBgNVBAYTAlVLMQ8wDQYDVQQK\n" + + "DAZCb3h5SFExEjAQBgNVBAMMCU1vY2sgU0FNTDCCASIwDQYJKoZIhvcNAQEBBQAD\n" + + "ggEPADCCAQoCggEBALGfYettMsct1T6tVUwTudNJH5Pnb9GGnkXi9Zw/e6x45DD0\n" + + "RuRONbFlJ2T4RjAE/uG+AjXxXQ8o2SZfb9+GgmCHuTJFNgHoZ1nFVXCmb/Hg8Hpd\n" + + "4vOAGXndixaReOiq3EH5XvpMjMkJ3+8+9VYMzMZOjkgQtAqO36eAFFfNKX7dTj3V\n" + + "2/W5sGHRv+9AarggJkF+ptUkXoLtVA51wcfYm6hILptpde5FQC8RWY1YrswBWAEZ\n" + + "NfyrR4JeSweElNHg4NVOs4TwGjOPwWGqzTfgTlECAwEAATANBgkqhkiG9w0BAQsF\n" + + "AAOCAQEAAYRlYflSXAWoZpFfwNiCQVE5d9zZ0DPzNdWhAybXcTyMf0z5mDf6FWBW\n" + + "5Gyoi9u3EMEDnzLcJNkwJAAc39Apa4I2/tml+Jy29dk8bTyX6m93ngmCgdLh5Za4\n" + + "khuU3AM3L63g7VexCuO7kwkjh/+LqdcIXsVGO6XDfu2QOs1Xpe9zIzLpwm/RNYeX\n" + + "UjbSj5ce/jekpAw7qyVVL4xOyh8AtUW1ek3wIw1MJvEgEPt0d16oshWJpoS1OT8L\n" + + "r/22SvYEo3EmSGdTVGgk3x3s+A0qWAqTcyjr7Q4s/GKYRFfomGwz0TZ4Iw1ZN99M\n" + + "m0eo2USlSRTVl7QHRTuiuSThHpLKQQ==\n" + + " \n" + + " \n" + + " \n" + + " urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress\n" + + " \n" + + " \n" + + " \n" + + ""; + + metadataXML = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXML); + try { + HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + fail(); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: metadataXML does not have a valid SAML metadata", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testCreationUsingXML() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial km = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, km.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + JsonObject resp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + verifyClientStructureWithoutClientSecret(resp, true); + + assertEquals("OK", resp.get("status").getAsString()); + // Check the actual returned values for each field + assertTrue(resp.get("clientId").getAsString().startsWith("st_saml_")); + + assertEquals("http://localhost:3000/auth/callback/saml-mock", resp.get("defaultRedirectURI").getAsString()); + + assertTrue(resp.get("redirectURIs").isJsonArray()); + assertEquals(1, resp.get("redirectURIs").getAsJsonArray().size()); + assertEquals("http://localhost:3000/auth/callback/saml-mock", resp.get("redirectURIs").getAsJsonArray().get(0).getAsString()); + + assertEquals(idpEntityId, resp.get("idpEntityId").getAsString()); + + String expectedCertBase64 = java.util.Base64.getEncoder().encodeToString(km.certificate.getEncoded()); + assertEquals(expectedCertBase64, resp.get("idpSigningCertificate").getAsString()); + + assertFalse(resp.get("allowIDPInitiatedLogin").getAsBoolean()); + + assertEquals("OK", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testUpdateClient() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create a client first + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial km2 = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId2 = "https://saml.example.com/entityid"; + String idpSsoUrl2 = "https://mocksaml.com/api/saml/sso"; + String metadataXML2 = MockSAML.generateIdpMetadataXML(idpEntityId2, idpSsoUrl2, km2.certificate); + String metadataXMLBase64_2 = java.util.Base64.getEncoder().encodeToString(metadataXML2.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXMLBase64_2); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + verifyClientStructureWithoutClientSecret(createResp, true); + + String clientId = createResp.get("clientId").getAsString(); + + // Update fields + JsonObject updateInput = new JsonObject(); + updateInput.addProperty("clientId", clientId); + updateInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock-2"); + JsonArray updatedRedirectURIs = new JsonArray(); + updatedRedirectURIs.add("http://localhost:3000/auth/callback/saml-mock-2"); + updatedRedirectURIs.add("http://localhost:3000/auth/callback/saml-mock-3"); + updateInput.add("redirectURIs", updatedRedirectURIs); + updateInput.addProperty("allowIDPInitiatedLogin", true); + // metadata is required by the API even on update + updateInput.addProperty("metadataXML", metadataXMLBase64_2); + + JsonObject updateResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", updateInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + verifyClientStructureWithoutClientSecret(updateResp, false); + + assertEquals("OK", updateResp.get("status").getAsString()); + assertEquals(clientId, updateResp.get("clientId").getAsString()); + assertEquals("http://localhost:3000/auth/callback/saml-mock-2", updateResp.get("defaultRedirectURI").getAsString()); + assertTrue(updateResp.get("redirectURIs").isJsonArray()); + assertEquals(2, updateResp.get("redirectURIs").getAsJsonArray().size()); + assertEquals("http://localhost:3000/auth/callback/saml-mock-2", updateResp.get("redirectURIs").getAsJsonArray().get(0).getAsString()); + assertEquals("http://localhost:3000/auth/callback/saml-mock-3", updateResp.get("redirectURIs").getAsJsonArray().get(1).getAsString()); + assertTrue(updateResp.get("allowIDPInitiatedLogin").getAsBoolean()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + private static void verifyClientStructureWithoutClientSecret(JsonObject client, boolean generatedClientId) throws Exception { + assertEquals(8, client.size()); + + String[] FIELDS = new String[]{ + "clientId", + "defaultRedirectURI", + "redirectURIs", + "idpEntityId", + "idpSigningCertificate", + "allowIDPInitiatedLogin", + "enableRequestSigning", + "status" + }; + + for (String field : FIELDS) { + assertTrue(client.has(field)); + } + + if (generatedClientId) { + assertTrue(client.get("clientId").getAsString().startsWith("st_saml_")); + } + + assertTrue(client.get("defaultRedirectURI").isJsonPrimitive()); + + assertTrue(client.get("redirectURIs").isJsonArray()); + assertTrue(client.get("redirectURIs").getAsJsonArray().size() > 0); + assertTrue(client.get("idpEntityId").isJsonPrimitive()); + assertTrue(client.get("idpSigningCertificate").isJsonPrimitive()); + assertTrue(client.get("enableRequestSigning").isJsonPrimitive()); + + assertEquals("OK", client.get("status").getAsString()); + } + + @Test + public void testDuplicateEntityId() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + // Create first client + JsonObject input1 = new JsonObject(); + input1.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + input1.add("redirectURIs", new JsonArray()); + input1.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + MockSAML.KeyMaterial km1 = MockSAML.generateSelfSignedKeyMaterial(); + String duplicateEntityId = "https://saml.example.com/entityid-dup"; + String ssoUrl = "https://mocksaml.com/api/saml/sso"; + String metadata1 = MockSAML.generateIdpMetadataXML(duplicateEntityId, ssoUrl, km1.certificate); + String metadata1B64 = java.util.Base64.getEncoder().encodeToString(metadata1.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + input1.addProperty("metadataXML", metadata1B64); + + JsonObject createResp1 = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", input1, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + assertEquals("OK", createResp1.get("status").getAsString()); + + // Attempt to create second client with the same IdP entity ID + JsonObject input2 = new JsonObject(); + input2.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + input2.add("redirectURIs", new JsonArray()); + input2.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + MockSAML.KeyMaterial km2 = MockSAML.generateSelfSignedKeyMaterial(); + String metadata2 = MockSAML.generateIdpMetadataXML(duplicateEntityId, ssoUrl, km2.certificate); + String metadata2B64 = java.util.Base64.getEncoder().encodeToString(metadata2.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + input2.addProperty("metadataXML", metadata2B64); + + JsonObject createResp2 = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", input2, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("DUPLICATE_IDP_ENTITY_ERROR", createResp2.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/CreateSamlLoginRedirectAPITest5_4.java b/src/test/java/io/supertokens/test/saml/api/CreateSamlLoginRedirectAPITest5_4.java new file mode 100644 index 000000000..173b5f048 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/CreateSamlLoginRedirectAPITest5_4.java @@ -0,0 +1,222 @@ +package io.supertokens.test.saml.api; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.utils.SemVer; + +public class CreateSamlLoginRedirectAPITest5_4 { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // missing clientId + { + JsonObject body = new JsonObject(); + body.addProperty("redirectURI", "http://localhost:3000/auth/callback/saml-mock"); + body.addProperty("acsURL", "http://localhost:3000/acs"); + + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'clientId' is invalid in JSON input", e.getMessage()); + } + } + + // missing redirectURI + { + JsonObject body = new JsonObject(); + body.addProperty("clientId", "some-client"); + body.addProperty("acsURL", "http://localhost:3000/acs"); + + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'redirectURI' is invalid in JSON input", e.getMessage()); + } + } + + // missing acsURL + { + JsonObject body = new JsonObject(); + body.addProperty("clientId", "some-client"); + body.addProperty("redirectURI", "http://localhost:3000/auth/callback/saml-mock"); + + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'acsURL' is invalid in JSON input", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testInvalidClientId() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject body = new JsonObject(); + body.addProperty("clientId", "non-existent-client"); + body.addProperty("redirectURI", "http://localhost:3000/auth/callback/saml-mock"); + body.addProperty("acsURL", "http://localhost:3000/acs"); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + assertEquals("INVALID_CLIENT_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testInvalidRedirectURI() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("spEntityId", "http://example.com/saml"); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + assertEquals("OK", createResp.get("status").getAsString()); + String clientId = createResp.get("clientId").getAsString(); + + JsonObject body = new JsonObject(); + body.addProperty("clientId", clientId); + body.addProperty("redirectURI", "http://localhost:3000/another/callback"); + body.addProperty("acsURL", "http://localhost:3000/acs"); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + assertEquals("INVALID_CLIENT_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testValidLoginRedirect() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Prepare IdP metadata using MockSAML self-signed certificate + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + java.security.cert.X509Certificate cert = keyMaterial.certificate; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, cert); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // Create client using metadataXML + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("spEntityId", "http://example.com/saml"); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + assertEquals("OK", createResp.get("status").getAsString()); + String clientId = createResp.get("clientId").getAsString(); + + // Create login request with valid redirect URI + JsonObject body = new JsonObject(); + body.addProperty("clientId", clientId); + body.addProperty("redirectURI", "http://localhost:3000/auth/callback/saml-mock"); + body.addProperty("acsURL", "http://localhost:3000/acs"); + body.addProperty("state", "abc123"); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/login", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + // Verify response structure + assertEquals("OK", resp.get("status").getAsString()); + assertTrue(resp.has("ssoRedirectURI")); + String ssoRedirectURI = resp.get("ssoRedirectURI").getAsString(); + assertTrue(ssoRedirectURI.startsWith(idpSsoUrl + "?")); + assertTrue(ssoRedirectURI.contains("SAMLRequest=")); + assertTrue(ssoRedirectURI.contains("RelayState=")); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/GetUserinfoTest5_4.java b/src/test/java/io/supertokens/test/saml/api/GetUserinfoTest5_4.java new file mode 100644 index 000000000..2509a860b --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/GetUserinfoTest5_4.java @@ -0,0 +1,293 @@ +package io.supertokens.test.saml.api; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.test.saml.SAMLTestUtils; +import io.supertokens.utils.SemVer; + +public class GetUserinfoTest5_4 { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Missing accessToken + { + JsonObject body = new JsonObject(); + body.addProperty("clientId", "some-client"); + + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/user", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'accessToken' is invalid in JSON input", e.getMessage()); + } + } + + // Missing clientId + { + JsonObject body = new JsonObject(); + body.addProperty("accessToken", "some-access-token"); + + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/user", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'clientId' is invalid in JSON input", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testInvalidAccessToken() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Test with invalid/fake access token + { + JsonObject body = new JsonObject(); + body.addProperty("accessToken", "invalid-access-token-12345"); + body.addProperty("clientId", "test-client-id"); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/user", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("INVALID_TOKEN_ERROR", response.get("status").getAsString()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testValidTokenWithWrongClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create first client + String spEntityId1 = "http://example.com/saml"; + String defaultRedirectURI1 = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL1 = "http://localhost:3000/acs"; + String idpEntityId1 = "https://saml.example.com/entityid"; + String idpSsoUrl1 = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo1 = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI1, + acsURL1, + idpEntityId1, + idpSsoUrl1 + ); + + // Create second client + String spEntityId2 = "http://example2.com/saml"; + String defaultRedirectURI2 = "http://localhost:3001/auth/callback/saml-mock"; + String acsURL2 = "http://localhost:3001/acs"; + String idpEntityId2 = "https://saml2.example.com/entityid"; + String idpSsoUrl2 = "https://mocksaml2.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo2 = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI2, + acsURL2, + idpEntityId2, + idpSsoUrl2 + ); + + // Create a login request for client1 to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo1.clientId, + clientInfo1.defaultRedirectURI, + clientInfo1.acsURL, + "test-state" + ); + + // Generate a valid SAML Response for client1 + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo1.idpEntityId, + "https://saml.supertokens.com", + clientInfo1.acsURL, + "user@example.com", + null, + relayState, + clientInfo1.keyMaterial, + 300 + ); + + // Process the callback for client1 to get a valid access token + JsonObject callbackBody = new JsonObject(); + callbackBody.addProperty("samlResponse", samlResponseBase64); + callbackBody.addProperty("relayState", relayState); + + JsonObject callbackResp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", callbackBody, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("OK", callbackResp.get("status").getAsString()); + + // Extract the access token from the redirect URI + String redirectURI = callbackResp.get("redirectURI").getAsString(); + String accessToken = extractAccessTokenFromRedirectURI(redirectURI); + + // Now try to use the valid access token from client1 with client2's clientId + JsonObject userInfoBody = new JsonObject(); + userInfoBody.addProperty("accessToken", accessToken); + userInfoBody.addProperty("clientId", clientInfo2.clientId); // Wrong client ID + + JsonObject userInfoResp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/user", userInfoBody, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("INVALID_TOKEN_ERROR", userInfoResp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testValidTokenWithCorrectClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + // Generate a valid SAML Response + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + // Process the callback to get a valid access token + JsonObject callbackBody = new JsonObject(); + callbackBody.addProperty("samlResponse", samlResponseBase64); + callbackBody.addProperty("relayState", relayState); + + JsonObject callbackResp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", callbackBody, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("OK", callbackResp.get("status").getAsString()); + + // Extract the access token from the redirect URI + String redirectURI = callbackResp.get("redirectURI").getAsString(); + String accessToken = extractAccessTokenFromRedirectURI(redirectURI); + + // Use the valid access token with the correct client ID + JsonObject userInfoBody = new JsonObject(); + userInfoBody.addProperty("accessToken", accessToken); + userInfoBody.addProperty("clientId", clientInfo.clientId); + + JsonObject userInfoResp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/user", userInfoBody, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + // Verify successful response + assertEquals("OK", userInfoResp.get("status").getAsString()); + assertNotNull(userInfoResp.get("sub")); + assertEquals("user@example.com", userInfoResp.get("sub").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + private String extractAccessTokenFromRedirectURI(String redirectURI) { + // Extract the 'code' parameter from the redirect URI + // Format: http://localhost:3000/auth/callback/saml-mock?code=some-uuid&state=test-state + int codeIndex = redirectURI.indexOf("code="); + if (codeIndex == -1) { + throw new IllegalStateException("Code parameter not found in redirect URI: " + redirectURI); + } + + String codePart = redirectURI.substring(codeIndex + "code=".length()); + int ampIndex = codePart.indexOf('&'); + if (ampIndex != -1) { + codePart = codePart.substring(0, ampIndex); + } + + return java.net.URLDecoder.decode(codePart, java.nio.charset.StandardCharsets.UTF_8); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/HandleSAMLCallbackTest5_4.java b/src/test/java/io/supertokens/test/saml/api/HandleSAMLCallbackTest5_4.java new file mode 100644 index 000000000..f49c01d63 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/HandleSAMLCallbackTest5_4.java @@ -0,0 +1,455 @@ +package io.supertokens.test.saml.api; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.test.saml.SAMLTestUtils; +import io.supertokens.utils.SemVer; + +public class HandleSAMLCallbackTest5_4 { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Missing SAMLResponse + { + JsonObject body = new JsonObject(); + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'samlResponse' is invalid in JSON input", e.getMessage()); + } + } + + // Empty SAMLResponse (base64 of empty string is empty) + { + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", ""); + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Invalid or malformed SAML response input", e.getMessage()); + } + } + + // Non-XML SAMLResponse (base64 of 'hello') + { + String nonXmlBase64 = java.util.Base64.getEncoder().encodeToString("hello".getBytes(java.nio.charset.StandardCharsets.UTF_8)); + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", nonXmlBase64); + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Invalid or malformed SAML response input", e.getMessage()); + } + } + + // Arbitrary XML as SAMLResponse (not a SAML Response element) + { + String xml = ""; + String xmlBase64 = java.util.Base64.getEncoder().encodeToString(xml.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", xmlBase64); + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Invalid or malformed SAML response input", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testNonExistingRelayState() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + null, + clientInfo.keyMaterial, + 300 + ); + + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + body.addProperty("relayState", "this-does-not-exist"); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("INVALID_RELAY_STATE_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testWrongAudienceInSAMLResponse() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Audience that does not match the client's SP Entity ID + String wrongAudience = "http://wrong.example.com/sp"; + + // Create a login request to generate a RelayState, then use it during callback + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + wrongAudience, + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + body.addProperty("relayState", relayState); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("SAML_RESPONSE_VERIFICATION_FAILED_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testWrongSignatureInSAMLResponse() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState, then use it during callback + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + // Generate a different key material to sign the assertion with the wrong certificate + MockSAML.KeyMaterial wrongKeyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + wrongKeyMaterial, + 300 + ); + + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + body.addProperty("relayState", relayState); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("SAML_RESPONSE_VERIFICATION_FAILED_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testClientDeletedBeforeProcessingCallbackResultsInInvalidClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + // Create a valid SAML Response for this client and the relayState + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + // Now delete the client before processing the callback + JsonObject removeBody = new JsonObject(); + removeBody.addProperty("clientId", clientInfo.clientId); + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", removeBody, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + // Process the callback; should result in INVALID_CLIENT_ERROR + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + body.addProperty("relayState", relayState); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("INVALID_CLIENT_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testIDPFlowWithIDPDisallowedOnClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + // Create a client with allowIDPInitiatedLogin = false (default) + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl, + false // allowIDPInitiatedLogin = false + ); + + // Generate an IDP-initiated SAML response (no RelayState, no InResponseTo) + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + null, // no inResponseTo for IDP-initiated + clientInfo.keyMaterial, + 300 + ); + + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + // Intentionally omit relayState to simulate IDP-initiated login + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("IDP_LOGIN_DISALLOWED_ERROR", resp.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testIDPFlow() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + // Create a client with allowIDPInitiatedLogin = true + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl, + true // allowIDPInitiatedLogin = true + ); + + // Generate an IDP-initiated SAML response (no RelayState, no InResponseTo) + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + null, // no inResponseTo for IDP-initiated + clientInfo.keyMaterial, + 300 + ); + + JsonObject body = new JsonObject(); + body.addProperty("samlResponse", samlResponseBase64); + // Intentionally omit relayState to simulate IDP-initiated login + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/callback", body, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("OK", resp.get("status").getAsString()); + String redirectURI = resp.get("redirectURI").getAsString(); + // Check that the redirectURI contains the code query parameter + assertNotNull(redirectURI); + assertTrue("Redirect URI should contain code parameter", redirectURI.contains("code=")); + // Check it starts with the default redirect URI + assertTrue("Redirect URI should start with default redirect URI", redirectURI.startsWith(defaultRedirectURI)); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/LegacyTest5_4.java b/src/test/java/io/supertokens/test/saml/api/LegacyTest5_4.java new file mode 100644 index 000000000..8850a36a3 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/LegacyTest5_4.java @@ -0,0 +1,733 @@ +package io.supertokens.test.saml.api; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.test.saml.SAMLTestUtils; +import io.supertokens.utils.SemVer; + +public class LegacyTest5_4 { + + private static final String TEST_REDIRECT_URI = "http://localhost:3000/auth/callback/saml-mock"; + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() throws IOException { + Utils.reset(); + // Set the legacy ACS URL for testing + Utils.setValueInConfig("saml_legacy_acs_url", "http://localhost:3567/recipe/saml/legacy/callback"); + } + + // ========== LegacyAuthorizeAPI Tests ========== + + @Test + public void testLegacyAuthorizeBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Missing client_id + { + Map params = new HashMap<>(); + params.put("redirect_uri", TEST_REDIRECT_URI); + params.put("state", "test-state"); + + try { + HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/authorize", params, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'client_id' is missing in GET request", e.getMessage()); + } + } + + // Missing redirect_uri + { + Map params = new HashMap<>(); + params.put("client_id", "test-client"); + params.put("state", "test-state"); + + try { + HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/authorize", params, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'redirect_uri' is missing in GET request", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyAuthorizeInvalidClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Test with non-existent client_id + Map params = new HashMap<>(); + params.put("client_id", "non-existent-client"); + params.put("redirect_uri", TEST_REDIRECT_URI); + params.put("state", "test-state"); + + JsonObject response = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/authorize", params, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("INVALID_CLIENT_ERROR", response.get("status").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyAuthorizeValidClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Test valid authorization request + String redirectURI = TEST_REDIRECT_URI; // Use the same redirect URI as configured in the client + String state = "test-state-123"; + + // Create query parameters map + Map params = new HashMap<>(); + params.put("client_id", clientInfo.clientId); + params.put("redirect_uri", redirectURI); + params.put("state", state); + + // This should redirect to SSO URL, so we expect a 307 redirect + try { + HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/authorize", params, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail("Expected redirect response"); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(307, e.statusCode); + // Verify the redirect URL contains expected parameters + String location = e.getMessage(); + assertNotNull(location); + assertNotNull("Location header should contain SSO URL", location); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + // ========== LegacyCallbackAPI Tests ========== + + @Test + public void testLegacyCallbackBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Missing SAMLResponse + { + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", new JsonObject(), 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Missing form field: SAMLResponse", e.getMessage()); + } + } + + // Empty SAMLResponse + { + JsonObject formData = new JsonObject(); + formData.addProperty("SAMLResponse", ""); + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Missing form field: SAMLResponse", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyCallbackInvalidRelayState() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + null, + clientInfo.keyMaterial, + 300 + ); + + JsonObject formData = new JsonObject(); + formData.addProperty("SAMLResponse", samlResponseBase64); + formData.addProperty("RelayState", "invalid-relay-state"); + + try { + String response = HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: INVALID_RELAY_STATE_ERROR", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyCallbackValidResponse() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + JsonObject formData = new JsonObject(); + formData.addProperty("SAMLResponse", samlResponseBase64); + formData.addProperty("RelayState", relayState); + + // This should redirect to the callback URL with authorization code + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml", false); + fail("Expected redirect response"); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(302, e.statusCode); + String location = e.getMessage(); + assertNotNull(location); + assertNotNull("Location header should contain redirect URI", location); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + // ========== LegacyTokenAPI Tests ========== + + @Test + public void testLegacyTokenBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Missing client_id + { + JsonObject formData = new JsonObject(); + formData.addProperty("client_secret", clientInfo.clientId); // In legacy API, client_secret is same as client_id + formData.addProperty("code", "test-code"); + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Missing form field: client_id", e.getMessage()); + } + } + + // Missing client_secret + { + JsonObject formData = new JsonObject(); + formData.addProperty("client_id", clientInfo.clientId); + formData.addProperty("code", "test-code"); + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Missing form field: client_secret", e.getMessage()); + } + } + + // Missing code + { + JsonObject formData = new JsonObject(); + formData.addProperty("client_id", clientInfo.clientId); + formData.addProperty("client_secret", clientInfo.clientId); // In legacy API, client_secret is same as client_id + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Missing form field: code", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyTokenInvalidClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject formData = new JsonObject(); + formData.addProperty("client_id", "non-existent-client"); + formData.addProperty("client_secret", "test-secret"); + formData.addProperty("code", "test-code"); + + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Invalid client_id", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyTokenInvalidSecret() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + JsonObject formData = new JsonObject(); + formData.addProperty("client_id", clientInfo.clientId); + formData.addProperty("client_secret", "wrong-secret"); + formData.addProperty("code", "test-code"); + + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", formData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Invalid client_secret", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyTokenValidRequest() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + // Process callback to get authorization code + JsonObject callbackFormData = new JsonObject(); + callbackFormData.addProperty("SAMLResponse", samlResponseBase64); + callbackFormData.addProperty("RelayState", relayState); + + String redirectURI = null; + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", callbackFormData, 1000, 1000, null, SemVer.v5_4.get(), "saml", false); + fail("Expected redirect response"); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(302, e.statusCode); + redirectURI = e.getMessage(); + } + + // Extract authorization code from redirect URI + String authCode = extractAuthCodeFromRedirectURI(redirectURI); + + // Now test token exchange + JsonObject tokenFormData = new JsonObject(); + tokenFormData.addProperty("client_id", clientInfo.clientId); + tokenFormData.addProperty("client_secret", "secret"); + tokenFormData.addProperty("code", authCode); + + JsonObject tokenResponse = HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", tokenFormData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("OK", tokenResponse.get("status").getAsString()); + assertNotNull(tokenResponse.get("access_token")); + String accessToken = tokenResponse.get("access_token").getAsString(); + assertEquals(authCode + "." + clientInfo.clientId, accessToken); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + // ========== LegacyUserinfoAPI Tests ========== + + @Test + public void testLegacyUserinfoBadInput() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Missing Authorization header + { + try { + HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/userinfo", null, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Authorization header is required", e.getMessage()); + } + } + + // Invalid Authorization header format + { + try { + HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/userinfo", null, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + fail(); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Authorization header is required", e.getMessage()); + } + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyUserinfoInvalidToken() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + try { + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer invalid-token"); + JsonObject response = HttpRequestForTesting.sendGETRequestWithHeaders(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/userinfo", null, headers, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: INVALID_TOKEN_ERROR", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testLegacyUserinfoValidToken() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Create SAML client + String spEntityId = "http://example.com/saml"; + String defaultRedirectURI = "http://localhost:3000/auth/callback/saml-mock"; + String acsURL = "http://localhost:3000/acs"; + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + + SAMLTestUtils.CreatedClientInfo clientInfo = SAMLTestUtils.createClientWithGeneratedMetadata( + process, + defaultRedirectURI, + acsURL, + idpEntityId, + idpSsoUrl + ); + + // Create a login request to generate a RelayState + String relayState = SAMLTestUtils.createLoginRequestAndGetRelayState( + process, + clientInfo.clientId, + clientInfo.defaultRedirectURI, + clientInfo.acsURL, + "test-state" + ); + + String samlResponseBase64 = MockSAML.generateSignedSAMLResponseBase64( + clientInfo.idpEntityId, + "https://saml.supertokens.com", + clientInfo.acsURL, + "user@example.com", + null, + relayState, + clientInfo.keyMaterial, + 300 + ); + + // Process callback to get authorization code + JsonObject callbackFormData = new JsonObject(); + callbackFormData.addProperty("SAMLResponse", samlResponseBase64); + callbackFormData.addProperty("RelayState", relayState); + + String redirectURI = null; + try { + HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/callback", callbackFormData, 1000, 1000, null, SemVer.v5_4.get(), "saml", false); + fail("Expected redirect response"); + } catch (io.supertokens.test.httpRequest.HttpResponseException e) { + assertEquals(302, e.statusCode); + redirectURI = e.getMessage(); + } + + // Extract authorization code from redirect URI + String authCode = extractAuthCodeFromRedirectURI(redirectURI); + + // Exchange code for access token + JsonObject tokenFormData = new JsonObject(); + tokenFormData.addProperty("client_id", clientInfo.clientId); + tokenFormData.addProperty("client_secret", "secret"); + tokenFormData.addProperty("code", authCode); + + JsonObject tokenResponse = HttpRequestForTesting.sendFormDataPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/token", tokenFormData, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertEquals("OK", tokenResponse.get("status").getAsString()); + + String accessToken = tokenResponse.get("access_token").getAsString(); + + // Now test userinfo with valid access token + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + accessToken); + JsonObject userInfoResponse = HttpRequestForTesting.sendGETRequestWithHeaders(process.getProcess(), "", + "http://localhost:3567/recipe/saml/legacy/userinfo", null, headers, 1000, 1000, null, SemVer.v5_4.get(), "saml"); + + assertNotNull(userInfoResponse.get("id")); + assertEquals("user@example.com", userInfoResponse.get("id").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + // Helper method to extract authorization code from redirect URI + private String extractAuthCodeFromRedirectURI(String redirectURI) { + // Extract the 'code' parameter from the redirect URI + // Format: http://localhost:3000/auth/callback/saml-mock?code=some-uuid&state=test-state + int codeIndex = redirectURI.indexOf("code="); + if (codeIndex == -1) { + throw new IllegalStateException("Code parameter not found in redirect URI: " + redirectURI); + } + + String codePart = redirectURI.substring(codeIndex + "code=".length()); + int ampIndex = codePart.indexOf('&'); + if (ampIndex != -1) { + codePart = codePart.substring(0, ampIndex); + } + + return java.net.URLDecoder.decode(codePart, java.nio.charset.StandardCharsets.UTF_8); + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/ListSAMLClientsTest5_4.java b/src/test/java/io/supertokens/test/saml/api/ListSAMLClientsTest5_4.java new file mode 100644 index 000000000..f4e52e376 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/ListSAMLClientsTest5_4.java @@ -0,0 +1,186 @@ +package io.supertokens.test.saml.api; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.utils.SemVer; + +public class ListSAMLClientsTest5_4 { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testEmptyList() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject listResp = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/list", null, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", listResp.get("status").getAsString()); + assertTrue(listResp.has("clients")); + assertTrue(listResp.get("clients").isJsonArray()); + assertEquals(0, listResp.get("clients").getAsJsonArray().size()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testListAfterCreatingClientViaXML() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("spEntityId", "http://example.com/saml"); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", createResp.get("status").getAsString()); + String clientId = createResp.get("clientId").getAsString(); + + JsonObject listResp = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/list", null, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", listResp.get("status").getAsString()); + assertTrue(listResp.get("clients").isJsonArray()); + JsonArray clients = listResp.get("clients").getAsJsonArray(); + assertEquals(1, clients.size()); + + JsonObject listed = findByClientId(clients, clientId); + assertNotNull(listed); + + // should not include clientSecret since we didn't set it + assertFalse(listed.has("clientSecret")); + + assertEquals("http://localhost:3000/auth/callback/saml-mock", listed.get("defaultRedirectURI").getAsString()); + assertTrue(listed.get("redirectURIs").isJsonArray()); + assertEquals(1, listed.get("redirectURIs").getAsJsonArray().size()); + assertEquals("http://localhost:3000/auth/callback/saml-mock", + listed.get("redirectURIs").getAsJsonArray().get(0).getAsString()); + + assertEquals(idpEntityId, listed.get("idpEntityId").getAsString()); + assertTrue(listed.has("idpSigningCertificate")); + assertFalse(listed.get("idpSigningCertificate").getAsString().isEmpty()); + assertFalse(listed.get("allowIDPInitiatedLogin").getAsBoolean()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testListIncludesClientSecretWhenProvided() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + JsonObject createClientInput = new JsonObject(); + createClientInput.addProperty("spEntityId", "http://example.com/saml"); + createClientInput.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + createClientInput.add("redirectURIs", new JsonArray()); + createClientInput.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + createClientInput.addProperty("metadataXML", metadataXMLBase64); + + String clientSecret = "my-secret-xyz"; + createClientInput.addProperty("clientSecret", clientSecret); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", createClientInput, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", createResp.get("status").getAsString()); + String clientId = createResp.get("clientId").getAsString(); + + JsonObject listResp = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/list", null, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", listResp.get("status").getAsString()); + JsonArray clients = listResp.get("clients").getAsJsonArray(); + JsonObject listed = findByClientId(clients, clientId); + assertNotNull(listed); + assertTrue(listed.has("clientSecret")); + assertEquals(clientSecret, listed.get("clientSecret").getAsString()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + private static JsonObject findByClientId(JsonArray clients, String clientId) { + for (JsonElement el : clients) { + JsonObject obj = el.getAsJsonObject(); + if (obj.has("clientId") && obj.get("clientId").getAsString().equals(clientId)) { + return obj; + } + } + return null; + } +} diff --git a/src/test/java/io/supertokens/test/saml/api/RemoveSAMLClientTest5_4.java b/src/test/java/io/supertokens/test/saml/api/RemoveSAMLClientTest5_4.java new file mode 100644 index 000000000..b1625b2c4 --- /dev/null +++ b/src/test/java/io/supertokens/test/saml/api/RemoveSAMLClientTest5_4.java @@ -0,0 +1,199 @@ +package io.supertokens.test.saml.api; + +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlag; +import io.supertokens.featureflag.FeatureFlagTestContent; +import org.junit.AfterClass; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import io.supertokens.ProcessState; +import io.supertokens.test.TestingProcessManager; +import io.supertokens.test.Utils; +import io.supertokens.test.httpRequest.HttpRequestForTesting; +import io.supertokens.test.httpRequest.HttpResponseException; +import io.supertokens.test.saml.MockSAML; +import io.supertokens.utils.SemVer; + +public class RemoveSAMLClientTest5_4 { + + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @Rule + public TestRule retryFlaky = Utils.retryFlakyTest(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + @Test + public void testDeleteNonExistingClientReturnsFalse() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject body = new JsonObject(); + body.addProperty("clientId", "st_saml_does_not_exist"); + + JsonObject resp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", body, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", resp.get("status").getAsString()); + assertFalse(resp.get("didExist").getAsBoolean()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testBadInputMissingClientId() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + JsonObject body = new JsonObject(); + try { + HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", body, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + // should not reach here + org.junit.Assert.fail(); + } catch (HttpResponseException e) { + assertEquals(400, e.statusCode); + assertEquals("Http error. Status Code: 400. Message: Field name 'clientId' is invalid in JSON input", e.getMessage()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testCreateThenDeleteClient() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // create a client first + JsonObject create = new JsonObject(); + create.addProperty("spEntityId", "http://example.com/saml"); + create.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + create.add("redirectURIs", new JsonArray()); + create.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + create.addProperty("metadataXML", metadataXMLBase64); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", create, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + String clientId = createResp.get("clientId").getAsString(); + assertTrue(clientId.startsWith("st_saml_")); + + // delete it + JsonObject body = new JsonObject(); + body.addProperty("clientId", clientId); + + JsonObject deleteResp = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", body, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + assertEquals("OK", deleteResp.get("status").getAsString()); + assertTrue(deleteResp.get("didExist").getAsBoolean()); + + // verify listing is empty after deletion + JsonObject listResp = HttpRequestForTesting.sendGETRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/list", null, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + assertEquals("OK", listResp.get("status").getAsString()); + assertTrue(listResp.get("clients").isJsonArray()); + assertEquals(0, listResp.get("clients").getAsJsonArray().size()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + @Test + public void testDeleteTwiceSecondTimeFalse() throws Exception { + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.SAML}); + + // create + JsonObject create = new JsonObject(); + create.addProperty("spEntityId", "http://example.com/saml"); + create.addProperty("defaultRedirectURI", "http://localhost:3000/auth/callback/saml-mock"); + create.add("redirectURIs", new JsonArray()); + create.get("redirectURIs").getAsJsonArray().add("http://localhost:3000/auth/callback/saml-mock"); + + // Generate IdP metadata using MockSAML + MockSAML.KeyMaterial keyMaterial = MockSAML.generateSelfSignedKeyMaterial(); + String idpEntityId = "https://saml.example.com/entityid"; + String idpSsoUrl = "https://mocksaml.com/api/saml/sso"; + String metadataXML = MockSAML.generateIdpMetadataXML(idpEntityId, idpSsoUrl, keyMaterial.certificate); + String metadataXMLBase64 = java.util.Base64.getEncoder().encodeToString(metadataXML.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + create.addProperty("metadataXML", metadataXMLBase64); + + JsonObject createResp = HttpRequestForTesting.sendJsonPUTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients", create, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + + String clientId = createResp.get("clientId").getAsString(); + + JsonObject body = new JsonObject(); + body.addProperty("clientId", clientId); + + JsonObject deleteResp1 = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", body, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + assertEquals("OK", deleteResp1.get("status").getAsString()); + assertTrue(deleteResp1.get("didExist").getAsBoolean()); + + JsonObject deleteResp2 = HttpRequestForTesting.sendJsonPOSTRequest(process.getProcess(), "", + "http://localhost:3567/recipe/saml/clients/remove", body, 1000, 1000, null, + SemVer.v5_4.get(), "saml"); + assertEquals("OK", deleteResp2.get("status").getAsString()); + assertFalse(deleteResp2.get("didExist").getAsBoolean()); + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } +} diff --git a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java index f596f37d3..5755fa0a3 100644 --- a/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java +++ b/src/test/java/io/supertokens/test/userIdMapping/UserIdMappingTest.java @@ -31,6 +31,7 @@ import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.nonAuthRecipe.NonAuthRecipeStorage; import io.supertokens.pluginInterface.oauth.OAuthStorage; +import io.supertokens.pluginInterface.saml.SAMLStorage; import io.supertokens.pluginInterface.useridmapping.UserIdMappingStorage; import io.supertokens.pluginInterface.useridmapping.exception.UnknownSuperTokensUserIdException; import io.supertokens.pluginInterface.useridmapping.exception.UserIdMappingAlreadyExistsException; @@ -809,7 +810,8 @@ public void checkThatCreateUserIdMappingHasAllNonAuthRecipeChecks() throws Excep JWTRecipeStorage.class.getName(), ActiveUsersStorage.class.getName(), OAuthStorage.class.getName(), - BulkImportStorage.class.getName() + BulkImportStorage.class.getName(), + SAMLStorage.class.getName() ); Reflections reflections = new Reflections("io.supertokens.pluginInterface"); @@ -894,7 +896,8 @@ public void checkThatDeleteUserIdMappingHasAllNonAuthRecipeChecks() throws Excep JWTRecipeStorage.class.getName(), ActiveUsersStorage.class.getName(), OAuthStorage.class.getName(), - BulkImportStorage.class.getName() + BulkImportStorage.class.getName(), + SAMLStorage.class.getName() ); Reflections reflections = new Reflections("io.supertokens.pluginInterface"); Set> classes = reflections.getSubTypesOf(NonAuthRecipeStorage.class);