task.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. from carrot.connection import DjangoAMQPConnection
  2. from celery.log import setup_logger
  3. from celery.conf import TASK_META_USE_DB
  4. from celery.registry import tasks
  5. from celery.messaging import TaskPublisher, TaskConsumer
  6. from celery.models import TaskMeta
  7. from django.core.cache import cache
  8. from datetime import timedelta
  9. from celery.backends import default_backend
  10. from celery.datastructures import PositionQueue
  11. import time
  12. import uuid
  13. import pickle
  14. import traceback
  15. class BaseAsyncResult(object):
  16. """Base class for pending result, takes ``backend`` argument."""
  17. def __init__(self, task_id, backend):
  18. self.task_id = task_id
  19. self.backend = backend
  20. def is_done(self):
  21. """Returns ``True`` if the task executed successfully."""
  22. return self.backend.is_done(self.task_id)
  23. def get(self):
  24. """Alias to ``wait``."""
  25. return self.wait()
  26. def wait(self, timeout=None):
  27. """Return the result when it arrives.
  28. If timeout is not ``None`` and the result does not arrive within
  29. ``timeout`` seconds then ``celery.backends.base.TimeoutError`` is
  30. raised. If the remote call raised an exception then that exception
  31. will be reraised by get()."""
  32. return self.backend.wait_for(self.task_id, timeout=timeout)
  33. def ready(self):
  34. """Returns ``True`` if the task executed successfully, or raised
  35. an exception. If the task is still pending, or is waiting for retry
  36. then ``False`` is returned."""
  37. status = self.backend.get_status(self.task_id)
  38. return status != "PENDING" or status != "RETRY"
  39. def successful(self):
  40. """Alias to ``is_done``."""
  41. return self.is_done()
  42. def __str__(self):
  43. """str(self) -> self.task_id"""
  44. return self.task_id
  45. def __repr__(self):
  46. return "<AsyncResult: %s>" % self.task_id
  47. @property
  48. def result(self):
  49. """The tasks resulting value."""
  50. if self.status == "DONE" or self.status == "FAILURE":
  51. return self.backend.get_result(self.task_id)
  52. return None
  53. @property
  54. def status(self):
  55. """The current status of the task."""
  56. return self.backend.get_status(self.task_id)
  57. class AsyncResult(BaseAsyncResult):
  58. """Pending task result using the default backend."""
  59. def __init__(self, task_id):
  60. super(AsyncResult, self).__init__(task_id, backend=default_backend)
  61. def delay_task(task_name, *args, **kwargs):
  62. """Delay a task for execution by the ``celery`` daemon.
  63. Examples
  64. --------
  65. >>> delay_task("update_record", name="George Constanza", age=32)
  66. """
  67. if task_name not in tasks:
  68. raise tasks.NotRegistered(
  69. "Task with name %s not registered in the task registry." % (
  70. task_name))
  71. publisher = TaskPublisher(connection=DjangoAMQPConnection())
  72. task_id = publisher.delay_task(task_name, *args, **kwargs)
  73. publisher.close()
  74. return AsyncResult(task_id)
  75. def discard_all():
  76. """Discard all waiting tasks.
  77. This will ignore all tasks waiting for execution, and they will
  78. be deleted from the messaging server.
  79. Returns the number of tasks discarded.
  80. """
  81. consumer = TaskConsumer(connection=DjangoAMQPConnection())
  82. discarded_count = consumer.discard_all()
  83. consumer.close()
  84. return discarded_count
  85. def mark_as_done(task_id, result):
  86. """Mark task as done (executed)."""
  87. return default_backend.mark_as_done(task_id, result)
  88. def mark_as_failure(task_id, exc):
  89. """Mark task as done (executed)."""
  90. return default_backend.mark_as_failure(task_id, exc)
  91. def is_done(task_id):
  92. """Returns ``True`` if task with ``task_id`` has been executed."""
  93. return default_backend.is_done(task_id)
  94. class Task(object):
  95. """A task that can be delayed for execution by the ``celery`` daemon.
  96. All subclasses of ``Task`` has to define the ``name`` attribute, which is
  97. the name of the task that can be passed to ``celery.task.delay_task``,
  98. it also has to define the ``run`` method, which is the actual method the
  99. ``celery`` daemon executes.
  100. Examples
  101. --------
  102. This is a simple task just logging a message,
  103. >>> from celery.task import tasks, Task
  104. >>> class MyTask(Task):
  105. ... name = "mytask"
  106. ...
  107. ... def run(self, some_arg=None, **kwargs):
  108. ... logger = self.get_logger(**kwargs)
  109. ... logger.info("Running MyTask with arg some_arg=%s" %
  110. ... some_arg))
  111. ... return 42
  112. ... tasks.register(MyTask)
  113. You can delay the task using the classmethod ``delay``...
  114. >>> result = MyTask.delay(some_arg="foo")
  115. >>> result.status # after some time
  116. 'DONE'
  117. >>> result.result
  118. 42
  119. ...or using the ``celery.task.delay_task`` function, by passing the
  120. name of the task.
  121. >>> from celery.task import delay_task
  122. >>> delay_task(MyTask.name, some_arg="foo")
  123. """
  124. name = None
  125. type = "regular"
  126. max_retries = 0 # unlimited
  127. retry_interval = timedelta(seconds=2)
  128. auto_retry = False
  129. def __init__(self):
  130. if not self.name:
  131. raise NotImplementedError("Tasks must define a name attribute.")
  132. def __call__(self, *args, **kwargs):
  133. """The ``__call__`` is called when you do ``Task().run()`` and calls
  134. the ``run`` method. It also catches any exceptions and logs them."""
  135. return self.run(*args, **kwargs)
  136. def run(self, *args, **kwargs):
  137. """The actual task. All subclasses of :class:`Task` must define
  138. the run method, if not a ``NotImplementedError`` exception is raised.
  139. """
  140. raise NotImplementedError("Tasks must define a run method.")
  141. def get_logger(self, **kwargs):
  142. """Get a process-aware logger object."""
  143. return setup_logger(**kwargs)
  144. def get_publisher(self):
  145. """Get a celery task message publisher."""
  146. return TaskPublisher(connection=DjangoAMQPConnection())
  147. def get_consumer(self):
  148. """Get a celery task message consumer."""
  149. return TaskConsumer(connection=DjangoAMQPConnection())
  150. def requeue(self, task_id, args, kwargs):
  151. self.get_publisher().requeue_task(self.name, task_id, args, kwargs)
  152. def retry(self, task_id, args, kwargs):
  153. retry_queue.put(self.name, task_id, args, kwargs)
  154. @classmethod
  155. def delay(cls, *args, **kwargs):
  156. """Delay this task for execution by the ``celery`` daemon(s)."""
  157. return delay_task(cls.name, *args, **kwargs)
  158. class TaskSet(object):
  159. """A task containing several subtasks, making it possible
  160. to track how many, or when all of the tasks are completed.
  161. Example Usage
  162. --------------
  163. >>> from djangofeeds.tasks import RefreshFeedTask
  164. >>> taskset = TaskSet(RefreshFeedTask, args=[
  165. ... {"feed_url": "http://cnn.com/rss"},
  166. ... {"feed_url": "http://bbc.com/rss"},
  167. ... {"feed_url": "http://xkcd.com/rss"}])
  168. >>> taskset_id, subtask_ids = taskset.run()
  169. """
  170. def __init__(self, task, args):
  171. """``task`` can be either a fully qualified task name, or a task
  172. class, args is a list of arguments for the subtasks.
  173. """
  174. try:
  175. task_name = task.name
  176. except AttributeError:
  177. task_name = task
  178. self.task_name = task_name
  179. self.arguments = args
  180. self.total = len(args)
  181. def run(self):
  182. """Run all tasks in the taskset.
  183. Returns a tuple with the taskset id, and a list of subtask id's.
  184. Examples
  185. --------
  186. >>> ts = RefreshFeeds([
  187. ... ["http://foo.com/rss", {}],
  188. ... ["http://bar.com/rss", {}],
  189. ... )
  190. >>> taskset_id, subtask_ids = ts.run()
  191. >>> taskset_id
  192. "d2c9b261-8eff-4bfb-8459-1e1b72063514"
  193. >>> subtask_ids
  194. ["b4996460-d959-49c8-aeb9-39c530dcde25",
  195. "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
  196. >>> time.sleep(10)
  197. >>> is_done(taskset_id)
  198. True
  199. """
  200. taskset_id = str(uuid.uuid4())
  201. publisher = TaskPublisher(connection=DjangoAMQPConnection())
  202. subtask_ids = []
  203. for arg, kwarg in self.arguments:
  204. subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
  205. taskset_id=taskset_id,
  206. task_args=arg,
  207. task_kwargs=kwarg)
  208. subtask_ids.append(subtask_id)
  209. publisher.close()
  210. return taskset_id, subtask_ids
  211. def xget(self):
  212. taskset_id, subtask_ids = self.run()
  213. results = dict([(task_id, AsyncResult(task_id))
  214. for task_id in subtask_ids])
  215. while results:
  216. for pending_result in results:
  217. if pending_result.status == "DONE":
  218. yield pending_result.result
  219. elif pending_result.status == "FAILURE":
  220. raise pending_result.result
  221. def join(self, timeout=None):
  222. time_start = time.time()
  223. taskset_id, subtask_ids = self.run()
  224. pending_results = map(AsyncResult, subtask_ids)
  225. results = PositionQueue(length=len(subtask_ids))
  226. while True:
  227. for i, pending_result in enumerate(pending_results):
  228. if pending_result.status == "DONE":
  229. results[i] = pending_result.result
  230. elif pending_result.status == "FAILURE":
  231. raise pending_result.result
  232. if results.is_full():
  233. return list(results)
  234. if timeout and time.time() > time_start + timeout:
  235. raise TimeOutError("The map operation timed out.")
  236. @classmethod
  237. def remote_execute(cls, func, args):
  238. pickled = pickle.dumps(func)
  239. arguments = [[[pickled, arg, {}], {}] for arg in args]
  240. return cls(ExecuteRemoteTask, arguments)
  241. @classmethod
  242. def map(cls, func, args, timeout=None):
  243. remote_task = cls.remote_execute(func, args)
  244. return remote_task.join(timeout=timeout)
  245. @classmethod
  246. def map_async(cls, func, args, timeout=None):
  247. serfunc = pickle.dumps(func)
  248. return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
  249. def dmap(func, args, timeout=None):
  250. """Distribute processing of the arguments and collect the results.
  251. Example
  252. --------
  253. >>> from celery.task import map
  254. >>> import operator
  255. >>> dmap(operator.add, [[2, 2], [4, 4], [8, 8]])
  256. [4, 8, 16]
  257. """
  258. return TaskSet.map(func, args, timeout=timeout)
  259. class AsynchronousMapTask(Task):
  260. name = "celery.map_async"
  261. def run(self, serfunc, args, **kwargs):
  262. timeout = kwargs.get("timeout")
  263. logger = self.get_logger(**kwargs)
  264. logger.info("<<<<<<< ASYNCMAP: %s(%s)" % (serfunc, args))
  265. return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
  266. tasks.register(AsynchronousMapTask)
  267. def dmap_async(func, args, timeout=None):
  268. """Distribute processing of the arguments and collect the results
  269. asynchronously. Returns a :class:`AsyncResult` object.
  270. Example
  271. --------
  272. >>> from celery.task import dmap_async
  273. >>> import operator
  274. >>> presult = dmap_async(operator.add, [[2, 2], [4, 4], [8, 8]])
  275. >>> presult
  276. <AsyncResult: 373550e8-b9a0-4666-bc61-ace01fa4f91d>
  277. >>> presult.status
  278. 'DONE'
  279. >>> presult.result
  280. [4, 8, 16]
  281. """
  282. return TaskSet.map_async(func, args, timeout=timeout)
  283. class PeriodicTask(Task):
  284. """A periodic task is a task that behaves like a cron job.
  285. The ``run_every`` attribute defines how often the task is run (its
  286. interval), it can be either a ``datetime.timedelta`` object or a integer
  287. specifying the time in seconds.
  288. You have to register the periodic task in the task registry.
  289. Examples
  290. --------
  291. >>> from celery.task import tasks, PeriodicTask
  292. >>> from datetime import timedelta
  293. >>> class MyPeriodicTask(PeriodicTask):
  294. ... name = "my_periodic_task"
  295. ... run_every = timedelta(seconds=30)
  296. ...
  297. ... def run(self, **kwargs):
  298. ... logger = self.get_logger(**kwargs)
  299. ... logger.info("Running MyPeriodicTask")
  300. >>> tasks.register(MyPeriodicTask)
  301. """
  302. run_every = timedelta(days=1)
  303. type = "periodic"
  304. def __init__(self):
  305. if not self.run_every:
  306. raise NotImplementedError(
  307. "Periodic tasks must have a run_every attribute")
  308. # If run_every is a integer, convert it to timedelta seconds.
  309. if isinstance(self.run_every, int):
  310. self.run_every = timedelta(seconds=self.run_every)
  311. super(PeriodicTask, self).__init__()
  312. class DeleteExpiredTaskMetaTask(PeriodicTask):
  313. """A periodic task that deletes expired task metadata every day.
  314. It's only registered if ``settings.CELERY_TASK_META_USE_DB`` is set.
  315. """
  316. name = "celery.delete_expired_task_meta"
  317. run_every = timedelta(days=1)
  318. def run(self, **kwargs):
  319. logger = self.get_logger(**kwargs)
  320. logger.info("Deleting expired task meta objects...")
  321. default_backend.cleanup()
  322. tasks.register(DeleteExpiredTaskMetaTask)
  323. class ExecuteRemoteTask(Task):
  324. name = "celery.execute_remote"
  325. def run(self, ser_callable, fargs, fkwargs, **kwargs):
  326. callable_ = pickle.loads(ser_callable)
  327. return callable_(*fargs, **fkwargs)
  328. tasks.register(ExecuteRemoteTask)
  329. def execute_remote(func, *args, **kwargs):
  330. return ExecuteRemoteTask.delay(pickle.dumps(func), args, kwargs)
  331. class SumTask(Task):
  332. name = "celery.sum_task"
  333. def run(self, *numbers, **kwargs):
  334. return sum(numbers)
  335. tasks.register(SumTask)