diff --git a/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java b/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java index 9e8aafe2d..e4af7a412 100644 --- a/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java +++ b/src/main/java/com/example/solidconnection/chat/config/WebSocketHandshakeInterceptor.java @@ -19,10 +19,9 @@ public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse res if (principal != null) { attributes.put("user", principal); - return true; } - return false; + return true; } @Override diff --git a/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java b/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java index 0a81b860e..8c8dc8f30 100644 --- a/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java +++ b/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java @@ -28,16 +28,21 @@ public class TokenAuthenticationFilter extends OncePerRequestFilter { public void doFilterInternal(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response, @NonNull FilterChain filterChain) throws ServletException, IOException { - Optional token = authorizationHeaderParser.parseToken(request); - if (token.isEmpty()) { - filterChain.doFilter(request, response); - return; - } + Optional resolvedToken = resolveToken(request); - TokenAuthentication authToken = new TokenAuthentication(token.get()); - Authentication auth = authenticationManager.authenticate(authToken); - SecurityContextHolder.getContext().setAuthentication(auth); + resolvedToken.filter(token -> !token.isBlank()).ifPresent(token -> { + TokenAuthentication authToken = new TokenAuthentication(token); + Authentication auth = authenticationManager.authenticate(authToken); + SecurityContextHolder.getContext().setAuthentication(auth); + }); filterChain.doFilter(request, response); } + + private Optional resolveToken(HttpServletRequest request) { + if (request.getRequestURI().startsWith("/connect")) { + return Optional.ofNullable(request.getParameter("token")); + } + return authorizationHeaderParser.parseToken(request); + } } diff --git a/src/test/java/com/example/solidconnection/websocket/WebSocketStompIntegrationTest.java b/src/test/java/com/example/solidconnection/websocket/WebSocketStompIntegrationTest.java index 978bfd717..b39c91ece 100644 --- a/src/test/java/com/example/solidconnection/websocket/WebSocketStompIntegrationTest.java +++ b/src/test/java/com/example/solidconnection/websocket/WebSocketStompIntegrationTest.java @@ -2,8 +2,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.ThrowableAssert.catchThrowable; -import static org.junit.jupiter.api.Assertions.assertAll; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.example.solidconnection.auth.service.AccessToken; import com.example.solidconnection.auth.service.AuthTokenProvider; @@ -11,9 +10,8 @@ import com.example.solidconnection.siteuser.fixture.SiteUserFixture; import com.example.solidconnection.support.TestContainerSpringBootTest; import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -22,11 +20,9 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.web.server.LocalServerPort; import org.springframework.messaging.converter.MappingJackson2MessageConverter; -import org.springframework.messaging.simp.stomp.StompHeaders; import org.springframework.messaging.simp.stomp.StompSession; import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; import org.springframework.web.client.HttpClientErrorException; -import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.messaging.WebSocketStompClient; import org.springframework.web.socket.sockjs.client.SockJsClient; @@ -67,48 +63,29 @@ void tearDown() { @Nested class WebSocket_핸드셰이크_및_STOMP_세션_수립_테스트 { - private final BlockingQueue transportErrorQueue = new ArrayBlockingQueue<>(1); - - private final StompSessionHandlerAdapter sessionHandler = new StompSessionHandlerAdapter() { - @Override - public void handleTransportError(StompSession session, Throwable exception) { - transportErrorQueue.add(exception); - } - }; - @Test void 인증된_사용자는_핸드셰이크를_성공한다() throws Exception { // given SiteUser user = siteUserFixture.사용자(); AccessToken accessToken = authTokenProvider.generateAccessToken(user); - - WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); - handshakeHeaders.add("Authorization", "Bearer " + accessToken.token()); + String tokenUrl = url + "?token=" + accessToken.token(); // when - stompSession = stompClient.connectAsync(url, handshakeHeaders, new StompHeaders(), sessionHandler).get(5, SECONDS); + stompSession = stompClient.connectAsync(tokenUrl, new StompSessionHandlerAdapter() { + }).get(5, SECONDS); // then - assertAll( - () -> assertThat(stompSession).isNotNull(), - () -> assertThat(transportErrorQueue).isEmpty() - ); + assertThat(stompSession.isConnected()).isTrue(); } @Test void 인증되지_않은_사용자는_핸드셰이크를_실패한다() { - // when - Throwable thrown = catchThrowable(() -> { - stompSession = stompClient.connectAsync(url, new WebSocketHttpHeaders(), new StompHeaders(), sessionHandler).get(5, SECONDS); - }); - - // then - assertAll( - () -> assertThat(thrown) - .isInstanceOf(ExecutionException.class) - .hasCauseInstanceOf(HttpClientErrorException.Unauthorized.class), - () -> assertThat(transportErrorQueue).hasSize(1) - ); + // when & then + assertThatThrownBy(() -> { + stompClient.connectAsync(url, new StompSessionHandlerAdapter() { + }).get(5, TimeUnit.SECONDS); + }).isInstanceOf(ExecutionException.class) + .hasCauseInstanceOf(HttpClientErrorException.Unauthorized.class); } } }