diff --git a/api/handler/evaluation.go b/api/handler/evaluation.go index 6c607c95..8e0e6561 100644 --- a/api/handler/evaluation.go +++ b/api/handler/evaluation.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "fmt" "log/slog" "strconv" @@ -8,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/errorx" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/component" ) @@ -95,8 +97,13 @@ func (h *EvaluationHandler) GetEvaluation(ctx *gin.Context) { evaluation, err := h.evaluation.GetEvaluation(ctx.Request.Context(), *req) if err != nil { slog.Error("Failed to get evaluation job", slog.Any("error", err)) - httpbase.ServerError(ctx, err) - return + if errors.Is(err, errorx.ErrForbidden) { + httpbase.ForbiddenError(ctx, err) + return + } else { + httpbase.ServerError(ctx, err) + return + } } httpbase.OK(ctx, evaluation) diff --git a/common/types/user.go b/common/types/user.go index 46c15fb6..e599edb7 100644 --- a/common/types/user.go +++ b/common/types/user.go @@ -190,6 +190,15 @@ type User struct { Tags []RepoTag `json:"tags,omitempty"` } +func (u User) IsAdmin() bool { + for _, role := range u.Roles { + if role == "admin" || role == "super_user" { + return true + } + } + return false +} + type UserLikesRequest struct { Username string `json:"username"` RepoID int64 `json:"repo_id"` diff --git a/component/evaluation.go b/component/evaluation.go index 0e9456d3..71398490 100644 --- a/component/evaluation.go +++ b/component/evaluation.go @@ -10,8 +10,10 @@ import ( "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/deploy/common" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/errorx" "opencsg.com/csghub-server/common/types" ) @@ -29,6 +31,7 @@ type evaluationComponentImpl struct { config *config.Config accountingComponent AccountingComponent repoComponent RepoComponent + userSvcClient rpc.UserSvcClient } type EvaluationComponent interface { @@ -60,6 +63,10 @@ func NewEvaluationComponent(config *config.Config) (EvaluationComponent, error) return nil, fmt.Errorf("failed to create repo component, %w", err) } c.accountingComponent = ac + c.userSvcClient = rpc.NewUserSvcHttpClient( + fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), + rpc.AuthWithApiKey(config.APIToken), + ) return c, nil } @@ -205,6 +212,15 @@ func (c *evaluationComponentImpl) GetEvaluation(ctx context.Context, req types.E if err != nil { return nil, fmt.Errorf("fail to get evaluation result, %w", err) } + if wf.Username != req.Username { + userInfo, err := c.userSvcClient.GetUserByName(ctx, req.Username) + if err != nil { + return nil, fmt.Errorf("failed to get user info for %s, %w", req.Username, err) + } + if !userInfo.IsAdmin() { + return nil, errorx.ErrForbidden + } + } var repoTags []types.RepoTags for _, path := range wf.Datasets { ds, err := c.datasetStore.FindByOriginPath(ctx, path) diff --git a/component/evaluation_ce_test.go b/component/evaluation_ce_test.go index 626b9093..41588d3e 100644 --- a/component/evaluation_ce_test.go +++ b/component/evaluation_ce_test.go @@ -4,12 +4,14 @@ package component import ( "context" + "database/sql" "encoding/json" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/errorx" "opencsg.com/csghub-server/common/types" ) @@ -177,6 +179,138 @@ func TestEvaluationComponent_GetEvaluation(t *testing.T) { require.Nil(t, err) } +func TestEvaluationComponent_GetEvaluation_AccessControl(t *testing.T) { + t.Run("owner is req user", func(t *testing.T) { + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + req := types.EvaluationGetReq{ + Username: "test", + ID: 1, + } + c.mocks.stores.WorkflowMock().EXPECT().FindByID(ctx, int64(1)).Return(database.ArgoWorkflow{ + ID: 1, + RepoIds: []string{"Rowan/hellaswag"}, + Datasets: []string{"deleted/dataset"}, + RepoType: "model", + Username: "test", + TaskName: "test", + TaskId: "test", + TaskType: "evaluation", + Status: "Succeed", + }, nil) + c.mocks.stores.DatasetMock().EXPECT().FindByOriginPath(ctx, "deleted/dataset").Return(nil, sql.ErrNoRows) + e, err := c.GetEvaluation(ctx, req) + require.NotNil(t, e) + require.Equal(t, "test", e.TaskName) + require.Nil(t, err) + require.Len(t, e.Datasets, 1) + require.True(t, e.Datasets[0].Deleted) + require.Equal(t, "deleted/dataset", e.Datasets[0].RepoId) + }) + t.Run("owner is different from req user and req user is not admin", func(t *testing.T) { + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + req := types.EvaluationGetReq{ + Username: "otheruser", + ID: 1, + } + c.mocks.stores.WorkflowMock().EXPECT().FindByID(ctx, int64(1)).Return(database.ArgoWorkflow{ + ID: 1, + RepoIds: []string{"Rowan/hellaswag"}, + Datasets: []string{"Rowan/hellaswag"}, + RepoType: "model", + Username: "test", + TaskName: "test", + TaskId: "test", + TaskType: "evaluation", + Status: "Succeed", + }, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "otheruser").Return(&types.User{ + Roles: []string{"user"}, + Username: "otheruser", + UUID: "otheruser", + ID: 2, + }, nil) + e, err := c.GetEvaluation(ctx, req) + require.Equal(t, err, errorx.ErrForbidden) + require.Nil(t, e) + }) + t.Run("owner is different from req user and req user is admin", func(t *testing.T) { + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + req := types.EvaluationGetReq{ + Username: "otheruser", + ID: 1, + } + c.mocks.stores.WorkflowMock().EXPECT().FindByID(ctx, int64(1)).Return(database.ArgoWorkflow{ + ID: 1, + RepoIds: []string{"Rowan/hellaswag"}, + Datasets: []string{"Rowan/hellaswag"}, + RepoType: "model", + Username: "test", + TaskName: "test", + TaskId: "test", + TaskType: "evaluation", + Status: "Succeed", + }, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "otheruser").Return(&types.User{ + Roles: []string{"admin"}, + Username: "otheruser", + UUID: "otheruser", + ID: 2, + }, nil) + c.mocks.stores.DatasetMock().EXPECT().FindByOriginPath(ctx, "Rowan/hellaswag").Return(&database.Dataset{ + Repository: &database.Repository{ + Path: "Rowan/hellaswag", + Tags: []database.Tag{ + { + Name: "test", + Category: "test", + Group: "test", + Scope: "test", + BuiltIn: true, + }, + }, + }, + }, nil) + e, err := c.GetEvaluation(ctx, req) + require.NotNil(t, e) + require.Equal(t, "test", e.TaskName) + require.Nil(t, err) + require.Len(t, e.Datasets, 1) + require.False(t, e.Datasets[0].Deleted) + require.Equal(t, "Rowan/hellaswag", e.Datasets[0].RepoId) + }) + t.Run("owner is different from req user and get user info fails", func(t *testing.T) { + ctx := context.TODO() + c := initializeTestEvaluationComponent(ctx, t) + c.config.Argo.QuotaGPUNumber = "1" + req := types.EvaluationGetReq{ + Username: "otheruser", + ID: 1, + } + c.mocks.stores.WorkflowMock().EXPECT().FindByID(ctx, int64(1)).Return(database.ArgoWorkflow{ + ID: 1, + RepoIds: []string{"Rowan/hellaswag"}, + Datasets: []string{"Rowan/hellaswag"}, + RepoType: "model", + Username: "test", + TaskName: "test", + TaskId: "test", + TaskType: "evaluation", + Status: "Succeed", + }, nil) + c.mocks.userSvcClient.EXPECT().GetUserByName(ctx, "otheruser").Return(nil, errorx.ErrRemoteServiceFail) + e, err := c.GetEvaluation(ctx, req) + require.Nil(t, e) + require.Error(t, err) + require.Contains(t, err.Error(), errorx.ErrRemoteServiceFail.Error()) + }) +} + func TestEvaluationComponent_DeleteEvaluation(t *testing.T) { ctx := context.TODO() c := initializeTestEvaluationComponent(ctx, t) diff --git a/component/wire_gen_test.go b/component/wire_gen_test.go index dfb45ce8..77bad107 100644 --- a/component/wire_gen_test.go +++ b/component/wire_gen_test.go @@ -1458,7 +1458,15 @@ func initializeTestEvaluationComponent(ctx context.Context, t interface { mockDeployer := deploy.NewMockDeployer(t) mockAccountingComponent := component.NewMockAccountingComponent(t) mockRepoComponent := component.NewMockRepoComponent(t) - componentEvaluationComponentImpl := NewTestEvaluationComponent(config, mockStores, mockDeployer, mockAccountingComponent, mockRepoComponent) + mockUserSvcClient := rpc.NewMockUserSvcClient(t) + componentEvaluationComponentImpl := NewTestEvaluationComponent( + config, + mockStores, + mockDeployer, + mockAccountingComponent, + mockRepoComponent, + mockUserSvcClient, + ) mockTagComponent := component.NewMockTagComponent(t) mockSpaceComponent := component.NewMockSpaceComponent(t) mockRuntimeArchitectureComponent := component.NewMockRuntimeArchitectureComponent(t) @@ -1472,7 +1480,6 @@ func initializeTestEvaluationComponent(ctx context.Context, t interface { sensitive: mockSensitiveComponent, } mockGitServer := gitserver.NewMockGitServer(t) - mockUserSvcClient := rpc.NewMockUserSvcClient(t) mockClient := s3.NewMockClient(t) mockMirrorServer := mirrorserver.NewMockMirrorServer(t) mockCache := cache.NewMockCache(t) diff --git a/component/wireset.go b/component/wireset.go index 29390646..36d9b5ab 100644 --- a/component/wireset.go +++ b/component/wireset.go @@ -551,7 +551,14 @@ func NewTestClusterComponent(config *config.Config, deployer deploy.Deployer, st var ClusterComponentSet = wire.NewSet(NewTestClusterComponent) -func NewTestEvaluationComponent(config *config.Config, stores *tests.MockStores, deployer deploy.Deployer, accountingComponent AccountingComponent, repoComponent RepoComponent) *evaluationComponentImpl { +func NewTestEvaluationComponent( + config *config.Config, + stores *tests.MockStores, + deployer deploy.Deployer, + accountingComponent AccountingComponent, + repoComponent RepoComponent, + userSvcClient rpc.UserSvcClient, +) *evaluationComponentImpl { return &evaluationComponentImpl{ deployer: deployer, userStore: stores.User, @@ -566,6 +573,7 @@ func NewTestEvaluationComponent(config *config.Config, stores *tests.MockStores, config: config, accountingComponent: accountingComponent, repoComponent: repoComponent, + userSvcClient: userSvcClient, } }