diff --git a/controllers/concept.go b/controllers/concept.go index 3f61c4d..6daa655 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/middlewares" "github.com/uc-cdis/cohort-middleware/models" "github.com/uc-cdis/cohort-middleware/utils" ) @@ -87,7 +88,8 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context) func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) { sourceId, cohortId, err := utils.ParseSourceAndCohortId(c) - if err != nil { + validRequest := middlewares.TeamProjectValidationForCohort(c, cohortId) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() @@ -112,7 +114,8 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) { sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c) - if err != nil { + validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() @@ -169,7 +172,9 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c) - if err != nil { + _, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs) + validRequest := middlewares.TeamProjectValidation(c, cohortId, cohortPairs) + if err != nil || !validRequest { log.Printf("Error: %s", err.Error()) c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()}) c.Abort() diff --git a/middlewares/auth.go b/middlewares/auth.go index 96d0912..1787706 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -8,6 +8,8 @@ import ( "github.com/gin-gonic/gin" "github.com/uc-cdis/cohort-middleware/config" + "github.com/uc-cdis/cohort-middleware/models" + "github.com/uc-cdis/cohort-middleware/utils" ) func AuthMiddleware() gin.HandlerFunc { @@ -22,10 +24,11 @@ func AuthMiddleware() gin.HandlerFunc { } return func(ctx *gin.Context) { - req, err := PrepareNewArboristRequest(ctx, c.GetString("arborist_endpoint")) + req, err := PrepareNewArboristRequest(ctx) if err != nil { ctx.AbortWithStatus(500) log.Printf("Error while preparing Arborist request: %s", err.Error()) + return } client := &http.Client{} // send the request to Arborist: @@ -44,9 +47,22 @@ func AuthMiddleware() gin.HandlerFunc { } // this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token -// and then return the URL that can be used to consult Arborist regarding access permissions. This function +// and then return the URL that can be used to consult Arborist regarding cohort-middleware access permissions. This function // returns an error if "Authorization / Bearer" token is missing in ctx -func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http.Request, error) { +func PrepareNewArboristRequest(ctx *gin.Context) (*http.Request, error) { + + resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path) + service := "cohort-middleware" + + return PrepareNewArboristRequestForResourceAndService(ctx, resourcePath, service) +} + +// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token +// and then return the URL that can be used to consult Arborist regarding access permissions for the given +// resource path and service. +func PrepareNewArboristRequestForResourceAndService(ctx *gin.Context, resourcePath string, service string) (*http.Request, error) { + c := config.GetConfig() + arboristEndpoint := c.GetString("arborist_endpoint") // validate: authorization := ctx.Request.Header.Get("Authorization") if authorization == "" { @@ -54,11 +70,10 @@ func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http } // build up the request URL string: - resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path) arboristAuth := fmt.Sprintf("%s/auth/proxy?resource=%s&service=%s&method=%s", arboristEndpoint, resourcePath, - "cohort-middleware", + service, "access") // make request object / validate URL: @@ -71,3 +86,56 @@ func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http req.Header.Set("Authorization", authorization) return req, nil } + +func HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { + + // query Arborist and return as soon as one of the teamProjects access check returns 200: + for _, teamProject := range teamProjects { + teamProjectAsResourcePath := teamProject + teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" + + req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) + if err != nil { + ctx.AbortWithStatus(500) + panic("Error while preparing Arborist request") + } + client := &http.Client{} + // send the request to Arborist: + resp, _ := client.Do(req) + + // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: + if resp.StatusCode == 200 { + return true + } else { + // unauthorized or otherwise: + log.Printf("Got response status %d from Arborist...", resp.StatusCode) + } + } + return false +} + +func TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { + filterCohortPairs := []utils.CustomDichotomousVariableDef{} + return TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs) +} + +// "team project" related checks: +// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project" +// (2) check if the user has permission in the "team project" +// Returns true if both checks above pass, false otherwise. +func TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool { + + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + cohortDefinitionModel := new(models.CohortDefinition) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) == 0 { + log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request") + return false + } + if !HasAccessToAtLeastOne(ctx, teamProjects) { + log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request") + return false + } + // passed both tests: + return true +} diff --git a/models/cohortdefinition.go b/models/cohortdefinition.go index 9893f07..b905ce1 100644 --- a/models/cohortdefinition.go +++ b/models/cohortdefinition.go @@ -67,6 +67,23 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error) return cohortDefinition, meta_result.Error } +// Returns any "team project" entries that are matched to _each and every one_ of the +// cohort definition ids found in uniqueCohortDefinitionIdsList. +func (h CohortDefinition) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) { + + db2 := db.GetAtlasDB().Db + var teamProjects []string + // Find any roles that are paired to each and every one of the cohort_definition_id values. + // Roles that ony match part of the values are filtered out by the having(count) clause: + query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role"). + Select("sec_role_name"). + Where("cohort_definition_id in (?)", uniqueCohortDefinitionIdsList). + Group("sec_role_name"). + Having("count(DISTINCT cohort_definition_id) = ?", len(uniqueCohortDefinitionIdsList)). + Scan(&teamProjects) + return teamProjects, query.Error +} + // Get the list of cohort_definition ids for a given "team project" (where "team project" is basically // a security role name of one of the roles in Atlas/WebAPI database). func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { diff --git a/models/concept.go b/models/concept.go index f8801f5..ddd999d 100644 --- a/models/concept.go +++ b/models/concept.go @@ -138,6 +138,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortId(sourceId int, cohor // {ConceptValue: "B", NPersonsInCohortWithValue: N-M-X}, // where X is the number of persons that have NO value or just a "null" value for one or more of the ids in the given filterConceptIds. func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef, breakdownConceptId int64) ([]*ConceptBreakdown, error) { + var dataSourceModel = new(Source) omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop) resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results) @@ -150,8 +151,6 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho Where("observation.observation_concept_id = ?", breakdownConceptId). Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId)) - // note: here we pass empty []utils.CustomDichotomousVariableDef{} instead of filterCohortPairs, since we already use the SQL generated by QueryFilterByCohortPairsHelper above, - // which is a better performing SQL in this particular scenario: query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation") query, cancel := utils.AddTimeoutToQuery(query) diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index 9e3b280..5a9506d 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -564,6 +564,38 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH } } +func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testing.T) { + setUp(t) + cohortDefinitionId := 2 + filterCohortPairs := []utils.CustomDichotomousVariableDef{ + { + CohortId1: smallestCohort.Id, + CohortId2: largestCohort.Id, + ProvidedName: "test"}, + } + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" { + t.Errorf("Expected to find only defaultteamproject") + } +} + +func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) { + setUp(t) + cohortDefinitionId := 2 + filterCohortPairs := []utils.CustomDichotomousVariableDef{ + { + CohortId1: 2, + CohortId2: 2, + ProvidedName: "test"}, + } + uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs) + teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList) + if len(teamProjects) != 2 { + t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects) + } +} + func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) { setUp(t) testTeamProject := "teamprojectX" diff --git a/utils/parsing.go b/utils/parsing.go index e29c8c6..bebcde4 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -275,3 +275,28 @@ func ParseSourceIdAndCohortIdAndVariablesAsSingleList(c *gin.Context) (int, int, } return sourceId, cohortId, conceptIdsAndCohortPairs, nil } + +func MakeUnique(input []int) []int { + uniqueMap := make(map[int]bool) + var uniqueList []int + + for _, num := range input { + if !uniqueMap[num] { + uniqueMap[num] = true + uniqueList = append(uniqueList, num) + } + } + return uniqueList +} + +func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterCohortPairs []CustomDichotomousVariableDef) []int { + var idsList []int + idsList = append(idsList, cohortDefinitionId) + if len(filterCohortPairs) > 0 { + for _, filterCohortPair := range filterCohortPairs { + idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) // TODO - rename CohortId1/2 here to cohortDefinition1/2 + } + } + uniqueIdsList := MakeUnique(idsList) + return uniqueIdsList +}