Browse Source

Merges BaseBackend and BaseDictBackend into one BaseBackend

Ask Solem 12 years ago
parent
commit
dd632b2721

+ 2 - 2
celery/backends/amqp.py

@@ -21,7 +21,7 @@ from celery import states
 from celery.exceptions import TimeoutError
 from celery.utils.log import get_logger
 
-from .base import BaseDictBackend
+from .base import BaseBackend
 
 logger = get_logger(__name__)
 
@@ -37,7 +37,7 @@ def repair_uuid(s):
     return '%s-%s-%s-%s-%s' % (s[:8], s[8:12], s[12:16], s[16:20], s[20:])
 
 
-class AMQPBackend(BaseDictBackend):
+class AMQPBackend(BaseBackend):
     """Publishes results by sending messages."""
     Exchange = Exchange
     Queue = Queue

+ 58 - 117
celery/backends/base.py

@@ -7,8 +7,6 @@
 
     - :class:`BaseBackend` defines the interface.
 
-    - :class:`BaseDictBackend` assumes the fields are stored in a dict.
-
     - :class:`KeyValueStoreBackend` is a common base class
       using K/V semantics like _get and _put.
 
@@ -46,7 +44,6 @@ def unpickle_backend(cls, args, kwargs):
 
 
 class BaseBackend(object):
-    """Base backend class."""
     READY_STATES = states.READY_STATES
     UNREADY_STATES = states.UNREADY_STATES
     EXCEPTION_STATES = states.EXCEPTION_STATES
@@ -61,44 +58,15 @@ class BaseBackend(object):
     #: If true the backend must implement :meth:`get_many`.
     supports_native_join = False
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, app=None, serializer=None, max_cached_results=None, **kwargs):
         from celery.app import app_or_default
-        self.app = app_or_default(kwargs.get('app'))
-        self.serializer = kwargs.get('serializer',
-                                     self.app.conf.CELERY_RESULT_SERIALIZER)
+        self.app = app_or_default(app)
+        self.serializer = serializer or self.app.conf.CELERY_RESULT_SERIALIZER
         (self.content_type,
          self.content_encoding,
          self.encoder) = serialization.registry._encoders[self.serializer]
-
-    def encode(self, data):
-        _, _, payload = serialization.encode(data, serializer=self.serializer)
-        return payload
-
-    def decode(self, payload):
-        payload = is_py3k and payload or str(payload)
-        return serialization.decode(payload,
-                                    content_type=self.content_type,
-                                    content_encoding=self.content_encoding)
-
-    def prepare_expires(self, value, type=None):
-        if value is None:
-            value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
-        if isinstance(value, timedelta):
-            value = timeutils.timedelta_seconds(value)
-        if value is not None and type:
-            return type(value)
-        return value
-
-    def encode_result(self, result, status):
-        if status in self.EXCEPTION_STATES and isinstance(result, Exception):
-            return self.prepare_exception(result)
-        else:
-            return self.prepare_value(result)
-
-    def store_result(self, task_id, result, status, traceback=None):
-        """Store the result and status of a task."""
-        raise NotImplementedError(
-                'store_result is not supported by this backend.')
+        self._cache = LRUCache(limit=max_cached_results or
+                                      self.app.conf.CELERY_MAX_CACHED_RESULTS)
 
     def mark_as_started(self, task_id, **meta):
         """Mark a task as started"""
@@ -140,8 +108,15 @@ class BaseBackend(object):
         """Prepare value for storage."""
         return result
 
-    def forget(self, task_id):
-        raise NotImplementedError('backend does not implement forget.')
+    def encode(self, data):
+        _, _, payload = serialization.encode(data, serializer=self.serializer)
+        return payload
+
+    def decode(self, payload):
+        payload = is_py3k and payload or str(payload)
+        return serialization.decode(payload,
+                                    content_type=self.content_type,
+                                    content_encoding=self.content_encoding)
 
     def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
         """Wait for task and return its result.
@@ -172,85 +147,23 @@ class BaseBackend(object):
             if timeout and time_elapsed >= timeout:
                 raise TimeoutError('The operation timed out.')
 
-    def cleanup(self):
-        """Backend cleanup. Is run by
-        :class:`celery.task.DeleteExpiredTaskMetaTask`."""
-        pass
-
-    def process_cleanup(self):
-        """Cleanup actions to do at the end of a task worker process."""
-        pass
-
-    def get_status(self, task_id):
-        """Get the status of a task."""
-        raise NotImplementedError(
-                'get_status is not supported by this backend.')
-
-    def get_result(self, task_id):
-        """Get the result of a task."""
-        raise NotImplementedError(
-                'get_result is not supported by this backend.')
-
-    def get_children(self, task_id):
-        raise NotImplementedError(
-                'get_children is not supported by this backend.')
-
-    def get_traceback(self, task_id):
-        """Get the traceback for a failed task."""
-        raise NotImplementedError(
-                'get_traceback is not supported by this backend.')
-
-    def save_group(self, group_id, result):
-        """Store the result and status of a task."""
-        raise NotImplementedError(
-                'save_group is not supported by this backend.')
-
-    def restore_group(self, group_id, cache=True):
-        """Get the result of a group."""
-        raise NotImplementedError(
-                'restore_group is not supported by this backend.')
-
-    def delete_group(self, group_id):
-        raise NotImplementedError(
-                'delete_group is not supported by this backend.')
-
-    def reload_task_result(self, task_id):
-        """Reload task result, even if it has been previously fetched."""
-        raise NotImplementedError(
-                'reload_task_result is not supported by this backend.')
-
-    def reload_group_result(self, task_id):
-        """Reload group result, even if it has been previously fetched."""
-        raise NotImplementedError(
-                'reload_group_result is not supported by this backend.')
-
-    def on_chord_part_return(self, task, propagate=False):
-        pass
-
-    def fallback_chord_unlock(self, group_id, body, result=None, **kwargs):
-        kwargs['result'] = [r.id for r in result]
-        self.app.tasks['celery.chord_unlock'].apply_async((group_id, body, ),
-                                                          kwargs, countdown=1)
-    on_chord_apply = fallback_chord_unlock
-
-    def current_task_children(self):
-        current = current_task()
-        if current:
-            return [r.serializable() for r in current.request.children]
-
-    def __reduce__(self, args=(), kwargs={}):
-        return (unpickle_backend, (self.__class__, args, kwargs))
-
-
-class BaseDictBackend(BaseBackend):
+    def prepare_expires(self, value, type=None):
+        if value is None:
+            value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
+        if isinstance(value, timedelta):
+            value = timeutils.timedelta_seconds(value)
+        if value is not None and type:
+            return type(value)
+        return value
 
-    def __init__(self, *args, **kwargs):
-        super(BaseDictBackend, self).__init__(*args, **kwargs)
-        self._cache = LRUCache(limit=kwargs.get('max_cached_results') or
-                                 self.app.conf.CELERY_MAX_CACHED_RESULTS)
+    def encode_result(self, result, status):
+        if status in self.EXCEPTION_STATES and isinstance(result, Exception):
+            return self.prepare_exception(result)
+        else:
+            return self.prepare_value(result)
 
     def store_result(self, task_id, result, status, traceback=None, **kwargs):
-        """Store task result and status."""
+        """Update task state and result."""
         result = self.encode_result(result, status)
         return self._store_result(task_id, result, status, traceback, **kwargs)
 
@@ -297,12 +210,13 @@ class BaseDictBackend(BaseBackend):
         return meta
 
     def reload_task_result(self, task_id):
+        """Reload task result, even if it has been previously fetched."""
         self._cache[task_id] = self.get_task_meta(task_id, cache=False)
 
     def reload_group_result(self, group_id):
+        """Reload group result, even if it has been previously fetched."""
         self._cache[group_id] = self.get_group_meta(group_id,
                                                     cache=False)
-
     def get_group_meta(self, group_id, cache=True):
         if cache:
             try:
@@ -329,8 +243,35 @@ class BaseDictBackend(BaseBackend):
         self._cache.pop(group_id, None)
         return self._delete_group(group_id)
 
+    def cleanup(self):
+        """Backend cleanup. Is run by
+        :class:`celery.task.DeleteExpiredTaskMetaTask`."""
+        pass
+
+    def process_cleanup(self):
+        """Cleanup actions to do at the end of a task worker process."""
+        pass
+
+    def on_chord_part_return(self, task, propagate=False):
+        pass
+
+    def fallback_chord_unlock(self, group_id, body, result=None, **kwargs):
+        kwargs['result'] = [r.id for r in result]
+        self.app.tasks['celery.chord_unlock'].apply_async((group_id, body, ),
+                                                          kwargs, countdown=1)
+    on_chord_apply = fallback_chord_unlock
+
+    def current_task_children(self):
+        current = current_task()
+        if current:
+            return [r.serializable() for r in current.request.children]
+
+    def __reduce__(self, args=(), kwargs={}):
+        return (unpickle_backend, (self.__class__, args, kwargs))
+BaseDictBackend = BaseBackend  # XXX compat
+
 
-class KeyValueStoreBackend(BaseDictBackend):
+class KeyValueStoreBackend(BaseBackend):
     task_keyprefix = ensure_bytes('celery-task-meta-')
     group_keyprefix = ensure_bytes('celery-taskset-meta-')
     chord_keyprefix = ensure_bytes('chord-unlock-')

+ 1 - 1
celery/backends/cache.py

@@ -77,7 +77,7 @@ class CacheBackend(KeyValueStoreBackend):
     implements_incr = True
 
     def __init__(self, expires=None, backend=None, options={}, **kwargs):
-        super(CacheBackend, self).__init__(self, **kwargs)
+        super(CacheBackend, self).__init__(**kwargs)
 
         self.options = dict(self.app.conf.CELERY_CACHE_BACKEND_OPTIONS,
                             **options)

+ 2 - 2
celery/backends/cassandra.py

@@ -23,12 +23,12 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils.log import get_logger
 from celery.utils.timeutils import maybe_timedelta, timedelta_seconds
 
-from .base import BaseDictBackend
+from .base import BaseBackend
 
 logger = get_logger(__name__)
 
 
-class CassandraBackend(BaseDictBackend):
+class CassandraBackend(BaseBackend):
     """Highly fault tolerant Cassandra backend.
 
     .. attribute:: servers

+ 2 - 2
celery/backends/database/__init__.py

@@ -14,7 +14,7 @@ from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.timeutils import maybe_timedelta
 
-from celery.backends.base import BaseDictBackend
+from celery.backends.base import BaseBackend
 
 from .models import Task, TaskSet
 from .session import ResultSession
@@ -49,7 +49,7 @@ def retry(fun):
     return _inner
 
 
-class DatabaseBackend(BaseDictBackend):
+class DatabaseBackend(BaseBackend):
     """The database result backend."""
     # ResultSet.iterate should sleep this much between each pool,
     # to not bombard the database with queries.

+ 2 - 2
celery/backends/mongodb.py

@@ -29,7 +29,7 @@ from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.timeutils import maybe_timedelta
 
-from .base import BaseDictBackend
+from .base import BaseBackend
 
 
 class Bunch(object):
@@ -38,7 +38,7 @@ class Bunch(object):
         self.__dict__.update(kw)
 
 
-class MongoBackend(BaseDictBackend):
+class MongoBackend(BaseBackend):
     mongodb_host = 'localhost'
     mongodb_port = 27017
     mongodb_user = None

+ 12 - 53
celery/tests/backends/test_base.py

@@ -16,8 +16,11 @@ from celery.utils.serialization import UnpickleableExceptionWrapper
 from celery.utils.serialization import get_pickleable_exception as gpe
 
 from celery import states
-from celery.backends.base import BaseBackend, KeyValueStoreBackend
-from celery.backends.base import BaseDictBackend, DisabledBackend
+from celery.backends.base import (
+    BaseBackend,
+    KeyValueStoreBackend,
+    DisabledBackend,
+)
 from celery.utils import uuid
 
 from celery.tests.utils import Case
@@ -48,53 +51,9 @@ class test_serialization(Case):
 
 class test_BaseBackend_interface(Case):
 
-    def test_get_status(self):
-        with self.assertRaises(NotImplementedError):
-            b.get_status('SOMExx-N0Nex1stant-IDxx-')
-
     def test__forget(self):
         with self.assertRaises(NotImplementedError):
-            b.forget('SOMExx-N0Nex1stant-IDxx-')
-
-    def test_get_children(self):
-        with self.assertRaises(NotImplementedError):
-            b.get_children('SOMExx-N0Nex1stant-IDxx-')
-
-    def test_store_result(self):
-        with self.assertRaises(NotImplementedError):
-            b.store_result('SOMExx-N0nex1stant-IDxx-', 42, states.SUCCESS)
-
-    def test_mark_as_started(self):
-        with self.assertRaises(NotImplementedError):
-            b.mark_as_started('SOMExx-N0nex1stant-IDxx-')
-
-    def test_reload_task_result(self):
-        with self.assertRaises(NotImplementedError):
-            b.reload_task_result('SOMExx-N0nex1stant-IDxx-')
-
-    def test_reload_group_result(self):
-        with self.assertRaises(NotImplementedError):
-            b.reload_group_result('SOMExx-N0nex1stant-IDxx-')
-
-    def test_get_result(self):
-        with self.assertRaises(NotImplementedError):
-            b.get_result('SOMExx-N0nex1stant-IDxx-')
-
-    def test_restore_group(self):
-        with self.assertRaises(NotImplementedError):
-            b.restore_group('SOMExx-N0nex1stant-IDxx-')
-
-    def test_delete_group(self):
-        with self.assertRaises(NotImplementedError):
-            b.delete_group('SOMExx-N0nex1stant-IDxx-')
-
-    def test_save_group(self):
-        with self.assertRaises(NotImplementedError):
-            b.save_group('SOMExx-N0nex1stant-IDxx-', 'blergh')
-
-    def test_get_traceback(self):
-        with self.assertRaises(NotImplementedError):
-            b.get_traceback('SOMExx-N0nex1stant-IDxx-')
+            b._forget('SOMExx-N0Nex1stant-IDxx-')
 
     def test_forget(self):
         with self.assertRaises(NotImplementedError):
@@ -163,7 +122,7 @@ class KVBackend(KeyValueStoreBackend):
 
     def __init__(self, *args, **kwargs):
         self.db = {}
-        super(KVBackend, self).__init__(KeyValueStoreBackend)
+        super(KVBackend, self).__init__()
 
     def get(self, key):
         return self.db.get(key)
@@ -181,10 +140,10 @@ class KVBackend(KeyValueStoreBackend):
         self.db.pop(key, None)
 
 
-class DictBackend(BaseDictBackend):
+class DictBackend(BaseBackend):
 
     def __init__(self, *args, **kwargs):
-        BaseDictBackend.__init__(self, *args, **kwargs)
+        BaseBackend.__init__(self, *args, **kwargs)
         self._data = {'can-delete': {'result': 'foo'}}
 
     def _restore_group(self, group_id):
@@ -199,7 +158,7 @@ class DictBackend(BaseDictBackend):
         self._data.pop(group_id, None)
 
 
-class test_BaseDictBackend(Case):
+class test_BaseBackend_dict(Case):
 
     def setUp(self):
         self.b = DictBackend()
@@ -217,13 +176,13 @@ class test_BaseDictBackend(Case):
         self.assertEqual(str(e), "'foo'")
 
     def test_save_group(self):
-        b = BaseDictBackend()
+        b = BaseBackend()
         b._save_group = Mock()
         b.save_group('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
 
     def test_forget_interface(self):
-        b = BaseDictBackend()
+        b = BaseBackend()
         with self.assertRaises(NotImplementedError):
             b.forget('foo')