Skip to content

Commit 2d39e7f

Browse files
committed
fix: normalize totals cache keys for async hits
1 parent 0c87034 commit 2d39e7f

File tree

2 files changed

+279
-12
lines changed

2 files changed

+279
-12
lines changed

superset/common/query_context_processor.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import logging
2020
import re
21-
from typing import Any, cast, ClassVar, TYPE_CHECKING
21+
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING
2222

2323
import pandas as pd
2424
from flask import current_app
@@ -251,9 +251,13 @@ def get_data(
251251

252252
return df.to_dict(orient="records")
253253

254-
def ensure_totals_available(self) -> None:
255-
queries_needing_totals = []
256-
totals_queries = []
254+
def _prepare_contribution_totals(self) -> tuple[list[int], int | None]:
255+
"""
256+
Identify contribution queries and normalize the totals query so cache keys
257+
align with cached results.
258+
"""
259+
queries_needing_totals: list[int] = []
260+
totals_idx: int | None = None
257261

258262
for i, query in enumerate(self._query_context.queries):
259263
needs_totals = any(
@@ -267,17 +271,28 @@ def ensure_totals_available(self) -> None:
267271
is_totals_query = (
268272
not query.columns and query.metrics and not query.post_processing
269273
)
270-
if is_totals_query:
271-
totals_queries.append(i)
274+
if is_totals_query and totals_idx is None:
275+
totals_idx = i
276+
277+
if totals_idx is not None:
278+
totals_query = self._query_context.queries[totals_idx]
279+
totals_query.row_limit = None
280+
281+
return queries_needing_totals, totals_idx
272282

273-
if not queries_needing_totals or not totals_queries:
283+
def ensure_totals_available(
284+
self,
285+
queries_needing_totals: Sequence[int] | None = None,
286+
totals_idx: int | None = None,
287+
) -> None:
288+
if queries_needing_totals is None or totals_idx is None:
289+
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
290+
291+
if not queries_needing_totals or totals_idx is None:
274292
return
275293

276-
totals_idx = totals_queries[0]
277294
totals_query = self._query_context.queries[totals_idx]
278295

279-
totals_query.row_limit = None
280-
281296
result = self._query_context.get_query_result(totals_query)
282297
df = result.df
283298

@@ -299,7 +314,12 @@ def get_payload(
299314
) -> dict[str, Any]:
300315
"""Returns the query results with both metadata and data"""
301316

302-
self.ensure_totals_available()
317+
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
318+
319+
# Skip ensure_totals_available when force_cached=True
320+
# This prevents recalculating contribution_totals from cached results
321+
if not force_cached:
322+
self.ensure_totals_available(queries_needing_totals, totals_idx)
303323

304324
# Update cache_values to reflect modifications made by ensure_totals_available()
305325
# This ensures cache keys are generated from the actual query state

tests/unit_tests/common/test_query_context_processor.py

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from typing import Any
1819
from unittest.mock import MagicMock, patch
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytest
2324

24-
from superset.common.chart_data import ChartDataResultFormat
25+
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
26+
from superset.common.db_query_status import QueryStatus
2527
from superset.common.query_context_processor import QueryContextProcessor
2628
from superset.utils.core import GenericDataType
2729

@@ -1066,3 +1068,248 @@ def test_cache_values_sync_after_ensure_totals_available():
10661068
# Verify that the main query row_limit is still 1000 (only totals query
10671069
# should be modified)
10681070
assert updated_cache_queries[0]["row_limit"] == 1000
1071+
1072+
1073+
def test_cache_key_excludes_contribution_totals():
1074+
"""
1075+
Test that cache_key() excludes contribution_totals from post_processing.
1076+
1077+
contribution_totals is computed at runtime by ensure_totals_available() and
1078+
varies per request. Including it in the cache key would cause mismatches
1079+
between workers that compute different totals for the same query.
1080+
"""
1081+
from superset.common.query_object import QueryObject
1082+
1083+
mock_datasource = MagicMock()
1084+
mock_datasource.uid = "test_datasource"
1085+
mock_datasource.database.extra = "{}"
1086+
mock_datasource.get_extra_cache_keys.return_value = []
1087+
1088+
# Create query with contribution post-processing that includes contribution_totals
1089+
query_with_totals = QueryObject(
1090+
datasource=mock_datasource,
1091+
columns=["region"],
1092+
metrics=["sales", "profit"],
1093+
post_processing=[
1094+
{
1095+
"operation": "contribution",
1096+
"options": {
1097+
"columns": ["sales", "profit"],
1098+
"rename_columns": ["%sales", "%profit"],
1099+
"contribution_totals": {"sales": 1000.0, "profit": 200.0},
1100+
},
1101+
}
1102+
],
1103+
)
1104+
1105+
# Create identical query without contribution_totals
1106+
query_without_totals = QueryObject(
1107+
datasource=mock_datasource,
1108+
columns=["region"],
1109+
metrics=["sales", "profit"],
1110+
post_processing=[
1111+
{
1112+
"operation": "contribution",
1113+
"options": {
1114+
"columns": ["sales", "profit"],
1115+
"rename_columns": ["%sales", "%profit"],
1116+
},
1117+
}
1118+
],
1119+
)
1120+
1121+
# Cache keys should be identical since contribution_totals is excluded
1122+
cache_key_with = query_with_totals.cache_key()
1123+
cache_key_without = query_without_totals.cache_key()
1124+
1125+
assert cache_key_with == cache_key_without, (
1126+
"Cache keys should match regardless of contribution_totals. "
1127+
f"With totals: {cache_key_with}, Without totals: {cache_key_without}"
1128+
)
1129+
1130+
1131+
def test_cache_key_preserves_other_post_processing_options():
1132+
"""
1133+
Test that cache_key() only excludes contribution_totals, not other options.
1134+
"""
1135+
from superset.common.query_object import QueryObject
1136+
1137+
mock_datasource = MagicMock()
1138+
mock_datasource.uid = "test_datasource"
1139+
mock_datasource.database.extra = "{}"
1140+
mock_datasource.get_extra_cache_keys.return_value = []
1141+
1142+
# Create query with contribution post-processing
1143+
query1 = QueryObject(
1144+
datasource=mock_datasource,
1145+
columns=["region"],
1146+
metrics=["sales"],
1147+
post_processing=[
1148+
{
1149+
"operation": "contribution",
1150+
"options": {
1151+
"columns": ["sales"],
1152+
"rename_columns": ["%sales"],
1153+
"contribution_totals": {"sales": 1000.0},
1154+
},
1155+
}
1156+
],
1157+
)
1158+
1159+
# Create query with different rename_columns
1160+
query2 = QueryObject(
1161+
datasource=mock_datasource,
1162+
columns=["region"],
1163+
metrics=["sales"],
1164+
post_processing=[
1165+
{
1166+
"operation": "contribution",
1167+
"options": {
1168+
"columns": ["sales"],
1169+
"rename_columns": ["%sales_pct"], # Different!
1170+
"contribution_totals": {"sales": 1000.0},
1171+
},
1172+
}
1173+
],
1174+
)
1175+
1176+
# Cache keys should differ because rename_columns is different
1177+
assert query1.cache_key() != query2.cache_key(), (
1178+
"Cache keys should differ when other post_processing options differ"
1179+
)
1180+
1181+
1182+
def test_cache_key_non_contribution_post_processing_unchanged():
1183+
"""
1184+
Test that non-contribution post_processing operations are unchanged in cache key.
1185+
"""
1186+
from superset.common.query_object import QueryObject
1187+
1188+
mock_datasource = MagicMock()
1189+
mock_datasource.uid = "test_datasource"
1190+
mock_datasource.database.extra = "{}"
1191+
mock_datasource.get_extra_cache_keys.return_value = []
1192+
1193+
# Create query with non-contribution post-processing
1194+
query1 = QueryObject(
1195+
datasource=mock_datasource,
1196+
columns=["region"],
1197+
metrics=["sales"],
1198+
post_processing=[
1199+
{
1200+
"operation": "pivot",
1201+
"options": {"columns": ["region"], "aggregates": {"sales": "sum"}},
1202+
}
1203+
],
1204+
)
1205+
1206+
query2 = QueryObject(
1207+
datasource=mock_datasource,
1208+
columns=["region"],
1209+
metrics=["sales"],
1210+
post_processing=[
1211+
{
1212+
"operation": "pivot",
1213+
"options": {"columns": ["region"], "aggregates": {"sales": "mean"}},
1214+
}
1215+
],
1216+
)
1217+
1218+
# Cache keys should differ because aggregates option is different
1219+
assert query1.cache_key() != query2.cache_key(), (
1220+
"Cache keys should differ for different non-contribution post_processing"
1221+
)
1222+
1223+
1224+
def test_force_cached_normalizes_totals_query_row_limit():
1225+
"""
1226+
When fetching from cache (force_cached=True), the totals query should still be
1227+
normalized so its cache key matches the cached entry, but the totals query should
1228+
not be executed.
1229+
"""
1230+
from superset.common.query_object import QueryObject
1231+
1232+
mock_datasource = MagicMock()
1233+
mock_datasource.uid = "test_datasource"
1234+
mock_datasource.column_names = ["region", "sales"]
1235+
mock_datasource.cache_timeout = None
1236+
mock_datasource.changed_on = None
1237+
mock_datasource.get_extra_cache_keys.return_value = []
1238+
mock_datasource.database.extra = "{}"
1239+
mock_datasource.database.impersonate_user = False
1240+
mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None
1241+
1242+
totals_query = QueryObject(
1243+
datasource=mock_datasource,
1244+
columns=[],
1245+
metrics=["sales"],
1246+
row_limit=1000,
1247+
)
1248+
main_query = QueryObject(
1249+
datasource=mock_datasource,
1250+
columns=["region"],
1251+
metrics=["sales"],
1252+
row_limit=1000,
1253+
post_processing=[{"operation": "contribution", "options": {}}],
1254+
)
1255+
1256+
totals_query.validate = MagicMock()
1257+
main_query.validate = MagicMock()
1258+
1259+
captured_limits: list[int | None] = []
1260+
1261+
def totals_cache_key(**kwargs: Any) -> str:
1262+
captured_limits.append(totals_query.row_limit)
1263+
return "totals-cache-key"
1264+
1265+
totals_query.cache_key = totals_cache_key
1266+
main_query.cache_key = lambda **kwargs: "main-cache-key"
1267+
1268+
mock_query_context = MagicMock()
1269+
mock_query_context.force = False
1270+
mock_query_context.datasource = mock_datasource
1271+
mock_query_context.queries = [main_query, totals_query]
1272+
mock_query_context.result_type = ChartDataResultType.FULL
1273+
mock_query_context.result_format = ChartDataResultFormat.JSON
1274+
mock_query_context.cache_values = {
1275+
"queries": [main_query.to_dict(), totals_query.to_dict()]
1276+
}
1277+
mock_query_context.get_query_result = MagicMock()
1278+
1279+
processor = QueryContextProcessor(mock_query_context)
1280+
processor._qc_datasource = mock_datasource
1281+
mock_query_context.get_df_payload = processor.get_df_payload
1282+
mock_query_context.get_data = processor.get_data
1283+
1284+
with patch(
1285+
"superset.common.query_context_processor.security_manager"
1286+
) as mock_security_manager:
1287+
mock_security_manager.get_rls_cache_key.return_value = None
1288+
1289+
with patch(
1290+
"superset.common.query_context_processor.QueryCacheManager"
1291+
) as mock_cache_manager:
1292+
1293+
def cache_get(*args: Any, **kwargs: Any) -> Any:
1294+
df = pd.DataFrame({"region": ["North"], "sales": [100]})
1295+
cache = MagicMock()
1296+
cache.is_loaded = True
1297+
cache.df = df
1298+
cache.query = "SELECT 1"
1299+
cache.error_message = None
1300+
cache.status = QueryStatus.SUCCESS
1301+
cache.applied_template_filters = []
1302+
cache.applied_filter_columns = []
1303+
cache.rejected_filter_columns = []
1304+
cache.annotation_data = {}
1305+
cache.is_cached = True
1306+
cache.sql_rowcount = len(df)
1307+
cache.cache_dttm = "2024-01-01T00:00:00"
1308+
return cache
1309+
1310+
mock_cache_manager.get.side_effect = cache_get
1311+
1312+
processor.get_payload(cache_query_context=False, force_cached=True)
1313+
1314+
assert captured_limits == [None], "Totals query should be normalized before caching"
1315+
mock_query_context.get_query_result.assert_not_called()

0 commit comments

Comments
 (0)