Skip to content

Commit

Permalink
fix: spread embeds failing when using the "count()" aggregate without…
Browse files Browse the repository at this point in the history
… a field - @laurenceisla

- Fixed "column reference <col> is ambiguous" error when selecting "?select=...table(col,count())"
- Fixed "column <json_aggregate>.<alias> does not exist" error when selecting "?select=...table(aias:count())"
  • Loading branch information
laurenceisla committed Aug 21, 2024
1 parent 63e6168 commit e44fbdd
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

- #3693, Prevent spread embedding to allow aggregates when they are disabled - @laurenceisla
- #3693, A nested spread embedding now correctly groups by the fields of its top parent relationship - @laurenceisla
- #3693, Fix spread embedding errors when using the `count()` aggregate without a field - @laurenceisla
+ Fixed `"column reference <col> is ambiguous"` error when selecting `?select=...table(col,count())`
+ Fixed `"column <json_aggregate>.<alias> does not exist"` error when selecting `?select=...table(aias:count())`

### Changed

Expand Down
36 changes: 24 additions & 12 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ data ResolverContext = ResolverContext
}

resolveColumnField :: Column -> CoercibleField
resolveColumnField col = CoercibleField (colName col) mempty False (colNominalType col) Nothing (colDefault col)
resolveColumnField col = CoercibleField (colName col) mempty False (colNominalType col) Nothing (colDefault col) False

resolveTableFieldName :: Table -> FieldName -> CoercibleField
resolveTableFieldName table fieldName =
Expand Down Expand Up @@ -380,11 +380,19 @@ addAliases = Right . fmap addAliasToPlan

aliasSelectField :: CoercibleSelectField -> CoercibleSelectField
aliasSelectField field@CoercibleSelectField{csField=fieldDetails, csAggFunction=aggFun, csAlias=alias}
| isJust alias || isJust aggFun = field
| isJust alias = field
| isJust aggFun = fieldAliasForSpreadAgg field
| isJsonKeyPath fieldDetails, Just key <- lastJsonKey fieldDetails = field { csAlias = Just key }
| isTransformPath fieldDetails = field { csAlias = Just (cfName fieldDetails) }
| otherwise = field

-- A request like: `/top_table?select=...middle_table(...nested_table(count()))` will `SELECT` the full row instead of `*`,
-- because doing a `COUNT(*)` in `top_table` would not return the desired results.
-- So we use the "count" alias if none is present since the field name won't be selected.
fieldAliasForSpreadAgg field
| cfFullRow (csField field) = field { csAlias = Just "count" }
| otherwise = field

isJsonKeyPath CoercibleField{cfJsonPath=(_: _)} = True
isJsonKeyPath _ = False

Expand Down Expand Up @@ -428,26 +436,30 @@ expandStars ctx rPlanTree = Right $ expandStarsForReadPlan False rPlanTree
adjustContext context fromQI _ = context{qi=fromQI}

expandStarsForTable :: ResolverContext -> Bool -> ReadPlan -> ReadPlan
expandStarsForTable ctx@ResolverContext{representations, outputType} hasAgg rp@ReadPlan{select=selectFields}
expandStarsForTable ctx@ResolverContext{representations, outputType} hasAgg rp@ReadPlan{select=selectFields, relIsSpread=isSpread}
-- We expand if either of the below are true:
-- * We have a '*' select AND there is an aggregate function in this ReadPlan's sub-tree.
-- * We have a '*' select AND the target table has at least one data representation.
-- We ignore any '*' selects that have an aggregate function attached (i.e for COUNT(*)).
| hasStarSelect && (hasAgg || hasDataRepresentation) = rp{select = concatMap (expandStarSelectField knownColumns) selectFields}
-- We ignore '*' selects that have an aggregate function attached, unless it's a `COUNT(*)` for a Spread Embed,
-- we tag it as "full row" in that case.
| hasStarSelect && (hasAgg || hasDataRepresentation) = rp{select = concatMap (expandStarSelectField isSpread knownColumns) selectFields}
| otherwise = rp
where
hasStarSelect = "*" `elem` map (cfName . csField) filteredSelectFields
filteredSelectFields = filter (isNothing . csAggFunction) selectFields
filteredSelectFields = filter (shouldExpandOrTag . csAggFunction) selectFields
shouldExpandOrTag aggFunc = isNothing aggFunc || (isSpread && aggFunc == Just Count)
hasDataRepresentation = any hasOutputRep knownColumns
knownColumns = knownColumnsInContext ctx

hasOutputRep :: Column -> Bool
hasOutputRep col = HM.member (colNominalType col, outputType) representations

expandStarSelectField :: [Column] -> CoercibleSelectField -> [CoercibleSelectField]
expandStarSelectField columns sel@CoercibleSelectField{csField=CoercibleField{cfName="*", cfJsonPath=[]}, csAggFunction=Nothing} =
expandStarSelectField :: Bool -> [Column] -> CoercibleSelectField -> [CoercibleSelectField]
expandStarSelectField _ columns sel@CoercibleSelectField{csField=CoercibleField{cfName="*", cfJsonPath=[]}, csAggFunction=Nothing} =
map (\col -> sel { csField = withOutputFormat ctx $ resolveColumnField col }) columns
expandStarSelectField _ selectField = [selectField]
expandStarSelectField True _ sel@CoercibleSelectField{csField=fld@CoercibleField{cfName="*", cfJsonPath=[]}, csAggFunction=Just Count} =
[sel { csField = fld { cfFullRow = True } }]
expandStarSelectField _ _ selectField = [selectField]

-- | Enforces the `max-rows` config on the result
treeRestrictRange :: Maybe Integer -> Action -> ReadPlanTree -> Either ApiRequestError ReadPlanTree
Expand Down Expand Up @@ -823,7 +835,7 @@ addRelatedOrders (Node rp@ReadPlan{order,from} forest) = do
-- where_ = [
-- CoercibleStmnt (
-- CoercibleFilter {
-- field = CoercibleField {cfName = "projects", cfJsonPath = [], cfToJson=False, cfIRType = "", cfTransform = Nothing, cfDefault = Nothing},
-- field = CoercibleField {cfName = "projects", cfJsonPath = [], cfToJson=False, cfIRType = "", cfTransform = Nothing, cfDefault = Nothing, cfFullRow = False},
-- opExpr = op
-- }
-- )
Expand All @@ -839,7 +851,7 @@ addRelatedOrders (Node rp@ReadPlan{order,from} forest) = do
-- Don't do anything to the filter if there's no embedding (a subtree) on projects. Assume it's a normal filter.
--
-- >>> ReadPlan.where_ . rootLabel <$> addNullEmbedFilters (readPlanTree nullOp [])
-- Right [CoercibleStmnt (CoercibleFilter {field = CoercibleField {cfName = "projects", cfJsonPath = [], cfToJson = False, cfIRType = "", cfTransform = Nothing, cfDefault = Nothing}, opExpr = OpExpr True (Is TriNull)})]
-- Right [CoercibleStmnt (CoercibleFilter {field = CoercibleField {cfName = "projects", cfJsonPath = [], cfToJson = False, cfIRType = "", cfTransform = Nothing, cfDefault = Nothing, cfFullRow = False}, opExpr = OpExpr True (Is TriNull)})]
--
-- If there's an embedding on projects, then change the filter to use the internal aggregate name (`clients_projects_1`) so the filter can succeed later.
--
Expand All @@ -858,7 +870,7 @@ addNullEmbedFilters (Node rp@ReadPlan{where_=curLogic} forest) = do
newNullFilters rPlans = \case
(CoercibleExpr b lOp trees) ->
CoercibleExpr b lOp <$> (newNullFilters rPlans `traverse` trees)
flt@(CoercibleStmnt (CoercibleFilter (CoercibleField fld [] _ _ _ _) opExpr)) ->
flt@(CoercibleStmnt (CoercibleFilter (CoercibleField fld [] _ _ _ _ _) opExpr)) ->
let foundRP = find (\ReadPlan{relName, relAlias} -> fld == fromMaybe relName relAlias) rPlans in
case (foundRP, opExpr) of
(Just ReadPlan{relAggAlias}, OpExpr b (Is TriNull)) -> Right $ CoercibleStmnt $ CoercibleFilterNullEmbed b relAggAlias
Expand Down
3 changes: 2 additions & 1 deletion src/PostgREST/Plan/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ data CoercibleField = CoercibleField
, cfIRType :: Text -- ^ The native Postgres type of the field, the intermediate (IR) type before mapping.
, cfTransform :: Maybe TransformerProc -- ^ The optional mapping from irType -> targetType.
, cfDefault :: Maybe Text
, cfFullRow :: Bool -- ^ True if the field represents the whole selected row. Used in spread rels: instead of COUNT(*), it does a COUNT(<row>) in order to not mix with other spreaded resources.
} deriving (Eq, Show)

unknownField :: FieldName -> JsonPath -> CoercibleField
unknownField name path = CoercibleField name path False "" Nothing Nothing
unknownField name path = CoercibleField name path False "" Nothing Nothing False

-- | Like an API request LogicTree, but with coercible field information.
data CoercibleLogicTree
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ callPlanToQuery (FunctionCall qi params arguments returnsScalar returnsSetOfScal
KeyParams [] -> "FROM " <> callIt mempty
KeyParams prms -> case arguments of
DirectArgs args -> "FROM " <> callIt (fmtArgs prms args)
JsonArgs json -> fromJsonBodyF json ((\p -> CoercibleField (ppName p) mempty False (ppTypeMaxLength p) Nothing Nothing) <$> prms) False True False <> ", " <>
JsonArgs json -> fromJsonBodyF json ((\p -> CoercibleField (ppName p) mempty False (ppTypeMaxLength p) Nothing Nothing False) <$> prms) False True False <> ", " <>
"LATERAL " <> callIt (fmtParams prms)

callIt :: SQL.Snippet -> SQL.Snippet
Expand Down
1 change: 1 addition & 0 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ pgFmtCallUnary :: Text -> SQL.Snippet -> SQL.Snippet
pgFmtCallUnary f x = SQL.sql (encodeUtf8 f) <> "(" <> x <> ")"

pgFmtField :: QualifiedIdentifier -> CoercibleField -> SQL.Snippet
pgFmtField table CoercibleField{cfFullRow=True} = fromQi table
pgFmtField table CoercibleField{cfName=fn, cfJsonPath=[]} = pgFmtColumn table fn
pgFmtField table CoercibleField{cfName=fn, cfToJson=doToJson, cfJsonPath=jp} | doToJson = "to_jsonb(" <> pgFmtColumn table fn <> ")" <> pgFmtJsonPath jp
| otherwise = pgFmtColumn table fn <> pgFmtJsonPath jp
Expand Down
67 changes: 67 additions & 0 deletions test/spec/Feature/Query/AggregateFunctionsSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,73 @@ allowed =
{"name": "Sarah", "process_supervisor": [{"category": "Batch", "cost_sum": 180.00}]}]|]
{ matchHeaders = [matchContentTypeJson] }

context "supports count() aggregate without specifying a field" $ do
it "works by itself in the embedded resource" $ do
get "/process_supervisor?select=supervisor_id,...processes(count())&order=supervisor_id" `shouldRespondWith`
[json|[
{"supervisor_id": 1, "count": 2},
{"supervisor_id": 2, "count": 2},
{"supervisor_id": 3, "count": 3},
{"supervisor_id": 4, "count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
get "/process_supervisor?select=supervisor_id,...processes(processes_count:count())&order=supervisor_id" `shouldRespondWith`
[json|[
{"supervisor_id": 1, "processes_count": 2},
{"supervisor_id": 2, "processes_count": 2},
{"supervisor_id": 3, "processes_count": 3},
{"supervisor_id": 4, "processes_count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
it "works alongside other columns in the embedded resource" $ do
get "/process_supervisor?select=...supervisors(id,count())&order=supervisors(id)" `shouldRespondWith`
[json|[
{"id": 1, "count": 2},
{"id": 2, "count": 2},
{"id": 3, "count": 3},
{"id": 4, "count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
get "/process_supervisor?select=...supervisors(supervisor:id,supervisor_count:count())&order=supervisors(supervisor)" `shouldRespondWith`
[json|[
{"supervisor": 1, "supervisor_count": 2},
{"supervisor": 2, "supervisor_count": 2},
{"supervisor": 3, "supervisor_count": 3},
{"supervisor": 4, "supervisor_count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
it "works on nested resources" $ do
get "/process_supervisor?select=supervisor_id,...processes(...process_costs(count()))&order=supervisor_id" `shouldRespondWith`
[json|[
{"supervisor_id": 1, "count": 2},
{"supervisor_id": 2, "count": 2},
{"supervisor_id": 3, "count": 2},
{"supervisor_id": 4, "count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
get "/process_supervisor?select=supervisor:supervisor_id,...processes(...process_costs(process_costs_count:count()))&order=supervisor_id" `shouldRespondWith`
[json|[
{"supervisor": 1, "process_costs_count": 2},
{"supervisor": 2, "process_costs_count": 2},
{"supervisor": 3, "process_costs_count": 2},
{"supervisor": 4, "process_costs_count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
it "works on nested resources grouped by spreaded fields" $ do
get "/process_supervisor?select=...processes(factory_id,...process_costs(count()))&order=processes(factory_id)" `shouldRespondWith`
[json|[
{"factory_id": 1, "count": 2},
{"factory_id": 2, "count": 4},
{"factory_id": 3, "count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
get "/process_supervisor?select=...processes(factory:factory_id,...process_costs(process_costs_count:count()))&order=processes(factory)" `shouldRespondWith`
[json|[
{"factory": 1, "process_costs_count": 2},
{"factory": 2, "process_costs_count": 4},
{"factory": 3, "process_costs_count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }
it "works on different levels of the nested resources at the same time" $
get "/process_supervisor?select=...processes(factory:factory_id,processes_count:count(),...process_costs(process_costs_count:count()))&order=processes(factory)" `shouldRespondWith`
[json|[
{"factory": 1, "processes_count": 2, "process_costs_count": 2},
{"factory": 2, "processes_count": 4, "process_costs_count": 4},
{"factory": 3, "processes_count": 2, "process_costs_count": 1}]|]
{ matchHeaders = [matchContentTypeJson] }

disallowed :: SpecWith ((), Application)
disallowed =
describe "attempting to use an aggregate when aggregate functions are disallowed" $ do
Expand Down
2 changes: 2 additions & 0 deletions test/spec/fixtures/data.sql
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ INSERT INTO processes VALUES (2, 'Process A2', 1, 2);
INSERT INTO processes VALUES (3, 'Process B1', 2, 1);
INSERT INTO processes VALUES (4, 'Process B2', 2, 1);
INSERT INTO processes VALUES (5, 'Process C1', 3, 2);
INSERT INTO processes VALUES (6, 'Process C2', 3, 2);

TRUNCATE TABLE process_costs CASCADE;
INSERT INTO process_costs VALUES (1, 150.00);
Expand All @@ -922,3 +923,4 @@ INSERT INTO process_supervisor VALUES (3, 4);
INSERT INTO process_supervisor VALUES (4, 1);
INSERT INTO process_supervisor VALUES (4, 2);
INSERT INTO process_supervisor VALUES (5, 3);
INSERT INTO process_supervisor VALUES (6, 3);

0 comments on commit e44fbdd

Please sign in to comment.