sets.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. from __future__ import with_statement
  4. from celery.app import app_or_default
  5. from celery.app.state import get_current_task
  6. from celery.canvas import subtask, maybe_subtask # noqa
  7. from celery.utils import uuid
  8. from celery.utils.compat import UserList
  9. class TaskSet(UserList):
  10. """A task containing several subtasks, making it possible
  11. to track how many, or when all of the tasks have been completed.
  12. :param tasks: A list of :class:`subtask` instances.
  13. Example::
  14. >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
  15. >>> s = TaskSet(refresh_feed.s(url) for url in urls)
  16. >>> taskset_result = s.apply_async()
  17. >>> list_of_return_values = taskset_result.join() # *expensive*
  18. """
  19. app = None
  20. def __init__(self, tasks=None, app=None, Publisher=None):
  21. self.app = app_or_default(app or self.app)
  22. self.data = [maybe_subtask(t) for t in tasks or []]
  23. self.total = len(self.tasks)
  24. self.Publisher = Publisher or self.app.amqp.TaskProducer
  25. def apply_async(self, connection=None, connect_timeout=None,
  26. publisher=None, taskset_id=None):
  27. """Apply TaskSet."""
  28. app = self.app
  29. if app.conf.CELERY_ALWAYS_EAGER:
  30. return self.apply(taskset_id=taskset_id)
  31. with app.default_connection(connection, connect_timeout) as conn:
  32. setid = taskset_id or uuid()
  33. pub = publisher or self.Publisher(conn)
  34. results = self._async_results(setid, pub)
  35. result = app.TaskSetResult(setid, results)
  36. parent = get_current_task()
  37. if parent:
  38. parent.request.children.append(result)
  39. return result
  40. def _async_results(self, taskset_id, publisher):
  41. return [task.apply_async(taskset_id=taskset_id, publisher=publisher)
  42. for task in self.tasks]
  43. def apply(self, taskset_id=None):
  44. """Applies the TaskSet locally by blocking until all tasks return."""
  45. setid = taskset_id or uuid()
  46. return self.app.TaskSetResult(setid, self._sync_results(setid))
  47. def _sync_results(self, taskset_id):
  48. return [task.apply(taskset_id=taskset_id) for task in self.tasks]
  49. def _get_tasks(self):
  50. return self.data
  51. def _set_tasks(self, tasks):
  52. self.data = tasks
  53. tasks = property(_get_tasks, _set_tasks)