Skip to content

Commit

Permalink
feat: integrate extra 'team project' validation for concept endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
pieterlukasse committed Dec 6, 2023
1 parent d009aa7 commit 9b94400
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 5 deletions.
11 changes: 8 additions & 3 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)
Expand Down Expand Up @@ -87,7 +88,8 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context)

func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) {
sourceId, cohortId, err := utils.ParseSourceAndCohortId(c)
if err != nil {
validRequest := middlewares.TeamProjectValidationForCohort(cohortId)
if err != nil || !validRequest {
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand All @@ -112,7 +114,8 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co

func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) {
sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c)
if err != nil {
validRequest := middlewares.TeamProjectValidation(cohortId, cohortPairs)
if err != nil || !validRequest {
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand Down Expand Up @@ -169,7 +172,9 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC

func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c)
if err != nil {
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
validRequest := middlewares.TeamProjectValidation(cohortId, cohortPairs)
if err != nil || !validRequest {
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand Down
33 changes: 33 additions & 0 deletions middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/config"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

func AuthMiddleware() gin.HandlerFunc {
Expand All @@ -26,6 +28,7 @@ func AuthMiddleware() gin.HandlerFunc {
if err != nil {
ctx.AbortWithStatus(500)
log.Printf("Error while preparing Arborist request: %s", err.Error())
return
}
client := &http.Client{}
// send the request to Arborist:
Expand Down Expand Up @@ -71,3 +74,33 @@ func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http
req.Header.Set("Authorization", authorization)
return req, nil
}

func HasAccessToAtLeastOne(teamProjects []string) bool {
// TODO - query Arborist
return true
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func TeamProjectValidationForCohort(cohortDefinitionId int) bool {
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
return TeamProjectValidation(cohortDefinitionId, filterCohortPairs)
}
func TeamProjectValidation(cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
cohortDefinitionModel := new(models.CohortDefinition)
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
return false
}
if !HasAccessToAtLeastOne(teamProjects) {
log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request")
return false
}
// passed both tests:
return true
}
17 changes: 17 additions & 0 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

// Returns any "team project" entries that are matched to _each and every one_ of the
// cohort definition ids found in uniqueCohortDefinitionIdsList.
func (h CohortDefinition) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) {

db2 := db.GetAtlasDB().Db
var teamProjects []string
// Find any roles that are paired to each and every one of the cohort_definition_id values.
// Roles that ony match part of the values are filtered out by the having(count) clause:
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("sec_role_name").
Where("cohort_definition_id in (?)", uniqueCohortDefinitionIdsList).
Group("sec_role_name").
Having("count(DISTINCT cohort_definition_id) = ?", len(uniqueCohortDefinitionIdsList)).
Scan(&teamProjects)
return teamProjects, query.Error
}

// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
Expand Down
3 changes: 1 addition & 2 deletions models/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortId(sourceId int, cohor
// {ConceptValue: "B", NPersonsInCohortWithValue: N-M-X},
// where X is the number of persons that have NO value or just a "null" value for one or more of the ids in the given filterConceptIds.
func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef, breakdownConceptId int64) ([]*ConceptBreakdown, error) {

var dataSourceModel = new(Source)
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)
Expand All @@ -150,8 +151,6 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho
Where("observation.observation_concept_id = ?", breakdownConceptId).
Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId))

// note: here we pass empty []utils.CustomDichotomousVariableDef{} instead of filterCohortPairs, since we already use the SQL generated by QueryFilterByCohortPairsHelper above,
// which is a better performing SQL in this particular scenario:
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")

query, cancel := utils.AddTimeoutToQuery(query)
Expand Down
32 changes: 32 additions & 0 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,38 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH
}
}

func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testing.T) {
setUp(t)
cohortDefinitionId := 2
filterCohortPairs := []utils.CustomDichotomousVariableDef{
{
CohortId1: smallestCohort.Id,
CohortId2: largestCohort.Id,
ProvidedName: "test"},
}
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" {
t.Errorf("Expected to find only defaultteamproject")
}
}

func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) {
setUp(t)
cohortDefinitionId := 2
filterCohortPairs := []utils.CustomDichotomousVariableDef{
{
CohortId1: 2,
CohortId2: 2,
ProvidedName: "test"},
}
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) != 2 {
t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects)
}
}

func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) {
setUp(t)
testTeamProject := "teamprojectX"
Expand Down
25 changes: 25 additions & 0 deletions utils/parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,28 @@ func ParseSourceIdAndCohortIdAndVariablesAsSingleList(c *gin.Context) (int, int,
}
return sourceId, cohortId, conceptIdsAndCohortPairs, nil
}

func MakeUnique(input []int) []int {
uniqueMap := make(map[int]bool)
var uniqueList []int

for _, num := range input {
if !uniqueMap[num] {
uniqueMap[num] = true
uniqueList = append(uniqueList, num)
}
}
return uniqueList
}

func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterCohortPairs []CustomDichotomousVariableDef) []int {
var idsList []int
idsList = append(idsList, cohortDefinitionId)
if len(filterCohortPairs) > 0 {
for _, filterCohortPair := range filterCohortPairs {
idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2) // TODO - rename CohortId1/2 here to cohortDefinition1/2
}
}
uniqueIdsList := MakeUnique(idsList)
return uniqueIdsList
}

0 comments on commit 9b94400

Please sign in to comment.