sets.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. from __future__ import with_statement
  4. from celery.app.state import get_current_task
  5. from celery.canvas import subtask, maybe_subtask # noqa
  6. from celery.utils import uuid
  7. from celery.utils.compat import UserList
  8. class TaskSet(UserList):
  9. """A task containing several subtasks, making it possible
  10. to track how many, or when all of the tasks have been completed.
  11. :param tasks: A list of :class:`subtask` instances.
  12. Example::
  13. >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
  14. >>> s = TaskSet(refresh_feed.s(url) for url in urls)
  15. >>> taskset_result = s.apply_async()
  16. >>> list_of_return_values = taskset_result.join() # *expensive*
  17. """
  18. _app = None
  19. def __init__(self, tasks=None, app=None, Publisher=None):
  20. self._app = app or self._app
  21. self.data = [maybe_subtask(t) for t in tasks or []]
  22. self._Publisher = Publisher
  23. def apply_async(self, connection=None, connect_timeout=None,
  24. publisher=None, taskset_id=None):
  25. """Apply TaskSet."""
  26. app = self.app
  27. if app.conf.CELERY_ALWAYS_EAGER:
  28. return self.apply(taskset_id=taskset_id)
  29. with app.default_connection(connection, connect_timeout) as conn:
  30. setid = taskset_id or uuid()
  31. pub = publisher or self.Publisher(connection=conn)
  32. try:
  33. results = self._async_results(setid, pub)
  34. finally:
  35. if not publisher: # created by us.
  36. pub.close()
  37. result = app.TaskSetResult(setid, results)
  38. parent = get_current_task()
  39. if parent:
  40. parent.request.children.append(result)
  41. return result
  42. def _async_results(self, taskset_id, publisher):
  43. return [task.apply_async(taskset_id=taskset_id, publisher=publisher)
  44. for task in self.tasks]
  45. def apply(self, taskset_id=None):
  46. """Applies the TaskSet locally by blocking until all tasks return."""
  47. setid = taskset_id or uuid()
  48. return self.app.TaskSetResult(setid, self._sync_results(setid))
  49. def _sync_results(self, taskset_id):
  50. return [task.apply(taskset_id=taskset_id) for task in self.tasks]
  51. @property
  52. def total(self):
  53. """Number of subtasks in this TaskSet."""
  54. return len(self)
  55. def _get_app(self):
  56. return self._app or self.data[0].type._get_app()
  57. def _set_app(self, app):
  58. self._app = app
  59. app = property(_get_app, _set_app)
  60. def _get_tasks(self):
  61. return self.data
  62. def _set_tasks(self, tasks):
  63. self.data = tasks
  64. tasks = property(_get_tasks, _set_tasks)
  65. def _get_Publisher(self):
  66. return self._Publisher or self.app.amqp.TaskPublisher
  67. def _set_Publisher(self, Publisher):
  68. self._Publisher = Publisher
  69. Publisher = property(_get_Publisher, _set_Publisher)