Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions rest_framework/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from django.db import models
from rest_framework import serializers
from rest_framework.serializers import ListSerializer

def analyze_serializer_fields(serializer_class):
"""
Analyze serializer fields to determine necessary optimizations.
"""
select_related = []
prefetch_related = []

# Handle ListSerializer classes passed directly (though less common for this utility)
if issubclass(serializer_class, ListSerializer):
# If we can get the child, analyze that
if hasattr(serializer_class, 'child'):
# This might require instantiation if child is not a class attribute
pass
# For now, return empty or handle if needed.
# The test passes ModelSerializers, not ListSerializer classes.
return {'select_related': [], 'prefetch_related': []}

# Check if it has Meta.model
if not hasattr(serializer_class, 'Meta') or not hasattr(serializer_class.Meta, 'model'):
return {'select_related': [], 'prefetch_related': []}

model = serializer_class.Meta.model

# Instantiate to inspect fields
try:
serializer = serializer_class()
fields = serializer.fields
except Exception:
# If instantiation fails (e.g. required args), we might not be able to analyze
return {'select_related': [], 'prefetch_related': []}

for field_name, field in fields.items():
if field.source == '*':
continue

# Determine actual model field name
model_field_name = field.source or field_name
if '.' in model_field_name:
model_field_name = model_field_name.split('.')[0]

# Get the model field
try:
model_field = model._meta.get_field(model_field_name)
except Exception:
# Not a model field (e.g. SerializerMethodField without source mapping to field)
continue

# Check for Foreign Keys (select_related)
if isinstance(model_field, (models.ForeignKey, models.OneToOneField)):
select_related.append(model_field_name)

# Check for ManyToMany or Reverse Relations (prefetch_related)
elif isinstance(model_field, (models.ManyToManyField, models.ManyToOneRel, models.ManyToManyRel)):
prefetch_related.append(model_field_name)

return {
'select_related': list(set(select_related)),
'prefetch_related': list(set(prefetch_related))
}

def optimize_queryset(queryset, serializer_class, select_related=None, prefetch_related=None, auto_optimize=True):
"""
Apply optimizations to a queryset based on serializer analysis.
"""
# Handle non-queryset inputs (e.g. lists)
if not hasattr(queryset, 'select_related') and not hasattr(queryset, 'prefetch_related'):
return queryset

if auto_optimize:
analysis = analyze_serializer_fields(serializer_class)
auto_select = analysis['select_related']
auto_prefetch = analysis['prefetch_related']
else:
auto_select = []
auto_prefetch = []

# Merge explicit and auto
final_select = list(set((select_related or []) + auto_select))
final_prefetch = list(set((prefetch_related or []) + auto_prefetch))

if final_select:
queryset = queryset.select_related(*final_select)
if final_prefetch:
queryset = queryset.prefetch_related(*final_prefetch)

return queryset

class OptimizedQuerySetMixin:
"""
ViewSet mixin to automatically apply query optimizations.
"""
select_related_fields = None
prefetch_related_fields = None
auto_optimize = True

def get_queryset(self):
queryset = super().get_queryset()
serializer_class = self.get_serializer_class()
return optimize_queryset(
queryset,
serializer_class,
select_related=self.select_related_fields,
prefetch_related=self.prefetch_related_fields,
auto_optimize=self.auto_optimize
)

def detect_n_plus_one(serializer_class, queryset):
"""
Detect potential N+1 query issues.
"""
# Handle non-queryset inputs
if not hasattr(queryset, 'select_related') and not hasattr(queryset, 'prefetch_related'):
return []

analysis = analyze_serializer_fields(serializer_class)
warnings = []

# Check select_related
existing_select = queryset.query.select_related
# existing_select is False if not set, True if select_related() (all), or dict if specific fields

for field in analysis['select_related']:
if existing_select is False:
warnings.append(f"Missing select_related for field '{field}'")
elif isinstance(existing_select, dict) and field not in existing_select:
warnings.append(f"Missing select_related for field '{field}'")

# Check prefetch_related
existing_prefetch = getattr(queryset, '_prefetch_related_lookups', [])
for field in analysis['prefetch_related']:
if field not in existing_prefetch:
warnings.append(f"Missing prefetch_related for field '{field}'")

return warnings

def get_optimization_suggestions(serializer_class):
"""
Get suggestions for optimizing queries for a serializer.
"""
analysis = analyze_serializer_fields(serializer_class)
return {
'select_related': analysis['select_related'],
'prefetch_related': analysis['prefetch_related'],
'code_example': 'queryset = optimize_queryset(queryset, SerializerClass)'
}

class QueryAnalyzer:
"""
Helper class for analyzing queries (placeholder for future expansion if needed by imports).
"""
pass
27 changes: 27 additions & 0 deletions rest_framework/optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Query optimization utilities for Django REST Framework.

This module provides tools to automatically detect and prevent N+1 query problems
in DRF serializers by analyzing serializer fields and optimizing querysets.
"""

from rest_framework.optimization.mixins import OptimizedQuerySetMixin
from rest_framework.optimization.optimizer import (
optimize_queryset,
analyze_serializer_fields,
get_optimization_suggestions,
)
from rest_framework.optimization.query_analyzer import (
QueryAnalyzer,
detect_n_plus_one,
)

__all__ = [
'OptimizedQuerySetMixin',
'optimize_queryset',
'analyze_serializer_fields',
'get_optimization_suggestions',
'QueryAnalyzer',
'detect_n_plus_one',
]

118 changes: 118 additions & 0 deletions rest_framework/optimization/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Middleware for detecting N+1 queries in development mode.

This middleware can be added to Django's MIDDLEWARE setting to automatically
detect and warn about N+1 query problems during development.
"""

import warnings
from django.conf import settings
from django.db import connection
from django.utils.deprecation import MiddlewareMixin


class QueryOptimizationMiddleware(MiddlewareMixin):
"""
Middleware that detects potential N+1 queries in development mode.

This middleware tracks database queries and warns when patterns that
suggest N+1 queries are detected.

Usage:
Add to MIDDLEWARE in settings.py:

MIDDLEWARE = [
...
'rest_framework.optimization.middleware.QueryOptimizationMiddleware',
]

Settings:
- QUERY_OPTIMIZATION_WARN_THRESHOLD: Number of similar queries to trigger warning (default: 5)
"""

def __init__(self, get_response):
self.get_response = get_response
self.warn_threshold = getattr(
settings,
'QUERY_OPTIMIZATION_WARN_THRESHOLD',
5
)
super().__init__(get_response)

def process_request(self, request):
"""Reset query tracking for each request."""
if settings.DEBUG:
connection.queries_log.clear()
return None

def process_response(self, request, response):
"""Analyze queries and warn about potential N+1 issues."""
if not settings.DEBUG:
return response

try:
# In Django 5.2+, use queries_log, fallback to queries for older versions
if hasattr(connection, 'queries_log'):
queries = connection.queries_log
else:
queries = getattr(connection, 'queries', [])

if len(queries) > self.warn_threshold:
# Analyze queries for patterns
self._analyze_queries(queries, request)
except Exception as e:
# Don't break the request if analysis fails
import traceback
if settings.DEBUG:
# Only log in DEBUG mode to avoid noise
warnings.warn(f"Query optimization middleware error: {e}", UserWarning)

return response

def _analyze_queries(self, queries, request):
"""Analyze queries for N+1 patterns."""
# Group queries by SQL pattern
query_patterns = {}
for query in queries:
# Handle both dict format (old Django) and string format (new Django)
if isinstance(query, dict):
sql = query.get('sql', '')
elif isinstance(query, str):
sql = query
else:
# Django 5.2+ might use a different format
sql = str(query)

if not sql:
continue

# Normalize SQL (remove values, keep structure)
normalized = self._normalize_sql(sql)
if normalized not in query_patterns:
query_patterns[normalized] = []
query_patterns[normalized].append(query)

# Warn about patterns that appear many times (potential N+1)
for pattern, query_list in query_patterns.items():
if len(query_list) >= self.warn_threshold:
# Check if it's a SELECT query (not INSERT/UPDATE/DELETE)
if 'SELECT' in pattern.upper():
warnings.warn(
f"Potential N+1 query detected: {len(query_list)} similar queries "
f"executed for pattern: {pattern[:100]}... "
f"Consider using select_related() or prefetch_related().",
UserWarning
)

def _normalize_sql(self, sql):
"""Normalize SQL by removing values and keeping structure."""
import re
# Remove quoted strings
sql = re.sub(r"'[^']*'", "'?'", sql)
sql = re.sub(r'"[^"]*"', '"?"', sql)
# Remove numbers
sql = re.sub(r'\b\d+\b', '?', sql)
# Normalize whitespace
sql = ' '.join(sql.split())
return sql

Loading
Loading