Skip to content

Commit 4918592

Browse files
committed
Add RSocketServiceMethod support for suspending functions
See #34868 Signed-off-by: Dmitry Sulman <[email protected]>
1 parent 9a10b04 commit 4918592

File tree

3 files changed

+164
-3
lines changed

3 files changed

+164
-3
lines changed

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828
import org.jspecify.annotations.Nullable;
2929
import org.reactivestreams.Publisher;
30+
import reactor.core.publisher.Flux;
3031
import reactor.core.publisher.Mono;
3132

3233
import org.springframework.core.DefaultParameterNameDiscoverer;
34+
import org.springframework.core.KotlinDetector;
3335
import org.springframework.core.MethodParameter;
3436
import org.springframework.core.ParameterizedTypeReference;
3537
import org.springframework.core.ReactiveAdapter;
@@ -54,6 +56,8 @@
5456
*/
5557
final class RSocketServiceMethod {
5658

59+
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
60+
5761
private final Method method;
5862

5963
private final MethodParameter[] parameters;
@@ -82,6 +86,10 @@ private static MethodParameter[] initMethodParameters(Method method) {
8286
if (count == 0) {
8387
return new MethodParameter[0];
8488
}
89+
if (KotlinDetector.isSuspendingFunction(method)) {
90+
count -= 1;
91+
}
92+
8593
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
8694
MethodParameter[] parameters = new MethodParameter[count];
8795
for (int i = 0; i < count; i++) {
@@ -129,10 +137,19 @@ private static Function<RSocketRequestValues, Object> initResponseFunction(
129137

130138
MethodParameter returnParam = new MethodParameter(method, -1);
131139
Class<?> returnType = returnParam.getParameterType();
140+
boolean isFlowReturnType = COROUTINES_FLOW_CLASS_NAME.equals(returnType.getName());
141+
boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) && !isFlowReturnType;
142+
if (isUnwrapped) {
143+
returnType = Mono.class;
144+
}
145+
else if (isFlowReturnType) {
146+
returnType = Flux.class;
147+
}
148+
132149
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
133150

134151
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
135-
Class<?> actualType = actualParam.getNestedParameterType();
152+
Class<?> actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType();
136153

137154
Function<RSocketRequestValues, Publisher<?>> responseFunction;
138155
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
@@ -147,7 +164,8 @@ else if (reactiveAdapter == null) {
147164
}
148165
else {
149166
ParameterizedTypeReference<?> payloadType =
150-
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
167+
ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() :
168+
actualParam.getNestedGenericParameterType());
151169

152170
responseFunction = values -> (
153171
reactiveAdapter.isMultiValue() ?

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.aop.framework.ProxyFactory;
3333
import org.springframework.aop.framework.ReflectiveMethodInvocation;
34+
import org.springframework.core.KotlinDetector;
3435
import org.springframework.core.MethodIntrospector;
3536
import org.springframework.core.ReactiveAdapterRegistry;
3637
import org.springframework.core.annotation.AnnotatedElementUtils;
@@ -246,7 +247,9 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
246247
Method method = invocation.getMethod();
247248
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
248249
if (serviceMethod != null) {
249-
return serviceMethod.invoke(invocation.getArguments());
250+
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
251+
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
252+
return serviceMethod.invoke(arguments);
250253
}
251254
if (method.isDefault()) {
252255
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
@@ -256,6 +259,12 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
256259
}
257260
throw new IllegalStateException("Unexpected method invocation: " + method);
258261
}
262+
263+
private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
264+
Object[] functionArgs = new Object[args.length - 1];
265+
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
266+
return functionArgs;
267+
}
259268
}
260269

261270
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright 2002-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.rsocket.service
18+
19+
import io.rsocket.util.DefaultPayload
20+
import kotlinx.coroutines.flow.Flow
21+
import kotlinx.coroutines.flow.flowOf
22+
import kotlinx.coroutines.flow.map
23+
import kotlinx.coroutines.flow.toList
24+
import kotlinx.coroutines.reactive.asFlow
25+
import kotlinx.coroutines.runBlocking
26+
import org.assertj.core.api.Assertions.assertThat
27+
import org.junit.jupiter.api.BeforeEach
28+
import org.junit.jupiter.api.Test
29+
import org.springframework.messaging.rsocket.RSocketRequester
30+
import org.springframework.messaging.rsocket.RSocketStrategies
31+
import org.springframework.messaging.rsocket.TestRSocket
32+
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
33+
import reactor.core.publisher.Flux
34+
import reactor.core.publisher.Mono
35+
36+
/**
37+
* Kotlin tests for [RSocketServiceMethod].
38+
*
39+
* @author Dmitry Sulman
40+
*/
41+
class RSocketServiceMethodKotlinTests {
42+
43+
private lateinit var rsocket: TestRSocket
44+
45+
private lateinit var proxyFactory: RSocketServiceProxyFactory
46+
47+
@BeforeEach
48+
fun setUp() {
49+
rsocket = TestRSocket()
50+
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
51+
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
52+
}
53+
54+
@Test
55+
fun fireAndForget(): Unit = runBlocking {
56+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
57+
58+
val requestPayload = "request"
59+
service.fireAndForget(requestPayload)
60+
61+
assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
62+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
63+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
64+
}
65+
66+
@Test
67+
fun requestResponse(): Unit = runBlocking {
68+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
69+
70+
val requestPayload = "request"
71+
val responsePayload = "response"
72+
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
73+
val response = service.requestResponse(requestPayload)
74+
75+
assertThat(response).isEqualTo(responsePayload)
76+
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
77+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
78+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
79+
}
80+
81+
@Test
82+
fun requestStream(): Unit = runBlocking {
83+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
84+
85+
val requestPayload = "request"
86+
val responsePayload1 = "response1"
87+
val responsePayload2 = "response2"
88+
rsocket.setPayloadFluxToReturn(
89+
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
90+
val response = service.requestStream(requestPayload).toList()
91+
92+
assertThat(response).containsExactly(responsePayload1, responsePayload2)
93+
assertThat(rsocket.savedMethodName).isEqualTo("requestStream")
94+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs")
95+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
96+
}
97+
98+
@Test
99+
fun requestChannel(): Unit = runBlocking {
100+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
101+
102+
val requestPayload1 = "request1"
103+
val requestPayload2 = "request2"
104+
val responsePayload1 = "response1"
105+
val responsePayload2 = "response2"
106+
rsocket.setPayloadFluxToReturn(
107+
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
108+
val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList()
109+
110+
assertThat(response).containsExactly(responsePayload1, responsePayload2)
111+
assertThat(rsocket.savedMethodName).isEqualTo("requestChannel")
112+
113+
val savedPayloads = rsocket.savedPayloadFlux
114+
?.asFlow()
115+
?.map { it.dataUtf8 }
116+
?.toList()
117+
assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2)
118+
}
119+
120+
private interface SuspendingFunctionsService {
121+
122+
@RSocketExchange("ff")
123+
suspend fun fireAndForget(input: String)
124+
125+
@RSocketExchange("rr")
126+
suspend fun requestResponse(input: String): String
127+
128+
@RSocketExchange("rs")
129+
suspend fun requestStream(input: String): Flow<String>
130+
131+
@RSocketExchange("rc")
132+
suspend fun requestChannel(input: Flow<String>): Flow<String>
133+
}
134+
}

0 commit comments

Comments
 (0)