|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 |
|
| 18 | +from typing import Any |
18 | 19 | from unittest.mock import MagicMock, patch |
19 | 20 |
|
20 | 21 | import numpy as np |
21 | 22 | import pandas as pd |
22 | 23 | import pytest |
23 | 24 |
|
24 | | -from superset.common.chart_data import ChartDataResultFormat |
| 25 | +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType |
| 26 | +from superset.common.db_query_status import QueryStatus |
25 | 27 | from superset.common.query_context_processor import QueryContextProcessor |
26 | 28 | from superset.utils.core import GenericDataType |
27 | 29 |
|
@@ -1066,3 +1068,248 @@ def test_cache_values_sync_after_ensure_totals_available(): |
1066 | 1068 | # Verify that the main query row_limit is still 1000 (only totals query |
1067 | 1069 | # should be modified) |
1068 | 1070 | assert updated_cache_queries[0]["row_limit"] == 1000 |
| 1071 | + |
| 1072 | + |
| 1073 | +def test_cache_key_excludes_contribution_totals(): |
| 1074 | + """ |
| 1075 | + Test that cache_key() excludes contribution_totals from post_processing. |
| 1076 | +
|
| 1077 | + contribution_totals is computed at runtime by ensure_totals_available() and |
| 1078 | + varies per request. Including it in the cache key would cause mismatches |
| 1079 | + between workers that compute different totals for the same query. |
| 1080 | + """ |
| 1081 | + from superset.common.query_object import QueryObject |
| 1082 | + |
| 1083 | + mock_datasource = MagicMock() |
| 1084 | + mock_datasource.uid = "test_datasource" |
| 1085 | + mock_datasource.database.extra = "{}" |
| 1086 | + mock_datasource.get_extra_cache_keys.return_value = [] |
| 1087 | + |
| 1088 | + # Create query with contribution post-processing that includes contribution_totals |
| 1089 | + query_with_totals = QueryObject( |
| 1090 | + datasource=mock_datasource, |
| 1091 | + columns=["region"], |
| 1092 | + metrics=["sales", "profit"], |
| 1093 | + post_processing=[ |
| 1094 | + { |
| 1095 | + "operation": "contribution", |
| 1096 | + "options": { |
| 1097 | + "columns": ["sales", "profit"], |
| 1098 | + "rename_columns": ["%sales", "%profit"], |
| 1099 | + "contribution_totals": {"sales": 1000.0, "profit": 200.0}, |
| 1100 | + }, |
| 1101 | + } |
| 1102 | + ], |
| 1103 | + ) |
| 1104 | + |
| 1105 | + # Create identical query without contribution_totals |
| 1106 | + query_without_totals = QueryObject( |
| 1107 | + datasource=mock_datasource, |
| 1108 | + columns=["region"], |
| 1109 | + metrics=["sales", "profit"], |
| 1110 | + post_processing=[ |
| 1111 | + { |
| 1112 | + "operation": "contribution", |
| 1113 | + "options": { |
| 1114 | + "columns": ["sales", "profit"], |
| 1115 | + "rename_columns": ["%sales", "%profit"], |
| 1116 | + }, |
| 1117 | + } |
| 1118 | + ], |
| 1119 | + ) |
| 1120 | + |
| 1121 | + # Cache keys should be identical since contribution_totals is excluded |
| 1122 | + cache_key_with = query_with_totals.cache_key() |
| 1123 | + cache_key_without = query_without_totals.cache_key() |
| 1124 | + |
| 1125 | + assert cache_key_with == cache_key_without, ( |
| 1126 | + "Cache keys should match regardless of contribution_totals. " |
| 1127 | + f"With totals: {cache_key_with}, Without totals: {cache_key_without}" |
| 1128 | + ) |
| 1129 | + |
| 1130 | + |
| 1131 | +def test_cache_key_preserves_other_post_processing_options(): |
| 1132 | + """ |
| 1133 | + Test that cache_key() only excludes contribution_totals, not other options. |
| 1134 | + """ |
| 1135 | + from superset.common.query_object import QueryObject |
| 1136 | + |
| 1137 | + mock_datasource = MagicMock() |
| 1138 | + mock_datasource.uid = "test_datasource" |
| 1139 | + mock_datasource.database.extra = "{}" |
| 1140 | + mock_datasource.get_extra_cache_keys.return_value = [] |
| 1141 | + |
| 1142 | + # Create query with contribution post-processing |
| 1143 | + query1 = QueryObject( |
| 1144 | + datasource=mock_datasource, |
| 1145 | + columns=["region"], |
| 1146 | + metrics=["sales"], |
| 1147 | + post_processing=[ |
| 1148 | + { |
| 1149 | + "operation": "contribution", |
| 1150 | + "options": { |
| 1151 | + "columns": ["sales"], |
| 1152 | + "rename_columns": ["%sales"], |
| 1153 | + "contribution_totals": {"sales": 1000.0}, |
| 1154 | + }, |
| 1155 | + } |
| 1156 | + ], |
| 1157 | + ) |
| 1158 | + |
| 1159 | + # Create query with different rename_columns |
| 1160 | + query2 = QueryObject( |
| 1161 | + datasource=mock_datasource, |
| 1162 | + columns=["region"], |
| 1163 | + metrics=["sales"], |
| 1164 | + post_processing=[ |
| 1165 | + { |
| 1166 | + "operation": "contribution", |
| 1167 | + "options": { |
| 1168 | + "columns": ["sales"], |
| 1169 | + "rename_columns": ["%sales_pct"], # Different! |
| 1170 | + "contribution_totals": {"sales": 1000.0}, |
| 1171 | + }, |
| 1172 | + } |
| 1173 | + ], |
| 1174 | + ) |
| 1175 | + |
| 1176 | + # Cache keys should differ because rename_columns is different |
| 1177 | + assert query1.cache_key() != query2.cache_key(), ( |
| 1178 | + "Cache keys should differ when other post_processing options differ" |
| 1179 | + ) |
| 1180 | + |
| 1181 | + |
| 1182 | +def test_cache_key_non_contribution_post_processing_unchanged(): |
| 1183 | + """ |
| 1184 | + Test that non-contribution post_processing operations are unchanged in cache key. |
| 1185 | + """ |
| 1186 | + from superset.common.query_object import QueryObject |
| 1187 | + |
| 1188 | + mock_datasource = MagicMock() |
| 1189 | + mock_datasource.uid = "test_datasource" |
| 1190 | + mock_datasource.database.extra = "{}" |
| 1191 | + mock_datasource.get_extra_cache_keys.return_value = [] |
| 1192 | + |
| 1193 | + # Create query with non-contribution post-processing |
| 1194 | + query1 = QueryObject( |
| 1195 | + datasource=mock_datasource, |
| 1196 | + columns=["region"], |
| 1197 | + metrics=["sales"], |
| 1198 | + post_processing=[ |
| 1199 | + { |
| 1200 | + "operation": "pivot", |
| 1201 | + "options": {"columns": ["region"], "aggregates": {"sales": "sum"}}, |
| 1202 | + } |
| 1203 | + ], |
| 1204 | + ) |
| 1205 | + |
| 1206 | + query2 = QueryObject( |
| 1207 | + datasource=mock_datasource, |
| 1208 | + columns=["region"], |
| 1209 | + metrics=["sales"], |
| 1210 | + post_processing=[ |
| 1211 | + { |
| 1212 | + "operation": "pivot", |
| 1213 | + "options": {"columns": ["region"], "aggregates": {"sales": "mean"}}, |
| 1214 | + } |
| 1215 | + ], |
| 1216 | + ) |
| 1217 | + |
| 1218 | + # Cache keys should differ because aggregates option is different |
| 1219 | + assert query1.cache_key() != query2.cache_key(), ( |
| 1220 | + "Cache keys should differ for different non-contribution post_processing" |
| 1221 | + ) |
| 1222 | + |
| 1223 | + |
| 1224 | +def test_force_cached_normalizes_totals_query_row_limit(): |
| 1225 | + """ |
| 1226 | + When fetching from cache (force_cached=True), the totals query should still be |
| 1227 | + normalized so its cache key matches the cached entry, but the totals query should |
| 1228 | + not be executed. |
| 1229 | + """ |
| 1230 | + from superset.common.query_object import QueryObject |
| 1231 | + |
| 1232 | + mock_datasource = MagicMock() |
| 1233 | + mock_datasource.uid = "test_datasource" |
| 1234 | + mock_datasource.column_names = ["region", "sales"] |
| 1235 | + mock_datasource.cache_timeout = None |
| 1236 | + mock_datasource.changed_on = None |
| 1237 | + mock_datasource.get_extra_cache_keys.return_value = [] |
| 1238 | + mock_datasource.database.extra = "{}" |
| 1239 | + mock_datasource.database.impersonate_user = False |
| 1240 | + mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None |
| 1241 | + |
| 1242 | + totals_query = QueryObject( |
| 1243 | + datasource=mock_datasource, |
| 1244 | + columns=[], |
| 1245 | + metrics=["sales"], |
| 1246 | + row_limit=1000, |
| 1247 | + ) |
| 1248 | + main_query = QueryObject( |
| 1249 | + datasource=mock_datasource, |
| 1250 | + columns=["region"], |
| 1251 | + metrics=["sales"], |
| 1252 | + row_limit=1000, |
| 1253 | + post_processing=[{"operation": "contribution", "options": {}}], |
| 1254 | + ) |
| 1255 | + |
| 1256 | + totals_query.validate = MagicMock() |
| 1257 | + main_query.validate = MagicMock() |
| 1258 | + |
| 1259 | + captured_limits: list[int | None] = [] |
| 1260 | + |
| 1261 | + def totals_cache_key(**kwargs: Any) -> str: |
| 1262 | + captured_limits.append(totals_query.row_limit) |
| 1263 | + return "totals-cache-key" |
| 1264 | + |
| 1265 | + totals_query.cache_key = totals_cache_key |
| 1266 | + main_query.cache_key = lambda **kwargs: "main-cache-key" |
| 1267 | + |
| 1268 | + mock_query_context = MagicMock() |
| 1269 | + mock_query_context.force = False |
| 1270 | + mock_query_context.datasource = mock_datasource |
| 1271 | + mock_query_context.queries = [main_query, totals_query] |
| 1272 | + mock_query_context.result_type = ChartDataResultType.FULL |
| 1273 | + mock_query_context.result_format = ChartDataResultFormat.JSON |
| 1274 | + mock_query_context.cache_values = { |
| 1275 | + "queries": [main_query.to_dict(), totals_query.to_dict()] |
| 1276 | + } |
| 1277 | + mock_query_context.get_query_result = MagicMock() |
| 1278 | + |
| 1279 | + processor = QueryContextProcessor(mock_query_context) |
| 1280 | + processor._qc_datasource = mock_datasource |
| 1281 | + mock_query_context.get_df_payload = processor.get_df_payload |
| 1282 | + mock_query_context.get_data = processor.get_data |
| 1283 | + |
| 1284 | + with patch( |
| 1285 | + "superset.common.query_context_processor.security_manager" |
| 1286 | + ) as mock_security_manager: |
| 1287 | + mock_security_manager.get_rls_cache_key.return_value = None |
| 1288 | + |
| 1289 | + with patch( |
| 1290 | + "superset.common.query_context_processor.QueryCacheManager" |
| 1291 | + ) as mock_cache_manager: |
| 1292 | + |
| 1293 | + def cache_get(*args: Any, **kwargs: Any) -> Any: |
| 1294 | + df = pd.DataFrame({"region": ["North"], "sales": [100]}) |
| 1295 | + cache = MagicMock() |
| 1296 | + cache.is_loaded = True |
| 1297 | + cache.df = df |
| 1298 | + cache.query = "SELECT 1" |
| 1299 | + cache.error_message = None |
| 1300 | + cache.status = QueryStatus.SUCCESS |
| 1301 | + cache.applied_template_filters = [] |
| 1302 | + cache.applied_filter_columns = [] |
| 1303 | + cache.rejected_filter_columns = [] |
| 1304 | + cache.annotation_data = {} |
| 1305 | + cache.is_cached = True |
| 1306 | + cache.sql_rowcount = len(df) |
| 1307 | + cache.cache_dttm = "2024-01-01T00:00:00" |
| 1308 | + return cache |
| 1309 | + |
| 1310 | + mock_cache_manager.get.side_effect = cache_get |
| 1311 | + |
| 1312 | + processor.get_payload(cache_query_context=False, force_cached=True) |
| 1313 | + |
| 1314 | + assert captured_limits == [None], "Totals query should be normalized before caching" |
| 1315 | + mock_query_context.get_query_result.assert_not_called() |
0 commit comments