strategy.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # -*- coding: utf-8 -*-
  2. """Task execution strategy (optimization)."""
  3. from __future__ import absolute_import, unicode_literals
  4. import logging
  5. from kombu.async.timer import to_timestamp
  6. from kombu.five import buffer_t
  7. from celery.exceptions import InvalidTaskError
  8. from celery.utils.imports import symbol_by_name
  9. from celery.utils.log import get_logger
  10. from celery.utils.saferepr import saferepr
  11. from celery.utils.time import timezone
  12. from .request import create_request_cls
  13. from .state import task_reserved
  14. __all__ = ('default',)
  15. logger = get_logger(__name__)
  16. # pylint: disable=redefined-outer-name
  17. # We cache globals and attribute lookups, so disable this warning.
  18. def proto1_to_proto2(message, body):
  19. """Convert Task message protocol 1 arguments to protocol 2.
  20. Returns:
  21. Tuple: of ``(body, headers, already_decoded_status, utc)``
  22. """
  23. try:
  24. args, kwargs = body.get('args', ()), body.get('kwargs', {})
  25. kwargs.items # pylint: disable=pointless-statement
  26. except KeyError:
  27. raise InvalidTaskError('Message does not have args/kwargs')
  28. except AttributeError:
  29. raise InvalidTaskError(
  30. 'Task keyword arguments must be a mapping',
  31. )
  32. body.update(
  33. argsrepr=saferepr(args),
  34. kwargsrepr=saferepr(kwargs),
  35. headers=message.headers,
  36. )
  37. try:
  38. body['group'] = body['taskset']
  39. except KeyError:
  40. pass
  41. embed = {
  42. 'callbacks': body.get('callbacks'),
  43. 'errbacks': body.get('errbacks'),
  44. 'chord': body.get('chord'),
  45. 'chain': None,
  46. }
  47. return (args, kwargs, embed), body, True, body.get('utc', True)
  48. def default(task, app, consumer,
  49. info=logger.info, error=logger.error, task_reserved=task_reserved,
  50. to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
  51. proto1_to_proto2=proto1_to_proto2):
  52. """Default task execution strategy.
  53. Note:
  54. Strategies are here as an optimization, so sadly
  55. it's not very easy to override.
  56. """
  57. hostname = consumer.hostname
  58. connection_errors = consumer.connection_errors
  59. _does_info = logger.isEnabledFor(logging.INFO)
  60. # task event related
  61. # (optimized to avoid calling request.send_event)
  62. eventer = consumer.event_dispatcher
  63. events = eventer and eventer.enabled
  64. send_event = eventer.send
  65. task_sends_events = events and task.send_events
  66. call_at = consumer.timer.call_at
  67. apply_eta_task = consumer.apply_eta_task
  68. rate_limits_enabled = not consumer.disable_rate_limits
  69. get_bucket = consumer.task_buckets.__getitem__
  70. handle = consumer.on_task_request
  71. limit_task = consumer._limit_task
  72. limit_post_eta = consumer._limit_post_eta
  73. body_can_be_buffer = consumer.pool.body_can_be_buffer
  74. Request = symbol_by_name(task.Request)
  75. Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
  76. revoked_tasks = consumer.controller.state.revoked
  77. def task_message_handler(message, body, ack, reject, callbacks,
  78. to_timestamp=to_timestamp):
  79. if body is None:
  80. body, headers, decoded, utc = (
  81. message.body, message.headers, False, app.uses_utc_timezone(),
  82. )
  83. if not body_can_be_buffer:
  84. body = bytes(body) if isinstance(body, buffer_t) else body
  85. else:
  86. body, headers, decoded, utc = proto1_to_proto2(message, body)
  87. req = Req(
  88. message,
  89. on_ack=ack, on_reject=reject, app=app, hostname=hostname,
  90. eventer=eventer, task=task, connection_errors=connection_errors,
  91. body=body, headers=headers, decoded=decoded, utc=utc,
  92. )
  93. if _does_info:
  94. info('Received task: %s', req)
  95. if (req.expires or req.id in revoked_tasks) and req.revoked():
  96. return
  97. if task_sends_events:
  98. send_event(
  99. 'task-received',
  100. uuid=req.id, name=req.name,
  101. args=req.argsrepr, kwargs=req.kwargsrepr,
  102. root_id=req.root_id, parent_id=req.parent_id,
  103. retries=req.request_dict.get('retries', 0),
  104. eta=req.eta and req.eta.isoformat(),
  105. expires=req.expires and req.expires.isoformat(),
  106. )
  107. bucket = None
  108. eta = None
  109. if req.eta:
  110. try:
  111. if req.utc:
  112. eta = to_timestamp(to_system_tz(req.eta))
  113. else:
  114. eta = to_timestamp(req.eta, app.timezone)
  115. except (OverflowError, ValueError) as exc:
  116. error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
  117. req.eta, exc, req.info(safe=True), exc_info=True)
  118. req.reject(requeue=False)
  119. if rate_limits_enabled:
  120. bucket = get_bucket(task.name)
  121. if eta and bucket:
  122. consumer.qos.increment_eventually()
  123. return call_at(eta, limit_post_eta, (req, bucket, 1),
  124. priority=6)
  125. if eta:
  126. consumer.qos.increment_eventually()
  127. call_at(eta, apply_eta_task, (req,), priority=6)
  128. return task_message_handler
  129. if bucket:
  130. return limit_task(req, bucket, 1)
  131. task_reserved(req)
  132. if callbacks:
  133. [callback(req) for callback in callbacks]
  134. handle(req)
  135. return task_message_handler