Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
import re
from typing import Any, cast, ClassVar, TYPE_CHECKING
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING

import pandas as pd
from flask import current_app
Expand Down Expand Up @@ -251,9 +251,13 @@ def get_data(

return df.to_dict(orient="records")

def ensure_totals_available(self) -> None:
queries_needing_totals = []
totals_queries = []
def _prepare_contribution_totals(self) -> tuple[list[int], int | None]:
"""
Identify contribution queries and normalize the totals query so cache keys
align with cached results.
"""
queries_needing_totals: list[int] = []
totals_idx: int | None = None

for i, query in enumerate(self._query_context.queries):
needs_totals = any(
Expand All @@ -267,17 +271,28 @@ def ensure_totals_available(self) -> None:
is_totals_query = (
not query.columns and query.metrics and not query.post_processing
)
if is_totals_query:
totals_queries.append(i)
if is_totals_query and totals_idx is None:
totals_idx = i

if queries_needing_totals and totals_idx is not None:
totals_query = self._query_context.queries[totals_idx]
totals_query.row_limit = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is not the scope of the PR (previously we were already doing this) but do you know why the query gets here with a row_limit? I would imagine the client/chart to send the query correctly?


return queries_needing_totals, totals_idx

if not queries_needing_totals or not totals_queries:
def ensure_totals_available(
self,
queries_needing_totals: Sequence[int] | None = None,
totals_idx: int | None = None,
) -> None:
if queries_needing_totals is None or totals_idx is None:
queries_needing_totals, totals_idx = self._prepare_contribution_totals()

if not queries_needing_totals or totals_idx is None:
return

totals_idx = totals_queries[0]
totals_query = self._query_context.queries[totals_idx]

totals_query.row_limit = None

result = self._query_context.get_query_result(totals_query)
df = result.df

Expand All @@ -299,10 +314,12 @@ def get_payload(
) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""

queries_needing_totals, totals_idx = self._prepare_contribution_totals()

# Skip ensure_totals_available when force_cached=True
# This prevents recalculating contribution_totals from cached results
if not force_cached:
self.ensure_totals_available()
self.ensure_totals_available(queries_needing_totals, totals_idx)

# Update cache_values to reflect modifications made by
# ensure_totals_available()
Expand Down
98 changes: 97 additions & 1 deletion tests/unit_tests/common/test_query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest

from superset.common.chart_data import ChartDataResultFormat
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_context_processor import QueryContextProcessor
from superset.utils.core import GenericDataType

Expand Down Expand Up @@ -1217,3 +1219,97 @@ def test_cache_key_non_contribution_post_processing_unchanged():
assert query1.cache_key() != query2.cache_key(), (
"Cache keys should differ for different non-contribution post_processing"
)


def test_force_cached_normalizes_totals_query_row_limit():
"""
When fetching from cache (force_cached=True), the totals query should still be
normalized so its cache key matches the cached entry, but the totals query should
not be executed.
"""
from superset.common.query_object import QueryObject

mock_datasource = MagicMock()
mock_datasource.uid = "test_datasource"
mock_datasource.column_names = ["region", "sales"]
mock_datasource.cache_timeout = None
mock_datasource.changed_on = None
mock_datasource.get_extra_cache_keys.return_value = []
mock_datasource.database.extra = "{}"
mock_datasource.database.impersonate_user = False
mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None

totals_query = QueryObject(
datasource=mock_datasource,
columns=[],
metrics=["sales"],
row_limit=1000,
)
main_query = QueryObject(
datasource=mock_datasource,
columns=["region"],
metrics=["sales"],
row_limit=1000,
post_processing=[{"operation": "contribution", "options": {}}],
)

totals_query.validate = MagicMock()
main_query.validate = MagicMock()

captured_limits: list[int | None] = []

def totals_cache_key(**kwargs: Any) -> str:
captured_limits.append(totals_query.row_limit)
return "totals-cache-key"

totals_query.cache_key = totals_cache_key
main_query.cache_key = lambda **kwargs: "main-cache-key"

mock_query_context = MagicMock()
mock_query_context.force = False
mock_query_context.datasource = mock_datasource
mock_query_context.queries = [main_query, totals_query]
mock_query_context.result_type = ChartDataResultType.FULL
mock_query_context.result_format = ChartDataResultFormat.JSON
mock_query_context.cache_values = {
"queries": [main_query.to_dict(), totals_query.to_dict()]
}
mock_query_context.get_query_result = MagicMock()

processor = QueryContextProcessor(mock_query_context)
processor._qc_datasource = mock_datasource
mock_query_context.get_df_payload = processor.get_df_payload
mock_query_context.get_data = processor.get_data

with patch(
"superset.common.query_context_processor.security_manager"
) as mock_security_manager:
mock_security_manager.get_rls_cache_key.return_value = None

with patch(
"superset.common.query_context_processor.QueryCacheManager"
) as mock_cache_manager:

def cache_get(*args: Any, **kwargs: Any) -> Any:
df = pd.DataFrame({"region": ["North"], "sales": [100]})
cache = MagicMock()
cache.is_loaded = True
cache.df = df
cache.query = "SELECT 1"
cache.error_message = None
cache.status = QueryStatus.SUCCESS
cache.applied_template_filters = []
cache.applied_filter_columns = []
cache.rejected_filter_columns = []
cache.annotation_data = {}
cache.is_cached = True
cache.sql_rowcount = len(df)
cache.cache_dttm = "2024-01-01T00:00:00"
return cache

mock_cache_manager.get.side_effect = cache_get

processor.get_payload(cache_query_context=False, force_cached=True)

assert captured_limits == [None], "Totals query should be normalized before caching"
mock_query_context.get_query_result.assert_not_called()
Loading