Skip to content

Commit 26bb0b4

Browse files
committed
Fix RemoteS3Connection for rewritten bucket and key
1 parent 2ab4bb4 commit 26bb0b4

File tree

4 files changed

+192
-27
lines changed

4 files changed

+192
-27
lines changed

trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,44 @@ public void shutDown()
112112
}
113113
}
114114

115-
public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse,
115+
public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMetadata, ParsedS3Request originalRequest, AsyncResponse asyncResponse,
116116
RequestLoggingSession requestLoggingSession)
117117
{
118-
SecurityResponse securityResponse = s3SecurityController.apply(request, identity);
118+
SecurityResponse securityResponse = s3SecurityController.apply(originalRequest, identity);
119119
if (securityResponse instanceof Failure(var error)) {
120-
log.debug("SecurityController check failed. AccessKey: %s, Request: %s, SecurityResponse: %s", signingMetadata.credential().accessKey(), request, securityResponse);
120+
log.debug("SecurityController check failed. AccessKey: %s, Request: %s, SecurityResponse: %s", signingMetadata.credential().accessKey(), originalRequest, securityResponse);
121121
requestLoggingSession.logError("request.security.fail.credentials", signingMetadata.credential());
122-
requestLoggingSession.logError("request.security.fail.request", request);
122+
requestLoggingSession.logError("request.security.fail.request", originalRequest);
123123
requestLoggingSession.logError("request.security.fail.error", error);
124124

125125
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
126126
}
127127

128-
Optional<S3RewriteResult> rewriteResult = s3RequestRewriter.rewrite(identity, signingMetadata, request);
129-
String targetBucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(request.bucketName());
130-
String targetKey = rewriteResult
128+
Optional<S3RewriteResult> rewriteResult = s3RequestRewriter.rewrite(identity, signingMetadata, originalRequest);
129+
String bucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(originalRequest.bucketName());
130+
String key = rewriteResult
131+
.map(S3RewriteResult::finalRequestKey)
132+
.orElse(originalRequest.keyInBucket());
133+
String path = rewriteResult
131134
.map(S3RewriteResult::finalRequestKey)
132135
.map(SdkHttpUtils::urlEncodeIgnoreSlashes)
133-
.orElse(request.rawPath());
136+
.orElse(originalRequest.rawPath());
137+
138+
ParsedS3Request request = new ParsedS3Request(
139+
originalRequest.requestId(),
140+
originalRequest.requestAuthorization(),
141+
originalRequest.requestDate(),
142+
bucket,
143+
key,
144+
originalRequest.requestHeaders(),
145+
originalRequest.queryParameters(),
146+
originalRequest.httpVerb(),
147+
path,
148+
originalRequest.rawQuery(),
149+
originalRequest.requestContent());
134150

135151
RemoteRequestWithPresignedURIs remoteRequest = remoteS3ConnectionController.withRemoteConnection(signingMetadata, identity, request, (remoteCredential, remoteS3Facade) -> {
136-
URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region());
152+
URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), request.rawPath(), request.bucketName(), request.requestAuthorization().region());
137153

138154
Request.Builder remoteRequestBuilder = new Request.Builder()
139155
.setMethod(request.httpVerb())
@@ -161,7 +177,8 @@ public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMet
161177

162178
Map<String, URI> presignedUrls;
163179
if (generatePresignedUrlsOnHead && request.httpVerb().equalsIgnoreCase("HEAD")) {
164-
presignedUrls = s3PresignController.buildPresignedRemoteUrls(identity, remoteSigningMetadata, request, targetRequestTimestamp, remoteUri);
180+
// Presigned URLs are generated for the ORIGINAL key and bucket, not the rewritten ones
181+
presignedUrls = s3PresignController.buildPresignedRemoteUrls(identity, remoteSigningMetadata, originalRequest, targetRequestTimestamp, remoteUri);
165182
}
166183
else {
167184
presignedUrls = ImmutableMap.of();
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.aws.proxy.server.remote.provider;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import com.google.inject.Inject;
18+
import io.trino.aws.proxy.server.testing.RequestRewriteUtil;
19+
import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer;
20+
import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container;
21+
import io.trino.aws.proxy.server.testing.harness.BuilderFilter;
22+
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
23+
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets;
24+
import io.trino.aws.proxy.spi.credentials.Identity;
25+
import io.trino.aws.proxy.spi.remote.RemoteS3Connection;
26+
import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection;
27+
import io.trino.aws.proxy.spi.remote.RemoteS3ConnectionProvider;
28+
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
29+
import io.trino.aws.proxy.spi.signing.SigningMetadata;
30+
import org.junit.jupiter.api.AfterEach;
31+
import org.junit.jupiter.api.Test;
32+
import software.amazon.awssdk.core.ResponseInputStream;
33+
import software.amazon.awssdk.core.sync.RequestBody;
34+
import software.amazon.awssdk.services.s3.S3Client;
35+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
36+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
37+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
38+
39+
import java.io.IOException;
40+
import java.util.ArrayList;
41+
import java.util.List;
42+
import java.util.Optional;
43+
44+
import static com.google.inject.Scopes.SINGLETON;
45+
import static io.trino.aws.proxy.server.testing.TestingUtil.LOREM_IPSUM;
46+
import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_REMOTE_CREDENTIAL;
47+
import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.remoteS3ConnectionProviderModule;
48+
import static java.util.Objects.requireNonNull;
49+
import static org.assertj.core.api.Assertions.assertThat;
50+
51+
@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class, TestRemoteS3ConnectionProviderWithRewriter.Filter.class})
52+
public class TestRemoteS3ConnectionProviderWithRewriter
53+
{
54+
public static class Filter
55+
implements BuilderFilter
56+
{
57+
@Override
58+
public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder)
59+
{
60+
return builder
61+
.withoutTestingRemoteS3ConnectionProvider()
62+
.addModule(remoteS3ConnectionProviderModule("rewrite-test", DelegateRemoteS3ConnectionProvider.class,
63+
binder -> binder.bind(DelegateRemoteS3ConnectionProvider.class).in(SINGLETON)))
64+
.withProperty("remote-s3-connection-provider.type", "rewrite-test");
65+
}
66+
}
67+
68+
public static class DelegateRemoteS3ConnectionProvider
69+
implements RemoteS3ConnectionProvider
70+
{
71+
private RemoteS3ConnectionProvider delegate;
72+
73+
private final List<RemoteS3ConnectionProviderArgs> callArgs = new ArrayList<>();
74+
75+
@Override
76+
public Optional<RemoteS3Connection> remoteConnection(SigningMetadata signingMetadata, Optional<Identity> identity, ParsedS3Request request)
77+
{
78+
callArgs.add(new RemoteS3ConnectionProviderArgs(signingMetadata, identity, request));
79+
return delegate.remoteConnection(signingMetadata, identity, request);
80+
}
81+
82+
public void setDelegate(RemoteS3ConnectionProvider delegate)
83+
{
84+
this.delegate = requireNonNull(delegate, "delegate is null");
85+
}
86+
87+
public List<RemoteS3ConnectionProviderArgs> getCallArgs()
88+
{
89+
return callArgs;
90+
}
91+
92+
public void reset()
93+
{
94+
callArgs.clear();
95+
delegate = null;
96+
}
97+
}
98+
99+
public record RemoteS3ConnectionProviderArgs(SigningMetadata signingMetadata, Optional<Identity> identity, ParsedS3Request request) {}
100+
101+
private final S3Client s3Client;
102+
private final S3Client storageClient;
103+
private final DelegateRemoteS3ConnectionProvider delegateRemoteS3ConnectionProvider;
104+
private final List<String> buckets;
105+
106+
@Inject
107+
public TestRemoteS3ConnectionProviderWithRewriter(
108+
S3Client s3Client,
109+
@ForS3Container S3Client storageClient,
110+
DelegateRemoteS3ConnectionProvider delegateRemoteS3ConnectionProvider,
111+
@ForS3Container List<String> buckets)
112+
{
113+
this.s3Client = requireNonNull(s3Client, "s3Client is null");
114+
this.storageClient = requireNonNull(storageClient, "storageClient is null");
115+
this.delegateRemoteS3ConnectionProvider = requireNonNull(delegateRemoteS3ConnectionProvider, "delegateRemoteS3ConnectionProvider is null");
116+
this.buckets = ImmutableList.copyOf(buckets);
117+
}
118+
119+
@AfterEach
120+
public void cleanup()
121+
{
122+
delegateRemoteS3ConnectionProvider.reset();
123+
}
124+
125+
@Test
126+
public void testRemoteS3ConnectionRetrievedWithRewrittenRequest()
127+
throws IOException
128+
{
129+
String bucket = buckets.getFirst();
130+
131+
storageClient.putObject(PutObjectRequest.builder().bucket("redirected-" + bucket).key("redirected-test_key_1337").build(), RequestBody.fromString(LOREM_IPSUM));
132+
133+
delegateRemoteS3ConnectionProvider.setDelegate((_, _, _) -> Optional.of(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL)));
134+
135+
ResponseInputStream<GetObjectResponse> resp = s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key("test_key_1337").build());
136+
assertThat(resp.readAllBytes()).asString().isEqualTo(LOREM_IPSUM);
137+
138+
assertThat(delegateRemoteS3ConnectionProvider.getCallArgs()).hasSize(1).first().satisfies(args -> {
139+
assertThat(args.request().bucketName()).isEqualTo("redirected-" + bucket);
140+
assertThat(args.request().keyInBucket()).isEqualTo("redirected-test_key_1337");
141+
});
142+
}
143+
}

trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingCredentialsRolesProvider.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
*/
1414
package io.trino.aws.proxy.server.testing;
1515

16+
import com.google.inject.Inject;
1617
import io.trino.aws.proxy.spi.credentials.AssumedRoleProvider;
1718
import io.trino.aws.proxy.spi.credentials.Credential;
1819
import io.trino.aws.proxy.spi.credentials.CredentialsProvider;
1920
import io.trino.aws.proxy.spi.credentials.EmulatedAssumedRole;
2021
import io.trino.aws.proxy.spi.credentials.Identity;
2122
import io.trino.aws.proxy.spi.credentials.IdentityCredential;
2223
import io.trino.aws.proxy.spi.remote.RemoteS3Connection;
24+
import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection;
2325
import io.trino.aws.proxy.spi.remote.RemoteS3ConnectionProvider;
2426
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
2527
import io.trino.aws.proxy.spi.signing.SigningMetadata;
@@ -33,6 +35,8 @@
3335
import java.util.concurrent.atomic.AtomicInteger;
3436

3537
import static com.google.common.base.Preconditions.checkState;
38+
import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_IDENTITY_CREDENTIAL;
39+
import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_REMOTE_CREDENTIAL;
3640
import static java.util.Objects.requireNonNull;
3741

3842
/**
@@ -60,6 +64,13 @@ private record Session(Credential sessionCredential, String originalEmulatedAcce
6064
}
6165
}
6266

67+
@Inject
68+
public TestingCredentialsRolesProvider()
69+
{
70+
addCredentials(TESTING_IDENTITY_CREDENTIAL);
71+
setDefaultRemoteConnection(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL));
72+
}
73+
6374
@Override
6475
public Optional<RemoteS3Connection> remoteConnection(SigningMetadata signingMetadata, Optional<Identity> identity, ParsedS3Request request)
6576
{

trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import com.google.common.collect.ImmutableList;
1717
import com.google.common.collect.ImmutableMap;
1818
import com.google.common.collect.ImmutableSet;
19-
import com.google.inject.Inject;
2019
import com.google.inject.Injector;
2120
import com.google.inject.Key;
2221
import com.google.inject.Module;
@@ -41,7 +40,6 @@
4140
import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container;
4241
import io.trino.aws.proxy.spi.credentials.Credential;
4342
import io.trino.aws.proxy.spi.credentials.IdentityCredential;
44-
import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection;
4543

4644
import java.io.Closeable;
4745
import java.util.Collection;
@@ -98,6 +96,7 @@ public static class Builder
9896
private boolean v4PySparkContainerAdded;
9997
private boolean opaContainerAdded;
10098
private boolean addTestingCredentialsRoleProviders = true;
99+
private boolean addTestingRemoteS3CredentialsProvider = true;
101100

102101
public Builder addModule(Module module)
103102
{
@@ -200,6 +199,12 @@ public Builder withoutTestingCredentialsRoleProviders()
200199
return this;
201200
}
202201

202+
public Builder withoutTestingRemoteS3ConnectionProvider()
203+
{
204+
addTestingRemoteS3CredentialsProvider = false;
205+
return this;
206+
}
207+
203208
public Builder withOpaContainer()
204209
{
205210
if (opaContainerAdded) {
@@ -214,33 +219,22 @@ public Builder withOpaContainer()
214219
public TestingTrinoAwsProxyServer buildAndStart()
215220
{
216221
if (addTestingCredentialsRoleProviders) {
217-
if (mockS3ContainerAdded) {
218-
modules.add(binder -> binder.bind(TestingCredentialsInitializer.class).asEagerSingleton());
219-
}
220-
221222
addModule(credentialsProviderModule("testing", TestingCredentialsRolesProvider.class, (binder) -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON)));
222223
withProperty("credentials-provider.type", "testing");
223224
addModule(assumedRoleProviderModule("testing", TestingCredentialsRolesProvider.class, (binder) -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON)));
224225
withProperty("assumed-role-provider.type", "testing");
226+
}
227+
228+
if (addTestingRemoteS3CredentialsProvider) {
225229
addModule(remoteS3ConnectionProviderModule("testing", TestingCredentialsRolesProvider.class,
226-
binder -> binder.bind(TestingCredentialsInitializer.class).in(Scopes.SINGLETON)));
230+
binder -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON)));
227231
withProperty("remote-s3-connection-provider.type", "testing");
228232
}
229233

230234
return start(modules.build(), properties.buildKeepingLast());
231235
}
232236
}
233237

234-
static class TestingCredentialsInitializer
235-
{
236-
@Inject
237-
TestingCredentialsInitializer(TestingCredentialsRolesProvider credentialsController)
238-
{
239-
credentialsController.addCredentials(TESTING_IDENTITY_CREDENTIAL);
240-
credentialsController.setDefaultRemoteConnection(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL));
241-
}
242-
}
243-
244238
private static TestingTrinoAwsProxyServer start(Collection<Module> extraModules, Map<String, String> properties)
245239
{
246240
ImmutableList.Builder<Module> modules = ImmutableList.<Module>builder()

0 commit comments

Comments
 (0)