diff --git a/README.md b/README.md index 6ed52d96..915b3616 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ cd tests/setup_local_db/ JSON summary data endpoints: ```bash curl http://localhost:8080/sources | python -m json.tool -curl http://localhost:8080/cohortdefinition-stats/by-source-id/1 | python -m json.tool +curl "http://localhost:8080/cohortdefinition-stats/by-source-id/1/by-team-project?team-project=test" | python -m json.tool curl http://localhost:8080/concept/by-source-id/1 | python -m json.tool curl -d '{"ConceptIds":[2000000324,2000006885]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1 | python -m json.tool curl -d '{"ConceptTypes":["Measurement","Person"]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1/by-type | python -m json.tool diff --git a/controllers/cohortdata.go b/controllers/cohortdata.go index b93ccce2..b0b7ae10 100644 --- a/controllers/cohortdata.go +++ b/controllers/cohortdata.go @@ -9,16 +9,21 @@ import ( "strconv" "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/middlewares" "github.com/uc-cdis/cohort-middleware/models" "github.com/uc-cdis/cohort-middleware/utils" ) type CohortDataController struct { - cohortDataModel models.CohortDataI + cohortDataModel models.CohortDataI + teamProjectAuthz middlewares.TeamProjectAuthzI } -func NewCohortDataController(cohortDataModel models.CohortDataI) CohortDataController { - return CohortDataController{cohortDataModel: cohortDataModel} +func NewCohortDataController(cohortDataModel models.CohortDataI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDataController { + return CohortDataController{ + cohortDataModel: cohortDataModel, + teamProjectAuthz: teamProjectAuthz, + } } func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Context) { @@ -44,6 +49,14 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co cohortId, _ := strconv.Atoi(cohortIdStr) histogramConceptId, _ := strconv.ParseInt(histogramIdStr, 10, 64) + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + cohortData, err := u.cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId, cohortId, histogramConceptId, filterConceptIds, cohortPairs) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving concept details", "error": err.Error()}) @@ -85,6 +98,14 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g sourceId, _ := strconv.Atoi(sourceIdStr) cohortId, _ := strconv.Atoi(cohortIdStr) + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + // call model method: cohortData, err := u.cohortDataModel.RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId, cohortId, conceptIds) if err != nil { @@ -230,6 +251,14 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep controlCohortId, errors[2] = utils.ParseNumericArg(c, "controlcohortid") conceptIds, cohortPairs, errors[3] = utils.ParseConceptIdsAndDichotomousDefs(c) + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + if utils.ContainsNonNil(errors) { c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"}) c.Abort() diff --git a/controllers/cohortdefinition.go b/controllers/cohortdefinition.go index e1638381..4d9b338b 100644 --- a/controllers/cohortdefinition.go +++ b/controllers/cohortdefinition.go @@ -2,7 +2,6 @@ package controllers import ( "net/http" - "strconv" "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/models" @@ -17,56 +16,10 @@ func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinition return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel} } -func (u CohortDefinitionController) RetriveById(c *gin.Context) { - cohortDefinitionId := c.Param("id") - - if cohortDefinitionId != "" { - cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId) - cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) - c.Abort() - return - } - c.JSON(http.StatusOK, gin.H{"cohort_definition": cohortDefinition}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"}) - c.Abort() -} - -func (u CohortDefinitionController) RetriveByName(c *gin.Context) { - cohortDefinitionName := c.Param("name") - - if cohortDefinitionName != "" { - cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionByName(cohortDefinitionName) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) - c.Abort() - return - } - c.JSON(http.StatusOK, gin.H{"CohortDefinition": cohortDefinition}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"}) - c.Abort() -} - -func (u CohortDefinitionController) RetriveAll(c *gin.Context) { - cohortDefinitions, err := u.cohortDefinitionModel.GetAllCohortDefinitions() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) - c.Abort() - return - } - c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions}) -} - func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) { // This method returns ALL cohortdefinition entries with cohort size statistics (for a given source) - sourceId, err1 := utils.ParseNumericArg(c, "sourceid") - teamProject := c.Param("teamproject") + teamProject := c.Query("team-project") if teamProject == "" { c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"}) c.Abort() diff --git a/controllers/concept.go b/controllers/concept.go index e8520a7d..78fd7d86 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -132,7 +132,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl c.Abort() return } - validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) @@ -198,7 +198,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { return } _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) - validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) diff --git a/middlewares/teamprojectauthz.go b/middlewares/teamprojectauthz.go index c7a4230e..3b044ffb 100644 --- a/middlewares/teamprojectauthz.go +++ b/middlewares/teamprojectauthz.go @@ -11,7 +11,8 @@ import ( type TeamProjectAuthzI interface { TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool - TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool + TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool + TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool } type HttpClientI interface { @@ -58,16 +59,20 @@ func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects [ func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { filterCohortPairs := []utils.CustomDichotomousVariableDef{} - return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs) + return u.TeamProjectValidation(ctx, []int{cohortDefinitionId}, filterCohortPairs) +} + +func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds, filterCohortPairs) + return u.TeamProjectValidationForCohortIdsList(ctx, uniqueCohortDefinitionIdsList) } // "team project" related checks: -// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project" +// (1) check if all cohorts belong to the same "team project" // (2) check if the user has permission in the "team project" // Returns true if both checks above pass, false otherwise. -func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { - - uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) +func (u TeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool { teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) if len(teamProjects) == 0 { log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") diff --git a/server/router.go b/server/router.go index 884118f5..e7252d66 100644 --- a/server/router.go +++ b/server/router.go @@ -29,10 +29,7 @@ func NewRouter() *gin.Engine { authorized.GET("/sources", source.RetriveAll) cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition)) - authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById) - authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName) - authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll) - authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) + authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) // concept endpoints: concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition), @@ -46,7 +43,7 @@ func NewRouter() *gin.Engine { authorized.POST("/concept-stats/by-source-id/:sourceid/by-cohort-definition-id/:cohortid/breakdown-by-concept-id/:breakdownconceptid/csv", concepts.RetrieveAttritionTable) // cohort stats and checks: - cohortData := controllers.NewCohortDataController(*new(models.CohortData)) + cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{})) // :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here: authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index f5f29a07..6dcae92a 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -5,6 +5,7 @@ import ( "io" "log" "net/http" + "net/url" "os" "reflect" "strconv" @@ -49,7 +50,8 @@ func tearDown() { log.Println("teardown for test") } -var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortDataModel)) +var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyTeamProjectAuthz)) +var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz)) // instance of the controller that talks to the regular model implementation (that needs a real DB): var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition)) @@ -110,9 +112,9 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) { cohortDefinitionStats := []*models.CohortDefinitionStats{ - {Id: 1, CohortSize: 10, Name: "name1"}, - {Id: 2, CohortSize: 22, Name: "name2"}, - {Id: 3, CohortSize: 33, Name: "name3"}, + {Id: 1, CohortSize: 10, Name: "name1_" + teamProject}, // just concatenate teamProject here, so we can assert on it in a later test... teamprojects are otherwise not really part of cohort names + {Id: 2, CohortSize: 22, Name: "name2_" + teamProject}, + {Id: 3, CohortSize: 33, Name: "name3_" + teamProject}, } return cohortDefinitionStats, nil } @@ -141,7 +143,11 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, return true } -func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { +func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + return true +} + +func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool { return true } @@ -151,7 +157,11 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Co return false } -func (h dummyFailingTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { +func (h dummyFailingTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + return false +} + +func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool { return false } @@ -258,6 +268,18 @@ func TestRetrieveHistogramForCohortIdAndConceptIdWithCorrectParams(t *testing.T) if !strings.Contains(result.CustomResponseWriterOut, "bins") { t.Errorf("Expected output starting with 'bins,...'") } + + // the same request should fail if the teamProject authorization fails: + requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody)) + cohortDataControllerWithFailingTeamProjectAuthz.RetrieveHistogramForCohortIdAndConceptId(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } } func TestRetrieveDataBySourceIdAndCohortIdAndVariablesWrongParams(t *testing.T) { @@ -290,6 +312,18 @@ func TestRetrieveDataBySourceIdAndCohortIdAndVariablesCorrectParams(t *testing.T if !strings.Contains(result.CustomResponseWriterOut, "sample.id,") { t.Errorf("Expected output starting with 'sample.id,...'") } + + // the same request should fail if the teamProject authorization fails: + requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody)) + cohortDataControllerWithFailingTeamProjectAuthz.RetrieveDataBySourceIdAndCohortIdAndVariables(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } } func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) { @@ -312,6 +346,18 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) if !strings.Contains(result.CustomResponseWriterOut, "case_control_overlap") { t.Errorf("Expected output containing 'case_control_overlap...'") } + + // the same request should fail if the teamProject authorization fails: + requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody)) + cohortDataControllerWithFailingTeamProjectAuthz.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } } func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValueBadRequest(t *testing.T) { @@ -385,7 +431,8 @@ func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) - requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) + requestContext.Request = &http.Request{URL: &url.URL{}} + requestContext.Request.URL.RawQuery = "team-project=dummy-team-project" requestContext.Writer = new(tests.CustomResponseWriter) defer func() { @@ -420,58 +467,22 @@ func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) - requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) + //requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) + requestContext.Request = &http.Request{URL: &url.URL{}} + teamProject := "/test/dummyname/dummy-team-project" + requestContext.Request.URL.RawQuery = "team-project=" + teamProject requestContext.Writer = new(tests.CustomResponseWriter) cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) log.Printf("result: %s", result) // expect result with all of the dummy data: - if !strings.Contains(result.CustomResponseWriterOut, "name1") || - !strings.Contains(result.CustomResponseWriterOut, "name2") || - !strings.Contains(result.CustomResponseWriterOut, "name3") { + if !strings.Contains(result.CustomResponseWriterOut, "name1_"+teamProject) || + !strings.Contains(result.CustomResponseWriterOut, "name2_"+teamProject) || + !strings.Contains(result.CustomResponseWriterOut, "name3_"+teamProject) { t.Errorf("Expected 3 rows in result") } } -func TestRetriveByIdWrongParam(t *testing.T) { - setUp(t) - requestContext := new(gin.Context) - requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"}) - requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveById(requestContext) - // Params above are wrong, so request should abort: - if !requestContext.IsAborted() { - t.Errorf("Expected aborted request") - } -} - -func TestRetriveById(t *testing.T) { - setUp(t) - requestContext := new(gin.Context) - requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"}) - requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveById(requestContext) - result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) - // expect result with dummy data: - if !strings.Contains(result.CustomResponseWriterOut, "test 1") { - t.Errorf("Expected data in result") - } -} - -func TestRetriveByIdModelError(t *testing.T) { - setUp(t) - requestContext := new(gin.Context) - requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"}) - requestContext.Writer = new(tests.CustomResponseWriter) - // set flag to let mock model layer return error instead of mock data: - dummyModelReturnError = true - cohortDefinitionController.RetriveById(requestContext) - if !requestContext.IsAborted() { - t.Errorf("Expected aborted request") - } -} - func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) { setUp(t) requestContext := new(gin.Context) diff --git a/tests/middlewares_tests/middlewares_test.go b/tests/middlewares_tests/middlewares_test.go index 427dd71e..883d596a 100644 --- a/tests/middlewares_tests/middlewares_test.go +++ b/tests/middlewares_tests/middlewares_test.go @@ -134,7 +134,7 @@ func TestTeamProjectValidation(t *testing.T) { requestContext.Request.Header = map[string][]string{ "Authorization": {"dummy_token_value"}, } - result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil) + result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1}, nil) if result == false { t.Errorf("Expected TeamProjectValidation result to be 'true'") } @@ -155,7 +155,7 @@ func TestTeamProjectValidationArborist401(t *testing.T) { requestContext.Request.Header = map[string][]string{ "Authorization": {"dummy_token_value"}, } - result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil) + result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1}, nil) if result == true { t.Errorf("Expected TeamProjectValidation result to be 'false'") } @@ -176,7 +176,7 @@ func TestTeamProjectValidationNoTeamProjectMatchingAllCohortDefinitions(t *testi requestContext.Request.Header = map[string][]string{ "Authorization": {"dummy_token_value"}, } - result := teamProjectAuthz.TeamProjectValidation(requestContext, 0, nil) + result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{0}, nil) if result == true { t.Errorf("Expected TeamProjectValidation result to be 'false'") } diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index e8539d43..8a95ff2a 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -573,7 +573,7 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testi CohortDefinitionId2: largestCohort.Id, ProvidedName: "test"}, } - uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs) teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" { t.Errorf("Expected to find only defaultteamproject") @@ -589,7 +589,7 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) { CohortDefinitionId2: 2, ProvidedName: "test"}, } - uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs) teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) if len(teamProjects) != 2 { t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects) diff --git a/utils/parsing.go b/utils/parsing.go index 693ce39c..c76c6f36 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -289,9 +289,9 @@ func MakeUnique(input []int) []int { return uniqueList } -func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterCohortPairs []CustomDichotomousVariableDef) []int { +func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds []int, filterCohortPairs []CustomDichotomousVariableDef) []int { var idsList []int - idsList = append(idsList, cohortDefinitionId) + idsList = append(idsList, cohortDefinitionIds...) if len(filterCohortPairs) > 0 { for _, filterCohortPair := range filterCohortPairs { idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2)