Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: integrate Arborist validation for team project for cohort data endpoints AND remove unused endpoints #83

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ cd tests/setup_local_db/
JSON summary data endpoints:
```bash
curl http://localhost:8080/sources | python -m json.tool
curl http://localhost:8080/cohortdefinition-stats/by-source-id/1 | python -m json.tool
curl "http://localhost:8080/cohortdefinition-stats/by-source-id/1/by-team-project?team-project=test" | python -m json.tool
curl http://localhost:8080/concept/by-source-id/1 | python -m json.tool
curl -d '{"ConceptIds":[2000000324,2000006885]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1 | python -m json.tool
curl -d '{"ConceptTypes":["Measurement","Person"]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1/by-type | python -m json.tool
Expand Down
35 changes: 32 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDataController struct {
cohortDataModel models.CohortDataI
cohortDataModel models.CohortDataI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDataController(cohortDataModel models.CohortDataI) CohortDataController {
return CohortDataController{cohortDataModel: cohortDataModel}
func NewCohortDataController(cohortDataModel models.CohortDataI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDataController {
return CohortDataController{
cohortDataModel: cohortDataModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Context) {
Expand All @@ -44,6 +49,14 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
cohortId, _ := strconv.Atoi(cohortIdStr)
histogramConceptId, _ := strconv.ParseInt(histogramIdStr, 10, 64)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

cohortData, err := u.cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId, cohortId, histogramConceptId, filterConceptIds, cohortPairs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving concept details", "error": err.Error()})
Expand Down Expand Up @@ -85,6 +98,14 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
sourceId, _ := strconv.Atoi(sourceIdStr)
cohortId, _ := strconv.Atoi(cohortIdStr)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

// call model method:
cohortData, err := u.cohortDataModel.RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId, cohortId, conceptIds)
if err != nil {
Expand Down Expand Up @@ -230,6 +251,14 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
controlCohortId, errors[2] = utils.ParseNumericArg(c, "controlcohortid")
conceptIds, cohortPairs, errors[3] = utils.ParseConceptIdsAndDichotomousDefs(c)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

if utils.ContainsNonNil(errors) {
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
Expand Down
49 changes: 1 addition & 48 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controllers

import (
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/models"
Expand All @@ -17,56 +16,10 @@ func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinition
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveByName(c *gin.Context) {
cohortDefinitionName := c.Param("name")

if cohortDefinitionName != "" {
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionByName(cohortDefinitionName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"CohortDefinition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
cohortDefinitions, err := u.cohortDefinitionModel.GetAllCohortDefinitions()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
teamProject := c.Param("teamproject")
teamProject := c.Query("team-project")
if teamProject == "" {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"})
c.Abort()
Expand Down
4 changes: 2 additions & 2 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
c.Abort()
return
}
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
Expand Down Expand Up @@ -198,7 +198,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
return
}
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
Expand Down
17 changes: 11 additions & 6 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (

type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
}

type HttpClientI interface {
Expand Down Expand Up @@ -58,16 +59,20 @@ func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects [

func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
return u.TeamProjectValidation(ctx, []int{cohortDefinitionId}, filterCohortPairs)
}

func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds, filterCohortPairs)
return u.TeamProjectValidationForCohortIdsList(ctx, uniqueCohortDefinitionIdsList)
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (1) check if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
func (u TeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
Expand Down
7 changes: 2 additions & 5 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ func NewRouter() *gin.Engine {
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
Expand All @@ -46,7 +43,7 @@ func NewRouter() *gin.Engine {
authorized.POST("/concept-stats/by-source-id/:sourceid/by-cohort-definition-id/:cohortid/breakdown-by-concept-id/:breakdownconceptid/csv", concepts.RetrieveAttritionTable)

// cohort stats and checks:
cohortData := controllers.NewCohortDataController(*new(models.CohortData))
cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
// :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here:
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue)

Expand Down
Loading
Loading