diff --git a/celery_unique.py b/celery_unique.py deleted file mode 100644 index 8461e92..0000000 --- a/celery_unique.py +++ /dev/null @@ -1,174 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals - -import datetime - - -UNIQUE_REDIS_KEY_PREFIX = 'celery_unique' - - -class UniqueTaskMixin(object): - abstract = True - unique_key = None - redis_client = None - - def apply_async(self, args=None, kwargs=None, task_id=None, producer=None, link=None, link_error=None, **options): - """Apply tasks asynchronously by sending a message. - - This method serves either as a wrapper for `celery.Task.apply_async()` or, if the task decorator - was configured with a `redis_client`, callable `unique_key` and `apply_async()` was called with - either an `eta` or `countdown` argument, the task will be treated as unique. In these cases, - this method will first revoke any extant task which matches the same unique key configuration - before proceeding to publish the task. Before returning, a unique task's identifying unique key - will be saved to Redis as a key, with its task id (provided by the newly-created `AsyncResult` instance) - serving as the value. - - @see `celery.Task.apply_async()` - """ - should_handle_as_unique_task = ( - callable(self.unique_key) - and ('eta' in options.keys() or 'countdown' in options.keys()) - and self.redis_client is not None - ) - - if should_handle_as_unique_task: - # Generate the unique redis key and revoke any task that shares the same key (if one exists) - unique_redis_key = self._make_redis_key(args, kwargs) - self._revoke_extant_unique_task_if_exists(unique_redis_key) - - # Pass the task along to Celery for publishing and intercept the AsyncResult return value - rv = super(UniqueTaskMixin, self).apply_async(args, kwargs, task_id, producer, link, link_error, **options) - - if should_handle_as_unique_task: - # Create a Redis key/value pair to serve as a tracking record for the newly-created task. - # The new record will be given a TTL that allows it to expire (approximately) at the same time - # that the task is executed. - ttl = self._make_ttl_for_unique_task_record(options) - self._create_unique_task_record(unique_redis_key, rv.task_id, ttl) - - return rv - - def _make_redis_key(self, callback_args, callback_kwargs): - """Creates a key used to identify the task's unique configuration in Redis. - - @note All positional arguments and/or keyword arguments sent to the task are applied identically to - the task's bound `unique_key` callable. - - @param callback_args: The positional arguments which will be passed to the task when it executes - @type callback_args: list | tuple - @param callback_kwargs: The keyword arguments which will be passed to the task when it executes - @type callback_kwargs: dict - - @return: The key which will be used to find any extant version of this task which, if found, - will by revoked. Keys are built by using three colon-delimited components: - 1. A global prefix used to identify that the key/value pair in Redis was created to track - a unique Celery task (by default, this is "celery_unique") - 2. The name of the task (usually the Python dot-notation path to the function) - 3. The value produced by the `key_generator` callable when supplied with the task's callback - arguments. - @rtype: unicode - """ - # Get the unbound lambda used to create `self.unique_key` if the inner function exists - key_generator = self.unique_key.__func__ if hasattr(self.unique_key, '__func__') else self.unique_key - - # Create and return the redis key with the generated unique key suffix - return '{prefix}:{task_name}:{unique_key}'.format( - prefix=UNIQUE_REDIS_KEY_PREFIX, - task_name=self.name, - unique_key=key_generator( - *(callback_args or ()), - **(callback_kwargs or {}) - ) - ) - - def _revoke_extant_unique_task_if_exists(self, redis_key): - """Given a Redis key, deletes the corresponding record if one exists. - - @param redis_key: The string (potentially) used by Redis as the key for the record - @type redis_key: str | unicode - """ - task_id = self.redis_client.get(redis_key) - if task_id is not None: - self.app.AsyncResult(task_id).revoke() - self.redis_client.delete(redis_key) - - def _create_unique_task_record(self, redis_key, task_id, ttl): - """Creates a new Redis key/value pair for the recently-published unique task. - - @param redis_key: The unique key which identifies the task and its configuration (expected to be produced - by the `UniqueTaskMixin._make_redis_key()` method). - @type redis_key: str | unicode - @param task_id: The ID of the recently-published unique task, which will be used as the Redis value - @param ttl: The TTL for the Redis record, which should be (approximately) equal to the number of seconds - remaining until the earliest time that the task is expected to be executed by Celery. - """ - self.redis_client.set(redis_key, task_id, ex=ttl) - - @staticmethod - def _make_ttl_for_unique_task_record(task_options): - """Given the options provided to `apply_async()` as keyword arguments, determines the appropriate - TTL to ensure that a unique task record in Redis expires (approximately) at the same time as the earliest - time that the task is expected to be executed by Celery. - - The TTL value will be determined by examining the following values, in order of preference: - - The `eta` keyword argument passed to `apply_async()`, if any. If this value is found, - then the TTL will be the number of seconds between now and the ETA datetime. - - The `countdown` keyword argument passed to `apply_async()`, which will theoretically always - exist if `eta` was not provided. If this value is used, the TTL will be equal. - - Additionally, if an `expires` keyword argument was passed, and its value represents (either as an integer - or timedelta) a shorter duration of time than the values provided by `eta` or `countdown`, the TTL will be - reduced to the value of `countdown`. - - Finally, the TTL value returned by this method will always be greater than or equal to 1, in order to ensure - compatibility with Redis' TTL requirements, and that a record produced for a nonexistent task will only - live for a maximum of 1 second. - - @param task_options: The values passed as additional keyword arguments to `apply_async()` - @type task_options: dict - - @return: The TTL (in seconds) for the Redis record to-be-created - @rtype: int - """ - # Set a default TTL as 1 second (in case actual TTL already occurred) - ttl_seconds = 1 - - option_keys = task_options.keys() - if 'eta' in option_keys: - # Get the difference between the ETA and now (relative to the ETA's timezone) - ttl_seconds = int( - (task_options['eta'] - datetime.datetime.now(tz=task_options['eta'].tzinfo)).total_seconds() - ) - elif 'countdown' in option_keys: - ttl_seconds = task_options['countdown'] - - if 'expires' in option_keys: - if isinstance(task_options['expires'], datetime.datetime): - # Get the difference between the countdown and now (relative to the countdown's timezone) - seconds_until_expiry = int( - (task_options['expires'] - datetime.datetime.now(tz=task_options['expires'].tzinfo)).total_seconds() - ) - else: - seconds_until_expiry = task_options['expires'] - if seconds_until_expiry < ttl_seconds: - ttl_seconds = seconds_until_expiry - - if ttl_seconds <= 0: - ttl_seconds = 1 - - return ttl_seconds - - -def unique_task_factory(task_cls): - """Creates a new, abstract Celery Task class that enables properly-configured Celery tasks to uniquely exist. - - @param task_cls: The original base class which should used with UniqueTaskMixin to produce a new Celery task - base class. - @type task_cls: type - - @return: The new Celery task base class with unique task-handling functionality mixed in. - @rtype: type - """ - return type(str('UniqueTask'), (UniqueTaskMixin, task_cls), {}) diff --git a/celery_unique/__init__.py b/celery_unique/__init__.py new file mode 100644 index 0000000..5082f23 --- /dev/null +++ b/celery_unique/__init__.py @@ -0,0 +1,8 @@ +from celery_unique.core import unique_task_factory +from celery_unique.core import UniqueTaskMixin + + +__all__ = [ + 'unique_task_factory', + 'UniqueTaskMixin', +] diff --git a/celery_unique/backends.py b/celery_unique/backends.py new file mode 100644 index 0000000..eccd62d --- /dev/null +++ b/celery_unique/backends.py @@ -0,0 +1,83 @@ +class BaseBackend: + """ + Abstract reference backend. + + An abstract backend that defines the interface that other backends must + implement. + """ + + def create_task_record(self, key, task_id, ttl): # pragma: no cover + """ + Creates a new record for the recently-published unique task. + + :param str key: The unique key which identifies the task and its + configuration. + :param str task_id: The ID of the recently-published unique task. + :param ttl: The TTL for the record, which should be (approximately) + equal to the number of seconds remaining until the earliest time + that the task is expected to be executed by Celery. + """ + raise NotImplementedError() + + def get_task_id(self, key): # pragma: no cover + """ + Returns the task_id for an exiting task + """ + raise NotImplementedError() + + def revoke_extant_task(self, key): # pragma: no cover + """ + Deletes a task for a given key. + + This deletes both the task and the cache entry. + + :param key: The string (potentially) used by the backend as the key for + the record. + :type redis_key: str | unicode + """ + raise NotImplementedError() + + +class RedisBackend(BaseBackend): + """ + A uniqueness backend that uses redis as a key/value store. + + See :class:`~.BaseBackend` for documentation on indivitual methods. + """ + + def __init__(self, redis_client): + self.redis_client = redis_client + + def get_task_id(self, key): + task_id = self.redis_client.get(key) + return task_id.decode() if task_id else None + + def revoke_extant_task(self, key): + self.redis_client.delete(key) + + def create_task_record(self, key, task_id, ttl): + self.redis_client.set(key, task_id, ex=ttl) + + +class InMemoryBackend(BaseBackend): + """ + Dummy backend which uses an in-memory store. + + This is a dummy backend which uses an in-memory store. It is mostly + suitable for development and testing, and should not be using in production + environments. + + See :class:`~.BaseBackend` for documentation on indivitual methods. + """ + + def __init__(self): + self.tasks = {} + + def get_task_id(self, key): + return self.tasks.get(key, None) + + def revoke_extant_task(self, key): + self.tasks.pop(key, None) + + def create_task_record(self, key, task_id, ttl=None): + self.tasks[key] = task_id diff --git a/celery_unique/core.py b/celery_unique/core.py new file mode 100644 index 0000000..cc6beed --- /dev/null +++ b/celery_unique/core.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import datetime + +UNIQUE_KEY_PREFIX = 'celery_unique' + + +class UniqueTaskMixin(object): + abstract = True + unique_key = None + unique_backend = None + + def apply_async(self, args=(), kwargs={}, task_id=None, **options): + """ + Apply tasks asynchronously by sending a message. + + This method serves either as a wrapper for `celery.Task.apply_async()` + or, if the task decorator was configured with a `unique_backend`, + callable `unique_key` and `apply_async()` was called with either an + `eta` or `countdown` argument, the task will be treated as unique. + + In these cases, this method will first revoke any extant task which + matches the same unique key configuration before proceeding to publish + the task. Before returning, a unique task's identifying unique key + will be saved to Redis as a key, with its task id (provided by the + newly-created `AsyncResult` instance) serving as the value. + + See ``celery.Task.apply_async()`` + + :param func unique_key: Function used to generate a unique key to + identify this task. The function will take receive the same args + and kwargs the task is passed. + :param UniquenessBackend backend: A backend to use to cache queued + tasks and to determine is a task is unique or not. + """ + should_handle_as_unique_task = self._handle_as_unique(options) + + if should_handle_as_unique_task: + # Generate the unique key and revoke any task that shares the same + # key (if one exists) + key = self.make_key(args, kwargs) + self._revoke_extant_task(key) + + # Pass the task along to Celery for publishing and intercept the + # AsyncResult return value + rv = super(UniqueTaskMixin, self).apply_async(args, kwargs, task_id, **options) + + if should_handle_as_unique_task: + # Inform the backend of this tasks. + # The new record will be given a TTL that allows it to expire + # (approximately) at the same time that the task is executed. + ttl = self._make_ttl_for_unique_task_record(options) + self.unique_backend.create_task_record(key, rv.task_id, ttl) + + return rv + + def _handle_as_unique(self, options): + """ + Determines if a task should be handles as unique. + + :param dict options: The options dict passed to `apply_async`. + """ + return ( + callable(self.unique_key) + and ('eta' in options or 'countdown' in options) + and self.unique_backend is not None + ) + + def _revoke_extant_task(self, key): + """ + Removes an extant task + + Not that this removed the key from the store regardless of whether the + task is still queued or not. + """ + task_id = self.unique_backend.get_task_id(key) + if task_id is not None: + self.app.AsyncResult(task_id).revoke() + self.unique_backend.revoke_extant_task(key) + + def make_key(self, callback_args, callback_kwargs): + """ + Creates a key used to identify the task's unique configuration in Redis. + + Note: All positional arguments and/or keyword arguments sent to the + task are applied identically to the task's bound `unique_key` callable. + + :param callback_args: The positional arguments which will be passed to + the task when it executes @type callback_args: list | tuple + :param callback_kwargs: The keyword arguments which will be passed to + the task when it executes @type callback_kwargs: dict + + :return: The key which will be used to find any extant version of this + task which, if found, will by revoked. Keys are built by using three + colon-delimited components:: + + 1. A global prefix used to identify that the key/value pair in + the backend was created to track + a unique Celery task (by default, this is "celery_unique") + 2. The name of the task (usually the Python dot-notation path to + the function) + 3. The value produced by the `key_generator` callable when supplied + with the task's callback arguments. + + :rtype: unicode + """ + + # Get the unbound lambda used to create `self.unique_key` if the inner + # function exists + if hasattr(self.unique_key, '__func__'): + key_generator = self.unique_key.__func__ + else: + key_generator = self.unique_key + + # Create and return the redis key with the generated unique key suffix + return '{prefix}:{task_name}:{unique_key}'.format( + prefix=UNIQUE_KEY_PREFIX, + task_name=self.name, + unique_key=key_generator(*callback_args, **callback_kwargs), + ) + + @staticmethod + def _make_ttl_for_unique_task_record(task_options): + """ + Calculate an aproximate TTL for an enqueued task. + + Given the options provided to `apply_async()` as keyword arguments, + determines the appropriate TTL to ensure that a unique task record + expires (approximately) at the same time as the earliest time that the + task is expected to be executed by Celery. + + The TTL value will be determined by examining the following values, in + order of preference: + - The `eta` keyword argument passed to `apply_async()`, if any. If + this value is found, then the TTL will be the number of seconds + between now and the ETA datetime. + - The `countdown` keyword argument passed to `apply_async()`, which + will theoretically always exist if `eta` was not provided. If + this value is used, the TTL will be equal. + + Additionally, if an `expires` keyword argument was passed, and its + value represents (either as an integer or timedelta) a shorter duration + of time than the values provided by `eta` or `countdown`, the TTL will + be reduced to the value of `countdown`. + + Finally, the TTL value returned by this method will always be greater + than or equal to 1, in order to ensure compatibility with different + backend's TTL requirements, and that a record produced for a + nonexistent task will only live for a maximum of 1 second. + + :param dict task_options: The values passed as additional keyword + arguments to `apply_async()` + + :return: The TTL (in seconds) for the Redis record to-be-created + :rtype: int + """ + # Set a default TTL as 1 second (in case actual TTL already occurred) + ttl_seconds = 1 + + option_keys = task_options.keys() + if 'eta' in option_keys: + # Get the difference between the ETA and now (relative to the ETA's timezone) + ttl_seconds = int( + (task_options['eta'] - datetime.datetime.now(tz=task_options['eta'].tzinfo)).total_seconds() + ) + elif 'countdown' in option_keys: + ttl_seconds = task_options['countdown'] + + if 'expires' in option_keys: + if isinstance(task_options['expires'], datetime.datetime): + # Get the difference between the countdown and now (relative to the countdown's timezone) + seconds_until_expiry = int( + (task_options['expires'] - datetime.datetime.now(tz=task_options['expires'].tzinfo)).total_seconds() + ) + else: + seconds_until_expiry = task_options['expires'] + if seconds_until_expiry < ttl_seconds: + ttl_seconds = seconds_until_expiry + + if ttl_seconds <= 0: + ttl_seconds = 1 + + return ttl_seconds + + +def unique_task_factory(task_cls): + """Creates a new, abstract Celery Task class that enables properly-configured Celery tasks to uniquely exist. + + @param task_cls: The original base class which should used with UniqueTaskMixin to produce a new Celery task + base class. + @type task_cls: type + + @return: The new Celery task base class with unique task-handling functionality mixed in. + @rtype: type + """ + return type(str('UniqueTask'), (UniqueTaskMixin, task_cls), {}) diff --git a/tests/celery_unique/__init__.py b/tests/celery_unique/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/celery_unique/test_backends.py b/tests/celery_unique/test_backends.py new file mode 100644 index 0000000..663e5a3 --- /dev/null +++ b/tests/celery_unique/test_backends.py @@ -0,0 +1,123 @@ +from unittest import TestCase + +from mockredis import mock_redis_client + +from celery_unique.backends import InMemoryBackend, RedisBackend + + +class RedisBackendTestCase(TestCase): + """Base class used to test RedisBackend methods.""" + + def setUp(self): + self.redis_client = mock_redis_client() + self.backend = RedisBackend(self.redis_client) + + +class RedisBackendCreateTaskRecordTestCase(RedisBackendTestCase): + + def test_record_created(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + task_ttl = 10 + + self.backend.create_task_record(task_key, task_id, task_ttl) + + self.assertEqual(self.redis_client.get(task_key).decode(), task_id) + self.assertEqual(self.backend.get_task_id(task_key), task_id) + + self.assertLessEqual(self.redis_client.ttl(task_key), task_ttl) + + +class RedisBackendRevokeExtantRecordTestCase(RedisBackendTestCase): + + def test_revoke_existing(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.redis_client.set(task_key, task_id) + self.assertEqual(self.redis_client.keys(), [task_key.encode()]) + + self.backend.revoke_extant_task(task_key) + self.assertEqual(self.redis_client.keys(), []) + + def test_revoke_inexistant(self): + # Note: this test also validates that we handle revoking inexistant + # tasks without any exception: + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.assertEqual(self.redis_client.keys(), []) + self.backend.revoke_extant_task(task_key) + self.assertEqual(self.redis_client.keys(), []) + + +class RedisBackendGetTaskIdTestCase(RedisBackendTestCase): + + def test_get_existing(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.redis_client.set(task_key, task_id) + self.assertEqual(self.backend.get_task_id(task_key), task_id) + + def test_get_inexisting(self): + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.assertEqual(self.redis_client.keys(), []) + self.assertIsNone(self.backend.get_task_id(task_key)) + + +class InMemoryBackendTestCase(TestCase): + """Base class used to test InMemoryBackend methods.""" + + def setUp(self): + self.backend = InMemoryBackend() + + +class InMemoryBackendCreateTaskRecordTestCase(InMemoryBackendTestCase): + + def test_record_created(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.backend.create_task_record(task_key, task_id) + + self.assertEqual(self.backend.tasks.get(task_key), task_id) + self.assertEqual(self.backend.get_task_id(task_key), task_id) + + +class InMemoryBackendRevokeExtantRecordTestCase(InMemoryBackendTestCase): + + def test_revoke_existing(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.backend.tasks[task_key] = task_id + self.assertEqual(self.backend.tasks, {task_key: task_id}) + + self.backend.revoke_extant_task(task_key) + self.assertEqual(self.backend.tasks, {}) + + def test_revoke_inexistant(self): + # Note: this test also validates that we handle revoking inexistant + # tasks without any exception: + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.assertEqual(self.backend.tasks, {}) + self.backend.revoke_extant_task(task_key) + self.assertEqual(self.backend.tasks, {}) + + +class InMemoryBackendGetTaskIdTestCase(InMemoryBackendTestCase): + + def test_get_existing(self): + task_id = '6bf813e5-74aa-4f12-a308-7e0a4bec5916' + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.backend.tasks[task_key] = task_id + self.assertEqual(self.backend.get_task_id(task_key), task_id) + + def test_get_inexisting(self): + task_key = '5jPpAKs0HOzdBjAZud6yQiL' + + self.assertEqual(self.backend.tasks, {}) + self.assertIsNone(self.backend.get_task_id(task_key)) diff --git a/tests/test_celery_unique.py b/tests/celery_unique/test_core.py similarity index 62% rename from tests/test_celery_unique.py rename to tests/celery_unique/test_core.py index b754bd7..9fad853 100644 --- a/tests/test_celery_unique.py +++ b/tests/celery_unique/test_core.py @@ -4,7 +4,7 @@ import datetime import inspect -import unittest +from unittest import TestCase try: from unittest import mock @@ -17,15 +17,12 @@ from mockredis import mock_redis_client import celery_unique +from celery_unique.backends import RedisBackend from tests import helpers -class CeleryUniqueTestCase(unittest.TestCase): - def setUp(self): - self.redis_client = mock_redis_client() +class UniqueTaskFactoryTestCase(TestCase): - -class UniqueTaskFactoryTestCase(CeleryUniqueTestCase): def test_return_value_is_class(self): test_cls = celery_unique.unique_task_factory(helpers.SimpleFakeTaskBase) self.assertTrue(inspect.isclass(test_cls)) @@ -45,125 +42,117 @@ def test_return_value_class_method_resolution_order(self): self.assertListEqual(test_cls.mro(), [test_cls, celery_unique.UniqueTaskMixin, given_task_class, object]) -class UniqueTaskMixinTestCase(CeleryUniqueTestCase): +class UniqueTaskMixinTestCase(TestCase): + """Helper base class used to test UniqueTaskMixin methods.""" + def setUp(self): super(UniqueTaskMixinTestCase, self).setUp() + self.redis_client = mock_redis_client() self.test_cls = celery_unique.unique_task_factory(helpers.SimpleFakeTaskBase) self.test_cls.name = 'A_TASK_NAME' - self.test_cls.redis_client = self.redis_client + self.redis_client = self.redis_client + self.test_cls.unique_backend = RedisBackend(self.redis_client) + +class UniqueTaskMixinMakeKeyTestCase(UniqueTaskMixinTestCase): -class UniqueTaskMixinMakeRedisKeyTestCase(UniqueTaskMixinTestCase): def test_with_static_unique_key_lambda(self): - self.assertEqual(celery_unique.UNIQUE_REDIS_KEY_PREFIX, 'celery_unique') - test_instance = self.test_cls() - test_instance.unique_key = lambda *args, **kwargs: 'A_UNIQUE_KEY' - redis_key = test_instance._make_redis_key((), {}) - self.assertEqual(redis_key, 'celery_unique:A_TASK_NAME:A_UNIQUE_KEY') + task = self.test_cls() + task.unique_key = lambda *args, **kwargs: 'A_UNIQUE_KEY' + + self.assertEqual( + task.make_key((), {}), + 'celery_unique:A_TASK_NAME:A_UNIQUE_KEY', + ) def test_with_dynamic_unique_key_lambda(self): - test_instance = self.test_cls() - test_instance.unique_key = lambda *args, **kwargs: '{}.{}'.format(args[0], kwargs['four']) - redis_key = test_instance._make_redis_key(callback_args=(1, 2), callback_kwargs={'three': 3, 'four': 4}) - self.assertEqual(redis_key, 'celery_unique:A_TASK_NAME:1.4') + task = self.test_cls() + task.unique_key = lambda *args, **kwargs: '{}.{}'.format( + args[0], + kwargs['four'], + ) + self.assertEqual( + task.make_key( + callback_args=(1, 2), + callback_kwargs={'three': 3, 'four': 4} + ), + 'celery_unique:A_TASK_NAME:1.4', + ) -class UniqueTaskMixinRevokeExtantUniqueTaskIfExistsTestCase(UniqueTaskMixinTestCase): + +class UniqueTaskMixinRevokeExtantTaskTestCase(UniqueTaskMixinTestCase): def setUp(self): - super(UniqueTaskMixinRevokeExtantUniqueTaskIfExistsTestCase, self).setUp() + super(UniqueTaskMixinRevokeExtantTaskTestCase, self).setUp() self.mock_celery_app = mock.Mock() self.mock_async_result = mock.Mock() self.mock_celery_app.AsyncResult.return_value = self.mock_async_result self.test_instance = self.test_cls() self.test_unique_task_key = 'celery_unique:A_TASK_NAME:1.4' - self.assertIsNone(self.test_instance.redis_client.get(self.test_unique_task_key)) + self.assertIsNone(self.redis_client.get(self.test_unique_task_key)) self.test_instance.app = self.mock_celery_app def test_does_not_revoke_when_not_found_in_redis(self): - self.test_instance._revoke_extant_unique_task_if_exists(self.test_unique_task_key) + # TODO: continue refactor + self.test_instance._revoke_extant_task(self.test_unique_task_key) self.assertFalse(self.mock_async_result.called) self.assertFalse(self.test_instance.app.AsyncResult.called) def test_does_not_delete_from_redis_when_not_found_in_redis(self): - with mock.patch.object(self.test_instance.redis_client, 'delete') as mock_redis_client_delete: - self.test_instance._revoke_extant_unique_task_if_exists(self.test_unique_task_key) - self.assertFalse(mock_redis_client_delete.called) + with mock.patch.object(self.redis_client, 'delete') as mock_delete: + self.test_instance._revoke_extant_task(self.test_unique_task_key) + self.assertFalse(mock_delete.called) def test_revokes_async_result_when_found_in_redis(self): test_task_id = uuid().encode() - self.test_instance.redis_client.set(self.test_unique_task_key, test_task_id) - self.assertEqual(self.test_instance.redis_client.get(self.test_unique_task_key), test_task_id) - self.test_instance._revoke_extant_unique_task_if_exists(self.test_unique_task_key) - self.assertIsNone(self.mock_celery_app.AsyncResult.assert_called_once_with(test_task_id)) + self.redis_client.set(self.test_unique_task_key, test_task_id) + self.assertEqual(self.redis_client.get(self.test_unique_task_key), test_task_id) + self.test_instance._revoke_extant_task(self.test_unique_task_key) + self.assertEqual(self.mock_celery_app.AsyncResult.call_count, 1) + self.assertEqual( + self.mock_celery_app.AsyncResult.call_args, + mock.call(test_task_id.decode()), + ) self.assertIsNone(self.mock_async_result.revoke.assert_called_once_with()) def test_deletes_from_redis_when_found_in_redis(self): test_task_id = uuid().encode() - self.test_instance.redis_client.set(self.test_unique_task_key, test_task_id) - self.assertIsNotNone(self.test_instance.redis_client.get(self.test_unique_task_key)) - self.assertEqual(self.test_instance.redis_client.get(self.test_unique_task_key), test_task_id) - self.test_instance._revoke_extant_unique_task_if_exists(self.test_unique_task_key) - self.assertNotEqual(self.test_instance.redis_client.get(self.test_unique_task_key), test_task_id) - self.assertIsNone(self.test_instance.redis_client.get(self.test_unique_task_key)) - - -class UniqueTaskMixinCreateUniqueTaskRecordTestCase(UniqueTaskMixinTestCase): - def setUp(self): - super(UniqueTaskMixinCreateUniqueTaskRecordTestCase, self).setUp() - self.test_unique_task_key = 'celery_unique:A_TASK_NAME:1.4' - self.test_instance = self.test_cls() - self.assertIsNone(self.test_instance.redis_client.get(self.test_unique_task_key)) - - def test_redis_record_created(self): - test_task_id = uuid().encode() - test_ttl_seconds = 10 - self.test_instance._create_unique_task_record(self.test_unique_task_key, test_task_id, test_ttl_seconds) - redis_value = self.test_instance.redis_client.get(self.test_unique_task_key) - self.assertIsNotNone(redis_value) - self.assertEqual(redis_value, test_task_id) - redis_record_ttl = self.test_instance.redis_client.ttl(self.test_unique_task_key) - self.assertLessEqual(redis_record_ttl, test_ttl_seconds) - - def test_redis_set_method_called_with_expected_arguments(self): - test_task_id = uuid().encode() - test_ttl_seconds = 10 - with mock.patch.object(self.test_instance.redis_client, 'set') as mock_redis_client_set: - self.test_instance._create_unique_task_record(self.test_unique_task_key, test_task_id, test_ttl_seconds) - self.assertIsNone( - mock_redis_client_set.assert_called_once_with(self.test_unique_task_key, test_task_id, ex=test_ttl_seconds) - ) + self.redis_client.set(self.test_unique_task_key, test_task_id) + self.assertIsNotNone(self.redis_client.get(self.test_unique_task_key)) + self.assertEqual(self.redis_client.get(self.test_unique_task_key), test_task_id) + self.test_instance._revoke_extant_task(self.test_unique_task_key) + self.assertNotEqual(self.redis_client.get(self.test_unique_task_key), test_task_id) + self.assertIsNone(self.redis_client.get(self.test_unique_task_key)) class UniqueTaskMixinMakeTTLForUniqueTaskRecordTestCase(UniqueTaskMixinTestCase): + + @freeze_time('2017-05-05') def test_ttl_is_difference_between_now_and_eta_if_eta_in_task_options_without_expiry(self): test_current_datetime_now_value = datetime.datetime.now() test_task_options = {'eta': test_current_datetime_now_value + datetime.timedelta(days=1)} expected_ttl = int((test_task_options['eta'] - test_current_datetime_now_value).total_seconds()) self.assertGreater(expected_ttl, 0) - with mock.patch.object(celery_unique, 'datetime', mock.Mock(wraps=datetime)) as mocked_datetime: - mocked_datetime.datetime.now.return_value = test_current_datetime_now_value - actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) + actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) self.assertEqual(actual_ttl, expected_ttl) + @freeze_time('2017-05-05') def test_ttl_is_difference_between_now_and_eta_if_eta_in_task_options_without_expiry_can_be_timezone_aware(self): test_current_datetime_now_value = datetime.datetime.now() test_task_options = {'eta': test_current_datetime_now_value + datetime.timedelta(days=1)} expected_ttl = int((test_task_options['eta'] - test_current_datetime_now_value).total_seconds()) self.assertGreater(expected_ttl, 0) - with mock.patch.object(celery_unique, 'datetime', mock.Mock(wraps=datetime)) as mocked_datetime: - mocked_datetime.datetime.now.return_value = test_current_datetime_now_value - actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) + actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) self.assertEqual(actual_ttl, expected_ttl) + @freeze_time('2017-05-05') def test_ttl_defaults_to_1_if_eta_before_now_in_task_options_without_expiry(self): test_current_datetime_now_value = datetime.datetime.now() test_task_options = {'eta': test_current_datetime_now_value - datetime.timedelta(days=1)} expected_default_ttl = 1 would_be_ttl = int((test_task_options['eta'] - test_current_datetime_now_value).total_seconds()) self.assertLessEqual(would_be_ttl, 0) - with mock.patch.object(celery_unique, 'datetime', mock.Mock(wraps=datetime)) as mocked_datetime: - mocked_datetime.datetime.now.return_value = test_current_datetime_now_value - actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) + actual_ttl = celery_unique.UniqueTaskMixin._make_ttl_for_unique_task_record(test_task_options) self.assertEqual(actual_ttl, expected_default_ttl) def test_ttl_is_countdown_if_countdown_in_task_options_without_expiry(self): @@ -231,50 +220,27 @@ def setUp(self): self.test_unique_redis_key = 'celery_unique:A_TASK_NAME:1.4' self.unique_key_lambda = lambda a, b, c, d: '{}.{}'.format(a, d) - def test_does_not_handle_as_unique_task_when_unique_key_is_not_callable(self): - test_instance = self.test_cls() - test_instance.app = mock.Mock() - self.assertFalse(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) - with mock.patch.object(self.redis_client, 'get') as mock_redis_client_get: - rs = test_instance.apply_async( + def test_does_not_handle_as_unique_task_when_not_applicable(self): + task = self.test_cls() + task._handle_as_unique = lambda options: False + + with mock.patch.object(task, 'unique_backend') as mock_backend: + rs = task.apply_async( args=(1, 2, 3, 4), eta=datetime.datetime.now() + datetime.timedelta(days=1) ) - self.assertIsInstance(rs, AsyncResult) - self.assertFalse(mock_redis_client_get.called) - def test_does_not_handle_as_unique_task_when_no_eta_or_countdown_in_options(self): - test_instance = self.test_cls() - test_instance.app = mock.Mock() - test_instance.unique_key = self.unique_key_lambda - self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) - with mock.patch.object(self.redis_client, 'get') as mock_redis_client_get: - rs = test_instance.apply_async(args=(1, 2, 3, 4)) self.assertIsInstance(rs, AsyncResult) - self.assertFalse(mock_redis_client_get.called) - def test_does_not_handle_as_unique_task_when_redis_client_is_None(self): - test_instance = self.test_cls() - test_instance.app = mock.Mock() - test_instance.redis_client = None - test_instance.unique_key = self.unique_key_lambda - self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNone(test_instance.redis_client) - with mock.patch.object(self.redis_client, 'get') as mock_redis_client_get: - rs = test_instance.apply_async( - args=(1, 2, 3, 4), - eta=datetime.datetime.now() + datetime.timedelta(days=1) - ) - self.assertIsInstance(rs, AsyncResult) - self.assertFalse(mock_redis_client_get.called) + # No key made or record created in the backend: + self.assertFalse(mock_backend.make_key.called) + self.assertFalse(mock_backend.crete_task_record.called) def test_attempts_to_revoke_extant_task_when_eta_is_given_with_no_countdown(self): test_instance = self.test_cls() test_instance.unique_key = self.unique_key_lambda self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) + self.assertIsNotNone(self.redis_client) with mock.patch.object(self.redis_client, 'get') as mock_redis_client_get: mock_redis_client_get.return_value = None rs = test_instance.apply_async( @@ -288,7 +254,7 @@ def test_attempts_to_revoke_extant_task_when_countdown_is_given_with_no_eta(self test_instance = self.test_cls() test_instance.unique_key = self.unique_key_lambda self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) + self.assertIsNotNone(self.redis_client) with mock.patch.object(self.redis_client, 'get') as mock_redis_client_get: mock_redis_client_get.return_value = None rs = test_instance.apply_async( @@ -307,7 +273,8 @@ def test_revokes_extant_task_when_one_exists(self): test_instance.app = mock.Mock() test_instance.unique_key = self.unique_key_lambda self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) + + self.assertIsNotNone(test_instance.unique_backend) rs = test_instance.apply_async( args=(1, 2, 3, 4), countdown=100 @@ -324,7 +291,7 @@ def test_creates_new_task_record_when_extant_task_exists(self): test_instance.app = mock.Mock() test_instance.unique_key = self.unique_key_lambda self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) + self.assertIsNotNone(self.redis_client) rs = test_instance.apply_async( args=(1, 2, 3, 4), countdown=100 @@ -335,14 +302,46 @@ def test_creates_new_task_record_when_extant_task_exists(self): def test_creates_new_task_record_when_no_extant_task_exists(self): self.assertIsNone(self.redis_client.get(self.test_unique_redis_key)) + test_instance = self.test_cls() test_instance.app = mock.Mock() test_instance.unique_key = self.unique_key_lambda + self.assertTrue(callable(test_instance.unique_key)) - self.assertIsNotNone(test_instance.redis_client) + self.assertIsNotNone(test_instance.unique_backend) rs = test_instance.apply_async( args=(1, 2, 3, 4), countdown=100 ) self.assertIsInstance(rs, AsyncResult) - self.assertEqual(self.redis_client.get(self.test_unique_redis_key), rs.task_id.encode()) + self.assertEqual( + self.redis_client.get(self.test_unique_redis_key), + rs.task_id.encode(), + ) + + +class UniqueTaskMixinHandleAsUniqueTestCase(UniqueTaskMixinTestCase): + + def test_is_unique(self): + task = self.test_cls() + self.assertIsNotNone(task.unique_backend) + task.unique_key = lambda *a, **kw: 'the_key' + self.assertEqual(task._handle_as_unique({'eta': 10}), True) + + def test_no_key_func(self): + task = self.test_cls() + self.assertIsNotNone(task.unique_backend) + self.assertIsNone(task.unique_key) + self.assertEqual(task._handle_as_unique({'eta': 10}), False) + + def test_no_eta_or_countdown(self): + task = self.test_cls() + task.unique_key = lambda *a, **kw: 'the_key' + self.assertIsNotNone(task.unique_backend) + self.assertEqual(task._handle_as_unique({}), False) + + def test_no_backend(self): + task = self.test_cls() + task.unique_backend = None + task.unique_key = lambda *a, **kw: 'the_key' + self.assertEqual(task._handle_as_unique({'eta': 10}), False)