strategy.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.worker.strategy
  4. ~~~~~~~~~~~~~~~~~~~~~~
  5. Task execution strategy (optimization).
  6. """
  7. from __future__ import absolute_import, unicode_literals
  8. import logging
  9. from kombu.async.timer import to_timestamp
  10. from kombu.five import buffer_t
  11. from celery.exceptions import InvalidTaskError
  12. from celery.utils.log import get_logger
  13. from celery.utils.saferepr import saferepr
  14. from celery.utils.timeutils import timezone
  15. from .request import Request, create_request_cls
  16. from .state import task_reserved
  17. __all__ = ['default']
  18. logger = get_logger(__name__)
  19. def proto1_to_proto2(message, body):
  20. """Converts Task message protocol 1 arguments to protocol 2.
  21. Returns tuple of ``(body, headers, already_decoded_status, utc)``
  22. """
  23. try:
  24. args, kwargs = body['args'], body['kwargs']
  25. kwargs.items
  26. except KeyError:
  27. raise InvalidTaskError('Message does not have args/kwargs')
  28. except AttributeError:
  29. raise InvalidTaskError(
  30. 'Task keyword arguments must be a mapping',
  31. )
  32. body.update(
  33. argsrepr=saferepr(args),
  34. kwargsrepr=saferepr(kwargs),
  35. headers=message.headers,
  36. )
  37. try:
  38. body['group'] = body['taskset']
  39. except KeyError:
  40. pass
  41. embed = {
  42. 'callbacks': body.get('callbacks'),
  43. 'errbacks': body.get('errbacks'),
  44. 'chord': body.get('chord'),
  45. 'chain': None,
  46. }
  47. return (args, kwargs, embed), body, True, body.get('utc', True)
  48. def default(task, app, consumer,
  49. info=logger.info, error=logger.error, task_reserved=task_reserved,
  50. to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
  51. proto1_to_proto2=proto1_to_proto2):
  52. hostname = consumer.hostname
  53. eventer = consumer.event_dispatcher
  54. connection_errors = consumer.connection_errors
  55. _does_info = logger.isEnabledFor(logging.INFO)
  56. events = eventer and eventer.enabled
  57. send_event = eventer.send
  58. call_at = consumer.timer.call_at
  59. apply_eta_task = consumer.apply_eta_task
  60. rate_limits_enabled = not consumer.disable_rate_limits
  61. get_bucket = consumer.task_buckets.__getitem__
  62. handle = consumer.on_task_request
  63. limit_task = consumer._limit_task
  64. body_can_be_buffer = consumer.pool.body_can_be_buffer
  65. Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
  66. revoked_tasks = consumer.controller.state.revoked
  67. def task_message_handler(message, body, ack, reject, callbacks,
  68. to_timestamp=to_timestamp):
  69. if body is None:
  70. body, headers, decoded, utc = (
  71. message.body, message.headers, False, True,
  72. )
  73. if not body_can_be_buffer:
  74. body = bytes(body) if isinstance(body, buffer_t) else body
  75. else:
  76. body, headers, decoded, utc = proto1_to_proto2(message, body)
  77. req = Req(
  78. message,
  79. on_ack=ack, on_reject=reject, app=app, hostname=hostname,
  80. eventer=eventer, task=task, connection_errors=connection_errors,
  81. body=body, headers=headers, decoded=decoded, utc=utc,
  82. )
  83. if _does_info:
  84. info('Received task: %s', req)
  85. if (req.expires or req.id in revoked_tasks) and req.revoked():
  86. return
  87. if events:
  88. send_event(
  89. 'task-received',
  90. uuid=req.id, name=req.name,
  91. args=req.argsrepr, kwargs=req.kwargsrepr,
  92. root_id=req.root_id, parent_id=req.parent_id,
  93. retries=req.request_dict.get('retries', 0),
  94. eta=req.eta and req.eta.isoformat(),
  95. expires=req.expires and req.expires.isoformat(),
  96. )
  97. if req.eta:
  98. try:
  99. if req.utc:
  100. eta = to_timestamp(to_system_tz(req.eta))
  101. else:
  102. eta = to_timestamp(req.eta, timezone.local)
  103. except OverflowError as exc:
  104. error("Couldn't convert eta %s to timestamp: %r. Task: %r",
  105. req.eta, exc, req.info(safe=True), exc_info=True)
  106. req.acknowledge()
  107. else:
  108. consumer.qos.increment_eventually()
  109. call_at(eta, apply_eta_task, (req,), priority=6)
  110. else:
  111. if rate_limits_enabled:
  112. bucket = get_bucket(task.name)
  113. if bucket:
  114. return limit_task(req, bucket, 1)
  115. task_reserved(req)
  116. if callbacks:
  117. [callback(req) for callback in callbacks]
  118. handle(req)
  119. return task_message_handler