strategy.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # -*- coding: utf-8 -*-
  2. """Task execution strategy (optimization)."""
  3. from __future__ import absolute_import, unicode_literals
  4. import logging
  5. from kombu.async.timer import to_timestamp
  6. from kombu.five import buffer_t
  7. from celery.exceptions import InvalidTaskError
  8. from celery.utils.log import get_logger
  9. from celery.utils.saferepr import saferepr
  10. from celery.utils.time import timezone
  11. from .request import Request, create_request_cls
  12. from .state import task_reserved
  13. __all__ = ['default']
  14. logger = get_logger(__name__)
  15. def proto1_to_proto2(message, body):
  16. """Converts Task message protocol 1 arguments to protocol 2.
  17. Returns:
  18. Tuple: of ``(body, headers, already_decoded_status, utc)``
  19. """
  20. try:
  21. args, kwargs = body['args'], body['kwargs']
  22. kwargs.items
  23. except KeyError:
  24. raise InvalidTaskError('Message does not have args/kwargs')
  25. except AttributeError:
  26. raise InvalidTaskError(
  27. 'Task keyword arguments must be a mapping',
  28. )
  29. body.update(
  30. argsrepr=saferepr(args),
  31. kwargsrepr=saferepr(kwargs),
  32. headers=message.headers,
  33. )
  34. try:
  35. body['group'] = body['taskset']
  36. except KeyError:
  37. pass
  38. embed = {
  39. 'callbacks': body.get('callbacks'),
  40. 'errbacks': body.get('errbacks'),
  41. 'chord': body.get('chord'),
  42. 'chain': None,
  43. }
  44. return (args, kwargs, embed), body, True, body.get('utc', True)
  45. def default(task, app, consumer,
  46. info=logger.info, error=logger.error, task_reserved=task_reserved,
  47. to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
  48. proto1_to_proto2=proto1_to_proto2):
  49. hostname = consumer.hostname
  50. eventer = consumer.event_dispatcher
  51. connection_errors = consumer.connection_errors
  52. _does_info = logger.isEnabledFor(logging.INFO)
  53. events = eventer and eventer.enabled
  54. send_event = eventer.send
  55. call_at = consumer.timer.call_at
  56. apply_eta_task = consumer.apply_eta_task
  57. rate_limits_enabled = not consumer.disable_rate_limits
  58. get_bucket = consumer.task_buckets.__getitem__
  59. handle = consumer.on_task_request
  60. limit_task = consumer._limit_task
  61. body_can_be_buffer = consumer.pool.body_can_be_buffer
  62. Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
  63. revoked_tasks = consumer.controller.state.revoked
  64. def task_message_handler(message, body, ack, reject, callbacks,
  65. to_timestamp=to_timestamp):
  66. if body is None:
  67. body, headers, decoded, utc = (
  68. message.body, message.headers, False, True,
  69. )
  70. if not body_can_be_buffer:
  71. body = bytes(body) if isinstance(body, buffer_t) else body
  72. else:
  73. body, headers, decoded, utc = proto1_to_proto2(message, body)
  74. req = Req(
  75. message,
  76. on_ack=ack, on_reject=reject, app=app, hostname=hostname,
  77. eventer=eventer, task=task, connection_errors=connection_errors,
  78. body=body, headers=headers, decoded=decoded, utc=utc,
  79. )
  80. if _does_info:
  81. info('Received task: %s', req)
  82. if (req.expires or req.id in revoked_tasks) and req.revoked():
  83. return
  84. if events:
  85. send_event(
  86. 'task-received',
  87. uuid=req.id, name=req.name,
  88. args=req.argsrepr, kwargs=req.kwargsrepr,
  89. root_id=req.root_id, parent_id=req.parent_id,
  90. retries=req.request_dict.get('retries', 0),
  91. eta=req.eta and req.eta.isoformat(),
  92. expires=req.expires and req.expires.isoformat(),
  93. )
  94. if req.eta:
  95. try:
  96. if req.utc:
  97. eta = to_timestamp(to_system_tz(req.eta))
  98. else:
  99. eta = to_timestamp(req.eta, timezone.local)
  100. except OverflowError as exc:
  101. error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
  102. req.eta, exc, req.info(safe=True), exc_info=True)
  103. req.acknowledge()
  104. else:
  105. consumer.qos.increment_eventually()
  106. call_at(eta, apply_eta_task, (req,), priority=6)
  107. else:
  108. if rate_limits_enabled:
  109. bucket = get_bucket(task.name)
  110. if bucket:
  111. return limit_task(req, bucket, 1)
  112. task_reserved(req)
  113. if callbacks:
  114. [callback(req) for callback in callbacks]
  115. handle(req)
  116. return task_message_handler