Skip to content

Commit

Permalink
fix: enforce RLS on resource selector search
Browse files Browse the repository at this point in the history
  • Loading branch information
adityathebe authored and moshloop committed Jan 14, 2025
1 parent a23c74a commit e3df2bb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
51 changes: 3 additions & 48 deletions catalog/controllers.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
package catalog

import (
"encoding/json"
"fmt"
"net/http"

"github.com/flanksource/commons/logger"
"github.com/flanksource/duty/api"
"github.com/flanksource/duty/context"
"github.com/flanksource/duty/query"
"github.com/flanksource/incident-commander/auth"
echoSrv "github.com/flanksource/incident-commander/echo"
"github.com/flanksource/incident-commander/rbac"
"github.com/flanksource/incident-commander/rbac/policy"
"github.com/labstack/echo/v4"
"github.com/lib/pq"
"go.opentelemetry.io/otel/trace"
)

func init() {
Expand All @@ -26,11 +21,11 @@ func RegisterRoutes(e *echo.Echo) {
logger.Infof("Registering /catalog routes")

apiGroup := e.Group("/catalog", rbac.Catalog(policy.ActionRead))
apiGroup.POST("/summary", SearchConfigSummary, rlsMiddleware)
apiGroup.POST("/summary", SearchConfigSummary, echoSrv.RLSMiddleware)

apiGroup.POST("/changes", SearchCatalogChanges, rlsMiddleware)
apiGroup.POST("/changes", SearchCatalogChanges, echoSrv.RLSMiddleware)
// Deprecated. Use POST
apiGroup.GET("/changes", SearchCatalogChanges, rlsMiddleware)
apiGroup.GET("/changes", SearchCatalogChanges, echoSrv.RLSMiddleware)
}

func SearchCatalogChanges(c echo.Context) error {
Expand Down Expand Up @@ -64,43 +59,3 @@ func SearchConfigSummary(c echo.Context) error {

return c.JSON(http.StatusOK, response)
}

func rlsMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context().(context.Context)

rlsPayload, err := auth.GetRLSPayload(ctx)
if err != nil {
return err
}

if rlsPayload.Disable {
return next(c)
}

rlsJSON, err := json.Marshal(rlsPayload)
if err != nil {
return err
}

err = ctx.Transaction(func(txCtx context.Context, _ trace.Span) error {
if err := txCtx.DB().Exec("SET LOCAL ROLE postgrest_api").Error; err != nil {
return err
}

// NOTE: SET statements in PostgreSQL do not support parameterized queries, so we must use fmt.Sprintf
// to inject the rlsJSON safely using pq.QuoteLiteral.
rlsSet := fmt.Sprintf(`SET LOCAL request.jwt.claims TO %s`, pq.QuoteLiteral(string(rlsJSON)))
if err := txCtx.DB().Exec(rlsSet).Error; err != nil {
return err
}

// set the context with the tx
c.SetRequest(c.Request().WithContext(txCtx))

return next(c)
})

return err
}
}
2 changes: 1 addition & 1 deletion echo/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func SearchResources(c echov4.Context) error {

var request query.SearchResourcesRequest
if err := json.NewDecoder(c.Request().Body).Decode(&request); err != nil {
return api.WriteError(c, api.Errorf(api.EINVALID, err.Error()))
return api.WriteError(c, api.Errorf(api.EINVALID, "%s", err.Error()))
}

response, err := query.SearchResources(ctx, request)
Expand Down
45 changes: 44 additions & 1 deletion echo/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package echo
import (
gocontext "context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -32,9 +33,11 @@ import (
"github.com/labstack/echo-contrib/echoprometheus"
echov4 "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/lib/pq"
prom "github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const (
Expand Down Expand Up @@ -114,7 +117,7 @@ func New(ctx context.Context) *echov4.Echo {
Forward(ctx, e, "/kubeproxy", "https://kubernetes.default.svc", KubeProxyTokenMiddleware)

e.GET("/properties", dutyEcho.Properties)
e.POST("/resources/search", SearchResources, rbac.Authorization(policy.ObjectCatalog, policy.ActionRead))
e.POST("/resources/search", SearchResources, rbac.Authorization(policy.ObjectCatalog, policy.ActionRead), RLSMiddleware)

e.GET("/metrics", echoprometheus.NewHandlerWithConfig(echoprometheus.HandlerConfig{
Gatherer: prom.DefaultGatherer,
Expand Down Expand Up @@ -306,3 +309,43 @@ func Start(e *echov4.Echo, httpPort int) {
logger.Fatalf("Failed to start server: %v", err)
}
}

func RLSMiddleware(next echov4.HandlerFunc) echov4.HandlerFunc {
return func(c echov4.Context) error {
ctx := c.Request().Context().(context.Context)

rlsPayload, err := auth.GetRLSPayload(ctx)
if err != nil {
return err
}

if rlsPayload.Disable {
return next(c)
}

rlsJSON, err := json.Marshal(rlsPayload)
if err != nil {
return err
}

err = ctx.Transaction(func(txCtx context.Context, _ trace.Span) error {
if err := txCtx.DB().Exec("SET LOCAL ROLE postgrest_api").Error; err != nil {
return err
}

// NOTE: SET statements in PostgreSQL do not support parameterized queries, so we must use fmt.Sprintf
// to inject the rlsJSON safely using pq.QuoteLiteral.
rlsSet := fmt.Sprintf(`SET LOCAL request.jwt.claims TO %s`, pq.QuoteLiteral(string(rlsJSON)))
if err := txCtx.DB().Exec(rlsSet).Error; err != nil {
return err
}

// set the context with the tx
c.SetRequest(c.Request().WithContext(txCtx))

return next(c)
})

return err
}
}

0 comments on commit e3df2bb

Please sign in to comment.