Skip to content

Commit 61e31ed

Browse files
committed
Add router to read and write in the cache
1 parent 1d6c306 commit 61e31ed

File tree

4 files changed

+86
-28
lines changed

4 files changed

+86
-28
lines changed

django_mongodb_backend/cache.py

+29-16
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.core.cache.backends.db import Options
66
from django.db import connections, router
77
from django.utils.functional import cached_property
8+
from pymongo import IndexModel
89
from pymongo.errors import DuplicateKeyError
910

1011

@@ -41,20 +42,29 @@ class CacheEntry:
4142
self.cache_model_class = CacheEntry
4243

4344
def create_indexes(self):
44-
self.collection.create_index("expires_at", expireAfterSeconds=0)
45-
self.collection.create_index("key", unique=True)
45+
expires_index = IndexModel("expires_at", expireAfterSeconds=0)
46+
key_index = IndexModel("key", unique=True)
47+
self.collection_to_write.create_indexes([expires_index, key_index])
4648

4749
@cached_property
4850
def serializer(self):
4951
return MongoSerializer(self.pickle_protocol)
5052

5153
@property
52-
def _db(self):
54+
def _db_to_read(self):
5355
return connections[router.db_for_read(self.cache_model_class)]
5456

5557
@property
56-
def collection(self):
57-
return self._db.get_collection(self._collection_name)
58+
def collection_to_read(self):
59+
return self._db_to_read.get_collection(self._collection_name)
60+
61+
@property
62+
def _db_to_write(self):
63+
return connections[router.db_for_write(self.cache_model_class)]
64+
65+
@property
66+
def collection_to_write(self):
67+
return self._db_to_read.get_collection(self._collection_name)
5868

5969
def get(self, key, default=None, version=None):
6070
result = self.get_many([key], version)
@@ -71,18 +81,18 @@ def get_many(self, keys, version=None):
7181
if not keys:
7282
return {}
7383
keys_map = {self.make_and_validate_key(key, version=version): key for key in keys}
74-
with self.collection.find(
84+
with self.collection_to_read.find(
7585
{"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)}
7686
) as cursor:
7787
return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor}
7888

7989
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
8090
key = self.make_and_validate_key(key, version=version)
8191
serialized_data = self.serializer.dumps(value)
82-
num = self.collection.count_documents({})
92+
num = self.collection_to_write.count_documents({})
8393
if num >= self._max_entries:
8494
self._cull(num)
85-
return self.collection.update_one(
95+
return self.collection_to_write.update_one(
8696
{"key": key},
8797
{
8898
"$set": {
@@ -97,11 +107,11 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
97107
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
98108
key = self.make_and_validate_key(key, version=version)
99109
serialized_data = self.serializer.dumps(value)
100-
num = self.collection.count_documents({})
110+
num = self.collection_to_write.count_documents({})
101111
if num >= self._max_entries:
102112
self._cull(num)
103113
try:
104-
self.collection.update_one(
114+
self.collection_to_write.update_one(
105115
{"key": key, **self._filter_expired(expired=True)},
106116
{
107117
"$set": {
@@ -124,7 +134,7 @@ def _cull(self, num):
124134
try:
125135
# Delete the first expiration date.
126136
deleted_from = next(
127-
self.collection.aggregate(
137+
self.collection_to_write.aggregate(
128138
[
129139
{"$sort": {"expires_at": -1, "key": 1}},
130140
{"$skip": keep_num},
@@ -136,7 +146,7 @@ def _cull(self, num):
136146
except StopIteration:
137147
pass
138148
else:
139-
self.collection.delete_many(
149+
self.collection_to_write.delete_many(
140150
{
141151
"$or": [
142152
{"expires_at": {"$lt": deleted_from["expires_at"]}},
@@ -152,7 +162,7 @@ def _cull(self, num):
152162

153163
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
154164
key = self.make_and_validate_key(key, version=version)
155-
res = self.collection.update_one(
165+
res = self.collection_to_write.update_one(
156166
{"key": key}, {"$set": {"expires_at": self._get_expiration_time(timeout)}}
157167
)
158168
return res.matched_count > 0
@@ -173,13 +183,16 @@ def _delete_many(self, keys, version=None):
173183
if not keys:
174184
return False
175185
keys = tuple(self.make_and_validate_key(key, version=version) for key in keys)
176-
return bool(self.collection.delete_many({"key": {"$in": keys}}).deleted_count)
186+
return bool(self.collection_to_write.delete_many({"key": {"$in": keys}}).deleted_count)
177187

178188
def has_key(self, key, version=None):
179189
key = self.make_and_validate_key(key, version=version)
180190
return (
181-
self.collection.count_documents({"key": key, **self._filter_expired(expired=False)}) > 0
191+
self.collection_to_read.count_documents(
192+
{"key": key, **self._filter_expired(expired=False)}
193+
)
194+
> 0
182195
)
183196

184197
def clear(self):
185-
self.collection.delete_many({})
198+
self.collection_to_write.delete_many({})

django_mongodb_backend/creation.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def create_test_db(self, *args, **kwargs):
2525
# Create cache collections
2626
for cache_alias in settings.CACHES:
2727
cache = caches[cache_alias]
28-
connection = cache._db
29-
if cache._collection_name in connection.introspection.table_names():
30-
continue
31-
cache = MongoDBCache(cache._collection_name, {})
32-
cache.create_indexes()
28+
if isinstance(cache, MongoDBCache):
29+
connection = cache._db_to_write
30+
if cache._collection_name in connection.introspection.table_names():
31+
continue
32+
cache.create_indexes()
3333
return test_database_name

django_mongodb_backend/management/commands/createcachecollection.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@ def handle(self, *collection_names, **options):
4343
for cache_alias in settings.CACHES:
4444
cache = caches[cache_alias]
4545
if isinstance(cache, MongoDBCache):
46-
self.check_collection(db, cache._collection_name)
46+
self.check_collection(db, cache)
4747

48-
def check_collection(self, database, collection_name):
49-
cache = MongoDBCache(collection_name, {})
48+
def check_collection(self, database, cache):
5049
if not router.allow_migrate_model(database, cache.cache_model_class):
5150
return
5251
connection = connections[database]
53-
if collection_name in connection.introspection.table_names():
52+
if cache._collection_name in connection.introspection.table_names():
5453
if self.verbosity > 0:
55-
self.stdout.write("Cache collection '%s' already exists." % collection_name)
54+
self.stdout.write("Cache collection '%s' already exists." % cache._collection_name)
5655
return
5756
cache.create_indexes()

tests/cache_/tests.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def test_get_or_set_racing(self):
929929
self.assertEqual(cache.get_or_set("key", "default"), "default")
930930

931931
def test_collection_has_indexes(self):
932-
indexes = list(cache.collection.list_indexes())
932+
indexes = list(cache.collection_to_read.list_indexes())
933933
self.assertTrue(
934934
any(
935935
index["key"] == SON([("expires_at", 1)]) and index.get("expireAfterSeconds") == 0
@@ -962,7 +962,7 @@ def setUp(self):
962962
self.addCleanup(self.drop_collection)
963963

964964
def drop_collection(self):
965-
cache.collection.drop()
965+
cache.collection_to_write.drop()
966966

967967
def create_cache_collection(self):
968968
management.call_command("createcachecollection", verbosity=0)
@@ -971,3 +971,49 @@ def create_cache_collection(self):
971971
@override_settings(USE_TZ=True)
972972
class DBCacheWithTimeZoneTests(DBCacheTests):
973973
pass
974+
975+
976+
class DBCacheRouter:
977+
"""A router that puts the cache table on the 'other' database."""
978+
979+
def db_for_read(self, model, **hints):
980+
if model._meta.app_label == "django_cache":
981+
return "other"
982+
return None
983+
984+
def db_for_write(self, model, **hints):
985+
if model._meta.app_label == "django_cache":
986+
return "other"
987+
return None
988+
989+
def allow_migrate(self, db, app_label, **hints):
990+
if app_label == "django_cache":
991+
return db == "other"
992+
return None
993+
994+
995+
@override_settings(
996+
CACHES={
997+
"default": {
998+
"BACKEND": "django_mongodb_backend.cache.MongoDBCache",
999+
"LOCATION": "my_cache_table",
1000+
},
1001+
},
1002+
)
1003+
@modify_settings(
1004+
INSTALLED_APPS={"prepend": "django_mongodb_backend"},
1005+
)
1006+
class CreateCacheTableForDBCacheTests(TestCase):
1007+
databases = {"default", "other"}
1008+
1009+
@override_settings(DATABASE_ROUTERS=[DBCacheRouter()])
1010+
def test_createcachetable_observes_database_router(self):
1011+
# cache table should not be created on 'default'
1012+
with self.assertNumQueries(0, using="default"):
1013+
management.call_command("createcachecollection", database="default", verbosity=0)
1014+
# cache table should be created on 'other'
1015+
# Queries:
1016+
# 1: Create indexes
1017+
num = 1
1018+
with self.assertNumQueries(num, using="other"):
1019+
management.call_command("createcachecollection", database="other", verbosity=0)

0 commit comments

Comments
 (0)