|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 |
|
| 18 | +from typing import Any |
18 | 19 | from unittest.mock import MagicMock, patch |
19 | 20 |
|
20 | 21 | import numpy as np |
21 | 22 | import pandas as pd |
22 | 23 | import pytest |
23 | 24 |
|
24 | | -from superset.common.chart_data import ChartDataResultFormat |
| 25 | +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType |
| 26 | +from superset.common.db_query_status import QueryStatus |
25 | 27 | from superset.common.query_context_processor import QueryContextProcessor |
26 | 28 | from superset.utils.core import GenericDataType |
27 | 29 |
|
@@ -1217,3 +1219,97 @@ def test_cache_key_non_contribution_post_processing_unchanged(): |
1217 | 1219 | assert query1.cache_key() != query2.cache_key(), ( |
1218 | 1220 | "Cache keys should differ for different non-contribution post_processing" |
1219 | 1221 | ) |
| 1222 | + |
| 1223 | + |
| 1224 | +def test_force_cached_normalizes_totals_query_row_limit(): |
| 1225 | + """ |
| 1226 | + When fetching from cache (force_cached=True), the totals query should still be |
| 1227 | + normalized so its cache key matches the cached entry, but the totals query should |
| 1228 | + not be executed. |
| 1229 | + """ |
| 1230 | + from superset.common.query_object import QueryObject |
| 1231 | + |
| 1232 | + mock_datasource = MagicMock() |
| 1233 | + mock_datasource.uid = "test_datasource" |
| 1234 | + mock_datasource.column_names = ["region", "sales"] |
| 1235 | + mock_datasource.cache_timeout = None |
| 1236 | + mock_datasource.changed_on = None |
| 1237 | + mock_datasource.get_extra_cache_keys.return_value = [] |
| 1238 | + mock_datasource.database.extra = "{}" |
| 1239 | + mock_datasource.database.impersonate_user = False |
| 1240 | + mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None |
| 1241 | + |
| 1242 | + totals_query = QueryObject( |
| 1243 | + datasource=mock_datasource, |
| 1244 | + columns=[], |
| 1245 | + metrics=["sales"], |
| 1246 | + row_limit=1000, |
| 1247 | + ) |
| 1248 | + main_query = QueryObject( |
| 1249 | + datasource=mock_datasource, |
| 1250 | + columns=["region"], |
| 1251 | + metrics=["sales"], |
| 1252 | + row_limit=1000, |
| 1253 | + post_processing=[{"operation": "contribution", "options": {}}], |
| 1254 | + ) |
| 1255 | + |
| 1256 | + totals_query.validate = MagicMock() |
| 1257 | + main_query.validate = MagicMock() |
| 1258 | + |
| 1259 | + captured_limits: list[int | None] = [] |
| 1260 | + |
| 1261 | + def totals_cache_key(**kwargs: Any) -> str: |
| 1262 | + captured_limits.append(totals_query.row_limit) |
| 1263 | + return "totals-cache-key" |
| 1264 | + |
| 1265 | + totals_query.cache_key = totals_cache_key |
| 1266 | + main_query.cache_key = lambda **kwargs: "main-cache-key" |
| 1267 | + |
| 1268 | + mock_query_context = MagicMock() |
| 1269 | + mock_query_context.force = False |
| 1270 | + mock_query_context.datasource = mock_datasource |
| 1271 | + mock_query_context.queries = [main_query, totals_query] |
| 1272 | + mock_query_context.result_type = ChartDataResultType.FULL |
| 1273 | + mock_query_context.result_format = ChartDataResultFormat.JSON |
| 1274 | + mock_query_context.cache_values = { |
| 1275 | + "queries": [main_query.to_dict(), totals_query.to_dict()] |
| 1276 | + } |
| 1277 | + mock_query_context.get_query_result = MagicMock() |
| 1278 | + |
| 1279 | + processor = QueryContextProcessor(mock_query_context) |
| 1280 | + processor._qc_datasource = mock_datasource |
| 1281 | + mock_query_context.get_df_payload = processor.get_df_payload |
| 1282 | + mock_query_context.get_data = processor.get_data |
| 1283 | + |
| 1284 | + with patch( |
| 1285 | + "superset.common.query_context_processor.security_manager" |
| 1286 | + ) as mock_security_manager: |
| 1287 | + mock_security_manager.get_rls_cache_key.return_value = None |
| 1288 | + |
| 1289 | + with patch( |
| 1290 | + "superset.common.query_context_processor.QueryCacheManager" |
| 1291 | + ) as mock_cache_manager: |
| 1292 | + |
| 1293 | + def cache_get(*args: Any, **kwargs: Any) -> Any: |
| 1294 | + df = pd.DataFrame({"region": ["North"], "sales": [100]}) |
| 1295 | + cache = MagicMock() |
| 1296 | + cache.is_loaded = True |
| 1297 | + cache.df = df |
| 1298 | + cache.query = "SELECT 1" |
| 1299 | + cache.error_message = None |
| 1300 | + cache.status = QueryStatus.SUCCESS |
| 1301 | + cache.applied_template_filters = [] |
| 1302 | + cache.applied_filter_columns = [] |
| 1303 | + cache.rejected_filter_columns = [] |
| 1304 | + cache.annotation_data = {} |
| 1305 | + cache.is_cached = True |
| 1306 | + cache.sql_rowcount = len(df) |
| 1307 | + cache.cache_dttm = "2024-01-01T00:00:00" |
| 1308 | + return cache |
| 1309 | + |
| 1310 | + mock_cache_manager.get.side_effect = cache_get |
| 1311 | + |
| 1312 | + processor.get_payload(cache_query_context=False, force_cached=True) |
| 1313 | + |
| 1314 | + assert captured_limits == [None], "Totals query should be normalized before caching" |
| 1315 | + mock_query_context.get_query_result.assert_not_called() |
0 commit comments