Skip to content

Commit

Permalink
feat: add tests for TeamProjectAuthz
Browse files Browse the repository at this point in the history
  • Loading branch information
pieterlukasse committed Dec 8, 2023
1 parent 317e132 commit 522afa0
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 8 deletions.
22 changes: 15 additions & 7 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@ type TeamProjectAuthzI interface {
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
}

type HttpClientI interface {
Do(req *http.Request) (*http.Response, error)
}

type TeamProjectAuthz struct {
cohortDefinitionModel models.CohortDefinitionI
httpClient HttpClientI
}

func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI) TeamProjectAuthz {
return TeamProjectAuthz{cohortDefinitionModel: cohortDefinitionModel}
func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpClient HttpClientI) TeamProjectAuthz {
return TeamProjectAuthz{
cohortDefinitionModel: cohortDefinitionModel,
httpClient: httpClient,
}
}
func hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
func (u TeamProjectAuthz) HasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
Expand All @@ -33,16 +41,16 @@ func hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
client := &http.Client{}
// send the request to Arborist:
resp, _ := client.Do(req)
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
}
}
return false
Expand All @@ -65,7 +73,7 @@ func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefiniti
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
return false
}
if !hasAccessToAtLeastOne(ctx, teamProjects) {
if !u.HasAccessToAtLeastOne(ctx, teamProjects) {
log.Printf("Invalid request error: user does not have access to any of the 'team projects' associated with the cohorts in this request")
return false
}
Expand Down
5 changes: 4 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"net/http"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/controllers"
"github.com/uc-cdis/cohort-middleware/middlewares"
Expand Down Expand Up @@ -33,7 +35,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition), *new(middlewares.TeamProjectAuthz))
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/concept/by-source-id/:sourceid", concepts.RetriveAllBySourceId)
authorized.POST("/concept/by-source-id/:sourceid", concepts.RetrieveInfoBySourceIdAndConceptIds)
authorized.POST("/concept/by-source-id/:sourceid/by-type", concepts.RetrieveInfoBySourceIdAndConceptTypes)
Expand Down
106 changes: 106 additions & 0 deletions tests/middlewares_tests/middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/config"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/tests"
)

Expand Down Expand Up @@ -78,3 +79,108 @@ func TestPrepareNewArboristRequestMissingToken(t *testing.T) {
t.Errorf("Expected error")
}
}

type dummyHttpClient struct {
statusCode int
nrCalls int
}

func (h *dummyHttpClient) Do(req *http.Request) (*http.Response, error) {
h.nrCalls++
return &http.Response{StatusCode: h.statusCode}, nil
}

type dummyCohortDefinitionDataModel struct{}

func (h dummyCohortDefinitionDataModel) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
return nil, nil
}

func (h dummyCohortDefinitionDataModel) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) {
// dummy switch just to support two test scenarios:
if uniqueCohortDefinitionIdsList[0] == 0 {
return nil, nil
} else {
return []string{"teamProject1", "teamProject2"}, nil
}
}

func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, error) {
return "dummy cohort name", nil
}

func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) {
return nil, nil
}
func (h dummyCohortDefinitionDataModel) GetCohortDefinitionById(id int) (*models.CohortDefinition, error) {
return nil, nil
}
func (h dummyCohortDefinitionDataModel) GetCohortDefinitionByName(name string) (*models.CohortDefinition, error) {
return nil, nil
}
func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitions() ([]*models.CohortDefinition, error) {
return nil, nil
}

func TestTeamProjectValidation(t *testing.T) {
setUp(t)
config.Init("mocktest")
arboristAuthzResponseCode := 200
dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode}
teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel),
dummyHttpClient)
requestContext := new(gin.Context)
requestContext.Request = new(http.Request)
requestContext.Request.Header = map[string][]string{
"Authorization": {"dummy_token_value"},
}
result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil)
if result == false {
t.Errorf("Expected TeamProjectValidation result to be 'true'")
}
if dummyHttpClient.nrCalls != 1 {
t.Errorf("Expected dummyHttpClient to have been only once")
}
}

func TestTeamProjectValidationArborist401(t *testing.T) {
setUp(t)
config.Init("mocktest")
arboristAuthzResponseCode := 401
dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode}
teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel),
dummyHttpClient)
requestContext := new(gin.Context)
requestContext.Request = new(http.Request)
requestContext.Request.Header = map[string][]string{
"Authorization": {"dummy_token_value"},
}
result := teamProjectAuthz.TeamProjectValidation(requestContext, 1, nil)
if result == true {
t.Errorf("Expected TeamProjectValidation result to be 'false'")
}
if dummyHttpClient.nrCalls <= 1 {
t.Errorf("Expected dummyHttpClient to have been called more than once")
}
}

func TestTeamProjectValidationNoTeamProjectMatchingAllCohortDefinitions(t *testing.T) {
setUp(t)
config.Init("mocktest")
arboristAuthzResponseCode := 200
dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode}
teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel),
dummyHttpClient)
requestContext := new(gin.Context)
requestContext.Request = new(http.Request)
requestContext.Request.Header = map[string][]string{
"Authorization": {"dummy_token_value"},
}
result := teamProjectAuthz.TeamProjectValidation(requestContext, 0, nil)
if result == true {
t.Errorf("Expected TeamProjectValidation result to be 'false'")
}
if dummyHttpClient.nrCalls > 0 {
t.Errorf("Expected dummyHttpClient to NOT have been called")
}
}

0 comments on commit 522afa0

Please sign in to comment.