diff --git a/django_redis/cache.py b/django_redis/cache.py index ebf4d14b..eec06480 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -183,3 +183,59 @@ def close(self, **kwargs): @omit_exception def touch(self, *args, **kwargs): return self.client.touch(*args, **kwargs) + + @omit_exception + def sadd(self, *args, **kwargs): + return self.client.sadd(*args, **kwargs) + + @omit_exception + def scard(self, *args, **kwargs): + return self.client.scard(*args, **kwargs) + + @omit_exception + def sdiff(self, *args, **kwargs): + return self.client.sdiff(*args, **kwargs) + + @omit_exception + def sdiffstore(self, *args, **kwargs): + return self.client.sdiffstore(*args, **kwargs) + + @omit_exception + def sinter(self, *args, **kwargs): + return self.client.sinter(*args, **kwargs) + + @omit_exception + def sinterstore(self, *args, **kwargs): + return self.client.sinterstore(*args, **kwargs) + + @omit_exception + def sismember(self, *args, **kwargs): + return self.client.sismember(*args, **kwargs) + + @omit_exception + def smembers(self, *args, **kwargs): + return self.client.smembers(*args, **kwargs) + + @omit_exception + def smove(self, *args, **kwargs): + return self.client.smove(*args, **kwargs) + + @omit_exception + def spop(self, *args, **kwargs): + return self.client.spop(*args, **kwargs) + + @omit_exception + def srandmember(self, *args, **kwargs): + return self.client.srandmember(*args, **kwargs) + + @omit_exception + def srem(self, *args, **kwargs): + return self.client.srem(*args, **kwargs) + + @omit_exception + def sunion(self, *args, **kwargs): + return self.client.sunion(*args, **kwargs) + + @omit_exception + def sunionstore(self, *args, **kwargs): + return self.client.sunionstore(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 6886b46b..75cd7240 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -776,3 +776,196 @@ def touch( # Convert to milliseconds timeout = int(timeout * 1000) return bool(client.pexpire(key, timeout)) + + def sadd( + self, + key: Any, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + values = [self.encode(value) for value in values] + return int(client.sadd(key, *values)) + + def scard( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return int(client.scard(key)) + + def sdiff( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sdiff(*keys)} + + def sdiffstore( + self, + dest: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sdiffstore(dest, *keys)) + + + def sinter( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sinter(*keys)} + + def sinterstore( + self, + dest: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sinterstore(dest, *keys)) + + def sismember( + self, + key: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + member = self.encode(member) + return bool(client.sismember(key, member)) + + def smembers( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return {self.decode(value) for value in client.smembers(key)} + + def smove( + self, + source: Any, + destination: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=True) + + source = self.make_key(source, version=version) + destination = self.make_key(destination) + member = self.encode(member) + return bool(client.smove(source, destination, member)) + + def spop( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[set, Any]: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + result = client.spop(key, count) + if type(result) == list: + return {self.decode(value) for value in result} + return self.decode(result) + + def srandmember( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[set, Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + result = client.srandmember(key, count) + if type(result) == list: + return {self.decode(value) for value in result} + return self.decode(result) + + def srem( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + members = [self.decode(member) for member in members] + return int(client.srem(key, *members)) + + def sunion( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sunion(*keys)} + + def sunionstore( + self, + destination: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + destination = self.make_key(destination, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sunionstore(destination, *keys)) diff --git a/tests/test_backend.py b/tests/test_backend.py index 165779ce..9d1e30c1 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -771,3 +771,80 @@ def test_clear(self, cache: RedisCache): cache.clear() value_from_cache_after_clear = cache.get("foo") assert value_from_cache_after_clear is None + + def test_sadd(self, cache: RedisCache): + assert cache.sadd("foo", "bar") == 1 + assert cache.smembers("foo") == {"bar"} + + def test_scard(self, cache: RedisCache): + cache.sadd("foo", "bar", "bar2") + assert cache.scard("foo") == 2 + + def test_sdiff(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiff("foo1", "foo2") == {"bar1"} + + def test_sdiffstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiffstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sinter(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinter("foo1", "foo2") == {"bar2"} + + def test_interstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinterstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar2"} + + def test_sismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.sismember("foo", "bar") is True + assert cache.sismember("foo", "bar2") is False + + def test_smove(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.smove("foo1", "foo2", "bar1") is True + assert cache.smove("foo1", "foo2", "bar4") is False + assert cache.smembers("foo1") == {"bar2"} + assert cache.smembers("foo2") == {"bar1", "bar2", "bar3"} + + def test_spop_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo") in {"bar1", "bar2"} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_spop(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo", 1) in {{"bar1"}, {"bar2"}} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_srandmember_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo") in {"bar1", "bar2"} + + def test_srandmember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo", 1) in {{"bar1"}, {"bar2"}} + + def test_srem(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srem("foo", "bar1") == 1 + assert cache.srem("foo", "bar3") == 0 + + def test_sunion(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunion("foo1", "foo2") == {"bar1", "bar2", "bar3"} + + def test_sunionstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunionstore("foo3", "foo1", "foo2") == 3 + assert cache.smembers("foo3") == {"bar1", "bar2", "bar3"}