diff --git a/src/main/java/com/pitchain/common/collector/RoleRequestCollector.java b/src/main/java/com/pitchain/common/collector/RoleRequestCollector.java index 210836e..c4b795c 100644 --- a/src/main/java/com/pitchain/common/collector/RoleRequestCollector.java +++ b/src/main/java/com/pitchain/common/collector/RoleRequestCollector.java @@ -3,7 +3,6 @@ import com.pitchain.common.annotation.RequiredRole; import com.pitchain.common.constant.MemberRole; import org.springframework.context.ApplicationContext; -import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.http.HttpMethod; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.RequestMethod; @@ -37,8 +36,7 @@ public RoleRequestCollector(ApplicationContext applicationContext) { private void collectRoleUris(Map.Entry entry, Map>> roleUriMap) { HandlerMethod handlerMethod = entry.getValue(); - RequiredRole requiredRole = AnnotatedElementUtils.findMergedAnnotation(handlerMethod.getMethod(), RequiredRole.class); - + RequiredRole requiredRole = handlerMethod.getMethod().getAnnotation(RequiredRole.class); if (requiredRole != null) { MemberRole memberRole = requiredRole.value(); Map> uriMap = roleUriMap.get(memberRole); diff --git a/src/main/java/com/pitchain/common/config/SecurityConfig.java b/src/main/java/com/pitchain/common/config/SecurityConfig.java index 55511b7..d5401cc 100644 --- a/src/main/java/com/pitchain/common/config/SecurityConfig.java +++ b/src/main/java/com/pitchain/common/config/SecurityConfig.java @@ -1,7 +1,11 @@ package com.pitchain.common.config; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pitchain.common.apiPayload.CustomResponse; +import com.pitchain.common.apiPayload.ErrorStatus; import com.pitchain.common.collector.RoleRequestCollector; import com.pitchain.common.filter.JwtAuthenticationFilter; +import jakarta.servlet.http.HttpServletResponse; import lombok.RequiredArgsConstructor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -11,12 +15,15 @@ import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; +import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -29,17 +36,6 @@ public class SecurityConfig { private final JwtAuthenticationFilter jwtAuthenticationFilter; private final RoleRequestCollector roleRequestCollector; - public static final String[] whitelist = { - "/oauth**", - "/resources/**", "/favicon.ico", // resource - "/swagger-ui/**", "/api-docs/**", "/v3/api-docs**", "/v3/api-docs/**", // swagger - "/dev/**", // 개발용, - "/health-check", // health check - - "/members/tokens", // 공통 유저 - "/companies", "/companies/login", "/companies/emails"// 회사 - }; - @Bean public PasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); @@ -55,17 +51,43 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { http.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class); http.authorizeHttpRequests(auth -> { - auth.requestMatchers(whitelist).permitAll(); roleRequestCollector.getRoleUriMap().forEach((role, methodUriMap) -> { methodUriMap.forEach((httpMethod, uriSet) -> { auth.requestMatchers(httpMethod, uriSet.toArray(new String[0])).hasAnyAuthority(role.getRoles()); }); }); + auth.requestMatchers(SWAGGER_PATTERNS).permitAll(); + auth.requestMatchers(STATIC_RESOURCES_PATTERNS).permitAll(); + auth.requestMatchers(PUBLIC_ENDPOINTS).permitAll(); auth.anyRequest().authenticated(); }); + http.exceptionHandling(exception -> { + exception.authenticationEntryPoint(customAuthenticationEntryPoint()); + exception.accessDeniedHandler(customAccessDeniedHandler()); + }); return http.build(); } + private static final String[] SWAGGER_PATTERNS = { + "/swagger-ui/**", + "/v3/api-docs/**", + }; + + private static final String[] STATIC_RESOURCES_PATTERNS = { + "/img/**", + "/css/**", + "/js/**", + "/favicon.ico", + }; + + private static final String[] PUBLIC_ENDPOINTS = { + "/health-check", // health check + "/oauth**", + "/members/tokens", "/members/emails", // 공통 유저 + "/companies", "/companies/login", // 회사 + "/dev/**", // 개발용 + }; + @Bean public CorsConfigurationSource corsConfigurationSource() { CorsConfiguration config = new CorsConfiguration(); @@ -79,4 +101,30 @@ public CorsConfigurationSource corsConfigurationSource() { source.registerCorsConfiguration("/**", config); return source; } + + @Bean + public AuthenticationEntryPoint customAuthenticationEntryPoint() { + return (request, response, authException) -> { + final String message = "유효한 인증 정보가 없거나, 존재하지 않는 API를 요청하셨습니다."; + writeErrorResponse(response, message); + }; + } + + @Bean + public AccessDeniedHandler customAccessDeniedHandler() { + return (request, response, accessDeniedException) -> { + final String message = "요청하신 API에 대한 접근 권한이 없습니다."; + writeErrorResponse(response, message); + }; + } + + private static void writeErrorResponse(HttpServletResponse response, String message) throws IOException { + ErrorStatus errorStatus = ErrorStatus._FORBIDDEN; + CustomResponse customResponse = CustomResponse.onFailure(errorStatus.getCode(), message); + + response.setStatus(errorStatus.getHttpStatus().value()); + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.getWriter().write(new ObjectMapper().writeValueAsString(customResponse)); + } } diff --git a/src/main/java/com/pitchain/common/constant/MemberRole.java b/src/main/java/com/pitchain/common/constant/MemberRole.java index 80bacdf..67ec0b1 100644 --- a/src/main/java/com/pitchain/common/constant/MemberRole.java +++ b/src/main/java/com/pitchain/common/constant/MemberRole.java @@ -8,17 +8,17 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import java.util.Collection; -import java.util.List; +import java.util.EnumSet; @Getter @RequiredArgsConstructor public enum MemberRole { - MEMBER(new String[]{"ROLE_INDIVIDUAL", "ROLE_COMPANY"}), - INDIVIDUAL(new String[]{"ROLE_INDIVIDUAL"}), - COMPANY(new String[]{"ROLE_COMPANY"}), + MEMBER(EnumSet.of(RoleType.INDIVIDUAL, RoleType.COMPANY)), + INDIVIDUAL(EnumSet.of(RoleType.INDIVIDUAL)), + COMPANY(EnumSet.of(RoleType.COMPANY)), ; - private final String[] roles; + private final EnumSet roleTypes; public static MemberRole toEnum(String role) { try { @@ -29,8 +29,23 @@ public static MemberRole toEnum(String role) { } public Collection getAuthorities() { - return List.of(roles).stream() - .map(SimpleGrantedAuthority::new) + return roleTypes.stream() + .map(roleType -> new SimpleGrantedAuthority(roleType.getRoleName())) .toList(); } + + public String[] getRoles() { + return roleTypes.stream().map(RoleType::getRoleName).toArray(String[]::new); + } + + @Getter + @RequiredArgsConstructor + public enum RoleType { + INDIVIDUAL("ROLE_INDIVIDUAL"), + COMPANY("ROLE_COMPANY"), + ; + + private final String roleName; + } + } diff --git a/src/main/java/com/pitchain/common/filter/JwtAuthenticationFilter.java b/src/main/java/com/pitchain/common/filter/JwtAuthenticationFilter.java index 53e76da..4f26ebb 100644 --- a/src/main/java/com/pitchain/common/filter/JwtAuthenticationFilter.java +++ b/src/main/java/com/pitchain/common/filter/JwtAuthenticationFilter.java @@ -1,12 +1,11 @@ package com.pitchain.common.filter; -import com.auth0.jwt.interfaces.DecodedJWT; -import com.pitchain.common.apiPayload.ErrorStatus; +import com.auth0.jwt.exceptions.JWTVerificationException; import com.pitchain.common.constant.MemberRole; import com.pitchain.common.constant.TokenType; -import com.pitchain.common.exception.GeneralException; -import com.pitchain.common.security.MemberDetails; import com.pitchain.common.redis.RedisTokenUtil; +import com.pitchain.common.security.MemberClaims; +import com.pitchain.common.security.MemberDetails; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -17,7 +16,6 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Component; -import org.springframework.util.PatternMatchUtils; import org.springframework.web.filter.OncePerRequestFilter; import java.io.IOException; @@ -28,37 +26,42 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { private final RedisTokenUtil redisTokenUtil; - public static final String[] whitelist = { - "/oauth2/**", - "/resources/**", "/favicon.ico", // resource - "/swagger-ui/**", "/api-docs/**", "/v3/api-docs**", "/v3/api-docs/**", // swagger - "/health-check", // health check - "/dev/**", // 개발용, - "/members/tokens", // 공통 유저 - "/companies", "/companies/login", "/companies/emails"// 회사 - }; - @Override - protected boolean shouldNotFilter(HttpServletRequest request) { - return PatternMatchUtils.simpleMatch(whitelist, request.getRequestURI()); + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + try { + verifyAccessToken(request); + } catch (JWTVerificationException e1) { + try { + verifyRefreshToken(request, response); + } catch (JWTVerificationException e2) { + filterChain.doFilter(request, response); + return; + } + } + + filterChain.doFilter(request, response); } - @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - String token = redisTokenUtil.extractToken(request, TokenType.ACCESS_TOKEN); + private void verifyAccessToken(HttpServletRequest request) { + String accessToken = redisTokenUtil.extractToken(request, TokenType.ACCESS_TOKEN); + MemberClaims claim = redisTokenUtil.getClaim(accessToken); + setAuthentication(claim); + } - if (token == null) - throw new GeneralException(ErrorStatus.TOKEN_MISSING); + private void verifyRefreshToken(HttpServletRequest request, HttpServletResponse response) { + String refreshToken = redisTokenUtil.extractToken(request, TokenType.REFRESH_TOKEN); + MemberClaims claims = redisTokenUtil.getClaim(refreshToken); + redisTokenUtil.reissueToken(response, claims); + setAuthentication(claims); + } - DecodedJWT decodedJWT = redisTokenUtil.decodedJWT(token); - Long id = decodedJWT.getClaim("id").asLong(); - String role = decodedJWT.getClaim("role").asString(); + private void setAuthentication(MemberClaims claims) { + Long id = claims.getId(); + MemberRole memberRole = claims.getMemberRole(); - MemberRole memberRole = MemberRole.toEnum(role); MemberDetails memberDetails = new MemberDetails(id, memberRole); Authentication authentication = new UsernamePasswordAuthenticationToken(memberDetails, null, memberRole.getAuthorities()); SecurityContextHolder.getContext().setAuthentication(authentication); - doFilter(request, response, filterChain); } } diff --git a/src/main/java/com/pitchain/common/redis/RedisTokenUtil.java b/src/main/java/com/pitchain/common/redis/RedisTokenUtil.java index f375936..9360229 100644 --- a/src/main/java/com/pitchain/common/redis/RedisTokenUtil.java +++ b/src/main/java/com/pitchain/common/redis/RedisTokenUtil.java @@ -11,6 +11,7 @@ import com.pitchain.common.exception.GeneralException; import com.pitchain.common.security.MemberClaims; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -49,7 +50,6 @@ public String issueAccessToken(Long memberId, MemberRole memberRole) { .sign(Algorithm.HMAC512(secretKey)); } - // todo 프로토타입 시연을 위한 임시 메소드 public String issueAccessTokenWithoutExpiration(Long memberId, MemberRole memberRole) { return JWT.create() .withSubject(ACCESS_TOKEN_SUBJECT) @@ -119,14 +119,21 @@ public DecodedJWT decodedJWT(String accessToken) { } } - public MemberClaims getClaim(String refreshToken) { - try { - DecodedJWT decodedJWT = JWT.require(Algorithm.HMAC512(secretKey)).build().verify(refreshToken); - Long id = decodedJWT.getClaim("id").asLong(); - MemberRole memberRole = MemberRole.valueOf(decodedJWT.getClaim("role").asString()); - return new MemberClaims(id, memberRole); - } catch (JWTVerificationException e) { - throw new GeneralException(ErrorStatus._UNAUTHORIZED); - } + public MemberClaims getClaim(String token) { + DecodedJWT decodedJWT = JWT.require(Algorithm.HMAC512(secretKey)).build().verify(token); + Long id = decodedJWT.getClaim("id").asLong(); + MemberRole memberRole = MemberRole.toEnum(decodedJWT.getClaim("role").asString()); + return new MemberClaims(id, memberRole); + } + + public void reissueToken(HttpServletResponse response, MemberClaims claims) { + Long userId = claims.getId(); + MemberRole memberRole = claims.getMemberRole(); + + String accessToken = issueAccessToken(userId, memberRole); + String refreshToken = issueRefreshToken(userId, memberRole); + + response.setHeader(accessHeader, BEARER + accessToken); + response.setHeader(refreshHeader, BEARER + refreshToken); } } diff --git a/src/main/java/com/pitchain/investment/presentation/InvestmentController.java b/src/main/java/com/pitchain/investment/presentation/InvestmentController.java index ff9db87..b0ca692 100644 --- a/src/main/java/com/pitchain/investment/presentation/InvestmentController.java +++ b/src/main/java/com/pitchain/investment/presentation/InvestmentController.java @@ -27,7 +27,7 @@ public void addInvestment(@PathVariable("bmId") Long bmId, investmentService.addInvestment(bmId, memberDetails, amount); } - @GetMapping("/{bmId}/investment") + @GetMapping("/{bmId}/investments") public InvestmentStatusRes getInvestmentStatus(@PathVariable("bmId") Long bmId) { InvestmentStatusRes investmentStatusRes = investmentService.getInvestmentStatus(bmId); return investmentStatusRes;