123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- from datetime import datetime
- from itertools import count
- from billiard.utils.functional import wraps
- from django.db import models
- from django.db import transaction
- from django.db.models.query import QuerySet
- def transaction_retry(max_retries=1):
- """Decorator for methods doing database operations.
- If the database operation fails, it will retry the operation
- at most ``max_retries`` times.
- """
- def _outer(fun):
- @wraps(fun)
- def _inner(*args, **kwargs):
- _max_retries = kwargs.pop("exception_retry_count", max_retries)
- for retries in count(0):
- try:
- return fun(*args, **kwargs)
- except Exception: # pragma: no cover
- # Depending on the database backend used we can experience
- # various exceptions. E.g. psycopg2 raises an exception
- # if some operation breaks the transaction, so saving
- # the task result won't be possible until we rollback
- # the transaction.
- if retries >= _max_retries:
- raise
- transaction.rollback_unless_managed()
- return _inner
- return _outer
- def update_model_with_dict(obj, fields):
- [setattr(obj, attr_name, attr_value)
- for attr_name, attr_value in fields.items()]
- obj.save()
- return obj
- class ExtendedQuerySet(QuerySet):
- def update_or_create(self, **kwargs):
- obj, created = self.get_or_create(**kwargs)
- if not created:
- fields = dict(kwargs.pop("defaults", {}))
- fields.update(kwargs)
- update_model_with_dict(obj, fields)
- return obj
- class ExtendedManager(models.Manager):
- def get_query_set(self):
- return ExtendedQuerySet(self.model)
- def update_or_create(self, **kwargs):
- return self.get_query_set().update_or_create(**kwargs)
- class ResultManager(ExtendedManager):
- def get_all_expired(self):
- """Get all expired task results."""
- from celery import conf
- expires = conf.TASK_RESULT_EXPIRES
- return self.filter(date_done__lt=datetime.now() - expires)
- def delete_expired(self):
- """Delete all expired taskset results."""
- self.get_all_expired().delete()
- class TaskManager(ResultManager):
- """Manager for :class:`celery.models.Task` models."""
- @transaction_retry(max_retries=1)
- def get_task(self, task_id):
- """Get task meta for task by ``task_id``.
- :keyword exception_retry_count: How many times to retry by
- transaction rollback on exception. This could theoretically
- happen in a race condition if another worker is trying to
- create the same task. The default is to retry once.
- """
- task, created = self.get_or_create(task_id=task_id)
- return task
- @transaction_retry(max_retries=2)
- def store_result(self, task_id, result, status, traceback=None):
- """Store the result and status of a task.
- :param task_id: task id
- :param result: The return value of the task, or an exception
- instance raised by the task.
- :param status: Task status. See
- :meth:`celery.result.AsyncResult.get_status` for a list of
- possible status values.
- :keyword traceback: The traceback at the point of exception (if the
- task failed).
- :keyword exception_retry_count: How many times to retry by
- transaction rollback on exception. This could theoretically
- happen in a race condition if another worker is trying to
- create the same task. The default is to retry twice.
- """
- return self.update_or_create(task_id=task_id, defaults={
- "status": status,
- "result": result,
- "traceback": traceback})
- class TaskSetManager(ResultManager):
- """Manager for :class:`celery.models.TaskSet` models."""
- @transaction_retry(max_retries=1)
- def restore_taskset(self, taskset_id):
- """Get taskset meta for task by ``taskset_id``."""
- try:
- return self.get(taskset_id=taskset_id)
- except self.model.DoesNotExist:
- return None
- @transaction_retry(max_retries=2)
- def store_result(self, taskset_id, result):
- """Store the result of a taskset.
- :param taskset_id: task set id
- :param result: The return value of the taskset
- """
- return self.update_or_create(taskset_id=taskset_id,
- defaults={"result": result})
|