task.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from carrot.connection import DjangoAMQPConnection
  2. from celery.log import setup_logger
  3. from celery.registry import tasks
  4. from celery.messaging import TaskPublisher, TaskConsumer
  5. from django.core.cache import cache
  6. from datetime import timedelta
  7. def delay_task(task_name, **kwargs):
  8. if task_name not in tasks:
  9. raise tasks.NotRegistered(
  10. "Task with name %s not registered in the task registry." % (
  11. task_name))
  12. publisher = TaskPublisher(connection=DjangoAMQPConnection)
  13. task_id = publisher.delay_task(task_name, **kwargs)
  14. publisher.close()
  15. return task_id
  16. def discard_all():
  17. consumer = TaskConsumer(connection=DjangoAMQPConnection)
  18. discarded_count = consumer.discard_all()
  19. consumer.close()
  20. return discarded_count
  21. def gen_task_done_cache_key(task_id):
  22. return "celery-task-done-marker-%s" % task_id
  23. def mark_as_done(task_id, result):
  24. if result is None:
  25. result = True
  26. cache_key = gen_task_done_cache_key(task_id)
  27. cache.set(cache_key, result)
  28. def is_done(task_id):
  29. cache_key = gen_task_done_cache_key(task_id)
  30. return cache.get(cache_key)
  31. class Task(object):
  32. name = None
  33. type = "regular"
  34. def __init__(self):
  35. if not self.name:
  36. raise NotImplementedError("Tasks must define a name attribute.")
  37. def __call__(self, **kwargs):
  38. return self.run(**kwargs)
  39. def run(self, **kwargs):
  40. raise NotImplementedError("Tasks must define a run method.")
  41. def get_logger(self, **kwargs):
  42. """Get a process-aware logger object."""
  43. return setup_logger(**kwargs)
  44. def get_publisher(self):
  45. """Get a celery task message publisher."""
  46. return TaskPublisher(connection=DjangoAMQPConnection)
  47. def get_consumer(self):
  48. """Get a celery task message consumer."""
  49. return TaskConsumer(connection=DjangoAMQPConnection)
  50. @classmethod
  51. def delay(cls, **kwargs):
  52. return delay_task(cls.name, **kwargs)
  53. class PeriodicTask(Task):
  54. run_every = timedelta(days=1)
  55. type = "periodic"
  56. def __init__(self):
  57. if not self.run_every:
  58. raise NotImplementedError(
  59. "Periodic tasks must have a run_every attribute")
  60. # If run_every is a integer, convert it to timedelta seconds.
  61. if isinstance(self.run_every, int):
  62. self.run_every = timedelta(seconds=self.run_every)
  63. super(PeriodicTask, self).__init__()
  64. class TestTask(Task):
  65. name = "celery-test-task"
  66. def run(self, some_arg, **kwargs):
  67. logger = self.get_logger(**kwargs)
  68. logger.info("TestTask got some_arg=%s" % some_arg)
  69. def after(self, task_id):
  70. logger = self.get_logger(**kwargs)
  71. logger.info("TestTask with id %s was successfully executed." % task_id)
  72. tasks.register(TestTask)