Ask Solem 8 年之前
父节点
当前提交
3dcc606b7e
共有 51 个文件被更改,包括 2488 次插入1583 次删除
  1. 11 4
      celery/__init__.py
  2. 8 8
      celery/app/__init__.py
  3. 147 78
      celery/app/amqp.py
  4. 8 4
      celery/app/backends.py
  5. 271 163
      celery/app/base.py
  6. 13 6
      celery/app/defaults.py
  7. 12 6
      celery/app/events.py
  8. 56 26
      celery/app/log.py
  9. 7 16
      celery/app/registry.py
  10. 32 12
      celery/app/routes.py
  11. 190 122
      celery/app/task.py
  12. 17 5
      celery/app/trace.py
  13. 61 36
      celery/app/utils.py
  14. 80 52
      celery/backends/async.py
  15. 264 166
      celery/backends/base.py
  16. 71 47
      celery/backends/rpc.py
  17. 9 3
      celery/beat.py
  18. 48 47
      celery/bootsteps.py
  19. 254 125
      celery/canvas.py
  20. 9 2
      celery/concurrency/asynpool.py
  21. 15 26
      celery/contrib/pytest.py
  22. 126 82
      celery/events/state.py
  23. 0 4
      celery/platforms.py
  24. 11 2
      celery/schedules.py
  25. 13 17
      celery/utils/collections.py
  26. 3 3
      celery/utils/debug.py
  27. 5 5
      celery/utils/functional.py
  28. 7 7
      celery/utils/graph.py
  29. 6 6
      celery/utils/imports.py
  30. 1 1
      celery/utils/log.py
  31. 34 17
      celery/utils/saferepr.py
  32. 3 3
      celery/utils/serialization.py
  33. 2 4
      celery/utils/static/__init__.py
  34. 7 4
      celery/utils/sysinfo.py
  35. 27 26
      celery/worker/autoscale.py
  36. 28 25
      celery/worker/components.py
  37. 4 2
      celery/worker/consumer/agent.py
  38. 8 6
      celery/worker/consumer/connection.py
  39. 113 83
      celery/worker/consumer/consumer.py
  40. 3 2
      celery/worker/consumer/control.py
  41. 14 13
      celery/worker/consumer/events.py
  42. 48 38
      celery/worker/consumer/gossip.py
  43. 11 6
      celery/worker/consumer/heart.py
  44. 22 13
      celery/worker/consumer/mingle.py
  45. 9 7
      celery/worker/consumer/tasks.py
  46. 100 58
      celery/worker/control.py
  47. 15 9
      celery/worker/heartbeat.py
  48. 72 47
      celery/worker/request.py
  49. 52 40
      celery/worker/state.py
  50. 45 14
      celery/worker/strategy.py
  51. 116 85
      celery/worker/worker.py

+ 11 - 4
celery/__init__.py

@@ -8,7 +8,7 @@
 import os
 import os
 import re
 import re
 import sys
 import sys
-from collections import namedtuple
+from typing import NamedTuple
 
 
 SERIES = 'latentcall'
 SERIES = 'latentcall'
 
 
@@ -30,9 +30,16 @@ __all__ = [
 
 
 VERSION_BANNER = '{0} ({1})'.format(__version__, SERIES)
 VERSION_BANNER = '{0} ({1})'.format(__version__, SERIES)
 
 
-version_info_t = namedtuple('version_info_t', (
-    'major', 'minor', 'micro', 'releaselevel', 'serial',
-))
+
+class version_info_t(NamedTuple):
+    """Version information tuple."""
+
+    major: int
+    minor: int
+    micro: int
+    releaselevel: str
+    serial: str
+
 
 
 # bumpversion can only search for {current_version}
 # bumpversion can only search for {current_version}
 # so we have to parse the version here.
 # so we have to parse the version here.

+ 8 - 8
celery/app/__init__.py

@@ -1,11 +1,11 @@
-# -*- coding: utf-8 -*-
-"""Celery Application."""
+from typing import Callable, Union
 from celery.local import Proxy
 from celery.local import Proxy
 from celery import _state
 from celery import _state
 from celery._state import (
 from celery._state import (
     app_or_default, enable_trace, disable_trace,
     app_or_default, enable_trace, disable_trace,
     push_current_task, pop_current_task,
     push_current_task, pop_current_task,
 )
 )
+from celery.types import AppT, TaskT
 from .base import Celery
 from .base import Celery
 from .utils import AppPickler
 from .utils import AppPickler
 
 
@@ -19,12 +19,12 @@ __all__ = [
 default_app = Proxy(lambda: _state.default_app)
 default_app = Proxy(lambda: _state.default_app)
 
 
 
 
-def bugreport(app=None):
+def bugreport(app: AppT = None) -> str:
     """Return information useful in bug reports."""
     """Return information useful in bug reports."""
     return (app or _state.get_current_app()).bugreport()
     return (app or _state.get_current_app()).bugreport()
 
 
 
 
-def shared_task(*args, **kwargs):
+def shared_task(*args, **kwargs) -> Union[Callable, TaskT]:
     """Create shared task (decorator).
     """Create shared task (decorator).
 
 
     This can be used by library authors to create tasks that'll work
     This can be used by library authors to create tasks that'll work
@@ -48,10 +48,10 @@ def shared_task(*args, **kwargs):
         >>> add.app is app2
         >>> add.app is app2
         True
         True
     """
     """
-    def create_shared_task(**options):
+    def create_shared_task(**options) -> Callable:
 
 
-        def __inner(fun):
-            name = options.get('name')
+        def __inner(fun: Callable) -> TaskT:
+            name: str = options.get('name')
             # Set as shared task so that unfinalized apps,
             # Set as shared task so that unfinalized apps,
             # and future apps will register a copy of this task.
             # and future apps will register a copy of this task.
             _state.connect_on_app_finalize(
             _state.connect_on_app_finalize(
@@ -66,7 +66,7 @@ def shared_task(*args, **kwargs):
 
 
             # Return a proxy that always gets the task from the current
             # Return a proxy that always gets the task from the current
             # apps task registry.
             # apps task registry.
-            def task_by_cons():
+            def task_by_cons() -> TaskT:
                 app = _state.get_current_app()
                 app = _state.get_current_app()
                 return app.tasks[
                 return app.tasks[
                     name or app.gen_task_name(fun.__name__, fun.__module__)
                     name or app.gen_task_name(fun.__name__, fun.__module__)

+ 147 - 78
celery/app/amqp.py

@@ -2,17 +2,24 @@
 """Sending/Receiving Messages (Kombu integration)."""
 """Sending/Receiving Messages (Kombu integration)."""
 import numbers
 import numbers
 
 
-from collections import Mapping, namedtuple
-from datetime import timedelta
+from collections import Mapping
+from datetime import datetime, timedelta, tzinfo
+from typing import (
+    Any, Callable, MutableMapping, NamedTuple, Set, Sequence, Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 from weakref import WeakValueDictionary
 
 
 from kombu import pools
 from kombu import pools
 from kombu import Connection, Consumer, Exchange, Producer, Queue
 from kombu import Connection, Consumer, Exchange, Producer, Queue
 from kombu.common import Broadcast
 from kombu.common import Broadcast
+from kombu.types import ChannelT, ConsumerT, EntityT, ProducerT, ResourceT
 from kombu.utils.functional import maybe_list
 from kombu.utils.functional import maybe_list
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 from celery import signals
 from celery import signals
+from celery.events import EventDispatcher
+from celery.types import AppT, ResultT, RouterT, SignatureT
 from celery.utils.nodenames import anon_nodename
 from celery.utils.nodenames import anon_nodename
 from celery.utils.saferepr import saferepr
 from celery.utils.saferepr import saferepr
 from celery.utils.text import indent as textindent
 from celery.utils.text import indent as textindent
@@ -31,11 +38,19 @@ QUEUE_FORMAT = """
 key={0.routing_key}
 key={0.routing_key}
 """
 """
 
 
-task_message = namedtuple('task_message',
-                          ('headers', 'properties', 'body', 'sent_event'))
+QueuesArgT = Union[Mapping[str, Queue], Sequence[Queue]]
 
 
 
 
-def utf8dict(d, encoding='utf-8'):
+class task_message(NamedTuple):
+    """Represents a task message that can be sent."""
+
+    headers: MutableMapping
+    properties: MutableMapping
+    body: Any
+    sent_event: Mapping
+
+
+def utf8dict(d: Mapping, encoding: str = 'utf-8') -> Mapping:
     return {k.decode(encoding) if isinstance(k, bytes) else k: v
     return {k.decode(encoding) if isinstance(k, bytes) else k: v
             for k, v in d.items()}
             for k, v in d.items()}
 
 
@@ -54,11 +69,16 @@ class Queues(dict):
 
 
     #: If set, this is a subset of queues to consume from.
     #: If set, this is a subset of queues to consume from.
     #: The rest of the queues are then used for routing only.
     #: The rest of the queues are then used for routing only.
-    _consume_from = None
-
-    def __init__(self, queues=None, default_exchange=None,
-                 create_missing=True, ha_policy=None, autoexchange=None,
-                 max_priority=None, default_routing_key=None):
+    _consume_from: Mapping[str, Queue] = None
+
+    def __init__(self,
+                 queues: QueuesArgT = None,
+                 default_exchange: str = None,
+                 create_missing: bool = True,
+                 ha_policy: Union[Sequence, str] = None,
+                 autoexchange: Callable[[str], Exchange] = None,
+                 max_priority: int = None,
+                 default_routing_key: str = None) -> None:
         dict.__init__(self)
         dict.__init__(self)
         self.aliases = WeakValueDictionary()
         self.aliases = WeakValueDictionary()
         self.default_exchange = default_exchange
         self.default_exchange = default_exchange
@@ -72,25 +92,25 @@ class Queues(dict):
         for name, q in (queues or {}).items():
         for name, q in (queues or {}).items():
             self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
             self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
 
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: str) -> Queue:
         try:
         try:
             return self.aliases[name]
             return self.aliases[name]
         except KeyError:
         except KeyError:
             return dict.__getitem__(self, name)
             return dict.__getitem__(self, name)
 
 
-    def __setitem__(self, name, queue):
+    def __setitem__(self, name: str, queue: Queue) -> None:
         if self.default_exchange and not queue.exchange:
         if self.default_exchange and not queue.exchange:
             queue.exchange = self.default_exchange
             queue.exchange = self.default_exchange
         dict.__setitem__(self, name, queue)
         dict.__setitem__(self, name, queue)
         if queue.alias:
         if queue.alias:
             self.aliases[queue.alias] = queue
             self.aliases[queue.alias] = queue
 
 
-    def __missing__(self, name):
+    def __missing__(self, name: str) -> Queue:
         if self.create_missing:
         if self.create_missing:
             return self.add(self.new_missing(name))
             return self.add(self.new_missing(name))
         raise KeyError(name)
         raise KeyError(name)
 
 
-    def add(self, queue, **kwargs):
+    def add(self, queue: Union[Queue, str], **kwargs) -> Queue:
         """Add new queue.
         """Add new queue.
 
 
         The first argument can either be a :class:`kombu.Queue` instance,
         The first argument can either be a :class:`kombu.Queue` instance,
@@ -111,14 +131,14 @@ class Queues(dict):
             return self.add_compat(queue, **kwargs)
             return self.add_compat(queue, **kwargs)
         return self._add(queue)
         return self._add(queue)
 
 
-    def add_compat(self, name, **options):
+    def add_compat(self, name: str, **options) -> Queue:
         # docs used to use binding_key as routing key
         # docs used to use binding_key as routing key
         options.setdefault('routing_key', options.get('binding_key'))
         options.setdefault('routing_key', options.get('binding_key'))
         if options['routing_key'] is None:
         if options['routing_key'] is None:
             options['routing_key'] = name
             options['routing_key'] = name
         return self._add(Queue.from_dict(name, **options))
         return self._add(Queue.from_dict(name, **options))
 
 
-    def _add(self, queue):
+    def _add(self, queue: Queue) -> Queue:
         if not queue.routing_key:
         if not queue.routing_key:
             if queue.exchange is None or queue.exchange.name == '':
             if queue.exchange is None or queue.exchange.name == '':
                 queue.exchange = self.default_exchange
                 queue.exchange = self.default_exchange
@@ -134,18 +154,19 @@ class Queues(dict):
         self[queue.name] = queue
         self[queue.name] = queue
         return queue
         return queue
 
 
-    def _set_ha_policy(self, args):
+    def _set_ha_policy(self, args: MutableMapping) -> None:
         policy = self.ha_policy
         policy = self.ha_policy
         if isinstance(policy, (list, tuple)):
         if isinstance(policy, (list, tuple)):
-            return args.update({'x-ha-policy': 'nodes',
-                                'x-ha-policy-params': list(policy)})
-        args['x-ha-policy'] = policy
+            args.update({'x-ha-policy': 'nodes',
+                         'x-ha-policy-params': list(policy)})
+        else:
+            args['x-ha-policy'] = policy
 
 
-    def _set_max_priority(self, args):
+    def _set_max_priority(self, args: MutableMapping) -> None:
         if 'x-max-priority' not in args and self.max_priority is not None:
         if 'x-max-priority' not in args and self.max_priority is not None:
-            return args.update({'x-max-priority': self.max_priority})
+            args.update({'x-max-priority': self.max_priority})
 
 
-    def format(self, indent=0, indent_first=True):
+    def format(self, indent: int = 0, indent_first: bool = True) -> str:
         """Format routing table into string for log dumps."""
         """Format routing table into string for log dumps."""
         active = self.consume_from
         active = self.consume_from
         if not active:
         if not active:
@@ -156,7 +177,7 @@ class Queues(dict):
             return textindent('\n'.join(info), indent)
             return textindent('\n'.join(info), indent)
         return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
         return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
 
 
-    def select_add(self, queue, **kwargs):
+    def select_add(self, queue: Queue, **kwargs) -> Queue:
         """Add new task queue that'll be consumed from.
         """Add new task queue that'll be consumed from.
 
 
         The queue will be active even when a subset has been selected
         The queue will be active even when a subset has been selected
@@ -167,7 +188,7 @@ class Queues(dict):
             self._consume_from[q.name] = q
             self._consume_from[q.name] = q
         return q
         return q
 
 
-    def select(self, include):
+    def select(self, include: Union[Sequence[str], str]) -> None:
         """Select a subset of currently defined queues to consume from.
         """Select a subset of currently defined queues to consume from.
 
 
         Arguments:
         Arguments:
@@ -178,7 +199,7 @@ class Queues(dict):
                 name: self[name] for name in maybe_list(include)
                 name: self[name] for name in maybe_list(include)
             }
             }
 
 
-    def deselect(self, exclude):
+    def deselect(self, exclude: Union[Sequence[str], str]) -> None:
         """Deselect queues so that they won't be consumed from.
         """Deselect queues so that they won't be consumed from.
 
 
         Arguments:
         Arguments:
@@ -189,19 +210,20 @@ class Queues(dict):
             exclude = maybe_list(exclude)
             exclude = maybe_list(exclude)
             if self._consume_from is None:
             if self._consume_from is None:
                 # using selection
                 # using selection
-                return self.select(k for k in self if k not in exclude)
-            # using all queues
-            for queue in exclude:
-                self._consume_from.pop(queue, None)
+                self.select(k for k in self if k not in exclude)
+            else:
+                # using all queues
+                for queue in exclude:
+                    self._consume_from.pop(queue, None)
 
 
-    def new_missing(self, name):
+    def new_missing(self, name: str) -> Queue:
         return Queue(name, self.autoexchange(name), name)
         return Queue(name, self.autoexchange(name), name)
 
 
     @property
     @property
-    def consume_from(self):
+    def consume_from(self) -> Mapping[str, Queue]:
         if self._consume_from is not None:
         if self._consume_from is not None:
             return self._consume_from
             return self._consume_from
-        return self
+        return cast(Mapping[str, Queue], self)
 
 
 
 
 class AMQP:
 class AMQP:
@@ -221,13 +243,13 @@ class AMQP:
 
 
     #: Underlying producer pool instance automatically
     #: Underlying producer pool instance automatically
     #: set by the :attr:`producer_pool`.
     #: set by the :attr:`producer_pool`.
-    _producer_pool = None
+    _producer_pool: ResourceT = None
 
 
     # Exchange class/function used when defining automatic queues.
     # Exchange class/function used when defining automatic queues.
     # For example, you can use ``autoexchange = lambda n: None`` to use the
     # For example, you can use ``autoexchange = lambda n: None`` to use the
     # AMQP default exchange: a shortcut to bypass routing
     # AMQP default exchange: a shortcut to bypass routing
     # and instead send directly to the queue named in the routing key.
     # and instead send directly to the queue named in the routing key.
-    autoexchange = None
+    autoexchange: Callable[[str], Exchange] = None
 
 
     #: Max size of positional argument representation used for
     #: Max size of positional argument representation used for
     #: logging purposes.
     #: logging purposes.
@@ -236,7 +258,9 @@ class AMQP:
     #: Max size of keyword argument representation used for logging purposes.
     #: Max size of keyword argument representation used for logging purposes.
     kwargsrepr_maxsize = 1024
     kwargsrepr_maxsize = 1024
 
 
-    def __init__(self, app):
+    task_protocols: Mapping[int, Callable] = None
+
+    def __init__(self, app: AppT) -> None:
         self.app = app
         self.app = app
         self.task_protocols = {
         self.task_protocols = {
             1: self.as_task_v1,
             1: self.as_task_v1,
@@ -244,15 +268,18 @@ class AMQP:
         }
         }
 
 
     @cached_property
     @cached_property
-    def create_task_message(self):
+    def create_task_message(self) -> Callable:
         return self.task_protocols[self.app.conf.task_protocol]
         return self.task_protocols[self.app.conf.task_protocol]
 
 
     @cached_property
     @cached_property
-    def send_task_message(self):
+    def send_task_message(self) -> Callable:
         return self._create_task_sender()
         return self._create_task_sender()
 
 
-    def Queues(self, queues, create_missing=None, ha_policy=None,
-               autoexchange=None, max_priority=None):
+    def Queues(self, queues: QueuesArgT,
+               create_missing: bool = None,
+               ha_policy: Union[Sequence, str] = None,
+               autoexchange: Callable[[str], Exchange] = None,
+               max_priority: int = None) -> Queues:
         # Create new :class:`Queues` instance, using queue defaults
         # Create new :class:`Queues` instance, using queue defaults
         # from the current configuration.
         # from the current configuration.
         conf = self.app.conf
         conf = self.app.conf
@@ -267,23 +294,27 @@ class AMQP:
             queues = (Queue(conf.task_default_queue,
             queues = (Queue(conf.task_default_queue,
                             exchange=self.default_exchange,
                             exchange=self.default_exchange,
                             routing_key=default_routing_key),)
                             routing_key=default_routing_key),)
-        autoexchange = (self.autoexchange if autoexchange is None
-                        else autoexchange)
+        autoexchange = (self.autoexchange
+                        if autoexchange is None else autoexchange)
         return self.queues_cls(
         return self.queues_cls(
             queues, self.default_exchange, create_missing,
             queues, self.default_exchange, create_missing,
             ha_policy, autoexchange, max_priority, default_routing_key,
             ha_policy, autoexchange, max_priority, default_routing_key,
         )
         )
 
 
-    def Router(self, queues=None, create_missing=None):
+    def Router(self, queues: Mapping[str, Queue] = None,
+               create_missing: bool = None) -> RouterT:
         """Return the current task router."""
         """Return the current task router."""
         return _routes.Router(self.routes, queues or self.queues,
         return _routes.Router(self.routes, queues or self.queues,
                               self.app.either('task_create_missing_queues',
                               self.app.either('task_create_missing_queues',
                                               create_missing), app=self.app)
                                               create_missing), app=self.app)
 
 
-    def flush_routes(self):
+    def flush_routes(self) -> None:
         self._rtable = _routes.prepare(self.app.conf.task_routes)
         self._rtable = _routes.prepare(self.app.conf.task_routes)
 
 
-    def TaskConsumer(self, channel, queues=None, accept=None, **kw):
+    def TaskConsumer(self, channel: ChannelT,
+                     queues: Mapping[str, Queue] = None,
+                     accept: Set[str] = None,
+                     **kw) -> ConsumerT:
         if accept is None:
         if accept is None:
             accept = self.app.conf.accept_content
             accept = self.app.conf.accept_content
         return self.Consumer(
         return self.Consumer(
@@ -292,14 +323,30 @@ class AMQP:
             **kw
             **kw
         )
         )
 
 
-    def as_task_v2(self, task_id, name, args=None, kwargs=None,
-                   countdown=None, eta=None, group_id=None,
-                   expires=None, retries=0, chord=None,
-                   callbacks=None, errbacks=None, reply_to=None,
-                   time_limit=None, soft_time_limit=None,
-                   create_sent_event=False, root_id=None, parent_id=None,
-                   shadow=None, chain=None, now=None, timezone=None,
-                   origin=None, argsrepr=None, kwargsrepr=None):
+    def as_task_v2(self, task_id: str, name: str, *,
+                   args: Sequence = None,
+                   kwargs: Mapping = None,
+                   countdown: float = None,
+                   eta: datetime = None,
+                   group_id: str = None,
+                   expires: Union[float, datetime] = None,
+                   retries: int = 0,
+                   chord: SignatureT = None,
+                   callbacks: Sequence[SignatureT] = None,
+                   errbacks: Sequence[SignatureT] = None,
+                   reply_to: str = None,
+                   time_limit: float = None,
+                   soft_time_limit: float = None,
+                   create_sent_event: bool = False,
+                   root_id: str = None,
+                   parent_id: str = None,
+                   shadow: str = None,
+                   chain: Sequence[SignatureT] = None,
+                   now: datetime = None,
+                   timezone: tzinfo = None,
+                   origin: str = None,
+                   argsrepr: str = None,
+                   kwargsrepr: str = None) -> task_message:
         args = args or ()
         args = args or ()
         kwargs = kwargs or {}
         kwargs = kwargs or {}
         if not isinstance(args, (list, tuple)):
         if not isinstance(args, (list, tuple)):
@@ -372,13 +419,26 @@ class AMQP:
             } if create_sent_event else None,
             } if create_sent_event else None,
         )
         )
 
 
-    def as_task_v1(self, task_id, name, args=None, kwargs=None,
-                   countdown=None, eta=None, group_id=None,
-                   expires=None, retries=0,
-                   chord=None, callbacks=None, errbacks=None, reply_to=None,
-                   time_limit=None, soft_time_limit=None,
-                   create_sent_event=False, root_id=None, parent_id=None,
-                   shadow=None, now=None, timezone=None):
+    def as_task_v1(self, task_id, name,
+                   args: Sequence = None,
+                   kwargs: Mapping = None,
+                   countdown: float = None,
+                   eta: datetime = None,
+                   group_id: str = None,
+                   expires: Union[float, datetime] = None,
+                   retries: int = 0,
+                   chord: SignatureT = None,
+                   callbacks: Sequence[SignatureT] = None,
+                   errbacks: Sequence[SignatureT] = None,
+                   reply_to: str = None,
+                   time_limit: float = None,
+                   soft_time_limit: float = None,
+                   create_sent_event: bool = False,
+                   root_id: str = None,
+                   parent_id: str = None,
+                   shadow: str = None,
+                   now: datetime = None,
+                   timezone: tzinfo = None) -> task_message:
         args = args or ()
         args = args or ()
         kwargs = kwargs or {}
         kwargs = kwargs or {}
         utc = self.utc
         utc = self.utc
@@ -436,12 +496,12 @@ class AMQP:
             } if create_sent_event else None,
             } if create_sent_event else None,
         )
         )
 
 
-    def _verify_seconds(self, s, what):
+    def _verify_seconds(self, s: float, what: str) -> float:
         if s < INT_MIN:
         if s < INT_MIN:
             raise ValueError('%s is out of range: %r' % (what, s))
             raise ValueError('%s is out of range: %r' % (what, s))
         return s
         return s
 
 
-    def _create_task_sender(self):
+    def _create_task_sender(self) -> Callable:
         default_retry = self.app.conf.task_publish_retry
         default_retry = self.app.conf.task_publish_retry
         default_policy = self.app.conf.task_publish_retry_policy
         default_policy = self.app.conf.task_publish_retry_policy
         default_delivery_mode = self.app.conf.task_default_delivery_mode
         default_delivery_mode = self.app.conf.task_default_delivery_mode
@@ -459,13 +519,22 @@ class AMQP:
         default_serializer = self.app.conf.task_serializer
         default_serializer = self.app.conf.task_serializer
         default_compressor = self.app.conf.result_compression
         default_compressor = self.app.conf.result_compression
 
 
-        def send_task_message(producer, name, message,
-                              exchange=None, routing_key=None, queue=None,
-                              event_dispatcher=None,
-                              retry=None, retry_policy=None,
-                              serializer=None, delivery_mode=None,
-                              compression=None, declare=None,
-                              headers=None, exchange_type=None, **kwargs):
+        def send_task_message(
+                producer: ProducerT, name: str, message: task_message,
+                *,
+                exchange: Union[Exchange, str] = None,
+                routing_key: str = None,
+                queue: str = None,
+                event_dispatcher: EventDispatcher = None,
+                retry: bool = None,
+                retry_policy: Mapping[str, Any] = None,
+                serializer: str = None,
+                delivery_mode: Union[str, int] = None,
+                compression: str = None,
+                declare: Sequence[EntityT] = None,
+                headers: Mapping = None,
+                exchange_type: str = None,
+                **kwargs) -> ResultT:
             retry = default_retry if retry is None else retry
             retry = default_retry if retry is None else retry
             headers2, properties, body, sent_event = message
             headers2, properties, body, sent_event = message
             if headers:
             if headers:
@@ -547,30 +616,30 @@ class AMQP:
         return send_task_message
         return send_task_message
 
 
     @cached_property
     @cached_property
-    def default_queue(self):
+    def default_queue(self) -> Queue:
         return self.queues[self.app.conf.task_default_queue]
         return self.queues[self.app.conf.task_default_queue]
 
 
     @cached_property
     @cached_property
-    def queues(self):
+    def queues(self) -> Queues:
         """Queue name⇒ declaration mapping."""
         """Queue name⇒ declaration mapping."""
         return self.Queues(self.app.conf.task_queues)
         return self.Queues(self.app.conf.task_queues)
 
 
     @queues.setter  # noqa
     @queues.setter  # noqa
-    def queues(self, queues):
+    def queues(self, queues: QueuesArgT) -> Queues:
         return self.Queues(queues)
         return self.Queues(queues)
 
 
     @property
     @property
-    def routes(self):
+    def routes(self) -> Sequence[RouterT]:
         if self._rtable is None:
         if self._rtable is None:
             self.flush_routes()
             self.flush_routes()
         return self._rtable
         return self._rtable
 
 
     @cached_property
     @cached_property
-    def router(self):
+    def router(self) -> RouterT:
         return self.Router()
         return self.Router()
 
 
     @property
     @property
-    def producer_pool(self):
+    def producer_pool(self) -> ResourceT:
         if self._producer_pool is None:
         if self._producer_pool is None:
             self._producer_pool = pools.producers[
             self._producer_pool = pools.producers[
                 self.app.connection_for_write()]
                 self.app.connection_for_write()]
@@ -578,16 +647,16 @@ class AMQP:
         return self._producer_pool
         return self._producer_pool
 
 
     @cached_property
     @cached_property
-    def default_exchange(self):
+    def default_exchange(self) -> Exchange:
         return Exchange(self.app.conf.task_default_exchange,
         return Exchange(self.app.conf.task_default_exchange,
                         self.app.conf.task_default_exchange_type)
                         self.app.conf.task_default_exchange_type)
 
 
     @cached_property
     @cached_property
-    def utc(self):
+    def utc(self) -> bool:
         return self.app.conf.enable_utc
         return self.app.conf.enable_utc
 
 
     @cached_property
     @cached_property
-    def _event_dispatcher(self):
+    def _event_dispatcher(self) -> EventDispatcher:
         # We call Dispatcher.publish with a custom producer
         # We call Dispatcher.publish with a custom producer
         # so don't need the diuspatcher to be enabled.
         # so don't need the diuspatcher to be enabled.
         return self.app.events.Dispatcher(enabled=False)
         return self.app.events.Dispatcher(enabled=False)

+ 8 - 4
celery/app/backends.py

@@ -2,8 +2,10 @@
 """Backend selection."""
 """Backend selection."""
 import sys
 import sys
 import types
 import types
+from typing import Mapping, Tuple, Union
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery._state import current_app
 from celery._state import current_app
+from celery.types import LoaderT
 from celery.utils.imports import load_extension_class_names, symbol_by_name
 from celery.utils.imports import load_extension_class_names, symbol_by_name
 
 
 __all__ = ['by_name', 'by_url']
 __all__ = ['by_name', 'by_url']
@@ -12,7 +14,7 @@ UNKNOWN_BACKEND = """
 Unknown result backend: {0!r}.  Did you spell that correctly? ({1!r})
 Unknown result backend: {0!r}.  Did you spell that correctly? ({1!r})
 """
 """
 
 
-BACKEND_ALIASES = {
+BACKEND_ALIASES: Mapping[str, str] = {
     'amqp': 'celery.backends.amqp:AMQPBackend',
     'amqp': 'celery.backends.amqp:AMQPBackend',
     'rpc': 'celery.backends.rpc.RPCBackend',
     'rpc': 'celery.backends.rpc.RPCBackend',
     'cache': 'celery.backends.cache:CacheBackend',
     'cache': 'celery.backends.cache:CacheBackend',
@@ -31,8 +33,9 @@ BACKEND_ALIASES = {
 }
 }
 
 
 
 
-def by_name(backend=None, loader=None,
-            extension_namespace='celery.result_backends'):
+def by_name(backend: Union[str, type] = None,
+            loader: LoaderT = None,
+            extension_namespace: str = 'celery.result_backends') -> type:
     """Get backend class by name/alias."""
     """Get backend class by name/alias."""
     backend = backend or 'disabled'
     backend = backend or 'disabled'
     loader = loader or current_app.loader
     loader = loader or current_app.loader
@@ -50,7 +53,8 @@ def by_name(backend=None, loader=None,
     return cls
     return cls
 
 
 
 
-def by_url(backend=None, loader=None):
+def by_url(backend: Union[str, type] = None,
+           loader: LoaderT = None) -> Tuple[type, str]:
     """Get backend class by URL."""
     """Get backend class by URL."""
     url = None
     url = None
     if backend and '://' in backend:
     if backend and '://' in backend:

+ 271 - 163
celery/app/base.py

@@ -5,14 +5,22 @@ import threading
 import warnings
 import warnings
 
 
 from collections import UserDict, defaultdict, deque
 from collections import UserDict, defaultdict, deque
+from datetime import datetime
 from operator import attrgetter
 from operator import attrgetter
+from types import ModuleType
+from typing import (
+    Any, Callable, ContextManager, Dict, List,
+    Mapping, MutableMapping, Optional, Set, Sequence, Tuple, Union,
+)
 
 
+from amqp.types import SSLArg
 from kombu import pools
 from kombu import pools
-from kombu.clocks import LamportClock
+from kombu.clocks import Clock, LamportClock
 from kombu.common import oid_from
 from kombu.common import oid_from
 from kombu.utils.compat import register_after_fork
 from kombu.utils.compat import register_after_fork
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 from kombu.utils.uuid import uuid
 from kombu.utils.uuid import uuid
+from kombu.types import ConnectionT, ProducerT, ResourceT
 from vine import starpromise
 from vine import starpromise
 from vine.utils import wraps
 from vine.utils import wraps
 
 
@@ -27,6 +35,11 @@ from celery._state import (
 from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
 from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
 from celery.loaders import get_loader_cls
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
 from celery.local import PromiseProxy, maybe_evaluate
+from celery.types import (
+    AppT, AppAMQPT, AppControlT, AppEventsT, AppLogT,
+    BackendT, BeatT, LoaderT, ResultT, RouterT, ScheduleT,
+    SignalT, SignatureT, TaskT, TaskRegistryT, WorkerT,
+)
 from celery.utils import abstract
 from celery.utils import abstract
 from celery.utils.collections import AttributeDictMixin
 from celery.utils.collections import AttributeDictMixin
 from celery.utils.dispatch import Signal
 from celery.utils.dispatch import Signal
@@ -71,7 +84,7 @@ Example:
 """
 """
 
 
 
 
-def app_has_custom(app, attr):
+def app_has_custom(app: AppT, attr: str) -> bool:
     """Return true if app has customized method `attr`.
     """Return true if app has customized method `attr`.
 
 
     Note:
     Note:
@@ -83,14 +96,14 @@ def app_has_custom(app, attr):
                       monkey_patched=[__name__])
                       monkey_patched=[__name__])
 
 
 
 
-def _unpickle_appattr(reverse_name, args):
+def _unpickle_appattr(reverse_name: str, args: Tuple) -> Any:
     """Unpickle app."""
     """Unpickle app."""
     # Given an attribute name and a list of args, gets
     # Given an attribute name and a list of args, gets
     # the attribute from the current app and calls it.
     # the attribute from the current app and calls it.
     return get_current_app()._rgetattr(reverse_name)(*args)
     return get_current_app()._rgetattr(reverse_name)(*args)
 
 
 
 
-def _after_fork_cleanup_app(app):
+def _after_fork_cleanup_app(app: AppT) -> None:
     # This is used with multiprocessing.register_after_fork,
     # This is used with multiprocessing.register_after_fork,
     # so need to be at module level.
     # so need to be at module level.
     try:
     try:
@@ -107,39 +120,41 @@ class PendingConfiguration(UserDict, AttributeDictMixin):
     # accessing any key will finalize the configuration,
     # accessing any key will finalize the configuration,
     # replacing `app.conf` with a concrete settings object.
     # replacing `app.conf` with a concrete settings object.
 
 
-    callback = None
-    _data = None
+    callback: Callable[[], MutableMapping] = None
+    _data: MutableMapping = None
 
 
-    def __init__(self, conf, callback):
+    def __init__(self,
+                 conf: MutableMapping,
+                 callback: Callable[[], MutableMapping]) -> None:
         object.__setattr__(self, '_data', conf)
         object.__setattr__(self, '_data', conf)
         object.__setattr__(self, 'callback', callback)
         object.__setattr__(self, 'callback', callback)
 
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Any) -> None:
         self._data[key] = value
         self._data[key] = value
 
 
-    def clear(self):
+    def clear(self) -> None:
         self._data.clear()
         self._data.clear()
 
 
-    def update(self, *args, **kwargs):
+    def update(self, *args, **kwargs) -> None:
         self._data.update(*args, **kwargs)
         self._data.update(*args, **kwargs)
 
 
-    def setdefault(self, *args, **kwargs):
-        return self._data.setdefault(*args, **kwargs)
+    def setdefault(self, key: str, value: Any) -> Any:
+        return self._data.setdefault(key, value)
 
 
-    def __contains__(self, key):
+    def __contains__(self, key: str) -> bool:
         # XXX will not show finalized configuration
         # XXX will not show finalized configuration
         # setdefault will cause `key in d` to happen,
         # setdefault will cause `key in d` to happen,
         # so for setdefault to be lazy, so does contains.
         # so for setdefault to be lazy, so does contains.
         return key in self._data
         return key in self._data
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.data)
         return len(self.data)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return repr(self.data)
         return repr(self.data)
 
 
     @cached_property
     @cached_property
-    def data(self):
+    def data(self) -> MutableMapping:
         return self.callback()
         return self.callback()
 
 
 
 
@@ -182,54 +197,71 @@ class Celery:
     SYSTEM = platforms.SYSTEM
     SYSTEM = platforms.SYSTEM
     IS_macOS, IS_WINDOWS = platforms.IS_macOS, platforms.IS_WINDOWS
     IS_macOS, IS_WINDOWS = platforms.IS_macOS, platforms.IS_WINDOWS
 
 
+    clock: Clock = None
+
     #: Name of the `__main__` module.  Required for standalone scripts.
     #: Name of the `__main__` module.  Required for standalone scripts.
     #:
     #:
     #: If set this will be used instead of `__main__` when automatically
     #: If set this will be used instead of `__main__` when automatically
     #: generating task names.
     #: generating task names.
-    main = None
+    main: str = None
 
 
     #: Custom options for command-line programs.
     #: Custom options for command-line programs.
     #: See :ref:`extending-commandoptions`
     #: See :ref:`extending-commandoptions`
-    user_options = None
+    user_options: MutableMapping[str, Set] = None
 
 
     #: Custom bootsteps to extend and modify the worker.
     #: Custom bootsteps to extend and modify the worker.
     #: See :ref:`extending-bootsteps`.
     #: See :ref:`extending-bootsteps`.
-    steps = None
+    steps: MutableMapping[str, Set] = None
 
 
     builtin_fixups = BUILTIN_FIXUPS
     builtin_fixups = BUILTIN_FIXUPS
 
 
-    amqp_cls = 'celery.app.amqp:AMQP'
-    backend_cls = None
-    events_cls = 'celery.app.events:Events'
-    loader_cls = None
-    log_cls = 'celery.app.log:Logging'
-    control_cls = 'celery.app.control:Control'
-    task_cls = 'celery.app.task:Task'
-    registry_cls = TaskRegistry
+    amqp_cls: Union[str, type] = 'celery.app.amqp:AMQP'
+    backend_cls: Union[str, type] = None
+    events_cls: Union[str, type] = 'celery.app.events:Events'
+    loader_cls: Union[str, type] = None
+    log_cls: Union[str, type] = 'celery.app.log:Logging'
+    control_cls: Union[str, type] = 'celery.app.control:Control'
+    task_cls: Union[str, type] = 'celery.app.task:Task'
+    registry_cls: Union[str, type] = TaskRegistry
 
 
-    _fixups = None
-    _pool = None
-    _conf = None
-    _after_fork_registered = False
+    _fixups: List = None
+    _pool: ResourceT = None
+    _conf: MutableMapping = None
+    _after_fork_registered: bool = False
 
 
     #: Signal sent when app is loading configuration.
     #: Signal sent when app is loading configuration.
-    on_configure = None
+    on_configure: SignatureT = None
 
 
     #: Signal sent after app has prepared the configuration.
     #: Signal sent after app has prepared the configuration.
-    on_after_configure = None
+    on_after_configure: SignalT = None
 
 
     #: Signal sent after app has been finalized.
     #: Signal sent after app has been finalized.
-    on_after_finalize = None
+    on_after_finalize: SignalT = None
 
 
     #: Signal sent by every new process after fork.
     #: Signal sent by every new process after fork.
-    on_after_fork = None
-
-    def __init__(self, main=None, loader=None, backend=None,
-                 amqp=None, events=None, log=None, control=None,
-                 set_as_current=True, tasks=None, broker=None, include=None,
-                 changes=None, config_source=None, fixups=None, task_cls=None,
-                 autofinalize=True, namespace=None, strict_typing=True,
-                 **kwargs):
+    on_after_fork: SignalT = None
+
+    def __init__(self,
+                 main: str = None,
+                 *,
+                 loader: Union[str, type] = None,
+                 backend: Union[str, type] = None,
+                 amqp: Union[str, type] = None,
+                 events: Union[str, type] = None,
+                 log: Union[str, type] = None,
+                 control: Union[str, type] = None,
+                 set_as_current: bool = True,
+                 tasks: Union[str, type] = None,
+                 broker: str = None,
+                 include: Sequence[str] = None,
+                 changes: MutableMapping = None,
+                 config_source: str = None,
+                 fixups: Sequence[Union[str, type]] = None,
+                 task_cls: Union[str, type] = None,
+                 autofinalize: bool = True,
+                 namespace: str = None,
+                 strict_typing: bool = True,
+                 **kwargs) -> None:
         self.clock = LamportClock()
         self.clock = LamportClock()
         self.main = main
         self.main = main
         self.amqp_cls = amqp or self.amqp_cls
         self.amqp_cls = amqp or self.amqp_cls
@@ -299,7 +331,7 @@ class Celery:
         self.on_init()
         self.on_init()
         _register_app(self)
         _register_app(self)
 
 
-    def _get_default_loader(self):
+    def _get_default_loader(self) -> Union[str, type]:
         # the --loader command-line argument sets the environment variable.
         # the --loader command-line argument sets the environment variable.
         return (
         return (
             os.environ.get('CELERY_LOADER') or
             os.environ.get('CELERY_LOADER') or
@@ -307,30 +339,30 @@ class Celery:
             'celery.loaders.app:AppLoader'
             'celery.loaders.app:AppLoader'
         )
         )
 
 
-    def on_init(self):
+    def on_init(self) -> None:
         """Optional callback called at init."""
         """Optional callback called at init."""
-        pass
+        ...
 
 
-    def __autoset(self, key, value):
+    def __autoset(self, key: str, value: Any) -> None:
         if value:
         if value:
             self._preconf[key] = value
             self._preconf[key] = value
             self._preconf_set_by_auto.add(key)
             self._preconf_set_by_auto.add(key)
 
 
-    def set_current(self):
+    def set_current(self) -> None:
         """Make this the current app for this thread."""
         """Make this the current app for this thread."""
         _set_current_app(self)
         _set_current_app(self)
 
 
-    def set_default(self):
+    def set_default(self) -> None:
         """Make this the default app for all threads."""
         """Make this the default app for all threads."""
         set_default_app(self)
         set_default_app(self)
 
 
-    def _ensure_after_fork(self):
+    def _ensure_after_fork(self) -> None:
         if not self._after_fork_registered:
         if not self._after_fork_registered:
             self._after_fork_registered = True
             self._after_fork_registered = True
             if register_after_fork is not None:
             if register_after_fork is not None:
                 register_after_fork(self, _after_fork_cleanup_app)
                 register_after_fork(self, _after_fork_cleanup_app)
 
 
-    def close(self):
+    def close(self) -> None:
         """Clean up after the application.
         """Clean up after the application.
 
 
         Only necessary for dynamically created apps, and you should
         Only necessary for dynamically created apps, and you should
@@ -344,25 +376,25 @@ class Celery:
         self._pool = None
         self._pool = None
         _deregister_app(self)
         _deregister_app(self)
 
 
-    def start(self, argv=None):
+    def start(self, argv: List[str] = None) -> None:
         """Run :program:`celery` using `argv`.
         """Run :program:`celery` using `argv`.
 
 
         Uses :data:`sys.argv` if `argv` is not specified.
         Uses :data:`sys.argv` if `argv` is not specified.
         """
         """
-        return instantiate(
+        instantiate(
             'celery.bin.celery:CeleryCommand', app=self
             'celery.bin.celery:CeleryCommand', app=self
         ).execute_from_commandline(argv)
         ).execute_from_commandline(argv)
 
 
-    def worker_main(self, argv=None):
+    def worker_main(self, argv: List[str] = None) -> None:
         """Run :program:`celery worker` using `argv`.
         """Run :program:`celery worker` using `argv`.
 
 
         Uses :data:`sys.argv` if `argv` is not specified.
         Uses :data:`sys.argv` if `argv` is not specified.
         """
         """
-        return instantiate(
+        instantiate(
             'celery.bin.worker:worker', app=self
             'celery.bin.worker:worker', app=self
         ).execute_from_commandline(argv)
         ).execute_from_commandline(argv)
 
 
-    def task(self, *args, **opts):
+    def task(self, *args, **opts) -> Union[Callable, TaskT]:
         """Decorator to create a task class out of any callable.
         """Decorator to create a task class out of any callable.
 
 
         Examples:
         Examples:
@@ -398,10 +430,15 @@ class Celery:
             from . import shared_task
             from . import shared_task
             return shared_task(*args, lazy=False, **opts)
             return shared_task(*args, lazy=False, **opts)
 
 
-        def inner_create_task_cls(shared=True, filter=None, lazy=True, **opts):
+        def inner_create_task_cls(
+                *,
+                shared: bool = True,
+                filter: Callable = None,
+                lazy: bool = True,
+                **opts) -> Callable:
             _filt = filter  # stupid 2to3
             _filt = filter  # stupid 2to3
 
 
-            def _create_task_cls(fun):
+            def _create_task_cls(fun: Callable) -> TaskT:
                 if shared:
                 if shared:
                     def cons(app):
                     def cons(app):
                         return app._task_from_fun(fun, **opts)
                         return app._task_from_fun(fun, **opts)
@@ -430,7 +467,12 @@ class Celery:
                     sum([len(args), len(opts)])))
                     sum([len(args), len(opts)])))
         return inner_create_task_cls(**opts)
         return inner_create_task_cls(**opts)
 
 
-    def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
+    def _task_from_fun(self, fun: Callable,
+                       *,
+                       name: str = None,
+                       base: type = None,
+                       bind: bool = False,
+                       **options) -> TaskT:
         if not self.finalized and not self.autofinalize:
         if not self.finalized and not self.autofinalize:
             raise RuntimeError('Contract breach: app not finalized')
             raise RuntimeError('Contract breach: app not finalized')
         name = name or self.gen_task_name(fun.__name__, fun.__module__)
         name = name or self.gen_task_name(fun.__name__, fun.__module__)
@@ -462,7 +504,7 @@ class Celery:
             if autoretry_for and not hasattr(task, '_orig_run'):
             if autoretry_for and not hasattr(task, '_orig_run'):
 
 
                 @wraps(task.run)
                 @wraps(task.run)
-                def run(*args, **kwargs):
+                def run(*args, **kwargs) -> Any:
                     try:
                     try:
                         return task._orig_run(*args, **kwargs)
                         return task._orig_run(*args, **kwargs)
                     except autoretry_for as exc:
                     except autoretry_for as exc:
@@ -473,7 +515,7 @@ class Celery:
             task = self._tasks[name]
             task = self._tasks[name]
         return task
         return task
 
 
-    def register_task(self, task):
+    def register_task(self, task: TaskT) -> TaskT:
         """Utility for registering a task-based class.
         """Utility for registering a task-based class.
 
 
         Note:
         Note:
@@ -490,10 +532,10 @@ class Celery:
         task.bind(self)
         task.bind(self)
         return task
         return task
 
 
-    def gen_task_name(self, name, module):
+    def gen_task_name(self, name: str, module: str) -> str:
         return gen_task_name(self, name, module)
         return gen_task_name(self, name, module)
 
 
-    def finalize(self, auto=False):
+    def finalize(self, auto: bool = False) -> None:
         """Finalize the app.
         """Finalize the app.
 
 
         This loads built-in tasks, evaluates pending task decorators,
         This loads built-in tasks, evaluates pending task decorators,
@@ -515,7 +557,7 @@ class Celery:
 
 
                 self.on_after_finalize.send(sender=self)
                 self.on_after_finalize.send(sender=self)
 
 
-    def add_defaults(self, fun):
+    def add_defaults(self, fun: Union[Callable, Mapping]) -> None:
         """Add default configuration from dict ``d``.
         """Add default configuration from dict ``d``.
 
 
         If the argument is a callable function then it will be regarded
         If the argument is a callable function then it will be regarded
@@ -536,11 +578,15 @@ class Celery:
         if not callable(fun):
         if not callable(fun):
             d, fun = fun, lambda: d
             d, fun = fun, lambda: d
         if self.configured:
         if self.configured:
-            return self._conf.add_defaults(fun())
-        self._pending_defaults.append(fun)
+            self._conf.add_defaults(fun())
+        else:
+            self._pending_defaults.append(fun)
 
 
-    def config_from_object(self, obj,
-                           silent=False, force=False, namespace=None):
+    def config_from_object(self, obj: Any,
+                           *,
+                           silent: bool = False,
+                           force: bool = False,
+                           namespace: str = None) -> None:
         """Read configuration from object.
         """Read configuration from object.
 
 
         Object is either an actual object or the name of a module to import.
         Object is either an actual object or the name of a module to import.
@@ -561,9 +607,12 @@ class Celery:
         if force or self.configured:
         if force or self.configured:
             self._conf = None
             self._conf = None
             if self.loader.config_from_object(obj, silent=silent):
             if self.loader.config_from_object(obj, silent=silent):
-                return self.conf
+                conf = self.conf  # noqa
 
 
-    def config_from_envvar(self, variable_name, silent=False, force=False):
+    def config_from_envvar(self, variable_name: str,
+                           *,
+                           silent: bool = False,
+                           force: bool = False) -> None:
         """Read configuration from environment variable.
         """Read configuration from environment variable.
 
 
         The value of the environment variable must be the name
         The value of the environment variable must be the name
@@ -574,20 +623,26 @@ class Celery:
             >>> celery.config_from_envvar('CELERY_CONFIG_MODULE')
             >>> celery.config_from_envvar('CELERY_CONFIG_MODULE')
         """
         """
         module_name = os.environ.get(variable_name)
         module_name = os.environ.get(variable_name)
-        if not module_name:
-            if silent:
-                return False
+        if module_name:
+            self.config_from_object(module_name, silent=silent, force=force)
+        elif not silent:
             raise ImproperlyConfigured(
             raise ImproperlyConfigured(
                 ERR_ENVVAR_NOT_SET.strip().format(variable_name))
                 ERR_ENVVAR_NOT_SET.strip().format(variable_name))
-        return self.config_from_object(module_name, silent=silent, force=force)
 
 
-    def config_from_cmdline(self, argv, namespace='celery'):
+    def config_from_cmdline(self, argv: List[str],
+                            namespace: str = 'celery') -> None:
         self._conf.update(
         self._conf.update(
             self.loader.cmdline_config_parser(argv, namespace)
             self.loader.cmdline_config_parser(argv, namespace)
         )
         )
 
 
-    def setup_security(self, allowed_serializers=None, key=None, cert=None,
-                       store=None, digest='sha1', serializer='json'):
+    def setup_security(self,
+                       allowed_serializers: Set[str] = None,
+                       *,
+                       key: str = None,
+                       cert: str = None,
+                       store: str = None,
+                       digest: str = 'sha1',
+                       serializer: str = 'json') -> None:
         """Setup the message-signing serializer.
         """Setup the message-signing serializer.
 
 
         This will affect all application instances (a global operation).
         This will affect all application instances (a global operation).
@@ -612,11 +667,14 @@ class Celery:
                 the serializers supported.  Default is ``json``.
                 the serializers supported.  Default is ``json``.
         """
         """
         from celery.security import setup_security
         from celery.security import setup_security
-        return setup_security(allowed_serializers, key, cert,
-                              store, digest, serializer, app=self)
-
-    def autodiscover_tasks(self, packages=None,
-                           related_name='tasks', force=False):
+        setup_security(allowed_serializers, key, cert,
+                       store, digest, serializer, app=self)
+
+    def autodiscover_tasks(self,
+                           packages: Sequence[str] = None,
+                           *,
+                           related_name: str = 'tasks',
+                           force: bool = False) -> None:
         """Auto-discover task modules.
         """Auto-discover task modules.
 
 
         Searches a list of packages for a "tasks.py" module (or use
         Searches a list of packages for a "tasks.py" module (or use
@@ -655,37 +713,61 @@ class Celery:
                 to happen immediately.
                 to happen immediately.
         """
         """
         if force:
         if force:
-            return self._autodiscover_tasks(packages, related_name)
-        signals.import_modules.connect(starpromise(
-            self._autodiscover_tasks, packages, related_name,
-        ), weak=False, sender=self)
+            self._autodiscover_tasks(packages, related_name)
+        else:
+            signals.import_modules.connect(starpromise(
+                self._autodiscover_tasks, packages, related_name,
+            ), weak=False, sender=self)
 
 
-    def _autodiscover_tasks(self, packages, related_name, **kwargs):
+    def _autodiscover_tasks(
+            self, packages: Sequence[str], related_name: str,
+            **kwargs) -> None:
         if packages:
         if packages:
-            return self._autodiscover_tasks_from_names(packages, related_name)
-        return self._autodiscover_tasks_from_fixups(related_name)
+            self._autodiscover_tasks_from_names(packages, related_name)
+        else:
+            self._autodiscover_tasks_from_fixups(related_name)
 
 
-    def _autodiscover_tasks_from_names(self, packages, related_name):
+    def _autodiscover_tasks_from_names(
+            self, packages: Sequence[str], related_name: str) -> None:
         # packages argument can be lazy
         # packages argument can be lazy
-        return self.loader.autodiscover_tasks(
+        self.loader.autodiscover_tasks(
             packages() if callable(packages) else packages, related_name,
             packages() if callable(packages) else packages, related_name,
         )
         )
 
 
-    def _autodiscover_tasks_from_fixups(self, related_name):
-        return self._autodiscover_tasks_from_names([
+    def _autodiscover_tasks_from_fixups(self, related_name: str) -> None:
+        self._autodiscover_tasks_from_names([
             pkg for fixup in self._fixups
             pkg for fixup in self._fixups
             for pkg in fixup.autodiscover_tasks()
             for pkg in fixup.autodiscover_tasks()
             if hasattr(fixup, 'autodiscover_tasks')
             if hasattr(fixup, 'autodiscover_tasks')
         ], related_name=related_name)
         ], related_name=related_name)
 
 
-    def send_task(self, name, args=None, kwargs=None, countdown=None,
-                  eta=None, task_id=None, producer=None, connection=None,
-                  router=None, result_cls=None, expires=None,
-                  link=None, link_error=None,
-                  add_to_parent=True, group_id=None, retries=0, chord=None,
-                  reply_to=None, time_limit=None, soft_time_limit=None,
-                  root_id=None, parent_id=None, route_name=None,
-                  shadow=None, chain=None, task_type=None, **options):
+    def send_task(self, name: str,
+                  args: Sequence = None,
+                  kwargs: Mapping = None,
+                  countdown: float = None,
+                  eta: datetime = None,
+                  task_id: str = None,
+                  producer: ProducerT = None,
+                  connection: ConnectionT = None,
+                  router: RouterT = None,
+                  result_cls: type = None,
+                  expires: Union[float, datetime] = None,
+                  link: Union[SignatureT, Sequence[SignatureT]]=None,
+                  link_error: Union[SignatureT, Sequence[SignatureT]] = None,
+                  add_to_parent: bool = True,
+                  group_id: str = None,
+                  retries: int = 0,
+                  chord: SignatureT = None,
+                  reply_to: str = None,
+                  time_limit: float = None,
+                  soft_time_limit: float = None,
+                  root_id: str = None,
+                  parent_id: str = None,
+                  route_name: str = None,
+                  shadow: str = None,
+                  chain: List[SignatureT] = None,
+                  task_type: TaskT = None,
+                  **options) -> ResultT:
         """Send task by name.
         """Send task by name.
 
 
         Supports the same arguments as :meth:`@-Task.apply_async`.
         Supports the same arguments as :meth:`@-Task.apply_async`.
@@ -737,7 +819,7 @@ class Celery:
                 parent.add_trail(result)
                 parent.add_trail(result)
         return result
         return result
 
 
-    def connection_for_read(self, url=None, **kwargs):
+    def connection_for_read(self, url: str = None, **kwargs) -> ConnectionT:
         """Establish connection used for consuming.
         """Establish connection used for consuming.
 
 
         See Also:
         See Also:
@@ -745,7 +827,7 @@ class Celery:
         """
         """
         return self._connection(url or self.conf.broker_read_url, **kwargs)
         return self._connection(url or self.conf.broker_read_url, **kwargs)
 
 
-    def connection_for_write(self, url=None, **kwargs):
+    def connection_for_write(self, url: str = None, **kwargs) -> ConnectionT:
         """Establish connection used for producing.
         """Establish connection used for producing.
 
 
         See Also:
         See Also:
@@ -753,11 +835,20 @@ class Celery:
         """
         """
         return self._connection(url or self.conf.broker_write_url, **kwargs)
         return self._connection(url or self.conf.broker_write_url, **kwargs)
 
 
-    def connection(self, hostname=None, userid=None, password=None,
-                   virtual_host=None, port=None, ssl=None,
-                   connect_timeout=None, transport=None,
-                   transport_options=None, heartbeat=None,
-                   login_method=None, failover_strategy=None, **kwargs):
+    def connection(self,
+                   hostname: str = None,
+                   userid: str = None,
+                   password: str = None,
+                   virtual_host: str = None,
+                   port: int = None,
+                   ssl: SSLArg = None,
+                   connect_timeout: float = None,
+                   transport: Any = None,
+                   transport_options: Mapping = None,
+                   heartbeat: float = None,
+                   login_method: str = None,
+                   failover_strategy: str = None,
+                   **kwargs) -> ConnectionT:
         """Establish a connection to the message broker.
         """Establish a connection to the message broker.
 
 
         Please use :meth:`connection_for_read` and
         Please use :meth:`connection_for_read` and
@@ -796,11 +887,19 @@ class Celery:
             **kwargs
             **kwargs
         )
         )
 
 
-    def _connection(self, url, userid=None, password=None,
-                    virtual_host=None, port=None, ssl=None,
-                    connect_timeout=None, transport=None,
-                    transport_options=None, heartbeat=None,
-                    login_method=None, failover_strategy=None, **kwargs):
+    def _connection(self, url: str,
+                    userid: str = None,
+                    password: str = None,
+                    virtual_host: str = None,
+                    port: int = None,
+                    ssl: SSLArg = None,
+                    connect_timeout: float = None,
+                    transport: Any = None,
+                    transport_options: Mapping = None,
+                    heartbeat: float = None,
+                    login_method: str = None,
+                    failover_strategy: str = None,
+                    **kwargs) -> ConnectionT:
         conf = self.conf
         conf = self.conf
         return self.amqp.Connection(
         return self.amqp.Connection(
             url,
             url,
@@ -824,13 +923,14 @@ class Celery:
         )
         )
     broker_connection = connection
     broker_connection = connection
 
 
-    def _acquire_connection(self, pool=True):
+    def _acquire_connection(self, pool: bool = True) -> ConnectionT:
         """Helper for :meth:`connection_or_acquire`."""
         """Helper for :meth:`connection_or_acquire`."""
         if pool:
         if pool:
             return self.pool.acquire(block=True)
             return self.pool.acquire(block=True)
         return self.connection_for_write()
         return self.connection_for_write()
 
 
-    def connection_or_acquire(self, connection=None, pool=True, *_, **__):
+    def connection_or_acquire(self, connection: ConnectionT = None,
+                              pool: bool = True, *_, **__) -> ContextManager:
         """Context used to acquire a connection from the pool.
         """Context used to acquire a connection from the pool.
 
 
         For use within a :keyword:`with` statement to get a connection
         For use within a :keyword:`with` statement to get a connection
@@ -842,7 +942,8 @@ class Celery:
         """
         """
         return FallbackContext(connection, self._acquire_connection, pool=pool)
         return FallbackContext(connection, self._acquire_connection, pool=pool)
 
 
-    def producer_or_acquire(self, producer=None):
+    def producer_or_acquire(self,
+                            producer: ProducerT = None) -> ContextManager:
         """Context used to acquire a producer from the pool.
         """Context used to acquire a producer from the pool.
 
 
         For use within a :keyword:`with` statement to get a producer
         For use within a :keyword:`with` statement to get a producer
@@ -856,23 +957,23 @@ class Celery:
             producer, self.producer_pool.acquire, block=True,
             producer, self.producer_pool.acquire, block=True,
         )
         )
 
 
-    def prepare_config(self, c):
+    def prepare_config(self, c: Mapping) -> Mapping:
         """Prepare configuration before it is merged with the defaults."""
         """Prepare configuration before it is merged with the defaults."""
         return find_deprecated_settings(c)
         return find_deprecated_settings(c)
 
 
-    def now(self):
+    def now(self) -> datetime:
         """Return the current time and date as a datetime."""
         """Return the current time and date as a datetime."""
         return self.loader.now(utc=self.conf.enable_utc)
         return self.loader.now(utc=self.conf.enable_utc)
 
 
-    def select_queues(self, queues=None):
+    def select_queues(self, queues: Sequence[str] = None) -> None:
         """Select subset of queues.
         """Select subset of queues.
 
 
         Arguments:
         Arguments:
             queues (Sequence[str]): a list of queue names to keep.
             queues (Sequence[str]): a list of queue names to keep.
         """
         """
-        return self.amqp.queues.select(queues)
+        self.amqp.queues.select(queues)
 
 
-    def either(self, default_key, *defaults):
+    def either(self, default_key: str, *defaults) -> Any:
         """Get key from configuration or use default values.
         """Get key from configuration or use default values.
 
 
         Fallback to the value of a configuration key if none of the
         Fallback to the value of a configuration key if none of the
@@ -882,17 +983,17 @@ class Celery:
             first(None, defaults), starpromise(self.conf.get, default_key),
             first(None, defaults), starpromise(self.conf.get, default_key),
         ])
         ])
 
 
-    def bugreport(self):
+    def bugreport(self) -> str:
         """Return information useful in bug reports."""
         """Return information useful in bug reports."""
         return bugreport(self)
         return bugreport(self)
 
 
-    def _get_backend(self):
+    def _get_backend(self) -> BackendT:
         backend, url = backends.by_url(
         backend, url = backends.by_url(
             self.backend_cls or self.conf.result_backend,
             self.backend_cls or self.conf.result_backend,
             self.loader)
             self.loader)
         return backend(app=self, url=url)
         return backend(app=self, url=url)
 
 
-    def _finalize_pending_conf(self):
+    def _finalize_pending_conf(self) -> MutableMapping:
         """Get config value by key and finalize loading the configuration.
         """Get config value by key and finalize loading the configuration.
 
 
         Note:
         Note:
@@ -902,7 +1003,7 @@ class Celery:
         conf = self._conf = self._load_config()
         conf = self._conf = self._load_config()
         return conf
         return conf
 
 
-    def _load_config(self):
+    def _load_config(self) -> MutableMapping:
         if isinstance(self.on_configure, Signal):
         if isinstance(self.on_configure, Signal):
             self.on_configure.send(sender=self)
             self.on_configure.send(sender=self)
         else:
         else:
@@ -936,7 +1037,7 @@ class Celery:
         self.on_after_configure.send(sender=self, source=self._conf)
         self.on_after_configure.send(sender=self, source=self._conf)
         return self._conf
         return self._conf
 
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         self._pool = None
         self._pool = None
         try:
         try:
             self.__dict__['amqp']._producer_pool = None
             self.__dict__['amqp']._producer_pool = None
@@ -944,13 +1045,14 @@ class Celery:
             pass
             pass
         self.on_after_fork.send(sender=self)
         self.on_after_fork.send(sender=self)
 
 
-    def signature(self, *args, **kwargs):
+    def signature(self, *args, **kwargs) -> SignatureT:
         """Return a new :class:`~celery.Signature` bound to this app."""
         """Return a new :class:`~celery.Signature` bound to this app."""
         kwargs['app'] = self
         kwargs['app'] = self
         return self._canvas.signature(*args, **kwargs)
         return self._canvas.signature(*args, **kwargs)
 
 
-    def add_periodic_task(self, schedule, sig,
-                          args=(), kwargs=(), name=None, **opts):
+    def add_periodic_task(self, schedule: ScheduleT, sig: SignatureT,
+                          args: Tuple = (), kwargs: Dict = (),
+                          name: str = None, **opts) -> str:
         key, entry = self._sig_to_periodic_task_entry(
         key, entry = self._sig_to_periodic_task_entry(
             schedule, sig, args, kwargs, name, **opts)
             schedule, sig, args, kwargs, name, **opts)
         if self.configured:
         if self.configured:
@@ -959,8 +1061,10 @@ class Celery:
             self._pending_periodic_tasks.append((key, entry))
             self._pending_periodic_tasks.append((key, entry))
         return key
         return key
 
 
-    def _sig_to_periodic_task_entry(self, schedule, sig,
-                                    args=(), kwargs={}, name=None, **opts):
+    def _sig_to_periodic_task_entry(
+            self, schedule: ScheduleT, sig: SignatureT,
+            args: Tuple = (), kwargs: Dict = {},
+            name: str = None, **opts) -> Tuple[str, Mapping]:
         sig = (sig.clone(args, kwargs)
         sig = (sig.clone(args, kwargs)
                if isinstance(sig, abstract.CallableSignature)
                if isinstance(sig, abstract.CallableSignature)
                else self.signature(sig.name, args, kwargs))
                else self.signature(sig.name, args, kwargs))
@@ -972,18 +1076,22 @@ class Celery:
             'options': dict(sig.options, **opts),
             'options': dict(sig.options, **opts),
         }
         }
 
 
-    def _add_periodic_task(self, key, entry):
+    def _add_periodic_task(self, key: str, entry: Mapping) -> None:
         self._conf.beat_schedule[key] = entry
         self._conf.beat_schedule[key] = entry
 
 
-    def create_task_cls(self):
+    def create_task_cls(self) -> type:
         """Create a base task class bound to this app."""
         """Create a base task class bound to this app."""
         return self.subclass_with_self(
         return self.subclass_with_self(
             self.task_cls, name='Task', attribute='_app',
             self.task_cls, name='Task', attribute='_app',
             keep_reduce=True, abstract=True,
             keep_reduce=True, abstract=True,
         )
         )
 
 
-    def subclass_with_self(self, Class, name=None, attribute='app',
-                           reverse=None, keep_reduce=False, **kw):
+    def subclass_with_self(self, Class: type,
+                           name: str = None,
+                           attribute: str = 'app',
+                           reverse: str = None,
+                           keep_reduce: bool = False,
+                           **kw) -> type:
         """Subclass an app-compatible class.
         """Subclass an app-compatible class.
 
 
         App-compatible means that the class has a class attribute that
         App-compatible means that the class has a class attribute that
@@ -1017,24 +1125,24 @@ class Celery:
 
 
         return type(name or Class.__name__, (Class,), attrs)
         return type(name or Class.__name__, (Class,), attrs)
 
 
-    def _rgetattr(self, path):
+    def _rgetattr(self, path: str) -> Any:
         return attrgetter(path)(self)
         return attrgetter(path)(self)
 
 
-    def __enter__(self):
+    def __enter__(self) -> AppT:
         return self
         return self
 
 
-    def __exit__(self, *exc_info):
+    def __exit__(self, *exc_info) -> None:
         self.close()
         self.close()
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<{0} {1}>'.format(type(self).__name__, appstr(self))
         return '<{0} {1}>'.format(type(self).__name__, appstr(self))
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         if self._using_v1_reduce:
         if self._using_v1_reduce:
             return self.__reduce_v1__()
             return self.__reduce_v1__()
         return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__()))
         return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__()))
 
 
-    def __reduce_v1__(self):
+    def __reduce_v1__(self) -> Tuple:
         # Reduce only pickles the configuration changes,
         # Reduce only pickles the configuration changes,
         # so the default configuration doesn't have to be passed
         # so the default configuration doesn't have to be passed
         # between processes.
         # between processes.
@@ -1043,7 +1151,7 @@ class Celery:
             (self.__class__, self.Pickler) + self.__reduce_args__(),
             (self.__class__, self.Pickler) + self.__reduce_args__(),
         )
         )
 
 
-    def __reduce_keys__(self):
+    def __reduce_keys__(self) -> Mapping[str, Any]:
         """Keyword arguments used to reconstruct the object when unpickling."""
         """Keyword arguments used to reconstruct the object when unpickling."""
         return {
         return {
             'main': self.main,
             'main': self.main,
@@ -1061,7 +1169,7 @@ class Celery:
             'namespace': self.namespace,
             'namespace': self.namespace,
         }
         }
 
 
-    def __reduce_args__(self):
+    def __reduce_args__(self) -> Tuple:
         """Deprecated method, please use :meth:`__reduce_keys__` instead."""
         """Deprecated method, please use :meth:`__reduce_keys__` instead."""
         return (self.main, self._conf.changes if self.configured else {},
         return (self.main, self._conf.changes if self.configured else {},
                 self.loader_cls, self.backend_cls, self.amqp_cls,
                 self.loader_cls, self.backend_cls, self.amqp_cls,
@@ -1069,7 +1177,7 @@ class Celery:
                 False, self._config_source)
                 False, self._config_source)
 
 
     @cached_property
     @cached_property
-    def Worker(self):
+    def Worker(self) -> WorkerT:
         """Worker application.
         """Worker application.
 
 
         See Also:
         See Also:
@@ -1078,7 +1186,7 @@ class Celery:
         return self.subclass_with_self('celery.apps.worker:Worker')
         return self.subclass_with_self('celery.apps.worker:Worker')
 
 
     @cached_property
     @cached_property
-    def WorkController(self, **kwargs):
+    def WorkController(self, **kwargs) -> WorkerT:
         """Embeddable worker.
         """Embeddable worker.
 
 
         See Also:
         See Also:
@@ -1087,7 +1195,7 @@ class Celery:
         return self.subclass_with_self('celery.worker:WorkController')
         return self.subclass_with_self('celery.worker:WorkController')
 
 
     @cached_property
     @cached_property
-    def Beat(self, **kwargs):
+    def Beat(self, **kwargs) -> BeatT:
         """:program:`celery beat` scheduler application.
         """:program:`celery beat` scheduler application.
 
 
         See Also:
         See Also:
@@ -1096,7 +1204,7 @@ class Celery:
         return self.subclass_with_self('celery.apps.beat:Beat')
         return self.subclass_with_self('celery.apps.beat:Beat')
 
 
     @cached_property
     @cached_property
-    def Task(self):
+    def Task(self) -> type:
         """Base task class for this app."""
         """Base task class for this app."""
         return self.create_task_cls()
         return self.create_task_cls()
 
 
@@ -1105,7 +1213,7 @@ class Celery:
         return prepare_annotations(self.conf.task_annotations)
         return prepare_annotations(self.conf.task_annotations)
 
 
     @cached_property
     @cached_property
-    def AsyncResult(self):
+    def AsyncResult(self) -> type:
         """Create new result instance.
         """Create new result instance.
 
 
         See Also:
         See Also:
@@ -1114,11 +1222,11 @@ class Celery:
         return self.subclass_with_self('celery.result:AsyncResult')
         return self.subclass_with_self('celery.result:AsyncResult')
 
 
     @cached_property
     @cached_property
-    def ResultSet(self):
+    def ResultSet(self) -> type:
         return self.subclass_with_self('celery.result:ResultSet')
         return self.subclass_with_self('celery.result:ResultSet')
 
 
     @cached_property
     @cached_property
-    def GroupResult(self):
+    def GroupResult(self) -> type:
         """Create new group result instance.
         """Create new group result instance.
 
 
         See Also:
         See Also:
@@ -1127,7 +1235,7 @@ class Celery:
         return self.subclass_with_self('celery.result:GroupResult')
         return self.subclass_with_self('celery.result:GroupResult')
 
 
     @property
     @property
-    def pool(self):
+    def pool(self) -> ResourceT:
         """Broker connection pool: :class:`~@pool`.
         """Broker connection pool: :class:`~@pool`.
 
 
         Note:
         Note:
@@ -1141,12 +1249,12 @@ class Celery:
         return self._pool
         return self._pool
 
 
     @property
     @property
-    def current_task(self):
+    def current_task(self) -> Optional[TaskT]:
         """Instance of task being executed, or :const:`None`."""
         """Instance of task being executed, or :const:`None`."""
         return _task_stack.top
         return _task_stack.top
 
 
     @property
     @property
-    def current_worker_task(self):
+    def current_worker_task(self) -> Optional[TaskT]:
         """The task currently being executed by a worker or :const:`None`.
         """The task currently being executed by a worker or :const:`None`.
 
 
         Differs from :data:`current_task` in that it's not affected
         Differs from :data:`current_task` in that it's not affected
@@ -1155,7 +1263,7 @@ class Celery:
         return get_current_worker_task()
         return get_current_worker_task()
 
 
     @cached_property
     @cached_property
-    def oid(self):
+    def oid(self) -> str:
         """Universally unique identifier for this app."""
         """Universally unique identifier for this app."""
         # since 4.0: thread.get_ident() is not included when
         # since 4.0: thread.get_ident() is not included when
         # generating the process id.  This is due to how the RPC
         # generating the process id.  This is due to how the RPC
@@ -1164,53 +1272,53 @@ class Celery:
         return oid_from(self, threads=False)
         return oid_from(self, threads=False)
 
 
     @cached_property
     @cached_property
-    def amqp(self):
+    def amqp(self) -> AppAMQPT:
         """AMQP related functionality: :class:`~@amqp`."""
         """AMQP related functionality: :class:`~@amqp`."""
         return instantiate(self.amqp_cls, app=self)
         return instantiate(self.amqp_cls, app=self)
 
 
     @cached_property
     @cached_property
-    def backend(self):
+    def backend(self) -> BackendT:
         """Current backend instance."""
         """Current backend instance."""
         return self._get_backend()
         return self._get_backend()
 
 
     @property
     @property
-    def conf(self):
+    def conf(self) -> MutableMapping:
         """Current configuration."""
         """Current configuration."""
         if self._conf is None:
         if self._conf is None:
             self._conf = self._load_config()
             self._conf = self._load_config()
         return self._conf
         return self._conf
 
 
     @conf.setter
     @conf.setter
-    def conf(self, d):  # noqa
+    def conf(self, d: MutableMapping) -> None:  # noqa
         self._conf = d
         self._conf = d
 
 
     @cached_property
     @cached_property
-    def control(self):
+    def control(self) -> AppControlT:
         """Remote control: :class:`~@control`."""
         """Remote control: :class:`~@control`."""
         return instantiate(self.control_cls, app=self)
         return instantiate(self.control_cls, app=self)
 
 
     @cached_property
     @cached_property
-    def events(self):
+    def events(self) -> AppEventsT:
         """Consuming and sending events: :class:`~@events`."""
         """Consuming and sending events: :class:`~@events`."""
         return instantiate(self.events_cls, app=self)
         return instantiate(self.events_cls, app=self)
 
 
     @cached_property
     @cached_property
-    def loader(self):
+    def loader(self) -> LoaderT:
         """Current loader instance."""
         """Current loader instance."""
         return get_loader_cls(self.loader_cls)(app=self)
         return get_loader_cls(self.loader_cls)(app=self)
 
 
     @cached_property
     @cached_property
-    def log(self):
+    def log(self) -> AppLogT:
         """Logging: :class:`~@log`."""
         """Logging: :class:`~@log`."""
         return instantiate(self.log_cls, app=self)
         return instantiate(self.log_cls, app=self)
 
 
     @cached_property
     @cached_property
-    def _canvas(self):
+    def _canvas(self) -> ModuleType:
         from celery import canvas
         from celery import canvas
         return canvas
         return canvas
 
 
     @cached_property
     @cached_property
-    def tasks(self):
+    def tasks(self) -> TaskRegistryT:
         """Task registry.
         """Task registry.
 
 
         Warning:
         Warning:

+ 13 - 6
celery/app/defaults.py

@@ -1,8 +1,9 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """Configuration introspection and defaults."""
 """Configuration introspection and defaults."""
 import sys
 import sys
-from collections import deque, namedtuple
+from collections import deque
 from datetime import timedelta
 from datetime import timedelta
+from typing import NamedTuple
 from celery.utils.functional import memoize
 from celery.utils.functional import memoize
 from celery.utils.serialization import strtobool
 from celery.utils.serialization import strtobool
 
 
@@ -31,7 +32,13 @@ OLD_NS = {'celery_{0}'}
 OLD_NS_BEAT = {'celerybeat_{0}'}
 OLD_NS_BEAT = {'celerybeat_{0}'}
 OLD_NS_WORKER = {'celeryd_{0}'}
 OLD_NS_WORKER = {'celeryd_{0}'}
 
 
-searchresult = namedtuple('searchresult', ('namespace', 'key', 'type'))
+
+class find_result_t(NamedTuple):
+    """Return value of :func:`find`."""
+
+    namespace: str
+    key: str
+    type: 'Option'
 
 
 
 
 def Namespace(__old__=None, **options):
 def Namespace(__old__=None, **options):
@@ -340,18 +347,18 @@ def find(name, namespace='celery'):
     # - Try specified name-space first.
     # - Try specified name-space first.
     namespace = namespace.lower()
     namespace = namespace.lower()
     try:
     try:
-        return searchresult(
+        return find_result_t(
             namespace, name.lower(), NAMESPACES[namespace][name.lower()],
             namespace, name.lower(), NAMESPACES[namespace][name.lower()],
         )
         )
     except KeyError:
     except KeyError:
         # - Try all the other namespaces.
         # - Try all the other namespaces.
         for ns, opts in NAMESPACES.items():
         for ns, opts in NAMESPACES.items():
             if ns.lower() == name.lower():
             if ns.lower() == name.lower():
-                return searchresult(None, ns, opts)
+                return find_result_t(None, ns, opts)
             elif isinstance(opts, dict):
             elif isinstance(opts, dict):
                 try:
                 try:
-                    return searchresult(ns, name.lower(), opts[name.lower()])
+                    return find_result_t(ns, name.lower(), opts[name.lower()])
                 except KeyError:
                 except KeyError:
                     pass
                     pass
     # - See if name is a qualname last.
     # - See if name is a qualname last.
-    return searchresult(None, name.lower(), DEFAULTS[name.lower()])
+    return find_result_t(None, name.lower(), DEFAULTS[name.lower()])

+ 12 - 6
celery/app/events.py

@@ -1,5 +1,8 @@
 """Implementation for the app.events shortcuts."""
 """Implementation for the app.events shortcuts."""
 from contextlib import contextmanager
 from contextlib import contextmanager
+from celery.events import EventDispatcher, EventReceiver
+from celery.events.state import State
+from celery.types import AppT
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 
 
@@ -10,27 +13,30 @@ class Events:
     dispatcher_cls = 'celery.events.dispatcher:EventDispatcher'
     dispatcher_cls = 'celery.events.dispatcher:EventDispatcher'
     state_cls = 'celery.events.state:State'
     state_cls = 'celery.events.state:State'
 
 
-    def __init__(self, app=None):
+    def __init__(self, app: AppT = None):
         self.app = app
         self.app = app
 
 
     @cached_property
     @cached_property
-    def Receiver(self):
+    def Receiver(self) -> EventReceiver:
         return self.app.subclass_with_self(
         return self.app.subclass_with_self(
             self.receiver_cls, reverse='events.Receiver')
             self.receiver_cls, reverse='events.Receiver')
 
 
     @cached_property
     @cached_property
-    def Dispatcher(self):
+    def Dispatcher(self) -> EventDispatcher:
         return self.app.subclass_with_self(
         return self.app.subclass_with_self(
             self.dispatcher_cls, reverse='events.Dispatcher')
             self.dispatcher_cls, reverse='events.Dispatcher')
 
 
     @cached_property
     @cached_property
-    def State(self):
+    def State(self) -> State:
         return self.app.subclass_with_self(
         return self.app.subclass_with_self(
             self.state_cls, reverse='events.State')
             self.state_cls, reverse='events.State')
 
 
     @contextmanager
     @contextmanager
-    def default_dispatcher(self, hostname=None, enabled=True,
-                           buffer_while_offline=False):
+    def default_dispatcher(
+            self,
+            hostname: str = None,
+            enabled: bool = True,
+            buffer_while_offline: bool = False) -> EventDispatcher:
         with self.app.amqp.producer_pool.acquire(block=True) as prod:
         with self.app.amqp.producer_pool.acquire(block=True) as prod:
             # pylint: disable=too-many-function-args
             # pylint: disable=too-many-function-args
             # This is a property pylint...
             # This is a property pylint...

+ 56 - 26
celery/app/log.py

@@ -12,6 +12,7 @@ import os
 import sys
 import sys
 
 
 from logging.handlers import WatchedFileHandler
 from logging.handlers import WatchedFileHandler
+from typing import Union
 
 
 from kombu.utils.encoding import set_default_encoding_file
 from kombu.utils.encoding import set_default_encoding_file
 
 
@@ -19,6 +20,7 @@ from celery import signals
 from celery._state import get_current_task
 from celery._state import get_current_task
 from celery.local import class_property
 from celery.local import class_property
 from celery.platforms import isatty
 from celery.platforms import isatty
+from celery.types import AppT
 from celery.utils.log import (
 from celery.utils.log import (
     get_logger, mlevel,
     get_logger, mlevel,
     ColorFormatter, LoggingProxy, get_multiprocessing_logger,
     ColorFormatter, LoggingProxy, get_multiprocessing_logger,
@@ -35,7 +37,7 @@ MP_LOG = os.environ.get('MP_LOG', False)
 class TaskFormatter(ColorFormatter):
 class TaskFormatter(ColorFormatter):
     """Formatter for tasks, adding the task name and id."""
     """Formatter for tasks, adding the task name and id."""
 
 
-    def format(self, record):
+    def format(self, record: logging.LogRecord) -> str:
         task = get_current_task()
         task = get_current_task()
         if task and task.request:
         if task and task.request:
             record.__dict__.update(task_id=task.request.id,
             record.__dict__.update(task_id=task.request.id,
@@ -54,15 +56,20 @@ class Logging:
     #: will do nothing.
     #: will do nothing.
     _setup = False
     _setup = False
 
 
-    def __init__(self, app):
+    def __init__(self, app: AppT) -> None:
         self.app = app
         self.app = app
         self.loglevel = mlevel(logging.WARN)
         self.loglevel = mlevel(logging.WARN)
         self.format = self.app.conf.worker_log_format
         self.format = self.app.conf.worker_log_format
         self.task_format = self.app.conf.worker_task_log_format
         self.task_format = self.app.conf.worker_task_log_format
         self.colorize = self.app.conf.worker_log_color
         self.colorize = self.app.conf.worker_log_color
 
 
-    def setup(self, loglevel=None, logfile=None, redirect_stdouts=False,
-              redirect_level='WARNING', colorize=None, hostname=None):
+    def setup(self,
+              loglevel: Union[str, int] = None,
+              logfile: str = None,
+              redirect_stdouts: bool = False,
+              redirect_level: str = 'WARNING',
+              colorize: bool = None,
+              hostname: str = None) -> bool:
         loglevel = mlevel(loglevel)
         loglevel = mlevel(loglevel)
         handled = self.setup_logging_subsystem(
         handled = self.setup_logging_subsystem(
             loglevel, logfile, colorize=colorize, hostname=hostname,
             loglevel, logfile, colorize=colorize, hostname=hostname,
@@ -76,7 +83,9 @@ class Logging:
         )
         )
         return handled
         return handled
 
 
-    def redirect_stdouts(self, loglevel=None, name='celery.redirected'):
+    def redirect_stdouts(self,
+                         loglevel: int = None,
+                         name: str = 'celery.redirected') -> None:
         self.redirect_stdouts_to_logger(
         self.redirect_stdouts_to_logger(
             get_logger(name), loglevel=loglevel
             get_logger(name), loglevel=loglevel
         )
         )
@@ -85,10 +94,15 @@ class Logging:
             CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''),
             CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''),
         )
         )
 
 
-    def setup_logging_subsystem(self, loglevel=None, logfile=None, format=None,
-                                colorize=None, hostname=None, **kwargs):
+    def setup_logging_subsystem(self,
+                                loglevel: Union[int, str] = None,
+                                logfile: str = None,
+                                format: str = None,
+                                colorize: bool = None,
+                                hostname: str = None,
+                                **kwargs) -> bool:
         if self.already_setup:
         if self.already_setup:
-            return
+            return False
         if logfile and hostname:
         if logfile and hostname:
             logfile = node_format(logfile, hostname)
             logfile = node_format(logfile, hostname)
         Logging._setup = True
         Logging._setup = True
@@ -144,18 +158,28 @@ class Logging:
         os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel),
         os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel),
                           _MP_FORK_LOGFILE_=logfile_name,
                           _MP_FORK_LOGFILE_=logfile_name,
                           _MP_FORK_LOGFORMAT_=format)
                           _MP_FORK_LOGFORMAT_=format)
-        return receivers
-
-    def _configure_logger(self, logger, logfile, loglevel,
-                          format, colorize, **kwargs):
+        return bool(receivers)
+
+    def _configure_logger(self,
+                          logger: logging.Logger,
+                          logfile: str,
+                          loglevel: int,
+                          format: str,
+                          colorize: bool,
+                          **kwargs) -> None:
         if logger is not None:
         if logger is not None:
             self.setup_handlers(logger, logfile, format,
             self.setup_handlers(logger, logfile, format,
                                 colorize, **kwargs)
                                 colorize, **kwargs)
             if loglevel:
             if loglevel:
                 logger.setLevel(loglevel)
                 logger.setLevel(loglevel)
 
 
-    def setup_task_loggers(self, loglevel=None, logfile=None, format=None,
-                           colorize=None, propagate=False, **kwargs):
+    def setup_task_loggers(self,
+                           loglevel: Union[str, int] = None,
+                           logfile: str = None,
+                           format: str = None,
+                           colorize: bool = None,
+                           propagate: bool = False,
+                           **kwargs) -> logging.Logger:
         """Setup the task logger.
         """Setup the task logger.
 
 
         If `logfile` is not specified, then `sys.stderr` is used.
         If `logfile` is not specified, then `sys.stderr` is used.
@@ -181,8 +205,10 @@ class Logging:
         )
         )
         return logger
         return logger
 
 
-    def redirect_stdouts_to_logger(self, logger, loglevel=None,
-                                   stdout=True, stderr=True):
+    def redirect_stdouts_to_logger(self, logger: logging.Logger,
+                                   loglevel: int = None,
+                                   stdout: bool = True,
+                                   stderr: bool = True) -> LoggingProxy:
         """Redirect :class:`sys.stdout` and :class:`sys.stderr` to logger.
         """Redirect :class:`sys.stdout` and :class:`sys.stderr` to logger.
 
 
         Arguments:
         Arguments:
@@ -197,7 +223,9 @@ class Logging:
             sys.stderr = proxy
             sys.stderr = proxy
         return proxy
         return proxy
 
 
-    def supports_color(self, colorize=None, logfile=None):
+    def supports_color(self,
+                       colorize: bool = None,
+                       logfile: str = None) -> bool:
         colorize = self.colorize if colorize is None else colorize
         colorize = self.colorize if colorize is None else colorize
         if self.app.IS_WINDOWS:
         if self.app.IS_WINDOWS:
             # Windows does not support ANSI color codes.
             # Windows does not support ANSI color codes.
@@ -208,11 +236,12 @@ class Logging:
             return logfile is None and isatty(sys.stderr)
             return logfile is None and isatty(sys.stderr)
         return colorize
         return colorize
 
 
-    def colored(self, logfile=None, enabled=None):
+    def colored(self, logfile: str = None, enabled: bool = None) -> colored:
         return colored(enabled=self.supports_color(enabled, logfile))
         return colored(enabled=self.supports_color(enabled, logfile))
 
 
-    def setup_handlers(self, logger, logfile, format, colorize,
-                       formatter=ColorFormatter, **kwargs):
+    def setup_handlers(self, logger: logging.Logger,
+                       logfile: str, format: str, colorize: bool,
+                       formatter=ColorFormatter, **kwargs) -> logging.Logger:
         if self._is_configured(logger):
         if self._is_configured(logger):
             return logger
             return logger
         handler = self._detect_handler(logfile)
         handler = self._detect_handler(logfile)
@@ -220,30 +249,31 @@ class Logging:
         logger.addHandler(handler)
         logger.addHandler(handler)
         return logger
         return logger
 
 
-    def _detect_handler(self, logfile=None):
+    def _detect_handler(self, logfile: str = None) -> logging.Handler:
         """Create handler from filename, an open stream or `None` (stderr)."""
         """Create handler from filename, an open stream or `None` (stderr)."""
         logfile = sys.__stderr__ if logfile is None else logfile
         logfile = sys.__stderr__ if logfile is None else logfile
         if hasattr(logfile, 'write'):
         if hasattr(logfile, 'write'):
             return logging.StreamHandler(logfile)
             return logging.StreamHandler(logfile)
         return WatchedFileHandler(logfile)
         return WatchedFileHandler(logfile)
 
 
-    def _has_handler(self, logger):
+    def _has_handler(self, logger: logging.Logger) -> bool:
         return any(
         return any(
             not isinstance(h, logging.NullHandler)
             not isinstance(h, logging.NullHandler)
             for h in logger.handlers or []
             for h in logger.handlers or []
         )
         )
 
 
-    def _is_configured(self, logger):
+    def _is_configured(self, logger: logging.Logger) -> bool:
         return self._has_handler(logger) and not getattr(
         return self._has_handler(logger) and not getattr(
             logger, '_rudimentary_setup', False)
             logger, '_rudimentary_setup', False)
 
 
-    def get_default_logger(self, name='celery', **kwargs):
+    def get_default_logger(self, name: str = 'celery',
+                           **kwargs) -> logging.Logger:
         return get_logger(name)
         return get_logger(name)
 
 
     @class_property
     @class_property
-    def already_setup(self):
+    def already_setup(self) -> bool:
         return self._setup
         return self._setup
 
 
     @already_setup.setter  # noqa
     @already_setup.setter  # noqa
-    def already_setup(self, was_setup):
+    def already_setup(self, was_setup: bool) -> None:
         self._setup = was_setup
         self._setup = was_setup

+ 7 - 16
celery/app/registry.py

@@ -2,8 +2,10 @@
 """Registry of available tasks."""
 """Registry of available tasks."""
 import inspect
 import inspect
 from importlib import import_module
 from importlib import import_module
+from typing import Any
 from celery._state import get_current_app
 from celery._state import get_current_app
 from celery.exceptions import NotRegistered, InvalidTaskError
 from celery.exceptions import NotRegistered, InvalidTaskError
+from celery.types import TaskT
 
 
 __all__ = ['TaskRegistry']
 __all__ = ['TaskRegistry']
 
 
@@ -13,10 +15,10 @@ class TaskRegistry(dict):
 
 
     NotRegistered = NotRegistered
     NotRegistered = NotRegistered
 
 
-    def __missing__(self, key):
+    def __missing__(self, key: str) -> Any:
         raise self.NotRegistered(key)
         raise self.NotRegistered(key)
 
 
-    def register(self, task):
+    def register(self, task: TaskT) -> None:
         """Register a task in the task registry.
         """Register a task in the task registry.
 
 
         The task will be automatically instantiated if not already an
         The task will be automatically instantiated if not already an
@@ -28,7 +30,7 @@ class TaskRegistry(dict):
                     type(task).__name__))
                     type(task).__name__))
         self[task.name] = inspect.isclass(task) and task() or task
         self[task.name] = inspect.isclass(task) and task() or task
 
 
-    def unregister(self, name):
+    def unregister(self, name: str) -> None:
         """Unregister task by name.
         """Unregister task by name.
 
 
         Arguments:
         Arguments:
@@ -43,23 +45,12 @@ class TaskRegistry(dict):
         except KeyError:
         except KeyError:
             raise self.NotRegistered(name)
             raise self.NotRegistered(name)
 
 
-    # -- these methods are irrelevant now and will be removed in 4.0
-    def regular(self):
-        return self.filter_types('regular')
 
 
-    def periodic(self):
-        return self.filter_types('periodic')
-
-    def filter_types(self, type):
-        return {name: task for name, task in self.items()
-                if getattr(task, 'type', 'regular') == type}
-
-
-def _unpickle_task(name):
+def _unpickle_task(name: str) -> TaskT:
     return get_current_app().tasks[name]
     return get_current_app().tasks[name]
 
 
 
 
-def _unpickle_task_v2(name, module=None):
+def _unpickle_task_v2(name: str, module: str = None) -> TaskT:
     if module:
     if module:
         import_module(module)
         import_module(module)
     return get_current_app().tasks[name]
     return get_current_app().tasks[name]

+ 32 - 12
celery/app/routes.py

@@ -6,8 +6,10 @@ Contains utilities for working with task routers, (:setting:`task_routes`).
 import re
 import re
 import string
 import string
 from collections import Mapping, OrderedDict
 from collections import Mapping, OrderedDict
+from typing import Any, Callable, Sequence, Union, Tuple
 from kombu import Queue
 from kombu import Queue
 from celery.exceptions import QueueNotFound
 from celery.exceptions import QueueNotFound
+from celery.types import AppT, RouterT, TaskT
 from celery.utils.collections import lpmerge
 from celery.utils.collections import lpmerge
 from celery.utils.functional import maybe_evaluate, mlazy
 from celery.utils.functional import maybe_evaluate, mlazy
 from celery.utils.imports import symbol_by_name
 from celery.utils.imports import symbol_by_name
@@ -15,7 +17,8 @@ from celery.utils.imports import symbol_by_name
 __all__ = ['MapRoute', 'Router', 'prepare']
 __all__ = ['MapRoute', 'Router', 'prepare']
 
 
 
 
-def glob_to_re(glob, quote=string.punctuation.replace('*', '')):
+def glob_to_re(glob: str, *,
+               quote: str = string.punctuation.replace('*', '')) -> str:
     glob = ''.join('\\' + c if c in quote else c for c in glob)
     glob = ''.join('\\' + c if c in quote else c for c in glob)
     return glob.replace('*', '.+?')
     return glob.replace('*', '.+?')
 
 
@@ -23,7 +26,10 @@ def glob_to_re(glob, quote=string.punctuation.replace('*', '')):
 class MapRoute:
 class MapRoute:
     """Creates a router out of a :class:`dict`."""
     """Creates a router out of a :class:`dict`."""
 
 
-    def __init__(self, map):
+    map: Sequence[Tuple[str, Any]] = None
+    patterns: Mapping = None
+
+    def __init__(self, map: Union[Mapping, Sequence[Tuple[str, Any]]]) -> None:
         map = map.items() if isinstance(map, Mapping) else map
         map = map.items() if isinstance(map, Mapping) else map
         self.map = {}
         self.map = {}
         self.patterns = OrderedDict()
         self.patterns = OrderedDict()
@@ -35,7 +41,7 @@ class MapRoute:
             else:
             else:
                 self.map[k] = v
                 self.map[k] = v
 
 
-    def __call__(self, name, *args, **kwargs):
+    def __call__(self, name: str, *args, **kwargs) -> Mapping:
         try:
         try:
             return dict(self.map[name])
             return dict(self.map[name])
         except KeyError:
         except KeyError:
@@ -53,14 +59,19 @@ class MapRoute:
 class Router:
 class Router:
     """Route tasks based on the :setting:`task_routes` setting."""
     """Route tasks based on the :setting:`task_routes` setting."""
 
 
-    def __init__(self, routes=None, queues=None,
-                 create_missing=False, app=None):
+    def __init__(self,
+                 routes: Sequence = None,
+                 queues: Mapping = None,
+                 create_missing: bool = False,
+                 app: AppT = None) -> None:
         self.app = app
         self.app = app
         self.queues = {} if queues is None else queues
         self.queues = {} if queues is None else queues
         self.routes = [] if routes is None else routes
         self.routes = [] if routes is None else routes
         self.create_missing = create_missing
         self.create_missing = create_missing
 
 
-    def route(self, options, name, args=(), kwargs={}, task_type=None):
+    def route(self, options: Mapping, name: str,
+              args: Sequence = (), kwargs: Mapping = {},
+              task_type: TaskT = None) -> Mapping:
         options = self.expand_destination(options)  # expands 'queue'
         options = self.expand_destination(options)  # expands 'queue'
         if self.routes:
         if self.routes:
             route = self.lookup_route(name, args, kwargs, options, task_type)
             route = self.lookup_route(name, args, kwargs, options, task_type)
@@ -71,7 +82,7 @@ class Router:
                               self.app.conf.task_default_queue), options)
                               self.app.conf.task_default_queue), options)
         return options
         return options
 
 
-    def expand_destination(self, route):
+    def expand_destination(self, route: Union[str, Mapping]) -> Mapping:
         # Route can be a queue name: convenient for direct exchanges.
         # Route can be a queue name: convenient for direct exchanges.
         if isinstance(route, str):
         if isinstance(route, str):
             queue, route = route, {}
             queue, route = route, {}
@@ -91,15 +102,24 @@ class Router:
                         'Queue {0!r} missing from task_queues'.format(queue))
                         'Queue {0!r} missing from task_queues'.format(queue))
         return route
         return route
 
 
-    def lookup_route(self, name,
-                     args=None, kwargs=None, options=None, task_type=None):
+    def lookup_route(self, name: str,
+                     args: Sequence = None,
+                     kwargs: Mapping = None,
+                     options: Mapping = None,
+                     task_type: TaskT = None) -> Mapping:
         query = self.query_router
         query = self.query_router
         for router in self.routes:
         for router in self.routes:
             route = query(router, name, args, kwargs, options, task_type)
             route = query(router, name, args, kwargs, options, task_type)
             if route is not None:
             if route is not None:
                 return route
                 return route
 
 
-    def query_router(self, router, task, args, kwargs, options, task_type):
+    def query_router(self,
+                     router: Union[RouterT, Callable],
+                     task: str,
+                     args: Sequence,
+                     kwargs: Mapping,
+                     options: Mapping,
+                     task_type: TaskT) -> None:
         router = maybe_evaluate(router)
         router = maybe_evaluate(router)
         if hasattr(router, 'route_for_task'):
         if hasattr(router, 'route_for_task'):
             # pre 4.0 router class
             # pre 4.0 router class
@@ -107,7 +127,7 @@ class Router:
         return router(task, args, kwargs, options, task=task_type)
         return router(task, args, kwargs, options, task=task_type)
 
 
 
 
-def expand_router_string(router):
+def expand_router_string(router: Any):
     router = symbol_by_name(router)
     router = symbol_by_name(router)
     if hasattr(router, 'route_for_task'):
     if hasattr(router, 'route_for_task'):
         # need to instantiate pre 4.0 router classes
         # need to instantiate pre 4.0 router classes
@@ -115,7 +135,7 @@ def expand_router_string(router):
     return router
     return router
 
 
 
 
-def prepare(routes):
+def prepare(routes: Any) -> Sequence[RouterT]:
     """Expand the :setting:`task_routes` setting."""
     """Expand the :setting:`task_routes` setting."""
     def expand_route(route):
     def expand_route(route):
         if isinstance(route, (Mapping, list, tuple)):
         if isinstance(route, (Mapping, list, tuple)):

+ 190 - 122
celery/app/task.py

@@ -2,8 +2,14 @@
 """Task implementation: request context and the task base class."""
 """Task implementation: request context and the task base class."""
 import sys
 import sys
 
 
+from datetime import datetime
+from typing import (
+    Any, Awaitable, Callable, Iterable, Mapping, Sequence, Tuple, Union,
+)
+
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 from kombu.exceptions import OperationalError
 from kombu.exceptions import OperationalError
+from kombu.types import ProducerT
 from kombu.utils.uuid import uuid
 from kombu.utils.uuid import uuid
 
 
 from celery import current_app, group
 from celery import current_app, group
@@ -13,9 +19,13 @@ from celery.canvas import signature
 from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.local import class_property
 from celery.local import class_property
 from celery.result import EagerResult
 from celery.result import EagerResult
+from celery.types import (
+    AppT, BackendT, ResultT, SignatureT, TaskT, TracerT, WorkerConsumerT,
+)
 from celery.utils import abstract
 from celery.utils import abstract
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.imports import instantiate
 from celery.utils.imports import instantiate
+from celery.utils.threads import LocalStack
 
 
 from .annotations import resolve_all as resolve_all_annotations
 from .annotations import resolve_all as resolve_all_annotations
 from .registry import _unpickle_task_v2
 from .registry import _unpickle_task_v2
@@ -40,13 +50,13 @@ R_INSTANCE = '<@task: {0.name} of {app}{flags}>'
 TaskType = type
 TaskType = type
 
 
 
 
-def _strflags(flags, default=''):
+def _strflags(flags: Sequence, default: str = '') -> str:
     if flags:
     if flags:
         return ' ({0})'.format(', '.join(flags))
         return ' ({0})'.format(', '.join(flags))
     return default
     return default
 
 
 
 
-def _reprtask(task, fmt=None, flags=None):
+def _reprtask(task: TaskT, fmt: str = None, flags: Sequence = None) -> str:
     flags = list(flags) if flags is not None else []
     flags = list(flags) if flags is not None else []
     if not fmt:
     if not fmt:
         fmt = R_BOUND_TASK if task._app else R_UNBOUND_TASK
         fmt = R_BOUND_TASK if task._app else R_UNBOUND_TASK
@@ -59,51 +69,54 @@ def _reprtask(task, fmt=None, flags=None):
 class Context:
 class Context:
     """Task request variables (Task.request)."""
     """Task request variables (Task.request)."""
 
 
-    logfile = None
-    loglevel = None
-    hostname = None
-    id = None
-    args = None
-    kwargs = None
-    retries = 0
-    eta = None
-    expires = None
-    is_eager = False
-    headers = None
-    delivery_info = None
-    reply_to = None
-    root_id = None
-    parent_id = None
-    correlation_id = None
-    taskset = None   # compat alias to group
-    group = None
-    chord = None
-    chain = None
-    utc = None
-    called_directly = True
-    callbacks = None
-    errbacks = None
-    timelimit = None
-    origin = None
-    _children = None   # see property
-    _protected = 0
-
-    def __init__(self, *args, **kwargs):
+    logfile: str = None
+    loglevel: int = None
+    hostname: str = None
+    id: str = None
+    args: Sequence = None
+    kwargs: Mapping = None
+    retries: int = 0
+    eta: datetime = None
+    expires: Union[float, datetime] = None
+    is_eager: bool = False
+    headers: Mapping = None
+    delivery_info: Mapping = None
+    reply_to: str = None
+    root_id: str = None
+    parent_id: str = None
+    correlation_id: str = None
+    # compat alias to group
+    taskset: str = None
+    group: str = None
+    chord: SignatureT = None
+    chain: Sequence[SignatureT] = None
+    utc: bool = None
+    called_directly: bool = True
+    callbacks: Sequence[SignatureT] = None
+    errbacks: Sequence[SignatureT] = None
+    timelimit: Tuple[float, float] = None
+    origin: str = None
+
+    # see property
+    _children: Sequence[ResultT] = None
+    _protected: int = 0
+
+    def __init__(self, *args, **kwargs) -> None:
         self.update(*args, **kwargs)
         self.update(*args, **kwargs)
 
 
-    def update(self, *args, **kwargs):
-        return self.__dict__.update(*args, **kwargs)
+    def update(self, *args, **kwargs) -> None:
+        self.__dict__.update(*args, **kwargs)
 
 
-    def clear(self):
-        return self.__dict__.clear()
+    def clear(self) -> None:
+        self.__dict__.clear()
 
 
-    def get(self, key, default=None):
+    def get(self, key: str, default: Any = None) -> Any:
         return getattr(self, key, default)
         return getattr(self, key, default)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<Context: {0!r}>'.format(vars(self))
         return '<Context: {0!r}>'.format(vars(self))
 
 
-    def as_execution_options(self):
+    def as_execution_options(self) -> Mapping:
         limit_hard, limit_soft = self.timelimit or (None, None)
         limit_hard, limit_soft = self.timelimit or (None, None)
         return {
         return {
             'task_id': self.id,
             'task_id': self.id,
@@ -124,7 +137,7 @@ class Context:
         }
         }
 
 
     @property
     @property
-    def children(self):
+    def children(self) -> Sequence[ResultT]:
         # children must be an empy list for every thread
         # children must be an empy list for every thread
         if self._children is None:
         if self._children is None:
             self._children = []
             self._children = []
@@ -141,7 +154,7 @@ class Task:
         is overridden).
         is overridden).
     """
     """
 
 
-    __trace__ = None
+    __trace__: TracerT = None
     __v2_compat__ = False  # set by old base in celery.task.base
     __v2_compat__ = False  # set by old base in celery.task.base
 
 
     MaxRetriesExceededError = MaxRetriesExceededError
     MaxRetriesExceededError = MaxRetriesExceededError
@@ -151,42 +164,42 @@ class Task:
     Strategy = 'celery.worker.strategy:default'
     Strategy = 'celery.worker.strategy:default'
 
 
     #: This is the instance bound to if the task is a method of a class.
     #: This is the instance bound to if the task is a method of a class.
-    __self__ = None
+    __self__: Any = None
 
 
     #: The application instance associated with this task class.
     #: The application instance associated with this task class.
-    _app = None
+    _app: AppT = None
 
 
     #: Name of the task.
     #: Name of the task.
-    name = None
+    name: str = None
 
 
     #: Enable argument checking.
     #: Enable argument checking.
     #: You can set this to false if you don't want the signature to be
     #: You can set this to false if you don't want the signature to be
     #: checked when calling the task.
     #: checked when calling the task.
     #: Defaults to :attr:`app.strict_typing <@Celery.strict_typing>`.
     #: Defaults to :attr:`app.strict_typing <@Celery.strict_typing>`.
-    typing = None
+    typing: bool = None
 
 
     #: Maximum number of retries before giving up.  If set to :const:`None`,
     #: Maximum number of retries before giving up.  If set to :const:`None`,
     #: it will **never** stop retrying.
     #: it will **never** stop retrying.
-    max_retries = 3
+    max_retries: int = 3
 
 
     #: Default time in seconds before a retry of the task should be
     #: Default time in seconds before a retry of the task should be
     #: executed.  3 minutes by default.
     #: executed.  3 minutes by default.
-    default_retry_delay = 3 * 60
+    default_retry_delay = 180.0
 
 
     #: Rate limit for this task type.  Examples: :const:`None` (no rate
     #: Rate limit for this task type.  Examples: :const:`None` (no rate
     #: limit), `'100/s'` (hundred tasks a second), `'100/m'` (hundred tasks
     #: limit), `'100/s'` (hundred tasks a second), `'100/m'` (hundred tasks
     #: a minute),`'100/h'` (hundred tasks an hour)
     #: a minute),`'100/h'` (hundred tasks an hour)
-    rate_limit = None
+    rate_limit: Union[int, str] = None
 
 
     #: If enabled the worker won't store task state and return values
     #: If enabled the worker won't store task state and return values
     #: for this task.  Defaults to the :setting:`task_ignore_result`
     #: for this task.  Defaults to the :setting:`task_ignore_result`
     #: setting.
     #: setting.
-    ignore_result = None
+    ignore_result: bool = None
 
 
     #: If enabled the request will keep track of subtasks started by
     #: If enabled the request will keep track of subtasks started by
     #: this task, and this information will be sent with the result
     #: this task, and this information will be sent with the result
     #: (``result.children``).
     #: (``result.children``).
-    trail = True
+    trail: bool = True
 
 
     #: If enabled the worker will send monitoring events related to
     #: If enabled the worker will send monitoring events related to
     #: this task (but only if the worker is configured to send
     #: this task (but only if the worker is configured to send
@@ -194,29 +207,29 @@ class Task:
     #: Note that this has no effect on the task-failure event case
     #: Note that this has no effect on the task-failure event case
     #: where a task is not registered (as it will have no task class
     #: where a task is not registered (as it will have no task class
     #: to check this flag).
     #: to check this flag).
-    send_events = True
+    send_events: bool = True
 
 
     #: When enabled errors will be stored even if the task is otherwise
     #: When enabled errors will be stored even if the task is otherwise
     #: configured to ignore results.
     #: configured to ignore results.
-    store_errors_even_if_ignored = None
+    store_errors_even_if_ignored: bool = None
 
 
     #: The name of a serializer that are registered with
     #: The name of a serializer that are registered with
     #: :mod:`kombu.serialization.registry`.  Default is `'pickle'`.
     #: :mod:`kombu.serialization.registry`.  Default is `'pickle'`.
-    serializer = None
+    serializer: str = None
 
 
     #: Hard time limit.
     #: Hard time limit.
     #: Defaults to the :setting:`task_time_limit` setting.
     #: Defaults to the :setting:`task_time_limit` setting.
-    time_limit = None
+    time_limit: float = None
 
 
     #: Soft time limit.
     #: Soft time limit.
     #: Defaults to the :setting:`task_soft_time_limit` setting.
     #: Defaults to the :setting:`task_soft_time_limit` setting.
-    soft_time_limit = None
+    soft_time_limit: float = None
 
 
     #: The result store backend used for this task.
     #: The result store backend used for this task.
-    backend = None
+    backend: BackendT = None
 
 
     #: If disabled this task won't be registered automatically.
     #: If disabled this task won't be registered automatically.
-    autoregister = True
+    autoregister: bool = True
 
 
     #: If enabled the task will report its status as 'started' when the task
     #: If enabled the task will report its status as 'started' when the task
     #: is executed by a worker.  Disabled by default as the normal behavior
     #: is executed by a worker.  Disabled by default as the normal behavior
@@ -229,7 +242,7 @@ class Task:
     #:
     #:
     #: The application default can be overridden using the
     #: The application default can be overridden using the
     #: :setting:`task_track_started` setting.
     #: :setting:`task_track_started` setting.
-    track_started = None
+    track_started: bool = None
 
 
     #: When enabled messages for this task will be acknowledged **after**
     #: When enabled messages for this task will be acknowledged **after**
     #: the task has been executed, and not *just before* (the
     #: the task has been executed, and not *just before* (the
@@ -240,7 +253,7 @@ class Task:
     #:
     #:
     #: The application default can be overridden with the
     #: The application default can be overridden with the
     #: :setting:`task_acks_late` setting.
     #: :setting:`task_acks_late` setting.
-    acks_late = None
+    acks_late: bool = None
 
 
     #: Even if :attr:`acks_late` is enabled, the worker will
     #: Even if :attr:`acks_late` is enabled, the worker will
     #: acknowledge tasks when the worker process executing them abruptly
     #: acknowledge tasks when the worker process executing them abruptly
@@ -252,7 +265,7 @@ class Task:
     #:
     #:
     #: Warning: Enabling this can cause message loops; make sure you know
     #: Warning: Enabling this can cause message loops; make sure you know
     #: what you're doing.
     #: what you're doing.
-    reject_on_worker_lost = None
+    reject_on_worker_lost: bool = None
 
 
     #: Tuple of expected exceptions.
     #: Tuple of expected exceptions.
     #:
     #:
@@ -260,29 +273,29 @@ class Task:
     #: and that shouldn't be regarded as a real error by the worker.
     #: and that shouldn't be regarded as a real error by the worker.
     #: Currently this means that the state will be updated to an error
     #: Currently this means that the state will be updated to an error
     #: state, but the worker won't log the event as an error.
     #: state, but the worker won't log the event as an error.
-    throws = ()
+    throws: Tuple[type] = ()
 
 
     #: Default task expiry time.
     #: Default task expiry time.
-    expires = None
+    expires: float = None
 
 
     #: Max length of result representation used in logs and events.
     #: Max length of result representation used in logs and events.
-    resultrepr_maxsize = 1024
+    resultrepr_maxsize: int = 1024
 
 
     #: Task request stack, the current request will be the topmost.
     #: Task request stack, the current request will be the topmost.
-    request_stack = None
+    request_stack: LocalStack = None
 
 
     #: Some may expect a request to exist even if the task hasn't been
     #: Some may expect a request to exist even if the task hasn't been
     #: called.  This should probably be deprecated.
     #: called.  This should probably be deprecated.
-    _default_request = None
+    _default_request: Context = None
 
 
     #: Deprecated attribute ``abstract`` here for compatibility.
     #: Deprecated attribute ``abstract`` here for compatibility.
-    abstract = True
+    abstract: bool = True
 
 
-    _exec_options = None
+    _exec_options: Mapping = None
 
 
-    __bound__ = False
+    __bound__: bool = False
 
 
-    from_config = (
+    from_config: Tuple[Tuple[str, str], ...] = (
         ('serializer', 'task_serializer'),
         ('serializer', 'task_serializer'),
         ('rate_limit', 'task_default_rate_limit'),
         ('rate_limit', 'task_default_rate_limit'),
         ('track_started', 'task_track_started'),
         ('track_started', 'task_track_started'),
@@ -292,13 +305,14 @@ class Task:
         ('store_errors_even_if_ignored', 'task_store_errors_even_if_ignored'),
         ('store_errors_even_if_ignored', 'task_store_errors_even_if_ignored'),
     )
     )
 
 
-    _backend = None  # set by backend property.
+    # set by backend property.
+    _backend: BackendT = None
 
 
     # - Tasks are lazily bound, so that configuration is not set
     # - Tasks are lazily bound, so that configuration is not set
     # - until the task is actually used
     # - until the task is actually used
 
 
     @classmethod
     @classmethod
-    def bind(cls, app):
+    def bind(cls, app: AppT) -> AppT:
         was_bound, cls.__bound__ = cls.__bound__, True
         was_bound, cls.__bound__ = cls.__bound__, True
         cls._app = app
         cls._app = app
         conf = app.conf
         conf = app.conf
@@ -324,17 +338,17 @@ class Task:
         return app
         return app
 
 
     @classmethod
     @classmethod
-    def on_bound(cls, app):
+    def on_bound(cls, app: AppT) -> None:
         """Called when the task is bound to an app.
         """Called when the task is bound to an app.
 
 
         Note:
         Note:
             This class method can be defined to do additional actions when
             This class method can be defined to do additional actions when
             the task class is bound to an app.
             the task class is bound to an app.
         """
         """
-        pass
+        ...
 
 
     @classmethod
     @classmethod
-    def _get_app(cls):
+    def _get_app(cls) -> AppT:
         if cls._app is None:
         if cls._app is None:
             cls._app = current_app
             cls._app = current_app
         if not cls.__bound__:
         if not cls.__bound__:
@@ -345,7 +359,7 @@ class Task:
     app = class_property(_get_app, bind)
     app = class_property(_get_app, bind)
 
 
     @classmethod
     @classmethod
-    def annotate(cls):
+    def annotate(cls) -> None:
         for d in resolve_all_annotations(cls.app.annotations, cls):
         for d in resolve_all_annotations(cls.app.annotations, cls):
             for key, value in d.items():
             for key, value in d.items():
                 if key.startswith('@'):
                 if key.startswith('@'):
@@ -354,7 +368,7 @@ class Task:
                     setattr(cls, key, value)
                     setattr(cls, key, value)
 
 
     @classmethod
     @classmethod
-    def add_around(cls, attr, around):
+    def add_around(cls, attr: str, around: Callable) -> None:
         orig = getattr(cls, attr)
         orig = getattr(cls, attr)
         if getattr(orig, '__wrapped__', None):
         if getattr(orig, '__wrapped__', None):
             orig = orig.__wrapped__
             orig = orig.__wrapped__
@@ -362,7 +376,7 @@ class Task:
         meth.__wrapped__ = orig
         meth.__wrapped__ = orig
         setattr(cls, attr, meth)
         setattr(cls, attr, meth)
 
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         _task_stack.push(self)
         _task_stack.push(self)
         self.push_request(args=args, kwargs=kwargs)
         self.push_request(args=args, kwargs=kwargs)
         try:
         try:
@@ -374,7 +388,7 @@ class Task:
             self.pop_request()
             self.pop_request()
             _task_stack.pop()
             _task_stack.pop()
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         # - tasks are pickled into the name of the task only, and the reciever
         # - tasks are pickled into the name of the task only, and the reciever
         # - simply grabs it from the local registry.
         # - simply grabs it from the local registry.
         # - in later versions the module of the task is also included,
         # - in later versions the module of the task is also included,
@@ -384,14 +398,15 @@ class Task:
         mod = mod if mod and mod in sys.modules else None
         mod = mod if mod and mod in sys.modules else None
         return (_unpickle_task_v2, (self.name, mod), None)
         return (_unpickle_task_v2, (self.name, mod), None)
 
 
-    def run(self, *args, **kwargs):
+    def run(self, *args, **kwargs) -> Any:
         """The body of the task executed by workers."""
         """The body of the task executed by workers."""
         raise NotImplementedError('Tasks must define the run method.')
         raise NotImplementedError('Tasks must define the run method.')
 
 
-    def start_strategy(self, app, consumer, **kwargs):
+    def start_strategy(self, app: AppT, consumer: WorkerConsumerT,
+                       **kwargs) -> Callable:
         return instantiate(self.Strategy, self, app, consumer, **kwargs)
         return instantiate(self.Strategy, self, app, consumer, **kwargs)
 
 
-    def delay(self, *args, **kwargs):
+    def delay(self, *args, **kwargs) -> ResultT:
         """Star argument version of :meth:`apply_async`.
         """Star argument version of :meth:`apply_async`.
 
 
         Does not support the extra options enabled by :meth:`apply_async`.
         Does not support the extra options enabled by :meth:`apply_async`.
@@ -404,8 +419,15 @@ class Task:
         """
         """
         return self.apply_async(args, kwargs)
         return self.apply_async(args, kwargs)
 
 
-    def apply_async(self, args=None, kwargs=None, task_id=None, producer=None,
-                    link=None, link_error=None, shadow=None, **options):
+    def apply_async(self,
+                    args: Sequence = None,
+                    kwargs: Mapping = None,
+                    task_id: str = None,
+                    producer: ProducerT = None,
+                    link: Sequence[SignatureT] = None,
+                    link_error: Sequence[SignatureT] = None,
+                    shadow: str = None,
+                    **options) -> ResultT:
         """Apply tasks asynchronously by sending a message.
         """Apply tasks asynchronously by sending a message.
 
 
         Arguments:
         Arguments:
@@ -526,7 +548,10 @@ class Task:
             **options
             **options
         )
         )
 
 
-    def shadow_name(self, args, kwargs, options):
+    def shadow_name(self,
+                    args: Sequence,
+                    kwargs: Mapping,
+                    options: Mapping) -> str:
         """Override for custom task name in worker logs/monitoring.
         """Override for custom task name in worker logs/monitoring.
 
 
         Example:
         Example:
@@ -548,8 +573,12 @@ class Task:
         """
         """
         pass
         pass
 
 
-    def signature_from_request(self, request=None, args=None, kwargs=None,
-                               queue=None, **extra_options):
+    def signature_from_request(self,
+                               request: Context = None,
+                               args: Sequence = None,
+                               kwargs: Mapping = None,
+                               queue: str = None,
+                               **extra_options) -> SignatureT:
         request = self.request if request is None else request
         request = self.request if request is None else request
         args = request.args if args is None else args
         args = request.args if args is None else args
         kwargs = request.kwargs if kwargs is None else kwargs
         kwargs = request.kwargs if kwargs is None else kwargs
@@ -570,8 +599,15 @@ class Task:
         )
         )
     subtask_from_request = signature_from_request  # XXX compat
     subtask_from_request = signature_from_request  # XXX compat
 
 
-    def retry(self, args=None, kwargs=None, exc=None, throw=True,
-              eta=None, countdown=None, max_retries=None, **options):
+    def retry(self,
+              args: Sequence = None,
+              kwargs: Mapping = None,
+              exc: Exception = None,
+              throw: bool = True,
+              eta: datetime = None,
+              countdown: float = None,
+              max_retries: int = None,
+              **options) -> None:
         """Retry the task.
         """Retry the task.
 
 
         Example:
         Example:
@@ -680,10 +716,18 @@ class Task:
             raise ret
             raise ret
         return ret
         return ret
 
 
-    def apply(self, args=None, kwargs=None,
-              link=None, link_error=None,
-              task_id=None, retries=None, throw=None,
-              logfile=None, loglevel=None, headers=None, **options):
+    def apply(self,
+              args: Sequence = None,
+              kwargs: Mapping = None,
+              link: Sequence[SignatureT] = None,
+              link_error: Sequence[SignatureT] = None,
+              task_id: str = None,
+              retries: int = None,
+              throw: bool = None,
+              logfile: str = None,
+              loglevel: int = None,
+              headers: Mapping = None,
+              **options) -> ResultT:
         """Execute this task locally, by blocking until the task returns.
         """Execute this task locally, by blocking until the task returns.
 
 
         Arguments:
         Arguments:
@@ -735,7 +779,7 @@ class Task:
         state = states.SUCCESS if ret.info is None else ret.info.state
         state = states.SUCCESS if ret.info is None else ret.info.state
         return EagerResult(task_id, retval, state, traceback=tb)
         return EagerResult(task_id, retval, state, traceback=tb)
 
 
-    def AsyncResult(self, task_id, **kwargs):
+    def AsyncResult(self, task_id: str, **kwargs) -> ResultT:
         """Get AsyncResult instance for this kind of task.
         """Get AsyncResult instance for this kind of task.
 
 
         Arguments:
         Arguments:
@@ -744,7 +788,8 @@ class Task:
         return self._get_app().AsyncResult(
         return self._get_app().AsyncResult(
             task_id, backend=self.backend, **kwargs)
             task_id, backend=self.backend, **kwargs)
 
 
-    def signature(self, args=None, *starargs, **starkwargs):
+    def signature(self, args: Sequence = None,
+                  *starargs, **starkwargs) -> SignatureT:
         """Create signature.
         """Create signature.
 
 
         Returns:
         Returns:
@@ -756,36 +801,39 @@ class Task:
         return signature(self, args, *starargs, **starkwargs)
         return signature(self, args, *starargs, **starkwargs)
     subtask = signature
     subtask = signature
 
 
-    def s(self, *args, **kwargs):
+    def s(self, *args, **kwargs) -> SignatureT:
         """Create signature.
         """Create signature.
 
 
         Shortcut for ``.s(*a, **k) -> .signature(a, k)``.
         Shortcut for ``.s(*a, **k) -> .signature(a, k)``.
         """
         """
         return self.signature(args, kwargs)
         return self.signature(args, kwargs)
 
 
-    def si(self, *args, **kwargs):
+    def si(self, *args, **kwargs) -> SignatureT:
         """Create immutable signature.
         """Create immutable signature.
 
 
         Shortcut for ``.si(*a, **k) -> .signature(a, k, immutable=True)``.
         Shortcut for ``.si(*a, **k) -> .signature(a, k, immutable=True)``.
         """
         """
         return self.signature(args, kwargs, immutable=True)
         return self.signature(args, kwargs, immutable=True)
 
 
-    def chunks(self, it, n):
+    def chunks(self, it: Iterable, n: int) -> SignatureT:
         """Create a :class:`~celery.canvas.chunks` task for this task."""
         """Create a :class:`~celery.canvas.chunks` task for this task."""
         from celery import chunks
         from celery import chunks
         return chunks(self.s(), it, n, app=self.app)
         return chunks(self.s(), it, n, app=self.app)
 
 
-    def map(self, it):
+    def map(self, it: Iterable) -> SignatureT:
         """Create a :class:`~celery.canvas.xmap` task from ``it``."""
         """Create a :class:`~celery.canvas.xmap` task from ``it``."""
         from celery import xmap
         from celery import xmap
         return xmap(self.s(), it, app=self.app)
         return xmap(self.s(), it, app=self.app)
 
 
-    def starmap(self, it):
+    def starmap(self, it: Iterable) -> SignatureT:
         """Create a :class:`~celery.canvas.xstarmap` task from ``it``."""
         """Create a :class:`~celery.canvas.xstarmap` task from ``it``."""
         from celery import xstarmap
         from celery import xstarmap
         return xstarmap(self.s(), it, app=self.app)
         return xstarmap(self.s(), it, app=self.app)
 
 
-    def send_event(self, type_, retry=True, retry_policy=None, **fields):
+    def send_event(self, type_: str,
+                   retry: bool = True,
+                   retry_policy: Mapping = None,
+                   **fields) -> Awaitable:
         """Send monitoring event message.
         """Send monitoring event message.
 
 
         This can be used to add custom event types in :pypi:`Flower`
         This can be used to add custom event types in :pypi:`Flower`
@@ -811,7 +859,7 @@ class Task:
                 type_,
                 type_,
                 uuid=req.id, retry=retry, retry_policy=retry_policy, **fields)
                 uuid=req.id, retry=retry, retry_policy=retry_policy, **fields)
 
 
-    def replace(self, sig):
+    def replace(self, sig: SignatureT) -> None:
         """Replace this task, with a new task inheriting the task id.
         """Replace this task, with a new task inheriting the task id.
 
 
         .. versionadded:: 4.0
         .. versionadded:: 4.0
@@ -851,7 +899,8 @@ class Task:
         sig.delay()
         sig.delay()
         raise Ignore('Replaced by new task')
         raise Ignore('Replaced by new task')
 
 
-    def add_to_chord(self, sig, lazy=False):
+    def add_to_chord(self, sig: SignatureT,
+                     lazy: bool = False) -> Union[ResultT, SignatureT]:
         """Add signature to the chord the current task is a member of.
         """Add signature to the chord the current task is a member of.
 
 
         .. versionadded:: 4.0
         .. versionadded:: 4.0
@@ -871,7 +920,10 @@ class Task:
         self.backend.add_to_chord(self.request.group, result)
         self.backend.add_to_chord(self.request.group, result)
         return sig.delay() if not lazy else sig
         return sig.delay() if not lazy else sig
 
 
-    def update_state(self, task_id=None, state=None, meta=None):
+    def update_state(self,
+                     task_id: str = None,
+                     state: str = None,
+                     meta: Mapping = None) -> None:
         """Update task state.
         """Update task state.
 
 
         Arguments:
         Arguments:
@@ -884,7 +936,11 @@ class Task:
             task_id = self.request.id
             task_id = self.request.id
         self.backend.store_result(task_id, meta, state)
         self.backend.store_result(task_id, meta, state)
 
 
-    def on_success(self, retval, task_id, args, kwargs):
+    def on_success(self,
+                   retval: Any,
+                   task_id: str,
+                   args: Sequence,
+                   kwargs: Mapping) -> None:
         """Success handler.
         """Success handler.
 
 
         Run by the worker if the task executes successfully.
         Run by the worker if the task executes successfully.
@@ -898,9 +954,14 @@ class Task:
         Returns:
         Returns:
             None: The return value of this handler is ignored.
             None: The return value of this handler is ignored.
         """
         """
-        pass
-
-    def on_retry(self, exc, task_id, args, kwargs, einfo):
+        ...
+
+    def on_retry(self,
+                 exc: Exception,
+                 task_id: str,
+                 args: Sequence,
+                 kwargs: Mapping,
+                 einfo: ExceptionInfo) -> None:
         """Retry handler.
         """Retry handler.
 
 
         This is run by the worker when the task is to be retried.
         This is run by the worker when the task is to be retried.
@@ -915,9 +976,14 @@ class Task:
         Returns:
         Returns:
             None: The return value of this handler is ignored.
             None: The return value of this handler is ignored.
         """
         """
-        pass
-
-    def on_failure(self, exc, task_id, args, kwargs, einfo):
+        ...
+
+    def on_failure(self,
+                   exc: Exception,
+                   task_id: str,
+                   args: Sequence,
+                   kwargs: Mapping,
+                   einfo: ExceptionInfo) -> None:
         """Error handler.
         """Error handler.
 
 
         This is run by the worker when the task fails.
         This is run by the worker when the task fails.
@@ -932,9 +998,11 @@ class Task:
         Returns:
         Returns:
             None: The return value of this handler is ignored.
             None: The return value of this handler is ignored.
         """
         """
-        pass
+        ...
 
 
-    def after_return(self, status, retval, task_id, args, kwargs, einfo):
+    def after_return(self, status: str, retval: Any, task_id: str,
+                     args: Sequence, kwargs: Mapping,
+                     einfo: ExceptionInfo) -> None:
         """Handler called after the task returns.
         """Handler called after the task returns.
 
 
         Arguments:
         Arguments:
@@ -948,24 +1016,24 @@ class Task:
         Returns:
         Returns:
             None: The return value of this handler is ignored.
             None: The return value of this handler is ignored.
         """
         """
-        pass
+        ...
 
 
-    def add_trail(self, result):
+    def add_trail(self, result: ResultT) -> ResultT:
         if self.trail:
         if self.trail:
             self.request.children.append(result)
             self.request.children.append(result)
         return result
         return result
 
 
-    def push_request(self, *args, **kwargs):
+    def push_request(self, *args, **kwargs) -> None:
         self.request_stack.push(Context(*args, **kwargs))
         self.request_stack.push(Context(*args, **kwargs))
 
 
-    def pop_request(self):
+    def pop_request(self) -> None:
         self.request_stack.pop()
         self.request_stack.pop()
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(task)``."""
         """``repr(task)``."""
         return _reprtask(self, R_SELF_TASK if self.__self__ else R_INSTANCE)
         return _reprtask(self, R_SELF_TASK if self.__self__ else R_INSTANCE)
 
 
-    def _get_request(self):
+    def _get_request(self) -> Context:
         """Get current request object."""
         """Get current request object."""
         req = self.request_stack.top
         req = self.request_stack.top
         if req is None:
         if req is None:
@@ -977,23 +1045,23 @@ class Task:
         return req
         return req
     request = property(_get_request)
     request = property(_get_request)
 
 
-    def _get_exec_options(self):
+    def _get_exec_options(self) -> Mapping:
         if self._exec_options is None:
         if self._exec_options is None:
             self._exec_options = extract_exec_options(self)
             self._exec_options = extract_exec_options(self)
         return self._exec_options
         return self._exec_options
 
 
     @property
     @property
-    def backend(self):
+    def backend(self) -> BackendT:
         backend = self._backend
         backend = self._backend
         if backend is None:
         if backend is None:
             return self.app.backend
             return self.app.backend
         return backend
         return backend
 
 
     @backend.setter
     @backend.setter
-    def backend(self, value):  # noqa
+    def backend(self, value: BackendT) -> None:  # noqa
         self._backend = value
         self._backend = value
 
 
     @property
     @property
-    def __name__(self):
+    def __name__(self) -> str:
         return self.__class__.__name__
         return self.__class__.__name__
 BaseTask = Task  # noqa: E305 XXX compat alias
 BaseTask = Task  # noqa: E305 XXX compat alias

+ 17 - 5
celery/app/trace.py

@@ -19,8 +19,8 @@ import logging
 import os
 import os
 import sys
 import sys
 
 
-from collections import namedtuple
 from time import monotonic
 from time import monotonic
+from typing import Any, NamedTuple
 from warnings import warn
 from warnings import warn
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
@@ -79,9 +79,16 @@ LOG_RETRY = """\
 Task %(name)s[%(id)s] retry: %(exc)s\
 Task %(name)s[%(id)s] retry: %(exc)s\
 """
 """
 
 
-log_policy_t = namedtuple(
-    'log_policy_t', ('format', 'description', 'severity', 'traceback', 'mail'),
-)
+
+class log_policy_t(NamedTuple):
+    """Describes the logging policy for a specific state."""
+
+    format: str
+    description: str
+    severity: int
+    traceback: int
+    mail: int
+
 
 
 log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1)
 log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1)
 log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0)
 log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0)
@@ -111,7 +118,12 @@ IGNORE_STATES = frozenset({IGNORED, RETRY, REJECTED})
 _localized = []
 _localized = []
 _patched = {}
 _patched = {}
 
 
-trace_ok_t = namedtuple('trace_ok_t', ('retval', 'info', 'runtime', 'retstr'))
+trace_ok_t = NamedTuple('trace_ok_t', [
+    ('retval', Any),
+    ('info', 'TraceInfo'),
+    ('runtime', float),
+    ('retstr', str),
+])
 
 
 
 
 def task_has_custom(task, attr):
 def task_has_custom(task, attr):

+ 61 - 36
celery/app/utils.py

@@ -4,21 +4,23 @@ import os
 import platform as _platform
 import platform as _platform
 import re
 import re
 
 
-from collections import Mapping, namedtuple
+from collections import Mapping
 from copy import deepcopy
 from copy import deepcopy
 from types import ModuleType
 from types import ModuleType
+from typing import Any, Callable, MutableMapping, NamedTuple, Set, Union
 
 
 from kombu.utils.url import maybe_sanitize_url
 from kombu.utils.url import maybe_sanitize_url
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.platforms import pyimplementation
 from celery.platforms import pyimplementation
+from celery.types import AppT
 from celery.utils.collections import ConfigurationView
 from celery.utils.collections import ConfigurationView
 from celery.utils.text import pretty
 from celery.utils.text import pretty
 from celery.utils.imports import import_from_cwd, symbol_by_name, qualname
 from celery.utils.imports import import_from_cwd, symbol_by_name, qualname
 
 
 from .defaults import (
 from .defaults import (
     _TO_NEW_KEY, _TO_OLD_KEY, _OLD_DEFAULTS, _OLD_SETTING_KEYS,
     _TO_NEW_KEY, _TO_OLD_KEY, _OLD_DEFAULTS, _OLD_SETTING_KEYS,
-    DEFAULTS, SETTING_KEYS, find,
+    DEFAULTS, SETTING_KEYS, find, find_result_t,
 )
 )
 
 
 __all__ = [
 __all__ = [
@@ -65,7 +67,7 @@ Or change all of the settings to use the new format :)
 FMT_REPLACE_SETTING = '{replace:<36} -> {with_}'
 FMT_REPLACE_SETTING = '{replace:<36} -> {with_}'
 
 
 
 
-def appstr(app):
+def appstr(app: AppT) -> str:
     """String used in __repr__ etc, to id app instances."""
     """String used in __repr__ etc, to id app instances."""
     return '{0}:{1:#x}'.format(app.main or '__main__', id(app))
     return '{0}:{1:#x}'.format(app.main or '__main__', id(app))
 
 
@@ -80,7 +82,7 @@ class Settings(ConfigurationView):
     """
     """
 
 
     @property
     @property
-    def broker_read_url(self):
+    def broker_read_url(self) -> str:
         return (
         return (
             os.environ.get('CELERY_BROKER_READ_URL') or
             os.environ.get('CELERY_BROKER_READ_URL') or
             self.get('broker_read_url') or
             self.get('broker_read_url') or
@@ -88,7 +90,7 @@ class Settings(ConfigurationView):
         )
         )
 
 
     @property
     @property
-    def broker_write_url(self):
+    def broker_write_url(self) -> str:
         return (
         return (
             os.environ.get('CELERY_BROKER_WRITE_URL') or
             os.environ.get('CELERY_BROKER_WRITE_URL') or
             self.get('broker_write_url') or
             self.get('broker_write_url') or
@@ -96,40 +98,40 @@ class Settings(ConfigurationView):
         )
         )
 
 
     @property
     @property
-    def broker_url(self):
+    def broker_url(self) -> str:
         return (
         return (
             os.environ.get('CELERY_BROKER_URL') or
             os.environ.get('CELERY_BROKER_URL') or
             self.first('broker_url', 'broker_host')
             self.first('broker_url', 'broker_host')
         )
         )
 
 
     @property
     @property
-    def task_default_exchange(self):
+    def task_default_exchange(self) -> str:
         return self.first(
         return self.first(
             'task_default_exchange',
             'task_default_exchange',
             'task_default_queue',
             'task_default_queue',
         )
         )
 
 
     @property
     @property
-    def task_default_routing_key(self):
+    def task_default_routing_key(self) -> str:
         return self.first(
         return self.first(
             'task_default_routing_key',
             'task_default_routing_key',
             'task_default_queue',
             'task_default_queue',
         )
         )
 
 
     @property
     @property
-    def timezone(self):
+    def timezone(self) -> str:
         # this way we also support django's time zone.
         # this way we also support django's time zone.
         return self.first('timezone', 'time_zone')
         return self.first('timezone', 'time_zone')
 
 
-    def without_defaults(self):
+    def without_defaults(self) -> 'Settings':
         """Return the current configuration, but without defaults."""
         """Return the current configuration, but without defaults."""
         # the last stash is the default settings, so just skip that
         # the last stash is the default settings, so just skip that
         return Settings({}, self.maps[:-1])
         return Settings({}, self.maps[:-1])
 
 
-    def value_set_for(self, key):
+    def value_set_for(self, key: str) -> bool:
         return key in self.without_defaults()
         return key in self.without_defaults()
 
 
-    def find_option(self, name, namespace=''):
+    def find_option(self, name: str, namespace: str = '') -> find_result_t:
         """Search for option by name.
         """Search for option by name.
 
 
         Example:
         Example:
@@ -146,11 +148,11 @@ class Settings(ConfigurationView):
         """
         """
         return find(name, namespace)
         return find(name, namespace)
 
 
-    def find_value_for_key(self, name, namespace='celery'):
+    def find_value_for_key(self, name: str, namespace: str = 'celery') -> Any:
         """Shortcut to ``get_by_parts(*find_option(name)[:-1])``."""
         """Shortcut to ``get_by_parts(*find_option(name)[:-1])``."""
         return self.get_by_parts(*self.find_option(name, namespace)[:-1])
         return self.get_by_parts(*self.find_option(name, namespace)[:-1])
 
 
-    def get_by_parts(self, *parts):
+    def get_by_parts(self, *parts) -> Any:
         """Return the current value for setting specified as a path.
         """Return the current value for setting specified as a path.
 
 
         Example:
         Example:
@@ -160,7 +162,7 @@ class Settings(ConfigurationView):
         """
         """
         return self['_'.join(part for part in parts if part)]
         return self['_'.join(part for part in parts if part)]
 
 
-    def finalize(self):
+    def finalize(self) -> None:
         # See PendingConfiguration in celery/app/base.py
         # See PendingConfiguration in celery/app/base.py
         # first access will read actual configuration.
         # first access will read actual configuration.
         try:
         try:
@@ -169,7 +171,9 @@ class Settings(ConfigurationView):
             pass
             pass
         return self
         return self
 
 
-    def table(self, with_defaults=False, censored=True):
+    def table(self, *,
+              with_defaults: bool = False,
+              censored: bool = True) -> Mapping:
         filt = filter_hidden_settings if censored else lambda v: v
         filt = filter_hidden_settings if censored else lambda v: v
         dict_members = dir(dict)
         dict_members = dir(dict)
         self.finalize()
         self.finalize()
@@ -179,24 +183,29 @@ class Settings(ConfigurationView):
             if not k.startswith('_') and k not in dict_members
             if not k.startswith('_') and k not in dict_members
         })
         })
 
 
-    def humanize(self, with_defaults=False, censored=True):
+    def humanize(self, *,
+                 with_defaults: bool = False,
+                 censored: bool = True) -> str:
         """Return a human readable text showing configuration changes."""
         """Return a human readable text showing configuration changes."""
         return '\n'.join(
         return '\n'.join(
             '{0}: {1}'.format(key, pretty(value, width=50))
             '{0}: {1}'.format(key, pretty(value, width=50))
             for key, value in self.table(with_defaults, censored).items())
             for key, value in self.table(with_defaults, censored).items())
 
 
 
 
-def _new_key_to_old(key, convert=_TO_OLD_KEY.get):
+def _new_key_to_old(key: str, *, convert: Callable = _TO_OLD_KEY.get) -> str:
     return convert(key, key)
     return convert(key, key)
 
 
 
 
-def _old_key_to_new(key, convert=_TO_NEW_KEY.get):
+def _old_key_to_new(key: str, *, convert: Callable = _TO_NEW_KEY.get) -> str:
     return convert(key, key)
     return convert(key, key)
 
 
 
 
-_settings_info_t = namedtuple('settings_info_t', (
-    'defaults', 'convert', 'key_t', 'mix_error',
-))
+class _settings_info_t(NamedTuple):
+    defaults: Mapping
+    convert: Mapping
+    key_t: Callable
+    mix_error: str
+
 
 
 _settings_info = _settings_info_t(
 _settings_info = _settings_info_t(
     DEFAULTS, _TO_NEW_KEY, _old_key_to_new, E_MIX_OLD_INTO_NEW,
     DEFAULTS, _TO_NEW_KEY, _old_key_to_new, E_MIX_OLD_INTO_NEW,
@@ -206,8 +215,12 @@ _old_settings_info = _settings_info_t(
 )
 )
 
 
 
 
-def detect_settings(conf, preconf={}, ignore_keys=set(), prefix=None,
-                    all_keys=SETTING_KEYS, old_keys=_OLD_SETTING_KEYS):
+def detect_settings(conf: MutableMapping,
+                    preconf: Mapping = {},
+                    ignore_keys: Set = set(),
+                    prefix: str = None,
+                    all_keys: Set = SETTING_KEYS,
+                    old_keys: Set = _OLD_SETTING_KEYS) -> Settings:
     source = conf
     source = conf
     if conf is None:
     if conf is None:
         source, conf = preconf, {}
         source, conf = preconf, {}
@@ -260,42 +273,51 @@ def detect_settings(conf, preconf={}, ignore_keys=set(), prefix=None,
 class AppPickler:
 class AppPickler:
     """Old application pickler/unpickler (< 3.1)."""
     """Old application pickler/unpickler (< 3.1)."""
 
 
-    def __call__(self, cls, *args):
+    def __call__(self, cls: type, *args) -> AppT:
         kwargs = self.build_kwargs(*args)
         kwargs = self.build_kwargs(*args)
         app = self.construct(cls, **kwargs)
         app = self.construct(cls, **kwargs)
         self.prepare(app, **kwargs)
         self.prepare(app, **kwargs)
         return app
         return app
 
 
-    def prepare(self, app, **kwargs):
+    def prepare(self, app: AppT, **kwargs) -> None:
         app.conf.update(kwargs['changes'])
         app.conf.update(kwargs['changes'])
 
 
-    def build_kwargs(self, *args):
+    def build_kwargs(self, *args) -> Mapping:
         return self.build_standard_kwargs(*args)
         return self.build_standard_kwargs(*args)
 
 
-    def build_standard_kwargs(self, main, changes, loader, backend, amqp,
-                              events, log, control, accept_magic_kwargs,
-                              config_source=None):
+    def build_standard_kwargs(
+            self,
+            main: str,
+            changes: Mapping,
+            loader: Union[str, type],
+            backend: Union[str, type],
+            amqp: Union[str, type],
+            events: Union[str, type],
+            log: Union[str, type],
+            control: Union[str, type],
+            accept_magic_kwargs: bool,
+            config_source: str = None) -> Mapping:
         return dict(main=main, loader=loader, backend=backend, amqp=amqp,
         return dict(main=main, loader=loader, backend=backend, amqp=amqp,
                     changes=changes, events=events, log=log, control=control,
                     changes=changes, events=events, log=log, control=control,
                     set_as_current=False,
                     set_as_current=False,
                     config_source=config_source)
                     config_source=config_source)
 
 
-    def construct(self, cls, **kwargs):
+    def construct(self, cls: Callable, **kwargs) -> Any:
         return cls(**kwargs)
         return cls(**kwargs)
 
 
 
 
-def _unpickle_app(cls, pickler, *args):
+def _unpickle_app(cls, pickler: type, *args) -> AppT:
     """Rebuild app for versions 2.5+."""
     """Rebuild app for versions 2.5+."""
     return pickler()(cls, *args)
     return pickler()(cls, *args)
 
 
 
 
-def _unpickle_app_v2(cls, kwargs):
+def _unpickle_app_v2(cls: type, kwargs: Mapping) -> AppT:
     """Rebuild app for versions 3.1+."""
     """Rebuild app for versions 3.1+."""
     kwargs['set_as_current'] = False
     kwargs['set_as_current'] = False
     return cls(**kwargs)
     return cls(**kwargs)
 
 
 
 
-def filter_hidden_settings(conf):
+def filter_hidden_settings(conf: Mapping) -> Mapping:
     """Filter sensitive settings."""
     """Filter sensitive settings."""
     def maybe_censor(key, value, mask='*' * 8):
     def maybe_censor(key, value, mask='*' * 8):
         if isinstance(value, Mapping):
         if isinstance(value, Mapping):
@@ -314,7 +336,7 @@ def filter_hidden_settings(conf):
     return {k: maybe_censor(k, v) for k, v in conf.items()}
     return {k: maybe_censor(k, v) for k, v in conf.items()}
 
 
 
 
-def bugreport(app):
+def bugreport(app: AppT) -> str:
     """Return a string containing information useful in bug-reports."""
     """Return a string containing information useful in bug-reports."""
     import billiard
     import billiard
     import celery
     import celery
@@ -344,7 +366,10 @@ def bugreport(app):
     )
     )
 
 
 
 
-def find_app(app, symbol_by_name=symbol_by_name, imp=import_from_cwd):
+def find_app(app: AppT,
+             *,
+             symbol_by_name: Callable = symbol_by_name,
+             imp: Callable = import_from_cwd) -> AppT:
     """Find app by name."""
     """Find app by name."""
     from .base import Celery
     from .base import Celery
 
 

+ 80 - 52
celery/backends/async.py

@@ -4,14 +4,22 @@ import threading
 
 
 from collections import deque
 from collections import deque
 from time import monotonic, sleep
 from time import monotonic, sleep
-from weakref import WeakKeyDictionary
+from typing import (
+    Any, Awaitable, Callable, Iterable, Iterator, Mapping, Set, Sequence,
+)
 from queue import Empty
 from queue import Empty
+from weakref import WeakKeyDictionary
 
 
+from kombu.types import MessageT
 from kombu.utils.compat import detect_environment
 from kombu.utils.compat import detect_environment
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 from celery import states
 from celery import states
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
+from celery.types import AppT, BackendT, ResultT, ResultConsumerT
+from celery.utils.collections import BufferMap
+
+from .base import pending_results_t
 
 
 __all__ = [
 __all__ = [
     'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
     'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
@@ -21,9 +29,9 @@ __all__ = [
 drainers = {}
 drainers = {}
 
 
 
 
-def register_drainer(name):
+def register_drainer(name: str) -> Callable:
     """Decorator used to register a new result drainer type."""
     """Decorator used to register a new result drainer type."""
-    def _inner(cls):
+    def _inner(cls: type) -> type:
         drainers[name] = cls
         drainers[name] = cls
         return cls
         return cls
     return _inner
     return _inner
@@ -33,16 +41,19 @@ def register_drainer(name):
 class Drainer:
 class Drainer:
     """Result draining service."""
     """Result draining service."""
 
 
-    def __init__(self, result_consumer):
+    def __init__(self, result_consumer: ResultConsumerT) -> None:
         self.result_consumer = result_consumer
         self.result_consumer = result_consumer
 
 
-    def start(self):
-        pass
+    def start(self) -> None:
+        ...
 
 
-    def stop(self):
-        pass
+    def stop(self) -> None:
+        ...
 
 
-    def drain_events_until(self, p, timeout=None, on_interval=None, wait=None):
+    def drain_events_until(self, p: Awaitable,
+                           timeout: float = None,
+                           on_interval: Callable = None,
+                           wait: Callable = None) -> Iterator:
         wait = wait or self.result_consumer.drain_events
         wait = wait or self.result_consumer.drain_events
         time_start = monotonic()
         time_start = monotonic()
 
 
@@ -59,7 +70,8 @@ class Drainer:
             if p.ready:  # got event on the wanted channel.
             if p.ready:  # got event on the wanted channel.
                 break
                 break
 
 
-    def wait_for(self, p, wait, timeout=None):
+    def wait_for(self, p: Awaitable, wait: Callable,
+                 timeout: float = None) -> None:
         wait(timeout=timeout)
         wait(timeout=timeout)
 
 
 
 
@@ -67,13 +79,13 @@ class greenletDrainer(Drainer):
     spawn = None
     spawn = None
     _g = None
     _g = None
 
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         super(greenletDrainer, self).__init__(*args, **kwargs)
         super(greenletDrainer, self).__init__(*args, **kwargs)
         self._started = threading.Event()
         self._started = threading.Event()
         self._stopped = threading.Event()
         self._stopped = threading.Event()
         self._shutdown = threading.Event()
         self._shutdown = threading.Event()
 
 
-    def run(self):
+    def run(self) -> None:
         self._started.set()
         self._started.set()
         while not self._stopped.is_set():
         while not self._stopped.is_set():
             try:
             try:
@@ -82,16 +94,17 @@ class greenletDrainer(Drainer):
                 pass
                 pass
         self._shutdown.set()
         self._shutdown.set()
 
 
-    def start(self):
+    def start(self) -> None:
         if not self._started.is_set():
         if not self._started.is_set():
             self._g = self.spawn(self.run)
             self._g = self.spawn(self.run)
             self._started.wait()
             self._started.wait()
 
 
-    def stop(self):
+    def stop(self) -> None:
         self._stopped.set()
         self._stopped.set()
         self._shutdown.wait(threading.TIMEOUT_MAX)
         self._shutdown.wait(threading.TIMEOUT_MAX)
 
 
-    def wait_for(self, p, wait, timeout=None):
+    def wait_for(self, p: Awaitable, wait: Callable,
+                 timeout: float = None) -> None:
         self.start()
         self.start()
         if not p.ready:
         if not p.ready:
             sleep(0)
             sleep(0)
@@ -101,7 +114,7 @@ class greenletDrainer(Drainer):
 class eventletDrainer(greenletDrainer):
 class eventletDrainer(greenletDrainer):
 
 
     @cached_property
     @cached_property
-    def spawn(self):
+    def spawn(self) -> Callable:
         from eventlet import spawn
         from eventlet import spawn
         return spawn
         return spawn
 
 
@@ -110,7 +123,7 @@ class eventletDrainer(greenletDrainer):
 class geventDrainer(greenletDrainer):
 class geventDrainer(greenletDrainer):
 
 
     @cached_property
     @cached_property
-    def spawn(self):
+    def spawn(self) -> Callable:
         from gevent import spawn
         from gevent import spawn
         return spawn
         return spawn
 
 
@@ -118,10 +131,13 @@ class geventDrainer(greenletDrainer):
 class AsyncBackendMixin:
 class AsyncBackendMixin:
     """Mixin for backends that enables the async API."""
     """Mixin for backends that enables the async API."""
 
 
-    def _collect_into(self, result, bucket):
+    def _collect_into(self, result: ResultT, bucket: deque):
         self.result_consumer.buckets[result] = bucket
         self.result_consumer.buckets[result] = bucket
 
 
-    def iter_native(self, result, no_ack=True, **kwargs):
+    async def iter_native(
+            self, result: ResultT,
+            *,
+            no_ack: bool = True, **kwargs) -> Iterator[str, Mapping]:
         self._ensure_not_eager()
         self._ensure_not_eager()
 
 
         results = result.results
         results = result.results
@@ -145,7 +161,9 @@ class AsyncBackendMixin:
             node = bucket.popleft()
             node = bucket.popleft()
             yield node.id, node._cache
             yield node.id, node._cache
 
 
-    def add_pending_result(self, result, weak=False, start_drainer=True):
+    def add_pending_result(self, result: ResultT,
+                           weak: bool = False,
+                           start_drainer: bool = True) -> ResultT:
         if start_drainer:
         if start_drainer:
             self.result_consumer.drainer.start()
             self.result_consumer.drainer.start()
         try:
         try:
@@ -154,57 +172,64 @@ class AsyncBackendMixin:
             self._add_pending_result(result.id, result, weak=weak)
             self._add_pending_result(result.id, result, weak=weak)
         return result
         return result
 
 
-    def _maybe_resolve_from_buffer(self, result):
+    def _maybe_resolve_from_buffer(self, result: ResultT) -> None:
         result._maybe_set_cache(self._pending_messages.take(result.id))
         result._maybe_set_cache(self._pending_messages.take(result.id))
 
 
-    def _add_pending_result(self, task_id, result, weak=False):
+    def _add_pending_result(self, task_id: str, result: ResultT,
+                            weak: bool = False) -> None:
         concrete, weak_ = self._pending_results
         concrete, weak_ = self._pending_results
         if task_id not in weak_ and result.id not in concrete:
         if task_id not in weak_ and result.id not in concrete:
             (weak_ if weak else concrete)[task_id] = result
             (weak_ if weak else concrete)[task_id] = result
             self.result_consumer.consume_from(task_id)
             self.result_consumer.consume_from(task_id)
 
 
-    def add_pending_results(self, results, weak=False):
+    def add_pending_results(self, results: Sequence[ResultT],
+                            weak: bool = False) -> None:
         self.result_consumer.drainer.start()
         self.result_consumer.drainer.start()
-        return [self.add_pending_result(result, weak=weak, start_drainer=False)
-                for result in results]
+        [self.add_pending_result(result, weak=weak, start_drainer=False)
+         for result in results]
 
 
-    def remove_pending_result(self, result):
+    def remove_pending_result(self, result: ResultT) -> ResultT:
         self._remove_pending_result(result.id)
         self._remove_pending_result(result.id)
         self.on_result_fulfilled(result)
         self.on_result_fulfilled(result)
         return result
         return result
 
 
-    def _remove_pending_result(self, task_id):
+    def _remove_pending_result(self, task_id: str) -> None:
         for map in self._pending_results:
         for map in self._pending_results:
             map.pop(task_id, None)
             map.pop(task_id, None)
 
 
-    def on_result_fulfilled(self, result):
+    def on_result_fulfilled(self, result: ResultT) -> None:
         self.result_consumer.cancel_for(result.id)
         self.result_consumer.cancel_for(result.id)
 
 
-    def wait_for_pending(self, result,
-                         callback=None, propagate=True, **kwargs):
+    def wait_for_pending(self, result: ResultT,
+                         callback: Callable = None,
+                         propagate: bool = True,
+                         **kwargs) -> Any:
         self._ensure_not_eager()
         self._ensure_not_eager()
         for _ in self._wait_for_pending(result, **kwargs):
         for _ in self._wait_for_pending(result, **kwargs):
             pass
             pass
         return result.maybe_throw(callback=callback, propagate=propagate)
         return result.maybe_throw(callback=callback, propagate=propagate)
 
 
-    def _wait_for_pending(self, result,
-                          timeout=None, on_interval=None, on_message=None,
-                          **kwargs):
+    def _wait_for_pending(self, result: ResultT,
+                          timeout: float = None,
+                          on_interval: Callable = None,
+                          on_message: Callable = None,
+                          **kwargs) -> Iterable:
         return self.result_consumer._wait_for_pending(
         return self.result_consumer._wait_for_pending(
             result, timeout=timeout,
             result, timeout=timeout,
             on_interval=on_interval, on_message=on_message,
             on_interval=on_interval, on_message=on_message,
         )
         )
 
 
     @property
     @property
-    def is_async(self):
+    def is_async(self) -> bool:
         return True
         return True
 
 
 
 
 class BaseResultConsumer:
 class BaseResultConsumer:
     """Manager responsible for consuming result messages."""
     """Manager responsible for consuming result messages."""
 
 
-    def __init__(self, backend, app, accept,
-                 pending_results, pending_messages):
+    def __init__(self, backend: BackendT, app: AppT, accept: Set[str],
+                 pending_results: pending_results_t,
+                 pending_messages: BufferMap) -> None:
         self.backend = backend
         self.backend = backend
         self.app = app
         self.app = app
         self.accept = accept
         self.accept = accept
@@ -214,37 +239,39 @@ class BaseResultConsumer:
         self.buckets = WeakKeyDictionary()
         self.buckets = WeakKeyDictionary()
         self.drainer = drainers[detect_environment()](self)
         self.drainer = drainers[detect_environment()](self)
 
 
-    def start(self, initial_task_id, **kwargs):
+    def start(self, initial_task_id: str, **kwargs) -> None:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
-    def stop(self):
-        pass
+    def stop(self) -> None:
+        ...
 
 
-    def drain_events(self, timeout=None):
+    def drain_events(self, timeout: float = None) -> None:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
-    def consume_from(self, task_id):
+    def consume_from(self, task_id: str) -> None:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
-    def cancel_for(self, task_id):
+    def cancel_for(self, task_id: str) -> None:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         self.buckets.clear()
         self.buckets.clear()
         self.buckets = WeakKeyDictionary()
         self.buckets = WeakKeyDictionary()
         self.on_message = None
         self.on_message = None
         self.on_after_fork()
         self.on_after_fork()
 
 
-    def on_after_fork(self):
-        pass
+    def on_after_fork(self) -> None:
+        ...
 
 
-    def drain_events_until(self, p, timeout=None, on_interval=None):
+    def drain_events_until(self, p: Awaitable,
+                           timeout: float = None,
+                           on_interval: Callable = None) -> Iterable:
         return self.drainer.drain_events_until(
         return self.drainer.drain_events_until(
             p, timeout=timeout, on_interval=on_interval)
             p, timeout=timeout, on_interval=on_interval)
 
 
     def _wait_for_pending(self, result,
     def _wait_for_pending(self, result,
                           timeout=None, on_interval=None, on_message=None,
                           timeout=None, on_interval=None, on_message=None,
-                          **kwargs):
+                          **kwargs) -> Iterable:
         self.on_wait_for_pending(result, timeout=timeout, **kwargs)
         self.on_wait_for_pending(result, timeout=timeout, **kwargs)
         prev_on_m, self.on_message = self.on_message, on_message
         prev_on_m, self.on_message = self.on_message, on_message
         try:
         try:
@@ -258,13 +285,14 @@ class BaseResultConsumer:
         finally:
         finally:
             self.on_message = prev_on_m
             self.on_message = prev_on_m
 
 
-    def on_wait_for_pending(self, result, timeout=None, **kwargs):
-        pass
+    def on_wait_for_pending(self, result: ResultT,
+                            timeout: float = None, **kwargs) -> None:
+        ...
 
 
-    def on_out_of_band_result(self, message):
+    def on_out_of_band_result(self, message: MessageT) -> None:
         self.on_state_change(message.payload, message)
         self.on_state_change(message.payload, message)
 
 
-    def _get_pending_result(self, task_id):
+    def _get_pending_result(self, task_id: str) -> ResultT:
         for mapping in self._pending_results:
         for mapping in self._pending_results:
             try:
             try:
                 return mapping[task_id]
                 return mapping[task_id]
@@ -272,7 +300,7 @@ class BaseResultConsumer:
                 pass
                 pass
         raise KeyError(task_id)
         raise KeyError(task_id)
 
 
-    def on_state_change(self, meta, message):
+    def on_state_change(self, meta: Mapping, message: MessageT) -> None:
         if self.on_message:
         if self.on_message:
             self.on_message(meta)
             self.on_message(meta)
         if meta['status'] in states.READY_STATES:
         if meta['status'] in states.READY_STATES:

+ 264 - 166
celery/backends/base.py

@@ -9,8 +9,12 @@
 import sys
 import sys
 import time
 import time
 
 
-from collections import namedtuple
 from datetime import timedelta
 from datetime import timedelta
+from numbers import Number
+from typing import (
+    Any, AnyStr, Callable, Dict, Iterable, Iterator,
+    Mapping, MutableMapping, NamedTuple, Optional, Set, Sequence, Tuple,
+)
 from weakref import WeakValueDictionary
 from weakref import WeakValueDictionary
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
@@ -18,6 +22,7 @@ from kombu.serialization import (
     dumps, loads, prepare_accept_content,
     dumps, loads, prepare_accept_content,
     registry as serializer_registry,
     registry as serializer_registry,
 )
 )
+from kombu.types import ProducerT
 from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 from kombu.utils.url import maybe_sanitize_url
 from kombu.utils.url import maybe_sanitize_url
 
 
@@ -30,6 +35,7 @@ from celery.exceptions import (
 from celery.result import (
 from celery.result import (
     GroupResult, ResultBase, allow_join_result, result_from_tuple,
     GroupResult, ResultBase, allow_join_result, result_from_tuple,
 )
 )
+from celery.types import AppT, BackendT, ResultT, RequestT, SignatureT
 from celery.utils.collections import BufferMap
 from celery.utils.collections import BufferMap
 from celery.utils.functional import LRUCache, arity_greater
 from celery.utils.functional import LRUCache, arity_greater
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
@@ -47,10 +53,6 @@ logger = get_logger(__name__)
 
 
 MESSAGE_BUFFER_MAX = 8192
 MESSAGE_BUFFER_MAX = 8192
 
 
-pending_results_t = namedtuple('pending_results_t', (
-    'concrete', 'weak',
-))
-
 E_NO_BACKEND = """
 E_NO_BACKEND = """
 No result backend is configured.
 No result backend is configured.
 Please see the documentation for more information.
 Please see the documentation for more information.
@@ -66,15 +68,22 @@ Result backends that supports chords: Redis, Database, Memcached, and more.
 """
 """
 
 
 
 
-def unpickle_backend(cls, args, kwargs):
+class pending_results_t(NamedTuple):
+    """Tuple of concrete and weak references to pending results."""
+
+    concrete: Mapping
+    weak: WeakValueDictionary
+
+
+def unpickle_backend(cls: type, args: Tuple, kwargs: Dict) -> BackendT:
     """Return an unpickled backend."""
     """Return an unpickled backend."""
     return cls(*args, app=current_app._get_current_object(), **kwargs)
     return cls(*args, app=current_app._get_current_object(), **kwargs)
 
 
 
 
 class _nulldict(dict):
 class _nulldict(dict):
 
 
-    def ignore(self, *a, **kw):
-        pass
+    def ignore(self, *a, **kw) -> None:
+        ...
     __setitem__ = update = setdefault = ignore
     __setitem__ = update = setdefault = ignore
 
 
 
 
@@ -89,7 +98,7 @@ class Backend:
     #: Time to sleep between polling each individual item
     #: Time to sleep between polling each individual item
     #: in `ResultSet.iterate`. as opposed to the `interval`
     #: in `ResultSet.iterate`. as opposed to the `interval`
     #: argument which is for each pass.
     #: argument which is for each pass.
-    subpolling_interval = None
+    subpolling_interval: float = None
 
 
     #: If true the backend must implement :meth:`get_many`.
     #: If true the backend must implement :meth:`get_many`.
     supports_native_join = False
     supports_native_join = False
@@ -109,9 +118,15 @@ class Backend:
         'interval_max': 1,
         'interval_max': 1,
     }
     }
 
 
-    def __init__(self, app,
-                 serializer=None, max_cached_results=None, accept=None,
-                 expires=None, expires_type=None, url=None, **kwargs):
+    def __init__(self, app: AppT,
+                 *,
+                 serializer: str = None,
+                 max_cached_results: int = None,
+                 accept: Set[str] = None,
+                 expires: float = None,
+                 expires_type: Callable = None,
+                 url: str = None,
+                 **kwargs) -> None:
         self.app = app
         self.app = app
         conf = self.app.conf
         conf = self.app.conf
         self.serializer = serializer or conf.result_serializer
         self.serializer = serializer or conf.result_serializer
@@ -128,7 +143,7 @@ class Backend:
         self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
         self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
         self.url = url
         self.url = url
 
 
-    def as_uri(self, include_password=False):
+    def as_uri(self, include_password: bool = False) -> str:
         """Return the backend as an URI, sanitizing the password or not."""
         """Return the backend as an URI, sanitizing the password or not."""
         # when using maybe_sanitize_url(), "/" is added
         # when using maybe_sanitize_url(), "/" is added
         # we're stripping it for consistency
         # we're stripping it for consistency
@@ -137,33 +152,42 @@ class Backend:
         url = maybe_sanitize_url(self.url or '')
         url = maybe_sanitize_url(self.url or '')
         return url[:-1] if url.endswith(':///') else url
         return url[:-1] if url.endswith(':///') else url
 
 
-    def mark_as_started(self, task_id, **meta):
+    async def mark_as_started(self, task_id: str, **meta) -> Any:
         """Mark a task as started."""
         """Mark a task as started."""
-        return self.store_result(task_id, meta, states.STARTED)
+        return await self.store_result(task_id, meta, states.STARTED)
 
 
-    def mark_as_done(self, task_id, result,
-                     request=None, store_result=True, state=states.SUCCESS):
+    async def mark_as_done(self, task_id: str, result: Any,
+                           *,
+                           request: RequestT = None,
+                           store_result: bool = True,
+                           state: str = states.SUCCESS) -> None:
         """Mark task as successfully executed."""
         """Mark task as successfully executed."""
         if store_result:
         if store_result:
-            self.store_result(task_id, result, state, request=request)
+            await self.store_result(task_id, result, state, request=request)
         if request and request.chord:
         if request and request.chord:
-            self.on_chord_part_return(request, state, result)
-
-    def mark_as_failure(self, task_id, exc,
-                        traceback=None, request=None,
-                        store_result=True, call_errbacks=True,
-                        state=states.FAILURE):
+            await self.on_chord_part_return(request, state, result)
+
+    async def mark_as_failure(self, task_id: str, exc: Exception,
+                              *,
+                              traceback: str = None,
+                              request: RequestT = None,
+                              store_result: bool = True,
+                              call_errbacks: bool = True,
+                              state: str = states.FAILURE) -> None:
         """Mark task as executed with failure."""
         """Mark task as executed with failure."""
         if store_result:
         if store_result:
-            self.store_result(task_id, exc, state,
-                              traceback=traceback, request=request)
+            await self.store_result(task_id, exc, state,
+                                    traceback=traceback, request=request)
         if request:
         if request:
             if request.chord:
             if request.chord:
-                self.on_chord_part_return(request, state, exc)
+                await self.on_chord_part_return(request, state, exc)
             if call_errbacks and request.errbacks:
             if call_errbacks and request.errbacks:
-                self._call_task_errbacks(request, exc, traceback)
+                await self._call_task_errbacks(request, exc, traceback)
 
 
-    def _call_task_errbacks(self, request, exc, traceback):
+    async def _call_task_errbacks(self,
+                                  request: RequestT,
+                                  exc: Exception,
+                                  traceback: str) -> None:
         old_signature = []
         old_signature = []
         for errback in request.errbacks:
         for errback in request.errbacks:
             errback = self.app.signature(errback)
             errback = self.app.signature(errback)
@@ -176,30 +200,38 @@ class Backend:
             # need to do so if the errback only takes a single task_id arg.
             # need to do so if the errback only takes a single task_id arg.
             task_id = request.id
             task_id = request.id
             root_id = request.root_id or task_id
             root_id = request.root_id or task_id
-            group(old_signature, app=self.app).apply_async(
+            await group(old_signature, app=self.app).apply_async(
                 (task_id,), parent_id=task_id, root_id=root_id
                 (task_id,), parent_id=task_id, root_id=root_id
             )
             )
 
 
-    def mark_as_revoked(self, task_id, reason='',
-                        request=None, store_result=True, state=states.REVOKED):
+    async def mark_as_revoked(self, task_id: str, reason: str = '',
+                              *,
+                              request: RequestT = None,
+                              store_result: bool = True,
+                              state: str = states.REVOKED) -> None:
         exc = TaskRevokedError(reason)
         exc = TaskRevokedError(reason)
         if store_result:
         if store_result:
-            self.store_result(task_id, exc, state,
-                              traceback=None, request=request)
+            await self.store_result(task_id, exc, state,
+                                    traceback=None, request=request)
         if request and request.chord:
         if request and request.chord:
-            self.on_chord_part_return(request, state, exc)
-
-    def mark_as_retry(self, task_id, exc, traceback=None,
-                      request=None, store_result=True, state=states.RETRY):
+            await self.on_chord_part_return(request, state, exc)
+
+    async def mark_as_retry(self, task_id: str, exc: Exception,
+                            *,
+                            traceback: str = None,
+                            request: RequestT = None,
+                            store_result: bool = True,
+                            state: str = states.RETRY) -> None:
         """Mark task as being retries.
         """Mark task as being retries.
 
 
         Note:
         Note:
             Stores the current exception (if any).
             Stores the current exception (if any).
         """
         """
-        return self.store_result(task_id, exc, state,
-                                 traceback=traceback, request=request)
+        return await self.store_result(task_id, exc, state,
+                                       traceback=traceback, request=request)
 
 
-    def chord_error_from_stack(self, callback, exc=None):
+    async def chord_error_from_stack(self, callback: Callable,
+                                     exc: Exception = None) -> ExceptionInfo:
         # need below import for test for some crazy reason
         # need below import for test for some crazy reason
         from celery import group  # pylint: disable
         from celery import group  # pylint: disable
         app = self.app
         app = self.app
@@ -208,34 +240,37 @@ class Backend:
         except KeyError:
         except KeyError:
             backend = self
             backend = self
         try:
         try:
-            group(
+            await group(
                 [app.signature(errback)
                 [app.signature(errback)
                  for errback in callback.options.get('link_error') or []],
                  for errback in callback.options.get('link_error') or []],
                 app=app,
                 app=app,
             ).apply_async((callback.id,))
             ).apply_async((callback.id,))
         except Exception as eb_exc:  # pylint: disable=broad-except
         except Exception as eb_exc:  # pylint: disable=broad-except
-            return backend.fail_from_current_stack(callback.id, exc=eb_exc)
+            return await backend.fail_from_current_stack(
+                callback.id, exc=eb_exc)
         else:
         else:
-            return backend.fail_from_current_stack(callback.id, exc=exc)
+            return await backend.fail_from_current_stack(
+                callback.id, exc=exc)
 
 
-    def fail_from_current_stack(self, task_id, exc=None):
+    async def fail_from_current_stack(self, task_id: str,
+                                      exc: Exception = None) -> ExceptionInfo:
         type_, real_exc, tb = sys.exc_info()
         type_, real_exc, tb = sys.exc_info()
         try:
         try:
             exc = real_exc if exc is None else exc
             exc = real_exc if exc is None else exc
             ei = ExceptionInfo((type_, exc, tb))
             ei = ExceptionInfo((type_, exc, tb))
-            self.mark_as_failure(task_id, exc, ei.traceback)
+            await self.mark_as_failure(task_id, exc, ei.traceback)
             return ei
             return ei
         finally:
         finally:
             del tb
             del tb
 
 
-    def prepare_exception(self, exc, serializer=None):
+    def prepare_exception(self, exc: Exception, serializer: str = None) -> Any:
         """Prepare exception for serialization."""
         """Prepare exception for serialization."""
         serializer = self.serializer if serializer is None else serializer
         serializer = self.serializer if serializer is None else serializer
         if serializer in EXCEPTION_ABLE_CODECS:
         if serializer in EXCEPTION_ABLE_CODECS:
             return get_pickleable_exception(exc)
             return get_pickleable_exception(exc)
         return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
         return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
 
 
-    def exception_to_python(self, exc):
+    def exception_to_python(self, exc: Any) -> Exception:
         """Convert serialized exception to Python exception."""
         """Convert serialized exception to Python exception."""
         if exc:
         if exc:
             if not isinstance(exc, BaseException):
             if not isinstance(exc, BaseException):
@@ -245,34 +280,35 @@ class Backend:
                 exc = get_pickled_exception(exc)
                 exc = get_pickled_exception(exc)
         return exc
         return exc
 
 
-    def prepare_value(self, result):
+    def prepare_value(self, result: Any) -> Any:
         """Prepare value for storage."""
         """Prepare value for storage."""
         if self.serializer != 'pickle' and isinstance(result, ResultBase):
         if self.serializer != 'pickle' and isinstance(result, ResultBase):
             return result.as_tuple()
             return result.as_tuple()
         return result
         return result
 
 
-    def encode(self, data):
+    def encode(self, data: Any) -> AnyStr:
         _, _, payload = self._encode(data)
         _, _, payload = self._encode(data)
         return payload
         return payload
 
 
-    def _encode(self, data):
+    def _encode(self, data: Any) -> AnyStr:
         return dumps(data, serializer=self.serializer)
         return dumps(data, serializer=self.serializer)
 
 
-    def meta_from_decoded(self, meta):
+    def meta_from_decoded(self, meta: MutableMapping) -> MutableMapping:
         if meta['status'] in self.EXCEPTION_STATES:
         if meta['status'] in self.EXCEPTION_STATES:
             meta['result'] = self.exception_to_python(meta['result'])
             meta['result'] = self.exception_to_python(meta['result'])
         return meta
         return meta
 
 
-    def decode_result(self, payload):
+    def decode_result(self, payload: AnyStr) -> Mapping:
         return self.meta_from_decoded(self.decode(payload))
         return self.meta_from_decoded(self.decode(payload))
 
 
-    def decode(self, payload):
+    def decode(self, payload: AnyStr) -> Mapping:
         return loads(payload,
         return loads(payload,
                      content_type=self.content_type,
                      content_type=self.content_type,
                      content_encoding=self.content_encoding,
                      content_encoding=self.content_encoding,
                      accept=self.accept)
                      accept=self.accept)
 
 
-    def prepare_expires(self, value, type=None):
+    def prepare_expires(self, value: Optional[Number],
+                        type: Callable = None) -> Optional[Number]:
         if value is None:
         if value is None:
             value = self.app.conf.result_expires
             value = self.app.conf.result_expires
         if isinstance(value, timedelta):
         if isinstance(value, timedelta):
@@ -281,61 +317,63 @@ class Backend:
             return type(value)
             return type(value)
         return value
         return value
 
 
-    def prepare_persistent(self, enabled=None):
+    def prepare_persistent(self, enabled: bool = None) -> bool:
         if enabled is not None:
         if enabled is not None:
             return enabled
             return enabled
         p = self.app.conf.result_persistent
         p = self.app.conf.result_persistent
         return self.persistent if p is None else p
         return self.persistent if p is None else p
 
 
-    def encode_result(self, result, state):
+    def encode_result(self, result: Any, state: str) -> Any:
         if state in self.EXCEPTION_STATES and isinstance(result, Exception):
         if state in self.EXCEPTION_STATES and isinstance(result, Exception):
             return self.prepare_exception(result)
             return self.prepare_exception(result)
         else:
         else:
             return self.prepare_value(result)
             return self.prepare_value(result)
 
 
-    def is_cached(self, task_id):
+    def is_cached(self, task_id: str) -> bool:
         return task_id in self._cache
         return task_id in self._cache
 
 
-    def store_result(self, task_id, result, state,
-                     traceback=None, request=None, **kwargs):
+    async def store_result(self, task_id: str, result: Any, state: str,
+                           *,
+                           traceback: str = None,
+                           request: RequestT = None, **kwargs) -> Any:
         """Update task state and result."""
         """Update task state and result."""
         result = self.encode_result(result, state)
         result = self.encode_result(result, state)
-        self._store_result(task_id, result, state, traceback,
-                           request=request, **kwargs)
+        await self._store_result(task_id, result, state, traceback,
+                                 request=request, **kwargs)
         return result
         return result
 
 
-    def forget(self, task_id):
+    async def forget(self, task_id: str) -> None:
         self._cache.pop(task_id, None)
         self._cache.pop(task_id, None)
-        self._forget(task_id)
+        await self._forget(task_id)
 
 
-    def _forget(self, task_id):
+    async def _forget(self, task_id: str) -> None:
         raise NotImplementedError('backend does not implement forget.')
         raise NotImplementedError('backend does not implement forget.')
 
 
-    def get_state(self, task_id):
+    def get_state(self, task_id: str) -> str:
         """Get the state of a task."""
         """Get the state of a task."""
         return self.get_task_meta(task_id)['status']
         return self.get_task_meta(task_id)['status']
 
 
-    def get_traceback(self, task_id):
+    def get_traceback(self, task_id: str) -> str:
         """Get the traceback for a failed task."""
         """Get the traceback for a failed task."""
         return self.get_task_meta(task_id).get('traceback')
         return self.get_task_meta(task_id).get('traceback')
 
 
-    def get_result(self, task_id):
+    def get_result(self, task_id: str) -> Any:
         """Get the result of a task."""
         """Get the result of a task."""
         return self.get_task_meta(task_id).get('result')
         return self.get_task_meta(task_id).get('result')
 
 
-    def get_children(self, task_id):
+    def get_children(self, task_id: str) -> Sequence[str]:
         """Get the list of subtasks sent by a task."""
         """Get the list of subtasks sent by a task."""
         try:
         try:
             return self.get_task_meta(task_id)['children']
             return self.get_task_meta(task_id)['children']
         except KeyError:
         except KeyError:
             pass
             pass
 
 
-    def _ensure_not_eager(self):
+    def _ensure_not_eager(self) -> None:
         if self.app.conf.task_always_eager:
         if self.app.conf.task_always_eager:
             raise RuntimeError(
             raise RuntimeError(
                 "Cannot retrieve result with task_always_eager enabled")
                 "Cannot retrieve result with task_always_eager enabled")
 
 
-    def get_task_meta(self, task_id, cache=True):
+    def get_task_meta(self, task_id: str, cache: bool = True) -> Mapping:
         self._ensure_not_eager()
         self._ensure_not_eager()
         if cache:
         if cache:
             try:
             try:
@@ -348,15 +386,15 @@ class Backend:
             self._cache[task_id] = meta
             self._cache[task_id] = meta
         return meta
         return meta
 
 
-    def reload_task_result(self, task_id):
+    def reload_task_result(self, task_id: str) -> None:
         """Reload task result, even if it has been previously fetched."""
         """Reload task result, even if it has been previously fetched."""
         self._cache[task_id] = self.get_task_meta(task_id, cache=False)
         self._cache[task_id] = self.get_task_meta(task_id, cache=False)
 
 
-    def reload_group_result(self, group_id):
+    def reload_group_result(self, group_id: str) -> None:
         """Reload group result, even if it has been previously fetched."""
         """Reload group result, even if it has been previously fetched."""
         self._cache[group_id] = self.get_group_meta(group_id, cache=False)
         self._cache[group_id] = self.get_group_meta(group_id, cache=False)
 
 
-    def get_group_meta(self, group_id, cache=True):
+    def get_group_meta(self, group_id: str, cache: bool = True) -> Mapping:
         self._ensure_not_eager()
         self._ensure_not_eager()
         if cache:
         if cache:
             try:
             try:
@@ -369,91 +407,118 @@ class Backend:
             self._cache[group_id] = meta
             self._cache[group_id] = meta
         return meta
         return meta
 
 
-    def restore_group(self, group_id, cache=True):
+    def restore_group(self, group_id: str, cache: bool = True) -> GroupResult:
         """Get the result for a group."""
         """Get the result for a group."""
         meta = self.get_group_meta(group_id, cache=cache)
         meta = self.get_group_meta(group_id, cache=cache)
         if meta:
         if meta:
             return meta['result']
             return meta['result']
 
 
-    def save_group(self, group_id, result):
+    def save_group(self, group_id: str, result: GroupResult) -> GroupResult:
         """Store the result of an executed group."""
         """Store the result of an executed group."""
         return self._save_group(group_id, result)
         return self._save_group(group_id, result)
 
 
-    def delete_group(self, group_id):
+    def _save_group(self, group_id: str, result: GroupResult) -> GroupResult:
+        raise NotImplementedError()
+
+    def delete_group(self, group_id: str) -> None:
         self._cache.pop(group_id, None)
         self._cache.pop(group_id, None)
-        return self._delete_group(group_id)
+        self._delete_group(group_id)
 
 
-    def cleanup(self):
+    def _delete_group(self, group_id: str) -> None:
+        raise NotImplementedError()
+
+    def cleanup(self) -> None:
         """Backend cleanup.
         """Backend cleanup.
 
 
         Note:
         Note:
             This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`.
             This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`.
         """
         """
-        pass
+        ...
 
 
-    def process_cleanup(self):
+    def process_cleanup(self) -> None:
         """Cleanup actions to do at the end of a task worker process."""
         """Cleanup actions to do at the end of a task worker process."""
-        pass
+        ...
 
 
-    def on_task_call(self, producer, task_id):
+    async def on_task_call(self, producer: ProducerT, task_id: str) -> Mapping:
         return {}
         return {}
 
 
-    def add_to_chord(self, chord_id, result):
+    async def add_to_chord(self, chord_id: str, result: ResultT) -> None:
         raise NotImplementedError('Backend does not support add_to_chord')
         raise NotImplementedError('Backend does not support add_to_chord')
 
 
-    def on_chord_part_return(self, request, state, result, **kwargs):
-        pass
+    async def on_chord_part_return(self, request: RequestT,
+                                   state: str,
+                                   result: Any,
+                                   **kwargs) -> None:
+        ...
 
 
-    def fallback_chord_unlock(self, group_id, body, result=None,
-                              countdown=1, **kwargs):
+    async def fallback_chord_unlock(self, group_id: str, body: SignatureT,
+                                    result: ResultT = None,
+                                    countdown: float = 1,
+                                    **kwargs) -> None:
         kwargs['result'] = [r.as_tuple() for r in result]
         kwargs['result'] = [r.as_tuple() for r in result]
-        self.app.tasks['celery.chord_unlock'].apply_async(
+        await self.app.tasks['celery.chord_unlock'].apply_async(
             (group_id, body,), kwargs, countdown=countdown,
             (group_id, body,), kwargs, countdown=countdown,
         )
         )
 
 
-    def ensure_chords_allowed(self):
-        pass
+    def ensure_chords_allowed(self) -> None:
+        ...
 
 
-    def apply_chord(self, header, partial_args, group_id, body,
-                    options={}, **kwargs):
+    async def apply_chord(self, header: SignatureT, partial_args: Sequence,
+                          *,
+                          group_id: str, body: SignatureT,
+                          options: Mapping = {}, **kwargs) -> ResultT:
         self.ensure_chords_allowed()
         self.ensure_chords_allowed()
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
-        result = header(*partial_args, task_id=group_id, **fixed_options or {})
-        self.fallback_chord_unlock(group_id, body, **kwargs)
+        result = await header(
+            *partial_args, task_id=group_id, **fixed_options or {})
+        await self.fallback_chord_unlock(group_id, body, **kwargs)
         return result
         return result
 
 
-    def current_task_children(self, request=None):
+    def current_task_children(self,
+                              request: RequestT = None) -> Sequence[Tuple]:
         request = request or getattr(get_current_task(), 'request', None)
         request = request or getattr(get_current_task(), 'request', None)
         if request:
         if request:
             return [r.as_tuple() for r in getattr(request, 'children', [])]
             return [r.as_tuple() for r in getattr(request, 'children', [])]
 
 
-    def __reduce__(self, args=(), kwargs={}):
+    def __reduce__(self, args: Tuple = (), kwargs: Dict = {}) -> Tuple:
         return (unpickle_backend, (self.__class__, args, kwargs))
         return (unpickle_backend, (self.__class__, args, kwargs))
 
 
 
 
 class SyncBackendMixin:
 class SyncBackendMixin:
 
 
-    def iter_native(self, result, timeout=None, interval=0.5, no_ack=True,
-                    on_message=None, on_interval=None):
+    async def iter_native(
+            self, result: ResultT,
+            *,
+            timeout: float = None,
+            interval: float = 0.5,
+            no_ack: bool = True,
+            on_message: Callable = None,
+            on_interval: Callable = None) -> Iterable[Tuple[str, Mapping]]:
         self._ensure_not_eager()
         self._ensure_not_eager()
         results = result.results
         results = result.results
         if not results:
         if not results:
             return iter([])
             return iter([])
-        return self.get_many(
+        return await self.get_many(
             {r.id for r in results},
             {r.id for r in results},
             timeout=timeout, interval=interval, no_ack=no_ack,
             timeout=timeout, interval=interval, no_ack=no_ack,
             on_message=on_message, on_interval=on_interval,
             on_message=on_message, on_interval=on_interval,
         )
         )
 
 
-    def wait_for_pending(self, result, timeout=None, interval=0.5,
-                         no_ack=True, on_message=None, on_interval=None,
-                         callback=None, propagate=True):
+    async def wait_for_pending(self, result: ResultT,
+                               *,
+                               timeout: float = None,
+                               interval: float = 0.5,
+                               no_ack: bool = True,
+                               on_message: Callable = None,
+                               on_interval: Callable = None,
+                               callback: Callable = None,
+                               propagate: bool = True) -> Any:
         self._ensure_not_eager()
         self._ensure_not_eager()
         if on_message is not None:
         if on_message is not None:
             raise ImproperlyConfigured(
             raise ImproperlyConfigured(
                 'Backend does not support on_message callback')
                 'Backend does not support on_message callback')
 
 
-        meta = self.wait_for(
+        meta = await self.wait_for(
             result.id, timeout=timeout,
             result.id, timeout=timeout,
             interval=interval,
             interval=interval,
             on_interval=on_interval,
             on_interval=on_interval,
@@ -463,8 +528,12 @@ class SyncBackendMixin:
             result._maybe_set_cache(meta)
             result._maybe_set_cache(meta)
             return result.maybe_throw(propagate=propagate, callback=callback)
             return result.maybe_throw(propagate=propagate, callback=callback)
 
 
-    def wait_for(self, task_id,
-                 timeout=None, interval=0.5, no_ack=True, on_interval=None):
+    async def wait_for(self, task_id: str,
+                       *,
+                       timeout: float = None,
+                       interval: float = 0.5,
+                       no_ack: bool = True,
+                       on_interval: Callable = None) -> Mapping:
         """Wait for task and return its result.
         """Wait for task and return its result.
 
 
         If the task raises an exception, this exception
         If the task raises an exception, this exception
@@ -484,21 +553,23 @@ class SyncBackendMixin:
             if meta['status'] in states.READY_STATES:
             if meta['status'] in states.READY_STATES:
                 return meta
                 return meta
             if on_interval:
             if on_interval:
-                on_interval()
+                await on_interval()
             # avoid hammering the CPU checking status.
             # avoid hammering the CPU checking status.
             time.sleep(interval)
             time.sleep(interval)
             time_elapsed += interval
             time_elapsed += interval
             if timeout and time_elapsed >= timeout:
             if timeout and time_elapsed >= timeout:
                 raise TimeoutError('The operation timed out.')
                 raise TimeoutError('The operation timed out.')
 
 
-    def add_pending_result(self, result, weak=False):
+    def add_pending_result(self, result: ResultT,
+                           *,
+                           weak: bool = False) -> ResultT:
         return result
         return result
 
 
-    def remove_pending_result(self, result):
+    def remove_pending_result(self, result: ResultT) -> ResultT:
         return result
         return result
 
 
     @property
     @property
-    def is_async(self):
+    def is_async(self) -> bool:
         return False
         return False
 
 
 
 
@@ -514,7 +585,7 @@ class BaseKeyValueStoreBackend(Backend):
     chord_keyprefix = 'chord-unlock-'
     chord_keyprefix = 'chord-unlock-'
     implements_incr = False
     implements_incr = False
 
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         if hasattr(self.key_t, '__func__'):  # pragma: no cover
         if hasattr(self.key_t, '__func__'):  # pragma: no cover
             self.key_t = self.key_t.__func__  # remove binding
             self.key_t = self.key_t.__func__  # remove binding
         self._encode_prefixes()
         self._encode_prefixes()
@@ -522,51 +593,51 @@ class BaseKeyValueStoreBackend(Backend):
         if self.implements_incr:
         if self.implements_incr:
             self.apply_chord = self._apply_chord_incr
             self.apply_chord = self._apply_chord_incr
 
 
-    def _encode_prefixes(self):
+    def _encode_prefixes(self) -> None:
         self.task_keyprefix = self.key_t(self.task_keyprefix)
         self.task_keyprefix = self.key_t(self.task_keyprefix)
         self.group_keyprefix = self.key_t(self.group_keyprefix)
         self.group_keyprefix = self.key_t(self.group_keyprefix)
         self.chord_keyprefix = self.key_t(self.chord_keyprefix)
         self.chord_keyprefix = self.key_t(self.chord_keyprefix)
 
 
-    def get(self, key):
+    async def get(self, key: AnyStr) -> AnyStr:
         raise NotImplementedError('Must implement the get method.')
         raise NotImplementedError('Must implement the get method.')
 
 
-    def mget(self, keys):
+    async def mget(self, keys: Sequence[AnyStr]) -> Sequence[AnyStr]:
         raise NotImplementedError('Does not support get_many')
         raise NotImplementedError('Does not support get_many')
 
 
-    def set(self, key, value):
+    async def set(self, key: AnyStr, value: AnyStr) -> None:
         raise NotImplementedError('Must implement the set method.')
         raise NotImplementedError('Must implement the set method.')
 
 
-    def delete(self, key):
+    async def delete(self, key: AnyStr) -> None:
         raise NotImplementedError('Must implement the delete method')
         raise NotImplementedError('Must implement the delete method')
 
 
-    def incr(self, key):
+    async def incr(self, key: AnyStr) -> None:
         raise NotImplementedError('Does not implement incr')
         raise NotImplementedError('Does not implement incr')
 
 
-    def expire(self, key, value):
-        pass
+    async def expire(self, key: AnyStr, value: AnyStr) -> None:
+        ...
 
 
-    def get_key_for_task(self, task_id, key=''):
+    def get_key_for_task(self, task_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for a task by id."""
         """Get the cache key for a task by id."""
         key_t = self.key_t
         key_t = self.key_t
         return key_t('').join([
         return key_t('').join([
             self.task_keyprefix, key_t(task_id), key_t(key),
             self.task_keyprefix, key_t(task_id), key_t(key),
         ])
         ])
 
 
-    def get_key_for_group(self, group_id, key=''):
+    def get_key_for_group(self, group_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for a group by id."""
         """Get the cache key for a group by id."""
         key_t = self.key_t
         key_t = self.key_t
         return key_t('').join([
         return key_t('').join([
             self.group_keyprefix, key_t(group_id), key_t(key),
             self.group_keyprefix, key_t(group_id), key_t(key),
         ])
         ])
 
 
-    def get_key_for_chord(self, group_id, key=''):
+    def get_key_for_chord(self, group_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for the chord waiting on group with given id."""
         """Get the cache key for the chord waiting on group with given id."""
         key_t = self.key_t
         key_t = self.key_t
         return key_t('').join([
         return key_t('').join([
             self.chord_keyprefix, key_t(group_id), key_t(key),
             self.chord_keyprefix, key_t(group_id), key_t(key),
         ])
         ])
 
 
-    def _strip_prefix(self, key):
+    def _strip_prefix(self, key: AnyStr) -> AnyStr:
         """Take bytes: emit string."""
         """Take bytes: emit string."""
         key = self.key_t(key)
         key = self.key_t(key)
         for prefix in self.task_keyprefix, self.group_keyprefix:
         for prefix in self.task_keyprefix, self.group_keyprefix:
@@ -574,14 +645,18 @@ class BaseKeyValueStoreBackend(Backend):
                 return bytes_to_str(key[len(prefix):])
                 return bytes_to_str(key[len(prefix):])
         return bytes_to_str(key)
         return bytes_to_str(key)
 
 
-    def _filter_ready(self, values, READY_STATES=states.READY_STATES):
+    def _filter_ready(
+            self,
+            values: Sequence[Tuple[Any, Any]],
+            *,
+            READY_STATES: Set[str] = states.READY_STATES) -> Iterable[Tuple]:
         for k, v in values:
         for k, v in values:
             if v is not None:
             if v is not None:
                 v = self.decode_result(v)
                 v = self.decode_result(v)
                 if v['status'] in READY_STATES:
                 if v['status'] in READY_STATES:
                     yield k, v
                     yield k, v
 
 
-    def _mget_to_results(self, values, keys):
+    def _mget_to_results(self, values: Any, keys: Sequence[AnyStr]) -> Mapping:
         if hasattr(values, 'items'):
         if hasattr(values, 'items'):
             # client returns dict so mapping preserved.
             # client returns dict so mapping preserved.
             return {
             return {
@@ -595,9 +670,16 @@ class BaseKeyValueStoreBackend(Backend):
                 for i, v in self._filter_ready(enumerate(values))
                 for i, v in self._filter_ready(enumerate(values))
             }
             }
 
 
-    def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
-                 on_message=None, on_interval=None, max_iterations=None,
-                 READY_STATES=states.READY_STATES):
+    async def get_many(
+            self, task_ids: Sequence[str],
+            *,
+            timeout: float = None,
+            interval: float = 0.5,
+            no_ack: bool = True,
+            on_message: Callable = None,
+            on_interval: Callable = None,
+            max_iterations: int = None,
+            READY_STATES=states.READY_STATES) -> Iterator[str, Mapping]:
         interval = 0.5 if interval is None else interval
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
         cached_ids = set()
         cached_ids = set()
@@ -616,54 +698,59 @@ class BaseKeyValueStoreBackend(Backend):
         iterations = 0
         iterations = 0
         while ids:
         while ids:
             keys = list(ids)
             keys = list(ids)
-            r = self._mget_to_results(self.mget([self.get_key_for_task(k)
-                                                 for k in keys]), keys)
+            payloads = await self.mget([
+                self.get_key_for_task(k) for k in keys
+            ])
+            r = self._mget_to_results(payloads, keys)
             cache.update(r)
             cache.update(r)
             ids.difference_update({bytes_to_str(v) for v in r})
             ids.difference_update({bytes_to_str(v) for v in r})
             for key, value in r.items():
             for key, value in r.items():
                 if on_message is not None:
                 if on_message is not None:
-                    on_message(value)
+                    await on_message(value)
                 yield bytes_to_str(key), value
                 yield bytes_to_str(key), value
             if timeout and iterations * interval >= timeout:
             if timeout and iterations * interval >= timeout:
                 raise TimeoutError('Operation timed out ({0})'.format(timeout))
                 raise TimeoutError('Operation timed out ({0})'.format(timeout))
             if on_interval:
             if on_interval:
-                on_interval()
+                await on_interval()
             time.sleep(interval)  # don't busy loop.
             time.sleep(interval)  # don't busy loop.
             iterations += 1
             iterations += 1
             if max_iterations and iterations >= max_iterations:
             if max_iterations and iterations >= max_iterations:
                 break
                 break
 
 
-    def _forget(self, task_id):
-        self.delete(self.get_key_for_task(task_id))
+    async def _forget(self, task_id: str) -> None:
+        await self.delete(self.get_key_for_task(task_id))
 
 
-    def _store_result(self, task_id, result, state,
-                      traceback=None, request=None, **kwargs):
+    async def _store_result(self, task_id: str, result: Any, state: str,
+                            traceback: str = None,
+                            request: RequestT = None,
+                            **kwargs) -> Any:
         meta = {
         meta = {
             'status': state, 'result': result, 'traceback': traceback,
             'status': state, 'result': result, 'traceback': traceback,
             'children': self.current_task_children(request),
             'children': self.current_task_children(request),
             'task_id': bytes_to_str(task_id),
             'task_id': bytes_to_str(task_id),
         }
         }
-        self.set(self.get_key_for_task(task_id), self.encode(meta))
+        await self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
         return result
 
 
-    def _save_group(self, group_id, result):
-        self.set(self.get_key_for_group(group_id),
-                 self.encode({'result': result.as_tuple()}))
+    async def _save_group(self,
+                          group_id: str, result: GroupResult) -> GroupResult:
+        await self.set(self.get_key_for_group(group_id),
+                       self.encode({'result': result.as_tuple()}))
         return result
         return result
 
 
-    def _delete_group(self, group_id):
-        self.delete(self.get_key_for_group(group_id))
+    async def _delete_group(self, group_id: str) -> None:
+        await self.delete(self.get_key_for_group(group_id))
 
 
-    def _get_task_meta_for(self, task_id):
+    async def _get_task_meta_for(self, task_id: str) -> Mapping:
         """Get task meta-data for a task by id."""
         """Get task meta-data for a task by id."""
-        meta = self.get(self.get_key_for_task(task_id))
+        meta = await self.get(self.get_key_for_task(task_id))
         if not meta:
         if not meta:
             return {'status': states.PENDING, 'result': None}
             return {'status': states.PENDING, 'result': None}
         return self.decode_result(meta)
         return self.decode_result(meta)
 
 
-    def _restore_group(self, group_id):
+    async def _restore_group(self, group_id: str) -> GroupResult:
         """Get task meta-data for a task by id."""
         """Get task meta-data for a task by id."""
-        meta = self.get(self.get_key_for_group(group_id))
+        meta = await self.get(self.get_key_for_group(group_id))
         # previously this was always pickled, but later this
         # previously this was always pickled, but later this
         # was extended to support other serializers, so the
         # was extended to support other serializers, so the
         # structure is kind of weird.
         # structure is kind of weird.
@@ -673,16 +760,25 @@ class BaseKeyValueStoreBackend(Backend):
             meta['result'] = result_from_tuple(result, self.app)
             meta['result'] = result_from_tuple(result, self.app)
             return meta
             return meta
 
 
-    def _apply_chord_incr(self, header, partial_args, group_id, body,
-                          result=None, options={}, **kwargs):
+    async def _apply_chord_incr(self, header: SignatureT,
+                                partial_args: Sequence,
+                                group_id: str,
+                                body: SignatureT,
+                                result: ResultT = None,
+                                options: Mapping = {},
+                                **kwargs) -> ResultT:
         self.ensure_chords_allowed()
         self.ensure_chords_allowed()
-        self.save_group(group_id, self.app.GroupResult(group_id, result))
+        await self.save_group(group_id, self.app.GroupResult(group_id, result))
 
 
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
 
 
-        return header(*partial_args, task_id=group_id, **fixed_options or {})
+        return await header(
+            *partial_args, task_id=group_id, **fixed_options or {})
 
 
-    def on_chord_part_return(self, request, state, result, **kwargs):
+    async def on_chord_part_return(
+            self,
+            request: RequestT, state: str, result: Any,
+            **kwargs) -> None:
         if not self.implements_incr:
         if not self.implements_incr:
             return
             return
         app = self.app
         app = self.app
@@ -691,25 +787,27 @@ class BaseKeyValueStoreBackend(Backend):
             return
             return
         key = self.get_key_for_chord(gid)
         key = self.get_key_for_chord(gid)
         try:
         try:
-            deps = GroupResult.restore(gid, backend=self)
+            deps = await GroupResult.restore(gid, backend=self)
         except Exception as exc:  # pylint: disable=broad-except
         except Exception as exc:  # pylint: disable=broad-except
             callback = maybe_signature(request.chord, app=app)
             callback = maybe_signature(request.chord, app=app)
             logger.exception('Chord %r raised: %r', gid, exc)
             logger.exception('Chord %r raised: %r', gid, exc)
-            return self.chord_error_from_stack(
+            await self.chord_error_from_stack(
                 callback,
                 callback,
                 ChordError('Cannot restore group: {0!r}'.format(exc)),
                 ChordError('Cannot restore group: {0!r}'.format(exc)),
             )
             )
+            return
         if deps is None:
         if deps is None:
             try:
             try:
                 raise ValueError(gid)
                 raise ValueError(gid)
             except ValueError as exc:
             except ValueError as exc:
                 callback = maybe_signature(request.chord, app=app)
                 callback = maybe_signature(request.chord, app=app)
                 logger.exception('Chord callback %r raised: %r', gid, exc)
                 logger.exception('Chord callback %r raised: %r', gid, exc)
-                return self.chord_error_from_stack(
+                await self.chord_error_from_stack(
                     callback,
                     callback,
                     ChordError('GroupResult {0} no longer exists'.format(gid)),
                     ChordError('GroupResult {0} no longer exists'.format(gid)),
                 )
                 )
-        val = self.incr(key)
+                return
+        val = await self.incr(key)
         size = len(deps)
         size = len(deps)
         if val > size:  # pragma: no cover
         if val > size:  # pragma: no cover
             logger.warning('Chord counter incremented too many times for %r',
             logger.warning('Chord counter incremented too many times for %r',
@@ -719,7 +817,7 @@ class BaseKeyValueStoreBackend(Backend):
             j = deps.join_native if deps.supports_native_join else deps.join
             j = deps.join_native if deps.supports_native_join else deps.join
             try:
             try:
                 with allow_join_result():
                 with allow_join_result():
-                    ret = j(timeout=3.0, propagate=True)
+                    ret = await j(timeout=3.0, propagate=True)
             except Exception as exc:  # pylint: disable=broad-except
             except Exception as exc:  # pylint: disable=broad-except
                 try:
                 try:
                     culprit = next(deps._failed_join_report())
                     culprit = next(deps._failed_join_report())
@@ -730,21 +828,21 @@ class BaseKeyValueStoreBackend(Backend):
                     reason = repr(exc)
                     reason = repr(exc)
 
 
                 logger.exception('Chord %r raised: %r', gid, reason)
                 logger.exception('Chord %r raised: %r', gid, reason)
-                self.chord_error_from_stack(callback, ChordError(reason))
+                await self.chord_error_from_stack(callback, ChordError(reason))
             else:
             else:
                 try:
                 try:
-                    callback.delay(ret)
+                    await callback.delay(ret)
                 except Exception as exc:  # pylint: disable=broad-except
                 except Exception as exc:  # pylint: disable=broad-except
                     logger.exception('Chord %r raised: %r', gid, exc)
                     logger.exception('Chord %r raised: %r', gid, exc)
-                    self.chord_error_from_stack(
+                    await self.chord_error_from_stack(
                         callback,
                         callback,
                         ChordError('Callback error: {0!r}'.format(exc)),
                         ChordError('Callback error: {0!r}'.format(exc)),
                     )
                     )
             finally:
             finally:
                 deps.delete()
                 deps.delete()
-                self.client.delete(key)
+                await self.client.delete(key)
         else:
         else:
-            self.expire(key, self.expires)
+            await self.expire(key, self.expires)
 
 
 
 
 class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
 class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
@@ -756,16 +854,16 @@ class DisabledBackend(BaseBackend):
 
 
     _cache = {}   # need this attribute to reset cache in tests.
     _cache = {}   # need this attribute to reset cache in tests.
 
 
-    def store_result(self, *args, **kwargs):
-        pass
+    async def store_result(self, *args, **kwargs) -> Any:
+        ...
 
 
-    def ensure_chords_allowed(self):
+    def ensure_chords_allowed(self) -> None:
         raise NotImplementedError(E_CHORD_NO_BACKEND.strip())
         raise NotImplementedError(E_CHORD_NO_BACKEND.strip())
 
 
-    def _is_disabled(self, *args, **kwargs):
+    def _is_disabled(self, *args, **kwargs) -> bool:
         raise NotImplementedError(E_NO_BACKEND.strip())
         raise NotImplementedError(E_NO_BACKEND.strip())
 
 
-    def as_uri(self, *args, **kwargs):
+    def as_uri(self, *args, **kwargs) -> str:
         return 'disabled://'
         return 'disabled://'
 
 
     get_state = get_result = get_traceback = _is_disabled
     get_state = get_result = get_traceback = _is_disabled

+ 71 - 47
celery/backends/rpc.py

@@ -6,12 +6,17 @@ RPC-style result backend, using reply-to and one queue per client.
 import kombu
 import kombu
 import time
 import time
 
 
+from typing import Any, Dict, Iterator, Mapping, Set, Tuple, Union
+
 from kombu.common import maybe_declare
 from kombu.common import maybe_declare
+from kombu.types import ChannelT, ConnectionT, EntityT, MessageT, ProducerT
 from kombu.utils.compat import register_after_fork
 from kombu.utils.compat import register_after_fork
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 from celery import states
 from celery import states
 from celery._state import current_task, task_join_will_block
 from celery._state import current_task, task_join_will_block
+from celery.types import AppT, BackendT, ResultT, RequestT
+from celery.result import GroupResult
 
 
 from . import base
 from . import base
 from .async import AsyncBackendMixin, BaseResultConsumer
 from .async import AsyncBackendMixin, BaseResultConsumer
@@ -32,21 +37,23 @@ class BacklogLimitExceeded(Exception):
     """Too much state history to fast-forward."""
     """Too much state history to fast-forward."""
 
 
 
 
-def _on_after_fork_cleanup_backend(backend):
+def _on_after_fork_cleanup_backend(backend: BackendT) -> None:
     backend._after_fork()
     backend._after_fork()
 
 
 
 
 class ResultConsumer(BaseResultConsumer):
 class ResultConsumer(BaseResultConsumer):
     Consumer = kombu.Consumer
     Consumer = kombu.Consumer
 
 
-    _connection = None
-    _consumer = None
+    _connection: ConnectionT = None
+    _consumer: ConsumerT = None
 
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self._create_binding = self.backend._create_binding
         self._create_binding = self.backend._create_binding
 
 
-    def start(self, initial_task_id, no_ack=True, **kwargs):
+    def start(self, initial_task_id: str,
+              *,
+              no_ack: bool = True, **kwargs) -> None:
         self._connection = self.app.connection()
         self._connection = self.app.connection()
         initial_queue = self._create_binding(initial_task_id)
         initial_queue = self._create_binding(initial_task_id)
         self._consumer = self.Consumer(
         self._consumer = self.Consumer(
@@ -55,33 +62,34 @@ class ResultConsumer(BaseResultConsumer):
             accept=self.accept)
             accept=self.accept)
         self._consumer.consume()
         self._consumer.consume()
 
 
-    def drain_events(self, timeout=None):
+    def drain_events(self, timeout: float = None) -> None:
         if self._connection:
         if self._connection:
-            return self._connection.drain_events(timeout=timeout)
+            self._connection.drain_events(timeout=timeout)
         elif timeout:
         elif timeout:
             time.sleep(timeout)
             time.sleep(timeout)
 
 
-    def stop(self):
+    def stop(self) -> None:
         try:
         try:
             self._consumer.cancel()
             self._consumer.cancel()
         finally:
         finally:
             self._connection.close()
             self._connection.close()
 
 
-    def on_after_fork(self):
+    def on_after_fork(self) -> None:
         self._consumer = None
         self._consumer = None
         if self._connection is not None:
         if self._connection is not None:
             self._connection.collect()
             self._connection.collect()
             self._connection = None
             self._connection = None
 
 
-    def consume_from(self, task_id):
+    def consume_from(self, task_id: str) -> None:
         if self._consumer is None:
         if self._consumer is None:
-            return self.start(task_id)
-        queue = self._create_binding(task_id)
-        if not self._consumer.consuming_from(queue):
-            self._consumer.add_queue(queue)
-            self._consumer.consume()
+            self.start(task_id)
+        else:
+            queue = self._create_binding(task_id)
+            if not self._consumer.consuming_from(queue):
+                self._consumer.add_queue(queue)
+                self._consumer.consume()
 
 
-    def cancel_for(self, task_id):
+    def cancel_for(self, task_id: str) -> None:
         if self._consumer:
         if self._consumer:
             self._consumer.cancel_by_queue(self._create_binding(task_id).name)
             self._consumer.cancel_by_queue(self._create_binding(task_id).name)
 
 
@@ -117,8 +125,14 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
 
 
         can_cache_declaration = False
         can_cache_declaration = False
 
 
-    def __init__(self, app, connection=None, exchange=None, exchange_type=None,
-                 persistent=None, serializer=None, auto_delete=True, **kwargs):
+    def __init__(self, app: AppT,
+                 connection: ConnectionT = None,
+                 exchange: str = None,
+                 exchange_type: str = None,
+                 persistent: bool = None,
+                 serializer: str = None,
+                 auto_delete: bool = True,
+                 **kwargs) -> None:
         super().__init__(app, **kwargs)
         super().__init__(app, **kwargs)
         conf = self.app.conf
         conf = self.app.conf
         self._connection = connection
         self._connection = connection
@@ -139,32 +153,36 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         if register_after_fork is not None:
         if register_after_fork is not None:
             register_after_fork(self, _on_after_fork_cleanup_backend)
             register_after_fork(self, _on_after_fork_cleanup_backend)
 
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         # clear state for child processes.
         # clear state for child processes.
         self._pending_results.clear()
         self._pending_results.clear()
         self.result_consumer._after_fork()
         self.result_consumer._after_fork()
 
 
-    def _create_exchange(self, name, type='direct', delivery_mode=2):
+    def _create_exchange(self, name: str,
+                         type: str = 'direct',
+                         delivery_mode: Union[int, str] = 2) -> Exchange:
         # uses direct to queue routing (anon exchange).
         # uses direct to queue routing (anon exchange).
         return self.Exchange(None)
         return self.Exchange(None)
 
 
-    def _create_binding(self, task_id):
+    def _create_binding(self, task_id: str) -> EntityT:
         """Create new binding for task with id."""
         """Create new binding for task with id."""
         # RPC backend caches the binding, as one queue is used for all tasks.
         # RPC backend caches the binding, as one queue is used for all tasks.
         return self.binding
         return self.binding
 
 
-    def ensure_chords_allowed(self):
+    def ensure_chords_allowed(self) -> None:
         raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
         raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
 
 
-    def on_task_call(self, producer, task_id):
+    def on_task_call(self, producer: ProducerT, task_id: str) -> Mapping:
         # Called every time a task is sent when using this backend.
         # Called every time a task is sent when using this backend.
         # We declare the queue we receive replies on in advance of sending
         # We declare the queue we receive replies on in advance of sending
         # the message, but we skip this if running in the prefork pool
         # the message, but we skip this if running in the prefork pool
         # (task_join_will_block), as we know the queue is already declared.
         # (task_join_will_block), as we know the queue is already declared.
         if not task_join_will_block():
         if not task_join_will_block():
             maybe_declare(self.binding(producer.channel), retry=True)
             maybe_declare(self.binding(producer.channel), retry=True)
+        return {}
 
 
-    def destination_for(self, task_id, request):
+    def destination_for(self,
+                        task_id: str, request: RequestT) -> Tuple[str, str]:
         """Get the destination for result by task id.
         """Get the destination for result by task id.
 
 
         Returns:
         Returns:
@@ -179,22 +197,23 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
                 'RPC backend missing task request for {0!r}'.format(task_id))
                 'RPC backend missing task request for {0!r}'.format(task_id))
         return request.reply_to, request.correlation_id or task_id
         return request.reply_to, request.correlation_id or task_id
 
 
-    def on_reply_declare(self, task_id):
+    def on_reply_declare(self, task_id: str) -> None:
         # Return value here is used as the `declare=` argument
         # Return value here is used as the `declare=` argument
         # for Producer.publish.
         # for Producer.publish.
         # By default we don't have to declare anything when sending a result.
         # By default we don't have to declare anything when sending a result.
-        pass
+        ...
 
 
-    def on_result_fulfilled(self, result):
+    def on_result_fulfilled(self, result: ResultT) -> None:
         # This usually cancels the queue after the result is received,
         # This usually cancels the queue after the result is received,
         # but we don't have to cancel since we have one queue per process.
         # but we don't have to cancel since we have one queue per process.
-        pass
+        ...
 
 
-    def as_uri(self, include_password=True):
+    def as_uri(self, include_password: bool = True) -> str:
         return 'rpc://'
         return 'rpc://'
 
 
-    def store_result(self, task_id, result, state,
-                     traceback=None, request=None, **kwargs):
+    def store_result(self, task_id: str, result: Any, state: str,
+                     traceback: str = None, request: RequestT = None,
+                     **kwargs) -> Any:
         """Send task return value and state."""
         """Send task return value and state."""
         routing_key, correlation_id = self.destination_for(task_id, request)
         routing_key, correlation_id = self.destination_for(task_id, request)
         if not routing_key:
         if not routing_key:
@@ -212,7 +231,9 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             )
             )
         return result
         return result
 
 
-    def _to_result(self, task_id, state, result, traceback, request):
+    def _to_result(self,
+                   task_id: str, state: str, result: Any,
+                   traceback: str, request: RequestT) -> Mapping:
         return {
         return {
             'task_id': task_id,
             'task_id': task_id,
             'status': state,
             'status': state,
@@ -221,7 +242,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             'children': self.current_task_children(request),
             'children': self.current_task_children(request),
         }
         }
 
 
-    def on_out_of_band_result(self, task_id, message):
+    def on_out_of_band_result(self, task_id: str, message: MessageT) -> None:
         # Callback called when a reply for a task is received,
         # Callback called when a reply for a task is received,
         # but we have no idea what do do with it.
         # but we have no idea what do do with it.
         # Since the result is not pending, we put it in a separate
         # Since the result is not pending, we put it in a separate
@@ -230,7 +251,8 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             self.result_consumer.on_out_of_band_result(message)
             self.result_consumer.on_out_of_band_result(message)
         self._out_of_band[task_id] = message
         self._out_of_band[task_id] = message
 
 
-    def get_task_meta(self, task_id, backlog_limit=1000):
+    def get_task_meta(self, task_id: str,
+                      *, backlog_limit: int = 1000) -> Mapping:
         buffered = self._out_of_band.pop(task_id, None)
         buffered = self._out_of_band.pop(task_id, None)
         if buffered:
         if buffered:
             return self._set_cache_by_message(task_id, buffered)
             return self._set_cache_by_message(task_id, buffered)
@@ -262,13 +284,15 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
                 # result probably pending.
                 # result probably pending.
                 return {'status': states.PENDING, 'result': None}
                 return {'status': states.PENDING, 'result': None}
 
 
-    def _set_cache_by_message(self, task_id, message):
+    def _set_cache_by_message(self,
+                              task_id: str, message: MessageT) -> Mapping:
         payload = self._cache[task_id] = self.meta_from_decoded(
         payload = self._cache[task_id] = self.meta_from_decoded(
             message.payload)
             message.payload)
         return payload
         return payload
 
 
-    def _slurp_from_queue(self, task_id, accept,
-                          limit=1000, no_ack=False):
+    def _slurp_from_queue(self, task_id: str, accept: Set[str],
+                          limit: int = 1000,
+                          no_ack: bool = False) -> Iterator[MessageT]:
         with self.app.pool.acquire_channel(block=True) as (_, channel):
         with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
             binding = self._create_binding(task_id)(channel)
             binding.declare()
             binding.declare()
@@ -281,7 +305,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             else:
             else:
                 raise self.BacklogLimitExceeded(task_id)
                 raise self.BacklogLimitExceeded(task_id)
 
 
-    def _get_message_task_id(self, message):
+    def _get_message_task_id(self, message: MessageT) -> str:
         try:
         try:
             # try property first so we don't have to deserialize
             # try property first so we don't have to deserialize
             # the payload.
             # the payload.
@@ -290,10 +314,10 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             # message sent by old Celery version, need to deserialize.
             # message sent by old Celery version, need to deserialize.
             return message.payload['task_id']
             return message.payload['task_id']
 
 
-    def revive(self, channel):
-        pass
+    def revive(self, channel: ChannelT) -> None:
+        ...
 
 
-    def reload_task_result(self, task_id):
+    def reload_task_result(self, task_id: str) -> None:
         raise NotImplementedError(
         raise NotImplementedError(
             'reload_task_result is not supported by this backend.')
             'reload_task_result is not supported by this backend.')
 
 
@@ -302,19 +326,19 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         raise NotImplementedError(
         raise NotImplementedError(
             'reload_group_result is not supported by this backend.')
             'reload_group_result is not supported by this backend.')
 
 
-    def save_group(self, group_id, result):
+    def save_group(self, group_id: str, result: GroupResult) -> None:
         raise NotImplementedError(
         raise NotImplementedError(
             'save_group is not supported by this backend.')
             'save_group is not supported by this backend.')
 
 
-    def restore_group(self, group_id, cache=True):
+    def restore_group(self, group_id: str, cache: bool = True) -> GroupResult:
         raise NotImplementedError(
         raise NotImplementedError(
             'restore_group is not supported by this backend.')
             'restore_group is not supported by this backend.')
 
 
-    def delete_group(self, group_id):
+    def delete_group(self, group_id: str) -> None:
         raise NotImplementedError(
         raise NotImplementedError(
             'delete_group is not supported by this backend.')
             'delete_group is not supported by this backend.')
 
 
-    def __reduce__(self, args=(), kwargs={}):
+    def __reduce__(self, args: Tuple = (), kwargs: Dict = {}) -> Tuple:
         return super().__reduce__(args, dict(
         return super().__reduce__(args, dict(
             kwargs,
             kwargs,
             connection=self._connection,
             connection=self._connection,
@@ -327,7 +351,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         ))
         ))
 
 
     @property
     @property
-    def binding(self):
+    def binding(self) -> EntityT:
         return self.Queue(
         return self.Queue(
             self.oid, self.exchange, self.oid,
             self.oid, self.exchange, self.oid,
             durable=False,
             durable=False,
@@ -336,6 +360,6 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         )
         )
 
 
     @cached_property
     @cached_property
-    def oid(self):
+    def oid(self) -> str:
         # cached here is the app OID: name of queue we receive results on.
         # cached here is the app OID: name of queue we receive results on.
         return self.app.oid
         return self.app.oid

+ 9 - 3
celery/beat.py

@@ -8,10 +8,10 @@ import shelve
 import sys
 import sys
 import traceback
 import traceback
 
 
-from collections import namedtuple
 from functools import total_ordering
 from functools import total_ordering
 from threading import Event, Thread
 from threading import Event, Thread
 from time import monotonic
 from time import monotonic
+from typing import NamedTuple
 
 
 from billiard import ensure_multiprocessing
 from billiard import ensure_multiprocessing
 from billiard.context import Process
 from billiard.context import Process
@@ -32,8 +32,6 @@ __all__ = [
     'PersistentScheduler', 'Service', 'EmbeddedService',
     'PersistentScheduler', 'Service', 'EmbeddedService',
 ]
 ]
 
 
-event_t = namedtuple('event_t', ('time', 'priority', 'entry'))
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 debug, info, error, warning = (logger.debug, logger.info,
 debug, info, error, warning = (logger.debug, logger.info,
                                logger.error, logger.warning)
                                logger.error, logger.warning)
@@ -41,6 +39,14 @@ debug, info, error, warning = (logger.debug, logger.info,
 DEFAULT_MAX_INTERVAL = 300  # 5 minutes
 DEFAULT_MAX_INTERVAL = 300  # 5 minutes
 
 
 
 
+class event_t(NamedTuple):
+    """Represents beat event in heap."""
+
+    time: float
+    priority: int
+    entry: 'ScheduleEntry'
+
+
 class SchedulingError(Exception):
 class SchedulingError(Exception):
     """An error occurred while scheduling a task."""
     """An error occurred while scheduling a task."""
 
 

+ 48 - 47
celery/bootsteps.py

@@ -122,24 +122,24 @@ class StartStopStep(Step):
     #: Optional obj created by the :meth:`create` method.
     #: Optional obj created by the :meth:`create` method.
     #: This is used by :class:`StartStopStep` to keep the
     #: This is used by :class:`StartStopStep` to keep the
     #: original service object.
     #: original service object.
-    obj = None
+    obj: Any = None
 
 
-    def start(self, parent):
+    async def start(self, parent: Any) -> None:
         if self.obj:
         if self.obj:
-            return self.obj.start()
+            await self.obj.start()
 
 
-    def stop(self, parent):
+    async def stop(self, parent: Any) -> None:
         if self.obj:
         if self.obj:
-            return self.obj.stop()
+            await self.obj.stop()
 
 
-    def close(self, parent):
-        pass
+    async def close(self, parent: Any) -> None:
+        ...
 
 
-    def terminate(self, parent):
+    async def terminate(self, parent: Any) -> None:
         if self.obj:
         if self.obj:
-            return getattr(self.obj, 'terminate', self.obj.stop)()
+            await getattr(self.obj, 'terminate', self.obj.stop)()
 
 
-    def include(self, parent):
+    def include(self, parent: Any) -> bool:
         inc, ret = self._should_include(parent)
         inc, ret = self._should_include(parent)
         if inc:
         if inc:
             self.obj = ret
             self.obj = ret
@@ -156,27 +156,27 @@ class ConsumerStep(StartStopStep):
     def get_consumers(self, channel):
     def get_consumers(self, channel):
         raise NotImplementedError('missing get_consumers')
         raise NotImplementedError('missing get_consumers')
 
 
-    def start(self, c):
+    async def start(self, c):
         channel = c.connection.channel()
         channel = c.connection.channel()
         self.consumers = self.get_consumers(channel)
         self.consumers = self.get_consumers(channel)
         for consumer in self.consumers or []:
         for consumer in self.consumers or []:
             consumer.consume()
             consumer.consume()
 
 
-    def stop(self, c):
-        self._close(c, True)
+    async def stop(self, c):
+        await self._close(c, True)
 
 
-    def shutdown(self, c):
-        self._close(c, False)
+    async def shutdown(self, c):
+        await self._close(c, False)
 
 
-    def _close(self, c, cancel_consumers=True):
+    async def _close(self, c, cancel_consumers=True):
         channels = set()
         channels = set()
         for consumer in self.consumers or []:
         for consumer in self.consumers or []:
             if cancel_consumers:
             if cancel_consumers:
-                ignore_errors(c.connection, consumer.cancel)
+                await ignore_errors(c.connection, consumer.cancel)
             if consumer.channel:
             if consumer.channel:
                 channels.add(consumer.channel)
                 channels.add(consumer.channel)
         for channel in channels:
         for channel in channels:
-            ignore_errors(c.connection, channel.close)
+            await ignore_errors(c.connection, channel.close)
 
 
 
 
 def _pre(ns: Step, fmt: str) -> str:
 def _pre(ns: Step, fmt: str) -> str:
@@ -235,11 +235,11 @@ class Blueprint:
 
 
     GraphFormatter = StepFormatter
     GraphFormatter = StepFormatter
 
 
-    name = None                        # type: str
-    state = None                       # type: int
-    started = 0                        # type: int
-    default_steps = set()              # type: Set[Union[str, Step]]
-    state_to_name = {                  # type: Mapping[int, str]
+    name: str = None
+    state: int = None
+    started: int = 0
+    default_steps: Set[Union[str, Step]] = set()
+    state_to_name: Mapping[int, str] = {
         0: 'initializing',
         0: 'initializing',
         RUN: 'running',
         RUN: 'running',
         CLOSE: 'closing',
         CLOSE: 'closing',
@@ -257,16 +257,16 @@ class Blueprint:
         self.on_close = on_close
         self.on_close = on_close
         self.on_stopped = on_stopped
         self.on_stopped = on_stopped
         self.shutdown_complete = Event()
         self.shutdown_complete = Event()
-        self.steps = {} : Mapping[str, Step]
+        self.steps: Mapping[str, Step] = {}
 
 
-    def start(self, parent: Any) -> None:
+    async def start(self, parent: Any) -> None:
         self.state = RUN
         self.state = RUN
         if self.on_start:
         if self.on_start:
             self.on_start()
             self.on_start()
         for i, step in enumerate(s for s in parent.steps if s is not None):
         for i, step in enumerate(s for s in parent.steps if s is not None):
             self._debug('Starting %s', step.alias)
             self._debug('Starting %s', step.alias)
             self.started = i + 1
             self.started = i + 1
-            step.start(parent)
+            await step.start(parent)
             logger.debug('^-- substep ok')
             logger.debug('^-- substep ok')
 
 
     def human_state(self) -> str:
     def human_state(self) -> str:
@@ -278,23 +278,24 @@ class Blueprint:
             info.update(step.info(parent) or {})
             info.update(step.info(parent) or {})
         return info
         return info
 
 
-    def close(self, parent: Any) -> None:
+    async def close(self, parent: Any) -> None:
         if self.on_close:
         if self.on_close:
             self.on_close()
             self.on_close()
-        self.send_all(parent, 'close', 'closing', reverse=False)
-
-    def restart(self,
-                parent: Any,
-                method: str = 'stop',
-                description: str = 'restarting',
-                propagate: bool = False) -> None:
-        self.send_all(parent, method, description, propagate=propagate)
-
-    def send_all(self, parent: Any, method: str,
-                 description: str = None,
-                 reverse: bool = True,
-                 propagate: bool = True,
-                 args: Sequence = ()) -> None:
+        await self.send_all(parent, 'close', 'closing', reverse=False)
+
+    async def restart(
+            self,
+            parent: Any,
+            method: str = 'stop',
+            description: str = 'restarting',
+            propagate: bool = False) -> None:
+        await self.send_all(parent, method, description, propagate=propagate)
+
+    async def send_all(self, parent: Any, method: str,
+                       description: str = None,
+                       reverse: bool = True,
+                       propagate: bool = True,
+                       args: Sequence = ()) -> None:
         description = description or method.replace('_', ' ')
         description = description or method.replace('_', ' ')
         steps = reversed(parent.steps) if reverse else parent.steps
         steps = reversed(parent.steps) if reverse else parent.steps
         for step in steps:
         for step in steps:
@@ -304,16 +305,16 @@ class Blueprint:
                     self._debug('%s %s...',
                     self._debug('%s %s...',
                                 description.capitalize(), step.alias)
                                 description.capitalize(), step.alias)
                     try:
                     try:
-                        fun(parent, *args)
+                        await fun(parent, *args)
                     except Exception as exc:  # pylint: disable=broad-except
                     except Exception as exc:  # pylint: disable=broad-except
                         if propagate:
                         if propagate:
                             raise
                             raise
                         logger.exception(
                         logger.exception(
                             'Error on %s %s: %r', description, step.alias, exc)
                             'Error on %s %s: %r', description, step.alias, exc)
 
 
-    def stop(self, parent: Any,
-             close: bool = True,
-             terminate: bool = False) -> None:
+    async def stop(self, parent: Any,
+                   close: bool = True,
+                   terminate: bool = False) -> None:
         what = 'terminating' if terminate else 'stopping'
         what = 'terminating' if terminate else 'stopping'
         if self.state in (CLOSE, TERMINATE):
         if self.state in (CLOSE, TERMINATE):
             return
             return
@@ -323,10 +324,10 @@ class Blueprint:
             self.state = TERMINATE
             self.state = TERMINATE
             self.shutdown_complete.set()
             self.shutdown_complete.set()
             return
             return
-        self.close(parent)
+        await self.close(parent)
         self.state = CLOSE
         self.state = CLOSE
 
 
-        self.restart(
+        await self.restart(
             parent, 'terminate' if terminate else 'stop',
             parent, 'terminate' if terminate else 'stop',
             description=what, propagate=False,
             description=what, propagate=False,
         )
         )

+ 254 - 125
celery/canvas.py

@@ -12,14 +12,21 @@ from collections import MutableSequence, deque
 from copy import deepcopy
 from copy import deepcopy
 from functools import partial as _partial, reduce
 from functools import partial as _partial, reduce
 from operator import itemgetter
 from operator import itemgetter
+from typing import (
+    Any, Callable, Dict, Iterable, Iterator,
+    List, Mapping, Sequence, Tuple, Optional,
+    cast,
+)
 
 
+from kombu.types import ConnectionT, ProducerT
 from kombu.utils.functional import fxrange, reprcall
 from kombu.utils.functional import fxrange, reprcall
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 from kombu.utils.uuid import uuid
 from kombu.utils.uuid import uuid
-from vine import barrier
+from vine import Thenable, barrier
 
 
 from celery._state import current_app
 from celery._state import current_app
 from celery.result import GroupResult
 from celery.result import GroupResult
+from celery.types import AppT, ResultT, RouterT, SignatureT, TaskT
 from celery.utils import abstract
 from celery.utils import abstract
 from celery.utils.functional import (
 from celery.utils.functional import (
     maybe_list, is_list, _regen, regen, chunks as _chunks,
     maybe_list, is_list, _regen, regen, chunks as _chunks,
@@ -34,7 +41,7 @@ __all__ = [
 ]
 ]
 
 
 
 
-def maybe_unroll_group(g):
+def maybe_unroll_group(g: SignatureT) -> Any:
     """Unroll group with only one member."""
     """Unroll group with only one member."""
     # Issue #1656
     # Issue #1656
     try:
     try:
@@ -50,11 +57,11 @@ def maybe_unroll_group(g):
         return g.tasks[0] if size == 1 else g
         return g.tasks[0] if size == 1 else g
 
 
 
 
-def task_name_from(task):
+def task_name_from(task: Any) -> str:
     return getattr(task, 'name', task)
     return getattr(task, 'name', task)
 
 
 
 
-def _upgrade(fields, sig):
+def _upgrade(fields: Mapping, sig: SignatureT) -> SignatureT:
     """Used by custom signatures in .from_dict, to keep common fields."""
     """Used by custom signatures in .from_dict, to keep common fields."""
     sig.update(chord_size=fields.get('chord_size'))
     sig.update(chord_size=fields.get('chord_size'))
     return sig
     return sig
@@ -118,17 +125,18 @@ class Signature(dict):
     """
     """
 
 
     TYPES = {}
     TYPES = {}
-    _app = _type = None
+    _app: AppT = None
+    _type: TaskT = None
 
 
     @classmethod
     @classmethod
-    def register_type(cls, name=None):
-        def _inner(subclass):
+    def register_type(cls, name: str = None) -> Callable:
+        def _inner(subclass: type) -> type:
             cls.TYPES[name or subclass.__name__] = subclass
             cls.TYPES[name or subclass.__name__] = subclass
             return subclass
             return subclass
         return _inner
         return _inner
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         typ = d.get('subtask_type')
         typ = d.get('subtask_type')
         if typ:
         if typ:
             target_cls = cls.TYPES[typ]
             target_cls = cls.TYPES[typ]
@@ -136,9 +144,16 @@ class Signature(dict):
                 return target_cls.from_dict(d, app=app)
                 return target_cls.from_dict(d, app=app)
         return Signature(d, app=app)
         return Signature(d, app=app)
 
 
-    def __init__(self, task=None, args=None, kwargs=None, options=None,
-                 type=None, subtask_type=None, immutable=False,
-                 app=None, **ex):
+    def __init__(self,
+                 task: str = None,
+                 args: Sequence = None,
+                 kwargs: Mapping = None,
+                 options: Mapping = None,
+                 type: TaskT = None,
+                 subtask_type: str = None,
+                 immutable: bool = False,
+                 app: AppT = None,
+                 **ex) -> None:
         self._app = app
         self._app = app
 
 
         if isinstance(task, dict):
         if isinstance(task, dict):
@@ -161,16 +176,17 @@ class Signature(dict):
                 chord_size=None,
                 chord_size=None,
             )
             )
 
 
-    def __call__(self, *partial_args, **partial_kwargs):
+    def __call__(self, *partial_args, **partial_kwargs) -> Any:
         """Call the task directly (in the current process)."""
         """Call the task directly (in the current process)."""
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
         return self.type(*args, **kwargs)
         return self.type(*args, **kwargs)
 
 
-    def delay(self, *partial_args, **partial_kwargs):
+    def delay(self, *partial_args, **partial_kwargs) -> ResultT:
         """Shortcut to :meth:`apply_async` using star arguments."""
         """Shortcut to :meth:`apply_async` using star arguments."""
         return self.apply_async(partial_args, partial_kwargs)
         return self.apply_async(partial_args, partial_kwargs)
 
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self, args: Sequence = (), kwargs: Mapping = {},
+              **options) -> ResultT:
         """Call task locally.
         """Call task locally.
 
 
         Same as :meth:`apply_async` but executed the task inline instead
         Same as :meth:`apply_async` but executed the task inline instead
@@ -180,7 +196,11 @@ class Signature(dict):
         args, kwargs, options = self._merge(args, kwargs, options)
         args, kwargs, options = self._merge(args, kwargs, options)
         return self.type.apply(args, kwargs, **options)
         return self.type.apply(args, kwargs, **options)
 
 
-    def apply_async(self, args=(), kwargs={}, route_name=None, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    route_name: str = None,
+                    **options) -> None:
         """Apply this task asynchronously.
         """Apply this task asynchronously.
 
 
         Arguments:
         Arguments:
@@ -209,7 +229,10 @@ class Signature(dict):
         #   Borks on this, as it's a property
         #   Borks on this, as it's a property
         return _apply(args, kwargs, **options)
         return _apply(args, kwargs, **options)
 
 
-    def _merge(self, args=(), kwargs={}, options={}, force=False):
+    def _merge(self, args: Sequence = (),
+               kwargs: Mapping = {},
+               options: Mapping = {},
+               force: bool = False) -> Tuple[Tuple, Dict, Dict]:
         if self.immutable and not force:
         if self.immutable and not force:
             return (self.args, self.kwargs,
             return (self.args, self.kwargs,
                     dict(self.options, **options) if options else self.options)
                     dict(self.options, **options) if options else self.options)
@@ -217,7 +240,10 @@ class Signature(dict):
                 dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
                 dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
                 dict(self.options, **options) if options else self.options)
                 dict(self.options, **options) if options else self.options)
 
 
-    def clone(self, args=(), kwargs={}, **opts):
+    def clone(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **opts) -> SignatureT:
         """Create a copy of this signature.
         """Create a copy of this signature.
 
 
         Arguments:
         Arguments:
@@ -240,8 +266,13 @@ class Signature(dict):
         return s
         return s
     partial = clone
     partial = clone
 
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         """Finalize the signature by adding a concrete task id.
         """Finalize the signature by adding a concrete task id.
 
 
         The task won't be called and you shouldn't call the signature
         The task won't be called and you shouldn't call the signature
@@ -273,7 +304,10 @@ class Signature(dict):
         return self.AsyncResult(tid)
         return self.AsyncResult(tid)
     _freeze = freeze
     _freeze = freeze
 
 
-    def replace(self, args=None, kwargs=None, options=None):
+    def replace(self,
+                args: Sequence = None,
+                kwargs: Mapping = None,
+                options: Mapping = None) -> SignatureT:
         """Replace the args, kwargs or options set for this signature.
         """Replace the args, kwargs or options set for this signature.
 
 
         These are only replaced if the argument for the section is
         These are only replaced if the argument for the section is
@@ -288,7 +322,7 @@ class Signature(dict):
             s.options = options
             s.options = options
         return s
         return s
 
 
-    def set(self, immutable=None, **options):
+    def set(self, immutable: bool = None, **options) -> SignatureT:
         """Set arbitrary execution options (same as ``.options.update(…)``).
         """Set arbitrary execution options (same as ``.options.update(…)``).
 
 
         Returns:
         Returns:
@@ -300,44 +334,45 @@ class Signature(dict):
         self.options.update(options)
         self.options.update(options)
         return self
         return self
 
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         self.immutable = immutable
         self.immutable = immutable
 
 
-    def _with_list_option(self, key):
+    def _with_list_option(self, key: str) -> Any:
         items = self.options.setdefault(key, [])
         items = self.options.setdefault(key, [])
         if not isinstance(items, MutableSequence):
         if not isinstance(items, MutableSequence):
             items = self.options[key] = [items]
             items = self.options[key] = [items]
         return items
         return items
 
 
-    def append_to_list_option(self, key, value):
+    def append_to_list_option(self, key: str, value: Any) -> Any:
         items = self._with_list_option(key)
         items = self._with_list_option(key)
         if value not in items:
         if value not in items:
             items.append(value)
             items.append(value)
         return value
         return value
 
 
-    def extend_list_option(self, key, value):
+    def extend_list_option(self, key: str, value: Any) -> None:
         items = self._with_list_option(key)
         items = self._with_list_option(key)
         items.extend(maybe_list(value))
         items.extend(maybe_list(value))
 
 
-    def link(self, callback):
+    def link(self, callback: SignatureT) -> SignatureT:
         """Add callback task to be applied if this task succeeds.
         """Add callback task to be applied if this task succeeds.
 
 
         Returns:
         Returns:
             Signature: the argument passed, for chaining
             Signature: the argument passed, for chaining
                 or use with :func:`~functools.reduce`.
                 or use with :func:`~functools.reduce`.
         """
         """
-        return self.append_to_list_option('link', callback)
+        return cast(SignatureT, self.append_to_list_option('link', callback))
 
 
-    def link_error(self, errback):
+    def link_error(self, errback: SignatureT) -> SignatureT:
         """Add callback task to be applied on error in task execution.
         """Add callback task to be applied on error in task execution.
 
 
         Returns:
         Returns:
             Signature: the argument passed, for chaining
             Signature: the argument passed, for chaining
                 or use with :func:`~functools.reduce`.
                 or use with :func:`~functools.reduce`.
         """
         """
-        return self.append_to_list_option('link_error', errback)
+        return cast(SignatureT,
+                    self.append_to_list_option('link_error', errback))
 
 
-    def on_error(self, errback):
+    def on_error(self, errback: SignatureT) -> SignatureT:
         """Version of :meth:`link_error` that supports chaining.
         """Version of :meth:`link_error` that supports chaining.
 
 
         on_error chains the original signature, not the errback so::
         on_error chains the original signature, not the errback so::
@@ -350,7 +385,7 @@ class Signature(dict):
         self.link_error(errback)
         self.link_error(errback)
         return self
         return self
 
 
-    def flatten_links(self):
+    def flatten_links(self) -> List[SignatureT]:
         """Return a recursive list of dependencies.
         """Return a recursive list of dependencies.
 
 
         "unchain" if you will, but with links intact.
         "unchain" if you will, but with links intact.
@@ -361,7 +396,7 @@ class Signature(dict):
                 for link in maybe_list(self.options.get('link')) or [])
                 for link in maybe_list(self.options.get('link')) or [])
         )))
         )))
 
 
-    def __or__(self, other):
+    def __or__(self, other: SignatureT) -> SignatureT:
         # These could be implemented in each individual class,
         # These could be implemented in each individual class,
         # I'm sure, but for now we have this.
         # I'm sure, but for now we have this.
         if isinstance(other, chord) and len(other.tasks) == 1:
         if isinstance(other, chord) and len(other.tasks) == 1:
@@ -431,7 +466,7 @@ class Signature(dict):
             return _chain(self, other, app=self._app)
             return _chain(self, other, app=self._app)
         return NotImplemented
         return NotImplemented
 
 
-    def election(self):
+    def election(self) -> ResultT:
         type = self.type
         type = self.type
         app = type.app
         app = type.app
         tid = self.options.get('task_id') or uuid()
         tid = self.options.get('task_id') or uuid()
@@ -442,50 +477,50 @@ class Signature(dict):
                                  connection=P.connection)
                                  connection=P.connection)
             return type.AsyncResult(tid)
             return type.AsyncResult(tid)
 
 
-    def reprcall(self, *args, **kwargs):
+    def reprcall(self, *args, **kwargs) -> str:
         args, kwargs, _ = self._merge(args, kwargs, {}, force=True)
         args, kwargs, _ = self._merge(args, kwargs, {}, force=True)
         return reprcall(self['task'], args, kwargs)
         return reprcall(self['task'], args, kwargs)
 
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo) -> SignatureT:
         memo[id(self)] = self
         memo[id(self)] = self
         return dict(self)
         return dict(self)
 
 
-    def __invert__(self):
+    def __invert__(self) -> Any:
         return self.apply_async().get()
         return self.apply_async().get()
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         # for serialization, the task type is lazily loaded,
         # for serialization, the task type is lazily loaded,
         # and not stored in the dict itself.
         # and not stored in the dict itself.
         return signature, (dict(self),)
         return signature, (dict(self),)
 
 
-    def __json__(self):
+    def __json__(self) -> Any:
         return dict(self)
         return dict(self)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.reprcall()
         return self.reprcall()
 
 
     @property
     @property
-    def name(self):
+    def name(self) -> str:
         # for duck typing compatibility with Task.name
         # for duck typing compatibility with Task.name
         return self.task
         return self.task
 
 
     @cached_property
     @cached_property
-    def type(self):
+    def type(self) -> TaskT:
         return self._type or self.app.tasks[self['task']]
         return self._type or self.app.tasks[self['task']]
 
 
     @cached_property
     @cached_property
-    def app(self):
+    def app(self) -> AppT:
         return self._app or current_app
         return self._app or current_app
 
 
     @cached_property
     @cached_property
-    def AsyncResult(self):
+    def AsyncResult(self) -> type:
         try:
         try:
             return self.type.AsyncResult
             return self.type.AsyncResult
         except KeyError:  # task not registered
         except KeyError:  # task not registered
             return self.app.AsyncResult
             return self.app.AsyncResult
 
 
     @cached_property
     @cached_property
-    def _apply_async(self):
+    def _apply_async(self) -> Callable:
         try:
         try:
             return self.type.apply_async
             return self.type.apply_async
         except KeyError:
         except KeyError:
@@ -509,7 +544,7 @@ class _chain(Signature):
     tasks = getitem_property('kwargs.tasks', 'Tasks in chain.')
     tasks = getitem_property('kwargs.tasks', 'Tasks in chain.')
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         tasks = d['kwargs']['tasks']
         tasks = d['kwargs']['tasks']
         if tasks:
         if tasks:
             if isinstance(tasks, tuple):  # aaaargh
             if isinstance(tasks, tuple):  # aaaargh
@@ -518,7 +553,7 @@ class _chain(Signature):
             tasks[0] = maybe_signature(tasks[0], app=app)
             tasks[0] = maybe_signature(tasks[0], app=app)
         return _upgrade(d, _chain(tasks, app=app, **d['options']))
         return _upgrade(d, _chain(tasks, app=app, **d['options']))
 
 
-    def __init__(self, *tasks, **options):
+    def __init__(self, *tasks, **options) -> None:
         tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
         tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
                  else tasks)
                  else tasks)
         Signature.__init__(
         Signature.__init__(
@@ -528,11 +563,11 @@ class _chain(Signature):
         self.subtask_type = 'chain'
         self.subtask_type = 'chain'
         self._frozen = None
         self._frozen = None
 
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         if self.tasks:
         if self.tasks:
             return self.apply_async(args, kwargs)
             return self.apply_async(args, kwargs)
 
 
-    def clone(self, *args, **kwargs):
+    def clone(self, *args, **kwargs) -> SignatureT:
         to_signature = maybe_signature
         to_signature = maybe_signature
         s = Signature.clone(self, *args, **kwargs)
         s = Signature.clone(self, *args, **kwargs)
         s.kwargs['tasks'] = [
         s.kwargs['tasks'] = [
@@ -541,7 +576,10 @@ class _chain(Signature):
         ]
         ]
         return s
         return s
 
 
-    def apply_async(self, args=(), kwargs={}, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **options) -> ResultT:
         # python is best at unpacking kwargs, so .run is here to do that.
         # python is best at unpacking kwargs, so .run is here to do that.
         app = self.app
         app = self.app
         if app.conf.task_always_eager:
         if app.conf.task_always_eager:
@@ -549,9 +587,19 @@ class _chain(Signature):
         return self.run(args, kwargs, app=app, **(
         return self.run(args, kwargs, app=app, **(
             dict(self.options, **options) if options else self.options))
             dict(self.options, **options) if options else self.options))
 
 
-    def run(self, args=(), kwargs={}, group_id=None, chord=None,
-            task_id=None, link=None, link_error=None,
-            producer=None, root_id=None, parent_id=None, app=None, **options):
+    def run(self,
+            args: Sequence = (),
+            kwargs: Mapping = {},
+            group_id: str = None,
+            chord: SignatureT = None,
+            task_id: str = None,
+            link: Sequence[SignatureT] = None,
+            link_error: Sequence[SignatureT] = None,
+            producer: ProducerT = None,
+            root_id: str = None,
+            parent_id: str = None,
+            app: AppT = None,
+            **options) -> ResultT:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         app = app or self.app
         app = app or self.app
@@ -580,8 +628,13 @@ class _chain(Signature):
             first_task.apply_async(**options)
             first_task.apply_async(**options)
             return results[0]
             return results[0]
 
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         _, results = self._frozen = self.prepare_steps(
         _, results = self._frozen = self.prepare_steps(
@@ -590,10 +643,17 @@ class _chain(Signature):
         )
         )
         return results[0]
         return results[0]
 
 
-    def prepare_steps(self, args, tasks,
-                      root_id=None, parent_id=None, link_error=None, app=None,
-                      last_task_id=None, group_id=None, chord_body=None,
-                      clone=True, from_dict=Signature.from_dict):
+    def prepare_steps(self, args: Sequence, tasks: Sequence[SignatureT],
+                      root_id: str = None,
+                      parent_id: str = None,
+                      link_error: Sequence[SignatureT] = None,
+                      app: AppT = None,
+                      last_task_id: str = None,
+                      group_id: str = None,
+                      chord_body: SignatureT = None,
+                      clone: bool = True,
+                      from_dict: Callable = Signature.from_dict
+                      ) -> Tuple[Sequence[SignatureT], Sequence[ResultT]]:
         app = app or self.app
         app = app or self.app
         # use chain message field for protocol 2 and later.
         # use chain message field for protocol 2 and later.
         # this avoids pickle blowing the stack on the recursion
         # this avoids pickle blowing the stack on the recursion
@@ -693,7 +753,10 @@ class _chain(Signature):
                 prev_res = node
                 prev_res = node
         return tasks, results
         return tasks, results
 
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **options) -> ResultT:
         last, fargs = None, args
         last, fargs = None, args
         for task in self.tasks:
         for task in self.tasks:
             res = task.clone(fargs).apply(
             res = task.clone(fargs).apply(
@@ -702,7 +765,7 @@ class _chain(Signature):
         return last
         return last
 
 
     @property
     @property
-    def app(self):
+    def app(self) -> AppT:
         app = self._app
         app = self._app
         if app is None:
         if app is None:
             try:
             try:
@@ -711,7 +774,7 @@ class _chain(Signature):
                 pass
                 pass
         return app or current_app
         return app or current_app
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if not self.tasks:
         if not self.tasks:
             return '<{0}@{1:#x}: empty>'.format(
             return '<{0}@{1:#x}: empty>'.format(
                 type(self).__name__, id(self))
                 type(self).__name__, id(self))
@@ -769,7 +832,7 @@ class chain(_chain):
     """
     """
 
 
     # could be function, but must be able to reference as :class:`chain`.
     # could be function, but must be able to reference as :class:`chain`.
-    def __new__(cls, *tasks, **kwargs):
+    def __new__(cls, *tasks, **kwargs) -> SignatureT:
         # This forces `chain(X, Y, Z)` to work the same way as `X | Y | Z`
         # This forces `chain(X, Y, Z)` to work the same way as `X | Y | Z`
         if not kwargs and tasks:
         if not kwargs and tasks:
             if len(tasks) == 1 and is_list(tasks[0]):
             if len(tasks) == 1 and is_list(tasks[0]):
@@ -784,18 +847,21 @@ class _basemap(Signature):
     _unpack_args = itemgetter('task', 'it')
     _unpack_args = itemgetter('task', 'it')
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
         return _upgrade(
             d, cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']),
             d, cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']),
         )
         )
 
 
-    def __init__(self, task, it, **options):
+    def __init__(self, task: str, it: Iterable, **options) -> None:
         Signature.__init__(
         Signature.__init__(
             self, self._task_name, (),
             self, self._task_name, (),
             {'task': task, 'it': regen(it)}, immutable=True, **options
             {'task': task, 'it': regen(it)}, immutable=True, **options
         )
         )
 
 
-    def apply_async(self, args=(), kwargs={}, **opts):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **opts) -> ResultT:
         # need to evaluate generators
         # need to evaluate generators
         task, it = self._unpack_args(self.kwargs)
         task, it = self._unpack_args(self.kwargs)
         return self.type.apply_async(
         return self.type.apply_async(
@@ -815,7 +881,7 @@ class xmap(_basemap):
 
 
     _task_name = 'celery.map'
     _task_name = 'celery.map'
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         task, it = self._unpack_args(self.kwargs)
         task, it = self._unpack_args(self.kwargs)
         return '[{0}(x) for x in {1}]'.format(
         return '[{0}(x) for x in {1}]'.format(
             task.task, truncate(repr(it), 100))
             task.task, truncate(repr(it), 100))
@@ -827,7 +893,7 @@ class xstarmap(_basemap):
 
 
     _task_name = 'celery.starmap'
     _task_name = 'celery.starmap'
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         task, it = self._unpack_args(self.kwargs)
         task, it = self._unpack_args(self.kwargs)
         return '[{0}(*x) for x in {1}]'.format(
         return '[{0}(*x) for x in {1}]'.format(
             task.task, truncate(repr(it), 100))
             task.task, truncate(repr(it), 100))
@@ -840,29 +906,32 @@ class chunks(Signature):
     _unpack_args = itemgetter('task', 'it', 'n')
     _unpack_args = itemgetter('task', 'it', 'n')
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
         return _upgrade(
             d, chunks(*cls._unpack_args(
             d, chunks(*cls._unpack_args(
                 d['kwargs']), app=app, **d['options']),
                 d['kwargs']), app=app, **d['options']),
         )
         )
 
 
-    def __init__(self, task, it, n, **options):
+    def __init__(self, task: str, it: Iterable, n: int, **options) -> None:
         Signature.__init__(
         Signature.__init__(
             self, 'celery.chunks', (),
             self, 'celery.chunks', (),
             {'task': task, 'it': regen(it), 'n': n},
             {'task': task, 'it': regen(it), 'n': n},
             immutable=True, **options
             immutable=True, **options
         )
         )
 
 
-    def __call__(self, **options):
+    def __call__(self, **options) -> ResultT:
         return self.apply_async(**options)
         return self.apply_async(**options)
 
 
-    def apply_async(self, args=(), kwargs={}, **opts):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **opts) -> ResultT:
         return self.group().apply_async(
         return self.group().apply_async(
             args, kwargs,
             args, kwargs,
             route_name=task_name_from(self.kwargs.get('task')), **opts
             route_name=task_name_from(self.kwargs.get('task')), **opts
         )
         )
 
 
-    def group(self):
+    def group(self) -> SignatureT:
         # need to evaluate generators
         # need to evaluate generators
         task, it, n = self._unpack_args(self.kwargs)
         task, it, n = self._unpack_args(self.kwargs)
         return group((xstarmap(task, part, app=self._app)
         return group((xstarmap(task, part, app=self._app)
@@ -870,11 +939,12 @@ class chunks(Signature):
                      app=self._app)
                      app=self._app)
 
 
     @classmethod
     @classmethod
-    def apply_chunks(cls, task, it, n, app=None):
+    def apply_chunks(cls, task: str, it: Iterable, n: int,
+                     app: AppT = None) -> ResultT:
         return cls(task, it, n, app=app)()
         return cls(task, it, n, app=app)()
 
 
 
 
-def _maybe_group(tasks, app):
+def _maybe_group(tasks: Any, app: AppT) -> Sequence[SignatureT]:
     if isinstance(tasks, dict):
     if isinstance(tasks, dict):
         tasks = signature(tasks, app=app)
         tasks = signature(tasks, app=app)
 
 
@@ -921,12 +991,12 @@ class group(Signature):
     tasks = getitem_property('kwargs.tasks', 'Tasks in group.')
     tasks = getitem_property('kwargs.tasks', 'Tasks in group.')
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
         return _upgrade(
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
         )
         )
 
 
-    def __init__(self, *tasks, **options):
+    def __init__(self, *tasks, **options) -> None:
         if len(tasks) == 1:
         if len(tasks) == 1:
             tasks = tasks[0]
             tasks = tasks[0]
             if isinstance(tasks, group):
             if isinstance(tasks, group):
@@ -938,17 +1008,27 @@ class group(Signature):
         )
         )
         self.subtask_type = 'group'
         self.subtask_type = 'group'
 
 
-    def __call__(self, *partial_args, **options):
+    def __call__(self, *partial_args, **options) -> ResultT:
         return self.apply_async(partial_args, **options)
         return self.apply_async(partial_args, **options)
 
 
-    def skew(self, start=1.0, stop=None, step=1.0):
+    def skew(self,
+             start: float = 1.0,
+             stop: float = None,
+             step: float = 1.0) -> SignatureT:
         it = fxrange(start, stop, step, repeatlast=True)
         it = fxrange(start, stop, step, repeatlast=True)
         for task in self.tasks:
         for task in self.tasks:
             task.set(countdown=next(it))
             task.set(countdown=next(it))
         return self
         return self
 
 
-    def apply_async(self, args=(), kwargs=None, add_to_parent=True,
-                    producer=None, link=None, link_error=None, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = None,
+                    *,
+                    add_to_parent: bool = True,
+                    producer: ProducerT = None,
+                    link: Sequence[SignatureT] = None,
+                    link_error: Sequence[SignatureT] = None,
+                    **options) -> ResultT:
         if link is not None:
         if link is not None:
             raise TypeError('Cannot add link to group: use a chord')
             raise TypeError('Cannot add link to group: use a chord')
         if link_error is not None:
         if link_error is not None:
@@ -982,7 +1062,10 @@ class group(Signature):
             parent_task.add_trail(result)
             parent_task.add_trail(result)
         return result
         return result
 
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **options) -> ResultT:
         app = self.app
         app = self.app
         if not self.tasks:
         if not self.tasks:
             return self.freeze()  # empty group returns GroupResult
             return self.freeze()  # empty group returns GroupResult
@@ -992,23 +1075,30 @@ class group(Signature):
             sig.apply(args=args, kwargs=kwargs, **options) for sig, _ in tasks
             sig.apply(args=args, kwargs=kwargs, **options) for sig, _ in tasks
         ])
         ])
 
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         for task in self.tasks:
         for task in self.tasks:
             task.set_immutable(immutable)
             task.set_immutable(immutable)
 
 
-    def link(self, sig):
+    def link(self, sig: SignatureT) -> SignatureT:
         # Simply link to first task
         # Simply link to first task
         sig = sig.clone().set(immutable=True)
         sig = sig.clone().set(immutable=True)
         return self.tasks[0].link(sig)
         return self.tasks[0].link(sig)
 
 
-    def link_error(self, sig):
+    def link_error(self, sig: SignatureT) -> SignatureT:
         sig = sig.clone().set(immutable=True)
         sig = sig.clone().set(immutable=True)
         return self.tasks[0].link_error(sig)
         return self.tasks[0].link_error(sig)
 
 
-    def _prepared(self, tasks, partial_args, group_id, root_id, app,
-                  CallableSignature=abstract.CallableSignature,
-                  from_dict=Signature.from_dict,
-                  isinstance=isinstance, tuple=tuple):
+    def _prepared(self, tasks: Sequence[SignatureT],
+                  partial_args: Sequence,
+                  group_id: str,
+                  root_id: str,
+                  app: AppT,
+                  *,
+                  CallableSignature: Callable = abstract.CallableSignature,
+                  from_dict: Callable = Signature.from_dict,
+                  isinstance: Callable = isinstance,
+                  tuple: Callable = tuple
+                  ) -> Iterable[Tuple[SignatureT, ResultT]]:
         for task in tasks:
         for task in tasks:
             if isinstance(task, CallableSignature):
             if isinstance(task, CallableSignature):
                 # local sigs are always of type Signature, and we
                 # local sigs are always of type Signature, and we
@@ -1029,9 +1119,15 @@ class group(Signature):
                     task.args = tuple(partial_args) + tuple(task.args)
                     task.args = tuple(partial_args) + tuple(task.args)
                 yield task, task.freeze(group_id=group_id, root_id=root_id)
                 yield task, task.freeze(group_id=group_id, root_id=root_id)
 
 
-    def _apply_tasks(self, tasks, producer=None, app=None, p=None,
-                     add_to_parent=None, chord=None,
-                     args=None, kwargs=None, **options):
+    def _apply_tasks(self, tasks: Sequence[SignatureT],
+                     producer: ProducerT = None,
+                     app: AppT = None,
+                     p: Thenable = None,
+                     add_to_parent: bool = None,
+                     chord: SignatureT = None,
+                     args: Sequence = None,
+                     kwargs: Mapping = None,
+                     **options) -> Iterable[ResultT]:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         app = app or self.app
         app = app or self.app
@@ -1053,7 +1149,7 @@ class group(Signature):
                     res.then(p, weak=True)
                     res.then(p, weak=True)
                 yield res  # <-- r.parent, etc set in the frozen result.
                 yield res  # <-- r.parent, etc set in the frozen result.
 
 
-    def _freeze_gid(self, options):
+    def _freeze_gid(self, options: Mapping) -> Tuple[Mapping, str, str]:
         # remove task_id and use that as the group_id,
         # remove task_id and use that as the group_id,
         # if we don't remove it then every task will have the same id...
         # if we don't remove it then every task will have the same id...
         options = dict(self.options, **options)
         options = dict(self.options, **options)
@@ -1061,8 +1157,13 @@ class group(Signature):
             options.pop('task_id', uuid()))
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
         return options, group_id, options.get('root_id')
 
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         opts = self.options
         opts = self.options
@@ -1089,7 +1190,9 @@ class group(Signature):
         return self.app.GroupResult(gid, results)
         return self.app.GroupResult(gid, results)
     _freeze = freeze
     _freeze = freeze
 
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
+    def _freeze_unroll(self, new_tasks: Sequence[SignatureT],
+                       group_id: str, chord: SignatureT,
+                       root_id: str, parent_id: str) -> Iterator[ResultT]:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         stack = deque(self.tasks)
         stack = deque(self.tasks)
@@ -1103,18 +1206,18 @@ class group(Signature):
                                   chord=chord, root_id=root_id,
                                   chord=chord, root_id=root_id,
                                   parent_id=parent_id)
                                   parent_id=parent_id)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.tasks:
         if self.tasks:
             return remove_repeating_from_task(
             return remove_repeating_from_task(
                 self.tasks[0]['task'],
                 self.tasks[0]['task'],
                 'group({0.tasks!r})'.format(self))
                 'group({0.tasks!r})'.format(self))
         return 'group(<empty>)'
         return 'group(<empty>)'
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.tasks)
         return len(self.tasks)
 
 
     @property
     @property
-    def app(self):
+    def app(self) -> AppT:
         app = self._app
         app = self._app
         if app is None:
         if app is None:
             try:
             try:
@@ -1153,18 +1256,25 @@ class chord(Signature):
     """
     """
 
 
     @classmethod
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         args, d['kwargs'] = cls._unpack_args(**d['kwargs'])
         args, d['kwargs'] = cls._unpack_args(**d['kwargs'])
         return _upgrade(d, cls(*args, app=app, **d))
         return _upgrade(d, cls(*args, app=app, **d))
 
 
     @staticmethod
     @staticmethod
-    def _unpack_args(header=None, body=None, **kwargs):
+    def _unpack_args(header: Sequence[SignatureT] = None,
+                     body: SignatureT = None,
+                     **kwargs) -> Tuple[Tuple, Mapping]:
         # Python signatures are better at extracting keys from dicts
         # Python signatures are better at extracting keys from dicts
         # than manually popping things off.
         # than manually popping things off.
         return (header, body), kwargs
         return (header, body), kwargs
 
 
-    def __init__(self, header, body=None, task='celery.chord',
-                 args=(), kwargs={}, app=None, **options):
+    def __init__(self, header: Sequence[SignatureT],
+                 body: SignatureT = None,
+                 task: str = 'celery.chord',
+                 args: Sequence = (),
+                 kwargs: Mapping = {},
+                 app: AppT = None,
+                 **options) -> None:
         Signature.__init__(
         Signature.__init__(
             self, task, args,
             self, task, args,
             dict(kwargs=kwargs, header=_maybe_group(header, app),
             dict(kwargs=kwargs, header=_maybe_group(header, app),
@@ -1172,11 +1282,16 @@ class chord(Signature):
         )
         )
         self.subtask_type = 'chord'
         self.subtask_type = 'chord'
 
 
-    def __call__(self, body=None, **options):
+    def __call__(self, body: SignatureT = None, **options) -> ResultT:
         return self.apply_async((), {'body': body} if body else {}, **options)
         return self.apply_async((), {'body': body} if body else {}, **options)
 
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         #   XXX chord is also a class in outer scope.
         if not isinstance(self.tasks, group):
         if not isinstance(self.tasks, group):
@@ -1200,9 +1315,15 @@ class chord(Signature):
         self.id = self.tasks.id
         self.id = self.tasks.id
         return bodyres
         return bodyres
 
 
-    def apply_async(self, args=(), kwargs={}, task_id=None,
-                    producer=None, connection=None,
-                    router=None, result_cls=None, **options):
+    def apply_async(self,
+                    args: Sequence =(),
+                    kwargs: Mapping = {},
+                    task_id: str = None,
+                    producer: ProducerT = None,
+                    connection: ConnectionT = None,
+                    router: RouterT = None,
+                    result_cls: type = None,
+                    **options) -> ResultT:
         kwargs = kwargs or {}
         kwargs = kwargs or {}
         args = (tuple(args) + tuple(self.args)
         args = (tuple(args) + tuple(self.args)
                 if args and not self.immutable else self.args)
                 if args and not self.immutable else self.args)
@@ -1223,7 +1344,10 @@ class chord(Signature):
         # chord([A, B, ...], C)
         # chord([A, B, ...], C)
         return self.run(tasks, body, args, task_id=task_id, **options)
         return self.run(tasks, body, args, task_id=task_id, **options)
 
 
-    def apply(self, args=(), kwargs={}, propagate=True, body=None, **options):
+    def apply(self, args: Sequence = (), kwargs: Mapping = {},
+              propagate: bool = True,
+              body: SignatureT = None,
+              **options) -> ResultT:
         body = self.body if body is None else body
         body = self.body if body is None else body
         tasks = (self.tasks.clone() if isinstance(self.tasks, group)
         tasks = (self.tasks.clone() if isinstance(self.tasks, group)
                  else group(self.tasks, app=self.app))
                  else group(self.tasks, app=self.app))
@@ -1231,7 +1355,8 @@ class chord(Signature):
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
         )
         )
 
 
-    def _traverse_tasks(self, tasks, value=None):
+    def _traverse_tasks(self, tasks: Sequence[SignatureT],
+                        value: Any = None) -> Iterator[SignatureT]:
         stack = deque(list(tasks))
         stack = deque(list(tasks))
         while stack:
         while stack:
             task = stack.popleft()
             task = stack.popleft()
@@ -1240,12 +1365,14 @@ class chord(Signature):
             else:
             else:
                 yield task if value is None else value
                 yield task if value is None else value
 
 
-    def __length_hint__(self):
+    def __length_hint__(self) -> int:
         return sum(self._traverse_tasks(self.tasks, 1))
         return sum(self._traverse_tasks(self.tasks, 1))
 
 
-    def run(self, header, body, partial_args, app=None, interval=None,
-            countdown=1, max_retries=None, eager=False,
-            task_id=None, **options):
+    def run(self, header: SignatureT, body: SignatureT, partial_args: Sequence,
+            app: AppT = None, interval: float = None,
+            countdown: float = 1, max_retries: int = None,
+            eager: bool = False,
+            task_id: str = None, **options) -> ResultT:
         app = app or self._get_app(body)
         app = app or self._get_app(body)
         group_id = header.options.get('task_id') or uuid()
         group_id = header.options.get('task_id') or uuid()
         root_id = body.options.get('root_id')
         root_id = body.options.get('root_id')
@@ -1267,7 +1394,7 @@ class chord(Signature):
         bodyres.parent = parent
         bodyres.parent = parent
         return bodyres
         return bodyres
 
 
-    def clone(self, *args, **kwargs):
+    def clone(self, *args, **kwargs) -> SignatureT:
         s = Signature.clone(self, *args, **kwargs)
         s = Signature.clone(self, *args, **kwargs)
         # need to make copy of body
         # need to make copy of body
         try:
         try:
@@ -1276,20 +1403,20 @@ class chord(Signature):
             pass
             pass
         return s
         return s
 
 
-    def link(self, callback):
+    def link(self, callback: SignatureT) -> SignatureT:
         self.body.link(callback)
         self.body.link(callback)
         return callback
         return callback
 
 
-    def link_error(self, errback):
+    def link_error(self, errback: SignatureT) -> SignatureT:
         self.body.link_error(errback)
         self.body.link_error(errback)
         return errback
         return errback
 
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         # changes mutability of header only, not callback.
         # changes mutability of header only, not callback.
         for task in self.tasks:
         for task in self.tasks:
             task.set_immutable(immutable)
             task.set_immutable(immutable)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.body:
         if self.body:
             if isinstance(self.body, _chain):
             if isinstance(self.body, _chain):
                 return remove_repeating_from_task(
                 return remove_repeating_from_task(
@@ -1304,10 +1431,10 @@ class chord(Signature):
         return '<chord without body: {0.tasks!r}>'.format(self)
         return '<chord without body: {0.tasks!r}>'.format(self)
 
 
     @cached_property
     @cached_property
-    def app(self):
+    def app(self) -> AppT:
         return self._get_app(self.body)
         return self._get_app(self.body)
 
 
-    def _get_app(self, body=None):
+    def _get_app(self, body: SignatureT = None) -> AppT:
         app = self._app
         app = self._app
         if app is None:
         if app is None:
             try:
             try:
@@ -1323,7 +1450,7 @@ class chord(Signature):
     body = getitem_property('kwargs.body', 'Body task of chord.')
     body = getitem_property('kwargs.body', 'Body task of chord.')
 
 
 
 
-def signature(varies, *args, **kwargs):
+def signature(varies: SignatureT, *args, **kwargs) -> SignatureT:
     """Create new signature.
     """Create new signature.
 
 
     - if the first argument is a signature already then it's cloned.
     - if the first argument is a signature already then it's cloned.
@@ -1340,11 +1467,13 @@ def signature(varies, *args, **kwargs):
     return Signature(varies, *args, **kwargs)
     return Signature(varies, *args, **kwargs)
 
 
 
 
-def maybe_signature(d, app=None, clone=False):
+def maybe_signature(d: Optional[SignatureT],
+                    app: AppT = None,
+                    clone: bool = False) -> Optional[SignatureT]:
     """Ensure obj is a signature, or None.
     """Ensure obj is a signature, or None.
 
 
     Arguments:
     Arguments:
-        d (Optional[Union[abstract.CallableSignature, Mapping]]):
+        d (Optional[SignatureT]):
             Signature or dict-serialized signature.
             Signature or dict-serialized signature.
         app (celery.Celery):
         app (celery.Celery):
             App to bind signature to.
             App to bind signature to.
@@ -1353,7 +1482,7 @@ def maybe_signature(d, app=None, clone=False):
            will be cloned when this flag is enabled.
            will be cloned when this flag is enabled.
 
 
     Returns:
     Returns:
-        Optional[abstract.CallableSignature]
+        Optional[SignatureT]
     """
     """
     if d is not None:
     if d is not None:
         if isinstance(d, abstract.CallableSignature):
         if isinstance(d, abstract.CallableSignature):

+ 9 - 2
celery/concurrency/asynpool.py

@@ -23,11 +23,12 @@ import socket
 import struct
 import struct
 import time
 import time
 
 
-from collections import Counter, deque, namedtuple
+from collections import Counter, deque
 from io import BytesIO
 from io import BytesIO
 from numbers import Integral
 from numbers import Integral
 from pickle import HIGHEST_PROTOCOL
 from pickle import HIGHEST_PROTOCOL
 from time import sleep
 from time import sleep
+from typing import NamedTuple
 from weakref import WeakValueDictionary, ref
 from weakref import WeakValueDictionary, ref
 
 
 from billiard.pool import RUN, TERMINATE, ACK, NACK, WorkersJoined
 from billiard.pool import RUN, TERMINATE, ACK, NACK, WorkersJoined
@@ -91,7 +92,13 @@ SCHED_STRATEGIES = {
 }
 }
 SCHED_STRATEGY_TO_NAME = {v: k for k, v in SCHED_STRATEGIES.items()}
 SCHED_STRATEGY_TO_NAME = {v: k for k, v in SCHED_STRATEGIES.items()}
 
 
-Ack = namedtuple('Ack', ('id', 'fd', 'payload'))
+
+class Ack(NamedTuple):
+    """Ack message payload."""
+
+    id: int
+    fd: int
+    payload: bytes
 
 
 
 
 def gen_not_started(gen):
 def gen_not_started(gen):

+ 15 - 26
celery/contrib/pytest.py

@@ -12,12 +12,11 @@ NO_WORKER = os.environ.get('NO_WORKER')
 
 
 
 
 @contextmanager
 @contextmanager
-def _create_app(request,
-                enable_logging=False,
-                use_trap=False,
-                parameters={},
-                **config):
-    # type: (Any, **Any) -> Celery
+def _create_app(request: Any,
+                enable_logging: bool = False,
+                use_trap: bool = False,
+                parameters: Mapping = {},
+                **config) -> Celery:
     """Utility context used to setup Celery app for pytest fixtures."""
     """Utility context used to setup Celery app for pytest fixtures."""
     test_app = TestApp(
     test_app = TestApp(
         set_as_current=False,
         set_as_current=False,
@@ -41,8 +40,7 @@ def _create_app(request,
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def use_celery_app_trap():
-    # type: () -> bool
+def use_celery_app_trap() -> bool:
     """You can override this fixture to enable the app trap.
     """You can override this fixture to enable the app trap.
 
 
     The app trap raises an exception whenever something attempts
     The app trap raises an exception whenever something attempts
@@ -56,8 +54,7 @@ def celery_session_app(request,
                        celery_config,
                        celery_config,
                        celery_parameters,
                        celery_parameters,
                        celery_enable_logging,
                        celery_enable_logging,
-                       use_celery_app_trap):
-    # type: (Any) -> Celery
+                       use_celery_app_trap) -> Celery:
     """Session Fixture: Return app for session fixtures."""
     """Session Fixture: Return app for session fixtures."""
     mark = request.node.get_marker('celery')
     mark = request.node.get_marker('celery')
     config = dict(celery_config, **mark.kwargs if mark else {})
     config = dict(celery_config, **mark.kwargs if mark else {})
@@ -77,8 +74,7 @@ def celery_session_worker(request,
                           celery_session_app,
                           celery_session_app,
                           celery_includes,
                           celery_includes,
                           celery_worker_pool,
                           celery_worker_pool,
-                          celery_worker_parameters):
-    # type: (Any, Celery, Sequence[str], str) -> WorkController
+                          celery_worker_parameters) -> WorkController:
     """Session Fixture: Start worker that lives throughout test suite."""
     """Session Fixture: Start worker that lives throughout test suite."""
     if not NO_WORKER:
     if not NO_WORKER:
         for module in celery_includes:
         for module in celery_includes:
@@ -90,15 +86,13 @@ def celery_session_worker(request,
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_enable_logging():
-    # type: () -> bool
+def celery_enable_logging() -> bool:
     """You can override this fixture to enable logging."""
     """You can override this fixture to enable logging."""
     return False
     return False
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_includes():
-    # type: () -> Sequence[str]
+def celery_includes() -> Sequence[str]:
     """You can override this include modules when a worker start.
     """You can override this include modules when a worker start.
 
 
     You can have this return a list of module names to import,
     You can have this return a list of module names to import,
@@ -108,8 +102,7 @@ def celery_includes():
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_worker_pool():
-    # type: () -> Union[str, Any]
+def celery_worker_pool() -> Union[str, type]:
     """You can override this fixture to set the worker pool.
     """You can override this fixture to set the worker pool.
 
 
     The "solo" pool is used by default, but you can set this to
     The "solo" pool is used by default, but you can set this to
@@ -119,8 +112,7 @@ def celery_worker_pool():
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_config():
-    # type: () -> Mapping[str, Any]
+def celery_config() -> Mapping[str, Any]:
     """Redefine this fixture to configure the test Celery app.
     """Redefine this fixture to configure the test Celery app.
 
 
     The config returned by your fixture will then be used
     The config returned by your fixture will then be used
@@ -130,8 +122,7 @@ def celery_config():
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_parameters():
-    # type: () -> Mapping[str, Any]
+def celery_parameters() -> Mapping[str, Any]:
     """Redefine this fixture to change the init parameters of test Celery app.
     """Redefine this fixture to change the init parameters of test Celery app.
 
 
     The dict returned by your fixture will then be used
     The dict returned by your fixture will then be used
@@ -141,8 +132,7 @@ def celery_parameters():
 
 
 
 
 @pytest.fixture(scope='session')
 @pytest.fixture(scope='session')
-def celery_worker_parameters():
-    # type: () -> Mapping[str, Any]
+def celery_worker_parameters() -> Mapping[str, Any]:
     """Redefine this fixture to change the init parameters of Celery workers.
     """Redefine this fixture to change the init parameters of Celery workers.
 
 
     This can be used e. g. to define queues the worker will consume tasks from.
     This can be used e. g. to define queues the worker will consume tasks from.
@@ -175,8 +165,7 @@ def celery_worker(request,
                   celery_app,
                   celery_app,
                   celery_includes,
                   celery_includes,
                   celery_worker_pool,
                   celery_worker_pool,
-                  celery_worker_parameters):
-    # type: (Any, Celery, Sequence[str], str) -> WorkController
+                  celery_worker_parameters) -> WorkController:
     """Fixture: Start worker in a thread, stop it when the test returns."""
     """Fixture: Start worker in a thread, stop it when the test returns."""
     if not NO_WORKER:
     if not NO_WORKER:
         for module in celery_includes:
         for module in celery_includes:

+ 126 - 82
celery/events/state.py

@@ -23,6 +23,7 @@ from decimal import Decimal
 from itertools import islice
 from itertools import islice
 from operator import itemgetter
 from operator import itemgetter
 from time import time
 from time import time
+from typing import Any, Iterator, List, Mapping, Set, Sequence, Tuple
 from weakref import WeakSet, ref
 from weakref import WeakSet, ref
 
 
 from kombu.clocks import timetuple
 from kombu.clocks import timetuple
@@ -64,7 +65,7 @@ R_WORKER = '<Worker: {0.hostname} ({0.status_string} clock:{0.clock})'
 R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>'
 R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>'
 
 
 #: Mapping of task event names to task state.
 #: Mapping of task event names to task state.
-TASK_EVENT_TO_STATE = {
+TASK_EVENT_TO_STATE: Mapping[str, str] = {
     'sent': states.PENDING,
     'sent': states.PENDING,
     'received': states.RECEIVED,
     'received': states.RECEIVED,
     'started': states.STARTED,
     'started': states.STARTED,
@@ -92,26 +93,31 @@ class CallableDefaultdict(defaultdict):
         ...     'proj.tasks.add', reverse=True))
         ...     'proj.tasks.add', reverse=True))
     """
     """
 
 
-    def __init__(self, fun, *args, **kwargs):
+    def __init__(self, fun: Callable, *args, **kwargs) -> None:
         self.fun = fun
         self.fun = fun
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         return self.fun(*args, **kwargs)
         return self.fun(*args, **kwargs)
 Callable.register(CallableDefaultdict)  # noqa: E305
 Callable.register(CallableDefaultdict)  # noqa: E305
 
 
 
 
 @memoize(maxsize=1000, keyfun=lambda a, _: a[0])
 @memoize(maxsize=1000, keyfun=lambda a, _: a[0])
-def _warn_drift(hostname, drift, local_received, timestamp):
+def _warn_drift(hostname: str, drift: float,
+                local_received: float, timestamp: float) -> None:
     # we use memoize here so the warning is only logged once per hostname
     # we use memoize here so the warning is only logged once per hostname
     warn(DRIFT_WARNING, hostname, drift,
     warn(DRIFT_WARNING, hostname, drift,
          datetime.fromtimestamp(local_received),
          datetime.fromtimestamp(local_received),
          datetime.fromtimestamp(timestamp))
          datetime.fromtimestamp(timestamp))
 
 
 
 
-def heartbeat_expires(timestamp, freq=60,
-                      expire_window=HEARTBEAT_EXPIRE_WINDOW,
-                      Decimal=Decimal, float=float, isinstance=isinstance):
+def heartbeat_expires(timestamp: float,
+                      freq: float = 60.0,
+                      expire_window: float = HEARTBEAT_EXPIRE_WINDOW,
+                      *,
+                      Decimal: Callable = Decimal,
+                      float: Callable = float,
+                      isinstance: Callable = isinstance) -> float:
     """Return time when heartbeat expires."""
     """Return time when heartbeat expires."""
     # some json implementations returns decimal.Decimal objects,
     # some json implementations returns decimal.Decimal objects,
     # which aren't compatible with float.
     # which aren't compatible with float.
@@ -121,26 +127,26 @@ def heartbeat_expires(timestamp, freq=60,
     return timestamp + (freq * (expire_window / 1e2))
     return timestamp + (freq * (expire_window / 1e2))
 
 
 
 
-def _depickle_task(cls, fields):
+def _depickle_task(cls: type, fields: Mapping) -> 'Task':
     return cls(**fields)
     return cls(**fields)
 
 
 
 
-def with_unique_field(attr):
+def with_unique_field(attr: str) -> Callable:
 
 
-    def _decorate_cls(cls):
+    def _decorate_cls(cls: type) -> type:
 
 
-        def __eq__(this, other):
+        def __eq__(this, other: Any) -> bool:
             if isinstance(other, this.__class__):
             if isinstance(other, this.__class__):
                 return getattr(this, attr) == getattr(other, attr)
                 return getattr(this, attr) == getattr(other, attr)
             return NotImplemented
             return NotImplemented
         cls.__eq__ = __eq__
         cls.__eq__ = __eq__
 
 
-        def __ne__(this, other):
+        def __ne__(this, other: Any) -> bool:
             res = this.__eq__(other)
             res = this.__eq__(other)
             return True if res is NotImplemented else not res
             return True if res is NotImplemented else not res
         cls.__ne__ = __ne__
         cls.__ne__ = __ne__
 
 
-        def __hash__(this):
+        def __hash__(this: Any) -> int:
             return hash(getattr(this, attr))
             return hash(getattr(this, attr))
         cls.__hash__ = __hash__
         cls.__hash__ = __hash__
 
 
@@ -161,9 +167,18 @@ class Worker:
     if not PYPY:  # pragma: no cover
     if not PYPY:  # pragma: no cover
         __slots__ = _fields + ('event', '__dict__', '__weakref__')
         __slots__ = _fields + ('event', '__dict__', '__weakref__')
 
 
-    def __init__(self, hostname=None, pid=None, freq=60,
-                 heartbeats=None, clock=0, active=None, processed=None,
-                 loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None):
+    def __init__(self,
+                 hostname: str = None,
+                 pid: int = None,
+                 freq: float = 60.0,
+                 heartbeats: Sequence[float] = None,
+                 clock: int = 0,
+                 active: int = None,
+                 processed: int = None,
+                 loadavg: Tuple[float, float, float] = None,
+                 sw_ident: str = None,
+                 sw_ver: str = None,
+                 sw_sys: str = None) -> None:
         self.hostname = hostname
         self.hostname = hostname
         self.pid = pid
         self.pid = pid
         self.freq = freq
         self.freq = freq
@@ -177,23 +192,28 @@ class Worker:
         self.sw_sys = sw_sys
         self.sw_sys = sw_sys
         self.event = self._create_event_handler()
         self.event = self._create_event_handler()
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return self.__class__, (self.hostname, self.pid, self.freq,
         return self.__class__, (self.hostname, self.pid, self.freq,
                                 self.heartbeats, self.clock, self.active,
                                 self.heartbeats, self.clock, self.active,
                                 self.processed, self.loadavg, self.sw_ident,
                                 self.processed, self.loadavg, self.sw_ident,
                                 self.sw_ver, self.sw_sys)
                                 self.sw_ver, self.sw_sys)
 
 
-    def _create_event_handler(self):
+    def _create_event_handler(self) -> Callable:
         _set = object.__setattr__
         _set = object.__setattr__
         hbmax = self.heartbeat_max
         hbmax = self.heartbeat_max
         heartbeats = self.heartbeats
         heartbeats = self.heartbeats
         hb_pop = self.heartbeats.pop
         hb_pop = self.heartbeats.pop
         hb_append = self.heartbeats.append
         hb_append = self.heartbeats.append
 
 
-        def event(type_, timestamp=None,
-                  local_received=None, fields=None,
-                  max_drift=HEARTBEAT_DRIFT_MAX, abs=abs, int=int,
-                  insort=bisect.insort, len=len):
+        def event(type_: str,
+                  timestamp: float = None,
+                  local_received: float = None,
+                  fields: Mapping = None,
+                  max_drift: float = HEARTBEAT_DRIFT_MAX,
+                  abs: Callable = abs,
+                  int: Callable = int,
+                  insort: Callable = bisect.insort,
+                  len: Callable = len) -> None:
             fields = fields or {}
             fields = fields or {}
             for k, v in fields.items():
             for k, v in fields.items():
                 _set(self, k, v)
                 _set(self, k, v)
@@ -216,28 +236,28 @@ class Worker:
                         insort(heartbeats, local_received)
                         insort(heartbeats, local_received)
         return event
         return event
 
 
-    def update(self, f, **kw):
+    def update(self, f: Mapping, **kw) -> None:
         for k, v in (dict(f, **kw) if kw else f).items():
         for k, v in (dict(f, **kw) if kw else f).items():
             setattr(self, k, v)
             setattr(self, k, v)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_WORKER.format(self)
         return R_WORKER.format(self)
 
 
     @property
     @property
-    def status_string(self):
+    def status_string(self) -> str:
         return 'ONLINE' if self.alive else 'OFFLINE'
         return 'ONLINE' if self.alive else 'OFFLINE'
 
 
     @property
     @property
-    def heartbeat_expires(self):
+    def heartbeat_expires(self) -> float:
         return heartbeat_expires(self.heartbeats[-1],
         return heartbeat_expires(self.heartbeats[-1],
                                  self.freq, self.expire_window)
                                  self.freq, self.expire_window)
 
 
     @property
     @property
-    def alive(self, nowfun=time):
+    def alive(self, *, nowfun: Callable = time) -> bool:
         return bool(self.heartbeats and nowfun() < self.heartbeat_expires)
         return bool(self.heartbeats and nowfun() < self.heartbeat_expires)
 
 
     @property
     @property
-    def id(self):
+    def id(self) -> str:
         return '{0.hostname}.{0.pid}'.format(self)
         return '{0.hostname}.{0.pid}'.format(self)
 
 
 
 
@@ -285,7 +305,11 @@ class Task:
         'root_id', 'parent_id',
         'root_id', 'parent_id',
     )
     )
 
 
-    def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
+    def __init__(self,
+                 uuid: str = None,
+                 cluster_state: 'State' = None,
+                 children: Sequence[str] = None,
+                 **kwargs) -> None:
         self.uuid = uuid
         self.uuid = uuid
         self.cluster_state = cluster_state
         self.cluster_state = cluster_state
         self.children = WeakSet(
         self.children = WeakSet(
@@ -301,10 +325,14 @@ class Task:
         if kwargs:
         if kwargs:
             self.__dict__.update(kwargs)
             self.__dict__.update(kwargs)
 
 
-    def event(self, type_, timestamp=None, local_received=None, fields=None,
-              precedence=states.precedence,
-              setattr=setattr, task_event_to_state=TASK_EVENT_TO_STATE.get,
-              RETRY=states.RETRY):
+    def event(self, type_: str,
+              timestamp: float = None,
+              local_received: float = None,
+              fields: Mapping = None,
+              precedence: Callable = states.precedence,
+              setattr: Callable = setattr,
+              task_event_to_state: Callable = TASK_EVENT_TO_STATE.get,
+              RETRY: str = states.RETRY) -> None:
         fields = fields or {}
         fields = fields or {}
 
 
         # using .get is faster than catching KeyError in this case.
         # using .get is faster than catching KeyError in this case.
@@ -331,11 +359,12 @@ class Task:
         # update current state with info from this event.
         # update current state with info from this event.
         self.__dict__.update(fields)
         self.__dict__.update(fields)
 
 
-    def info(self, fields=None, extra=[]):
+    def info(self, fields: Sequence[str] = None,
+             extra: Sequence[str] = []) -> Mapping:
         """Information about this task suitable for on-screen display."""
         """Information about this task suitable for on-screen display."""
         fields = self._info_fields if fields is None else fields
         fields = self._info_fields if fields is None else fields
 
 
-        def _keys():
+        def _keys() -> Iterator[Tuple[str, Any]]:
             for key in list(fields) + list(extra):
             for key in list(fields) + list(extra):
                 value = getattr(self, key, None)
                 value = getattr(self, key, None)
                 if value is not None:
                 if value is not None:
@@ -343,46 +372,46 @@ class Task:
 
 
         return dict(_keys())
         return dict(_keys())
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_TASK.format(self)
         return R_TASK.format(self)
 
 
-    def as_dict(self):
+    def as_dict(self) -> Mapping:
         get = object.__getattribute__
         get = object.__getattribute__
         handler = self._serializer_handlers.get
         handler = self._serializer_handlers.get
         return {
         return {
             k: handler(k, pass1)(get(self, k)) for k in self._fields
             k: handler(k, pass1)(get(self, k)) for k in self._fields
         }
         }
 
 
-    def _serializable_children(self, value):
+    def _serializable_children(self, value: Any) -> Sequence[str]:
         return [task.id for task in self.children]
         return [task.id for task in self.children]
 
 
-    def _serializable_root(self, value):
+    def _serializable_root(self, value: Any) -> str:
         return self.root_id
         return self.root_id
 
 
-    def _serializable_parent(self, value):
+    def _serializable_parent(self, value: Any) -> str:
         return self.parent_id
         return self.parent_id
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return _depickle_task, (self.__class__, self.as_dict())
         return _depickle_task, (self.__class__, self.as_dict())
 
 
     @property
     @property
-    def id(self):
+    def id(self) -> str:
         return self.uuid
         return self.uuid
 
 
     @property
     @property
-    def origin(self):
+    def origin(self) -> str:
         return self.client if self.worker is None else self.worker.id
         return self.client if self.worker is None else self.worker.id
 
 
     @property
     @property
-    def ready(self):
+    def ready(self) -> bool:
         return self.state in states.READY_STATES
         return self.state in states.READY_STATES
 
 
     @cached_property
     @cached_property
-    def parent(self):
+    def parent(self) -> 'Task':
         return self.parent_id and self.cluster_state.tasks[self.parent_id]
         return self.parent_id and self.cluster_state.tasks[self.parent_id]
 
 
     @cached_property
     @cached_property
-    def root(self):
+    def root(self) -> 'Task':
         return self.root_id and self.cluster_state.tasks[self.root_id]
         return self.root_id and self.cluster_state.tasks[self.root_id]
 
 
 
 
@@ -395,11 +424,17 @@ class State:
     task_count = 0
     task_count = 0
     heap_multiplier = 4
     heap_multiplier = 4
 
 
-    def __init__(self, callback=None,
-                 workers=None, tasks=None, taskheap=None,
-                 max_workers_in_memory=5000, max_tasks_in_memory=10000,
-                 on_node_join=None, on_node_leave=None,
-                 tasks_by_type=None, tasks_by_worker=None):
+    def __init__(self,
+                 callback: Callable = None,
+                 workers: Sequence[Worker] = None,
+                 tasks: Sequence[Task] = None,
+                 taskheap: List = None,
+                 max_workers_in_memory: int = 5000,
+                 max_tasks_in_memory: int = 10000,
+                 on_node_join: Callable = None,
+                 on_node_leave: Callable = None,
+                 tasks_by_type: Mapping = None,
+                 tasks_by_worker: Mapping = None) -> None:
         self.event_callback = callback
         self.event_callback = callback
         self.workers = (LRUCache(max_workers_in_memory)
         self.workers = (LRUCache(max_workers_in_memory)
                         if workers is None else workers)
                         if workers is None else workers)
@@ -416,23 +451,21 @@ class State:
         self._tasks_to_resolve = {}
         self._tasks_to_resolve = {}
         self.rebuild_taskheap()
         self.rebuild_taskheap()
 
 
-        # type: Mapping[TaskName, WeakSet[Task]]
-        self.tasks_by_type = CallableDefaultdict(
+        self.tasks_by_type: Mapping[str, Set[Task]] = CallableDefaultdict(
             self._tasks_by_type, WeakSet)
             self._tasks_by_type, WeakSet)
         self.tasks_by_type.update(
         self.tasks_by_type.update(
             _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
             _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
 
 
-        # type: Mapping[Hostname, WeakSet[Task]]
-        self.tasks_by_worker = CallableDefaultdict(
+        self.tasks_by_worker: Mapping[str, Set[Task] = CallableDefaultdict(
             self._tasks_by_worker, WeakSet)
             self._tasks_by_worker, WeakSet)
         self.tasks_by_worker.update(
         self.tasks_by_worker.update(
             _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
             _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
 
 
     @cached_property
     @cached_property
-    def _event(self):
+    def _event(self) -> Callable:
         return self._create_dispatcher()
         return self._create_dispatcher()
 
 
-    def freeze_while(self, fun, *args, **kwargs):
+    def freeze_while(self, fun: Callable, *args, **kwargs) -> Any:
         clear_after = kwargs.pop('clear_after', False)
         clear_after = kwargs.pop('clear_after', False)
         with self._mutex:
         with self._mutex:
             try:
             try:
@@ -441,11 +474,11 @@ class State:
                 if clear_after:
                 if clear_after:
                     self._clear()
                     self._clear()
 
 
-    def clear_tasks(self, ready=True):
+    def clear_tasks(self, ready: bool = True) -> None:
         with self._mutex:
         with self._mutex:
-            return self._clear_tasks(ready)
+            self._clear_tasks(ready)
 
 
-    def _clear_tasks(self, ready=True):
+    def _clear_tasks(self, ready: bool = True) -> None:
         if ready:
         if ready:
             in_progress = {
             in_progress = {
                 uuid: task for uuid, task in self.itertasks()
                 uuid: task for uuid, task in self.itertasks()
@@ -457,17 +490,18 @@ class State:
             self.tasks.clear()
             self.tasks.clear()
         self._taskheap[:] = []
         self._taskheap[:] = []
 
 
-    def _clear(self, ready=True):
+    def _clear(self, ready: bool = True) -> None:
         self.workers.clear()
         self.workers.clear()
         self._clear_tasks(ready)
         self._clear_tasks(ready)
         self.event_count = 0
         self.event_count = 0
         self.task_count = 0
         self.task_count = 0
 
 
-    def clear(self, ready=True):
+    def clear(self, ready: bool = True) -> None:
         with self._mutex:
         with self._mutex:
-            return self._clear(ready)
+            self._clear(ready)
 
 
-    def get_or_create_worker(self, hostname, **kwargs):
+    def get_or_create_worker(self, hostname: str,
+                             **kwargs) -> Tuple[Worker, bool]:
         """Get or create worker by hostname.
         """Get or create worker by hostname.
 
 
         Returns:
         Returns:
@@ -483,7 +517,7 @@ class State:
                 hostname, **kwargs)
                 hostname, **kwargs)
             return worker, True
             return worker, True
 
 
-    def get_or_create_task(self, uuid):
+    def get_or_create_task(self, uuid: str) -> Tuple[Task, bool]:
         """Get or create task by uuid."""
         """Get or create task by uuid."""
         try:
         try:
             return self.tasks[uuid], False
             return self.tasks[uuid], False
@@ -491,11 +525,11 @@ class State:
             task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
             task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
             return task, True
             return task, True
 
 
-    def event(self, event):
+    def event(self, event: Mapping) -> Tuple[Any, str]:
         with self._mutex:
         with self._mutex:
             return self._event(event)
             return self._event(event)
 
 
-    def _create_dispatcher(self):
+    def _create_dispatcher(self) -> Callable:
         # noqa: C901
         # noqa: C901
         # pylint: disable=too-many-statements
         # pylint: disable=too-many-statements
         # This code is highly optimized, but not for reusability.
         # This code is highly optimized, but not for reusability.
@@ -522,9 +556,12 @@ class State:
         get_task_by_type_set = self.tasks_by_type.__getitem__
         get_task_by_type_set = self.tasks_by_type.__getitem__
         get_task_by_worker_set = self.tasks_by_worker.__getitem__
         get_task_by_worker_set = self.tasks_by_worker.__getitem__
 
 
-        def _event(event,
-                   timetuple=timetuple, KeyError=KeyError,
-                   insort=bisect.insort, created=True):
+        def _event(event: Mapping,
+                   *,
+                   timetuple: Callable = timetuple,
+                   KeyError: type = KeyError,
+                   insort: Callable = bisect.insort,
+                   created: bool = True) -> Tuple[Any, str]:
             self.event_count += 1
             self.event_count += 1
             if event_callback:
             if event_callback:
                 event_callback(self, event)
                 event_callback(self, event)
@@ -618,27 +655,29 @@ class State:
                 return (task, task_created), subject
                 return (task, task_created), subject
         return _event
         return _event
 
 
-    def _add_pending_task_child(self, task):
+    def _add_pending_task_child(self, task: Task) -> None:
         try:
         try:
             ch = self._tasks_to_resolve[task.parent_id]
             ch = self._tasks_to_resolve[task.parent_id]
         except KeyError:
         except KeyError:
             ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
             ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
         ch.add(task)
         ch.add(task)
 
 
-    def rebuild_taskheap(self, timetuple=timetuple):
+    def rebuild_taskheap(self, *, timetuple: Callable = timetuple) -> None:
         heap = self._taskheap[:] = [
         heap = self._taskheap[:] = [
             timetuple(t.clock, t.timestamp, t.origin, ref(t))
             timetuple(t.clock, t.timestamp, t.origin, ref(t))
             for t in self.tasks.values()
             for t in self.tasks.values()
         ]
         ]
         heap.sort()
         heap.sort()
 
 
-    def itertasks(self, limit=None):
+    def itertasks(self, limit: int = None) -> Iterator[Task]:
         for index, row in enumerate(self.tasks.items()):
         for index, row in enumerate(self.tasks.items()):
             yield row
             yield row
             if limit and index + 1 >= limit:
             if limit and index + 1 >= limit:
                 break
                 break
 
 
-    def tasks_by_time(self, limit=None, reverse=True):
+    def tasks_by_time(self,
+                      limit: int = None,
+                      reverse: bool = True) -> Iterator[Tuple[str, Task]]:
         """Generator yielding tasks ordered by time.
         """Generator yielding tasks ordered by time.
 
 
         Yields:
         Yields:
@@ -658,7 +697,9 @@ class State:
                     seen.add(uuid)
                     seen.add(uuid)
     tasks_by_timestamp = tasks_by_time
     tasks_by_timestamp = tasks_by_time
 
 
-    def _tasks_by_type(self, name, limit=None, reverse=True):
+    def _tasks_by_type(self, name: str,
+                       limit: int = None,
+                       reverse: bool = True) -> Iterator[str, Task]:
         """Get all tasks by type.
         """Get all tasks by type.
 
 
         This is slower than accessing :attr:`tasks_by_type`,
         This is slower than accessing :attr:`tasks_by_type`,
@@ -673,7 +714,9 @@ class State:
             0, limit,
             0, limit,
         )
         )
 
 
-    def _tasks_by_worker(self, hostname, limit=None, reverse=True):
+    def _tasks_by_worker(self, hostname: str,
+                         limit: int = None,
+                         reverse: bool = True) -> Iterator[str, Task]:
         """Get all tasks by worker.
         """Get all tasks by worker.
 
 
         Slower than accessing :attr:`tasks_by_worker`, but ordered by time.
         Slower than accessing :attr:`tasks_by_worker`, but ordered by time.
@@ -684,18 +727,18 @@ class State:
             0, limit,
             0, limit,
         )
         )
 
 
-    def task_types(self):
+    def task_types(self) -> List[str]:
         """Return a list of all seen task types."""
         """Return a list of all seen task types."""
         return sorted(self._seen_types)
         return sorted(self._seen_types)
 
 
-    def alive_workers(self):
+    def alive_workers(self) -> Iterator[Worker]:
         """Return a list of (seemingly) alive workers."""
         """Return a list of (seemingly) alive workers."""
         return (w for w in self.workers.values() if w.alive)
         return (w for w in self.workers.values() if w.alive)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_STATE.format(self)
         return R_STATE.format(self)
 
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return self.__class__, (
         return self.__class__, (
             self.event_callback, self.workers, self.tasks, None,
             self.event_callback, self.workers, self.tasks, None,
             self.max_workers_in_memory, self.max_tasks_in_memory,
             self.max_workers_in_memory, self.max_tasks_in_memory,
@@ -705,10 +748,11 @@ class State:
         )
         )
 
 
 
 
-def _serialize_Task_WeakSet_Mapping(mapping):
+def _serialize_Task_WeakSet_Mapping(mapping: Mapping) -> Mapping:
     return {name: [t.id for t in tasks] for name, tasks in mapping.items()}
     return {name: [t.id for t in tasks] for name, tasks in mapping.items()}
 
 
 
 
-def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
+def _deserialize_Task_WeakSet_Mapping(
+        mapping: Mapping, tasks: Sequence[Task]) -> Mapping[str, Set[Task]]:
     return {name: WeakSet(tasks[i] for i in ids if i in tasks)
     return {name: WeakSet(tasks[i] for i in ids if i in tasks)
             for name, ids in (mapping or {}).items()}
             for name, ids in (mapping or {}).items()}

+ 0 - 4
celery/platforms.py

@@ -14,8 +14,6 @@ import signal as _signal
 import sys
 import sys
 import warnings
 import warnings
 
 
-from collections import namedtuple
-
 from billiard.compat import get_fdmax, close_open_fds
 from billiard.compat import get_fdmax, close_open_fds
 # fileno used to be in this module
 # fileno used to be in this module
 from kombu.utils.compat import maybe_fileno
 from kombu.utils.compat import maybe_fileno
@@ -65,8 +63,6 @@ PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | ((os.R_OK))
 PIDLOCKED = """ERROR: Pidfile ({0}) already exists.
 PIDLOCKED = """ERROR: Pidfile ({0}) already exists.
 Seems we're already running? (pid: {1})"""
 Seems we're already running? (pid: {1})"""
 
 
-_range = namedtuple('_range', ('start', 'stop'))
-
 C_FORCE_ROOT = os.environ.get('C_FORCE_ROOT', False)
 C_FORCE_ROOT = os.environ.get('C_FORCE_ROOT', False)
 
 
 ROOT_DISALLOWED = """\
 ROOT_DISALLOWED = """\

+ 11 - 2
celery/schedules.py

@@ -4,8 +4,9 @@ import numbers
 import re
 import re
 
 
 from bisect import bisect, bisect_left
 from bisect import bisect, bisect_left
-from collections import Iterable, namedtuple
+from collections import Iterable
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
+from typing import NamedTuple
 
 
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
@@ -21,7 +22,6 @@ __all__ = [
     'maybe_schedule', 'solar',
     'maybe_schedule', 'solar',
 ]
 ]
 
 
-schedstate = namedtuple('schedstate', ('is_due', 'next'))
 
 
 CRON_PATTERN_INVALID = """\
 CRON_PATTERN_INVALID = """\
 Invalid crontab pattern.  Valid range is {min}-{max}. \
 Invalid crontab pattern.  Valid range is {min}-{max}. \
@@ -50,6 +50,15 @@ SOLAR_INVALID_EVENT = """\
 Argument event "{event}" is invalid, must be one of {all_events}.\
 Argument event "{event}" is invalid, must be one of {all_events}.\
 """
 """
 
 
+Kailuga1
+
+
+class schedstate(NamedTuple):
+    """Return value of ``schedule.is_due``."""
+
+    is_due: bool
+    next: float
+
 
 
 def cronfield(s):
 def cronfield(s):
     return '*' if s is None else s
     return '*' if s is None else s

+ 13 - 17
celery/utils/collections.py

@@ -103,7 +103,7 @@ class DictAttribute:
     `obj[k] = val -> obj.k = val`
     `obj[k] = val -> obj.k = val`
     """
     """
 
 
-    obj = None  # type: Mapping[Any, Any]
+    obj: Mapping = None
 
 
     def __init__(self, obj: Mapping) -> None:
     def __init__(self, obj: Mapping) -> None:
         object.__setattr__(self, 'obj', obj)
         object.__setattr__(self, 'obj', obj)
@@ -155,10 +155,10 @@ MutableMapping.register(DictAttribute)  # noqa: E305
 class ChainMap(MutableMapping):
 class ChainMap(MutableMapping):
     """Key lookup on a sequence of maps."""
     """Key lookup on a sequence of maps."""
 
 
-    key_t = None      # type: Optional[KeyCallback]
-    changes = None    # type: Mapping
-    defaults = None   # type: Sequence[Mapping]
-    maps = None       # type: Sequence[Mapping]
+    key_t: Optional[KeyCallback] = None
+    changes: Mapping = None
+    defaults: Sequence[Mapping] = None
+    maps: Sequence[Mapping] = None
 
 
     def __init__(self, *maps: Sequence[Mapping],
     def __init__(self, *maps: Sequence[Mapping],
                  key_t: KeyCallback = None, **kwargs) -> None:
                  key_t: KeyCallback = None, **kwargs) -> None:
@@ -305,7 +305,6 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
         return key,
         return key,
 
 
     def __getitem__(self, key: str) -> Any:
     def __getitem__(self, key: str) -> Any:
-        # type: (str) -> Any
         keys = self._to_keys(key)
         keys = self._to_keys(key)
         getitem = super(ConfigurationView, self).__getitem__
         getitem = super(ConfigurationView, self).__getitem__
         for k in keys + (
         for k in keys + (
@@ -420,11 +419,9 @@ class LimitedSet:
         self.minlen = 0 if minlen is None else minlen
         self.minlen = 0 if minlen is None else minlen
         self.expires = 0 if expires is None else expires
         self.expires = 0 if expires is None else expires
 
 
-        # type: Mapping[str, Any]
-        self._data = {}
+        self._data: Mapping[str, Any] = {}
 
 
-        # type: Sequence[Tuple[float, Any]]
-        self._heap = []
+        self._heap: Sequence[Tuple[float, Any]] = []
 
 
         if data:
         if data:
             # import items from data
             # import items from data
@@ -577,7 +574,7 @@ MutableSet.register(LimitedSet)  # noqa: E305
 class Evictable:
 class Evictable:
     """Mixin for classes supporting the ``evict`` method."""
     """Mixin for classes supporting the ``evict`` method."""
 
 
-    Empty = Empty  # type: Exception
+    Empty = Empty
 
 
     def evict(self) -> None:
     def evict(self) -> None:
         """Force evict until maxsize is enforced."""
         """Force evict until maxsize is enforced."""
@@ -607,12 +604,12 @@ class Messagebuffer(Evictable):
                  maxsize: Optional[int],
                  maxsize: Optional[int],
                  iterable: Optional[Iterable]=None, deque: Any=deque) -> None:
                  iterable: Optional[Iterable]=None, deque: Any=deque) -> None:
         self.maxsize = maxsize
         self.maxsize = maxsize
-        self.data = deque(iterable or [])  # type: deque
+        self.data = deque(iterable or [])
 
 
-        self._append = self.data.append    # type: Callable[[Any], None]
-        self._pop = self.data.popleft      # type: Callable[[], Any]
-        self._len = self.data.__len__      # type: Callable[[], int]
-        self._extend = self.data.extend    # type: Callable[[Iterable], None]
+        self._append = self.data.append
+        self._pop = self.data.popleft
+        self._len = self.data.__len__
+        self._extend = self.data.extend
 
 
     def put(self, item: Any) -> None:
     def put(self, item: Any) -> None:
         self._append(item)
         self._append(item)
@@ -682,7 +679,6 @@ class BufferMap(OrderedDict, Evictable):
         if iterable:
         if iterable:
             self.update(iterable)
             self.update(iterable)
 
 
-        # type: int
         self.total = sum(len(buf) for buf in self.items())
         self.total = sum(len(buf) for buf in self.items())
 
 
     def put(self, key: Any, item: Any) -> None:
     def put(self, key: Any, item: Any) -> None:

+ 3 - 3
celery/utils/debug.py

@@ -29,7 +29,7 @@ __all__ = [
     'humanbytes', 'mem_rss', 'ps', 'cry',
     'humanbytes', 'mem_rss', 'ps', 'cry',
 ]
 ]
 
 
-UNITS = (               # type: Sequence[Tuple[float, str]]
+UNITS: Sequence[Tuple[float, str]] = (
     (2 ** 40.0, 'TB'),
     (2 ** 40.0, 'TB'),
     (2 ** 30.0, 'GB'),
     (2 ** 30.0, 'GB'),
     (2 ** 20.0, 'MB'),
     (2 ** 20.0, 'MB'),
@@ -37,8 +37,8 @@ UNITS = (               # type: Sequence[Tuple[float, str]]
     (0.0, 'b'),
     (0.0, 'b'),
 )
 )
 
 
-_process = None         # type: Optional[Process]
-_mem_sample = []        # type: MutableSequence[str]
+_process: Process = None
+_mem_sample: MutableSequence[str] = []
 
 
 
 
 def _on_blocking(signum: int, frame: Any) -> None:
 def _on_blocking(signum: int, frame: Any) -> None:

+ 5 - 5
celery/utils/functional.py

@@ -50,8 +50,8 @@ class mlazy(lazy):
     """
     """
 
 
     #: Set to :const:`True` after the object has been evaluated.
     #: Set to :const:`True` after the object has been evaluated.
-    evaluated = False  # type: bool
-    _value = None      # type: Any
+    evaluated: bool = False
+    _value: Any = None
 
 
     def evaluate(self) -> Any:
     def evaluate(self) -> Any:
         if not self.evaluated:
         if not self.evaluated:
@@ -169,7 +169,7 @@ def mattrgetter(*attrs: str) -> Callable[[Any], Mapping[str, Any]]:
 
 
 def uniq(it: Iterable) -> Iterable[Any]:
 def uniq(it: Iterable) -> Iterable[Any]:
     """Return all unique elements in ``it``, preserving order."""
     """Return all unique elements in ``it``, preserving order."""
-    seen = set()  # type: MutableSet
+    seen = set()
     return (seen.add(obj) or obj for obj in it if obj not in seen)
     return (seen.add(obj) or obj for obj in it if obj not in seen)
 
 
 
 
@@ -274,7 +274,7 @@ def head_from_fun(fun: Callable,
         name, fun = fun.__class__.__name__, fun.__call__
         name, fun = fun.__class__.__name__, fun.__call__
     else:
     else:
         name = fun.__name__
         name = fun.__name__
-    definition = FUNHEAD_TEMPLATE.format(   # type: str
+    definition = FUNHEAD_TEMPLATE.format(
         fun_name=name,
         fun_name=name,
         fun_args=_argsfromspec(getfullargspec(fun)),
         fun_args=_argsfromspec(getfullargspec(fun)),
         fun_value=1,
         fun_value=1,
@@ -285,7 +285,7 @@ def head_from_fun(fun: Callable,
     # pylint: disable=exec-used
     # pylint: disable=exec-used
     # Tasks are rarely, if ever, created at runtime - exec here is fine.
     # Tasks are rarely, if ever, created at runtime - exec here is fine.
     exec(definition, namespace)
     exec(definition, namespace)
-    result = namespace[name]  # type: Any
+    result: Any = namespace[name]
     result._source = definition
     result._source = definition
     if bound:
     if bound:
         return partial(result, object())
         return partial(result, object())

+ 7 - 7
celery/utils/graph.py

@@ -47,7 +47,7 @@ class DependencyGraph(Iterable):
     def __init__(self, it: Optional[Iterable]=None,
     def __init__(self, it: Optional[Iterable]=None,
                  formatter: Optional['GraphFormatter']=None) -> None:
                  formatter: Optional['GraphFormatter']=None) -> None:
         self.formatter = formatter or GraphFormatter()
         self.formatter = formatter or GraphFormatter()
-        self.adjacent = {}  # type: Dict[Any, Any]
+        self.adjacent: Dict[Any, Any] = {}
         if it is not None:
         if it is not None:
             self.update(it)
             self.update(it)
 
 
@@ -116,8 +116,8 @@ class DependencyGraph(Iterable):
 
 
         See https://en.wikipedia.org/wiki/Topological_sorting
         See https://en.wikipedia.org/wiki/Topological_sorting
         """
         """
-        count = Counter()  # type: Counter
-        result = []        # type: MutableSequence[Any]
+        count = Counter()
+        result: MutableSequence[Any] = []
 
 
         for node in self:
         for node in self:
             for successor in self[node]:
             for successor in self[node]:
@@ -141,9 +141,9 @@ class DependencyGraph(Iterable):
         See Also:
         See Also:
             :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
             :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
         """
         """
-        result = []  # type: MutableSequence[Any]
-        stack = []   # type: MutableSequence[Any]
-        low = {}     # type: Dict[Any, Any]
+        result: MutableSequence[Any] = []
+        stack: MutableSequence[Any] = []
+        low: Dict[Any, Any] = {}
 
 
         def visit(node):
         def visit(node):
             if node in low:
             if node in low:
@@ -178,7 +178,7 @@ class DependencyGraph(Iterable):
             formatter (celery.utils.graph.GraphFormatter): Custom graph
             formatter (celery.utils.graph.GraphFormatter): Custom graph
                 formatter to use.
                 formatter to use.
         """
         """
-        seen = set()  # type: MutableSet
+        seen: MutableSet = set()
         draw = formatter or self.formatter
         draw = formatter or self.formatter
 
 
         def P(s):
         def P(s):

+ 6 - 6
celery/utils/imports.py

@@ -8,7 +8,7 @@ import warnings
 from contextlib import contextmanager
 from contextlib import contextmanager
 from imp import reload
 from imp import reload
 from types import ModuleType
 from types import ModuleType
-from typing import Any, Callable, Iterator, Optional
+from typing import Any, Callable, Iterator
 from kombu.utils.imports import symbol_by_name
 from kombu.utils.imports import symbol_by_name
 
 
 #: Billiard sets this when execv is enabled.
 #: Billiard sets this when execv is enabled.
@@ -65,8 +65,8 @@ def cwd_in_path() -> Iterator:
 
 
 
 
 def find_module(module: str,
 def find_module(module: str,
-                path: Optional[str]=None,
-                imp: Optional[Callable]=None) -> ModuleType:
+                path: str = None,
+                imp: Callable = None) -> ModuleType:
     """Version of :func:`imp.find_module` supporting dots."""
     """Version of :func:`imp.find_module` supporting dots."""
     if imp is None:
     if imp is None:
         imp = importlib.import_module
         imp = importlib.import_module
@@ -86,8 +86,8 @@ def find_module(module: str,
 
 
 
 
 def import_from_cwd(module: str,
 def import_from_cwd(module: str,
-                    imp: Optional[Callable]=None,
-                    package: Optional[str]=None) -> ModuleType:
+                    imp: Callable = None,
+                    package: str = None) -> ModuleType:
     """Import module, temporarily including modules in the current directory.
     """Import module, temporarily including modules in the current directory.
 
 
     Modules located in the current directory has
     Modules located in the current directory has
@@ -100,7 +100,7 @@ def import_from_cwd(module: str,
 
 
 
 
 def reload_from_cwd(module: ModuleType,
 def reload_from_cwd(module: ModuleType,
-                    reloader: Optional[Callable]=None) -> Any:
+                    reloader: Callable = None) -> Any:
     """Reload module (ensuring that CWD is in sys.path)."""
     """Reload module (ensuring that CWD is in sys.path)."""
     if reloader is None:
     if reloader is None:
         reloader = reload
         reloader = reload

+ 1 - 1
celery/utils/log.py

@@ -34,7 +34,7 @@ RESERVED_LOGGER_NAMES = {'celery', 'celery.task'}
 # Every logger in the celery package inherits from the "celery"
 # Every logger in the celery package inherits from the "celery"
 # logger, and every task logger inherits from the "celery.task"
 # logger, and every task logger inherits from the "celery.task"
 # logger.
 # logger.
-base_logger = logger = _get_logger('celery')  # type: logging.Logger
+base_logger = logger = _get_logger('celery')
 
 
 
 
 def set_in_sighandler(value: bool) -> None:
 def set_in_sighandler(value: bool) -> None:

+ 34 - 17
celery/utils/saferepr.py

@@ -11,13 +11,13 @@ Differences from regular :func:`repr`:
 Very slow with no limits, super quick with limits.
 Very slow with no limits, super quick with limits.
 """
 """
 import traceback
 import traceback
-from collections import Mapping, deque, namedtuple
+from collections import Mapping, deque
 from decimal import Decimal
 from decimal import Decimal
 from itertools import chain
 from itertools import chain
 from numbers import Number
 from numbers import Number
 from pprint import _recursion
 from pprint import _recursion
 from typing import (
 from typing import (
-    Any, AnyStr, Callable, Iterator, Set, Sequence, Tuple,
+    Any, AnyStr, Callable, Iterator, NamedTuple, Set, Sequence, Tuple,
 )
 )
 from .text import truncate
 from .text import truncate
 
 
@@ -26,25 +26,42 @@ __all__ = ['saferepr', 'reprstream']
 # pylint: disable=redefined-outer-name
 # pylint: disable=redefined-outer-name
 # We cache globals and attribute lookups, so disable this warning.
 # We cache globals and attribute lookups, so disable this warning.
 
 
-#: Node representing literal text.
-#:   - .value: is the literal text value
-#:   - .truncate: specifies if this text can be truncated, for things like
-#:                LIT_DICT_END this will be False, as we always display
-#:                the ending brackets, e.g:  [[[1, 2, 3, ...,], ..., ]]
-#:   - .direction: If +1 the current level is increment by one,
-#:                 if -1 the current level is decremented by one, and
-#:                 if 0 the current level is unchanged.
-_literal = namedtuple('_literal', ('value', 'truncate', 'direction'))
 
 
-#: Node representing a dictionary key.
-_key = namedtuple('_key', ('value',))
+class _literal(NamedTuple):
+    """Node representing literal text.
 
 
-#: Node representing quoted text, e.g. a string value.
-_quoted = namedtuple('_quoted', ('value',))
+    Attributes:
+       - .value: is the literal text value
+       - .truncate: specifies if this text can be truncated, for things like
+                    LIT_DICT_END this will be False, as we always display
+                    the ending brackets, e.g:  [[[1, 2, 3, ...,], ..., ]]
+       - .direction: If +1 the current level is increment by one,
+                     if -1 the current level is decremented by one, and
+                     if 0 the current level is unchanged.
+    """
+
+    value: str
+    truncate: bool
+    direction: int
+
+
+class _key(NamedTuple):
+    """Node representing a dictionary key."""
+
+    value: str
+
+
+class _quoted(NamedTuple):
+    """Node representing quoted text, e.g. a string value."""
+
+    value: str
+
+
+class _dirty(NamedTuple):
+    """Recursion protection."""
 
 
+    objid: int
 
 
-#: Recursion protection.
-_dirty = namedtuple('_dirty', ('objid',))
 
 
 #: Types that are repsented as chars.
 #: Types that are repsented as chars.
 chars_t = (bytes, str)
 chars_t = (bytes, str)

+ 3 - 3
celery/utils/serialization.py

@@ -105,13 +105,13 @@ class UnpickleableExceptionWrapper(Exception):
     """
     """
 
 
     #: The module of the original exception.
     #: The module of the original exception.
-    exc_module = None       # type: str
+    exc_module: str = None
 
 
     #: The name of the original exception class.
     #: The name of the original exception class.
-    exc_cls_name = None     # type: str
+    exc_cls_name: str = None
 
 
     #: The arguments for the original exception.
     #: The arguments for the original exception.
-    exc_args = None         # type: Sequence[Any]
+    exc_args: Sequence[Any] = None
 
 
     def __init__(self, exc_module: str, exc_cls_name: str,
     def __init__(self, exc_module: str, exc_cls_name: str,
                  exc_args: Sequence[Any], text: Optional[str]=None) -> None:
                  exc_args: Sequence[Any], text: Optional[str]=None) -> None:

+ 2 - 4
celery/utils/static/__init__.py

@@ -2,13 +2,11 @@
 import os
 import os
 
 
 
 
-def get_file(*args):
-    # type: (*str) -> str
+def get_file(*args) -> str:
     """Get filename for static file."""
     """Get filename for static file."""
     return os.path.join(os.path.abspath(os.path.dirname(__file__)), *args)
     return os.path.join(os.path.abspath(os.path.dirname(__file__)), *args)
 
 
 
 
-def logo():
-    # type: () -> bytes
+def logo() -> bytes:
     """Celery logo image."""
     """Celery logo image."""
     return get_file('celery_128.png')
     return get_file('celery_128.png')

+ 7 - 4
celery/utils/sysinfo.py

@@ -1,15 +1,18 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """System information utilities."""
 """System information utilities."""
 import os
 import os
-from collections import namedtuple
 from math import ceil
 from math import ceil
+from typing import NamedTuple
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 __all__ = ['load_average', 'load_average_t', 'df']
 __all__ = ['load_average', 'load_average_t', 'df']
 
 
-load_average_t = namedtuple('load_average_t', (
-    'min_1', 'min_5', 'min_15',
-))
+class load_average_t(NamedTuple):
+    """Load average information triple."""
+
+    min_1: float
+    min_5: float
+    min_15: float
 
 
 
 
 def _avg(f: float) -> float:
 def _avg(f: float) -> float:

+ 27 - 26
celery/worker/autoscale.py

@@ -10,15 +10,13 @@ the :option:`celery worker --autoscale` option is used.
 """
 """
 import os
 import os
 import threading
 import threading
-
 from time import monotonic, sleep
 from time import monotonic, sleep
-
+from typing import Mapping, Optional, Tuple
 from kombu.async.semaphore import DummyLock
 from kombu.async.semaphore import DummyLock
-
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import AutoscalerT, LoopT, PoolT, RequestT, WorkerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
-
 from . import state
 from . import state
 from .components import Pool
 from .components import Pool
 
 
@@ -37,11 +35,11 @@ class WorkerComponent(bootsteps.StartStopStep):
     conditional = True
     conditional = True
     requires = (Pool,)
     requires = (Pool,)
 
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         self.enabled = w.autoscale
         self.enabled = w.autoscale
         w.autoscaler = None
         w.autoscaler = None
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> Optional[AutoscalerT]:
         scaler = w.autoscaler = self.instantiate(
         scaler = w.autoscaler = self.instantiate(
             w.autoscaler_cls,
             w.autoscaler_cls,
             w.pool, w.max_concurrency, w.min_concurrency,
             w.pool, w.max_concurrency, w.min_concurrency,
@@ -49,7 +47,7 @@ class WorkerComponent(bootsteps.StartStopStep):
         )
         )
         return scaler if not w.use_eventloop else None
         return scaler if not w.use_eventloop else None
 
 
-    def register_with_event_loop(self, w, hub):
+    def register_with_event_loop(self, w: WorkerT, hub: LoopT) -> None:
         w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
         w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
         hub.call_repeatedly(
         hub.call_repeatedly(
             w.autoscaler.keepalive, w.autoscaler.maybe_scale,
             w.autoscaler.keepalive, w.autoscaler.maybe_scale,
@@ -59,26 +57,29 @@ class WorkerComponent(bootsteps.StartStopStep):
 class Autoscaler(bgThread):
 class Autoscaler(bgThread):
     """Background thread to autoscale pool workers."""
     """Background thread to autoscale pool workers."""
 
 
-    def __init__(self, pool, max_concurrency,
-                 min_concurrency=0, worker=None,
-                 keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
+    _last_scale_up: float = None
+
+    def __init__(self, pool: PoolT, max_concurrency: int,
+                 min_concurrency: int = 0,
+                 worker: WorkerT = None,
+                 keepalive: float = AUTOSCALE_KEEPALIVE,
+                 mutex: threading.Lock = None) -> None:
         super(Autoscaler, self).__init__()
         super(Autoscaler, self).__init__()
         self.pool = pool
         self.pool = pool
         self.mutex = mutex or threading.Lock()
         self.mutex = mutex or threading.Lock()
         self.max_concurrency = max_concurrency
         self.max_concurrency = max_concurrency
         self.min_concurrency = min_concurrency
         self.min_concurrency = min_concurrency
         self.keepalive = keepalive
         self.keepalive = keepalive
-        self._last_scale_up = None
         self.worker = worker
         self.worker = worker
 
 
         assert self.keepalive, 'cannot scale down too fast.'
         assert self.keepalive, 'cannot scale down too fast.'
 
 
-    def body(self):
+    def body(self) -> None:
         with self.mutex:
         with self.mutex:
             self.maybe_scale()
             self.maybe_scale()
         sleep(1.0)
         sleep(1.0)
 
 
-    def _maybe_scale(self, req=None):
+    def _maybe_scale(self, req: RequestT = None) -> None:
         procs = self.processes
         procs = self.processes
         cur = min(self.qty, self.max_concurrency)
         cur = min(self.qty, self.max_concurrency)
         if cur > procs:
         if cur > procs:
@@ -89,11 +90,11 @@ class Autoscaler(bgThread):
             self.scale_down(procs - cur)
             self.scale_down(procs - cur)
             return True
             return True
 
 
-    def maybe_scale(self, req=None):
+    def maybe_scale(self, req: RequestT = None) -> None:
         if self._maybe_scale(req):
         if self._maybe_scale(req):
             self.pool.maintain_pool()
             self.pool.maintain_pool()
 
 
-    def update(self, max=None, min=None):
+    def update(self, max: int = None, min: int = None) -> Tuple[int, int]:
         with self.mutex:
         with self.mutex:
             if max is not None:
             if max is not None:
                 if max < self.processes:
                 if max < self.processes:
@@ -105,35 +106,35 @@ class Autoscaler(bgThread):
                 self.min_concurrency = min
                 self.min_concurrency = min
             return self.max_concurrency, self.min_concurrency
             return self.max_concurrency, self.min_concurrency
 
 
-    def force_scale_up(self, n):
+    def force_scale_up(self, n: int) -> None:
         with self.mutex:
         with self.mutex:
             new = self.processes + n
             new = self.processes + n
             if new > self.max_concurrency:
             if new > self.max_concurrency:
                 self.max_concurrency = new
                 self.max_concurrency = new
             self._grow(n)
             self._grow(n)
 
 
-    def force_scale_down(self, n):
+    def force_scale_down(self, n: int) -> None:
         with self.mutex:
         with self.mutex:
             new = self.processes - n
             new = self.processes - n
             if new < self.min_concurrency:
             if new < self.min_concurrency:
                 self.min_concurrency = max(new, 0)
                 self.min_concurrency = max(new, 0)
             self._shrink(min(n, self.processes))
             self._shrink(min(n, self.processes))
 
 
-    def scale_up(self, n):
+    def scale_up(self, n: int) -> None:
         self._last_scale_up = monotonic()
         self._last_scale_up = monotonic()
-        return self._grow(n)
+        self._grow(n)
 
 
-    def scale_down(self, n):
+    def scale_down(self, n: int) -> None:
         if self._last_scale_up and (
         if self._last_scale_up and (
                 monotonic() - self._last_scale_up > self.keepalive):
                 monotonic() - self._last_scale_up > self.keepalive):
-            return self._shrink(n)
+            self._shrink(n)
 
 
-    def _grow(self, n):
+    def _grow(self, n: int) -> None:
         info('Scaling up %s processes.', n)
         info('Scaling up %s processes.', n)
         self.pool.grow(n)
         self.pool.grow(n)
         self.worker.consumer._update_prefetch_count(n)
         self.worker.consumer._update_prefetch_count(n)
 
 
-    def _shrink(self, n):
+    def _shrink(self, n: int) -> None:
         info('Scaling down %s processes.', n)
         info('Scaling down %s processes.', n)
         try:
         try:
             self.pool.shrink(n)
             self.pool.shrink(n)
@@ -143,7 +144,7 @@ class Autoscaler(bgThread):
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
         self.worker.consumer._update_prefetch_count(-n)
         self.worker.consumer._update_prefetch_count(-n)
 
 
-    def info(self):
+    def info(self) -> Mapping:
         return {
         return {
             'max': self.max_concurrency,
             'max': self.max_concurrency,
             'min': self.min_concurrency,
             'min': self.min_concurrency,
@@ -152,9 +153,9 @@ class Autoscaler(bgThread):
         }
         }
 
 
     @property
     @property
-    def qty(self):
+    def qty(self) -> int:
         return len(state.reserved_requests)
         return len(state.reserved_requests)
 
 
     @property
     @property
-    def processes(self):
+    def processes(self) -> int:
         return self.pool.num_processes
         return self.pool.num_processes

+ 28 - 25
celery/worker/components.py

@@ -2,13 +2,14 @@
 """Worker-level Bootsteps."""
 """Worker-level Bootsteps."""
 import atexit
 import atexit
 import warnings
 import warnings
-
+from typing import Mapping, Sequence, Tuple, Union
 from kombu.async import Hub as _Hub, get_event_loop, set_event_loop
 from kombu.async import Hub as _Hub, get_event_loop, set_event_loop
 from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 from kombu.async.timer import Timer as _Timer
 from kombu.async.timer import Timer as _Timer
-
+from kombu.types import HubT
 from celery import bootsteps
 from celery import bootsteps
 from celery._state import _set_task_join_will_block
 from celery._state import _set_task_join_will_block
+from celery.types import BeatT, ConsumerT, PoolT, StepT, WorkerT
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.platforms import IS_WINDOWS
 from celery.platforms import IS_WINDOWS
 from celery.utils.log import worker_logger as logger
 from celery.utils.log import worker_logger as logger
@@ -32,7 +33,7 @@ as early as possible.
 class Timer(bootsteps.Step):
 class Timer(bootsteps.Step):
     """Timer bootstep."""
     """Timer bootstep."""
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> None:
         if w.use_eventloop:
         if w.use_eventloop:
             # does not use dedicated timer thread.
             # does not use dedicated timer thread.
             w.timer = _Timer(max_interval=10.0)
             w.timer = _Timer(max_interval=10.0)
@@ -46,26 +47,26 @@ class Timer(bootsteps.Step):
                                        on_error=self.on_timer_error,
                                        on_error=self.on_timer_error,
                                        on_tick=self.on_timer_tick)
                                        on_tick=self.on_timer_tick)
 
 
-    def on_timer_error(self, exc):
+    def on_timer_error(self, exc: Exception) -> None:
         logger.error('Timer error: %r', exc, exc_info=True)
         logger.error('Timer error: %r', exc, exc_info=True)
 
 
-    def on_timer_tick(self, delay):
+    def on_timer_tick(self, delay: float) -> None:
         logger.debug('Timer wake-up! Next ETA %s secs.', delay)
         logger.debug('Timer wake-up! Next ETA %s secs.', delay)
 
 
 
 
 class Hub(bootsteps.StartStopStep):
 class Hub(bootsteps.StartStopStep):
     """Worker starts the event loop."""
     """Worker starts the event loop."""
 
 
-    requires = (Timer,)
+    requires: Sequence[StepT] = (Timer,)
 
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         w.hub = None
         w.hub = None
         super(Hub, self).__init__(w, **kwargs)
         super(Hub, self).__init__(w, **kwargs)
 
 
-    def include_if(self, w):
+    def include_if(self, w: WorkerT) -> bool:
         return w.use_eventloop
         return w.use_eventloop
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> 'Hub':
         w.hub = get_event_loop()
         w.hub = get_event_loop()
         if w.hub is None:
         if w.hub is None:
             required_hub = getattr(w._conninfo, 'requires_hub', None)
             required_hub = getattr(w._conninfo, 'requires_hub', None)
@@ -74,16 +75,16 @@ class Hub(bootsteps.StartStopStep):
         self._patch_thread_primitives(w)
         self._patch_thread_primitives(w)
         return self
         return self
 
 
-    def start(self, w):
-        pass
+    async def start(self, w: WorkerT) -> None:
+        ...
 
 
-    def stop(self, w):
+    async def stop(self, w: WorkerT) -> None:
         w.hub.close()
         w.hub.close()
 
 
-    def terminate(self, w):
+    async def terminate(self, w: WorkerT) -> None:
         w.hub.close()
         w.hub.close()
 
 
-    def _patch_thread_primitives(self, w):
+    def _patch_thread_primitives(self, w: WorkerT) -> None:
         # make clock use dummy lock
         # make clock use dummy lock
         w.app.clock.mutex = DummyLock()
         w.app.clock.mutex = DummyLock()
         # multiprocessing's ApplyResult uses this lock.
         # multiprocessing's ApplyResult uses this lock.
@@ -111,7 +112,9 @@ class Pool(bootsteps.StartStopStep):
 
 
     requires = (Hub,)
     requires = (Hub,)
 
 
-    def __init__(self, w, autoscale=None, **kwargs):
+    def __init__(self, w: WorkerT,
+                 autoscale: Union[str, Tuple[int, int]] = None,
+                 **kwargs) -> None:
         w.pool = None
         w.pool = None
         w.max_concurrency = None
         w.max_concurrency = None
         w.min_concurrency = w.concurrency
         w.min_concurrency = w.concurrency
@@ -124,15 +127,15 @@ class Pool(bootsteps.StartStopStep):
             w.max_concurrency, w.min_concurrency = w.autoscale
             w.max_concurrency, w.min_concurrency = w.autoscale
         super(Pool, self).__init__(w, **kwargs)
         super(Pool, self).__init__(w, **kwargs)
 
 
-    def close(self, w):
+    async def close(self, w: WorkerT) -> None:
         if w.pool:
         if w.pool:
             w.pool.close()
             w.pool.close()
 
 
-    def terminate(self, w):
+    async def terminate(self, w: WorkerT) -> None:
         if w.pool:
         if w.pool:
             w.pool.terminate()
             w.pool.terminate()
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> PoolT:
         semaphore = None
         semaphore = None
         max_restarts = None
         max_restarts = None
         if w.app.conf.worker_pool in GREEN_POOLS:  # pragma: no cover
         if w.app.conf.worker_pool in GREEN_POOLS:  # pragma: no cover
@@ -168,10 +171,10 @@ class Pool(bootsteps.StartStopStep):
         _set_task_join_will_block(pool.task_join_will_block)
         _set_task_join_will_block(pool.task_join_will_block)
         return pool
         return pool
 
 
-    def info(self, w):
+    def info(self, w: WorkerT) -> Mapping:
         return {'pool': w.pool.info if w.pool else 'N/A'}
         return {'pool': w.pool.info if w.pool else 'N/A'}
 
 
-    def register_with_event_loop(self, w, hub):
+    def register_with_event_loop(self, w: WorkerT, hub: HubT) -> None:
         w.pool.register_with_event_loop(hub)
         w.pool.register_with_event_loop(hub)
 
 
 
 
@@ -184,12 +187,12 @@ class Beat(bootsteps.StartStopStep):
     label = 'Beat'
     label = 'Beat'
     conditional = True
     conditional = True
 
 
-    def __init__(self, w, beat=False, **kwargs):
+    def __init__(self, w: WorkerT, beat: bool = False, **kwargs) -> None:
         self.enabled = w.beat = beat
         self.enabled = w.beat = beat
         w.beat = None
         w.beat = None
         super(Beat, self).__init__(w, beat=beat, **kwargs)
         super(Beat, self).__init__(w, beat=beat, **kwargs)
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> BeatT:
         from celery.beat import EmbeddedService
         from celery.beat import EmbeddedService
         if w.pool_cls.__module__.endswith(('gevent', 'eventlet')):
         if w.pool_cls.__module__.endswith(('gevent', 'eventlet')):
             raise ImproperlyConfigured(ERR_B_GREEN)
             raise ImproperlyConfigured(ERR_B_GREEN)
@@ -202,12 +205,12 @@ class Beat(bootsteps.StartStopStep):
 class StateDB(bootsteps.Step):
 class StateDB(bootsteps.Step):
     """Bootstep that sets up between-restart state database file."""
     """Bootstep that sets up between-restart state database file."""
 
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         self.enabled = w.statedb
         self.enabled = w.statedb
         w._persistence = None
         w._persistence = None
         super(StateDB, self).__init__(w, **kwargs)
         super(StateDB, self).__init__(w, **kwargs)
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> None:
         w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock)
         w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock)
         atexit.register(w._persistence.save)
         atexit.register(w._persistence.save)
 
 
@@ -217,7 +220,7 @@ class Consumer(bootsteps.StartStopStep):
 
 
     last = True
     last = True
 
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> ConsumerT:
         if w.max_concurrency:
         if w.max_concurrency:
             prefetch_count = max(w.min_concurrency, 1) * w.prefetch_multiplier
             prefetch_count = max(w.min_concurrency, 1) * w.prefetch_multiplier
         else:
         else:

+ 4 - 2
celery/worker/consumer/agent.py

@@ -1,5 +1,7 @@
 """Celery + :pypi:`cell` integration."""
 """Celery + :pypi:`cell` integration."""
+from typing import Any
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from .connection import Connection
 from .connection import Connection
 
 
 __all__ = ['Agent']
 __all__ = ['Agent']
@@ -11,10 +13,10 @@ class Agent(bootsteps.StartStopStep):
     conditional = True
     conditional = True
     requires = (Connection,)
     requires = (Connection,)
 
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         self.agent_cls = self.enabled = c.app.conf.worker_agent
         self.agent_cls = self.enabled = c.app.conf.worker_agent
         super(Agent, self).__init__(c, **kwargs)
         super(Agent, self).__init__(c, **kwargs)
 
 
-    def create(self, c):
+    def create(self, c: WorkerConsumerT) -> Any:
         agent = c.agent = self.instantiate(self.agent_cls, c.connection)
         agent = c.agent = self.instantiate(self.agent_cls, c.connection)
         return agent
         return agent

+ 8 - 6
celery/worker/consumer/connection.py

@@ -1,6 +1,8 @@
 """Consumer Broker Connection Bootstep."""
 """Consumer Broker Connection Bootstep."""
+from typing import Mapping
 from kombu.common import ignore_errors
 from kombu.common import ignore_errors
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
 __all__ = ['Connection']
 __all__ = ['Connection']
@@ -12,22 +14,22 @@ info = logger.info
 class Connection(bootsteps.StartStopStep):
 class Connection(bootsteps.StartStopStep):
     """Service managing the consumer broker connection."""
     """Service managing the consumer broker connection."""
 
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         c.connection = None
         c.connection = None
         super(Connection, self).__init__(c, **kwargs)
         super(Connection, self).__init__(c, **kwargs)
 
 
-    def start(self, c):
-        c.connection = c.connect()
+    async def start(self, c) -> None:
+        c.connection = await c.connect()
         info('Connected to %s', c.connection.as_uri())
         info('Connected to %s', c.connection.as_uri())
 
 
-    def shutdown(self, c):
+    async def shutdown(self, c) -> None:
         # We must set self.connection to None here, so
         # We must set self.connection to None here, so
         # that the green pidbox thread exits.
         # that the green pidbox thread exits.
         connection, c.connection = c.connection, None
         connection, c.connection = c.connection, None
         if connection:
         if connection:
-            ignore_errors(connection, connection.close)
+            await ignore_errors(connection, connection.close)
 
 
-    def info(self, c):
+    def info(self, c) -> Mapping:
         params = 'N/A'
         params = 'N/A'
         if c.connection:
         if c.connection:
             params = c.connection.info()
             params = c.connection.info()

+ 113 - 83
celery/worker/consumer/consumer.py

@@ -11,19 +11,24 @@ import os
 
 
 from collections import defaultdict
 from collections import defaultdict
 from time import sleep
 from time import sleep
+from typing import Callable, Mapping, Tuple
 
 
 from billiard.common import restart_state
 from billiard.common import restart_state
 from billiard.exceptions import RestartFreqExceeded
 from billiard.exceptions import RestartFreqExceeded
 from kombu.async.semaphore import DummyLock
 from kombu.async.semaphore import DummyLock
+from kombu.types import ConnectionT, MessageT
 from kombu.utils.compat import _detect_environment
 from kombu.utils.compat import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
 from kombu.utils.limits import TokenBucket
 from kombu.utils.limits import TokenBucket
-from vine import ppartial, promise
+from vine import Thenable, ppartial, promise
 
 
 from celery import bootsteps
 from celery import bootsteps
 from celery import signals
 from celery import signals
 from celery.app.trace import build_tracer
 from celery.app.trace import build_tracer
 from celery.exceptions import InvalidTaskError, NotRegistered
 from celery.exceptions import InvalidTaskError, NotRegistered
+from celery.types import (
+    AppT, LoopT, PoolT, RequestT, TaskT, TimerT, WorkerT, WorkerConsumerT,
+)
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.nodenames import gethostname
 from celery.utils.nodenames import gethostname
@@ -109,7 +114,7 @@ body: {0}
 """
 """
 
 
 
 
-def dump_body(m, body):
+def dump_body(m: MessageT, body: bytes):
     """Format message body for debugging purposes."""
     """Format message body for debugging purposes."""
     # v2 protocol does not deserialize body
     # v2 protocol does not deserialize body
     body = m.body if body is None else body
     body = m.body if body is None else body
@@ -151,15 +156,25 @@ class Consumer:
             'celery.worker.consumer.agent:Agent',
             'celery.worker.consumer.agent:Agent',
         ]
         ]
 
 
-        def shutdown(self, parent):
-            self.send_all(parent, 'shutdown')
-
-    def __init__(self, on_task_request,
-                 init_callback=noop, hostname=None,
-                 pool=None, app=None,
-                 timer=None, controller=None, hub=None, amqheartbeat=None,
-                 worker_options=None, disable_rate_limits=False,
-                 initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
+        async def shutdown(self, parent: WorkerConsumerT) -> None:
+            await self.send_all(parent, 'shutdown')
+
+    def __init__(self,
+                 on_task_request: Callable,
+                 *,
+                 init_callback: Callable = noop,
+                 hostname: str = None,
+                 pool: PoolT = None,
+                 app: AppT = None,
+                 timer: TimerT = None,
+                 controller: WorkerT = None,
+                 hub: LoopT = None,
+                 amqheartbeat: float = None,
+                 worker_options: Mapping = None,
+                 disable_rate_limits: bool = False,
+                 initial_prefetch_count: int = 2,
+                 prefetch_multiplier: int = 1,
+                 **kwargs) -> None:
         self.app = app
         self.app = app
         self.controller = controller
         self.controller = controller
         self.init_callback = init_callback
         self.init_callback = init_callback
@@ -213,14 +228,14 @@ class Consumer:
         )
         )
         self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
         self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
 
 
-    def call_soon(self, p, *args, **kwargs):
+    def call_soon(self, p: Thenable, *args, **kwargs) -> Thenable:
         p = ppartial(p, *args, **kwargs)
         p = ppartial(p, *args, **kwargs)
         if self.hub:
         if self.hub:
             return self.hub.call_soon(p)
             return self.hub.call_soon(p)
         self._pending_operations.append(p)
         self._pending_operations.append(p)
         return p
         return p
 
 
-    def perform_pending_operations(self):
+    def perform_pending_operations(self) -> None:
         if not self.hub:
         if not self.hub:
             while self._pending_operations:
             while self._pending_operations:
                 try:
                 try:
@@ -228,16 +243,16 @@ class Consumer:
                 except Exception as exc:  # pylint: disable=broad-except
                 except Exception as exc:  # pylint: disable=broad-except
                     logger.exception('Pending callback raised: %r', exc)
                     logger.exception('Pending callback raised: %r', exc)
 
 
-    def bucket_for_task(self, type):
+    def bucket_for_task(self, type: TaskT) -> TokenBucket:
         limit = rate(getattr(type, 'rate_limit', None))
         limit = rate(getattr(type, 'rate_limit', None))
         return TokenBucket(limit, capacity=1) if limit else None
         return TokenBucket(limit, capacity=1) if limit else None
 
 
-    def reset_rate_limits(self):
+    def reset_rate_limits(self) -> None:
         self.task_buckets.update(
         self.task_buckets.update(
             (n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
             (n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
         )
         )
 
 
-    def _update_prefetch_count(self, index=0):
+    def _update_prefetch_count(self, index: int = 0) -> int:
         """Update prefetch count after pool/shrink grow operations.
         """Update prefetch count after pool/shrink grow operations.
 
 
         Index must be the change in number of processes as a positive
         Index must be the change in number of processes as a positive
@@ -256,33 +271,35 @@ class Consumer:
         )
         )
         return self._update_qos_eventually(index)
         return self._update_qos_eventually(index)
 
 
-    def _update_qos_eventually(self, index):
+    def _update_qos_eventually(self, index: int) -> int:
         return (self.qos.decrement_eventually if index < 0
         return (self.qos.decrement_eventually if index < 0
                 else self.qos.increment_eventually)(
                 else self.qos.increment_eventually)(
             abs(index) * self.prefetch_multiplier)
             abs(index) * self.prefetch_multiplier)
 
 
-    def _limit_move_to_pool(self, request):
+    async def _limit_move_to_pool(self, request: RequestT) -> None:
         task_reserved(request)
         task_reserved(request)
-        self.on_task_request(request)
+        await self.on_task_request(request)
 
 
-    def _on_bucket_wakeup(self, bucket, tokens):
+    async def _on_bucket_wakeup(self, bucket: TokenBucket, tokens: int):
         try:
         try:
             request = bucket.pop()
             request = bucket.pop()
         except IndexError:
         except IndexError:
             pass
             pass
         else:
         else:
-            self._limit_move_to_pool(request)
+            await self._limit_move_to_pool(request)
             self._schedule_oldest_bucket_request(bucket, tokens)
             self._schedule_oldest_bucket_request(bucket, tokens)
 
 
-    def _schedule_oldest_bucket_request(self, bucket, tokens):
+    def _schedule_oldest_bucket_request(
+            self, bucket: TokenBucket, tokens: int) -> None:
         try:
         try:
             request = bucket.pop()
             request = bucket.pop()
         except IndexError:
         except IndexError:
             pass
             pass
         else:
         else:
-            return self._schedule_bucket_request(request, bucket, tokens)
+            self._schedule_bucket_request(request, bucket, tokens)
 
 
-    def _schedule_bucket_request(self, request, bucket, tokens):
+    def _schedule_bucket_request(
+            self, request: RequestT, bucket: TokenBucket, tokens: int) -> None:
         bucket.can_consume(tokens)
         bucket.can_consume(tokens)
         bucket.add(request)
         bucket.add(request)
         pri = self._limit_order = (self._limit_order + 1) % 10
         pri = self._limit_order = (self._limit_order + 1) % 10
@@ -292,12 +309,16 @@ class Consumer:
             priority=pri,
             priority=pri,
         )
         )
 
 
-    def _limit_task(self, request, bucket, tokens):
+    def _limit_task(
+            self, request: RequestT,
+            bucket: TokenBucket,
+            tokens: int) -> None:
         if bucket.contents:
         if bucket.contents:
-            return bucket.add(request)
-        return self._schedule_bucket_request(request, bucket, tokens)
+            bucket.add(request)
+        else:
+            self._schedule_bucket_request(request, bucket, tokens)
 
 
-    def start(self):
+    async def start(self) -> None:
         blueprint = self.blueprint
         blueprint = self.blueprint
         while blueprint.state != CLOSE:
         while blueprint.state != CLOSE:
             maybe_shutdown()
             maybe_shutdown()
@@ -309,7 +330,7 @@ class Consumer:
                     sleep(1)
                     sleep(1)
             self.restart_count += 1
             self.restart_count += 1
             try:
             try:
-                blueprint.start(self)
+                await blueprint.start(self)
             except self.connection_errors as exc:
             except self.connection_errors as exc:
                 # If we're not retrying connections, no need to catch
                 # If we're not retrying connections, no need to catch
                 # connection errors
                 # connection errors
@@ -324,42 +345,42 @@ class Consumer:
                     else:
                     else:
                         self.on_connection_error_before_connected(exc)
                         self.on_connection_error_before_connected(exc)
                     self.on_close()
                     self.on_close()
-                    blueprint.restart(self)
+                    await blueprint.restart(self)
 
 
-    def on_connection_error_before_connected(self, exc):
+    def on_connection_error_before_connected(self, exc: Exception) -> None:
         error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
         error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
               'Trying to reconnect...')
               'Trying to reconnect...')
 
 
-    def on_connection_error_after_connected(self, exc):
+    def on_connection_error_after_connected(self, exc: Exception) -> None:
         warn(CONNECTION_RETRY, exc_info=True)
         warn(CONNECTION_RETRY, exc_info=True)
         try:
         try:
             self.connection.collect()
             self.connection.collect()
         except Exception:  # pylint: disable=broad-except
         except Exception:  # pylint: disable=broad-except
             pass
             pass
 
 
-    def register_with_event_loop(self, hub):
-        self.blueprint.send_all(
+    async def register_with_event_loop(self, hub: LoopT) -> None:
+        await self.blueprint.send_all(
             self, 'register_with_event_loop', args=(hub,),
             self, 'register_with_event_loop', args=(hub,),
             description='Hub.register',
             description='Hub.register',
         )
         )
 
 
-    def shutdown(self):
-        self.blueprint.shutdown(self)
+    async def shutdown(self) -> None:
+        await self.blueprint.shutdown(self)
 
 
-    def stop(self):
-        self.blueprint.stop(self)
+    async def stop(self) -> None:
+        await self.blueprint.stop(self)
 
 
-    def on_ready(self):
+    def on_ready(self) -> None:
         callback, self.init_callback = self.init_callback, None
         callback, self.init_callback = self.init_callback, None
         if callback:
         if callback:
             callback(self)
             callback(self)
 
 
-    def loop_args(self):
+    def loop_args(self) -> Tuple:
         return (self, self.connection, self.task_consumer,
         return (self, self.connection, self.task_consumer,
                 self.blueprint, self.hub, self.qos, self.amqheartbeat,
                 self.blueprint, self.hub, self.qos, self.amqheartbeat,
                 self.app.clock, self.amqheartbeat_rate)
                 self.app.clock, self.amqheartbeat_rate)
 
 
-    def on_decode_error(self, message, exc):
+    async def on_decode_error(self, message: MessageT, exc: Exception) -> None:
         """Callback called if an error occurs while decoding a message.
         """Callback called if an error occurs while decoding a message.
 
 
         Simply logs the error and acknowledges the message so it
         Simply logs the error and acknowledges the message so it
@@ -373,9 +394,9 @@ class Consumer:
              exc, message.content_type, message.content_encoding,
              exc, message.content_type, message.content_encoding,
              safe_repr(message.headers), dump_body(message, message.body),
              safe_repr(message.headers), dump_body(message, message.body),
              exc_info=1)
              exc_info=1)
-        message.ack()
+        await message.ack()
 
 
-    def on_close(self):
+    def on_close(self) -> None:
         # Clear internal queues to get rid of old messages.
         # Clear internal queues to get rid of old messages.
         # They can't be acked anyway, as a delivery tag is specific
         # They can't be acked anyway, as a delivery tag is specific
         # to the current channel.
         # to the current channel.
@@ -390,26 +411,29 @@ class Consumer:
         if self.pool and self.pool.flush:
         if self.pool and self.pool.flush:
             self.pool.flush()
             self.pool.flush()
 
 
-    def connect(self):
+    async def connect(self) -> ConnectionT:
         """Establish the broker connection used for consuming tasks.
         """Establish the broker connection used for consuming tasks.
 
 
         Retries establishing the connection if the
         Retries establishing the connection if the
         :setting:`broker_connection_retry` setting is enabled
         :setting:`broker_connection_retry` setting is enabled
         """
         """
         conn = self.connection_for_read(heartbeat=self.amqheartbeat)
         conn = self.connection_for_read(heartbeat=self.amqheartbeat)
+        await conn.connect()
         if self.hub:
         if self.hub:
             conn.transport.register_with_event_loop(conn.connection, self.hub)
             conn.transport.register_with_event_loop(conn.connection, self.hub)
         return conn
         return conn
 
 
-    def connection_for_read(self, heartbeat=None):
-        return self.ensure_connected(
+    async def connection_for_read(self,
+                                  heartbeat: float = None) -> ConnectionT:
+        return await self.ensure_connected(
             self.app.connection_for_read(heartbeat=heartbeat))
             self.app.connection_for_read(heartbeat=heartbeat))
 
 
-    def connection_for_write(self, heartbeat=None):
-        return self.ensure_connected(
+    async def connection_for_write(self,
+                                   heartbeat: float = None) -> ConnectionT:
+        return await self.ensure_connected(
             self.app.connection_for_write(heartbeat=heartbeat))
             self.app.connection_for_write(heartbeat=heartbeat))
 
 
-    def ensure_connected(self, conn):
+    async def ensure_connected(self, conn: ConnectionT) -> ConnectionT:
         # Callback called for each retry while the connection
         # Callback called for each retry while the connection
         # can't be established.
         # can't be established.
         def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
         def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
@@ -422,25 +446,28 @@ class Consumer:
         # until needed.
         # until needed.
         if not self.app.conf.broker_connection_retry:
         if not self.app.conf.broker_connection_retry:
             # retry disabled, just call connect directly.
             # retry disabled, just call connect directly.
-            conn.connect()
+            await conn.connect()
             return conn
             return conn
 
 
-        conn = conn.ensure_connection(
+        conn = await conn.ensure_connection(
             _error_handler, self.app.conf.broker_connection_max_retries,
             _error_handler, self.app.conf.broker_connection_max_retries,
             callback=maybe_shutdown,
             callback=maybe_shutdown,
         )
         )
         return conn
         return conn
 
 
-    def _flush_events(self):
+    def _flush_events(self) -> None:
         if self.event_dispatcher:
         if self.event_dispatcher:
             self.event_dispatcher.flush()
             self.event_dispatcher.flush()
 
 
-    def on_send_event_buffered(self):
+    def on_send_event_buffered(self) -> None:
         if self.hub:
         if self.hub:
             self.hub._ready.add(self._flush_events)
             self.hub._ready.add(self._flush_events)
 
 
-    def add_task_queue(self, queue, exchange=None, exchange_type=None,
-                       routing_key=None, **options):
+    async def add_task_queue(self, queue: str,
+                             exchange: str = None,
+                             exchange_type: str = None,
+                             routing_key: str = None,
+                             **options) -> None:
         cset = self.task_consumer
         cset = self.task_consumer
         queues = self.app.amqp.queues
         queues = self.app.amqp.queues
         # Must use in' here, as __missing__ will automatically
         # Must use in' here, as __missing__ will automatically
@@ -458,33 +485,34 @@ class Consumer:
                                   routing_key=routing_key, **options)
                                   routing_key=routing_key, **options)
         if not cset.consuming_from(queue):
         if not cset.consuming_from(queue):
             cset.add_queue(q)
             cset.add_queue(q)
-            cset.consume()
+            await cset.consume()
             info('Started consuming from %s', queue)
             info('Started consuming from %s', queue)
 
 
-    def cancel_task_queue(self, queue):
+    async def cancel_task_queue(self, queue):
         info('Canceling queue %s', queue)
         info('Canceling queue %s', queue)
         self.app.amqp.queues.deselect(queue)
         self.app.amqp.queues.deselect(queue)
-        self.task_consumer.cancel_by_queue(queue)
+        await self.task_consumer.cancel_by_queue(queue)
 
 
-    def apply_eta_task(self, task):
+    def apply_eta_task(self, request: RequestT) -> None:
         """Method called by the timer to apply a task with an ETA/countdown."""
         """Method called by the timer to apply a task with an ETA/countdown."""
-        task_reserved(task)
-        self.on_task_request(task)
+        task_reserved(request)
+        self.on_task_request(request)
         self.qos.decrement_eventually()
         self.qos.decrement_eventually()
 
 
-    def _message_report(self, body, message):
+    def _message_report(self, body: bytes, message: MessageT) -> str:
         return MESSAGE_REPORT.format(dump_body(message, body),
         return MESSAGE_REPORT.format(dump_body(message, body),
                                      safe_repr(message.content_type),
                                      safe_repr(message.content_type),
                                      safe_repr(message.content_encoding),
                                      safe_repr(message.content_encoding),
                                      safe_repr(message.delivery_info),
                                      safe_repr(message.delivery_info),
                                      safe_repr(message.headers))
                                      safe_repr(message.headers))
 
 
-    def on_unknown_message(self, body, message):
+    async def on_unknown_message(self, body: bytes, message: MessageT) -> None:
         warn(UNKNOWN_FORMAT, self._message_report(body, message))
         warn(UNKNOWN_FORMAT, self._message_report(body, message))
-        message.reject_log_error(logger, self.connection_errors)
+        await message.reject_log_error(logger, self.connection_errors)
         signals.task_rejected.send(sender=self, message=message, exc=None)
         signals.task_rejected.send(sender=self, message=message, exc=None)
 
 
-    def on_unknown_task(self, body, message, exc):
+    async def on_unknown_task(
+            self, body: bytes, message: MessageT, exc: Exception) -> None:
         error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         try:
         try:
             id_, name = message.headers['id'], message.headers['task']
             id_, name = message.headers['id'], message.headers['task']
@@ -499,12 +527,12 @@ class Consumer:
             reply_to=message.properties.get('reply_to'),
             reply_to=message.properties.get('reply_to'),
             errbacks=None,
             errbacks=None,
         )
         )
-        message.reject_log_error(logger, self.connection_errors)
-        self.app.backend.mark_as_failure(
+        await message.reject_log_error(logger, self.connection_errors)
+        await self.app.backend.mark_as_failure(
             id_, NotRegistered(name), request=request,
             id_, NotRegistered(name), request=request,
         )
         )
         if self.event_dispatcher:
         if self.event_dispatcher:
-            self.event_dispatcher.send(
+            await self.event_dispatcher.send(
                 'task-failed', uuid=id_,
                 'task-failed', uuid=id_,
                 exception='NotRegistered({0!r})'.format(name),
                 exception='NotRegistered({0!r})'.format(name),
             )
             )
@@ -512,27 +540,29 @@ class Consumer:
             sender=self, message=message, exc=exc, name=name, id=id_,
             sender=self, message=message, exc=exc, name=name, id=id_,
         )
         )
 
 
-    def on_invalid_task(self, body, message, exc):
+    async def on_invalid_task(
+            self, body: bytes, message: MessageT, exc: Exception) -> None:
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
-        message.reject_log_error(logger, self.connection_errors)
+        await message.reject_log_error(logger, self.connection_errors)
         signals.task_rejected.send(sender=self, message=message, exc=exc)
         signals.task_rejected.send(sender=self, message=message, exc=exc)
 
 
-    def update_strategies(self):
+    def update_strategies(self) -> None:
         loader = self.app.loader
         loader = self.app.loader
         for name, task in self.app.tasks.items():
         for name, task in self.app.tasks.items():
             self.strategies[name] = task.start_strategy(self.app, self)
             self.strategies[name] = task.start_strategy(self.app, self)
             task.__trace__ = build_tracer(name, task, loader, self.hostname,
             task.__trace__ = build_tracer(name, task, loader, self.hostname,
                                           app=self.app)
                                           app=self.app)
 
 
-    def create_task_handler(self, promise=promise):
+    def create_task_handler(self) -> Callable:
         strategies = self.strategies
         strategies = self.strategies
         on_unknown_message = self.on_unknown_message
         on_unknown_message = self.on_unknown_message
         on_unknown_task = self.on_unknown_task
         on_unknown_task = self.on_unknown_task
         on_invalid_task = self.on_invalid_task
         on_invalid_task = self.on_invalid_task
         callbacks = self.on_task_message
         callbacks = self.on_task_message
         call_soon = self.call_soon
         call_soon = self.call_soon
+        promise_t = promise
 
 
-        def on_task_received(message):
+        async def on_task_received(message: MessageT) -> None:
             # payload will only be set for v1 protocol, since v2
             # payload will only be set for v1 protocol, since v2
             # will defer deserializing the message body to the pool.
             # will defer deserializing the message body to the pool.
             payload = None
             payload = None
@@ -544,29 +574,29 @@ class Consumer:
                 try:
                 try:
                     payload = message.decode()
                     payload = message.decode()
                 except Exception as exc:  # pylint: disable=broad-except
                 except Exception as exc:  # pylint: disable=broad-except
-                    return self.on_decode_error(message, exc)
+                    return await self.on_decode_error(message, exc)
                 try:
                 try:
                     type_, payload = payload['task'], payload  # protocol v1
                     type_, payload = payload['task'], payload  # protocol v1
                 except (TypeError, KeyError):
                 except (TypeError, KeyError):
-                    return on_unknown_message(payload, message)
+                    return await on_unknown_message(payload, message)
             try:
             try:
                 strategy = strategies[type_]
                 strategy = strategies[type_]
             except KeyError as exc:
             except KeyError as exc:
-                return on_unknown_task(None, message, exc)
+                return await on_unknown_task(None, message, exc)
             else:
             else:
                 try:
                 try:
-                    strategy(
+                    await strategy(
                         message, payload,
                         message, payload,
-                        promise(call_soon, (message.ack_log_error,)),
-                        promise(call_soon, (message.reject_log_error,)),
+                        promise_t(call_soon, (message.ack_log_error,)),
+                        promise_t(call_soon, (message.reject_log_error,)),
                         callbacks,
                         callbacks,
                     )
                     )
                 except InvalidTaskError as exc:
                 except InvalidTaskError as exc:
-                    return on_invalid_task(payload, message, exc)
+                    return await on_invalid_task(payload, message, exc)
 
 
         return on_task_received
         return on_task_received
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(self)``."""
         """``repr(self)``."""
         return '<Consumer: {self.hostname} ({state})>'.format(
         return '<Consumer: {self.hostname} ({state})>'.format(
             self=self, state=self.blueprint.human_state(),
             self=self, state=self.blueprint.human_state(),
@@ -583,9 +613,9 @@ class Evloop(bootsteps.StartStopStep):
     label = 'event loop'
     label = 'event loop'
     last = True
     last = True
 
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         self.patch_all(c)
         self.patch_all(c)
         c.loop(*c.loop_args())
         c.loop(*c.loop_args())
 
 
-    def patch_all(self, c):
+    def patch_all(self, c: WorkerConsumerT) -> None:
         c.qos._mutex = DummyLock()
         c.qos._mutex = DummyLock()

+ 3 - 2
celery/worker/consumer/control.py

@@ -6,6 +6,7 @@ The actual commands are implemented in :mod:`celery.worker.control`.
 """
 """
 from celery import bootsteps
 from celery import bootsteps
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
+from celery.types import WorkerConsumerT
 from celery.worker import pidbox
 from celery.worker import pidbox
 from .tasks import Tasks
 from .tasks import Tasks
 
 
@@ -19,7 +20,7 @@ class Control(bootsteps.StartStopStep):
 
 
     requires = (Tasks,)
     requires = (Tasks,)
 
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         self.is_green = c.pool is not None and c.pool.is_green
         self.is_green = c.pool is not None and c.pool.is_green
         self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
         self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
         self.start = self.box.start
         self.start = self.box.start
@@ -27,6 +28,6 @@ class Control(bootsteps.StartStopStep):
         self.shutdown = self.box.shutdown
         self.shutdown = self.box.shutdown
         super(Control, self).__init__(c, **kwargs)
         super(Control, self).__init__(c, **kwargs)
 
 
-    def include_if(self, c):
+    def include_if(self, c: WorkerConsumerT) -> bool:
         return (c.app.conf.worker_enable_remote_control and
         return (c.app.conf.worker_enable_remote_control and
                 c.conninfo.supports_exchange_type('fanout'))
                 c.conninfo.supports_exchange_type('fanout'))

+ 14 - 13
celery/worker/consumer/events.py

@@ -4,6 +4,7 @@
 """
 """
 from kombu.common import ignore_errors
 from kombu.common import ignore_errors
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from .connection import Connection
 from .connection import Connection
 
 
 __all__ = ['Events']
 __all__ = ['Events']
@@ -14,11 +15,11 @@ class Events(bootsteps.StartStopStep):
 
 
     requires = (Connection,)
     requires = (Connection,)
 
 
-    def __init__(self, c,
-                 task_events=True,
-                 without_heartbeat=False,
-                 without_gossip=False,
-                 **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 task_events: bool = True,
+                 without_heartbeat: bool = False,
+                 without_gossip: bool = False,
+                 **kwargs) -> None:
         self.groups = None if task_events else ['worker']
         self.groups = None if task_events else ['worker']
         self.send_events = (
         self.send_events = (
             task_events or
             task_events or
@@ -28,7 +29,7 @@ class Events(bootsteps.StartStopStep):
         c.event_dispatcher = None
         c.event_dispatcher = None
         super(Events, self).__init__(c, **kwargs)
         super(Events, self).__init__(c, **kwargs)
 
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         # flush events sent while connection was down.
         # flush events sent while connection was down.
         prev = self._close(c)
         prev = self._close(c)
         dis = c.event_dispatcher = c.app.events.Dispatcher(
         dis = c.event_dispatcher = c.app.events.Dispatcher(
@@ -45,10 +46,13 @@ class Events(bootsteps.StartStopStep):
             dis.extend_buffer(prev)
             dis.extend_buffer(prev)
             dis.flush()
             dis.flush()
 
 
-    def stop(self, c):
-        pass
+    async def stop(self, c: WorkerConsumerT) -> None:
+        ...
 
 
-    def _close(self, c):
+    async def shutdown(self, c: WorkerConsumerT) -> None:
+        await self._close(c)
+
+    async def _close(self, c: WorkerConsumerT) -> None:
         if c.event_dispatcher:
         if c.event_dispatcher:
             dispatcher = c.event_dispatcher
             dispatcher = c.event_dispatcher
             # remember changes from remote control commands:
             # remember changes from remote control commands:
@@ -56,10 +60,7 @@ class Events(bootsteps.StartStopStep):
 
 
             # close custom connection
             # close custom connection
             if dispatcher.connection:
             if dispatcher.connection:
-                ignore_errors(c, dispatcher.connection.close)
+                await ignore_errors(c, dispatcher.connection.close)
             ignore_errors(c, dispatcher.close)
             ignore_errors(c, dispatcher.close)
             c.event_dispatcher = None
             c.event_dispatcher = None
             return dispatcher
             return dispatcher
-
-    def shutdown(self, c):
-        self._close(c)

+ 48 - 38
celery/worker/consumer/gossip.py

@@ -3,11 +3,14 @@ from collections import defaultdict
 from functools import partial
 from functools import partial
 from heapq import heappush
 from heapq import heappush
 from operator import itemgetter
 from operator import itemgetter
+from typing import Callable, Set, Sequence
 
 
 from kombu import Consumer
 from kombu import Consumer
 from kombu.async.semaphore import DummyLock
 from kombu.async.semaphore import DummyLock
+from kombu.types import ChannelT, ConsumerT, MessageT
 
 
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import AppT, EventT, SignatureT, WorkerT, WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
@@ -29,10 +32,13 @@ class Gossip(bootsteps.ConsumerStep):
     _cons_stamp_fields = itemgetter(
     _cons_stamp_fields = itemgetter(
         'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
         'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
     )
     )
-    compatible_transports = {'amqp', 'redis'}
+    compatible_transports: Set[str] = {'amqp', 'redis'}
 
 
-    def __init__(self, c, without_gossip=False,
-                 interval=5.0, heartbeat_interval=2.0, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_gossip: bool = False,
+                 interval: float = 5.0,
+                 heartbeat_interval: float = 2.0,
+                 **kwargs) -> None:
         self.enabled = not without_gossip and self.compatible_transport(c.app)
         self.enabled = not without_gossip and self.compatible_transport(c.app)
         self.app = c.app
         self.app = c.app
         c.gossip = self
         c.gossip = self
@@ -72,40 +78,41 @@ class Gossip(bootsteps.ConsumerStep):
 
 
         super(Gossip, self).__init__(c, **kwargs)
         super(Gossip, self).__init__(c, **kwargs)
 
 
-    def compatible_transport(self, app):
+    def compatible_transport(self, app: AppT) -> bool:
         with app.connection_for_read() as conn:
         with app.connection_for_read() as conn:
             return conn.transport.driver_type in self.compatible_transports
             return conn.transport.driver_type in self.compatible_transports
 
 
-    def election(self, id, topic, action=None):
+    async def election(self, id: str, topic: str, action: str = None):
         self.consensus_replies[id] = []
         self.consensus_replies[id] = []
-        self.dispatcher.send(
+        await self.dispatcher.send(
             'worker-elect',
             'worker-elect',
             id=id, topic=topic, action=action, cver=1,
             id=id, topic=topic, action=action, cver=1,
         )
         )
 
 
-    def call_task(self, task):
+    async def call_task(self, task: SignatureT) -> None:
         try:
         try:
             self.app.signature(task).apply_async()
             self.app.signature(task).apply_async()
         except Exception as exc:  # pylint: disable=broad-except
         except Exception as exc:  # pylint: disable=broad-except
             logger.exception('Could not call task: %r', exc)
             logger.exception('Could not call task: %r', exc)
 
 
-    def on_elect(self, event):
+    async def on_elect(self, event: EventT) -> None:
         try:
         try:
             (id_, clock, hostname, pid,
             (id_, clock, hostname, pid,
              topic, action, _) = self._cons_stamp_fields(event)
              topic, action, _) = self._cons_stamp_fields(event)
         except KeyError as exc:
         except KeyError as exc:
-            return logger.exception('election request missing field %s', exc)
-        heappush(
-            self.consensus_requests[id_],
-            (clock, '%s.%s' % (hostname, pid), topic, action),
-        )
-        self.dispatcher.send('worker-elect-ack', id=id_)
+            logger.exception('election request missing field %s', exc)
+        else:
+            heappush(
+                self.consensus_requests[id_],
+                (clock, '%s.%s' % (hostname, pid), topic, action),
+            )
+            await self.dispatcher.send('worker-elect-ack', id=id_)
 
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         super().start(c)
         super().start(c)
         self.dispatcher = c.event_dispatcher
         self.dispatcher = c.event_dispatcher
 
 
-    def on_elect_ack(self, event):
+    async def on_elect_ack(self, event: EventT) -> None:
         id = event['id']
         id = event['id']
         try:
         try:
             replies = self.consensus_replies[id]
             replies = self.consensus_replies[id]
@@ -125,59 +132,62 @@ class Gossip(bootsteps.ConsumerStep):
                 except KeyError:
                 except KeyError:
                     logger.exception('Unknown election topic %r', topic)
                     logger.exception('Unknown election topic %r', topic)
                 else:
                 else:
-                    handler(action)
+                    await handler(action)
             else:
             else:
                 info('node %s elected for %r', leader, id)
                 info('node %s elected for %r', leader, id)
             self.consensus_requests.pop(id, None)
             self.consensus_requests.pop(id, None)
             self.consensus_replies.pop(id, None)
             self.consensus_replies.pop(id, None)
 
 
-    def on_node_join(self, worker):
+    async def on_node_join(self, worker: WorkerT):
         debug('%s joined the party', worker.hostname)
         debug('%s joined the party', worker.hostname)
-        self._call_handlers(self.on.node_join, worker)
+        await self._call_handlers(self.on.node_join, worker)
 
 
-    def on_node_leave(self, worker):
+    async def on_node_leave(self, worker: WorkerT):
         debug('%s left', worker.hostname)
         debug('%s left', worker.hostname)
-        self._call_handlers(self.on.node_leave, worker)
+        await self._call_handlers(self.on.node_leave, worker)
 
 
-    def on_node_lost(self, worker):
+    async def on_node_lost(self, worker: WorkerT):
         info('missed heartbeat from %s', worker.hostname)
         info('missed heartbeat from %s', worker.hostname)
-        self._call_handlers(self.on.node_lost, worker)
+        await self._call_handlers(self.on.node_lost, worker)
 
 
-    def _call_handlers(self, handlers, *args, **kwargs):
+    async def _call_handlers(self, handlers: Sequence[Callable],
+                             *args, **kwargs) -> None:
         for handler in handlers:
         for handler in handlers:
             try:
             try:
-                handler(*args, **kwargs)
+                await handler(*args, **kwargs)
             except Exception as exc:  # pylint: disable=broad-except
             except Exception as exc:  # pylint: disable=broad-except
                 logger.exception(
                 logger.exception(
                     'Ignored error from handler %r: %r', handler, exc)
                     'Ignored error from handler %r: %r', handler, exc)
 
 
-    def register_timer(self):
+    def register_timer(self) -> None:
         if self._tref is not None:
         if self._tref is not None:
             self._tref.cancel()
             self._tref.cancel()
         self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
         self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
 
 
-    def periodic(self):
+    async def periodic(self) -> None:
         workers = self.state.workers
         workers = self.state.workers
         dirty = set()
         dirty = set()
         for worker in workers.values():
         for worker in workers.values():
             if not worker.alive:
             if not worker.alive:
                 dirty.add(worker)
                 dirty.add(worker)
-                self.on_node_lost(worker)
+                await self.on_node_lost(worker)
         for worker in dirty:
         for worker in dirty:
             workers.pop(worker.hostname, None)
             workers.pop(worker.hostname, None)
 
 
-    def get_consumers(self, channel):
+    def get_consumers(self, channel: ChannelT) -> Sequence[ConsumerT]:
         self.register_timer()
         self.register_timer()
         ev = self.Receiver(channel, routing_key='worker.#',
         ev = self.Receiver(channel, routing_key='worker.#',
                            queue_ttl=self.heartbeat_interval)
                            queue_ttl=self.heartbeat_interval)
-        return [Consumer(
-            channel,
-            queues=[ev.queue],
-            on_message=partial(self.on_message, ev.event_from_message),
-            no_ack=True
-        )]
-
-    def on_message(self, prepare, message):
+        return [
+            Consumer(
+                channel,
+                queues=[ev.queue],
+                on_message=partial(self.on_message, ev.event_from_message),
+                no_ack=True
+            )
+        ]
+
+    async def on_message(self, prepare: Callable, message: MessageT) -> None:
         _type = message.delivery_info['routing_key']
         _type = message.delivery_info['routing_key']
 
 
         # For redis when `fanout_patterns=False` (See Issue #1882)
         # For redis when `fanout_patterns=False` (See Issue #1882)
@@ -188,7 +198,7 @@ class Gossip(bootsteps.ConsumerStep):
         except KeyError:
         except KeyError:
             pass
             pass
         else:
         else:
-            return handler(message.payload)
+            return await handler(message.payload)
 
 
         # proto2: hostname in header; proto1: in body
         # proto2: hostname in header; proto1: in body
         hostname = (message.headers.get('hostname') or
         hostname = (message.headers.get('hostname') or

+ 11 - 6
celery/worker/consumer/heart.py

@@ -1,5 +1,6 @@
 """Worker Event Heartbeat Bootstep."""
 """Worker Event Heartbeat Bootstep."""
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.worker import heartbeat
 from celery.worker import heartbeat
 from .events import Events
 from .events import Events
 
 
@@ -17,19 +18,23 @@ class Heart(bootsteps.StartStopStep):
 
 
     requires = (Events,)
     requires = (Events,)
 
 
-    def __init__(self, c,
-                 without_heartbeat=False, heartbeat_interval=None, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_heartbeat: bool = False,
+                 heartbeat_interval: float = None,
+                 **kwargs) -> None:
         self.enabled = not without_heartbeat
         self.enabled = not without_heartbeat
         self.heartbeat_interval = heartbeat_interval
         self.heartbeat_interval = heartbeat_interval
         c.heart = None
         c.heart = None
         super(Heart, self).__init__(c, **kwargs)
         super(Heart, self).__init__(c, **kwargs)
 
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         c.heart = heartbeat.Heart(
         c.heart = heartbeat.Heart(
             c.timer, c.event_dispatcher, self.heartbeat_interval,
             c.timer, c.event_dispatcher, self.heartbeat_interval,
         )
         )
-        c.heart.start()
+        await c.heart.start()
 
 
-    def stop(self, c):
-        c.heart = c.heart and c.heart.stop()
+    async def stop(self, c: WorkerConsumerT) -> None:
+        heart, c.heart = c.heart, None
+        if heart:
+            await heart.stop()
     shutdown = stop
     shutdown = stop

+ 22 - 13
celery/worker/consumer/mingle.py

@@ -1,5 +1,7 @@
 """Worker <-> Worker Sync at startup (Bootstep)."""
 """Worker <-> Worker Sync at startup (Bootstep)."""
+from typing import Mapping
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import AppT, WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from .events import Events
 from .events import Events
 
 
@@ -23,53 +25,60 @@ class Mingle(bootsteps.StartStopStep):
     requires = (Events,)
     requires = (Events,)
     compatible_transports = {'amqp', 'redis'}
     compatible_transports = {'amqp', 'redis'}
 
 
-    def __init__(self, c, without_mingle=False, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_mingle: bool = False,
+                 **kwargs) -> None:
         self.enabled = not without_mingle and self.compatible_transport(c.app)
         self.enabled = not without_mingle and self.compatible_transport(c.app)
         super(Mingle, self).__init__(
         super(Mingle, self).__init__(
             c, without_mingle=without_mingle, **kwargs)
             c, without_mingle=without_mingle, **kwargs)
 
 
-    def compatible_transport(self, app):
+    def compatible_transport(self, app: AppT) -> bool:
         with app.connection_for_read() as conn:
         with app.connection_for_read() as conn:
             return conn.transport.driver_type in self.compatible_transports
             return conn.transport.driver_type in self.compatible_transports
 
 
-    def start(self, c):
-        self.sync(c)
+    async def start(self, c: WorkerConsumerT) -> None:
+        await self.sync(c)
 
 
-    def sync(self, c):
+    async def sync(self, c: WorkerConsumerT) -> None:
         info('mingle: searching for neighbors')
         info('mingle: searching for neighbors')
-        replies = self.send_hello(c)
+        replies = await self.send_hello(c)
         if replies:
         if replies:
             info('mingle: sync with %s nodes',
             info('mingle: sync with %s nodes',
                  len([reply for reply, value in replies.items() if value]))
                  len([reply for reply, value in replies.items() if value]))
-            [self.on_node_reply(c, nodename, reply)
+            [await self.on_node_reply(c, nodename, reply)
              for nodename, reply in replies.items() if reply]
              for nodename, reply in replies.items() if reply]
             info('mingle: sync complete')
             info('mingle: sync complete')
         else:
         else:
             info('mingle: all alone')
             info('mingle: all alone')
 
 
-    def send_hello(self, c):
+    async def send_hello(self, c: WorkerConsumerT) -> Mapping:
         inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
         inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
         our_revoked = c.controller.state.revoked
         our_revoked = c.controller.state.revoked
         replies = inspect.hello(c.hostname, our_revoked._data) or {}
         replies = inspect.hello(c.hostname, our_revoked._data) or {}
         replies.pop(c.hostname, None)  # delete my own response
         replies.pop(c.hostname, None)  # delete my own response
         return replies
         return replies
 
 
-    def on_node_reply(self, c, nodename, reply):
+    async def on_node_reply(self, c: WorkerConsumerT,
+                            nodename: str, reply: Mapping) -> None:
         debug('mingle: processing reply from %s', nodename)
         debug('mingle: processing reply from %s', nodename)
         try:
         try:
-            self.sync_with_node(c, **reply)
+            await self.sync_with_node(c, **reply)
         except MemoryError:
         except MemoryError:
             raise
             raise
         except Exception as exc:  # pylint: disable=broad-except
         except Exception as exc:  # pylint: disable=broad-except
             exception('mingle: sync with %s failed: %r', nodename, exc)
             exception('mingle: sync with %s failed: %r', nodename, exc)
 
 
-    def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
+    async def sync_with_node(self, c: WorkerConsumerT,
+                             clock: int = None,
+                             revoked: Mapping = None,
+                             **kwargs) -> None:
         self.on_clock_event(c, clock)
         self.on_clock_event(c, clock)
         self.on_revoked_received(c, revoked)
         self.on_revoked_received(c, revoked)
 
 
-    def on_clock_event(self, c, clock):
+    def on_clock_event(self, c: WorkerConsumerT, clock: int):
         c.app.clock.adjust(clock) if clock else c.app.clock.forward()
         c.app.clock.adjust(clock) if clock else c.app.clock.forward()
 
 
-    def on_revoked_received(self, c, revoked):
+    def on_revoked_received(self,
+                            c: WorkerConsumerT, revoked: Mapping) -> None:
         if revoked:
         if revoked:
             c.controller.state.revoked.update(revoked)
             c.controller.state.revoked.update(revoked)

+ 9 - 7
celery/worker/consumer/tasks.py

@@ -1,6 +1,8 @@
 """Worker Task Consumer Bootstep."""
 """Worker Task Consumer Bootstep."""
+from typing import Mapping
 from kombu.common import QoS, ignore_errors
 from kombu.common import QoS, ignore_errors
 from celery import bootsteps
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from .mingle import Mingle
 from .mingle import Mingle
 
 
@@ -14,11 +16,11 @@ class Tasks(bootsteps.StartStopStep):
 
 
     requires = (Mingle,)
     requires = (Mingle,)
 
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         c.task_consumer = c.qos = None
         c.task_consumer = c.qos = None
         super(Tasks, self).__init__(c, **kwargs)
         super(Tasks, self).__init__(c, **kwargs)
 
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         """Start task consumer."""
         """Start task consumer."""
         c.update_strategies()
         c.update_strategies()
 
 
@@ -36,20 +38,20 @@ class Tasks(bootsteps.StartStopStep):
             c.connection, on_decode_error=c.on_decode_error,
             c.connection, on_decode_error=c.on_decode_error,
         )
         )
 
 
-        def set_prefetch_count(prefetch_count):
-            return c.task_consumer.qos(
+        async def set_prefetch_count(prefetch_count: int) -> None:
+            await c.task_consumer.qos(
                 prefetch_count=prefetch_count,
                 prefetch_count=prefetch_count,
                 apply_global=qos_global,
                 apply_global=qos_global,
             )
             )
         c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
         c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
 
 
-    def stop(self, c):
+    async def stop(self, c: WorkerConsumerT) -> None:
         """Stop task consumer."""
         """Stop task consumer."""
         if c.task_consumer:
         if c.task_consumer:
             debug('Canceling task consumer...')
             debug('Canceling task consumer...')
             ignore_errors(c, c.task_consumer.cancel)
             ignore_errors(c, c.task_consumer.cancel)
 
 
-    def shutdown(self, c):
+    async def shutdown(self, c: WorkerConsumerT) -> None:
         """Shutdown task consumer."""
         """Shutdown task consumer."""
         if c.task_consumer:
         if c.task_consumer:
             self.stop(c)
             self.stop(c)
@@ -57,6 +59,6 @@ class Tasks(bootsteps.StartStopStep):
             ignore_errors(c, c.task_consumer.close)
             ignore_errors(c, c.task_consumer.close)
             c.task_consumer = None
             c.task_consumer = None
 
 
-    def info(self, c):
+    def info(self, c: WorkerConsumerT) -> Mapping:
         """Return task consumer info."""
         """Return task consumer info."""
         return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
         return {'prefetch_count': c.qos.value if c.qos else 'N/A'}

+ 100 - 58
celery/worker/control.py

@@ -2,19 +2,19 @@
 """Worker remote control command implementations."""
 """Worker remote control command implementations."""
 import io
 import io
 import tempfile
 import tempfile
-
-from collections import UserDict, namedtuple
-
+from collections import UserDict
+from typing import (
+    Any, Callable, Iterable, Mapping, NamedTuple, Sequence, Tuple, Union,
+)
 from billiard.common import TERM_SIGNAME
 from billiard.common import TERM_SIGNAME
 from kombu.utils.encoding import safe_repr
 from kombu.utils.encoding import safe_repr
-
 from celery.exceptions import WorkerShutdown
 from celery.exceptions import WorkerShutdown
 from celery.platforms import signals as _signals
 from celery.platforms import signals as _signals
+from celery.types import ControlStateT as StateT, RequestT, TimerT
 from celery.utils.functional import maybe_list
 from celery.utils.functional import maybe_list
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.serialization import jsonify, strtobool
 from celery.utils.serialization import jsonify, strtobool
 from celery.utils.time import rate
 from celery.utils.time import rate
-
 from . import state as worker_state
 from . import state as worker_state
 from .request import Request
 from .request import Request
 
 
@@ -23,17 +23,24 @@ __all__ = ['Panel']
 DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
 DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-controller_info_t = namedtuple('controller_info_t', [
-    'alias', 'type', 'visible', 'default_timeout',
-    'help', 'signature', 'args', 'variadic',
-])
+class controller_info_t(NamedTuple):
+    """Metadata about control command."""
+
+    alias: str
+    type: str
+    visible: bool
+    default_timeout: float
+    help: str
+    signature: str
+    args: Sequence[Tuple[str, type]]
+    variadic: str
 
 
 
 
-def ok(value):
+def ok(value: Any) -> Mapping:
     return {'ok': value}
     return {'ok': value}
 
 
 
 
-def nok(value):
+def nok(value: Any) -> Mapping:
     return {'error': value}
     return {'error': value}
 
 
 
 
@@ -44,17 +51,24 @@ class Panel(UserDict):
     meta = {}      # -"-
     meta = {}      # -"-
 
 
     @classmethod
     @classmethod
-    def register(cls, *args, **kwargs):
+    def register(cls, *args, **kwargs) -> Callable:
         if args:
         if args:
             return cls._register(**kwargs)(*args)
             return cls._register(**kwargs)(*args)
         return cls._register(**kwargs)
         return cls._register(**kwargs)
 
 
     @classmethod
     @classmethod
-    def _register(cls, name=None, alias=None, type='control',
-                  visible=True, default_timeout=1.0, help=None,
-                  signature=None, args=None, variadic=None):
-
-        def _inner(fun):
+    def _register(cls,
+                  name: str = None,
+                  alias: str = None,
+                  type: str = 'control',
+                  visible: bool = True,
+                  default_timeout: float = 1.0,
+                  help: str = None,
+                  signature: str = None,
+                  args: Sequence[Tuple[str, type]] = None,
+                  variadic: str = None) -> Callable:
+
+        def _inner(fun: Callable) -> Callable:
             control_name = name or fun.__name__
             control_name = name or fun.__name__
             _help = help or (fun.__doc__ or '').strip().split('\n')[0]
             _help = help or (fun.__doc__ or '').strip().split('\n')[0]
             cls.data[control_name] = fun
             cls.data[control_name] = fun
@@ -67,18 +81,18 @@ class Panel(UserDict):
         return _inner
         return _inner
 
 
 
 
-def control_command(**kwargs):
+def control_command(**kwargs) -> Callable:
     return Panel.register(type='control', **kwargs)
     return Panel.register(type='control', **kwargs)
 
 
 
 
-def inspect_command(**kwargs):
+def inspect_command(**kwargs) -> Callable:
     return Panel.register(type='inspect', **kwargs)
     return Panel.register(type='inspect', **kwargs)
 
 
 # -- App
 # -- App
 
 
 
 
 @inspect_command()
 @inspect_command()
-def report(state):
+def report(state: StateT) -> Mapping:
     """Information about Celery installation for bug reports."""
     """Information about Celery installation for bug reports."""
     return ok(state.app.bugreport())
     return ok(state.app.bugreport())
 
 
@@ -88,14 +102,14 @@ def report(state):
     signature='[include_defaults=False]',
     signature='[include_defaults=False]',
     args=[('with_defaults', strtobool)],
     args=[('with_defaults', strtobool)],
 )
 )
-def conf(state, with_defaults=False, **kwargs):
+def conf(state: StateT, with_defaults: bool = False, **kwargs) -> Mapping:
     """List configuration."""
     """List configuration."""
     return jsonify(state.app.conf.table(with_defaults=with_defaults),
     return jsonify(state.app.conf.table(with_defaults=with_defaults),
                    keyfilter=_wanted_config_key,
                    keyfilter=_wanted_config_key,
                    unknown_type_filter=safe_repr)
                    unknown_type_filter=safe_repr)
 
 
 
 
-def _wanted_config_key(key):
+def _wanted_config_key(key: Any) -> bool:
     return isinstance(key, str) and not key.startswith('__')
     return isinstance(key, str) and not key.startswith('__')
 
 
 
 
@@ -105,7 +119,9 @@ def _wanted_config_key(key):
     variadic='ids',
     variadic='ids',
     signature='[id1 [id2 [... [idN]]]]',
     signature='[id1 [id2 [... [idN]]]]',
 )
 )
-def query_task(state, ids, **kwargs):
+def query_task(state: StateT,
+               ids: Union[str, Sequence[str]],
+               **kwargs) -> Mapping[str, Tuple[str, Mapping]]:
     """Query for task information by id."""
     """Query for task information by id."""
     return {
     return {
         req.id: (_state_of_task(req), req.info())
         req.id: (_state_of_task(req), req.info())
@@ -113,8 +129,10 @@ def query_task(state, ids, **kwargs):
     }
     }
 
 
 
 
-def _find_requests_by_id(ids,
-                         get_request=worker_state.requests.__getitem__):
+def _find_requests_by_id(
+        ids: Sequence[str],
+        *,
+        get_request=worker_state.requests.__getitem__) -> Iterable[RequestT]:
     for task_id in ids:
     for task_id in ids:
         try:
         try:
             yield get_request(task_id)
             yield get_request(task_id)
@@ -122,9 +140,11 @@ def _find_requests_by_id(ids,
             pass
             pass
 
 
 
 
-def _state_of_task(request,
-                   is_active=worker_state.active_requests.__contains__,
-                   is_reserved=worker_state.reserved_requests.__contains__):
+def _state_of_task(
+        request: RequestT,
+        *,
+        is_active=worker_state.active_requests.__contains__,
+        is_reserved=worker_state.reserved_requests.__contains__) -> str:
     if is_active(request):
     if is_active(request):
         return 'active'
         return 'active'
     elif is_reserved(request):
     elif is_reserved(request):
@@ -136,7 +156,9 @@ def _state_of_task(request,
     variadic='task_id',
     variadic='task_id',
     signature='[id1 [id2 [... [idN]]]]',
     signature='[id1 [id2 [... [idN]]]]',
 )
 )
-def revoke(state, task_id, terminate=False, signal=None, **kwargs):
+def revoke(state: StateT, task_id: str,
+           terminate: bool = False,
+           signal: Union[str, int] = None, **kwargs) -> Mapping:
     """Revoke task by task id (or list of ids).
     """Revoke task by task id (or list of ids).
 
 
     Keyword Arguments:
     Keyword Arguments:
@@ -176,7 +198,7 @@ def revoke(state, task_id, terminate=False, signal=None, **kwargs):
     args=[('signal', str)],
     args=[('signal', str)],
     signature='<signal> [id1 [id2 [... [idN]]]]'
     signature='<signal> [id1 [id2 [... [idN]]]]'
 )
 )
-def terminate(state, signal, task_id, **kwargs):
+def terminate(state: StateT, signal: str, task_id: str, **kwargs) -> Mapping:
     """Terminate task by task id (or list of ids)."""
     """Terminate task by task id (or list of ids)."""
     return revoke(state, task_id, terminate=True, signal=signal)
     return revoke(state, task_id, terminate=True, signal=signal)
 
 
@@ -185,7 +207,8 @@ def terminate(state, signal, task_id, **kwargs):
     args=[('task_name', str), ('rate_limit', str)],
     args=[('task_name', str), ('rate_limit', str)],
     signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>',
     signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>',
 )
 )
-def rate_limit(state, task_name, rate_limit, **kwargs):
+def rate_limit(state: StateT, task_name: str, rate_limit: Union[str, int],
+               **kwargs) -> Mapping:
     """Tell worker(s) to modify the rate limit for a task by type.
     """Tell worker(s) to modify the rate limit for a task by type.
 
 
     See Also:
     See Also:
@@ -225,7 +248,10 @@ def rate_limit(state, task_name, rate_limit, **kwargs):
     args=[('task_name', str), ('soft', float), ('hard', float)],
     args=[('task_name', str), ('soft', float), ('hard', float)],
     signature='<task_name> <soft_secs> [hard_secs]',
     signature='<task_name> <soft_secs> [hard_secs]',
 )
 )
-def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
+def time_limit(state: StateT,
+               task_name: str = None,
+               hard: float = None,
+               soft: float = None, **kwargs) -> Mapping:
     """Tell worker(s) to modify the time limit for task by type.
     """Tell worker(s) to modify the time limit for task by type.
 
 
     Arguments:
     Arguments:
@@ -252,13 +278,14 @@ def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
 
 
 
 
 @inspect_command()
 @inspect_command()
-def clock(state, **kwargs):
+def clock(state: StateT, **kwargs) -> Mapping:
     """Get current logical clock value."""
     """Get current logical clock value."""
     return {'clock': state.app.clock.value}
     return {'clock': state.app.clock.value}
 
 
 
 
 @control_command()
 @control_command()
-def election(state, id, topic, action=None, **kwargs):
+def election(state: StateT, id: str, topic: str, action: str = None,
+             **kwargs) -> None:
     """Hold election.
     """Hold election.
 
 
     Arguments:
     Arguments:
@@ -271,7 +298,7 @@ def election(state, id, topic, action=None, **kwargs):
 
 
 
 
 @control_command()
 @control_command()
-def enable_events(state):
+def enable_events(state: StateT) -> Mapping:
     """Tell worker(s) to send task-related events."""
     """Tell worker(s) to send task-related events."""
     dispatcher = state.consumer.event_dispatcher
     dispatcher = state.consumer.event_dispatcher
     if dispatcher.groups and 'task' not in dispatcher.groups:
     if dispatcher.groups and 'task' not in dispatcher.groups:
@@ -282,7 +309,7 @@ def enable_events(state):
 
 
 
 
 @control_command()
 @control_command()
-def disable_events(state):
+def disable_events(state: StateT) -> Mapping:
     """Tell worker(s) to stop sending task-related events."""
     """Tell worker(s) to stop sending task-related events."""
     dispatcher = state.consumer.event_dispatcher
     dispatcher = state.consumer.event_dispatcher
     if 'task' in dispatcher.groups:
     if 'task' in dispatcher.groups:
@@ -293,7 +320,7 @@ def disable_events(state):
 
 
 
 
 @control_command()
 @control_command()
-def heartbeat(state):
+def heartbeat(state: StateT) -> None:
     """Tell worker(s) to send event heartbeat immediately."""
     """Tell worker(s) to send event heartbeat immediately."""
     logger.debug('Heartbeat requested by remote.')
     logger.debug('Heartbeat requested by remote.')
     dispatcher = state.consumer.event_dispatcher
     dispatcher = state.consumer.event_dispatcher
@@ -303,7 +330,8 @@ def heartbeat(state):
 # -- Worker
 # -- Worker
 
 
 @inspect_command(visible=False)
 @inspect_command(visible=False)
-def hello(state, from_node, revoked=None, **kwargs):
+def hello(state: StateT, from_node: str,
+          revoked: Mapping = None, **kwargs) -> Mapping:
     """Request mingle sync-data."""
     """Request mingle sync-data."""
     # pylint: disable=redefined-outer-name
     # pylint: disable=redefined-outer-name
     # XXX Note that this redefines `revoked`:
     # XXX Note that this redefines `revoked`:
@@ -319,24 +347,24 @@ def hello(state, from_node, revoked=None, **kwargs):
 
 
 
 
 @inspect_command(default_timeout=0.2)
 @inspect_command(default_timeout=0.2)
-def ping(state, **kwargs):
+def ping(state: StateT, **kwargs) -> Mapping:
     """Ping worker(s)."""
     """Ping worker(s)."""
     return ok('pong')
     return ok('pong')
 
 
 
 
 @inspect_command()
 @inspect_command()
-def stats(state, **kwargs):
+def stats(state: StateT, **kwargs) -> Mapping:
     """Request worker statistics/information."""
     """Request worker statistics/information."""
     return state.consumer.controller.stats()
     return state.consumer.controller.stats()
 
 
 
 
 @inspect_command(alias='dump_schedule')
 @inspect_command(alias='dump_schedule')
-def scheduled(state, **kwargs):
+def scheduled(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of currently scheduled ETA/countdown tasks."""
     """List of currently scheduled ETA/countdown tasks."""
     return list(_iter_schedule_requests(state.consumer.timer))
     return list(_iter_schedule_requests(state.consumer.timer))
 
 
 
 
-def _iter_schedule_requests(timer):
+def _iter_schedule_requests(timer: TimerT) -> Iterable[Mapping]:
     for waiting in timer.schedule.queue:
     for waiting in timer.schedule.queue:
         try:
         try:
             arg0 = waiting.entry.args[0]
             arg0 = waiting.entry.args[0]
@@ -352,7 +380,7 @@ def _iter_schedule_requests(timer):
 
 
 
 
 @inspect_command(alias='dump_reserved')
 @inspect_command(alias='dump_reserved')
-def reserved(state, **kwargs):
+def reserved(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of currently reserved tasks, not including scheduled/active."""
     """List of currently reserved tasks, not including scheduled/active."""
     reserved_tasks = (
     reserved_tasks = (
         state.tset(worker_state.reserved_requests) -
         state.tset(worker_state.reserved_requests) -
@@ -364,14 +392,14 @@ def reserved(state, **kwargs):
 
 
 
 
 @inspect_command(alias='dump_active')
 @inspect_command(alias='dump_active')
-def active(state, **kwargs):
+def active(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of tasks currently being executed."""
     """List of tasks currently being executed."""
     return [request.info()
     return [request.info()
             for request in state.tset(worker_state.active_requests)]
             for request in state.tset(worker_state.active_requests)]
 
 
 
 
 @inspect_command(alias='dump_revoked')
 @inspect_command(alias='dump_revoked')
-def revoked(state, **kwargs):
+def revoked(state: StateT, **kwargs) -> Sequence[str]:
     """List of revoked task-ids."""
     """List of revoked task-ids."""
     return list(worker_state.revoked)
     return list(worker_state.revoked)
 
 
@@ -381,7 +409,9 @@ def revoked(state, **kwargs):
     variadic='taskinfoitems',
     variadic='taskinfoitems',
     signature='[attr1 [attr2 [... [attrN]]]]',
     signature='[attr1 [attr2 [... [attrN]]]]',
 )
 )
-def registered(state, taskinfoitems=None, builtins=False, **kwargs):
+def registered(state: StateT,
+               taskinfoitems: Sequence[str] = None,
+               builtins: bool = False, **kwargs) -> Sequence[Mapping]:
     """List of registered tasks.
     """List of registered tasks.
 
 
     Arguments:
     Arguments:
@@ -415,7 +445,10 @@ def registered(state, taskinfoitems=None, builtins=False, **kwargs):
     args=[('type', str), ('num', int), ('max_depth', int)],
     args=[('type', str), ('num', int), ('max_depth', int)],
     signature='[object_type=Request] [num=200 [max_depth=10]]',
     signature='[object_type=Request] [num=200 [max_depth=10]]',
 )
 )
-def objgraph(state, num=200, max_depth=10, type='Request'):  # pragma: no cover
+def objgraph(state: StateT,  # pragma: no cover
+             num: int = 200,
+             max_depth: int = 10,
+             type: str = 'Request') -> Mapping:
     """Create graph of uncollected objects (memory-leak debugging).
     """Create graph of uncollected objects (memory-leak debugging).
 
 
     Arguments:
     Arguments:
@@ -440,7 +473,7 @@ def objgraph(state, num=200, max_depth=10, type='Request'):  # pragma: no cover
 
 
 
 
 @inspect_command()
 @inspect_command()
-def memsample(state, **kwargs):
+def memsample(state: StateT, **kwargs) -> str:
     """Sample current RSS memory usage."""
     """Sample current RSS memory usage."""
     from celery.utils.debug import sample_mem
     from celery.utils.debug import sample_mem
     return sample_mem()
     return sample_mem()
@@ -450,7 +483,8 @@ def memsample(state, **kwargs):
     args=[('samples', int)],
     args=[('samples', int)],
     signature='[n_samples=10]',
     signature='[n_samples=10]',
 )
 )
-def memdump(state, samples=10, **kwargs):  # pragma: no cover
+def memdump(state: StateT,  # pragma: no cover
+            samples: int = 10, **kwargs) -> str:
     """Dump statistics of previous memsample requests."""
     """Dump statistics of previous memsample requests."""
     from celery.utils import debug
     from celery.utils import debug
     out = io.StringIO()
     out = io.StringIO()
@@ -464,7 +498,7 @@ def memdump(state, samples=10, **kwargs):  # pragma: no cover
     args=[('n', int)],
     args=[('n', int)],
     signature='[N=1]',
     signature='[N=1]',
 )
 )
-def pool_grow(state, n=1, **kwargs):
+def pool_grow(state: StateT, n: int = 1, **kwargs) -> Mapping:
     """Grow pool by n processes/threads."""
     """Grow pool by n processes/threads."""
     if state.consumer.controller.autoscaler:
     if state.consumer.controller.autoscaler:
         state.consumer.controller.autoscaler.force_scale_up(n)
         state.consumer.controller.autoscaler.force_scale_up(n)
@@ -478,7 +512,7 @@ def pool_grow(state, n=1, **kwargs):
     args=[('n', int)],
     args=[('n', int)],
     signature='[N=1]',
     signature='[N=1]',
 )
 )
-def pool_shrink(state, n=1, **kwargs):
+def pool_shrink(state: StateT, n: int = 1, **kwargs) -> Mapping:
     """Shrink pool by n processes/threads."""
     """Shrink pool by n processes/threads."""
     if state.consumer.controller.autoscaler:
     if state.consumer.controller.autoscaler:
         state.consumer.controller.autoscaler.force_scale_down(n)
         state.consumer.controller.autoscaler.force_scale_down(n)
@@ -489,7 +523,11 @@ def pool_shrink(state, n=1, **kwargs):
 
 
 
 
 @control_command()
 @control_command()
-def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
+def pool_restart(state: StateT,
+                 modules: Sequence[str] = None,
+                 reload: bool = False,
+                 reloader: Callable = None,
+                 **kwargs) -> Mapping:
     """Restart execution pool."""
     """Restart execution pool."""
     if state.app.conf.worker_pool_restarts:
     if state.app.conf.worker_pool_restarts:
         state.consumer.controller.reload(modules, reload, reloader=reloader)
         state.consumer.controller.reload(modules, reload, reloader=reloader)
@@ -502,7 +540,7 @@ def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
     args=[('max', int), ('min', int)],
     args=[('max', int), ('min', int)],
     signature='[max [min]]',
     signature='[max [min]]',
 )
 )
-def autoscale(state, max=None, min=None):
+def autoscale(state: StateT, max: int = None, min: int = None) -> Mapping:
     """Modify autoscale settings."""
     """Modify autoscale settings."""
     autoscaler = state.consumer.controller.autoscaler
     autoscaler = state.consumer.controller.autoscaler
     if autoscaler:
     if autoscaler:
@@ -512,7 +550,8 @@ def autoscale(state, max=None, min=None):
 
 
 
 
 @control_command()
 @control_command()
-def shutdown(state, msg='Got shutdown from remote', **kwargs):
+def shutdown(state: StateT, msg: str = 'Got shutdown from remote',
+             **kwargs) -> None:
     """Shutdown worker(s)."""
     """Shutdown worker(s)."""
     logger.warning(msg)
     logger.warning(msg)
     raise WorkerShutdown(msg)
     raise WorkerShutdown(msg)
@@ -529,8 +568,11 @@ def shutdown(state, msg='Got shutdown from remote', **kwargs):
     ],
     ],
     signature='<queue> [exchange [type [routing_key]]]',
     signature='<queue> [exchange [type [routing_key]]]',
 )
 )
-def add_consumer(state, queue, exchange=None, exchange_type=None,
-                 routing_key=None, **options):
+def add_consumer(state: StateT, queue: str,
+                 exchange: str = None,
+                 exchange_type: str = None,
+                 routing_key: str = None,
+                 **options) -> Mapping:
     """Tell worker(s) to consume from task queue by name."""
     """Tell worker(s) to consume from task queue by name."""
     state.consumer.call_soon(
     state.consumer.call_soon(
         state.consumer.add_task_queue,
         state.consumer.add_task_queue,
@@ -542,7 +584,7 @@ def add_consumer(state, queue, exchange=None, exchange_type=None,
     args=[('queue', str)],
     args=[('queue', str)],
     signature='<queue>',
     signature='<queue>',
 )
 )
-def cancel_consumer(state, queue, **_):
+def cancel_consumer(state: StateT, queue: str, **_) -> Mapping:
     """Tell worker(s) to stop consuming from task queue by name."""
     """Tell worker(s) to stop consuming from task queue by name."""
     state.consumer.call_soon(
     state.consumer.call_soon(
         state.consumer.cancel_task_queue, queue,
         state.consumer.cancel_task_queue, queue,
@@ -551,7 +593,7 @@ def cancel_consumer(state, queue, **_):
 
 
 
 
 @inspect_command()
 @inspect_command()
-def active_queues(state):
+async def active_queues(state: StateT) -> Sequence[Mapping]:
     """List the task queues a worker are currently consuming from."""
     """List the task queues a worker are currently consuming from."""
     if state.consumer.task_consumer:
     if state.consumer.task_consumer:
         return [dict(queue.as_dict(recurse=True))
         return [dict(queue.as_dict(recurse=True))

+ 15 - 9
celery/worker/heartbeat.py

@@ -4,7 +4,10 @@
 This is the internal thread responsible for sending heartbeat events
 This is the internal thread responsible for sending heartbeat events
 at regular intervals (may not be an actual thread).
 at regular intervals (may not be an actual thread).
 """
 """
+from typing import Awaitable
+from celery.events import EventDispatcher
 from celery.signals import heartbeat_sent
 from celery.signals import heartbeat_sent
+from celery.types import EventT, TimerT
 from celery.utils.sysinfo import load_average
 from celery.utils.sysinfo import load_average
 from .state import SOFTWARE_INFO, active_requests, all_total_count
 from .state import SOFTWARE_INFO, active_requests, all_total_count
 
 
@@ -22,7 +25,8 @@ class Heart:
             heartbeats.  Default is 2 seconds.
             heartbeats.  Default is 2 seconds.
     """
     """
 
 
-    def __init__(self, timer, eventer, interval=None):
+    def __init__(self, *, timer: TimerT, eventer: EventDispatcher,
+                 interval: float = None):
         self.timer = timer
         self.timer = timer
         self.eventer = eventer
         self.eventer = eventer
         self.interval = float(interval or 2.0)
         self.interval = float(interval or 2.0)
@@ -36,23 +40,25 @@ class Heart:
         self._send_sent_signal = (
         self._send_sent_signal = (
             heartbeat_sent.send if heartbeat_sent.receivers else None)
             heartbeat_sent.send if heartbeat_sent.receivers else None)
 
 
-    def _send(self, event):
+    async def _send(self, event: EventT) -> Awaitable:
         if self._send_sent_signal is not None:
         if self._send_sent_signal is not None:
             self._send_sent_signal(sender=self)
             self._send_sent_signal(sender=self)
-        return self.eventer.send(event, freq=self.interval,
-                                 active=len(active_requests),
-                                 processed=all_total_count[0],
-                                 loadavg=load_average(),
-                                 **SOFTWARE_INFO)
+        return await self.eventer.send(
+            event,
+            freq=self.interval,
+            active=len(active_requests),
+            processed=all_total_count[0],
+            loadavg=load_average(),
+            **SOFTWARE_INFO)
 
 
-    def start(self):
+    async def start(self) -> None:
         if self.eventer.enabled:
         if self.eventer.enabled:
             self._send('worker-online')
             self._send('worker-online')
             self.tref = self.timer.call_repeatedly(
             self.tref = self.timer.call_repeatedly(
                 self.interval, self._send, ('worker-heartbeat',),
                 self.interval, self._send, ('worker-heartbeat',),
             )
             )
 
 
-    def stop(self):
+    async def stop(self) -> None:
         if self.tref is not None:
         if self.tref is not None:
             self.timer.cancel(self.tref)
             self.timer.cancel(self.tref)
             self.tref = None
             self.tref = None

+ 72 - 47
celery/worker/request.py

@@ -7,21 +7,27 @@ how tasks are executed.
 import logging
 import logging
 import sys
 import sys
 
 
-from datetime import datetime
+from datetime import datetime, tzinfo
+from typing import Any, Awaitable, Callable, Mapping, Sequence, Tuple, Union
 from weakref import ref
 from weakref import ref
 
 
 from billiard.common import TERM_SIGNAME
 from billiard.common import TERM_SIGNAME
+from billiard.einfo import ExceptionInfo
+from kombu.types import MessageT
 from kombu.utils.encoding import safe_repr, safe_str
 from kombu.utils.encoding import safe_repr, safe_str
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 from celery import signals
 from celery import signals
 from celery.app.trace import trace_task, trace_task_ret
 from celery.app.trace import trace_task, trace_task_ret
+from celery.events import EventDispatcher
 from celery.exceptions import (
 from celery.exceptions import (
     Ignore, TaskRevokedError, InvalidTaskError,
     Ignore, TaskRevokedError, InvalidTaskError,
     SoftTimeLimitExceeded, TimeLimitExceeded,
     SoftTimeLimitExceeded, TimeLimitExceeded,
     WorkerLostError, Terminated, Retry, Reject,
     WorkerLostError, Terminated, Retry, Reject,
 )
 )
 from celery.platforms import signals as _signals
 from celery.platforms import signals as _signals
+from celery.types import AppT, PoolT, SignatureT, TaskT
+from celery.utils.collections import LimitedSet
 from celery.utils.functional import maybe, noop
 from celery.utils.functional import maybe, noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.nodenames import gethostname
 from celery.utils.nodenames import gethostname
@@ -44,7 +50,7 @@ _does_info = False
 _does_debug = False
 _does_debug = False
 
 
 
 
-def __optimize__():
+def __optimize__() -> None:
     # this is also called by celery.app.trace.setup_worker_optimizations
     # this is also called by celery.app.trace.setup_worker_optimizations
     global _does_debug
     global _does_debug
     global _does_info
     global _does_info
@@ -65,13 +71,13 @@ class Request:
     """A request for task execution."""
     """A request for task execution."""
 
 
     acknowledged = False
     acknowledged = False
-    time_start = None
-    worker_pid = None
-    time_limits = (None, None)
-    _already_revoked = False
-    _terminate_on_ack = None
-    _apply_result = None
-    _tzlocal = None
+    time_start: float = None
+    worker_pid: int = None
+    time_limits: Tuple[float, float] = (None, None)
+    _already_revoked: bool = False
+    _terminate_on_ack: bool = None
+    _apply_result: Awaitable = None
+    _tzlocal: tzinfo = None
 
 
     if not IS_PYPY:  # pragma: no cover
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
         __slots__ = (
@@ -83,13 +89,23 @@ class Request:
             '__weakref__', '__dict__',
             '__weakref__', '__dict__',
         )
         )
 
 
-    def __init__(self, message, on_ack=noop,
-                 hostname=None, eventer=None, app=None,
-                 connection_errors=None, request_dict=None,
-                 task=None, on_reject=noop, body=None,
-                 headers=None, decoded=False, utc=True,
-                 maybe_make_aware=maybe_make_aware,
-                 maybe_iso8601=maybe_iso8601, **opts):
+    def __init__(self, message: MessageT,
+                 *,
+                 on_ack: Callable = noop,
+                 hostname: str = None,
+                 eventer: EventDispatcher = None,
+                 app: AppT = None,
+                 connection_errors: Tuple = None,
+                 request_dict: Mapping = None,
+                 task: str = None,
+                 on_reject: Callable = noop,
+                 body: Mapping = None,
+                 headers: Mapping = None,
+                 decoded: bool = False,
+                 utc: bool = True,
+                 maybe_make_aware: Callable = maybe_make_aware,
+                 maybe_iso8601: Callable = maybe_iso8601,
+                 **opts) -> None:
         if headers is None:
         if headers is None:
             headers = message.headers
             headers = message.headers
         if body is None:
         if body is None:
@@ -163,10 +179,10 @@ class Request:
         self.request_dict = headers
         self.request_dict = headers
 
 
     @property
     @property
-    def delivery_info(self):
+    def delivery_info(self) -> Mapping:
         return self.request_dict['delivery_info']
         return self.request_dict['delivery_info']
 
 
-    def execute_using_pool(self, pool, **kwargs):
+    def execute_using_pool(self, pool: PoolT, **kwargs) -> Awaitable:
         """Used by the worker to send this task to the pool.
         """Used by the worker to send this task to the pool.
 
 
         Arguments:
         Arguments:
@@ -198,7 +214,7 @@ class Request:
         self._apply_result = maybe(ref, result)
         self._apply_result = maybe(ref, result)
         return result
         return result
 
 
-    def execute(self, loglevel=None, logfile=None):
+    def execute(self, loglevel: int = None, logfile: str = None) -> Any:
         """Execute the task in a :func:`~celery.app.trace.trace_task`.
         """Execute the task in a :func:`~celery.app.trace.trace_task`.
 
 
         Arguments:
         Arguments:
@@ -230,15 +246,16 @@ class Request:
         self.acknowledge()
         self.acknowledge()
         return retval
         return retval
 
 
-    def maybe_expire(self):
+    def maybe_expire(self) -> bool:
         """If expired, mark the task as revoked."""
         """If expired, mark the task as revoked."""
         if self.expires:
         if self.expires:
             now = datetime.now(self.expires.tzinfo)
             now = datetime.now(self.expires.tzinfo)
             if now > self.expires:
             if now > self.expires:
                 revoked_tasks.add(self.id)
                 revoked_tasks.add(self.id)
                 return True
                 return True
+        return False
 
 
-    def terminate(self, pool, signal=None):
+    def terminate(self, pool: PoolT, signal: Union[str, int] = None) -> None:
         signal = _signals.signum(signal or TERM_SIGNAME)
         signal = _signals.signum(signal or TERM_SIGNAME)
         if self.time_start:
         if self.time_start:
             pool.terminate_job(self.worker_pid, signal)
             pool.terminate_job(self.worker_pid, signal)
@@ -250,7 +267,8 @@ class Request:
             if obj is not None:
             if obj is not None:
                 obj.terminate(signal)
                 obj.terminate(signal)
 
 
-    def _announce_revoked(self, reason, terminated, signum, expired):
+    def _announce_revoked(self, reason: str, terminated: bool,
+                          signum: int, expired: bool):
         task_ready(self)
         task_ready(self)
         self.send_event('task-revoked',
         self.send_event('task-revoked',
                         terminated=terminated, signum=signum, expired=expired)
                         terminated=terminated, signum=signum, expired=expired)
@@ -262,7 +280,7 @@ class Request:
         send_revoked(self.task, request=self,
         send_revoked(self.task, request=self,
                      terminated=terminated, signum=signum, expired=expired)
                      terminated=terminated, signum=signum, expired=expired)
 
 
-    def revoked(self):
+    def revoked(self) -> bool:
         """If revoked, skip task and mark state."""
         """If revoked, skip task and mark state."""
         expired = False
         expired = False
         if self._already_revoked:
         if self._already_revoked:
@@ -277,11 +295,11 @@ class Request:
             return True
             return True
         return False
         return False
 
 
-    def send_event(self, type, **fields):
+    def send_event(self, type: str, **fields) -> Awaitable:
         if self.eventer and self.eventer.enabled and self.task.send_events:
         if self.eventer and self.eventer.enabled and self.task.send_events:
-            self.eventer.send(type, uuid=self.id, **fields)
+            return self.eventer.send(type, uuid=self.id, **fields)
 
 
-    def on_accepted(self, pid, time_accepted):
+    def on_accepted(self, pid: int, time_accepted: float) -> None:
         """Handler called when task is accepted by worker pool."""
         """Handler called when task is accepted by worker pool."""
         self.worker_pid = pid
         self.worker_pid = pid
         self.time_start = time_accepted
         self.time_start = time_accepted
@@ -294,7 +312,7 @@ class Request:
         if self._terminate_on_ack is not None:
         if self._terminate_on_ack is not None:
             self.terminate(*self._terminate_on_ack)
             self.terminate(*self._terminate_on_ack)
 
 
-    def on_timeout(self, soft, timeout):
+    def on_timeout(self, soft: bool, timeout: float) -> None:
         """Handler called if the task times out."""
         """Handler called if the task times out."""
         task_ready(self)
         task_ready(self)
         if soft:
         if soft:
@@ -313,7 +331,8 @@ class Request:
         if self.task.acks_late:
         if self.task.acks_late:
             self.acknowledge()
             self.acknowledge()
 
 
-    def on_success(self, failed__retval__runtime, **kwargs):
+    def on_success(self, failed__retval__runtime: Tuple[bool, Any, float],
+                   **kwargs) -> None:
         """Handler called if the task was successfully processed."""
         """Handler called if the task was successfully processed."""
         failed, retval, runtime = failed__retval__runtime
         failed, retval, runtime = failed__retval__runtime
         if failed:
         if failed:
@@ -327,7 +346,7 @@ class Request:
 
 
         self.send_event('task-succeeded', result=retval, runtime=runtime)
         self.send_event('task-succeeded', result=retval, runtime=runtime)
 
 
-    def on_retry(self, exc_info):
+    def on_retry(self, exc_info: ExceptionInfo) -> None:
         """Handler called if the task should be retried."""
         """Handler called if the task should be retried."""
         if self.task.acks_late:
         if self.task.acks_late:
             self.acknowledge()
             self.acknowledge()
@@ -336,7 +355,9 @@ class Request:
                         exception=safe_repr(exc_info.exception.exc),
                         exception=safe_repr(exc_info.exception.exc),
                         traceback=safe_str(exc_info.traceback))
                         traceback=safe_str(exc_info.traceback))
 
 
-    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
+    def on_failure(self, exc_info: ExceptionInfo,
+                   send_failed_event: bool = True,
+                   return_ok: bool = False) -> None:
         """Handler called if the task raised an exception."""
         """Handler called if the task raised an exception."""
         task_ready(self)
         task_ready(self)
         if isinstance(exc_info.exception, MemoryError):
         if isinstance(exc_info.exception, MemoryError):
@@ -385,19 +406,19 @@ class Request:
             error('Task handler raised error: %r', exc,
             error('Task handler raised error: %r', exc,
                   exc_info=exc_info.exc_info)
                   exc_info=exc_info.exc_info)
 
 
-    def acknowledge(self):
+    def acknowledge(self) -> None:
         """Acknowledge task."""
         """Acknowledge task."""
         if not self.acknowledged:
         if not self.acknowledged:
             self.on_ack(logger, self.connection_errors)
             self.on_ack(logger, self.connection_errors)
             self.acknowledged = True
             self.acknowledged = True
 
 
-    def reject(self, requeue=False):
+    def reject(self, requeue: bool = False) -> None:
         if not self.acknowledged:
         if not self.acknowledged:
             self.on_reject(logger, self.connection_errors, requeue)
             self.on_reject(logger, self.connection_errors, requeue)
             self.acknowledged = True
             self.acknowledged = True
             self.send_event('task-rejected', requeue=requeue)
             self.send_event('task-rejected', requeue=requeue)
 
 
-    def info(self, safe=False):
+    def info(self, safe: bool = False) -> Mapping:
         return {
         return {
             'id': self.id,
             'id': self.id,
             'name': self.name,
             'name': self.name,
@@ -411,10 +432,10 @@ class Request:
             'worker_pid': self.worker_pid,
             'worker_pid': self.worker_pid,
         }
         }
 
 
-    def humaninfo(self):
+    def humaninfo(self) -> str:
         return '{0.name}[{0.id}]'.format(self)
         return '{0.name}[{0.id}]'.format(self)
 
 
-    def __str__(self):
+    def __str__(self) -> str:
         """``str(self)``."""
         """``str(self)``."""
         return ' '.join([
         return ' '.join([
             self.humaninfo(),
             self.humaninfo(),
@@ -422,7 +443,7 @@ class Request:
             ' expires:[{0}]'.format(self.expires) if self.expires else '',
             ' expires:[{0}]'.format(self.expires) if self.expires else '',
         ])
         ])
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(self)``."""
         """``repr(self)``."""
         return '<{0}: {1} {2} {3}>'.format(
         return '<{0}: {1} {2} {3}>'.format(
             type(self).__name__, self.humaninfo(),
             type(self).__name__, self.humaninfo(),
@@ -430,32 +451,32 @@ class Request:
         )
         )
 
 
     @property
     @property
-    def tzlocal(self):
+    def tzlocal(self) -> tzinfo:
         if self._tzlocal is None:
         if self._tzlocal is None:
             self._tzlocal = self.app.conf.timezone
             self._tzlocal = self.app.conf.timezone
         return self._tzlocal
         return self._tzlocal
 
 
     @property
     @property
-    def store_errors(self):
+    def store_errors(self) -> bool:
         return (not self.task.ignore_result or
         return (not self.task.ignore_result or
                 self.task.store_errors_even_if_ignored)
                 self.task.store_errors_even_if_ignored)
 
 
     @property
     @property
-    def reply_to(self):
+    def reply_to(self) -> str:
         # used by rpc backend when failures reported by parent process
         # used by rpc backend when failures reported by parent process
         return self.request_dict['reply_to']
         return self.request_dict['reply_to']
 
 
     @property
     @property
-    def correlation_id(self):
+    def correlation_id(self) -> str:
         # used similarly to reply_to
         # used similarly to reply_to
         return self.request_dict['correlation_id']
         return self.request_dict['correlation_id']
 
 
     @cached_property
     @cached_property
-    def _payload(self):
+    def _payload(self) -> Any:
         return self.body if self._decoded else self.message.payload
         return self.body if self._decoded else self.message.payload
 
 
     @cached_property
     @cached_property
-    def chord(self):
+    def chord(self) -> SignatureT:
         # used by backend.mark_as_failure when failure is reported
         # used by backend.mark_as_failure when failure is reported
         # by parent process
         # by parent process
         # pylint: disable=unpacking-non-sequence
         # pylint: disable=unpacking-non-sequence
@@ -464,7 +485,7 @@ class Request:
         return embed.get('chord')
         return embed.get('chord')
 
 
     @cached_property
     @cached_property
-    def errbacks(self):
+    def errbacks(self) -> Sequence[SignatureT]:
         # used by backend.mark_as_failure when failure is reported
         # used by backend.mark_as_failure when failure is reported
         # by parent process
         # by parent process
         # pylint: disable=unpacking-non-sequence
         # pylint: disable=unpacking-non-sequence
@@ -473,15 +494,19 @@ class Request:
         return embed.get('errbacks')
         return embed.get('errbacks')
 
 
     @cached_property
     @cached_property
-    def group(self):
+    def group(self) -> str:
         # used by backend.on_chord_part_return when failures reported
         # used by backend.on_chord_part_return when failures reported
         # by parent process
         # by parent process
         return self.request_dict['group']
         return self.request_dict['group']
 
 
 
 
-def create_request_cls(base, task, pool, hostname, eventer,
-                       ref=ref, revoked_tasks=revoked_tasks,
-                       task_ready=task_ready, trace=trace_task_ret):
+def create_request_cls(base: type, task: TaskT, pool: PoolT,
+                       hostname: str, eventer: EventDispatcher,
+                       *,
+                       ref: Callable = ref,
+                       revoked_tasks: LimitedSet = revoked_tasks,
+                       task_ready: Callable = task_ready,
+                       trace: Callable = trace_task_ret) -> type:
     default_time_limit = task.time_limit
     default_time_limit = task.time_limit
     default_soft_time_limit = task.soft_time_limit
     default_soft_time_limit = task.soft_time_limit
     apply_async = pool.apply_async
     apply_async = pool.apply_async

+ 52 - 40
celery/worker/state.py

@@ -12,12 +12,17 @@ import weakref
 import zlib
 import zlib
 
 
 from collections import Counter
 from collections import Counter
+from typing import (
+    Any, Callable, Mapping, MutableMapping, Set, Sequence, Union,
+)
 
 
+from kombu.clocks import Clock
 from kombu.serialization import pickle, pickle_protocol
 from kombu.serialization import pickle, pickle_protocol
 from kombu.utils.objects import cached_property
 from kombu.utils.objects import cached_property
 
 
 from celery import __version__
 from celery import __version__
 from celery.exceptions import WorkerShutdown, WorkerTerminate
 from celery.exceptions import WorkerShutdown, WorkerTerminate
+from celery.types import RequestT
 from celery.utils.collections import LimitedSet
 from celery.utils.collections import LimitedSet
 
 
 __all__ = [
 __all__ = [
@@ -27,7 +32,7 @@ __all__ = [
 ]
 ]
 
 
 #: Worker software/platform information.
 #: Worker software/platform information.
-SOFTWARE_INFO = {
+SOFTWARE_INFO: Mapping[str, Any] = {
     'sw_ident': 'py-celery',
     'sw_ident': 'py-celery',
     'sw_ver': __version__,
     'sw_ver': __version__,
     'sw_sys': platform.system(),
     'sw_sys': platform.system(),
@@ -41,28 +46,28 @@ REVOKES_MAX = 50000
 REVOKE_EXPIRES = 10800
 REVOKE_EXPIRES = 10800
 
 
 #: Mapping of reserved task_id->Request.
 #: Mapping of reserved task_id->Request.
-requests = {}
+requests: Mapping[RequestT] = {}
 
 
 #: set of all reserved :class:`~celery.worker.request.Request`'s.
 #: set of all reserved :class:`~celery.worker.request.Request`'s.
-reserved_requests = weakref.WeakSet()
+reserved_requests: Set[RequestT] = weakref.WeakSet()
 
 
 #: set of currently active :class:`~celery.worker.request.Request`'s.
 #: set of currently active :class:`~celery.worker.request.Request`'s.
-active_requests = weakref.WeakSet()
+active_requests: Set[RequestT] = weakref.WeakSet()
 
 
 #: count of tasks accepted by the worker, sorted by type.
 #: count of tasks accepted by the worker, sorted by type.
-total_count = Counter()
+total_count: Mapping[str, int] = Counter()
 
 
 #: count of all tasks accepted by the worker
 #: count of all tasks accepted by the worker
-all_total_count = [0]
+all_total_count: Sequence[int] = [0]
 
 
 #: the list of currently revoked tasks.  Persistent if ``statedb`` set.
 #: the list of currently revoked tasks.  Persistent if ``statedb`` set.
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
 
-should_stop = None
-should_terminate = None
+should_stop: Union[int, bool] = None
+should_terminate: Union[int, bool] = None
 
 
 
 
-def reset_state():
+def reset_state() -> None:
     requests.clear()
     requests.clear()
     reserved_requests.clear()
     reserved_requests.clear()
     active_requests.clear()
     active_requests.clear()
@@ -71,7 +76,7 @@ def reset_state():
     revoked.clear()
     revoked.clear()
 
 
 
 
-def maybe_shutdown():
+def maybe_shutdown() -> None:
     """Shutdown if flags have been set."""
     """Shutdown if flags have been set."""
     if should_stop is not None and should_stop is not False:
     if should_stop is not None and should_stop is not False:
         raise WorkerShutdown(should_stop)
         raise WorkerShutdown(should_stop)
@@ -79,28 +84,34 @@ def maybe_shutdown():
         raise WorkerTerminate(should_terminate)
         raise WorkerTerminate(should_terminate)
 
 
 
 
-def task_reserved(request,
-                  add_request=requests.__setitem__,
-                  add_reserved_request=reserved_requests.add):
+def task_reserved(
+        request: RequestT,
+        *,
+        add_request: Callable = requests.__setitem__,
+        add_reserved_request: Callable = reserved_requests.add) -> None:
     """Update global state when a task has been reserved."""
     """Update global state when a task has been reserved."""
     add_request(request.id, request)
     add_request(request.id, request)
     add_reserved_request(request)
     add_reserved_request(request)
 
 
 
 
-def task_accepted(request,
-                  _all_total_count=all_total_count,
-                  add_active_request=active_requests.add,
-                  add_to_total_count=total_count.update):
+def task_accepted(
+        request: RequestT,
+        *,
+        _all_total_count: Sequence[int] = all_total_count,
+        add_active_request: Callable = active_requests.add,
+        add_to_total_count: Callable = total_count.update) -> None:
     """Update global state when a task has been accepted."""
     """Update global state when a task has been accepted."""
     add_active_request(request)
     add_active_request(request)
     add_to_total_count({request.name: 1})
     add_to_total_count({request.name: 1})
     all_total_count[0] += 1
     all_total_count[0] += 1
 
 
 
 
-def task_ready(request,
-               remove_request=requests.pop,
-               discard_active_request=active_requests.discard,
-               discard_reserved_request=reserved_requests.discard):
+def task_ready(
+        request: RequestT,
+        *,
+        remove_request: Callable = requests.pop,
+        discard_reserved_request: Callable = reserved_requests.discard,
+        discard_active_request: Callable = active_requests.discard) -> None:
     """Update global state when a task is ready."""
     """Update global state when a task is ready."""
     remove_request(request.id, None)
     remove_request(request.id, None)
     discard_active_request(request)
     discard_active_request(request)
@@ -129,7 +140,7 @@ if C_BENCH:  # pragma: no cover
 
 
     if current_process()._name == 'MainProcess':
     if current_process()._name == 'MainProcess':
         @atexit.register
         @atexit.register
-        def on_shutdown():
+        def on_shutdown() -> None:
             if bench_first is not None and bench_last is not None:
             if bench_first is not None and bench_last is not None:
                 print('- Time spent in benchmark: {0!r}'.format(
                 print('- Time spent in benchmark: {0!r}'.format(
                       bench_last - bench_first))
                       bench_last - bench_first))
@@ -137,7 +148,7 @@ if C_BENCH:  # pragma: no cover
                       sum(bench_sample) / len(bench_sample)))
                       sum(bench_sample) / len(bench_sample)))
                 memdump()
                 memdump()
 
 
-    def task_reserved(request):  # noqa
+    def task_reserved(request: RequestT, **kwargs) -> None:  # noqa
         """Called when a task is reserved by the worker."""
         """Called when a task is reserved by the worker."""
         global bench_start
         global bench_start
         global bench_first
         global bench_first
@@ -149,7 +160,7 @@ if C_BENCH:  # pragma: no cover
 
 
         return __reserved(request)
         return __reserved(request)
 
 
-    def task_ready(request):  # noqa
+    def task_ready(request: RequestT, **kwargs) -> None:  # noqa
         """Called when a task is completed."""
         """Called when a task is completed."""
         global all_count
         global all_count
         global bench_start
         global bench_start
@@ -182,39 +193,40 @@ class Persistent:
     decompress = zlib.decompress
     decompress = zlib.decompress
     _is_open = False
     _is_open = False
 
 
-    def __init__(self, state, filename, clock=None):
+    def __init__(self, state: Any,
+                 filename: str, clock: Clock = None) -> None:
         self.state = state
         self.state = state
         self.filename = filename
         self.filename = filename
         self.clock = clock
         self.clock = clock
         self.merge()
         self.merge()
 
 
-    def open(self):
+    def open(self) -> MutableMapping:
         return self.storage.open(
         return self.storage.open(
             self.filename, protocol=self.protocol, writeback=True,
             self.filename, protocol=self.protocol, writeback=True,
         )
         )
 
 
-    def merge(self):
+    def merge(self) -> None:
         self._merge_with(self.db)
         self._merge_with(self.db)
 
 
-    def sync(self):
+    def sync(self) -> None:
         self._sync_with(self.db)
         self._sync_with(self.db)
         self.db.sync()
         self.db.sync()
 
 
-    def close(self):
+    def close(self) -> None:
         if self._is_open:
         if self._is_open:
             self.db.close()
             self.db.close()
             self._is_open = False
             self._is_open = False
 
 
-    def save(self):
+    def save(self) -> None:
         self.sync()
         self.sync()
         self.close()
         self.close()
 
 
-    def _merge_with(self, d):
+    def _merge_with(self, d: MutableMapping) -> MutableMapping:
         self._merge_revoked(d)
         self._merge_revoked(d)
         self._merge_clock(d)
         self._merge_clock(d)
         return d
         return d
 
 
-    def _sync_with(self, d):
+    def _sync_with(self, d: MutableMapping) -> MutableMapping:
         self._revoked_tasks.purge()
         self._revoked_tasks.purge()
         d.update({
         d.update({
             str('__proto__'): 3,
             str('__proto__'): 3,
@@ -223,11 +235,11 @@ class Persistent:
         })
         })
         return d
         return d
 
 
-    def _merge_clock(self, d):
+    def _merge_clock(self, d: MutableMapping):
         if self.clock:
         if self.clock:
             d[str('clock')] = self.clock.adjust(d.get(str('clock')) or 0)
             d[str('clock')] = self.clock.adjust(d.get(str('clock')) or 0)
 
 
-    def _merge_revoked(self, d):
+    def _merge_revoked(self, d: MutableMapping) -> None:
         try:
         try:
             self._merge_revoked_v3(d[str('zrevoked')])
             self._merge_revoked_v3(d[str('zrevoked')])
         except KeyError:
         except KeyError:
@@ -238,29 +250,29 @@ class Persistent:
         # purge expired items at boot
         # purge expired items at boot
         self._revoked_tasks.purge()
         self._revoked_tasks.purge()
 
 
-    def _merge_revoked_v3(self, zrevoked):
+    def _merge_revoked_v3(self, zrevoked: str) -> None:
         if zrevoked:
         if zrevoked:
             self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked)))
             self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked)))
 
 
-    def _merge_revoked_v2(self, saved):
+    def _merge_revoked_v2(self, saved: Mapping) -> Mapping:
         if not isinstance(saved, LimitedSet):
         if not isinstance(saved, LimitedSet):
             # (pre 3.0.18) used to be stored as a dict
             # (pre 3.0.18) used to be stored as a dict
             return self._merge_revoked_v1(saved)
             return self._merge_revoked_v1(saved)
         self._revoked_tasks.update(saved)
         self._revoked_tasks.update(saved)
 
 
-    def _merge_revoked_v1(self, saved):
+    def _merge_revoked_v1(self, saved: Sequence) -> None:
         add = self._revoked_tasks.add
         add = self._revoked_tasks.add
         for item in saved:
         for item in saved:
             add(item)
             add(item)
 
 
-    def _dumps(self, obj):
+    def _dumps(self, obj: Any) -> bytes:
         return pickle.dumps(obj, protocol=self.protocol)
         return pickle.dumps(obj, protocol=self.protocol)
 
 
     @property
     @property
-    def _revoked_tasks(self):
+    def _revoked_tasks(self) -> LimitedSet:
         return self.state.revoked
         return self.state.revoked
 
 
     @cached_property
     @cached_property
-    def db(self):
+    def db(self) -> MutableMapping:
         self._is_open = True
         self._is_open = True
         return self.open()
         return self.open()

+ 45 - 14
celery/worker/strategy.py

@@ -1,14 +1,16 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """Task execution strategy (optimization)."""
 """Task execution strategy (optimization)."""
 import logging
 import logging
-
+from typing import (
+    Awaitable, Callable, Dict, List, Mapping, NamedTuple, Sequence, Tuple,
+)
 from kombu.async.timer import to_timestamp
 from kombu.async.timer import to_timestamp
-
+from kombu.types import MessageT
 from celery.exceptions import InvalidTaskError
 from celery.exceptions import InvalidTaskError
+from celery.types import AppT, WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.saferepr import saferepr
 from celery.utils.saferepr import saferepr
 from celery.utils.time import timezone
 from celery.utils.time import timezone
-
 from .request import Request, create_request_cls
 from .request import Request, create_request_cls
 from .state import task_reserved
 from .state import task_reserved
 
 
@@ -20,7 +22,16 @@ logger = get_logger(__name__)
 # We cache globals and attribute lookups, so disable this warning.
 # We cache globals and attribute lookups, so disable this warning.
 
 
 
 
-def proto1_to_proto2(message, body):
+class converted_message_t(NamedTuple):
+    """Describes a converted message."""
+
+    body: Tuple[List, Dict, Mapping]
+    headers: Mapping
+    decoded: bool
+    utc: bool
+
+
+def proto1_to_proto2(message: MessageT, body: Mapping) -> converted_message_t:
     """Convert Task message protocol 1 arguments to protocol 2.
     """Convert Task message protocol 1 arguments to protocol 2.
 
 
     Returns:
     Returns:
@@ -50,13 +61,28 @@ def proto1_to_proto2(message, body):
         'chord': body.get('chord'),
         'chord': body.get('chord'),
         'chain': None,
         'chain': None,
     }
     }
-    return (args, kwargs, embed), body, True, body.get('utc', True)
+    return converted_message_t(
+        body=(args, kwargs, embed),
+        headers=body,
+        decoded=True,
+        utc=body.get('utc', True),
+    )
+
 
 
+StrategyT = Callable[
+    [MessageT, Mapping, Callable, Callable, Sequence[Callable]],
+    Awaitable,
+]
 
 
-def default(task, app, consumer,
-            info=logger.info, error=logger.error, task_reserved=task_reserved,
-            to_system_tz=timezone.to_system, bytes=bytes,
-            proto1_to_proto2=proto1_to_proto2):
+
+def default(task: str, app: AppT, consumer: WorkerConsumerT,
+            *,
+            info: Callable = logger.info,
+            error: Callable = logger.error,
+            task_reserved: Callable = task_reserved,
+            to_system_tz: Callable = timezone.to_system,
+            proto1_to_proto2: Callable = proto1_to_proto2,
+            bytes: Callable = bytes) -> StrategyT:
     """Default task execution strategy.
     """Default task execution strategy.
 
 
     Note:
     Note:
@@ -83,9 +109,14 @@ def default(task, app, consumer,
     Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
     Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
 
 
     revoked_tasks = consumer.controller.state.revoked
     revoked_tasks = consumer.controller.state.revoked
-
-    def task_message_handler(message, body, ack, reject, callbacks,
-                             to_timestamp=to_timestamp):
+    convert_to_timestamp = to_timestamp
+
+    async def task_message_handler(
+            message: MessageT,
+            body: Mapping,
+            ack: Callable,
+            reject: Callable,
+            callbacks: Sequence[Callable]) -> Awaitable:
         if body is None:
         if body is None:
             body, headers, decoded, utc = (
             body, headers, decoded, utc = (
                 message.body, message.headers, False, True,
                 message.body, message.headers, False, True,
@@ -118,9 +149,9 @@ def default(task, app, consumer,
         if req.eta:
         if req.eta:
             try:
             try:
                 if req.utc:
                 if req.utc:
-                    eta = to_timestamp(to_system_tz(req.eta))
+                    eta = convert_to_timestamp(to_system_tz(req.eta))
                 else:
                 else:
-                    eta = to_timestamp(req.eta, timezone.local)
+                    eta = convert_to_timestamp(req.eta, timezone.local)
             except (OverflowError, ValueError) as exc:
             except (OverflowError, ValueError) as exc:
                 error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                 error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                       req.eta, exc, req.info(safe=True), exc_info=True)
                       req.eta, exc, req.info(safe=True), exc_info=True)

+ 116 - 85
celery/worker/worker.py

@@ -15,6 +15,8 @@ The worker consists of several components, all managed by bootsteps
 import os
 import os
 import sys
 import sys
 
 
+from typing import Any, Callable, Mapping, Set, Sequence, Union
+
 from billiard import cpu_count
 from billiard import cpu_count
 from kombu.utils.compat import detect_environment
 from kombu.utils.compat import detect_environment
 
 
@@ -25,7 +27,10 @@ from celery import signals
 from celery.exceptions import (
 from celery.exceptions import (
     ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
     ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
 )
 )
-from celery.platforms import EX_FAILURE, create_pidlock
+from celery.types import (
+    AppT, BlueprintT, LoopT, PoolT, RequestT, StepT, WorkerConsumerT,
+)
+from celery.platforms import EX_FAILURE, Pidfile, create_pidlock
 from celery.utils.imports import reload_from_cwd
 from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.nodenames import default_nodename, worker_direct
 from celery.utils.nodenames import default_nodename, worker_direct
@@ -57,25 +62,27 @@ Trying to deselect queue subset of {0!r}, but queue {1} isn't
 defined in the `task_queues` setting.
 defined in the `task_queues` setting.
 """
 """
 
 
+CSVListArgT = Union[Sequence[str], str]
+
 
 
 class WorkController:
 class WorkController:
     """Unmanaged worker instance."""
     """Unmanaged worker instance."""
 
 
-    app = None
+    app: AppT = None
 
 
-    pidlock = None
-    blueprint = None
-    pool = None
-    semaphore = None
+    pidlock: Pidfile = None
+    blueprint: BlueprintT = None
+    pool: PoolT = None
+    semaphore: Any = None
 
 
     #: contains the exit code if a :exc:`SystemExit` event is handled.
     #: contains the exit code if a :exc:`SystemExit` event is handled.
-    exitcode = None
+    exitcode: int = None
 
 
     class Blueprint(bootsteps.Blueprint):
     class Blueprint(bootsteps.Blueprint):
         """Worker bootstep blueprint."""
         """Worker bootstep blueprint."""
 
 
         name = 'Worker'
         name = 'Worker'
-        default_steps = {
+        default_steps: Set[Union[str, StepT]] = {
             'celery.worker.components:Hub',
             'celery.worker.components:Hub',
             'celery.worker.components:Pool',
             'celery.worker.components:Pool',
             'celery.worker.components:Beat',
             'celery.worker.components:Beat',
@@ -85,7 +92,12 @@ class WorkController:
             'celery.worker.autoscale:WorkerComponent',
             'celery.worker.autoscale:WorkerComponent',
         }
         }
 
 
-    def __init__(self, app=None, hostname=None, **kwargs):
+    def __init__(
+            self,
+            *,
+            app: AppT = None,
+            hostname: str = None,
+            **kwargs) -> None:
         self.app = app or self.app
         self.app = app or self.app
         self.hostname = default_nodename(hostname)
         self.hostname = default_nodename(hostname)
         self.app.loader.init_worker()
         self.app.loader.init_worker()
@@ -95,9 +107,15 @@ class WorkController:
 
 
         self.setup_instance(**self.prepare_args(**kwargs))
         self.setup_instance(**self.prepare_args(**kwargs))
 
 
-    def setup_instance(self, queues=None, ready_callback=None, pidfile=None,
-                       include=None, use_eventloop=None, exclude_queues=None,
-                       **kwargs):
+    def setup_instance(self,
+                       *,
+                       queues: CSVListArgT = None,
+                       ready_callback: Callable = None,
+                       pidfile: str = None,
+                       include: CSVListArgT = None,
+                       use_eventloop: bool = None,
+                       exclude_queues: CSVListArgT = None,
+                       **kwargs) -> None:
         self.pidfile = pidfile
         self.pidfile = pidfile
         self.setup_queues(queues, exclude_queues)
         self.setup_queues(queues, exclude_queues)
         self.setup_includes(str_to_list(include))
         self.setup_includes(str_to_list(include))
@@ -135,33 +153,34 @@ class WorkController:
         )
         )
         self.blueprint.apply(self, **kwargs)
         self.blueprint.apply(self, **kwargs)
 
 
-    def on_init_blueprint(self):
-        pass
+    def on_init_blueprint(self) -> None:
+        ...
 
 
-    def on_before_init(self, **kwargs):
-        pass
+    def on_before_init(self, **kwargs) -> None:
+        ...
 
 
-    def on_after_init(self, **kwargs):
-        pass
+    def on_after_init(self, **kwargs) -> None:
+        ...
 
 
-    def on_start(self):
+    def on_start(self) -> None:
         if self.pidfile:
         if self.pidfile:
             self.pidlock = create_pidlock(self.pidfile)
             self.pidlock = create_pidlock(self.pidfile)
 
 
-    def on_consumer_ready(self, consumer):
-        pass
+    def on_consumer_ready(self, consumer: WorkerConsumerT) -> None:
+        ...
 
 
-    def on_close(self):
+    def on_close(self) -> None:
         self.app.loader.shutdown_worker()
         self.app.loader.shutdown_worker()
 
 
-    def on_stopped(self):
+    def on_stopped(self) -> None:
         self.timer.stop()
         self.timer.stop()
         self.consumer.shutdown()
         self.consumer.shutdown()
 
 
         if self.pidlock:
         if self.pidlock:
             self.pidlock.release()
             self.pidlock.release()
 
 
-    def setup_queues(self, include, exclude=None):
+    def setup_queues(self, include: CSVListArgT,
+                     exclude: CSVListArgT = None) -> None:
         include = str_to_list(include)
         include = str_to_list(include)
         exclude = str_to_list(exclude)
         exclude = str_to_list(exclude)
         try:
         try:
@@ -177,7 +196,7 @@ class WorkController:
         if self.app.conf.worker_direct:
         if self.app.conf.worker_direct:
             self.app.amqp.queues.select_add(worker_direct(self.hostname))
             self.app.amqp.queues.select_add(worker_direct(self.hostname))
 
 
-    def setup_includes(self, includes):
+    def setup_includes(self, includes: Sequence[str]) -> None:
         # Update celery_include to have all known task modules, so that we
         # Update celery_include to have all known task modules, so that we
         # ensure all task modules are imported in case an execv happens.
         # ensure all task modules are imported in case an execv happens.
         prev = tuple(self.app.conf.include)
         prev = tuple(self.app.conf.include)
@@ -189,81 +208,88 @@ class WorkController:
                         for task in self.app.tasks.values()}
                         for task in self.app.tasks.values()}
         self.app.conf.include = tuple(set(prev) | task_modules)
         self.app.conf.include = tuple(set(prev) | task_modules)
 
 
-    def prepare_args(self, **kwargs):
+    def prepare_args(self, **kwargs) -> Mapping:
         return kwargs
         return kwargs
 
 
-    def _send_worker_shutdown(self):
+    def _send_worker_shutdown(self) -> None:
         signals.worker_shutdown.send(sender=self)
         signals.worker_shutdown.send(sender=self)
 
 
-    def start(self):
+    async def start(self) -> None:
         try:
         try:
-            self.blueprint.start(self)
+            await self.blueprint.start(self)
         except WorkerTerminate:
         except WorkerTerminate:
-            self.terminate()
+            await self.terminate()
         except Exception as exc:
         except Exception as exc:
             logger.critical('Unrecoverable error: %r', exc, exc_info=True)
             logger.critical('Unrecoverable error: %r', exc, exc_info=True)
-            self.stop(exitcode=EX_FAILURE)
+            await self.stop(exitcode=EX_FAILURE)
         except SystemExit as exc:
         except SystemExit as exc:
-            self.stop(exitcode=exc.code)
+            await self.stop(exitcode=exc.code)
         except KeyboardInterrupt:
         except KeyboardInterrupt:
-            self.stop(exitcode=EX_FAILURE)
+            await self.stop(exitcode=EX_FAILURE)
 
 
-    def register_with_event_loop(self, hub):
-        self.blueprint.send_all(
+    async def register_with_event_loop(self, hub: LoopT) -> None:
+        await self.blueprint.send_all(
             self, 'register_with_event_loop', args=(hub,),
             self, 'register_with_event_loop', args=(hub,),
             description='hub.register',
             description='hub.register',
         )
         )
 
 
-    def _process_task_sem(self, req):
-        return self._quick_acquire(self._process_task, req)
+    async def _process_task_sem(self, req: RequestT) -> None:
+        await self._quick_acquire(self._process_task, req)
 
 
-    def _process_task(self, req):
+    async def _process_task(self, req: RequestT) -> None:
         """Process task by sending it to the pool of workers."""
         """Process task by sending it to the pool of workers."""
         try:
         try:
-            req.execute_using_pool(self.pool)
+            await req.execute_using_pool(self.pool)
         except TaskRevokedError:
         except TaskRevokedError:
             try:
             try:
                 self._quick_release()   # Issue 877
                 self._quick_release()   # Issue 877
             except AttributeError:
             except AttributeError:
                 pass
                 pass
 
 
-    def signal_consumer_close(self):
+    def signal_consumer_close(self) -> None:
         try:
         try:
             self.consumer.close()
             self.consumer.close()
         except AttributeError:
         except AttributeError:
             pass
             pass
 
 
-    def should_use_eventloop(self):
+    def should_use_eventloop(self) -> bool:
         return (detect_environment() == 'default' and
         return (detect_environment() == 'default' and
                 self._conninfo.transport.implements.async and
                 self._conninfo.transport.implements.async and
                 not self.app.IS_WINDOWS)
                 not self.app.IS_WINDOWS)
 
 
-    def stop(self, in_sighandler=False, exitcode=None):
+    async def stop(self,
+                   *,
+                   in_sighandler: bool = False,
+                   exitcode: int = None) -> None:
         """Graceful shutdown of the worker server."""
         """Graceful shutdown of the worker server."""
         if exitcode is not None:
         if exitcode is not None:
             self.exitcode = exitcode
             self.exitcode = exitcode
         if self.blueprint.state == RUN:
         if self.blueprint.state == RUN:
             self.signal_consumer_close()
             self.signal_consumer_close()
             if not in_sighandler or self.pool.signal_safe:
             if not in_sighandler or self.pool.signal_safe:
-                self._shutdown(warm=True)
+                await self._shutdown(warm=True)
         self._send_worker_shutdown()
         self._send_worker_shutdown()
 
 
-    def terminate(self, in_sighandler=False):
+    async def terminate(self, *, in_sighandler: bool = False) -> None:
         """Not so graceful shutdown of the worker server."""
         """Not so graceful shutdown of the worker server."""
         if self.blueprint.state != TERMINATE:
         if self.blueprint.state != TERMINATE:
             self.signal_consumer_close()
             self.signal_consumer_close()
             if not in_sighandler or self.pool.signal_safe:
             if not in_sighandler or self.pool.signal_safe:
-                self._shutdown(warm=False)
+                await self._shutdown(warm=False)
 
 
-    def _shutdown(self, warm=True):
+    async def _shutdown(self, *, warm: bool = True) -> None:
         # if blueprint does not exist it means that we had an
         # if blueprint does not exist it means that we had an
         # error before the bootsteps could be initialized.
         # error before the bootsteps could be initialized.
         if self.blueprint is not None:
         if self.blueprint is not None:
             with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
             with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
-                self.blueprint.stop(self, terminate=not warm)
+                await self.blueprint.stop(self, terminate=not warm)
                 self.blueprint.join()
                 self.blueprint.join()
 
 
-    def reload(self, modules=None, reload=False, reloader=None):
+    def reload(self,
+               modules: Sequence[str] = None,
+               *,
+               reload: bool = False,
+               reloader: Callable = None) -> None:
         list(self._reload_modules(
         list(self._reload_modules(
             modules, force_reload=reload, reloader=reloader))
             modules, force_reload=reload, reloader=reloader))
 
 
@@ -275,14 +301,17 @@ class WorkController:
         except NotImplementedError:
         except NotImplementedError:
             pass
             pass
 
 
-    def _reload_modules(self, modules=None, **kwargs):
+    def _reload_modules(self, modules: Sequence[str] = None, **kwargs) -> None:
         return (
         return (
             self._maybe_reload_module(m, **kwargs)
             self._maybe_reload_module(m, **kwargs)
             for m in set(self.app.loader.task_modules
             for m in set(self.app.loader.task_modules
                          if modules is None else (modules or ()))
                          if modules is None else (modules or ()))
         )
         )
 
 
-    def _maybe_reload_module(self, module, force_reload=False, reloader=None):
+    def _maybe_reload_module(self, module: str,
+                             *,
+                             force_reload: bool = False,
+                             reloader: Callable = None) -> Any:
         if module not in sys.modules:
         if module not in sys.modules:
             logger.debug('importing module %s', module)
             logger.debug('importing module %s', module)
             return self.app.loader.import_from_cwd(module)
             return self.app.loader.import_from_cwd(module)
@@ -290,12 +319,12 @@ class WorkController:
             logger.debug('reloading module %s', module)
             logger.debug('reloading module %s', module)
             return reload_from_cwd(sys.modules[module], reloader)
             return reload_from_cwd(sys.modules[module], reloader)
 
 
-    def info(self):
+    def info(self) -> Mapping[str, Any]:
         return {'total': self.state.total_count,
         return {'total': self.state.total_count,
                 'pid': os.getpid(),
                 'pid': os.getpid(),
                 'clock': str(self.app.clock)}
                 'clock': str(self.app.clock)}
 
 
-    def rusage(self):
+    def rusage(self) -> Mapping[str, Any]:
         if resource is None:
         if resource is None:
             raise NotImplementedError('rusage not supported by this platform')
             raise NotImplementedError('rusage not supported by this platform')
         s = resource.getrusage(resource.RUSAGE_SELF)
         s = resource.getrusage(resource.RUSAGE_SELF)
@@ -318,7 +347,7 @@ class WorkController:
             'nivcsw': s.ru_nivcsw,
             'nivcsw': s.ru_nivcsw,
         }
         }
 
 
-    def stats(self):
+    def stats(self) -> Mapping[str, Any]:
         info = self.info()
         info = self.info()
         info.update(self.blueprint.info(self))
         info.update(self.blueprint.info(self))
         info.update(self.consumer.blueprint.info(self.consumer))
         info.update(self.consumer.blueprint.info(self.consumer))
@@ -328,49 +357,54 @@ class WorkController:
             info['rusage'] = 'N/A'
             info['rusage'] = 'N/A'
         return info
         return info
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(worker)``."""
         """``repr(worker)``."""
         return '<Worker: {self.hostname} ({state})>'.format(
         return '<Worker: {self.hostname} ({state})>'.format(
             self=self,
             self=self,
             state=self.blueprint.human_state() if self.blueprint else 'INIT',
             state=self.blueprint.human_state() if self.blueprint else 'INIT',
         )
         )
 
 
-    def __str__(self):
+    def __str__(self) -> str:
         """``str(worker) == worker.hostname``."""
         """``str(worker) == worker.hostname``."""
         return self.hostname
         return self.hostname
 
 
     @property
     @property
-    def state(self):
+    def state(self) -> int:
         return state
         return state
 
 
-    def setup_defaults(self, concurrency=None, loglevel='WARN', logfile=None,
-                       task_events=None, pool=None, consumer_cls=None,
-                       timer_cls=None, timer_precision=None,
-                       autoscaler_cls=None,
-                       pool_putlocks=None,
-                       pool_restarts=None,
-                       optimization=None, O=None,  # O maps to -O=fair
-                       statedb=None,
-                       time_limit=None,
-                       soft_time_limit=None,
-                       scheduler=None,
-                       pool_cls=None,              # XXX use pool
-                       state_db=None,              # XXX use statedb
-                       task_time_limit=None,       # XXX use time_limit
-                       task_soft_time_limit=None,  # XXX use soft_time_limit
-                       scheduler_cls=None,         # XXX use scheduler
-                       schedule_filename=None,
-                       max_tasks_per_child=None,
-                       prefetch_multiplier=None, disable_rate_limits=None,
-                       worker_lost_wait=None,
-                       max_memory_per_child=None, **_kw):
+    def setup_defaults(self,
+                       *,
+                       concurrency: int = None,
+                       loglevel: Union[str, int] = 'WARN',
+                       logfile: str = None,
+                       task_events: bool = None,
+                       pool: Union[str, type] = None,
+                       consumer_cls: Union[str, type] = None,
+                       timer_cls: Union[str, type] = None,
+                       timer_precision: float = None,
+                       autoscaler_cls: Union[str, type] = None,
+                       pool_putlocks: bool = None,
+                       pool_restarts: bool = None,
+                       optimization: str = None,
+                       O: str = None,  # O maps to -O=fair
+                       statedb: str = None,
+                       time_limit: float = None,
+                       soft_time_limit: float = None,
+                       scheduler: Union[str, type] = None,
+                       schedule_filename: str = None,
+                       max_tasks_per_child: int = None,
+                       prefetch_multiplier: float = None,
+                       disable_rate_limits: bool = None,
+                       worker_lost_wait: float = None,
+                       max_memory_per_child: float = None,
+                       **_kw) -> None:
         either = self.app.either
         either = self.app.either
         self.loglevel = loglevel
         self.loglevel = loglevel
         self.logfile = logfile
         self.logfile = logfile
 
 
         self.concurrency = either('worker_concurrency', concurrency)
         self.concurrency = either('worker_concurrency', concurrency)
         self.task_events = either('worker_send_task_events', task_events)
         self.task_events = either('worker_send_task_events', task_events)
-        self.pool_cls = either('worker_pool', pool, pool_cls)
+        self.pool_cls = either('worker_pool', pool)
         self.consumer_cls = either('worker_consumer', consumer_cls)
         self.consumer_cls = either('worker_consumer', consumer_cls)
         self.timer_cls = either('worker_timer', timer_cls)
         self.timer_cls = either('worker_timer', timer_cls)
         self.timer_precision = either(
         self.timer_precision = either(
@@ -380,16 +414,13 @@ class WorkController:
         self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls)
         self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls)
         self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks)
         self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks)
         self.pool_restarts = either('worker_pool_restarts', pool_restarts)
         self.pool_restarts = either('worker_pool_restarts', pool_restarts)
-        self.statedb = either('worker_state_db', statedb, state_db)
+        self.statedb = either('worker_state_db', statedb)
         self.schedule_filename = either(
         self.schedule_filename = either(
             'beat_schedule_filename', schedule_filename,
             'beat_schedule_filename', schedule_filename,
         )
         )
-        self.scheduler = either('beat_scheduler', scheduler, scheduler_cls)
-        self.time_limit = either(
-            'task_time_limit', time_limit, task_time_limit)
-        self.soft_time_limit = either(
-            'task_soft_time_limit', soft_time_limit, task_soft_time_limit,
-        )
+        self.scheduler = either('beat_scheduler', scheduler)
+        self.time_limit = either('task_time_limit', time_limit)
+        self.soft_time_limit = either('task_soft_time_limit', soft_time_limit)
         self.max_tasks_per_child = either(
         self.max_tasks_per_child = either(
             'worker_max_tasks_per_child', max_tasks_per_child,
             'worker_max_tasks_per_child', max_tasks_per_child,
         )
         )