Skip to content

Commit

Permalink
Merge pull request #91 from uc-cdis/fix/remove_unnecessary_joins_on_o…
Browse files Browse the repository at this point in the history
…bservation

Feat: remove unnecessary joins on observation
  • Loading branch information
pieterlukasse authored Mar 11, 2024
2 parents f439c23 + 6b46606 commit 5a20d45
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 41 deletions.
4 changes: 2 additions & 2 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func populateConceptValue(row []string, cohortItem models.PersonConceptAndValue,
return row
}

func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(c *gin.Context) {
func (u CohortDataController) RetrieveCohortOverlapStats(c *gin.Context) {
errors := make([]error, 4)
var sourceId, caseCohortId, controlCohortId int
var conceptIds []int64
Expand All @@ -264,7 +264,7 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
c.Abort()
return
}
overlapStats, err := u.cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(sourceId, caseCohortId,
overlapStats, err := u.cohortDataModel.RetrieveCohortOverlapStats(sourceId, caseCohortId,
controlCohortId, conceptIds, cohortPairs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving stats", "error": err.Error()})
Expand Down
14 changes: 6 additions & 8 deletions models/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

type CohortDataI interface {
RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId int, cohortDefinitionId int, conceptIds []int64) ([]*PersonConceptAndValue, error)
RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(sourceId int, caseCohortId int, controlCohortId int, otherFilterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) (CohortOverlapStats, error)
RetrieveCohortOverlapStats(sourceId int, caseCohortId int, controlCohortId int, otherFilterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) (CohortOverlapStats, error)
RetrieveDataByOriginalCohortAndNewCohort(sourceId int, originalCohortDefinitionId int, cohortDefinitionId int) ([]*PersonIdAndCohort, error)
RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId int, cohortDefinitionId int, histogramConceptId int64, filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) ([]*PersonConceptAndValue, error)
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (h CohortData) RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCo
Where("observation.observation_concept_id = ?", histogramConceptId).
Where("observation.value_as_number is not null")

query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "unionAndIntersect.subject_id")

query, cancel := utils.AddTimeoutToQuery(query)
defer cancel()
Expand All @@ -110,22 +110,20 @@ func (h CohortData) RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCo
}

// Basically the same as the method above, but without the extra filtering on filterConceptId and filterConceptValue:
func (h CohortData) RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(sourceId int, caseCohortId int, controlCohortId int,
otherFilterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) (CohortOverlapStats, error) {
func (h CohortData) RetrieveCohortOverlapStats(sourceId int, caseCohortId int, controlCohortId int,
filterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) (CohortOverlapStats, error) {

var dataSourceModel = new(Source)
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)

// count persons that are in the intersection of both case and control cohorts, filtering on filterConceptValue:
var cohortOverlapStats CohortOverlapStats
query := QueryFilterByCohortPairsHelper(filterCohortPairs, resultsDataSource, caseCohortId, "case_cohort_unionedAndIntersectedWithFilters").
Select("count(distinct(case_cohort_unionedAndIntersectedWithFilters.subject_id)) as case_control_overlap").
Joins("INNER JOIN " + resultsDataSource.Schema + ".cohort as control_cohort ON control_cohort.subject_id = case_cohort_unionedAndIntersectedWithFilters.subject_id") // this one allows for the intersection between case and control and the assessment of the overlap

if len(otherFilterConceptIds) > 0 {
query = query.Joins("INNER JOIN " + omopDataSource.Schema + ".observation_continuous as observation" + omopDataSource.GetViewDirective() + " ON control_cohort.subject_id = observation.person_id")
query = QueryFilterByConceptIdsHelper(query, sourceId, otherFilterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
if len(filterConceptIds) > 0 {
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "control_cohort.subject_id")
}
query = query.Where("control_cohort.cohort_definition_id = ?", controlCohortId)
query, cancel := utils.AddTimeoutToQuery(query)
Expand Down
2 changes: 1 addition & 1 deletion models/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho
Where("observation.observation_concept_id = ?", breakdownConceptId).
Where(GetConceptValueNotNullCheckBasedOnConceptType("observation", sourceId, breakdownConceptId))

query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "unionAndIntersect.subject_id")

query, cancel := utils.AddTimeoutToQuery(query)
defer cancel()
Expand Down
4 changes: 2 additions & 2 deletions models/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (
// * It was added here to make it reusable, given these filters need to be added to many of the queries that take in
// a list of filters in the form of concept ids.
func QueryFilterByConceptIdsHelper(query *gorm.DB, sourceId int, filterConceptIds []int64,
omopDataSource *utils.DbAndSchema, resultSchemaName string, mainObservationTableAlias string) *gorm.DB {
omopDataSource *utils.DbAndSchema, resultSchemaName string, personIdFieldForObservationJoin string) *gorm.DB {
// iterate over the filterConceptIds, adding a new INNER JOIN and filters for each, so that the resulting set is the
// set of persons that have a non-null value for each and every one of the concepts:
for i, filterConceptId := range filterConceptIds {
observationTableAlias := fmt.Sprintf("observation_filter_%d", i)
log.Printf("Adding extra INNER JOIN with alias %s", observationTableAlias)
query = query.Joins("INNER JOIN "+omopDataSource.Schema+".observation_continuous as "+observationTableAlias+omopDataSource.GetViewDirective()+" ON "+observationTableAlias+".person_id = "+mainObservationTableAlias+".person_id").
query = query.Joins("INNER JOIN "+omopDataSource.Schema+".observation_continuous as "+observationTableAlias+omopDataSource.GetViewDirective()+" ON "+observationTableAlias+".person_id = "+personIdFieldForObservationJoin).
Where(observationTableAlias+".observation_concept_id = ?", filterConceptId).
Where(GetConceptValueNotNullCheckBasedOnConceptType(observationTableAlias, sourceId, filterConceptId))
}
Expand Down
2 changes: 1 addition & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewRouter() *gin.Engine {
// cohort stats and checks:
cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
// :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here:
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue)
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStats)

// full data endpoints:
authorized.POST("/cohort-data/by-source-id/:sourceid/by-cohort-definition-id/:cohortid", cohortData.RetrieveDataBySourceIdAndCohortIdAndVariables)
Expand Down
12 changes: 6 additions & 6 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (h dummyCohortDataModel) RetrieveHistogramDataBySourceIdAndCohortIdAndConce
return cohortData, nil
}

func (h dummyCohortDataModel) RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(sourceId int, caseCohortId int, controlCohortId int,
func (h dummyCohortDataModel) RetrieveCohortOverlapStats(sourceId int, caseCohortId int, controlCohortId int,
otherFilterConceptIds []int64, filterCohortPairs []utils.CustomDichotomousVariableDef) (models.CohortOverlapStats, error) {
var zeroOverlap models.CohortOverlapStats
return zeroOverlap, nil
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndVariablesCorrectParams(t *testing.T
}
}

func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) {
func TestRetrieveCohortOverlapStats(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
Expand All @@ -346,7 +346,7 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
requestBody := "{\"variables\":[{\"variable_type\": \"concept\", \"concept_id\": 2000000324},{\"variable_type\": \"concept\", \"concept_id\": 2000006885}]}"
requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody))

cohortDataController.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(requestContext)
cohortDataController.RetrieveCohortOverlapStats(requestContext)
// Params above are correct, so request should NOT abort:
if requestContext.IsAborted() {
t.Errorf("Did not expect this request to abort")
Expand All @@ -358,7 +358,7 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)

// the same request should fail if the teamProject authorization fails:
requestContext.Request.Body = io.NopCloser(strings.NewReader(requestBody))
cohortDataControllerWithFailingTeamProjectAuthz.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(requestContext)
cohortDataControllerWithFailingTeamProjectAuthz.RetrieveCohortOverlapStats(requestContext)
result = requestContext.Writer.(*tests.CustomResponseWriter)
// expect error:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
Expand All @@ -369,13 +369,13 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
}
}

func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValueBadRequest(t *testing.T) {
func TestRetrieveCohortOverlapStatsBadRequest(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Writer = new(tests.CustomResponseWriter)

cohortDataController.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(requestContext)
cohortDataController.RetrieveCohortOverlapStats(requestContext)
// Params above are incorrect, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected this request to abort")
Expand Down
65 changes: 44 additions & 21 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) {
query = models.QueryFilterByCohortPairsHelper(filterCohortPairs, resultsDataSource, population.Id, "unionAndIntersect").
Select("subject_id")
_ = query.Scan(&subjectIds)
// in this case we expect overlap the size of the largestCohort-5:
if len(subjectIds) != (largestCohort.CohortSize - 5) {
// in this case we expect overlap the size of the largestCohort-6 (where 6 is the size of the overlap between extendedCopyOfSecondLargestCohort and largestCohort):
if len(subjectIds) != (largestCohort.CohortSize - 6) {
t.Errorf("Expected %d overlap, found %d", largestCohort.CohortSize-5, len(subjectIds))
}

Expand All @@ -304,7 +304,7 @@ func TestQueryFilterByCohortPairsHelper(t *testing.T) {
Select("subject_id")
_ = query.Scan(&subjectIds)
// in this case we expect same as previous test above:
if len(subjectIds) != (largestCohort.CohortSize - 5) {
if len(subjectIds) != (largestCohort.CohortSize - 6) {
t.Errorf("Expected %d overlap, found %d", largestCohort.CohortSize-5, len(subjectIds))
}

Expand Down Expand Up @@ -495,8 +495,10 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairsW
if len(stats) > len(stats2) {
t.Errorf("First query is more restrictive, so its stats should not be larger than stats2 of second query. Got %d and %d", len(stats), len(stats2))
}
// test filtering with smallest cohort, lenght should be 1, since that's the size of the smallest cohort:
// setting the same cohort id here (artificial...normally it should be two different ids):

// test filtering with secondLargestCohort, smallest and largestCohort.
// Lenght of result set should be 2 persons (one HIS, one ASN), since there is a overlap of 1 between secondLargestCohort and smallest cohort,
// and overlap of 2 between secondLargestCohort and largestCohort, BUT only 1 has a HARE value:
filterCohortPairs = []utils.CustomDichotomousVariableDef{
{
CohortDefinitionId1: smallestCohort.Id,
Expand All @@ -506,7 +508,14 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairsW
stats3, _ := conceptModel.RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId,
secondLargestCohort.Id, filterIds, filterCohortPairs, breakdownConceptId)
if len(stats3) != 2 {
t.Errorf("Expected only two items in resultset, found %d", len(stats))
t.Errorf("Expected only two items in resultset, found %d", len(stats3))
}
countPersons := 0
for _, stat := range stats3 {
countPersons = countPersons + stat.NpersonsInCohortWithValue
}
if countPersons != 2 {
t.Errorf("Expected only two persons in resultset, found %d", countPersons)
}
}

Expand Down Expand Up @@ -731,9 +740,9 @@ func TestRetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(t
filterConceptIds := []int64{}
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
data, _ := cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(testSourceId, largestCohort.Id, histogramConceptId, filterConceptIds, filterCohortPairs)
// everyone in the largestCohort has the histogramConceptId:
if len(data) != largestCohort.CohortSize {
t.Errorf("expected 10 or more histogram data but got %d", len(data))
// everyone in the largestCohort has the histogramConceptId, but one person has NULL in the value_as_number:
if len(data) != largestCohort.CohortSize-1 {
t.Errorf("expected %d histogram data but got %d", largestCohort.CohortSize, len(data))
}

// now filter on the extendedCopyOfSecondLargestCohort
Expand Down Expand Up @@ -766,15 +775,15 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) {
// Subtest1: correct alias "observation":
query := omopDataSource.Db.Table(omopDataSource.Schema + ".observation_continuous as observation" + omopDataSource.GetViewDirective()).
Select("observation.person_id")
query = models.QueryFilterByConceptIdsHelper(query, testSourceId, filterConceptIds, omopDataSource, "", "observation")
query = models.QueryFilterByConceptIdsHelper(query, testSourceId, filterConceptIds, omopDataSource, "", "observation.person_id")
meta_result := query.Scan(&personIds)
if meta_result.Error != nil {
t.Errorf("Did NOT expect an error")
}
// Subtest2: incorrect alias "observation"...should fail:
query = omopDataSource.Db.Table(omopDataSource.Schema + ".observation_continuous as observationWRONG").
Select("*")
query = models.QueryFilterByConceptIdsHelper(query, testSourceId, filterConceptIds, omopDataSource, "", "observation")
query = models.QueryFilterByConceptIdsHelper(query, testSourceId, filterConceptIds, omopDataSource, "", "observation.person_id")
meta_result = query.Scan(&personIds)
if meta_result.Error == nil {
t.Errorf("Expected an error")
Expand Down Expand Up @@ -850,14 +859,14 @@ func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId
tests.FixSomething(models.Results, "cohort", "cohort_definition_id")
}

func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T) {
func TestRetrieveCohortOverlapStats(t *testing.T) {
// Tests if we get the expected overlap
setUp(t)
caseCohortId := secondLargestCohort.Id
controlCohortId := secondLargestCohort.Id // to ensure we get some overlap, just repeat the same here...
otherFilterConceptIds := []int64{}
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
stats, _ := cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(testSourceId, caseCohortId, controlCohortId,
stats, _ := cohortDataModel.RetrieveCohortOverlapStats(testSourceId, caseCohortId, controlCohortId,
otherFilterConceptIds, filterCohortPairs)
// basic test:
if stats.CaseControlOverlap != int64(secondLargestCohort.CohortSize) {
Expand All @@ -873,11 +882,11 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
CohortDefinitionId2: extendedCopyOfSecondLargestCohort.Id,
ProvidedName: "test"},
}
// then we expect overlap of 5 for extendedCopyOfSecondLargestCohort and largestCohort:
stats, _ = cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(testSourceId, caseCohortId, controlCohortId,
// then we expect overlap of 6 for extendedCopyOfSecondLargestCohort and largestCohort:
stats, _ = cohortDataModel.RetrieveCohortOverlapStats(testSourceId, caseCohortId, controlCohortId,
otherFilterConceptIds, filterCohortPairs)
if stats.CaseControlOverlap != 5 {
t.Errorf("Expected nr persons to be %d, found %d", 5, stats.CaseControlOverlap)
if stats.CaseControlOverlap != 6 {
t.Errorf("Expected nr persons to be %d, found %d", 6, stats.CaseControlOverlap)
}

// extra test: different parameters that should return the same as above ^:
Expand All @@ -886,10 +895,10 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
filterCohortPairs = []utils.CustomDichotomousVariableDef{}
otherFilterConceptIds = []int64{histogramConceptId} // extra filter, to cover this part of the code...
// then we expect overlap of 5 for extendedCopyOfSecondLargestCohort and largestCohort (the filter on histogramConceptId should not matter
// since all in largestCohort have an observation for this concept id):
stats2, _ := cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(testSourceId, caseCohortId, controlCohortId,
// since all in largestCohort have an observation for this concept id except one person who has it but has value_as_number as NULL):
stats2, _ := cohortDataModel.RetrieveCohortOverlapStats(testSourceId, caseCohortId, controlCohortId,
otherFilterConceptIds, filterCohortPairs)
if stats2.CaseControlOverlap != stats.CaseControlOverlap {
if stats2.CaseControlOverlap != stats.CaseControlOverlap-1 {
t.Errorf("Expected nr persons to be %d, found %d", stats.CaseControlOverlap, stats2.CaseControlOverlap)
}

Expand All @@ -898,7 +907,7 @@ func TestRetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(t *testing.T)
otherFilterConceptIds = []int64{histogramConceptId, dummyContinuousConceptId}
// all other arguments are the same as test above, and we expect overlap of 0, showing the otherFilterConceptIds
// had the expected effect:
stats3, _ := cohortDataModel.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(testSourceId, caseCohortId, controlCohortId,
stats3, _ := cohortDataModel.RetrieveCohortOverlapStats(testSourceId, caseCohortId, controlCohortId,
otherFilterConceptIds, filterCohortPairs)
if stats3.CaseControlOverlap != 0 {
t.Errorf("Expected nr persons to be 0, found %d", stats3.CaseControlOverlap)
Expand Down Expand Up @@ -1014,3 +1023,17 @@ func TestAddTimeoutToQuery(t *testing.T) {
t.Errorf("Expected result and NO error")
}
}

func TestPersonConceptAndCountString(t *testing.T) {
a := models.PersonConceptAndCount{
PersonId: 1,
ConceptId: 2,
Count: 3,
}

expected := "(person_id=1, concept_id=2, count=3)"
if a.String() != expected {
t.Errorf("Expected %s, found %s", expected, a.String())
}

}
Loading

0 comments on commit 5a20d45

Please sign in to comment.