Skip to content

Commit c1e31bc

Browse files
committed
fix: normalize totals cache keys for async hits
1 parent cf88551 commit c1e31bc

File tree

2 files changed

+125
-12
lines changed

2 files changed

+125
-12
lines changed

superset/common/query_context_processor.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import logging
2020
import re
21-
from typing import Any, cast, ClassVar, TYPE_CHECKING
21+
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING
2222

2323
import pandas as pd
2424
from flask import current_app
@@ -251,9 +251,13 @@ def get_data(
251251

252252
return df.to_dict(orient="records")
253253

254-
def ensure_totals_available(self) -> None:
255-
queries_needing_totals = []
256-
totals_queries = []
254+
def _prepare_contribution_totals(self) -> tuple[list[int], int | None]:
255+
"""
256+
Identify contribution queries and normalize the totals query so cache keys
257+
align with cached results.
258+
"""
259+
queries_needing_totals: list[int] = []
260+
totals_idx: int | None = None
257261

258262
for i, query in enumerate(self._query_context.queries):
259263
needs_totals = any(
@@ -267,17 +271,28 @@ def ensure_totals_available(self) -> None:
267271
is_totals_query = (
268272
not query.columns and query.metrics and not query.post_processing
269273
)
270-
if is_totals_query:
271-
totals_queries.append(i)
274+
if is_totals_query and totals_idx is None:
275+
totals_idx = i
276+
277+
if queries_needing_totals and totals_idx is not None:
278+
totals_query = self._query_context.queries[totals_idx]
279+
totals_query.row_limit = None
280+
281+
return queries_needing_totals, totals_idx
272282

273-
if not queries_needing_totals or not totals_queries:
283+
def ensure_totals_available(
284+
self,
285+
queries_needing_totals: Sequence[int] | None = None,
286+
totals_idx: int | None = None,
287+
) -> None:
288+
if queries_needing_totals is None or totals_idx is None:
289+
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
290+
291+
if not queries_needing_totals or totals_idx is None:
274292
return
275293

276-
totals_idx = totals_queries[0]
277294
totals_query = self._query_context.queries[totals_idx]
278295

279-
totals_query.row_limit = None
280-
281296
result = self._query_context.get_query_result(totals_query)
282297
df = result.df
283298

@@ -299,10 +314,12 @@ def get_payload(
299314
) -> dict[str, Any]:
300315
"""Returns the query results with both metadata and data"""
301316

317+
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
318+
302319
# Skip ensure_totals_available when force_cached=True
303320
# This prevents recalculating contribution_totals from cached results
304321
if not force_cached:
305-
self.ensure_totals_available()
322+
self.ensure_totals_available(queries_needing_totals, totals_idx)
306323

307324
# Update cache_values to reflect modifications made by
308325
# ensure_totals_available()

tests/unit_tests/common/test_query_context_processor.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from typing import Any
1819
from unittest.mock import MagicMock, patch
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytest
2324

24-
from superset.common.chart_data import ChartDataResultFormat
25+
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
26+
from superset.common.db_query_status import QueryStatus
2527
from superset.common.query_context_processor import QueryContextProcessor
2628
from superset.utils.core import GenericDataType
2729

@@ -1217,3 +1219,97 @@ def test_cache_key_non_contribution_post_processing_unchanged():
12171219
assert query1.cache_key() != query2.cache_key(), (
12181220
"Cache keys should differ for different non-contribution post_processing"
12191221
)
1222+
1223+
1224+
def test_force_cached_normalizes_totals_query_row_limit():
1225+
"""
1226+
When fetching from cache (force_cached=True), the totals query should still be
1227+
normalized so its cache key matches the cached entry, but the totals query should
1228+
not be executed.
1229+
"""
1230+
from superset.common.query_object import QueryObject
1231+
1232+
mock_datasource = MagicMock()
1233+
mock_datasource.uid = "test_datasource"
1234+
mock_datasource.column_names = ["region", "sales"]
1235+
mock_datasource.cache_timeout = None
1236+
mock_datasource.changed_on = None
1237+
mock_datasource.get_extra_cache_keys.return_value = []
1238+
mock_datasource.database.extra = "{}"
1239+
mock_datasource.database.impersonate_user = False
1240+
mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None
1241+
1242+
totals_query = QueryObject(
1243+
datasource=mock_datasource,
1244+
columns=[],
1245+
metrics=["sales"],
1246+
row_limit=1000,
1247+
)
1248+
main_query = QueryObject(
1249+
datasource=mock_datasource,
1250+
columns=["region"],
1251+
metrics=["sales"],
1252+
row_limit=1000,
1253+
post_processing=[{"operation": "contribution", "options": {}}],
1254+
)
1255+
1256+
totals_query.validate = MagicMock()
1257+
main_query.validate = MagicMock()
1258+
1259+
captured_limits: list[int | None] = []
1260+
1261+
def totals_cache_key(**kwargs: Any) -> str:
1262+
captured_limits.append(totals_query.row_limit)
1263+
return "totals-cache-key"
1264+
1265+
totals_query.cache_key = totals_cache_key
1266+
main_query.cache_key = lambda **kwargs: "main-cache-key"
1267+
1268+
mock_query_context = MagicMock()
1269+
mock_query_context.force = False
1270+
mock_query_context.datasource = mock_datasource
1271+
mock_query_context.queries = [main_query, totals_query]
1272+
mock_query_context.result_type = ChartDataResultType.FULL
1273+
mock_query_context.result_format = ChartDataResultFormat.JSON
1274+
mock_query_context.cache_values = {
1275+
"queries": [main_query.to_dict(), totals_query.to_dict()]
1276+
}
1277+
mock_query_context.get_query_result = MagicMock()
1278+
1279+
processor = QueryContextProcessor(mock_query_context)
1280+
processor._qc_datasource = mock_datasource
1281+
mock_query_context.get_df_payload = processor.get_df_payload
1282+
mock_query_context.get_data = processor.get_data
1283+
1284+
with patch(
1285+
"superset.common.query_context_processor.security_manager"
1286+
) as mock_security_manager:
1287+
mock_security_manager.get_rls_cache_key.return_value = None
1288+
1289+
with patch(
1290+
"superset.common.query_context_processor.QueryCacheManager"
1291+
) as mock_cache_manager:
1292+
1293+
def cache_get(*args: Any, **kwargs: Any) -> Any:
1294+
df = pd.DataFrame({"region": ["North"], "sales": [100]})
1295+
cache = MagicMock()
1296+
cache.is_loaded = True
1297+
cache.df = df
1298+
cache.query = "SELECT 1"
1299+
cache.error_message = None
1300+
cache.status = QueryStatus.SUCCESS
1301+
cache.applied_template_filters = []
1302+
cache.applied_filter_columns = []
1303+
cache.rejected_filter_columns = []
1304+
cache.annotation_data = {}
1305+
cache.is_cached = True
1306+
cache.sql_rowcount = len(df)
1307+
cache.cache_dttm = "2024-01-01T00:00:00"
1308+
return cache
1309+
1310+
mock_cache_manager.get.side_effect = cache_get
1311+
1312+
processor.get_payload(cache_query_context=False, force_cached=True)
1313+
1314+
assert captured_limits == [None], "Totals query should be normalized before caching"
1315+
mock_query_context.get_query_result.assert_not_called()

0 commit comments

Comments
 (0)