Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.pitchain.common.annotation.RequiredRole;
import com.pitchain.common.constant.MemberRole;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.RequestMethod;
Expand Down Expand Up @@ -37,8 +36,7 @@ public RoleRequestCollector(ApplicationContext applicationContext) {

private void collectRoleUris(Map.Entry<RequestMappingInfo, HandlerMethod> entry, Map<MemberRole, Map<HttpMethod, Set<String>>> roleUriMap) {
HandlerMethod handlerMethod = entry.getValue();
RequiredRole requiredRole = AnnotatedElementUtils.findMergedAnnotation(handlerMethod.getMethod(), RequiredRole.class);

RequiredRole requiredRole = handlerMethod.getMethod().getAnnotation(RequiredRole.class);
if (requiredRole != null) {
MemberRole memberRole = requiredRole.value();
Map<HttpMethod, Set<String>> uriMap = roleUriMap.get(memberRole);
Expand Down
72 changes: 60 additions & 12 deletions src/main/java/com/pitchain/common/config/SecurityConfig.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.pitchain.common.config;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.pitchain.common.apiPayload.CustomResponse;
import com.pitchain.common.apiPayload.ErrorStatus;
import com.pitchain.common.collector.RoleRequestCollector;
import com.pitchain.common.filter.JwtAuthenticationFilter;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -11,12 +15,15 @@
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

Expand All @@ -29,17 +36,6 @@ public class SecurityConfig {
private final JwtAuthenticationFilter jwtAuthenticationFilter;
private final RoleRequestCollector roleRequestCollector;

public static final String[] whitelist = {
"/oauth**",
"/resources/**", "/favicon.ico", // resource
"/swagger-ui/**", "/api-docs/**", "/v3/api-docs**", "/v3/api-docs/**", // swagger
"/dev/**", // 개발용,
"/health-check", // health check

"/members/tokens", // 공통 유저
"/companies", "/companies/login", "/companies/emails"// 회사
};

@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
Expand All @@ -55,17 +51,43 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);

http.authorizeHttpRequests(auth -> {
auth.requestMatchers(whitelist).permitAll();
roleRequestCollector.getRoleUriMap().forEach((role, methodUriMap) -> {
methodUriMap.forEach((httpMethod, uriSet) -> {
auth.requestMatchers(httpMethod, uriSet.toArray(new String[0])).hasAnyAuthority(role.getRoles());
});
});
auth.requestMatchers(SWAGGER_PATTERNS).permitAll();
auth.requestMatchers(STATIC_RESOURCES_PATTERNS).permitAll();
auth.requestMatchers(PUBLIC_ENDPOINTS).permitAll();
auth.anyRequest().authenticated();
});
http.exceptionHandling(exception -> {
exception.authenticationEntryPoint(customAuthenticationEntryPoint());
exception.accessDeniedHandler(customAccessDeniedHandler());
});
return http.build();
}

private static final String[] SWAGGER_PATTERNS = {
"/swagger-ui/**",
"/v3/api-docs/**",
};

private static final String[] STATIC_RESOURCES_PATTERNS = {
"/img/**",
"/css/**",
"/js/**",
"/favicon.ico",
};

private static final String[] PUBLIC_ENDPOINTS = {
"/health-check", // health check
"/oauth**",
"/members/tokens", "/members/emails", // 공통 유저
"/companies", "/companies/login", // 회사
"/dev/**", // 개발용
};

@Bean
public CorsConfigurationSource corsConfigurationSource() {
CorsConfiguration config = new CorsConfiguration();
Expand All @@ -79,4 +101,30 @@ public CorsConfigurationSource corsConfigurationSource() {
source.registerCorsConfiguration("/**", config);
return source;
}

@Bean
public AuthenticationEntryPoint customAuthenticationEntryPoint() {
return (request, response, authException) -> {
final String message = "유효한 인증 정보가 없거나, 존재하지 않는 API를 요청하셨습니다.";
writeErrorResponse(response, message);
};
}

@Bean
public AccessDeniedHandler customAccessDeniedHandler() {
return (request, response, accessDeniedException) -> {
final String message = "요청하신 API에 대한 접근 권한이 없습니다.";
writeErrorResponse(response, message);
};
}

private static void writeErrorResponse(HttpServletResponse response, String message) throws IOException {
ErrorStatus errorStatus = ErrorStatus._FORBIDDEN;
CustomResponse customResponse = CustomResponse.onFailure(errorStatus.getCode(), message);

response.setStatus(errorStatus.getHttpStatus().value());
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.getWriter().write(new ObjectMapper().writeValueAsString(customResponse));
}
}
29 changes: 22 additions & 7 deletions src/main/java/com/pitchain/common/constant/MemberRole.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import org.springframework.security.core.authority.SimpleGrantedAuthority;

import java.util.Collection;
import java.util.List;
import java.util.EnumSet;

@Getter
@RequiredArgsConstructor
public enum MemberRole {
MEMBER(new String[]{"ROLE_INDIVIDUAL", "ROLE_COMPANY"}),
INDIVIDUAL(new String[]{"ROLE_INDIVIDUAL"}),
COMPANY(new String[]{"ROLE_COMPANY"}),
MEMBER(EnumSet.of(RoleType.INDIVIDUAL, RoleType.COMPANY)),
INDIVIDUAL(EnumSet.of(RoleType.INDIVIDUAL)),
COMPANY(EnumSet.of(RoleType.COMPANY)),
;

private final String[] roles;
private final EnumSet<RoleType> roleTypes;

public static MemberRole toEnum(String role) {
try {
Expand All @@ -29,8 +29,23 @@ public static MemberRole toEnum(String role) {
}

public Collection<? extends GrantedAuthority> getAuthorities() {
return List.of(roles).stream()
.map(SimpleGrantedAuthority::new)
return roleTypes.stream()
.map(roleType -> new SimpleGrantedAuthority(roleType.getRoleName()))
.toList();
}

public String[] getRoles() {
return roleTypes.stream().map(RoleType::getRoleName).toArray(String[]::new);
}

@Getter
@RequiredArgsConstructor
public enum RoleType {
INDIVIDUAL("ROLE_INDIVIDUAL"),
COMPANY("ROLE_COMPANY"),
;

private final String roleName;
}

}
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package com.pitchain.common.filter;

import com.auth0.jwt.interfaces.DecodedJWT;
import com.pitchain.common.apiPayload.ErrorStatus;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.pitchain.common.constant.MemberRole;
import com.pitchain.common.constant.TokenType;
import com.pitchain.common.exception.GeneralException;
import com.pitchain.common.security.MemberDetails;
import com.pitchain.common.redis.RedisTokenUtil;
import com.pitchain.common.security.MemberClaims;
import com.pitchain.common.security.MemberDetails;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -17,7 +16,6 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.PatternMatchUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import java.io.IOException;
Expand All @@ -28,37 +26,42 @@
public class JwtAuthenticationFilter extends OncePerRequestFilter {
private final RedisTokenUtil redisTokenUtil;

public static final String[] whitelist = {
"/oauth2/**",
"/resources/**", "/favicon.ico", // resource
"/swagger-ui/**", "/api-docs/**", "/v3/api-docs**", "/v3/api-docs/**", // swagger
"/health-check", // health check
"/dev/**", // 개발용,
"/members/tokens", // 공통 유저
"/companies", "/companies/login", "/companies/emails"// 회사
};

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return PatternMatchUtils.simpleMatch(whitelist, request.getRequestURI());
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
try {
verifyAccessToken(request);
} catch (JWTVerificationException e1) {
try {
verifyRefreshToken(request, response);
} catch (JWTVerificationException e2) {
filterChain.doFilter(request, response);
return;
}
}

filterChain.doFilter(request, response);
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
String token = redisTokenUtil.extractToken(request, TokenType.ACCESS_TOKEN);
private void verifyAccessToken(HttpServletRequest request) {
String accessToken = redisTokenUtil.extractToken(request, TokenType.ACCESS_TOKEN);
MemberClaims claim = redisTokenUtil.getClaim(accessToken);
setAuthentication(claim);
}

if (token == null)
throw new GeneralException(ErrorStatus.TOKEN_MISSING);
private void verifyRefreshToken(HttpServletRequest request, HttpServletResponse response) {
String refreshToken = redisTokenUtil.extractToken(request, TokenType.REFRESH_TOKEN);
MemberClaims claims = redisTokenUtil.getClaim(refreshToken);
redisTokenUtil.reissueToken(response, claims);
setAuthentication(claims);
}

DecodedJWT decodedJWT = redisTokenUtil.decodedJWT(token);
Long id = decodedJWT.getClaim("id").asLong();
String role = decodedJWT.getClaim("role").asString();
private void setAuthentication(MemberClaims claims) {
Long id = claims.getId();
MemberRole memberRole = claims.getMemberRole();

MemberRole memberRole = MemberRole.toEnum(role);
MemberDetails memberDetails = new MemberDetails(id, memberRole);
Authentication authentication = new UsernamePasswordAuthenticationToken(memberDetails, null, memberRole.getAuthorities());

SecurityContextHolder.getContext().setAuthentication(authentication);
doFilter(request, response, filterChain);
}
}
27 changes: 17 additions & 10 deletions src/main/java/com/pitchain/common/redis/RedisTokenUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.pitchain.common.exception.GeneralException;
import com.pitchain.common.security.MemberClaims;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -49,7 +50,6 @@ public String issueAccessToken(Long memberId, MemberRole memberRole) {
.sign(Algorithm.HMAC512(secretKey));
}

// todo 프로토타입 시연을 위한 임시 메소드
public String issueAccessTokenWithoutExpiration(Long memberId, MemberRole memberRole) {
return JWT.create()
.withSubject(ACCESS_TOKEN_SUBJECT)
Expand Down Expand Up @@ -119,14 +119,21 @@ public DecodedJWT decodedJWT(String accessToken) {
}
}

public MemberClaims getClaim(String refreshToken) {
try {
DecodedJWT decodedJWT = JWT.require(Algorithm.HMAC512(secretKey)).build().verify(refreshToken);
Long id = decodedJWT.getClaim("id").asLong();
MemberRole memberRole = MemberRole.valueOf(decodedJWT.getClaim("role").asString());
return new MemberClaims(id, memberRole);
} catch (JWTVerificationException e) {
throw new GeneralException(ErrorStatus._UNAUTHORIZED);
}
public MemberClaims getClaim(String token) {
DecodedJWT decodedJWT = JWT.require(Algorithm.HMAC512(secretKey)).build().verify(token);
Long id = decodedJWT.getClaim("id").asLong();
MemberRole memberRole = MemberRole.toEnum(decodedJWT.getClaim("role").asString());
return new MemberClaims(id, memberRole);
}

public void reissueToken(HttpServletResponse response, MemberClaims claims) {
Long userId = claims.getId();
MemberRole memberRole = claims.getMemberRole();

String accessToken = issueAccessToken(userId, memberRole);
String refreshToken = issueRefreshToken(userId, memberRole);

response.setHeader(accessHeader, BEARER + accessToken);
response.setHeader(refreshHeader, BEARER + refreshToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void addInvestment(@PathVariable("bmId") Long bmId,
investmentService.addInvestment(bmId, memberDetails, amount);
}

@GetMapping("/{bmId}/investment")
@GetMapping("/{bmId}/investments")
public InvestmentStatusRes getInvestmentStatus(@PathVariable("bmId") Long bmId) {
InvestmentStatusRes investmentStatusRes = investmentService.getInvestmentStatus(bmId);
return investmentStatusRes;
Expand Down