Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: integrate extra 'team project' validation for concept endpoints #82

1 change: 1 addition & 0 deletions config/mocktest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arborist_endpoint: 'https://arboristdummyurl'
20 changes: 15 additions & 5 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type ConceptController struct {
conceptModel models.ConceptI
cohortDefinitionModel models.CohortDefinitionI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI) ConceptController {
return ConceptController{conceptModel: conceptModel, cohortDefinitionModel: cohortDefinitionModel}
func NewConceptController(conceptModel models.ConceptI, cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) ConceptController {
return ConceptController{
conceptModel: conceptModel,
cohortDefinitionModel: cohortDefinitionModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u ConceptController) RetriveAllBySourceId(c *gin.Context) {
Expand Down Expand Up @@ -87,7 +93,8 @@ func (u ConceptController) RetrieveInfoBySourceIdAndConceptTypes(c *gin.Context)

func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Context) {
sourceId, cohortId, err := utils.ParseSourceAndCohortId(c)
if err != nil {
validRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId)
if err != nil || !validRequest {
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand All @@ -112,7 +119,8 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co

func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(c *gin.Context) {
sourceId, cohortId, conceptIds, cohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesList(c)
if err != nil {
validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
if err != nil || !validRequest {
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand Down Expand Up @@ -169,7 +177,9 @@ func generateRowForVariable(variableName string, breakdownConceptValuesToPeopleC

func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
sourceId, cohortId, conceptIdsAndCohortPairs, err := utils.ParseSourceIdAndCohortIdAndVariablesAsSingleList(c)
if err != nil {
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
validRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
if err != nil || !validRequest {
log.Printf("Error: %s", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request", "error": err.Error()})
c.Abort()
Expand Down
23 changes: 18 additions & 5 deletions middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ func AuthMiddleware() gin.HandlerFunc {
}

return func(ctx *gin.Context) {
req, err := PrepareNewArboristRequest(ctx, c.GetString("arborist_endpoint"))
req, err := PrepareNewArboristRequest(ctx)
if err != nil {
ctx.AbortWithStatus(500)
log.Printf("Error while preparing Arborist request: %s", err.Error())
return
}
client := &http.Client{}
// send the request to Arborist:
Expand All @@ -44,21 +45,33 @@ func AuthMiddleware() gin.HandlerFunc {
}

// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token
// and then return the URL that can be used to consult Arborist regarding access permissions. This function
// and then return the URL that can be used to consult Arborist regarding cohort-middleware access permissions. This function
// returns an error if "Authorization / Bearer" token is missing in ctx
func PrepareNewArboristRequest(ctx *gin.Context, arboristEndpoint string) (*http.Request, error) {
func PrepareNewArboristRequest(ctx *gin.Context) (*http.Request, error) {

resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path)
service := "cohort-middleware"

return PrepareNewArboristRequestForResourceAndService(ctx, resourcePath, service)
}

// this function will take the request from the given ctx, validated it for the presence of an "Authorization / Bearer" token
// and then return the URL that can be used to consult Arborist regarding access permissions for the given
// resource path and service.
func PrepareNewArboristRequestForResourceAndService(ctx *gin.Context, resourcePath string, service string) (*http.Request, error) {
c := config.GetConfig()
arboristEndpoint := c.GetString("arborist_endpoint")
// validate:
authorization := ctx.Request.Header.Get("Authorization")
if authorization == "" {
return nil, errors.New("missing Authorization header")
}

// build up the request URL string:
resourcePath := fmt.Sprintf("/cohort-middleware%s", ctx.Request.URL.Path)
arboristAuth := fmt.Sprintf("%s/auth/proxy?resource=%s&service=%s&method=%s",
arboristEndpoint,
resourcePath,
"cohort-middleware",
service,
"access")

// make request object / validate URL:
Expand Down
82 changes: 82 additions & 0 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package middlewares

import (
"log"
"net/http"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
}

type HttpClientI interface {
Do(req *http.Request) (*http.Response, error)
}

type TeamProjectAuthz struct {
cohortDefinitionModel models.CohortDefinitionI
httpClient HttpClientI
}

func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpClient HttpClientI) TeamProjectAuthz {
return TeamProjectAuthz{
cohortDefinitionModel: cohortDefinitionModel,
httpClient: httpClient,
}
}
func (u TeamProjectAuthz) HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"

req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
}
}
return false
}

func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
return false
}
if !u.HasAccessToAtLeastOne(ctx, teamProjects) {
log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request")
return false
}
// passed both tests:
return true
}
19 changes: 19 additions & 0 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type CohortDefinitionI interface {
GetAllCohortDefinitions() ([]*CohortDefinition, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error)
GetCohortName(cohortId int) (string, error)
GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error)
GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error)
}

type CohortDefinition struct {
Expand Down Expand Up @@ -67,6 +69,23 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

// Returns any "team project" entries that are matched to _each and every one_ of the
// cohort definition ids found in uniqueCohortDefinitionIdsList.
func (h CohortDefinition) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) {

db2 := db.GetAtlasDB().Db
var teamProjects []string
// Find any roles that are paired to each and every one of the cohort_definition_id values.
// Roles that ony match part of the values are filtered out by the having(count) clause:
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("sec_role_name").
Where("cohort_definition_id in (?)", uniqueCohortDefinitionIdsList).
Group("sec_role_name").
Having("count(DISTINCT cohort_definition_id) = ?", len(uniqueCohortDefinitionIdsList)).
Scan(&teamProjects)
return teamProjects, query.Error
}

// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
Expand Down
3 changes: 1 addition & 2 deletions models/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortId(sourceId int, cohor
// {ConceptValue: "B", NPersonsInCohortWithValue: N-M-X},
// where X is the number of persons that have NO value or just a "null" value for one or more of the ids in the given filterConceptIds.
func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef, breakdownConceptId int64) ([]*ConceptBreakdown, error) {

var dataSourceModel = new(Source)
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)
Expand All @@ -150,8 +151,6 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho
Where("observation.observation_concept_id = ?", breakdownConceptId).
Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId))

// note: here we pass empty []utils.CustomDichotomousVariableDef{} instead of filterCohortPairs, since we already use the SQL generated by QueryFilterByCohortPairsHelper above,
// which is a better performing SQL in this particular scenario:
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")

query, cancel := utils.AddTimeoutToQuery(query)
Expand Down
5 changes: 4 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"net/http"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/controllers"
"github.com/uc-cdis/cohort-middleware/middlewares"
Expand Down Expand Up @@ -33,7 +35,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition))
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/concept/by-source-id/:sourceid", concepts.RetriveAllBySourceId)
authorized.POST("/concept/by-source-id/:sourceid", concepts.RetrieveInfoBySourceIdAndConceptIds)
authorized.POST("/concept/by-source-id/:sourceid/by-type", concepts.RetrieveInfoBySourceIdAndConceptTypes)
Expand Down
20 changes: 19 additions & 1 deletion tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ type dummyCohortDefinitionDataModel struct{}

var dummyModelReturnError bool = false

func (h dummyCohortDefinitionDataModel) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
return []int{1}, nil
}

func (h dummyCohortDefinitionDataModel) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) {
return []string{"test"}, nil
}

func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, error) {
return "dummy cohort name", nil
}
Expand Down Expand Up @@ -127,7 +135,17 @@ func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitions() ([]*models.Coh
return nil, nil
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel))
type dummyTeamProjectAuthz struct{}

func (h dummyTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
return true
}

func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {
return true
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))

type dummyConceptDataModel struct{}

Expand Down
Loading
Loading