Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: integrate Arborist validation for team project for cohort data endpoints AND remove unused endpoints #83

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDataController struct {
cohortDataModel models.CohortDataI
cohortDataModel models.CohortDataI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDataController(cohortDataModel models.CohortDataI) CohortDataController {
return CohortDataController{cohortDataModel: cohortDataModel}
func NewCohortDataController(cohortDataModel models.CohortDataI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDataController {
return CohortDataController{
cohortDataModel: cohortDataModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Context) {
Expand All @@ -44,6 +49,15 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
cohortId, _ := strconv.Atoi(cohortIdStr)
histogramConceptId, _ := strconv.ParseInt(histogramIdStr, 10, 64)

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortId}, cohortPairs)
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohortIdsList(c, uniqueCohortDefinitionIdsList)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

cohortData, err := u.cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId, cohortId, histogramConceptId, filterConceptIds, cohortPairs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving concept details", "error": err.Error()})
Expand Down Expand Up @@ -85,6 +99,15 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
sourceId, _ := strconv.Atoi(sourceIdStr)
cohortId, _ := strconv.Atoi(cohortIdStr)

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortId}, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohortIdsList(c, uniqueCohortDefinitionIdsList)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

// call model method:
cohortData, err := u.cohortDataModel.RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId, cohortId, conceptIds)
if err != nil {
Expand Down Expand Up @@ -230,6 +253,15 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
controlCohortId, errors[2] = utils.ParseNumericArg(c, "controlcohortid")
conceptIds, cohortPairs, errors[3] = utils.ParseConceptIdsAndDichotomousDefs(c)

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{caseCohortId, controlCohortId}, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohortIdsList(c, uniqueCohortDefinitionIdsList)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

if utils.ContainsNonNil(errors) {
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
Expand Down
46 changes: 0 additions & 46 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controllers

import (
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/models"
Expand All @@ -17,51 +16,6 @@ func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinition
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveByName(c *gin.Context) {
cohortDefinitionName := c.Param("name")

if cohortDefinitionName != "" {
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionByName(cohortDefinitionName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"CohortDefinition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
cohortDefinitions, err := u.cohortDefinitionModel.GetAllCohortDefinitions()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

Expand Down
13 changes: 9 additions & 4 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
}

type HttpClientI interface {
Expand Down Expand Up @@ -61,13 +62,17 @@ func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohor
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
}

func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
return u.TeamProjectValidationForCohortIdsList(ctx, uniqueCohortDefinitionIdsList)
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (1) check if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
func (u TeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
Expand Down
5 changes: 1 addition & 4 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ func NewRouter() *gin.Engine {
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
Expand All @@ -46,7 +43,7 @@ func NewRouter() *gin.Engine {
authorized.POST("/concept-stats/by-source-id/:sourceid/by-cohort-definition-id/:cohortid/breakdown-by-concept-id/:breakdownconceptid/csv", concepts.RetrieveAttritionTable)

// cohort stats and checks:
cohortData := controllers.NewCohortDataController(*new(models.CohortData))
cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
// :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here:
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue)

Expand Down
86 changes: 46 additions & 40 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func tearDown() {
log.Println("teardown for test")
}

var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortDataModel))
var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyTeamProjectAuthz))
var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz))

// instance of the controller that talks to the regular model implementation (that needs a real DB):
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
Expand Down Expand Up @@ -145,6 +146,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDef
return true
}

func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
return true
}

type dummyFailingTeamProjectAuthz struct{}

func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
Expand All @@ -155,6 +160,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, co
return false
}

func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
return false
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

Expand Down Expand Up @@ -258,6 +267,18 @@ func TestRetrieveHistogramForCohortIdAndConceptIdWithCorrectParams(t *testing.T)
if !strings.Contains(result.CustomResponseWriterOut, "bins") {
t.Errorf("Expected output starting with 'bins,...'")
}

// the same request should fail if the teamProject authorization fails:
requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody))
cohortDataControllerWithFailingTeamProjectAuthz.RetrieveHistogramForCohortIdAndConceptId(requestContext)
result = requestContext.Writer.(*tests.CustomResponseWriter)
// expect error:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' as result")
}
if !requestContext.IsAborted() {
t.Errorf("Expected request to be aborted")
}
}

func TestRetrieveDataBySourceIdAndCohortIdAndVariablesWrongParams(t *testing.T) {
Expand Down Expand Up @@ -290,6 +311,18 @@ func TestRetrieveDataBySourceIdAndCohortIdAndVariablesCorrectParams(t *testing.T
if !strings.Contains(result.CustomResponseWriterOut, "sample.id,") {
t.Errorf("Expected output starting with 'sample.id,...'")
}

// the same request should fail if the teamProject authorization fails:
requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody))
cohortDataControllerWithFailingTeamProjectAuthz.RetrieveDataBySourceIdAndCohortIdAndVariables(requestContext)
result = requestContext.Writer.(*tests.CustomResponseWriter)
// expect error:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' as result")
}
if !requestContext.IsAborted() {
t.Errorf("Expected request to be aborted")
}
}

func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) {
Expand All @@ -312,6 +345,18 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
if !strings.Contains(result.CustomResponseWriterOut, "case_control_overlap") {
t.Errorf("Expected output containing 'case_control_overlap...'")
}

// the same request should fail if the teamProject authorization fails:
requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody))
cohortDataControllerWithFailingTeamProjectAuthz.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(requestContext)
result = requestContext.Writer.(*tests.CustomResponseWriter)
// expect error:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' as result")
}
if !requestContext.IsAborted() {
t.Errorf("Expected request to be aborted")
}
}

func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValueBadRequest(t *testing.T) {
Expand Down Expand Up @@ -433,45 +478,6 @@ func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
}
}

func TestRetriveByIdWrongParam(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveById(requestContext)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
}

func TestRetriveById(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveById(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "test 1") {
t.Errorf("Expected data in result")
}
}

func TestRetriveByIdModelError(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"})
requestContext.Writer = new(tests.CustomResponseWriter)
// set flag to let mock model layer return error instead of mock data:
dummyModelReturnError = true
cohortDefinitionController.RetriveById(requestContext)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
}

func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
Expand Down
4 changes: 2 additions & 2 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testi
CohortDefinitionId2: largestCohort.Id,
ProvidedName: "test"},
}
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" {
t.Errorf("Expected to find only defaultteamproject")
Expand All @@ -589,7 +589,7 @@ func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) {
CohortDefinitionId2: 2,
ProvidedName: "test"},
}
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) != 2 {
t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects)
Expand Down
4 changes: 2 additions & 2 deletions utils/parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ func MakeUnique(input []int) []int {
return uniqueList
}

func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId int, filterCohortPairs []CustomDichotomousVariableDef) []int {
func GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds []int, filterCohortPairs []CustomDichotomousVariableDef) []int {
var idsList []int
idsList = append(idsList, cohortDefinitionId)
idsList = append(idsList, cohortDefinitionIds...)
if len(filterCohortPairs) > 0 {
for _, filterCohortPair := range filterCohortPairs {
idsList = append(idsList, filterCohortPair.CohortDefinitionId1, filterCohortPair.CohortDefinitionId2)
Expand Down
Loading