task.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from carrot.connection import DjangoAMQPConnection
  2. from celery.log import setup_logger
  3. from celery.conf import TASK_META_USE_DB
  4. from celery.registry import tasks
  5. from celery.messaging import TaskPublisher, TaskConsumer
  6. from celery.models import TaskMeta
  7. from django.core.cache import cache
  8. from datetime import timedelta
  9. import uuid
  10. import traceback
  11. __all__ = ["delay_task", "discard_all", "gen_task_done_cache_key",
  12. "mark_as_done", "is_done", "Task", "PeriodicTask", "TestTask"]
  13. def delay_task(task_name, **kwargs):
  14. if task_name not in tasks:
  15. raise tasks.NotRegistered(
  16. "Task with name %s not registered in the task registry." % (
  17. task_name))
  18. publisher = TaskPublisher(connection=DjangoAMQPConnection)
  19. task_id = publisher.delay_task(task_name, **kwargs)
  20. publisher.close()
  21. return task_id
  22. def discard_all():
  23. consumer = TaskConsumer(connection=DjangoAMQPConnection)
  24. discarded_count = consumer.discard_all()
  25. consumer.close()
  26. return discarded_count
  27. def gen_task_done_cache_key(task_id):
  28. return "celery-task-done-marker-%s" % task_id
  29. def mark_as_done(task_id, result):
  30. if result is None:
  31. result = True
  32. if TASK_META_USE_DB:
  33. TaskMeta.objects.mark_as_done(task_id)
  34. else:
  35. cache_key = gen_task_done_cache_key(task_id)
  36. cache.set(cache_key, result)
  37. def is_done(task_id):
  38. if TASK_META_USE_DB:
  39. return TaskMeta.objects.is_done(task_id)
  40. else:
  41. cache_key = gen_task_done_cache_key(task_id)
  42. return cache.get(cache_key)
  43. class Task(object):
  44. name = None
  45. type = "regular"
  46. def __init__(self):
  47. if not self.name:
  48. raise NotImplementedError("Tasks must define a name attribute.")
  49. def __call__(self, **kwargs):
  50. try:
  51. retval = self.run(**kwargs)
  52. except Exception, e:
  53. logger = self.get_logger(**kwargs)
  54. logger.critical("Task got exception %s: %s\n%s" % (
  55. e.__class__, e, traceback.format_exc()))
  56. return
  57. else:
  58. return retval
  59. def run(self, **kwargs):
  60. raise NotImplementedError("Tasks must define a run method.")
  61. def get_logger(self, **kwargs):
  62. """Get a process-aware logger object."""
  63. return setup_logger(**kwargs)
  64. def get_publisher(self):
  65. """Get a celery task message publisher."""
  66. return TaskPublisher(connection=DjangoAMQPConnection)
  67. def get_consumer(self):
  68. """Get a celery task message consumer."""
  69. return TaskConsumer(connection=DjangoAMQPConnection)
  70. @classmethod
  71. def delay(cls, **kwargs):
  72. return delay_task(cls.name, **kwargs)
  73. class TaskSet(object):
  74. """A task containing several subtasks, making it possible
  75. to track how many, or when all of the tasks are completed.
  76. Example Usage
  77. --------------
  78. >>> from djangofeeds.tasks import RefreshFeedTask
  79. >>> taskset = TaskSet(RefreshFeedTask, args=[
  80. ... {"feed_url": "http://cnn.com/rss"},
  81. ... {"feed_url": "http://bbc.com/rss"},
  82. ... {"feed_url": "http://xkcd.com/rss"}])
  83. >>> taskset_id = taskset.delay()
  84. """
  85. def __init__(self, task, args):
  86. """``task`` can be either a fully qualified task name, or a task
  87. class, args is a list of arguments for the subtasks.
  88. """
  89. try:
  90. task_name = task.name
  91. except AttributeError:
  92. task_name = task
  93. self.task_name = task_name
  94. self.arguments = args
  95. self.total = len(args)
  96. def run(self):
  97. taskset_id = str(uuid.uuid4())
  98. publisher = TaskPublisher(connection=DjangoAMQPConnection)
  99. subtask_ids = []
  100. for arg in self.arguments:
  101. subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
  102. taskset_id=taskset_id,
  103. task_kwargs=arg)
  104. subtask_ids.append(subtask_id)
  105. publisher.close()
  106. return taskset_id, subtask_ids
  107. class PeriodicTask(Task):
  108. run_every = timedelta(days=1)
  109. type = "periodic"
  110. def __init__(self):
  111. if not self.run_every:
  112. raise NotImplementedError(
  113. "Periodic tasks must have a run_every attribute")
  114. # If run_every is a integer, convert it to timedelta seconds.
  115. if isinstance(self.run_every, int):
  116. self.run_every = timedelta(seconds=self.run_every)
  117. super(PeriodicTask, self).__init__()
  118. class TestTask(Task):
  119. name = "celery.test_task"
  120. def run(self, some_arg, **kwargs):
  121. logger = self.get_logger(**kwargs)
  122. logger.info("TestTask got some_arg=%s" % some_arg)
  123. tasks.register(TestTask)
  124. class DeleteExpiredTaskMetaTask(PeriodicTask):
  125. name = "celery.delete_expired_task_meta"
  126. run_every = timedelta(days=1)
  127. def run(self, **kwargs):
  128. logger = self.get_logger(**kwargs)
  129. logger.info("Deleting expired task meta objects...")
  130. TaskMeta.objects.delete_expired()
  131. if TASK_META_USE_DB:
  132. tasks.register(DeleteExpiredTaskMetaTask)