From 8c3ee9552a909e5cb7b859aa9c9396e18ebd4649 Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Wed, 6 Dec 2023 19:11:49 +0100 Subject: [PATCH 1/7] feat: integrate extra 'team project' validation for concept endpoints --- controllers/concept.go | 11 +++-- middlewares/auth.go | 78 +++++++++++++++++++++++++++++-- models/cohortdefinition.go | 17 +++++++ models/concept.go | 3 +- tests/models_tests/models_test.go | 32 +++++++++++++ utils/parsing.go | 25 ++++++++++ 6 files changed, 156 insertions(+), 10 deletions(-) diff --git a/controllers/concept.go b/controllers/concept.go index 3f61c4d5..6daa6551 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/middlewares" "github.com/uc-cdis/cohort-middleware/models" "github.com/uc-cdis/cohort-middleware/utils" ) @@ -87,7 +88,8 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context) func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) { sourceId, cohortId, err := utils.ParseSourceAndCohortId(c) - if err != nil { + validRequest := middlewares.TeamProjectValidationForCohort(c, cohortId) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() @@ -112,7 +114,8 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) { sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c) - if err != nil { + validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() @@ -169,7 +172,9 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c) - if err != nil { + _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) + validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() diff --git a/middlewares/auth.go b/middlewares/auth.go index 96d09129..1787706c 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -8,6 +8,8 @@ import ( "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/config" + "github.com/uc-cdis/cohort-middleware/models" + "github.com/uc-cdis/cohort-middleware/utils" ) func AuthMiddleware() gin.HandlerFunc { @@ -22,10 +24,11 @@ func AuthMiddleware() gin.HandlerFunc { } return func(ctx *gin.Context) { - req, err := PrepareNewArboristRequest(ctx, c.GetString("arborist_endpoint")) + req, err := PrepareNewArboristRequest(ctx) if err != nil { ctx.AbortWithStatus(500) log.Printf("Error while preparing Arborist request: %s", err.Error()) + return } client := &http.Client{} // send the request to Arborist: @@ -44,9 +47,22 @@ func AuthMiddleware() gin.HandlerFunc { } // this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token -// and then return the URL that can be used to consult Arborist regarding access permissions. This function +// and then return the URL that can be used to consult Arborist regarding cohort-middleware access permissions. This function // returns an error if "Authorization / Bearer" token is missing in ctx -func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http.Request, error) { +func PrepareNewArboristRequest(ctx *gin.Context) (*http.Request, error) { + + resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path) + service := "cohort-middleware" + + return PrepareNewArboristRequestForResourceAndService(ctx, resourcePath, service) +} + +// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token +// and then return the URL that can be used to consult Arborist regarding access permissions for the given +// resource path and service. +func PrepareNewArboristRequestForResourceAndService(ctx *gin.Context, resourcePath string, service string) (*http.Request, error) { + c := config.GetConfig() + arboristEndpoint := c.GetString("arborist_endpoint") // validate: authorization := ctx.Request.Header.Get("Authorization") if authorization == "" { @@ -54,11 +70,10 @@ func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http } // build up the request URL string: - resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path) arboristAuth := fmt.Sprintf("%s/auth/proxy?resource=%s&service=%s&method=%s", arboristEndpoint, resourcePath, - "cohort-middleware", + service, "access") // make request object / validate URL: @@ -71,3 +86,56 @@ func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http req.Header.Set("Authorization", authorization) return req, nil } + +func HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { + + // query Arborist and return as soon as one of the teamProjects access check returns 200: + for _, teamProject := range teamProjects { + teamProjectAsResourcePath := teamProject + teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" + + req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) + if err != nil { + ctx.AbortWithStatus(500) + panic("Error while preparing Arborist request") + } + client := &http.Client{} + // send the request to Arborist: + resp, _ := client.Do(req) + + // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: + if resp.StatusCode == 200 { + return true + } else { + // unauthorized or otherwise: + log.Printf("Got response status %d from Arborist...", resp.StatusCode) + } + } + return false +} + +func TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { + filterCohortPairs := []utils.CustomDichotomousVariableDef{} + return TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs) +} + +// "team project" related checks: +// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project" +// (2) check if the user has permission in the "team project" +// Returns true if both checks above pass, false otherwise. +func TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + cohortDefinitionModel := new(models.CohortDefinition) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) == 0 { + log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") + return false + } + if !HasAccessToAtLeastOne(ctx, teamProjects) { + log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") + return false + } + // passed both tests: + return true +} diff --git a/models/cohortdefinition.go b/models/cohortdefinition.go index 9893f070..b905ce13 100644 --- a/models/cohortdefinition.go +++ b/models/cohortdefinition.go @@ -67,6 +67,23 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error) return cohortDefinition, meta_result.Error } +// Returns any "team project" entries that are matched to _each and every one_ of the +// cohort definition ids found in uniqueCohortDefinitionIdsList. +func (h CohortDefinition) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) { + + db2 := db.GetAtlasDB().Db + var teamProjects []string + // Find any roles that are paired to each and every one of the cohort_definition_id values. + // Roles that ony match part of the values are filtered out by the having(count) clause: + query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role"). + Select("sec_role_name"). + Where("cohort_definition_id in (?)", uniqueCohortDefinitionIdsList). + Group("sec_role_name"). + Having("count(DISTINCT cohort_definition_id) = ?", len(uniqueCohortDefinitionIdsList)). + Scan(&teamProjects) + return teamProjects, query.Error +} + // Get the list of cohort_definition ids for a given "team project" (where "team project" is basically // a security role name of one of the roles in Atlas/WebAPI database). func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { diff --git a/models/concept.go b/models/concept.go index f8801f50..ddd999d2 100644 --- a/models/concept.go +++ b/models/concept.go @@ -138,6 +138,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortId(sourceId int, cohor // {ConceptValue: "B", NPersonsInCohortWithValue: N-M-X}, // where X is the number of persons that have NO value or just a "null" value for one or more of the ids in the given filterConceptIds. func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef, breakdownConceptId int64) ([]*ConceptBreakdown, error) { + var dataSourceModel = new(Source) omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop) resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results) @@ -150,8 +151,6 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho Where("observation.observation_concept_id = ?", breakdownConceptId). Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId)) - // note: here we pass empty []utils.CustomDichotomousVariableDef{} instead of filterCohortPairs, since we already use the SQL generated by QueryFilterByCohortPairsHelper above, - // which is a better performing SQL in this particular scenario: query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation") query, cancel := utils.AddTimeoutToQuery(query) diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index 9e3b2804..5a9506d3 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -564,6 +564,38 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH } } +func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testing.T) { + setUp(t) + cohortDefinitionId := 2 + filterCohortPairs := []utils.CustomDichotomousVariableDef{ + { + CohortId1: smallestCohort.Id, + CohortId2: largestCohort.Id, + ProvidedName: "test"}, + } + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" { + t.Errorf("Expected to find only defaultteamproject") + } +} + +func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) { + setUp(t) + cohortDefinitionId := 2 + filterCohortPairs := []utils.CustomDichotomousVariableDef{ + { + CohortId1: 2, + CohortId2: 2, + ProvidedName: "test"}, + } + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) != 2 { + t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects) + } +} + func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) { setUp(t) testTeamProject := "teamprojectX" diff --git a/utils/parsing.go b/utils/parsing.go index e29c8c67..bebcde40 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -275,3 +275,28 @@ func ParseSourceIdAndCohortIdAndVariablesAsSingleList(c *gin.Context) (int, int, } return sourceId, cohortId, conceptIdsAndCohortPairs, nil } + +func MakeUnique(input []int) []int { + uniqueMap := make(map[int]bool) + var uniqueList []int + + for _, num := range input { + if !uniqueMap[num] { + uniqueMap[num] = true + uniqueList = append(uniqueList, num) + } + } + return uniqueList +} + +func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterCohortPairs []CustomDichotomousVariableDef) []int { + var idsList []int + idsList = append(idsList, cohortDefinitionId) + if len(filterCohortPairs) > 0 { + for _, filterCohortPair := range filterCohortPairs { + idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) // TODO - rename CohortId1/2 here to cohortDefinition1/2 + } + } + uniqueIdsList := MakeUnique(idsList) + return uniqueIdsList +} From fb06d28dc23fa61ca55a699e630aa28ac15c81ef Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Fri, 8 Dec 2023 18:29:30 +0100 Subject: [PATCH 2/7] fix: fix PrepareNewArboristRequest tests --- config/mocktest.yaml | 1 + tests/middlewares_tests/middlewares_test.go | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 config/mocktest.yaml diff --git a/config/mocktest.yaml b/config/mocktest.yaml new file mode 100644 index 00000000..792f2776 --- /dev/null +++ b/config/mocktest.yaml @@ -0,0 +1 @@ +arborist_endpoint: 'https://arboristdummyurl' diff --git a/tests/middlewares_tests/middlewares_test.go b/tests/middlewares_tests/middlewares_test.go index 1a6e6390..3b6aa2c1 100644 --- a/tests/middlewares_tests/middlewares_test.go +++ b/tests/middlewares_tests/middlewares_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/config" "github.com/uc-cdis/cohort-middleware/middlewares" "github.com/uc-cdis/cohort-middleware/tests" ) @@ -42,6 +43,7 @@ func tearDown() { func TestPrepareNewArboristRequest(t *testing.T) { setUp(t) + config.Init("mocktest") requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "Authorization", Value: "dummy_token_value"}) requestContext.Writer = new(tests.CustomResponseWriter) @@ -51,8 +53,7 @@ func TestPrepareNewArboristRequest(t *testing.T) { } u, _ := url.Parse("https://some-cohort-middl-server/api/abc/123") requestContext.Request.URL = u - arboristEndpoint := "https://arboristdummyurl" - resultArboristRequest, error := middlewares.PrepareNewArboristRequest(requestContext, arboristEndpoint) + resultArboristRequest, error := middlewares.PrepareNewArboristRequest(requestContext) expectedResult := "resource=/cohort-middleware/api/abc/123&service=cohort-middleware&method=access" // check if expected result URL was produced: @@ -63,12 +64,14 @@ func TestPrepareNewArboristRequest(t *testing.T) { func TestPrepareNewArboristRequestMissingToken(t *testing.T) { setUp(t) + config.Init("mocktest") requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"}) requestContext.Writer = new(tests.CustomResponseWriter) requestContext.Request = new(http.Request) - arboristEndpoint := "https://arboristdummyurl" - _, error := middlewares.PrepareNewArboristRequest(requestContext, arboristEndpoint) + u, _ := url.Parse("https://some-cohort-middl-server/api/abc/123") + requestContext.Request.URL = u + _, error := middlewares.PrepareNewArboristRequest(requestContext) // Params above are wrong, so request should abort: if error.Error() != "missing Authorization header" { From 317e13221fdb580aa583c26c8f0756a3361ad8f1 Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Fri, 8 Dec 2023 18:34:20 +0100 Subject: [PATCH 3/7] feat: move "team project" checks to TeamProjectAuthzI ...and add both real and mock implementations. This makes this part easier to test. --- controllers/concept.go | 15 +++-- middlewares/auth.go | 55 --------------- middlewares/teamprojectauthz.go | 74 +++++++++++++++++++++ models/cohortdefinition.go | 2 + server/router.go | 2 +- tests/controllers_tests/controllers_test.go | 20 +++++- 6 files changed, 106 insertions(+), 62 deletions(-) create mode 100644 middlewares/teamprojectauthz.go diff --git a/controllers/concept.go b/controllers/concept.go index 6daa6551..7d97568d 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -18,10 +18,15 @@ import ( type ConceptController struct { conceptModel models.ConceptI cohortDefinitionModel models.CohortDefinitionI + teamProjectAuthz middlewares.TeamProjectAuthzI } -func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI) ConceptController { - return ConceptController{conceptModel: conceptModel, cohortDefinitionModel: cohortDefinitionModel} +func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) ConceptController { + return ConceptController{ + conceptModel: conceptModel, + cohortDefinitionModel: cohortDefinitionModel, + teamProjectAuthz: teamProjectAuthz, + } } func (u ConceptController) RetriveAllBySourceId(c *gin.Context) { @@ -88,7 +93,7 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context) func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) { sourceId, cohortId, err := utils.ParseSourceAndCohortId(c) - validRequest := middlewares.TeamProjectValidationForCohort(c, cohortId) + validRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId) if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) @@ -114,7 +119,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) { sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c) - validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) @@ -173,7 +178,7 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c) _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) - validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) diff --git a/middlewares/auth.go b/middlewares/auth.go index 1787706c..9c756be6 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -8,8 +8,6 @@ import ( "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/config" - "github.com/uc-cdis/cohort-middleware/models" - "github.com/uc-cdis/cohort-middleware/utils" ) func AuthMiddleware() gin.HandlerFunc { @@ -86,56 +84,3 @@ func PrepareNewArboristRequestForResourceAndService(ctx *gin.Context, resourcePa req.Header.Set("Authorization", authorization) return req, nil } - -func HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { - - // query Arborist and return as soon as one of the teamProjects access check returns 200: - for _, teamProject := range teamProjects { - teamProjectAsResourcePath := teamProject - teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" - - req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) - if err != nil { - ctx.AbortWithStatus(500) - panic("Error while preparing Arborist request") - } - client := &http.Client{} - // send the request to Arborist: - resp, _ := client.Do(req) - - // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: - if resp.StatusCode == 200 { - return true - } else { - // unauthorized or otherwise: - log.Printf("Got response status %d from Arborist...", resp.StatusCode) - } - } - return false -} - -func TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { - filterCohortPairs := []utils.CustomDichotomousVariableDef{} - return TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs) -} - -// "team project" related checks: -// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project" -// (2) check if the user has permission in the "team project" -// Returns true if both checks above pass, false otherwise. -func TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { - - uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) - cohortDefinitionModel := new(models.CohortDefinition) - teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) - if len(teamProjects) == 0 { - log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") - return false - } - if !HasAccessToAtLeastOne(ctx, teamProjects) { - log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") - return false - } - // passed both tests: - return true -} diff --git a/middlewares/teamprojectauthz.go b/middlewares/teamprojectauthz.go new file mode 100644 index 00000000..002decf4 --- /dev/null +++ b/middlewares/teamprojectauthz.go @@ -0,0 +1,74 @@ +package middlewares + +import ( + "log" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/models" + "github.com/uc-cdis/cohort-middleware/utils" +) + +type TeamProjectAuthzI interface { + TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool + TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool +} + +type TeamProjectAuthz struct { + cohortDefinitionModel models.CohortDefinitionI +} + +func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI) TeamProjectAuthz { + return TeamProjectAuthz{cohortDefinitionModel: cohortDefinitionModel} +} +func hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { + + // query Arborist and return as soon as one of the teamProjects access check returns 200: + for _, teamProject := range teamProjects { + teamProjectAsResourcePath := teamProject + teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" + + req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) + if err != nil { + ctx.AbortWithStatus(500) + panic("Error while preparing Arborist request") + } + client := &http.Client{} + // send the request to Arborist: + resp, _ := client.Do(req) + + // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: + if resp.StatusCode == 200 { + return true + } else { + // unauthorized or otherwise: + log.Printf("Got response status %d from Arborist...", resp.StatusCode) + } + } + return false +} + +func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { + filterCohortPairs := []utils.CustomDichotomousVariableDef{} + return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs) +} + +// "team project" related checks: +// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project" +// (2) check if the user has permission in the "team project" +// Returns true if both checks above pass, false otherwise. +func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) == 0 { + log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") + return false + } + if !hasAccessToAtLeastOne(ctx, teamProjects) { + log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") + return false + } + // passed both tests: + return true +} diff --git a/models/cohortdefinition.go b/models/cohortdefinition.go index b905ce13..7013a96c 100644 --- a/models/cohortdefinition.go +++ b/models/cohortdefinition.go @@ -15,6 +15,8 @@ type CohortDefinitionI interface { GetAllCohortDefinitions() ([]*CohortDefinition, error) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) GetCohortName(cohortId int) (string, error) + GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) + GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) } type CohortDefinition struct { diff --git a/server/router.go b/server/router.go index 1357cbf9..4d7281cf 100644 --- a/server/router.go +++ b/server/router.go @@ -33,7 +33,7 @@ func NewRouter() *gin.Engine { authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) // concept endpoints: - concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition)) + concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition), *new(middlewares.TeamProjectAuthz)) authorized.GET("/concept/by-source-id/:sourceid", concepts.RetriveAllBySourceId) authorized.POST("/concept/by-source-id/:sourceid", concepts.RetrieveInfoBySourceIdAndConceptIds) authorized.POST("/concept/by-source-id/:sourceid/by-type", concepts.RetrieveInfoBySourceIdAndConceptTypes) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index 6a3f6ffe..ef59cd49 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -96,6 +96,14 @@ type dummyCohortDefinitionDataModel struct{} var dummyModelReturnError bool = false +func (h dummyCohortDefinitionDataModel) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { + return []int{1}, nil +} + +func (h dummyCohortDefinitionDataModel) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) { + return []string{"test"}, nil +} + func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, error) { return "dummy cohort name", nil } @@ -127,7 +135,17 @@ func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitions() ([]*models.Coh return nil, nil } -var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel)) +type dummyTeamProjectAuthz struct{} + +func (h dummyTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { + return true +} + +func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + return true +} + +var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz)) type dummyConceptDataModel struct{} From 522afa0afbbf73672272e3fe6bee36992cf212bc Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Fri, 8 Dec 2023 20:16:43 +0100 Subject: [PATCH 4/7] feat: add tests for TeamProjectAuthz --- middlewares/teamprojectauthz.go | 22 ++-- server/router.go | 5 +- tests/middlewares_tests/middlewares_test.go | 106 ++++++++++++++++++++ 3 files changed, 125 insertions(+), 8 deletions(-) diff --git a/middlewares/teamprojectauthz.go b/middlewares/teamprojectauthz.go index 002decf4..90a2e0d8 100644 --- a/middlewares/teamprojectauthz.go +++ b/middlewares/teamprojectauthz.go @@ -14,14 +14,22 @@ type TeamProjectAuthzI interface { TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool } +type HttpClientI interface { + Do(req *http.Request) (*http.Response, error) +} + type TeamProjectAuthz struct { cohortDefinitionModel models.CohortDefinitionI + httpClient HttpClientI } -func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI) TeamProjectAuthz { - return TeamProjectAuthz{cohortDefinitionModel: cohortDefinitionModel} +func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpClient HttpClientI) TeamProjectAuthz { + return TeamProjectAuthz{ + cohortDefinitionModel: cohortDefinitionModel, + httpClient: httpClient, + } } -func hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { +func (u TeamProjectAuthz) HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { // query Arborist and return as soon as one of the teamProjects access check returns 200: for _, teamProject := range teamProjects { @@ -33,16 +41,16 @@ func hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { ctx.AbortWithStatus(500) panic("Error while preparing Arborist request") } - client := &http.Client{} // send the request to Arborist: - resp, _ := client.Do(req) + resp, _ := u.httpClient.Do(req) + log.Printf("Got response status %d from Arborist...", resp.StatusCode) // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: if resp.StatusCode == 200 { return true } else { // unauthorized or otherwise: - log.Printf("Got response status %d from Arborist...", resp.StatusCode) + log.Printf("Status %d does NOT give access to team project...", resp.StatusCode) } } return false @@ -65,7 +73,7 @@ func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefiniti log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") return false } - if !hasAccessToAtLeastOne(ctx, teamProjects) { + if !u.HasAccessToAtLeastOne(ctx, teamProjects) { log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") return false } diff --git a/server/router.go b/server/router.go index 4d7281cf..884118f5 100644 --- a/server/router.go +++ b/server/router.go @@ -1,6 +1,8 @@ package server import ( + "net/http" + "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/controllers" "github.com/uc-cdis/cohort-middleware/middlewares" @@ -33,7 +35,8 @@ func NewRouter() *gin.Engine { authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) // concept endpoints: - concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition), *new(middlewares.TeamProjectAuthz)) + concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition), + middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{})) authorized.GET("/concept/by-source-id/:sourceid", concepts.RetriveAllBySourceId) authorized.POST("/concept/by-source-id/:sourceid", concepts.RetrieveInfoBySourceIdAndConceptIds) authorized.POST("/concept/by-source-id/:sourceid/by-type", concepts.RetrieveInfoBySourceIdAndConceptTypes) diff --git a/tests/middlewares_tests/middlewares_test.go b/tests/middlewares_tests/middlewares_test.go index 3b6aa2c1..427dd71e 100644 --- a/tests/middlewares_tests/middlewares_test.go +++ b/tests/middlewares_tests/middlewares_test.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/config" "github.com/uc-cdis/cohort-middleware/middlewares" + "github.com/uc-cdis/cohort-middleware/models" "github.com/uc-cdis/cohort-middleware/tests" ) @@ -78,3 +79,108 @@ func TestPrepareNewArboristRequestMissingToken(t *testing.T) { t.Errorf("Expected error") } } + +type dummyHttpClient struct { + statusCode int + nrCalls int +} + +func (h *dummyHttpClient) Do(req *http.Request) (*http.Response, error) { + h.nrCalls++ + return &http.Response{StatusCode: h.statusCode}, nil +} + +type dummyCohortDefinitionDataModel struct{} + +func (h dummyCohortDefinitionDataModel) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { + return nil, nil +} + +func (h dummyCohortDefinitionDataModel) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) { + // dummy switch just to support two test scenarios: + if uniqueCohortDefinitionIdsList[0] == 0 { + return nil, nil + } else { + return []string{"teamProject1", "teamProject2"}, nil + } +} + +func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, error) { + return "dummy cohort name", nil +} + +func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) { + return nil, nil +} +func (h dummyCohortDefinitionDataModel) GetCohortDefinitionById(id int) (*models.CohortDefinition, error) { + return nil, nil +} +func (h dummyCohortDefinitionDataModel) GetCohortDefinitionByName(name string) (*models.CohortDefinition, error) { + return nil, nil +} +func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitions() ([]*models.CohortDefinition, error) { + return nil, nil +} + +func TestTeamProjectValidation(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 200 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Request.Header = map[string][]string{ + "Authorization": {"dummy_token_value"}, + } + result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil) + if result == false { + t.Errorf("Expected TeamProjectValidation result to be 'true'") + } + if dummyHttpClient.nrCalls != 1 { + t.Errorf("Expected dummyHttpClient to have been only once") + } +} + +func TestTeamProjectValidationArborist401(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 401 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Request.Header = map[string][]string{ + "Authorization": {"dummy_token_value"}, + } + result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil) + if result == true { + t.Errorf("Expected TeamProjectValidation result to be 'false'") + } + if dummyHttpClient.nrCalls <= 1 { + t.Errorf("Expected dummyHttpClient to have been called more than once") + } +} + +func TestTeamProjectValidationNoTeamProjectMatchingAllCohortDefinitions(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 200 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Request.Header = map[string][]string{ + "Authorization": {"dummy_token_value"}, + } + result := teamProjectAuthz.TeamProjectValidation(requestContext, 0, nil) + if result == true { + t.Errorf("Expected TeamProjectValidation result to be 'false'") + } + if dummyHttpClient.nrCalls > 0 { + t.Errorf("Expected dummyHttpClient to NOT have been called") + } +} From e76c6bc5e2d6a77cd9b332518abd377cbd3fc706 Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Mon, 11 Dec 2023 16:38:11 +0100 Subject: [PATCH 5/7] fix: fix the usage of TeamProjectValidation ...and add respective tests --- controllers/concept.go | 33 +++++++++-- tests/controllers_tests/controllers_test.go | 63 +++++++++++++++++++++ 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/controllers/concept.go b/controllers/concept.go index 7d97568d..f3bf9072 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -93,13 +93,20 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context) func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) { sourceId, cohortId, err := utils.ParseSourceAndCohortId(c) - validRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId) - if err != nil || !validRequest { + if err != nil { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() return } + validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid") if err != nil { log.Printf("Error: %s", err.Error()) @@ -119,13 +126,20 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) { sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c) - validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) - if err != nil || !validRequest { + if err != nil { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() return } + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid") if err != nil { log.Printf("Error: %s", err.Error()) @@ -178,13 +192,20 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c) _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) - validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) - if err != nil || !validRequest { + if err != nil { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() return } + validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.Abort() + return + } + breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid") if err != nil { log.Printf("Error: %s", err.Error()) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index ef59cd49..5b9b4988 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -145,7 +145,18 @@ func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDef return true } +type dummyFailingTeamProjectAuthz struct{} + +func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { + return false +} + +func (h dummyFailingTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + return false +} + var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz)) +var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz)) type dummyConceptDataModel struct{} @@ -461,6 +472,34 @@ func TestRetriveByIdModelError(t *testing.T) { } } +func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) { + setUp(t) + requestContext := new(gin.Context) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: "1"}) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "cohortid", Value: "1"}) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "breakdownconceptid", Value: "1"}) + + requestContext.Writer = new(tests.CustomResponseWriter) + conceptController.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext) + result := requestContext.Writer.(*tests.CustomResponseWriter) + log.Printf("result: %s", result) + // expect result with dummy data: + if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") { + t.Errorf("Expected data in result") + } + + // the same request should fail if the teamProject authorization fails: + conceptControllerWithFailingTeamProjectAuthz.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } +} + func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) { setUp(t) requestContext := new(gin.Context) @@ -479,6 +518,18 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) { if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") { t.Errorf("Expected data in result") } + + // the same request should fail if the teamProject authorization fails: + requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody)) + conceptControllerWithFailingTeamProjectAuthz.RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } } func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariablesModelError(t *testing.T) { @@ -899,4 +950,16 @@ func TestRetrieveAttritionTable(t *testing.T) { } i++ } + + // the same request should fail if the teamProject authorization fails: + requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody)) + conceptControllerWithFailingTeamProjectAuthz.RetrieveAttritionTable(requestContext) + result = requestContext.Writer.(*tests.CustomResponseWriter) + // expect error: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' as result") + } + if !requestContext.IsAborted() { + t.Errorf("Expected request to be aborted") + } } From cc06ba8c7ee781a7d6507ebe3101a3d988ffb4d9 Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Mon, 11 Dec 2023 16:57:30 +0100 Subject: [PATCH 6/7] fix: fix minor code issues --- controllers/concept.go | 2 +- middlewares/teamprojectauthz.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/controllers/concept.go b/controllers/concept.go index f3bf9072..e8520a7d 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -191,13 +191,13 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c) - _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) if err != nil { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() return } + _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") diff --git a/middlewares/teamprojectauthz.go b/middlewares/teamprojectauthz.go index 90a2e0d8..c7a4230e 100644 --- a/middlewares/teamprojectauthz.go +++ b/middlewares/teamprojectauthz.go @@ -29,7 +29,7 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli httpClient: httpClient, } } -func (u TeamProjectAuthz) HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { +func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { // query Arborist and return as soon as one of the teamProjects access check returns 200: for _, teamProject := range teamProjects { @@ -73,7 +73,7 @@ func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefiniti log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") return false } - if !u.HasAccessToAtLeastOne(ctx, teamProjects) { + if !u.hasAccessToAtLeastOne(ctx, teamProjects) { log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") return false } From faed7c22a96df3913afbf2e642d2e126041764cb Mon Sep 17 00:00:00 2001 From: pieterlukasse Date: Thu, 14 Dec 2023 20:19:55 +0100 Subject: [PATCH 7/7] feat: improve variable name to reflect its real nature ...which is cohort_definition_id...there is no such thing as "cohort id" --- controllers/cohortdata.go | 6 +- models/helper.go | 4 +- tests/controllers_tests/controllers_test.go | 36 +++--- tests/models_tests/models_test.go | 116 ++++++++++---------- tests/utils_tests/utils_test.go | 6 +- utils/parsing.go | 14 +-- 6 files changed, 91 insertions(+), 91 deletions(-) diff --git a/controllers/cohortdata.go b/controllers/cohortdata.go index 3c61fc30..b93ccce2 100644 --- a/controllers/cohortdata.go +++ b/controllers/cohortdata.go @@ -111,7 +111,7 @@ func generateCohortPairsHeaders(cohortPairs []utils.CustomDichotomousVariableDef cohortPairsHeaders := []string{} for _, cohortPair := range cohortPairs { - cohortPairsHeaders = append(cohortPairsHeaders, utils.GetCohortPairKey(cohortPair.CohortId1, cohortPair.CohortId2)) + cohortPairsHeaders = append(cohortPairsHeaders, utils.GetCohortPairKey(cohortPair.CohortDefinitionId1, cohortPair.CohortDefinitionId2)) } return cohortPairsHeaders @@ -298,8 +298,8 @@ func (u CohortDataController) RetrievePeopleIdAndCohort(sourceId int, cohortId i */ personIdToCSVValues := make(map[int64]map[string]string) for _, cohortPair := range cohortPairs { - firstCohortDefinitionId := cohortPair.CohortId1 - secondCohortDefinitionId := cohortPair.CohortId2 + firstCohortDefinitionId := cohortPair.CohortDefinitionId1 + secondCohortDefinitionId := cohortPair.CohortDefinitionId2 cohortPairKey := utils.GetCohortPairKey(firstCohortDefinitionId, secondCohortDefinitionId) firstCohortPeopleData, err1 := u.cohortDataModel.RetrieveDataByOriginalCohortAndNewCohort(sourceId, cohortId, firstCohortDefinitionId) diff --git a/models/helper.go b/models/helper.go index f4639b71..fd02fee4 100644 --- a/models/helper.go +++ b/models/helper.go @@ -44,7 +44,7 @@ func QueryFilterByCohortPairsHelper(filterCohortPairs []utils.CustomDichotomousV "UNION " + "SELECT subject_id FROM " + resultsDataSource.Schema + ".cohort WHERE cohort_definition_id=? " + ")" - idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) + idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2) } // EXCEPTs section: for _, filterCohortPair := range filterCohortPairs { @@ -54,7 +54,7 @@ func QueryFilterByCohortPairsHelper(filterCohortPairs []utils.CustomDichotomousV "INTERSECT " + "SELECT subject_id FROM " + resultsDataSource.Schema + ".cohort WHERE cohort_definition_id=? " + ")" - idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) + idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2) } } unionAndIntersectSQL = unionAndIntersectSQL + diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index 5b9b4988..f5f29a07 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -717,14 +717,14 @@ func TestGetAttritionRowForConceptIdsAndCohortPairs(t *testing.T) { int64(1234), int64(5678), utils.CustomDichotomousVariableDef{ - CohortId1: 1, - CohortId2: 2, - ProvidedName: "testA12"}, + CohortDefinitionId1: 1, + CohortDefinitionId2: 2, + ProvidedName: "testA12"}, int64(2090006880), utils.CustomDichotomousVariableDef{ - CohortId1: 3, - CohortId2: 4, - ProvidedName: "testB34"}, + CohortDefinitionId1: 3, + CohortDefinitionId2: 4, + ProvidedName: "testB34"}, } result, _ := conceptController.GetAttritionRowForConceptIdsAndCohortPairs(sourceId, cohortId, conceptIdsAndCohortPairs, breakdownConceptId, sortedConceptValues) @@ -771,9 +771,9 @@ func TestGenerateCompleteCSV(t *testing.T) { cohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 2, - CohortId2: 3, - ProvidedName: "test"}, + CohortDefinitionId1: 2, + CohortDefinitionId2: 3, + ProvidedName: "test"}, } b := controllers.GenerateCompleteCSV(partialCsv, personIdToCSVValues, cohortPairs) @@ -798,9 +798,9 @@ func TestRetrievePeopleIdAndCohort(t *testing.T) { cohortId := 1 cohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 2, - CohortId2: 3, - ProvidedName: "test"}, + CohortDefinitionId1: 2, + CohortDefinitionId2: 3, + ProvidedName: "test"}, } cohortData := []*models.PersonConceptAndValue{ @@ -839,9 +839,9 @@ func TestRetrievePeopleIdAndCohortNonExistingCohortPair(t *testing.T) { cohortId := 1 cohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 4, - CohortId2: 5, - ProvidedName: "test"}, + CohortDefinitionId1: 4, + CohortDefinitionId2: 5, + ProvidedName: "test"}, } cohortData := []*models.PersonConceptAndValue{ @@ -880,9 +880,9 @@ func TestRetrievePeopleIdAndCohortOverlappingCohortPair(t *testing.T) { cohortId := 1 cohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 1, - CohortId2: 1, - ProvidedName: "test"}, + CohortDefinitionId1: 1, + CohortDefinitionId2: 1, + ProvidedName: "test"}, } cohortData := []*models.PersonConceptAndValue{ diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index b860ced6..e8539d43 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -249,9 +249,9 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // smallestCohort and largestCohort do not overlap... filterCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } resultsDataSource := tests.GetResultsDataSource() var subjectIds []*SubjectId @@ -267,13 +267,13 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // now add a pair that overlaps with largestCohort: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, { - CohortId1: extendedCopyOfSecondLargestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: extendedCopyOfSecondLargestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = largestCohort @@ -289,13 +289,13 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // order doesn't matter: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: extendedCopyOfSecondLargestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: extendedCopyOfSecondLargestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = largestCohort @@ -311,9 +311,9 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // now test with two other cohorts that overlap: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: secondLargestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: secondLargestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = extendedCopyOfSecondLargestCohort @@ -329,13 +329,13 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // now add in the largestCohort as a pair of extendedCopyOfSecondLargestCohort to the mix above: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: secondLargestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: secondLargestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, { - CohortId1: largestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: largestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = extendedCopyOfSecondLargestCohort @@ -376,9 +376,9 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // should return 0: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: largestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: largestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = largestCohort @@ -386,7 +386,7 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { query = models.QueryFilterByCohortPairsHelper(filterCohortPairs, resultsDataSource, population.Id, "unionAndIntersect"). Select("subject_id") _ = query.Scan(&subjectIds) - // in this case we expect overlap the size to be 0, since the pair is composed of the same cohort in CohortId1 and CohortId2 and their overlap is excluded: + // in this case we expect overlap the size to be 0, since the pair is composed of the same cohort in CohortDefinitionId1 and CohortDefinitionId2 and their overlap is excluded: if len(subjectIds) != 0 { t.Errorf("Expected 0 overlap, found %d", len(subjectIds)) } @@ -394,9 +394,9 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) { // should return 0: filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: thirdLargestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: thirdLargestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } subjectIds = []*SubjectId{} population = smallestCohort @@ -417,9 +417,9 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndTwoCohortPai // setting the largest and smallest cohorts here as a pair: filterCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } breakdownConceptId := hareConceptId // not normally the case...but we'll use the same here just for the test... stats, _ := conceptModel.RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, @@ -439,13 +439,13 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndTwoCohortPai // and because of an overlaping person found in the two cohorts of the new pair. filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, { - CohortId1: secondLargestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test2"}, + CohortDefinitionId1: secondLargestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test2"}, } stats, _ = conceptModel.RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, populationCohort.Id, filterIds, filterCohortPairs, breakdownConceptId) @@ -464,9 +464,9 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairsW // setting the same cohort id here (artificial...but just to check if that returns the same value as when this filter is not there): filterCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: secondLargestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: secondLargestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, } breakdownConceptId := hareConceptId // not normally the case...but we'll use the same here just for the test... stats, _ := conceptModel.RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, @@ -499,9 +499,9 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairsW // setting the same cohort id here (artificial...normally it should be two different ids): filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } stats3, _ := conceptModel.RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, secondLargestCohort.Id, filterIds, filterCohortPairs, breakdownConceptId) @@ -569,9 +569,9 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testi cohortDefinitionId := 2 filterCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: largestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: largestCohort.Id, + ProvidedName: "test"}, } uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) @@ -585,9 +585,9 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) { cohortDefinitionId := 2 filterCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 2, - CohortId2: 2, - ProvidedName: "test"}, + CohortDefinitionId1: 2, + CohortDefinitionId2: 2, + ProvidedName: "test"}, } uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) @@ -710,9 +710,9 @@ func TestRetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(t // now filter on the extendedCopyOfSecondLargestCohort filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, } // then we expect histogram data for the overlapping population only (which is 5 for extendedCopyOfSecondLargestCohort and largestCohort): data, _ = cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, largestCohort.Id, histogramConceptId, filterConceptIds, filterCohortPairs) @@ -840,9 +840,9 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) controlCohortId = largestCohort.Id // to ensure we get largestCohort as initial overlap, just repeat the same here... filterCohortPairs = []utils.CustomDichotomousVariableDef{ { - CohortId1: smallestCohort.Id, - CohortId2: extendedCopyOfSecondLargestCohort.Id, - ProvidedName: "test"}, + CohortDefinitionId1: smallestCohort.Id, + CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id, + ProvidedName: "test"}, } // then we expect overlap of 5 for extendedCopyOfSecondLargestCohort and largestCohort: stats, _ = cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(testSourceId, caseCohortId, controlCohortId, diff --git a/tests/utils_tests/utils_test.go b/tests/utils_tests/utils_test.go index c5fa4c32..4498c6ab 100644 --- a/tests/utils_tests/utils_test.go +++ b/tests/utils_tests/utils_test.go @@ -66,9 +66,9 @@ func TestParsePrefixedConceptIdsAndDichotomousIds(t *testing.T) { expectedCohortPairs := []utils.CustomDichotomousVariableDef{ { - CohortId1: 1, - CohortId2: 3, - ProvidedName: "test"}, + CohortDefinitionId1: 1, + CohortDefinitionId2: 3, + ProvidedName: "test"}, } for i, cohortPair := range cohortPairs { diff --git a/utils/parsing.go b/utils/parsing.go index bebcde40..693ce39c 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -94,9 +94,9 @@ type ConceptTypes struct { // fields that define a custom dichotomous variable: type CustomDichotomousVariableDef struct { - CohortId1 int - CohortId2 int - ProvidedName string + CohortDefinitionId1 int + CohortDefinitionId2 int + ProvidedName string } func GetCohortPairKey(firstCohortDefinitionId int, secondCohortDefinitionId int) string { @@ -144,9 +144,9 @@ func ParseConceptIdsAndDichotomousDefsAsSingleList(c *gin.Context) ([]interface{ providedName = variable["provided_name"].(string) } customDichotomousVariableDef := CustomDichotomousVariableDef{ - CohortId1: cohortPair[0], - CohortId2: cohortPair[1], - ProvidedName: providedName, + CohortDefinitionId1: cohortPair[0], + CohortDefinitionId2: cohortPair[1], + ProvidedName: providedName, } conceptIdsAndCohortPairs = append(conceptIdsAndCohortPairs, customDichotomousVariableDef) } @@ -294,7 +294,7 @@ func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterC idsList = append(idsList, cohortDefinitionId) if len(filterCohortPairs) > 0 { for _, filterCohortPair := range filterCohortPairs { - idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) // TODO - rename CohortId1/2 here to cohortDefinition1/2 + idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2) } } uniqueIdsList := MakeUnique(idsList)