task.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from carrot.connection import DjangoAMQPConnection
  2. from celery.log import setup_logger
  3. from celery.conf import TASK_META_USE_DB
  4. from celery.registry import tasks
  5. from celery.messaging import TaskPublisher, TaskConsumer
  6. from celery.models import TaskMeta
  7. from django.core.cache import cache
  8. from datetime import timedelta
  9. from celery.backends import default_backend
  10. import uuid
  11. import traceback
  12. class BasePendingResult(object):
  13. """Base class for pending result, takes ``backend`` argument."""
  14. def __init__(self, task_id, backend):
  15. self.task_id = task_id
  16. self.backend = backend
  17. def __str__(self):
  18. return self.task_id
  19. def __repr__(self):
  20. return "<Job: %s>" % self.task_id
  21. def is_done(self):
  22. return self.backend.is_done(self.task_id)
  23. def wait_for(self):
  24. return self.backend.wait_for(self.task_id)
  25. @property
  26. def result(self):
  27. if self.status == "DONE":
  28. return self.backend.get_result(self.task_id)
  29. return None
  30. @property
  31. def status(self):
  32. return self.backend.get_status(self.task_id)
  33. class PendingResult(BasePendingResult):
  34. """Pending task result using the default backend."""
  35. def __init__(self, task_id):
  36. super(PendingResult, self).__init__(task_id, backend=default_backend)
  37. def delay_task(task_name, *args, **kwargs):
  38. """Delay a task for execution by the ``celery`` daemon.
  39. Examples
  40. --------
  41. >>> delay_task("update_record", name="George Constanza", age=32)
  42. """
  43. if task_name not in tasks:
  44. raise tasks.NotRegistered(
  45. "Task with name %s not registered in the task registry." % (
  46. task_name))
  47. publisher = TaskPublisher(connection=DjangoAMQPConnection())
  48. task_id = publisher.delay_task(task_name, *args, **kwargs)
  49. publisher.close()
  50. return PendingResult(task_id)
  51. def discard_all():
  52. """Discard all waiting tasks.
  53. This will ignore all tasks waiting for execution, and they will
  54. be deleted from the messaging server.
  55. Returns the number of tasks discarded.
  56. """
  57. consumer = TaskConsumer(connection=DjangoAMQPConnection())
  58. discarded_count = consumer.discard_all()
  59. consumer.close()
  60. return discarded_count
  61. def mark_as_done(task_id, result):
  62. """Mark task as done (executed)."""
  63. return default_backend.mark_as_done(task_id, result)
  64. def is_done(task_id):
  65. """Returns ``True`` if task with ``task_id`` has been executed."""
  66. return default_backend.is_done(task_id)
  67. class Task(object):
  68. """A task that can be delayed for execution by the ``celery`` daemon.
  69. All subclasses of ``Task`` has to define the ``name`` attribute, which is
  70. the name of the task that can be passed to ``celery.task.delay_task``,
  71. it also has to define the ``run`` method, which is the actual method the
  72. ``celery`` daemon executes.
  73. Examples
  74. --------
  75. This is a simple task just logging a message,
  76. >>> from celery.task import tasks, Task
  77. >>> class MyTask(Task):
  78. ... name = "mytask"
  79. ...
  80. ... def run(self, some_arg=None, **kwargs):
  81. ... logger = self.get_logger(**kwargs)
  82. ... logger.info("Running MyTask with arg some_arg=%s" %
  83. ... some_arg))
  84. ... tasks.register(MyTask)
  85. You can delay the task using the classmethod ``delay``...
  86. >>> MyTask.delay(some_arg="foo")
  87. ...or using the ``celery.task.delay_task`` function, by passing the
  88. name of the task.
  89. >>> from celery.task import delay_task
  90. >>> delay_task(MyTask.name, some_arg="foo")
  91. """
  92. name = None
  93. type = "regular"
  94. max_retries = 0 # unlimited
  95. retry_interval = timedelta(seconds=2)
  96. auto_retry = False
  97. def __init__(self):
  98. if not self.name:
  99. raise NotImplementedError("Tasks must define a name attribute.")
  100. def __call__(self, *args, **kwargs):
  101. """The ``__call__`` is called when you do ``Task().run()`` and calls
  102. the ``run`` method. It also catches any exceptions and logs them."""
  103. try:
  104. retval = self.run(*args, **kwargs)
  105. except Exception, e:
  106. logger = self.get_logger(**kwargs)
  107. logger.critical("Task got exception %s: %s\n%s" % (
  108. e.__class__, e, traceback.format_exc()))
  109. self.handle_exception(e, args, kwargs)
  110. if self.auto_retry:
  111. self.retry(kwargs["task_id"], args, kwargs)
  112. return
  113. else:
  114. return retval
  115. def run(self, *args, **kwargs):
  116. """The actual task. All subclasses of :class:`Task` must define
  117. the run method, if not a ``NotImplementedError`` exception is raised.
  118. """
  119. raise NotImplementedError("Tasks must define a run method.")
  120. def get_logger(self, **kwargs):
  121. """Get a process-aware logger object."""
  122. return setup_logger(**kwargs)
  123. def get_publisher(self):
  124. """Get a celery task message publisher."""
  125. return TaskPublisher(connection=DjangoAMQPConnection())
  126. def get_consumer(self):
  127. """Get a celery task message consumer."""
  128. return TaskConsumer(connection=DjangoAMQPConnection())
  129. def requeue(self, task_id, args, kwargs):
  130. self.get_publisher().requeue_task(self.name, task_id, args, kwargs)
  131. def retry(self, task_id, args, kwargs):
  132. retry_queue.put(self.name, task_id, args, kwargs)
  133. def handle_exception(self, exception, retry_args, retry_kwargs):
  134. pass
  135. @classmethod
  136. def delay(cls, *args, **kwargs):
  137. """Delay this task for execution by the ``celery`` daemon(s)."""
  138. return delay_task(cls.name, *args, **kwargs)
  139. class TaskSet(object):
  140. """A task containing several subtasks, making it possible
  141. to track how many, or when all of the tasks are completed.
  142. Example Usage
  143. --------------
  144. >>> from djangofeeds.tasks import RefreshFeedTask
  145. >>> taskset = TaskSet(RefreshFeedTask, args=[
  146. ... {"feed_url": "http://cnn.com/rss"},
  147. ... {"feed_url": "http://bbc.com/rss"},
  148. ... {"feed_url": "http://xkcd.com/rss"}])
  149. >>> taskset_id, subtask_ids = taskset.run()
  150. """
  151. def __init__(self, task, args):
  152. """``task`` can be either a fully qualified task name, or a task
  153. class, args is a list of arguments for the subtasks.
  154. """
  155. try:
  156. task_name = task.name
  157. except AttributeError:
  158. task_name = task
  159. self.task_name = task_name
  160. self.arguments = args
  161. self.total = len(args)
  162. def run(self):
  163. """Run all tasks in the taskset.
  164. Returns a tuple with the taskset id, and a list of subtask id's.
  165. Examples
  166. --------
  167. >>> ts = RefreshFeeds(["http://foo.com/rss", http://bar.com/rss"])
  168. >>> taskset_id, subtask_ids = ts.run()
  169. >>> taskset_id
  170. "d2c9b261-8eff-4bfb-8459-1e1b72063514"
  171. >>> subtask_ids
  172. ["b4996460-d959-49c8-aeb9-39c530dcde25",
  173. "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
  174. >>> time.sleep(10)
  175. >>> is_done(taskset_id)
  176. True
  177. """
  178. taskset_id = str(uuid.uuid4())
  179. publisher = TaskPublisher(connection=DjangoAMQPConnection())
  180. subtask_ids = []
  181. for arg in self.arguments:
  182. subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
  183. taskset_id=taskset_id,
  184. task_args=[],
  185. task_kwargs=arg)
  186. subtask_ids.append(subtask_id)
  187. publisher.close()
  188. return taskset_id, subtask_ids
  189. class PeriodicTask(Task):
  190. """A periodic task is a task that behaves like a cron job.
  191. The ``run_every`` attribute defines how often the task is run (its
  192. interval), it can be either a ``datetime.timedelta`` object or a integer
  193. specifying the time in seconds.
  194. You have to register the periodic task in the task registry.
  195. Examples
  196. --------
  197. >>> from celery.task import tasks, PeriodicTask
  198. >>> from datetime import timedelta
  199. >>> class MyPeriodicTask(PeriodicTask):
  200. ... name = "my_periodic_task"
  201. ... run_every = timedelta(seconds=30)
  202. ...
  203. ... def run(self, **kwargs):
  204. ... logger = self.get_logger(**kwargs)
  205. ... logger.info("Running MyPeriodicTask")
  206. >>> tasks.register(MyPeriodicTask)
  207. """
  208. run_every = timedelta(days=1)
  209. type = "periodic"
  210. def __init__(self):
  211. if not self.run_every:
  212. raise NotImplementedError(
  213. "Periodic tasks must have a run_every attribute")
  214. # If run_every is a integer, convert it to timedelta seconds.
  215. if isinstance(self.run_every, int):
  216. self.run_every = timedelta(seconds=self.run_every)
  217. super(PeriodicTask, self).__init__()
  218. class DeleteExpiredTaskMetaTask(PeriodicTask):
  219. """A periodic task that deletes expired task metadata every day.
  220. It's only registered if ``settings.CELERY_TASK_META_USE_DB`` is set.
  221. """
  222. name = "celery.delete_expired_task_meta"
  223. run_every = timedelta(days=1)
  224. def run(self, **kwargs):
  225. logger = self.get_logger(**kwargs)
  226. logger.info("Deleting expired task meta objects...")
  227. default_backend.cleanup()
  228. tasks.register(DeleteExpiredTaskMetaTask)