Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: integrate extra 'team project' validation for concept endpoints #82

1 change: 1 addition & 0 deletions config/mocktest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arborist_endpoint: 'https://arboristdummyurl'
6 changes: 3 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func generateCohortPairsHeaders(cohortPairs []utils.CustomDichotomousVariableDef
cohortPairsHeaders := []string{}

for _, cohortPair := range cohortPairs {
cohortPairsHeaders = append(cohortPairsHeaders, utils.GetCohortPairKey(cohortPair.CohortId1, cohortPair.CohortId2))
cohortPairsHeaders = append(cohortPairsHeaders, utils.GetCohortPairKey(cohortPair.CohortDefinitionId1, cohortPair.CohortDefinitionId2))
}

return cohortPairsHeaders
Expand Down Expand Up @@ -298,8 +298,8 @@ func (u CohortDataController) RetrievePeopleIdAndCohort(sourceId int, cohortId i
*/
personIdToCSVValues := make(map[int64]map[string]string)
for _, cohortPair := range cohortPairs {
firstCohortDefinitionId := cohortPair.CohortId1
secondCohortDefinitionId := cohortPair.CohortId2
firstCohortDefinitionId := cohortPair.CohortDefinitionId1
secondCohortDefinitionId := cohortPair.CohortDefinitionId2
cohortPairKey := utils.GetCohortPairKey(firstCohortDefinitionId, secondCohortDefinitionId)

firstCohortPeopleData, err1 := u.cohortDataModel.RetrieveDataByOriginalCohortAndNewCohort(sourceId, cohortId, firstCohortDefinitionId)
Expand Down
35 changes: 33 additions & 2 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type ConceptController struct {
conceptModel models.ConceptI
cohortDefinitionModel models.CohortDefinitionI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI) ConceptController {
return ConceptController{conceptModel: conceptModel, cohortDefinitionModel: cohortDefinitionModel}
func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) ConceptController {
return ConceptController{
conceptModel: conceptModel,
cohortDefinitionModel: cohortDefinitionModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u ConceptController) RetriveAllBySourceId(c *gin.Context) {
Expand Down Expand Up @@ -93,6 +99,14 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co
c.Abort()
return
}
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid")
if err != nil {
log.Printf("Error: %s", err.Error())
Expand All @@ -118,6 +132,14 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
c.Abort()
return
}
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid")
if err != nil {
log.Printf("Error: %s", err.Error())
Expand Down Expand Up @@ -175,6 +197,15 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
c.Abort()
return
}
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

breakdownConceptId, err := utils.ParseBigNumericArg(c, "breakdownconceptid")
if err != nil {
log.Printf("Error: %s", err.Error())
Expand Down
23 changes: 18 additions & 5 deletions middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ func AuthMiddleware() gin.HandlerFunc {
}

return func(ctx *gin.Context) {
req, err := PrepareNewArboristRequest(ctx, c.GetString("arborist_endpoint"))
req, err := PrepareNewArboristRequest(ctx)
if err != nil {
ctx.AbortWithStatus(500)
log.Printf("Error while preparing Arborist request: %s", err.Error())
return
}
client := &http.Client{}
// send the request to Arborist:
Expand All @@ -44,21 +45,33 @@ func AuthMiddleware() gin.HandlerFunc {
}

// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token
// and then return the URL that can be used to consult Arborist regarding access permissions. This function
// and then return the URL that can be used to consult Arborist regarding cohort-middleware access permissions. This function
// returns an error if "Authorization / Bearer" token is missing in ctx
func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http.Request, error) {
func PrepareNewArboristRequest(ctx *gin.Context) (*http.Request, error) {

resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path)
service := "cohort-middleware"

return PrepareNewArboristRequestForResourceAndService(ctx, resourcePath, service)
}

// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token
// and then return the URL that can be used to consult Arborist regarding access permissions for the given
// resource path and service.
func PrepareNewArboristRequestForResourceAndService(ctx *gin.Context, resourcePath string, service string) (*http.Request, error) {
c := config.GetConfig()
arboristEndpoint := c.GetString("arborist_endpoint")
// validate:
authorization := ctx.Request.Header.Get("Authorization")
if authorization == "" {
return nil, errors.New("missing Authorization header")
}

// build up the request URL string:
resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path)
arboristAuth := fmt.Sprintf("%s/auth/proxy?resource=%s&service=%s&method=%s",
arboristEndpoint,
resourcePath,
"cohort-middleware",
service,
"access")

// make request object / validate URL:
Expand Down
82 changes: 82 additions & 0 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package middlewares

import (
"log"
"net/http"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
}

type HttpClientI interface {
Do(req *http.Request) (*http.Response, error)
}

type TeamProjectAuthz struct {
cohortDefinitionModel models.CohortDefinitionI
httpClient HttpClientI
}

func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpClient HttpClientI) TeamProjectAuthz {
return TeamProjectAuthz{
cohortDefinitionModel: cohortDefinitionModel,
httpClient: httpClient,
}
}
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"

req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
}
}
return false
}

func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
return false
}
if !u.hasAccessToAtLeastOne(ctx, teamProjects) {
log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request")
return false
}
// passed both tests:
return true
}
19 changes: 19 additions & 0 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type CohortDefinitionI interface {
GetAllCohortDefinitions() ([]*CohortDefinition, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error)
GetCohortName(cohortId int) (string, error)
GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error)
GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error)
}

type CohortDefinition struct {
Expand Down Expand Up @@ -67,6 +69,23 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

// Returns any "team project" entries that are matched to _each and every one_ of the
// cohort definition ids found in uniqueCohortDefinitionIdsList.
func (h CohortDefinition) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) {

db2 := db.GetAtlasDB().Db
var teamProjects []string
// Find any roles that are paired to each and every one of the cohort_definition_id values.
// Roles that ony match part of the values are filtered out by the having(count) clause:
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("sec_role_name").
Where("cohort_definition_id in (?)", uniqueCohortDefinitionIdsList).
Group("sec_role_name").
Having("count(DISTINCT cohort_definition_id) = ?", len(uniqueCohortDefinitionIdsList)).
Scan(&teamProjects)
return teamProjects, query.Error
}

// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
Expand Down
3 changes: 1 addition & 2 deletions models/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortId(sourceId int, cohor
// {ConceptValue: "B", NPersonsInCohortWithValue: N-M-X},
// where X is the number of persons that have NO value or just a "null" value for one or more of the ids in the given filterConceptIds.
func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef, breakdownConceptId int64) ([]*ConceptBreakdown, error) {

var dataSourceModel = new(Source)
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)
Expand All @@ -150,8 +151,6 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho
Where("observation.observation_concept_id = ?", breakdownConceptId).
Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId))

// note: here we pass empty []utils.CustomDichotomousVariableDef{} instead of filterCohortPairs, since we already use the SQL generated by QueryFilterByCohortPairsHelper above,
// which is a better performing SQL in this particular scenario:
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")

query, cancel := utils.AddTimeoutToQuery(query)
Expand Down
4 changes: 2 additions & 2 deletions models/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func QueryFilterByCohortPairsHelper(filterCohortPairs []utils.CustomDichotomousV
"UNION " +
"SELECT subject_id FROM " + resultsDataSource.Schema + ".cohort WHERE cohort_definition_id=? " +
")"
idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2)
idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2)
}
// EXCEPTs section:
for _, filterCohortPair := range filterCohortPairs {
Expand All @@ -54,7 +54,7 @@ func QueryFilterByCohortPairsHelper(filterCohortPairs []utils.CustomDichotomousV
"INTERSECT " +
"SELECT subject_id FROM " + resultsDataSource.Schema + ".cohort WHERE cohort_definition_id=? " +
")"
idsList = append(idsList, filterCohortPair.CohortId1, filterCohortPair.CohortId2)
idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2)
}
}
unionAndIntersectSQL = unionAndIntersectSQL +
Expand Down
5 changes: 4 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"net/http"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/controllers"
"github.com/uc-cdis/cohort-middleware/middlewares"
Expand Down Expand Up @@ -33,7 +35,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition))
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/concept/by-source-id/:sourceid", concepts.RetriveAllBySourceId)
authorized.POST("/concept/by-source-id/:sourceid", concepts.RetrieveInfoBySourceIdAndConceptIds)
authorized.POST("/concept/by-source-id/:sourceid/by-type", concepts.RetrieveInfoBySourceIdAndConceptTypes)
Expand Down
Loading
Loading