Skip to content

feat: migrate OrgExport RPC to ConnectRPC server #1080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions internal/api/v1beta1connect/organization_billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/raystack/frontier/core/aggregates/orgbilling"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
"github.com/raystack/salt/rql"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/protobuf/types/known/timestamppb"
)

Expand Down Expand Up @@ -58,6 +59,31 @@ func (h *ConnectHandler) SearchOrganizations(ctx context.Context, request *conne
}), nil
}

func (h *ConnectHandler) ExportOrganizations(ctx context.Context, request *connect.Request[frontierv1beta1.ExportOrganizationsRequest], stream *connect.ServerStream[httpbody.HttpBody]) error {
orgBillingDataBytes, contentType, err := h.orgBillingService.Export(ctx)
if err != nil {
return nil
}

chunkSize := 1024 * 200 // 200KB

for i := 0; i < len(orgBillingDataBytes); i += chunkSize {
end := min(i+chunkSize, len(orgBillingDataBytes))

chunk := orgBillingDataBytes[i:end]
msg := &httpbody.HttpBody{
ContentType: contentType,
Data: chunk,
}

err := stream.Send(msg)
if err != nil {
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
}
return nil
}

func transformProtoToRQL(q *frontierv1beta1.RQLRequest) (*rql.Query, error) {
filters := make([]rql.Filter, 0)
for _, filter := range q.GetFilters() {
Expand Down
67 changes: 49 additions & 18 deletions pkg/server/connect_interceptors/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,57 @@ import (
"github.com/raystack/frontier/internal/api/v1beta1connect"
)

func UnaryAuthenticationCheck(h *v1beta1connect.ConnectHandler) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if authenticationSkipList[req.Spec().Procedure] {
return next(ctx, req)
}

principal, err := h.GetLoggedInPrincipal(ctx)
if err != nil {
return nil, err
}
ctx = authenticate.SetContextWithPrincipal(ctx, &principal)
ctx = audit.SetContextWithActor(ctx, audit.Actor{
ID: principal.ID,
Type: principal.Type,
})
type AuthenticationInterceptor struct {
h *v1beta1connect.ConnectHandler
}

func NewAuthenticationInterceptor(h *v1beta1connect.ConnectHandler) *AuthenticationInterceptor {
return &AuthenticationInterceptor{h}
}

func (i *AuthenticationInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if authenticationSkipList[req.Spec().Procedure] {
return next(ctx, req)
}

principal, err := i.h.GetLoggedInPrincipal(ctx)
if err != nil {
return nil, err
}
ctx = authenticate.SetContextWithPrincipal(ctx, &principal)
ctx = audit.SetContextWithActor(ctx, audit.Actor{
ID: principal.ID,
Type: principal.Type,
})
return next(ctx, req)
})
}

func (i *AuthenticationInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
return conn
})
}

func (i *AuthenticationInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
if authenticationSkipList[conn.Spec().Procedure] {
return next(ctx, conn)
}

principal, err := i.h.GetLoggedInPrincipal(ctx)
if err != nil {
return err
}
ctx = authenticate.SetContextWithPrincipal(ctx, &principal)
ctx = audit.SetContextWithActor(ctx, audit.Actor{
ID: principal.ID,
Type: principal.Type,
})
}
return connect.UnaryInterceptorFunc(interceptor)
return next(ctx, conn)
})
}

// authenticationSkipList stores path to skip authentication, by default its enabled for all requests
Expand Down
78 changes: 53 additions & 25 deletions pkg/server/connect_interceptors/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"github.com/raystack/frontier/core/preference"
"github.com/raystack/frontier/core/relation"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/raystack/frontier/internal/metrics"

Check failure on line 16 in pkg/server/connect_interceptors/authorization.go

View workflow job for this annotation

GitHub Actions / golangci

"github.com/raystack/frontier/internal/metrics" imported and not used (typecheck)

Check failure on line 16 in pkg/server/connect_interceptors/authorization.go

View workflow job for this annotation

GitHub Actions / regression

"github.com/raystack/frontier/internal/metrics" imported and not used

Check failure on line 16 in pkg/server/connect_interceptors/authorization.go

View workflow job for this annotation

GitHub Actions / unit

"github.com/raystack/frontier/internal/metrics" imported and not used

Check failure on line 16 in pkg/server/connect_interceptors/authorization.go

View workflow job for this annotation

GitHub Actions / unit

"github.com/raystack/frontier/internal/metrics" imported and not used

Check failure on line 16 in pkg/server/connect_interceptors/authorization.go

View workflow job for this annotation

GitHub Actions / smoke

"github.com/raystack/frontier/internal/metrics" imported and not used
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
)

Expand All @@ -22,35 +22,63 @@
ErrDeniedInvalidArgs = connect.NewError(connect.CodePermissionDenied, errors.New("invalid arguments"))
)

// UnaryAuthorizationCheck returns a unary server interceptor that checks for authorization
func UnaryAuthorizationCheck(h *v1beta1connect.ConnectHandler) connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// check if authorization needs to be skipped
if authorizationSkipEndpoints[req.Spec().Procedure] {
return next(ctx, req)
}
type AuthorizationInterceptor struct {
h *v1beta1connect.ConnectHandler
}

if metrics.ServiceOprLatency != nil {
promCollect := metrics.ServiceOprLatency("authenticate", "UnaryAuthorizationCheck")
defer promCollect()
}
func (a *AuthorizationInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
return conn
})
}

// apply authorization rules
azFunc, azVerifier := authorizationValidationMap[req.Spec().Procedure]
if !azVerifier {
// deny access if not configured by default
// return nil, connect.NewError(codes.Unauthenticated, "unauthorized access")
return nil, connect.NewError(connect.CodePermissionDenied, v1beta1connect.ErrUnauthorized)
}
if err := azFunc(ctx, h, req); err != nil {
return nil, err
}
func (a *AuthorizationInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
// check if authorization needs to be skipped
if authorizationSkipEndpoints[conn.Spec().Procedure] {
return next(ctx, conn)
}

// apply authorization rules
azFunc, azVerifier := authorizationValidationMap[conn.Spec().Procedure]
if !azVerifier {
// deny access if not configured by default
// return nil, connect.NewError(codes.Unauthenticated, "unauthorized access")
return connect.NewError(connect.CodePermissionDenied, v1beta1connect.ErrUnauthorized)
}
if err := azFunc(ctx, a.h, nil); err != nil {
return err
}

return next(ctx, conn)
})
}

func (a *AuthorizationInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// check if authorization needs to be skipped
if authorizationSkipEndpoints[req.Spec().Procedure] {
return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}

// apply authorization rules
azFunc, azVerifier := authorizationValidationMap[req.Spec().Procedure]
if !azVerifier {
// deny access if not configured by default
// return nil, connect.NewError(codes.Unauthenticated, "unauthorized access")
return nil, connect.NewError(connect.CodePermissionDenied, v1beta1connect.ErrUnauthorized)
}
if err := azFunc(ctx, a.h, req); err != nil {
return nil, err
}

return next(ctx, req)
})
}

func NewAuthorizationInterceptor(h *v1beta1connect.ConnectHandler) *AuthorizationInterceptor {
return &AuthorizationInterceptor{h}
}

// authorizationSkipEndpoints stores path to skip authorization, by default its enabled for all requests
Expand Down
112 changes: 107 additions & 5 deletions pkg/server/connect_interceptors/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,123 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/raystack/frontier/core/authenticate"

"github.com/raystack/frontier/internal/api/v1beta1connect"
"github.com/raystack/frontier/pkg/server/consts"

"connectrpc.com/connect"
"github.com/gorilla/securecookie"
"google.golang.org/grpc/metadata"
)

type Session struct {
type SessionInterceptor struct {
// TODO(kushsharma): server should be able to rotate encryption keys of codec
// use secure cookie EncodeMulti/DecodeMulti
cookieCodec securecookie.Codec
conf authenticate.SessionConfig
h *v1beta1connect.ConnectHandler
}

func (s *SessionInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return connect.StreamingClientFunc(func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
return conn
})
}

func (s *SessionInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return connect.StreamingHandlerFunc(func(ctx context.Context, conn connect.StreamingHandlerConn) error {
// parse and process cookies
incomingMD := metadata.MD{}
if s.cookieCodec != nil {
if mdCookies := conn.RequestHeader().Get("cookie"); len(mdCookies) > 0 {
header := http.Header{}
header.Add("Cookie", mdCookies)
request := http.Request{Header: header}
for _, requestCookie := range request.Cookies() {
// check if cookie is session cookie
if requestCookie.Name == consts.SessionRequestKey {
var sessionID string
// extract and decode session from cookie
if err := s.cookieCodec.Decode(requestCookie.Name, requestCookie.Value, &sessionID); err == nil {
// pass cookie in context
incomingMD.Set(consts.SessionIDGatewayKey, strings.TrimSpace(sessionID))
}
}
}
}
}

// pass user token if in token header as gateway context
if userToken := conn.RequestHeader().Values(consts.UserTokenRequestKey); len(userToken) > 0 {
incomingMD.Set(consts.UserTokenGatewayKey, strings.TrimSpace(userToken[0]))
}
// check if the same token is part of Authorization header
if authHeader := conn.RequestHeader().Values("authorization"); len(authHeader) > 0 {
tokenVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Bearer "))
if token, err := jwt.ParseInsecure([]byte(tokenVal)); err == nil {
if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) {
incomingMD.Set(consts.UserTokenGatewayKey, tokenVal)
}
}
secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic "))
if len(secretVal) > 0 {
incomingMD.Set(consts.UserSecretGatewayKey, secretVal)
}
}

ctx = metadata.NewIncomingContext(ctx, incomingMD)
return next(ctx, conn)
})
}

func (s *SessionInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
// parse and process cookies
incomingMD := metadata.MD{}
if s.cookieCodec != nil {
if mdCookies := req.Header().Get("cookie"); len(mdCookies) > 0 {
header := http.Header{}
header.Add("Cookie", mdCookies)
request := http.Request{Header: header}
for _, requestCookie := range request.Cookies() {
// check if cookie is session cookie
if requestCookie.Name == consts.SessionRequestKey {
var sessionID string
// extract and decode session from cookie
if err = s.cookieCodec.Decode(requestCookie.Name, requestCookie.Value, &sessionID); err == nil {
// pass cookie in context
incomingMD.Set(consts.SessionIDGatewayKey, strings.TrimSpace(sessionID))
}
}
}
}
}

// pass user token if in token header as gateway context
if userToken := req.Header().Values(consts.UserTokenRequestKey); len(userToken) > 0 {
incomingMD.Set(consts.UserTokenGatewayKey, strings.TrimSpace(userToken[0]))
}
// check if the same token is part of Authorization header
if authHeader := req.Header().Values("authorization"); len(authHeader) > 0 {
tokenVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Bearer "))
if token, err := jwt.ParseInsecure([]byte(tokenVal)); err == nil {
if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) {
incomingMD.Set(consts.UserTokenGatewayKey, tokenVal)
}
}
secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic "))
if len(secretVal) > 0 {
incomingMD.Set(consts.UserSecretGatewayKey, secretVal)
}
}

ctx = metadata.NewIncomingContext(ctx, incomingMD)
return next(ctx, req)
})
}

// UnaryConnectResponseInterceptor adds session cookie to response if session id is present in header
func (s Session) UnaryConnectResponseInterceptor() connect.UnaryInterceptorFunc {
func (s *SessionInterceptor) UnaryConnectResponseInterceptor() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Let the handler and other interceptors run first
Expand Down Expand Up @@ -85,11 +186,12 @@ func (s Session) UnaryConnectResponseInterceptor() connect.UnaryInterceptorFunc
return connect.UnaryInterceptorFunc(interceptor)
}

func NewSession(cookieCutter securecookie.Codec, conf authenticate.SessionConfig) *Session {
return &Session{
func NewSessionInterceptor(cookieCutter securecookie.Codec, conf authenticate.SessionConfig, h *v1beta1connect.ConnectHandler) *SessionInterceptor {
return &SessionInterceptor{
// could be nil if not configured by user
cookieCodec: cookieCutter,
conf: conf,
h: h,
}
}

Expand All @@ -108,7 +210,7 @@ func CookieSameSite(name string) http.SameSite {

// UnaryConnectRequestHeadersAnnotator converts session cookies set in grpc metadata to context
// this requires decrypting the cookie and setting it as context
func (s Session) UnaryConnectRequestHeadersAnnotator() connect.UnaryInterceptorFunc {
func (s *SessionInterceptor) UnaryConnectRequestHeadersAnnotator() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (resp connect.AnyResponse, err error) {
// parse and process cookies
Expand Down
14 changes: 8 additions & 6 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D
frontierService := v1beta1connect.NewConnectHandler(deps, cfg.Authentication)

sessionCookieCutter := getSessionCookieCutter(cfg.Authentication.Session.BlockSecretKey, cfg.Authentication.Session.HashSecretKey, logger)
sessionMiddleware := connectinterceptors.NewSession(sessionCookieCutter, cfg.Authentication.Session)

// grpcZapLogger := zap.Must(zap.NewProduction())
grpcZapLogger := zap.NewExample().Sugar()
loggerZap, ok := logger.(*log.Zap)
Expand Down Expand Up @@ -189,13 +187,17 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D
return err
}

authNInterceptor := connectinterceptors.NewAuthenticationInterceptor(frontierService)
authZInterceptor := connectinterceptors.NewAuthorizationInterceptor(frontierService)
sessionInterceptor := connectinterceptors.NewSessionInterceptor(sessionCookieCutter, cfg.Authentication.Session, frontierService)

interceptors := connect.WithInterceptors(
connectinterceptors.UnaryConnectLoggerInterceptor(grpcZapLogger.Desugar(), loggerOpts),
otelInterceptor,
sessionMiddleware.UnaryConnectRequestHeadersAnnotator(),
connectinterceptors.UnaryAuthenticationCheck(frontierService),
connectinterceptors.UnaryAuthorizationCheck(frontierService),
sessionMiddleware.UnaryConnectResponseInterceptor())
sessionInterceptor,
authNInterceptor,
authZInterceptor,
sessionInterceptor.UnaryConnectResponseInterceptor())

// Initialize connect handlers
frontierPath, frontierHandler := frontierv1beta1connect.NewFrontierServiceHandler(frontierService, interceptors)
Expand Down
Loading