Pārlūkot izejas kodu

Refactored celery.managers

Ask Solem 15 gadi atpakaļ
vecāks
revīzija
29620126ca
3 mainītis faili ar 94 papildinājumiem un 75 dzēšanām
  1. 94 72
      celery/managers.py
  2. 0 1
      celery/tests/test_task.py
  3. 0 2
      celery/views.py

+ 94 - 72
celery/managers.py

@@ -1,15 +1,90 @@
-"""celery.managers"""
 from datetime import datetime
+from itertools import count
+
+from billiard.utils.functional import wraps
+
 from django.db import models
 from django.db import transaction
+from django.db.models.query import QuerySet
+
+
+def transaction_retry(max_retries=1):
+    """Decorator for methods doing database operations.
+
+    If the database operation fails, it will retry the operation
+    at most ``max_retries`` times.
+
+    """
+    def _outer(fun):
+
+        @wraps(fun)
+        def _inner(*args, **kwargs):
+            _max_retries = kwargs.pop("exception_retry_count", max_retries)
+            for retries in count(0):
+                try:
+                    return fun(*args, **kwargs)
+                except Exception: # pragma: no cover
+                    # Depending on the database backend used we can experience
+                    # various exceptions. E.g. psycopg2 raises an exception
+                    # if some operation breaks the transaction, so saving
+                    # the task result won't be possible until we rollback
+                    # the transaction.
+                    if retries >= _max_retries:
+                        raise
+                    transaction.rollback_unless_managed()
+
+        return _inner
+
+    return _outer
+
+
+def update_model_with_dict(obj, fields):
+    [setattr(obj, attr_name, attr_value)
+        for attr_name, attr_value in fields.items()]
+    obj.save()
+    return obj
+
+
+class ExtendedQuerySet(QuerySet):
+
+    def update_or_create(self, **kwargs):
+        obj, created = self.get_or_create(**kwargs)
+
+        if not created:
+            fields = dict(kwargs.pop("defaults", {}))
+            fields.update(kwargs)
+            update_model_with_dict(obj, fields)
+
+        return obj
+
+
+class ExtendedManager(models.Manager):
+
+    def get_query_set(self):
+        return ExtendedQuerySet(self.model)
+
+    def update_or_create(self, **kwargs):
+        return self.get_query_set().update_or_create(**kwargs)
+
+
+class ResultManager(ExtendedManager):
 
-from celery.conf import TASK_RESULT_EXPIRES
+    def get_all_expired(self):
+        """Get all expired task results."""
+        from celery import conf
+        expires = conf.TASK_RESULT_EXPIRES
+        return self.filter(date_done__lt=datetime.now() - expires)
+
+    def delete_expired(self):
+        """Delete all expired taskset results."""
+        self.get_all_expired().delete()
 
 
-class TaskManager(models.Manager):
+class TaskManager(ResultManager):
     """Manager for :class:`celery.models.Task` models."""
 
-    def get_task(self, task_id, exception_retry_count=1):
+    @transaction_retry(max_retries=1)
+    def get_task(self, task_id):
         """Get task meta for task by ``task_id``.
 
         :keyword exception_retry_count: How many times to retry by
@@ -18,32 +93,15 @@ class TaskManager(models.Manager):
             create the same task. The default is to retry once.
 
         """
-        try:
-            task, created = self.get_or_create(task_id=task_id)
-        except Exception: # pragma: no cover
-            # We don't have a map of the different exceptions backends can
-            # throw, so we have to catch everything.
-            if exception_retry_count > 0:
-                transaction.rollback_unless_managed()
-                return self.get_task(task_id, exception_retry_count - 1)
-            else:
-                raise
+        task, created = self.get_or_create(task_id=task_id)
         return task
 
     def is_successful(self, task_id):
         """Returns ``True`` if the task was executed successfully."""
         return self.get_task(task_id).status == "SUCCESS"
 
-    def get_all_expired(self):
-        """Get all expired task results."""
-        return self.filter(date_done__lt=datetime.now() - TASK_RESULT_EXPIRES)
-
-    def delete_expired(self):
-        """Delete all expired task results."""
-        self.get_all_expired().delete()
-
-    def store_result(self, task_id, result, status, traceback=None,
-            exception_retry_count=2):
+    @transaction_retry(max_retries=2)
+    def store_result(self, task_id, result, status, traceback=None):
         """Store the result and status of a task.
 
         :param task_id: task id
@@ -64,40 +122,17 @@ class TaskManager(models.Manager):
             create the same task. The default is to retry twice.
 
         """
-        try:
-            task, created = self.get_or_create(task_id=task_id, defaults={
-                                                "status": status,
-                                                "result": result,
-                                                "traceback": traceback})
-            if not created:
-                task.status = status
-                task.result = result
-                task.traceback = traceback
-                task.save()
-        except Exception: # pragma: no cover
-            # depending on the database backend we can get various exceptions.
-            # for excample, psycopg2 raises an exception if some operation
-            # breaks transaction, and saving task result won't be possible
-            # until we rollback transaction
-            if exception_retry_count > 0:
-                transaction.rollback_unless_managed()
-                self.store_result(task_id, result, status, traceback,
-                                  exception_retry_count - 1)
-            else:
-                raise
-
-
-class TaskSetManager(models.Manager):
-    """Manager for :class:`celery.models.TaskSet` models."""
+        return self.update_or_create(task_id=task_id, defaults={
+                                        "status": status,
+                                        "result": result,
+                                        "traceback": traceback})
 
-    def get_all_expired(self):
-        """Get all expired taskset results."""
-        return self.filter(date_done__lt=datetime.now() - TASK_RESULT_EXPIRES)
 
-    def delete_expired(self):
-        """Delete all expired taskset results."""
-        self.get_all_expired().delete()
+class TaskSetManager(ResultManager):
+    """Manager for :class:`celery.models.TaskSet` models."""
 
+
+    @transaction_retry(max_retries=1)
     def get_taskset(self, taskset_id):
         """Get taskset meta for task by ``taskset_id``."""
         try:
@@ -105,7 +140,8 @@ class TaskSetManager(models.Manager):
         except self.model.DoesNotExist:
             return None
 
-    def store_result(self, taskset_id, result, exception_retry=True):
+    @transaction_retry(max_retries=2)
+    def store_result(self, taskset_id, result):
         """Store the result of a taskset.
 
         :param taskset_id: task set id
@@ -113,19 +149,5 @@ class TaskSetManager(models.Manager):
         :param result: The return value of the taskset
 
         """
-        try:
-            taskset, created = self.get_or_create(taskset_id=taskset_id, defaults={
-                                                "result": result})
-            if not created:
-                taskset.result = result
-                taskset.save()
-        except Exception, exc:
-            # depending on the database backend we can get various exceptions.
-            # for excample, psycopg2 raises an exception if some operation
-            # breaks transaction, and saving task result won't be possible
-            # until we rollback transaction
-            if exception_retry:
-                transaction.rollback_unless_managed()
-                self.store_result(taskset_id, result, False)
-            else:
-                raise
+        return self.update_or_create(taskset_id=taskset_id,
+                                     defaults={"result": result})

+ 0 - 1
celery/tests/test_task.py

@@ -340,7 +340,6 @@ class TestTaskSet(unittest.TestCase):
         taskset_id = taskset_res.taskset_id
         for subtask in subtasks:
             m = consumer.fetch().payload
-            print("M: %s" % m)
             self.assertEquals(m.get("taskset"), taskset_id)
             self.assertEquals(m.get("task"), IncrementCounterTask.name)
             self.assertEquals(m.get("id"), subtask.task_id)

+ 0 - 2
celery/views.py

@@ -32,7 +32,6 @@ def task_view(task):
     return _applier
 
 
-
 def apply(request, task_name):
     """View applying a task.
 
@@ -73,7 +72,6 @@ def task_status(request, task_id):
             mimetype="application/json")
 
 
-
 def task_webhook(fun):
     """Decorator turning a function into a task webhook.