Ask Solem 8 роки тому
батько
коміт
3dcc606b7e
51 змінених файлів з 2488 додано та 1583 видалено
  1. 11 4
      celery/__init__.py
  2. 8 8
      celery/app/__init__.py
  3. 147 78
      celery/app/amqp.py
  4. 8 4
      celery/app/backends.py
  5. 271 163
      celery/app/base.py
  6. 13 6
      celery/app/defaults.py
  7. 12 6
      celery/app/events.py
  8. 56 26
      celery/app/log.py
  9. 7 16
      celery/app/registry.py
  10. 32 12
      celery/app/routes.py
  11. 190 122
      celery/app/task.py
  12. 17 5
      celery/app/trace.py
  13. 61 36
      celery/app/utils.py
  14. 80 52
      celery/backends/async.py
  15. 264 166
      celery/backends/base.py
  16. 71 47
      celery/backends/rpc.py
  17. 9 3
      celery/beat.py
  18. 48 47
      celery/bootsteps.py
  19. 254 125
      celery/canvas.py
  20. 9 2
      celery/concurrency/asynpool.py
  21. 15 26
      celery/contrib/pytest.py
  22. 126 82
      celery/events/state.py
  23. 0 4
      celery/platforms.py
  24. 11 2
      celery/schedules.py
  25. 13 17
      celery/utils/collections.py
  26. 3 3
      celery/utils/debug.py
  27. 5 5
      celery/utils/functional.py
  28. 7 7
      celery/utils/graph.py
  29. 6 6
      celery/utils/imports.py
  30. 1 1
      celery/utils/log.py
  31. 34 17
      celery/utils/saferepr.py
  32. 3 3
      celery/utils/serialization.py
  33. 2 4
      celery/utils/static/__init__.py
  34. 7 4
      celery/utils/sysinfo.py
  35. 27 26
      celery/worker/autoscale.py
  36. 28 25
      celery/worker/components.py
  37. 4 2
      celery/worker/consumer/agent.py
  38. 8 6
      celery/worker/consumer/connection.py
  39. 113 83
      celery/worker/consumer/consumer.py
  40. 3 2
      celery/worker/consumer/control.py
  41. 14 13
      celery/worker/consumer/events.py
  42. 48 38
      celery/worker/consumer/gossip.py
  43. 11 6
      celery/worker/consumer/heart.py
  44. 22 13
      celery/worker/consumer/mingle.py
  45. 9 7
      celery/worker/consumer/tasks.py
  46. 100 58
      celery/worker/control.py
  47. 15 9
      celery/worker/heartbeat.py
  48. 72 47
      celery/worker/request.py
  49. 52 40
      celery/worker/state.py
  50. 45 14
      celery/worker/strategy.py
  51. 116 85
      celery/worker/worker.py

+ 11 - 4
celery/__init__.py

@@ -8,7 +8,7 @@
 import os
 import re
 import sys
-from collections import namedtuple
+from typing import NamedTuple
 
 SERIES = 'latentcall'
 
@@ -30,9 +30,16 @@ __all__ = [
 
 VERSION_BANNER = '{0} ({1})'.format(__version__, SERIES)
 
-version_info_t = namedtuple('version_info_t', (
-    'major', 'minor', 'micro', 'releaselevel', 'serial',
-))
+
+class version_info_t(NamedTuple):
+    """Version information tuple."""
+
+    major: int
+    minor: int
+    micro: int
+    releaselevel: str
+    serial: str
+
 
 # bumpversion can only search for {current_version}
 # so we have to parse the version here.

+ 8 - 8
celery/app/__init__.py

@@ -1,11 +1,11 @@
-# -*- coding: utf-8 -*-
-"""Celery Application."""
+from typing import Callable, Union
 from celery.local import Proxy
 from celery import _state
 from celery._state import (
     app_or_default, enable_trace, disable_trace,
     push_current_task, pop_current_task,
 )
+from celery.types import AppT, TaskT
 from .base import Celery
 from .utils import AppPickler
 
@@ -19,12 +19,12 @@ __all__ = [
 default_app = Proxy(lambda: _state.default_app)
 
 
-def bugreport(app=None):
+def bugreport(app: AppT = None) -> str:
     """Return information useful in bug reports."""
     return (app or _state.get_current_app()).bugreport()
 
 
-def shared_task(*args, **kwargs):
+def shared_task(*args, **kwargs) -> Union[Callable, TaskT]:
     """Create shared task (decorator).
 
     This can be used by library authors to create tasks that'll work
@@ -48,10 +48,10 @@ def shared_task(*args, **kwargs):
         >>> add.app is app2
         True
     """
-    def create_shared_task(**options):
+    def create_shared_task(**options) -> Callable:
 
-        def __inner(fun):
-            name = options.get('name')
+        def __inner(fun: Callable) -> TaskT:
+            name: str = options.get('name')
             # Set as shared task so that unfinalized apps,
             # and future apps will register a copy of this task.
             _state.connect_on_app_finalize(
@@ -66,7 +66,7 @@ def shared_task(*args, **kwargs):
 
             # Return a proxy that always gets the task from the current
             # apps task registry.
-            def task_by_cons():
+            def task_by_cons() -> TaskT:
                 app = _state.get_current_app()
                 return app.tasks[
                     name or app.gen_task_name(fun.__name__, fun.__module__)

+ 147 - 78
celery/app/amqp.py

@@ -2,17 +2,24 @@
 """Sending/Receiving Messages (Kombu integration)."""
 import numbers
 
-from collections import Mapping, namedtuple
-from datetime import timedelta
+from collections import Mapping
+from datetime import datetime, timedelta, tzinfo
+from typing import (
+    Any, Callable, MutableMapping, NamedTuple, Set, Sequence, Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
 from kombu import pools
 from kombu import Connection, Consumer, Exchange, Producer, Queue
 from kombu.common import Broadcast
+from kombu.types import ChannelT, ConsumerT, EntityT, ProducerT, ResourceT
 from kombu.utils.functional import maybe_list
 from kombu.utils.objects import cached_property
 
 from celery import signals
+from celery.events import EventDispatcher
+from celery.types import AppT, ResultT, RouterT, SignatureT
 from celery.utils.nodenames import anon_nodename
 from celery.utils.saferepr import saferepr
 from celery.utils.text import indent as textindent
@@ -31,11 +38,19 @@ QUEUE_FORMAT = """
 key={0.routing_key}
 """
 
-task_message = namedtuple('task_message',
-                          ('headers', 'properties', 'body', 'sent_event'))
+QueuesArgT = Union[Mapping[str, Queue], Sequence[Queue]]
 
 
-def utf8dict(d, encoding='utf-8'):
+class task_message(NamedTuple):
+    """Represents a task message that can be sent."""
+
+    headers: MutableMapping
+    properties: MutableMapping
+    body: Any
+    sent_event: Mapping
+
+
+def utf8dict(d: Mapping, encoding: str = 'utf-8') -> Mapping:
     return {k.decode(encoding) if isinstance(k, bytes) else k: v
             for k, v in d.items()}
 
@@ -54,11 +69,16 @@ class Queues(dict):
 
     #: If set, this is a subset of queues to consume from.
     #: The rest of the queues are then used for routing only.
-    _consume_from = None
-
-    def __init__(self, queues=None, default_exchange=None,
-                 create_missing=True, ha_policy=None, autoexchange=None,
-                 max_priority=None, default_routing_key=None):
+    _consume_from: Mapping[str, Queue] = None
+
+    def __init__(self,
+                 queues: QueuesArgT = None,
+                 default_exchange: str = None,
+                 create_missing: bool = True,
+                 ha_policy: Union[Sequence, str] = None,
+                 autoexchange: Callable[[str], Exchange] = None,
+                 max_priority: int = None,
+                 default_routing_key: str = None) -> None:
         dict.__init__(self)
         self.aliases = WeakValueDictionary()
         self.default_exchange = default_exchange
@@ -72,25 +92,25 @@ class Queues(dict):
         for name, q in (queues or {}).items():
             self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: str) -> Queue:
         try:
             return self.aliases[name]
         except KeyError:
             return dict.__getitem__(self, name)
 
-    def __setitem__(self, name, queue):
+    def __setitem__(self, name: str, queue: Queue) -> None:
         if self.default_exchange and not queue.exchange:
             queue.exchange = self.default_exchange
         dict.__setitem__(self, name, queue)
         if queue.alias:
             self.aliases[queue.alias] = queue
 
-    def __missing__(self, name):
+    def __missing__(self, name: str) -> Queue:
         if self.create_missing:
             return self.add(self.new_missing(name))
         raise KeyError(name)
 
-    def add(self, queue, **kwargs):
+    def add(self, queue: Union[Queue, str], **kwargs) -> Queue:
         """Add new queue.
 
         The first argument can either be a :class:`kombu.Queue` instance,
@@ -111,14 +131,14 @@ class Queues(dict):
             return self.add_compat(queue, **kwargs)
         return self._add(queue)
 
-    def add_compat(self, name, **options):
+    def add_compat(self, name: str, **options) -> Queue:
         # docs used to use binding_key as routing key
         options.setdefault('routing_key', options.get('binding_key'))
         if options['routing_key'] is None:
             options['routing_key'] = name
         return self._add(Queue.from_dict(name, **options))
 
-    def _add(self, queue):
+    def _add(self, queue: Queue) -> Queue:
         if not queue.routing_key:
             if queue.exchange is None or queue.exchange.name == '':
                 queue.exchange = self.default_exchange
@@ -134,18 +154,19 @@ class Queues(dict):
         self[queue.name] = queue
         return queue
 
-    def _set_ha_policy(self, args):
+    def _set_ha_policy(self, args: MutableMapping) -> None:
         policy = self.ha_policy
         if isinstance(policy, (list, tuple)):
-            return args.update({'x-ha-policy': 'nodes',
-                                'x-ha-policy-params': list(policy)})
-        args['x-ha-policy'] = policy
+            args.update({'x-ha-policy': 'nodes',
+                         'x-ha-policy-params': list(policy)})
+        else:
+            args['x-ha-policy'] = policy
 
-    def _set_max_priority(self, args):
+    def _set_max_priority(self, args: MutableMapping) -> None:
         if 'x-max-priority' not in args and self.max_priority is not None:
-            return args.update({'x-max-priority': self.max_priority})
+            args.update({'x-max-priority': self.max_priority})
 
-    def format(self, indent=0, indent_first=True):
+    def format(self, indent: int = 0, indent_first: bool = True) -> str:
         """Format routing table into string for log dumps."""
         active = self.consume_from
         if not active:
@@ -156,7 +177,7 @@ class Queues(dict):
             return textindent('\n'.join(info), indent)
         return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
 
-    def select_add(self, queue, **kwargs):
+    def select_add(self, queue: Queue, **kwargs) -> Queue:
         """Add new task queue that'll be consumed from.
 
         The queue will be active even when a subset has been selected
@@ -167,7 +188,7 @@ class Queues(dict):
             self._consume_from[q.name] = q
         return q
 
-    def select(self, include):
+    def select(self, include: Union[Sequence[str], str]) -> None:
         """Select a subset of currently defined queues to consume from.
 
         Arguments:
@@ -178,7 +199,7 @@ class Queues(dict):
                 name: self[name] for name in maybe_list(include)
             }
 
-    def deselect(self, exclude):
+    def deselect(self, exclude: Union[Sequence[str], str]) -> None:
         """Deselect queues so that they won't be consumed from.
 
         Arguments:
@@ -189,19 +210,20 @@ class Queues(dict):
             exclude = maybe_list(exclude)
             if self._consume_from is None:
                 # using selection
-                return self.select(k for k in self if k not in exclude)
-            # using all queues
-            for queue in exclude:
-                self._consume_from.pop(queue, None)
+                self.select(k for k in self if k not in exclude)
+            else:
+                # using all queues
+                for queue in exclude:
+                    self._consume_from.pop(queue, None)
 
-    def new_missing(self, name):
+    def new_missing(self, name: str) -> Queue:
         return Queue(name, self.autoexchange(name), name)
 
     @property
-    def consume_from(self):
+    def consume_from(self) -> Mapping[str, Queue]:
         if self._consume_from is not None:
             return self._consume_from
-        return self
+        return cast(Mapping[str, Queue], self)
 
 
 class AMQP:
@@ -221,13 +243,13 @@ class AMQP:
 
     #: Underlying producer pool instance automatically
     #: set by the :attr:`producer_pool`.
-    _producer_pool = None
+    _producer_pool: ResourceT = None
 
     # Exchange class/function used when defining automatic queues.
     # For example, you can use ``autoexchange = lambda n: None`` to use the
     # AMQP default exchange: a shortcut to bypass routing
     # and instead send directly to the queue named in the routing key.
-    autoexchange = None
+    autoexchange: Callable[[str], Exchange] = None
 
     #: Max size of positional argument representation used for
     #: logging purposes.
@@ -236,7 +258,9 @@ class AMQP:
     #: Max size of keyword argument representation used for logging purposes.
     kwargsrepr_maxsize = 1024
 
-    def __init__(self, app):
+    task_protocols: Mapping[int, Callable] = None
+
+    def __init__(self, app: AppT) -> None:
         self.app = app
         self.task_protocols = {
             1: self.as_task_v1,
@@ -244,15 +268,18 @@ class AMQP:
         }
 
     @cached_property
-    def create_task_message(self):
+    def create_task_message(self) -> Callable:
         return self.task_protocols[self.app.conf.task_protocol]
 
     @cached_property
-    def send_task_message(self):
+    def send_task_message(self) -> Callable:
         return self._create_task_sender()
 
-    def Queues(self, queues, create_missing=None, ha_policy=None,
-               autoexchange=None, max_priority=None):
+    def Queues(self, queues: QueuesArgT,
+               create_missing: bool = None,
+               ha_policy: Union[Sequence, str] = None,
+               autoexchange: Callable[[str], Exchange] = None,
+               max_priority: int = None) -> Queues:
         # Create new :class:`Queues` instance, using queue defaults
         # from the current configuration.
         conf = self.app.conf
@@ -267,23 +294,27 @@ class AMQP:
             queues = (Queue(conf.task_default_queue,
                             exchange=self.default_exchange,
                             routing_key=default_routing_key),)
-        autoexchange = (self.autoexchange if autoexchange is None
-                        else autoexchange)
+        autoexchange = (self.autoexchange
+                        if autoexchange is None else autoexchange)
         return self.queues_cls(
             queues, self.default_exchange, create_missing,
             ha_policy, autoexchange, max_priority, default_routing_key,
         )
 
-    def Router(self, queues=None, create_missing=None):
+    def Router(self, queues: Mapping[str, Queue] = None,
+               create_missing: bool = None) -> RouterT:
         """Return the current task router."""
         return _routes.Router(self.routes, queues or self.queues,
                               self.app.either('task_create_missing_queues',
                                               create_missing), app=self.app)
 
-    def flush_routes(self):
+    def flush_routes(self) -> None:
         self._rtable = _routes.prepare(self.app.conf.task_routes)
 
-    def TaskConsumer(self, channel, queues=None, accept=None, **kw):
+    def TaskConsumer(self, channel: ChannelT,
+                     queues: Mapping[str, Queue] = None,
+                     accept: Set[str] = None,
+                     **kw) -> ConsumerT:
         if accept is None:
             accept = self.app.conf.accept_content
         return self.Consumer(
@@ -292,14 +323,30 @@ class AMQP:
             **kw
         )
 
-    def as_task_v2(self, task_id, name, args=None, kwargs=None,
-                   countdown=None, eta=None, group_id=None,
-                   expires=None, retries=0, chord=None,
-                   callbacks=None, errbacks=None, reply_to=None,
-                   time_limit=None, soft_time_limit=None,
-                   create_sent_event=False, root_id=None, parent_id=None,
-                   shadow=None, chain=None, now=None, timezone=None,
-                   origin=None, argsrepr=None, kwargsrepr=None):
+    def as_task_v2(self, task_id: str, name: str, *,
+                   args: Sequence = None,
+                   kwargs: Mapping = None,
+                   countdown: float = None,
+                   eta: datetime = None,
+                   group_id: str = None,
+                   expires: Union[float, datetime] = None,
+                   retries: int = 0,
+                   chord: SignatureT = None,
+                   callbacks: Sequence[SignatureT] = None,
+                   errbacks: Sequence[SignatureT] = None,
+                   reply_to: str = None,
+                   time_limit: float = None,
+                   soft_time_limit: float = None,
+                   create_sent_event: bool = False,
+                   root_id: str = None,
+                   parent_id: str = None,
+                   shadow: str = None,
+                   chain: Sequence[SignatureT] = None,
+                   now: datetime = None,
+                   timezone: tzinfo = None,
+                   origin: str = None,
+                   argsrepr: str = None,
+                   kwargsrepr: str = None) -> task_message:
         args = args or ()
         kwargs = kwargs or {}
         if not isinstance(args, (list, tuple)):
@@ -372,13 +419,26 @@ class AMQP:
             } if create_sent_event else None,
         )
 
-    def as_task_v1(self, task_id, name, args=None, kwargs=None,
-                   countdown=None, eta=None, group_id=None,
-                   expires=None, retries=0,
-                   chord=None, callbacks=None, errbacks=None, reply_to=None,
-                   time_limit=None, soft_time_limit=None,
-                   create_sent_event=False, root_id=None, parent_id=None,
-                   shadow=None, now=None, timezone=None):
+    def as_task_v1(self, task_id, name,
+                   args: Sequence = None,
+                   kwargs: Mapping = None,
+                   countdown: float = None,
+                   eta: datetime = None,
+                   group_id: str = None,
+                   expires: Union[float, datetime] = None,
+                   retries: int = 0,
+                   chord: SignatureT = None,
+                   callbacks: Sequence[SignatureT] = None,
+                   errbacks: Sequence[SignatureT] = None,
+                   reply_to: str = None,
+                   time_limit: float = None,
+                   soft_time_limit: float = None,
+                   create_sent_event: bool = False,
+                   root_id: str = None,
+                   parent_id: str = None,
+                   shadow: str = None,
+                   now: datetime = None,
+                   timezone: tzinfo = None) -> task_message:
         args = args or ()
         kwargs = kwargs or {}
         utc = self.utc
@@ -436,12 +496,12 @@ class AMQP:
             } if create_sent_event else None,
         )
 
-    def _verify_seconds(self, s, what):
+    def _verify_seconds(self, s: float, what: str) -> float:
         if s < INT_MIN:
             raise ValueError('%s is out of range: %r' % (what, s))
         return s
 
-    def _create_task_sender(self):
+    def _create_task_sender(self) -> Callable:
         default_retry = self.app.conf.task_publish_retry
         default_policy = self.app.conf.task_publish_retry_policy
         default_delivery_mode = self.app.conf.task_default_delivery_mode
@@ -459,13 +519,22 @@ class AMQP:
         default_serializer = self.app.conf.task_serializer
         default_compressor = self.app.conf.result_compression
 
-        def send_task_message(producer, name, message,
-                              exchange=None, routing_key=None, queue=None,
-                              event_dispatcher=None,
-                              retry=None, retry_policy=None,
-                              serializer=None, delivery_mode=None,
-                              compression=None, declare=None,
-                              headers=None, exchange_type=None, **kwargs):
+        def send_task_message(
+                producer: ProducerT, name: str, message: task_message,
+                *,
+                exchange: Union[Exchange, str] = None,
+                routing_key: str = None,
+                queue: str = None,
+                event_dispatcher: EventDispatcher = None,
+                retry: bool = None,
+                retry_policy: Mapping[str, Any] = None,
+                serializer: str = None,
+                delivery_mode: Union[str, int] = None,
+                compression: str = None,
+                declare: Sequence[EntityT] = None,
+                headers: Mapping = None,
+                exchange_type: str = None,
+                **kwargs) -> ResultT:
             retry = default_retry if retry is None else retry
             headers2, properties, body, sent_event = message
             if headers:
@@ -547,30 +616,30 @@ class AMQP:
         return send_task_message
 
     @cached_property
-    def default_queue(self):
+    def default_queue(self) -> Queue:
         return self.queues[self.app.conf.task_default_queue]
 
     @cached_property
-    def queues(self):
+    def queues(self) -> Queues:
         """Queue name⇒ declaration mapping."""
         return self.Queues(self.app.conf.task_queues)
 
     @queues.setter  # noqa
-    def queues(self, queues):
+    def queues(self, queues: QueuesArgT) -> Queues:
         return self.Queues(queues)
 
     @property
-    def routes(self):
+    def routes(self) -> Sequence[RouterT]:
         if self._rtable is None:
             self.flush_routes()
         return self._rtable
 
     @cached_property
-    def router(self):
+    def router(self) -> RouterT:
         return self.Router()
 
     @property
-    def producer_pool(self):
+    def producer_pool(self) -> ResourceT:
         if self._producer_pool is None:
             self._producer_pool = pools.producers[
                 self.app.connection_for_write()]
@@ -578,16 +647,16 @@ class AMQP:
         return self._producer_pool
 
     @cached_property
-    def default_exchange(self):
+    def default_exchange(self) -> Exchange:
         return Exchange(self.app.conf.task_default_exchange,
                         self.app.conf.task_default_exchange_type)
 
     @cached_property
-    def utc(self):
+    def utc(self) -> bool:
         return self.app.conf.enable_utc
 
     @cached_property
-    def _event_dispatcher(self):
+    def _event_dispatcher(self) -> EventDispatcher:
         # We call Dispatcher.publish with a custom producer
         # so don't need the diuspatcher to be enabled.
         return self.app.events.Dispatcher(enabled=False)

+ 8 - 4
celery/app/backends.py

@@ -2,8 +2,10 @@
 """Backend selection."""
 import sys
 import types
+from typing import Mapping, Tuple, Union
 from celery.exceptions import ImproperlyConfigured
 from celery._state import current_app
+from celery.types import LoaderT
 from celery.utils.imports import load_extension_class_names, symbol_by_name
 
 __all__ = ['by_name', 'by_url']
@@ -12,7 +14,7 @@ UNKNOWN_BACKEND = """
 Unknown result backend: {0!r}.  Did you spell that correctly? ({1!r})
 """
 
-BACKEND_ALIASES = {
+BACKEND_ALIASES: Mapping[str, str] = {
     'amqp': 'celery.backends.amqp:AMQPBackend',
     'rpc': 'celery.backends.rpc.RPCBackend',
     'cache': 'celery.backends.cache:CacheBackend',
@@ -31,8 +33,9 @@ BACKEND_ALIASES = {
 }
 
 
-def by_name(backend=None, loader=None,
-            extension_namespace='celery.result_backends'):
+def by_name(backend: Union[str, type] = None,
+            loader: LoaderT = None,
+            extension_namespace: str = 'celery.result_backends') -> type:
     """Get backend class by name/alias."""
     backend = backend or 'disabled'
     loader = loader or current_app.loader
@@ -50,7 +53,8 @@ def by_name(backend=None, loader=None,
     return cls
 
 
-def by_url(backend=None, loader=None):
+def by_url(backend: Union[str, type] = None,
+           loader: LoaderT = None) -> Tuple[type, str]:
     """Get backend class by URL."""
     url = None
     if backend and '://' in backend:

+ 271 - 163
celery/app/base.py

@@ -5,14 +5,22 @@ import threading
 import warnings
 
 from collections import UserDict, defaultdict, deque
+from datetime import datetime
 from operator import attrgetter
+from types import ModuleType
+from typing import (
+    Any, Callable, ContextManager, Dict, List,
+    Mapping, MutableMapping, Optional, Set, Sequence, Tuple, Union,
+)
 
+from amqp.types import SSLArg
 from kombu import pools
-from kombu.clocks import LamportClock
+from kombu.clocks import Clock, LamportClock
 from kombu.common import oid_from
 from kombu.utils.compat import register_after_fork
 from kombu.utils.objects import cached_property
 from kombu.utils.uuid import uuid
+from kombu.types import ConnectionT, ProducerT, ResourceT
 from vine import starpromise
 from vine.utils import wraps
 
@@ -27,6 +35,11 @@ from celery._state import (
 from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
+from celery.types import (
+    AppT, AppAMQPT, AppControlT, AppEventsT, AppLogT,
+    BackendT, BeatT, LoaderT, ResultT, RouterT, ScheduleT,
+    SignalT, SignatureT, TaskT, TaskRegistryT, WorkerT,
+)
 from celery.utils import abstract
 from celery.utils.collections import AttributeDictMixin
 from celery.utils.dispatch import Signal
@@ -71,7 +84,7 @@ Example:
 """
 
 
-def app_has_custom(app, attr):
+def app_has_custom(app: AppT, attr: str) -> bool:
     """Return true if app has customized method `attr`.
 
     Note:
@@ -83,14 +96,14 @@ def app_has_custom(app, attr):
                       monkey_patched=[__name__])
 
 
-def _unpickle_appattr(reverse_name, args):
+def _unpickle_appattr(reverse_name: str, args: Tuple) -> Any:
     """Unpickle app."""
     # Given an attribute name and a list of args, gets
     # the attribute from the current app and calls it.
     return get_current_app()._rgetattr(reverse_name)(*args)
 
 
-def _after_fork_cleanup_app(app):
+def _after_fork_cleanup_app(app: AppT) -> None:
     # This is used with multiprocessing.register_after_fork,
     # so need to be at module level.
     try:
@@ -107,39 +120,41 @@ class PendingConfiguration(UserDict, AttributeDictMixin):
     # accessing any key will finalize the configuration,
     # replacing `app.conf` with a concrete settings object.
 
-    callback = None
-    _data = None
+    callback: Callable[[], MutableMapping] = None
+    _data: MutableMapping = None
 
-    def __init__(self, conf, callback):
+    def __init__(self,
+                 conf: MutableMapping,
+                 callback: Callable[[], MutableMapping]) -> None:
         object.__setattr__(self, '_data', conf)
         object.__setattr__(self, 'callback', callback)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Any) -> None:
         self._data[key] = value
 
-    def clear(self):
+    def clear(self) -> None:
         self._data.clear()
 
-    def update(self, *args, **kwargs):
+    def update(self, *args, **kwargs) -> None:
         self._data.update(*args, **kwargs)
 
-    def setdefault(self, *args, **kwargs):
-        return self._data.setdefault(*args, **kwargs)
+    def setdefault(self, key: str, value: Any) -> Any:
+        return self._data.setdefault(key, value)
 
-    def __contains__(self, key):
+    def __contains__(self, key: str) -> bool:
         # XXX will not show finalized configuration
         # setdefault will cause `key in d` to happen,
         # so for setdefault to be lazy, so does contains.
         return key in self._data
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.data)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return repr(self.data)
 
     @cached_property
-    def data(self):
+    def data(self) -> MutableMapping:
         return self.callback()
 
 
@@ -182,54 +197,71 @@ class Celery:
     SYSTEM = platforms.SYSTEM
     IS_macOS, IS_WINDOWS = platforms.IS_macOS, platforms.IS_WINDOWS
 
+    clock: Clock = None
+
     #: Name of the `__main__` module.  Required for standalone scripts.
     #:
     #: If set this will be used instead of `__main__` when automatically
     #: generating task names.
-    main = None
+    main: str = None
 
     #: Custom options for command-line programs.
     #: See :ref:`extending-commandoptions`
-    user_options = None
+    user_options: MutableMapping[str, Set] = None
 
     #: Custom bootsteps to extend and modify the worker.
     #: See :ref:`extending-bootsteps`.
-    steps = None
+    steps: MutableMapping[str, Set] = None
 
     builtin_fixups = BUILTIN_FIXUPS
 
-    amqp_cls = 'celery.app.amqp:AMQP'
-    backend_cls = None
-    events_cls = 'celery.app.events:Events'
-    loader_cls = None
-    log_cls = 'celery.app.log:Logging'
-    control_cls = 'celery.app.control:Control'
-    task_cls = 'celery.app.task:Task'
-    registry_cls = TaskRegistry
+    amqp_cls: Union[str, type] = 'celery.app.amqp:AMQP'
+    backend_cls: Union[str, type] = None
+    events_cls: Union[str, type] = 'celery.app.events:Events'
+    loader_cls: Union[str, type] = None
+    log_cls: Union[str, type] = 'celery.app.log:Logging'
+    control_cls: Union[str, type] = 'celery.app.control:Control'
+    task_cls: Union[str, type] = 'celery.app.task:Task'
+    registry_cls: Union[str, type] = TaskRegistry
 
-    _fixups = None
-    _pool = None
-    _conf = None
-    _after_fork_registered = False
+    _fixups: List = None
+    _pool: ResourceT = None
+    _conf: MutableMapping = None
+    _after_fork_registered: bool = False
 
     #: Signal sent when app is loading configuration.
-    on_configure = None
+    on_configure: SignatureT = None
 
     #: Signal sent after app has prepared the configuration.
-    on_after_configure = None
+    on_after_configure: SignalT = None
 
     #: Signal sent after app has been finalized.
-    on_after_finalize = None
+    on_after_finalize: SignalT = None
 
     #: Signal sent by every new process after fork.
-    on_after_fork = None
-
-    def __init__(self, main=None, loader=None, backend=None,
-                 amqp=None, events=None, log=None, control=None,
-                 set_as_current=True, tasks=None, broker=None, include=None,
-                 changes=None, config_source=None, fixups=None, task_cls=None,
-                 autofinalize=True, namespace=None, strict_typing=True,
-                 **kwargs):
+    on_after_fork: SignalT = None
+
+    def __init__(self,
+                 main: str = None,
+                 *,
+                 loader: Union[str, type] = None,
+                 backend: Union[str, type] = None,
+                 amqp: Union[str, type] = None,
+                 events: Union[str, type] = None,
+                 log: Union[str, type] = None,
+                 control: Union[str, type] = None,
+                 set_as_current: bool = True,
+                 tasks: Union[str, type] = None,
+                 broker: str = None,
+                 include: Sequence[str] = None,
+                 changes: MutableMapping = None,
+                 config_source: str = None,
+                 fixups: Sequence[Union[str, type]] = None,
+                 task_cls: Union[str, type] = None,
+                 autofinalize: bool = True,
+                 namespace: str = None,
+                 strict_typing: bool = True,
+                 **kwargs) -> None:
         self.clock = LamportClock()
         self.main = main
         self.amqp_cls = amqp or self.amqp_cls
@@ -299,7 +331,7 @@ class Celery:
         self.on_init()
         _register_app(self)
 
-    def _get_default_loader(self):
+    def _get_default_loader(self) -> Union[str, type]:
         # the --loader command-line argument sets the environment variable.
         return (
             os.environ.get('CELERY_LOADER') or
@@ -307,30 +339,30 @@ class Celery:
             'celery.loaders.app:AppLoader'
         )
 
-    def on_init(self):
+    def on_init(self) -> None:
         """Optional callback called at init."""
-        pass
+        ...
 
-    def __autoset(self, key, value):
+    def __autoset(self, key: str, value: Any) -> None:
         if value:
             self._preconf[key] = value
             self._preconf_set_by_auto.add(key)
 
-    def set_current(self):
+    def set_current(self) -> None:
         """Make this the current app for this thread."""
         _set_current_app(self)
 
-    def set_default(self):
+    def set_default(self) -> None:
         """Make this the default app for all threads."""
         set_default_app(self)
 
-    def _ensure_after_fork(self):
+    def _ensure_after_fork(self) -> None:
         if not self._after_fork_registered:
             self._after_fork_registered = True
             if register_after_fork is not None:
                 register_after_fork(self, _after_fork_cleanup_app)
 
-    def close(self):
+    def close(self) -> None:
         """Clean up after the application.
 
         Only necessary for dynamically created apps, and you should
@@ -344,25 +376,25 @@ class Celery:
         self._pool = None
         _deregister_app(self)
 
-    def start(self, argv=None):
+    def start(self, argv: List[str] = None) -> None:
         """Run :program:`celery` using `argv`.
 
         Uses :data:`sys.argv` if `argv` is not specified.
         """
-        return instantiate(
+        instantiate(
             'celery.bin.celery:CeleryCommand', app=self
         ).execute_from_commandline(argv)
 
-    def worker_main(self, argv=None):
+    def worker_main(self, argv: List[str] = None) -> None:
         """Run :program:`celery worker` using `argv`.
 
         Uses :data:`sys.argv` if `argv` is not specified.
         """
-        return instantiate(
+        instantiate(
             'celery.bin.worker:worker', app=self
         ).execute_from_commandline(argv)
 
-    def task(self, *args, **opts):
+    def task(self, *args, **opts) -> Union[Callable, TaskT]:
         """Decorator to create a task class out of any callable.
 
         Examples:
@@ -398,10 +430,15 @@ class Celery:
             from . import shared_task
             return shared_task(*args, lazy=False, **opts)
 
-        def inner_create_task_cls(shared=True, filter=None, lazy=True, **opts):
+        def inner_create_task_cls(
+                *,
+                shared: bool = True,
+                filter: Callable = None,
+                lazy: bool = True,
+                **opts) -> Callable:
             _filt = filter  # stupid 2to3
 
-            def _create_task_cls(fun):
+            def _create_task_cls(fun: Callable) -> TaskT:
                 if shared:
                     def cons(app):
                         return app._task_from_fun(fun, **opts)
@@ -430,7 +467,12 @@ class Celery:
                     sum([len(args), len(opts)])))
         return inner_create_task_cls(**opts)
 
-    def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
+    def _task_from_fun(self, fun: Callable,
+                       *,
+                       name: str = None,
+                       base: type = None,
+                       bind: bool = False,
+                       **options) -> TaskT:
         if not self.finalized and not self.autofinalize:
             raise RuntimeError('Contract breach: app not finalized')
         name = name or self.gen_task_name(fun.__name__, fun.__module__)
@@ -462,7 +504,7 @@ class Celery:
             if autoretry_for and not hasattr(task, '_orig_run'):
 
                 @wraps(task.run)
-                def run(*args, **kwargs):
+                def run(*args, **kwargs) -> Any:
                     try:
                         return task._orig_run(*args, **kwargs)
                     except autoretry_for as exc:
@@ -473,7 +515,7 @@ class Celery:
             task = self._tasks[name]
         return task
 
-    def register_task(self, task):
+    def register_task(self, task: TaskT) -> TaskT:
         """Utility for registering a task-based class.
 
         Note:
@@ -490,10 +532,10 @@ class Celery:
         task.bind(self)
         return task
 
-    def gen_task_name(self, name, module):
+    def gen_task_name(self, name: str, module: str) -> str:
         return gen_task_name(self, name, module)
 
-    def finalize(self, auto=False):
+    def finalize(self, auto: bool = False) -> None:
         """Finalize the app.
 
         This loads built-in tasks, evaluates pending task decorators,
@@ -515,7 +557,7 @@ class Celery:
 
                 self.on_after_finalize.send(sender=self)
 
-    def add_defaults(self, fun):
+    def add_defaults(self, fun: Union[Callable, Mapping]) -> None:
         """Add default configuration from dict ``d``.
 
         If the argument is a callable function then it will be regarded
@@ -536,11 +578,15 @@ class Celery:
         if not callable(fun):
             d, fun = fun, lambda: d
         if self.configured:
-            return self._conf.add_defaults(fun())
-        self._pending_defaults.append(fun)
+            self._conf.add_defaults(fun())
+        else:
+            self._pending_defaults.append(fun)
 
-    def config_from_object(self, obj,
-                           silent=False, force=False, namespace=None):
+    def config_from_object(self, obj: Any,
+                           *,
+                           silent: bool = False,
+                           force: bool = False,
+                           namespace: str = None) -> None:
         """Read configuration from object.
 
         Object is either an actual object or the name of a module to import.
@@ -561,9 +607,12 @@ class Celery:
         if force or self.configured:
             self._conf = None
             if self.loader.config_from_object(obj, silent=silent):
-                return self.conf
+                conf = self.conf  # noqa
 
-    def config_from_envvar(self, variable_name, silent=False, force=False):
+    def config_from_envvar(self, variable_name: str,
+                           *,
+                           silent: bool = False,
+                           force: bool = False) -> None:
         """Read configuration from environment variable.
 
         The value of the environment variable must be the name
@@ -574,20 +623,26 @@ class Celery:
             >>> celery.config_from_envvar('CELERY_CONFIG_MODULE')
         """
         module_name = os.environ.get(variable_name)
-        if not module_name:
-            if silent:
-                return False
+        if module_name:
+            self.config_from_object(module_name, silent=silent, force=force)
+        elif not silent:
             raise ImproperlyConfigured(
                 ERR_ENVVAR_NOT_SET.strip().format(variable_name))
-        return self.config_from_object(module_name, silent=silent, force=force)
 
-    def config_from_cmdline(self, argv, namespace='celery'):
+    def config_from_cmdline(self, argv: List[str],
+                            namespace: str = 'celery') -> None:
         self._conf.update(
             self.loader.cmdline_config_parser(argv, namespace)
         )
 
-    def setup_security(self, allowed_serializers=None, key=None, cert=None,
-                       store=None, digest='sha1', serializer='json'):
+    def setup_security(self,
+                       allowed_serializers: Set[str] = None,
+                       *,
+                       key: str = None,
+                       cert: str = None,
+                       store: str = None,
+                       digest: str = 'sha1',
+                       serializer: str = 'json') -> None:
         """Setup the message-signing serializer.
 
         This will affect all application instances (a global operation).
@@ -612,11 +667,14 @@ class Celery:
                 the serializers supported.  Default is ``json``.
         """
         from celery.security import setup_security
-        return setup_security(allowed_serializers, key, cert,
-                              store, digest, serializer, app=self)
-
-    def autodiscover_tasks(self, packages=None,
-                           related_name='tasks', force=False):
+        setup_security(allowed_serializers, key, cert,
+                       store, digest, serializer, app=self)
+
+    def autodiscover_tasks(self,
+                           packages: Sequence[str] = None,
+                           *,
+                           related_name: str = 'tasks',
+                           force: bool = False) -> None:
         """Auto-discover task modules.
 
         Searches a list of packages for a "tasks.py" module (or use
@@ -655,37 +713,61 @@ class Celery:
                 to happen immediately.
         """
         if force:
-            return self._autodiscover_tasks(packages, related_name)
-        signals.import_modules.connect(starpromise(
-            self._autodiscover_tasks, packages, related_name,
-        ), weak=False, sender=self)
+            self._autodiscover_tasks(packages, related_name)
+        else:
+            signals.import_modules.connect(starpromise(
+                self._autodiscover_tasks, packages, related_name,
+            ), weak=False, sender=self)
 
-    def _autodiscover_tasks(self, packages, related_name, **kwargs):
+    def _autodiscover_tasks(
+            self, packages: Sequence[str], related_name: str,
+            **kwargs) -> None:
         if packages:
-            return self._autodiscover_tasks_from_names(packages, related_name)
-        return self._autodiscover_tasks_from_fixups(related_name)
+            self._autodiscover_tasks_from_names(packages, related_name)
+        else:
+            self._autodiscover_tasks_from_fixups(related_name)
 
-    def _autodiscover_tasks_from_names(self, packages, related_name):
+    def _autodiscover_tasks_from_names(
+            self, packages: Sequence[str], related_name: str) -> None:
         # packages argument can be lazy
-        return self.loader.autodiscover_tasks(
+        self.loader.autodiscover_tasks(
             packages() if callable(packages) else packages, related_name,
         )
 
-    def _autodiscover_tasks_from_fixups(self, related_name):
-        return self._autodiscover_tasks_from_names([
+    def _autodiscover_tasks_from_fixups(self, related_name: str) -> None:
+        self._autodiscover_tasks_from_names([
             pkg for fixup in self._fixups
             for pkg in fixup.autodiscover_tasks()
             if hasattr(fixup, 'autodiscover_tasks')
         ], related_name=related_name)
 
-    def send_task(self, name, args=None, kwargs=None, countdown=None,
-                  eta=None, task_id=None, producer=None, connection=None,
-                  router=None, result_cls=None, expires=None,
-                  link=None, link_error=None,
-                  add_to_parent=True, group_id=None, retries=0, chord=None,
-                  reply_to=None, time_limit=None, soft_time_limit=None,
-                  root_id=None, parent_id=None, route_name=None,
-                  shadow=None, chain=None, task_type=None, **options):
+    def send_task(self, name: str,
+                  args: Sequence = None,
+                  kwargs: Mapping = None,
+                  countdown: float = None,
+                  eta: datetime = None,
+                  task_id: str = None,
+                  producer: ProducerT = None,
+                  connection: ConnectionT = None,
+                  router: RouterT = None,
+                  result_cls: type = None,
+                  expires: Union[float, datetime] = None,
+                  link: Union[SignatureT, Sequence[SignatureT]]=None,
+                  link_error: Union[SignatureT, Sequence[SignatureT]] = None,
+                  add_to_parent: bool = True,
+                  group_id: str = None,
+                  retries: int = 0,
+                  chord: SignatureT = None,
+                  reply_to: str = None,
+                  time_limit: float = None,
+                  soft_time_limit: float = None,
+                  root_id: str = None,
+                  parent_id: str = None,
+                  route_name: str = None,
+                  shadow: str = None,
+                  chain: List[SignatureT] = None,
+                  task_type: TaskT = None,
+                  **options) -> ResultT:
         """Send task by name.
 
         Supports the same arguments as :meth:`@-Task.apply_async`.
@@ -737,7 +819,7 @@ class Celery:
                 parent.add_trail(result)
         return result
 
-    def connection_for_read(self, url=None, **kwargs):
+    def connection_for_read(self, url: str = None, **kwargs) -> ConnectionT:
         """Establish connection used for consuming.
 
         See Also:
@@ -745,7 +827,7 @@ class Celery:
         """
         return self._connection(url or self.conf.broker_read_url, **kwargs)
 
-    def connection_for_write(self, url=None, **kwargs):
+    def connection_for_write(self, url: str = None, **kwargs) -> ConnectionT:
         """Establish connection used for producing.
 
         See Also:
@@ -753,11 +835,20 @@ class Celery:
         """
         return self._connection(url or self.conf.broker_write_url, **kwargs)
 
-    def connection(self, hostname=None, userid=None, password=None,
-                   virtual_host=None, port=None, ssl=None,
-                   connect_timeout=None, transport=None,
-                   transport_options=None, heartbeat=None,
-                   login_method=None, failover_strategy=None, **kwargs):
+    def connection(self,
+                   hostname: str = None,
+                   userid: str = None,
+                   password: str = None,
+                   virtual_host: str = None,
+                   port: int = None,
+                   ssl: SSLArg = None,
+                   connect_timeout: float = None,
+                   transport: Any = None,
+                   transport_options: Mapping = None,
+                   heartbeat: float = None,
+                   login_method: str = None,
+                   failover_strategy: str = None,
+                   **kwargs) -> ConnectionT:
         """Establish a connection to the message broker.
 
         Please use :meth:`connection_for_read` and
@@ -796,11 +887,19 @@ class Celery:
             **kwargs
         )
 
-    def _connection(self, url, userid=None, password=None,
-                    virtual_host=None, port=None, ssl=None,
-                    connect_timeout=None, transport=None,
-                    transport_options=None, heartbeat=None,
-                    login_method=None, failover_strategy=None, **kwargs):
+    def _connection(self, url: str,
+                    userid: str = None,
+                    password: str = None,
+                    virtual_host: str = None,
+                    port: int = None,
+                    ssl: SSLArg = None,
+                    connect_timeout: float = None,
+                    transport: Any = None,
+                    transport_options: Mapping = None,
+                    heartbeat: float = None,
+                    login_method: str = None,
+                    failover_strategy: str = None,
+                    **kwargs) -> ConnectionT:
         conf = self.conf
         return self.amqp.Connection(
             url,
@@ -824,13 +923,14 @@ class Celery:
         )
     broker_connection = connection
 
-    def _acquire_connection(self, pool=True):
+    def _acquire_connection(self, pool: bool = True) -> ConnectionT:
         """Helper for :meth:`connection_or_acquire`."""
         if pool:
             return self.pool.acquire(block=True)
         return self.connection_for_write()
 
-    def connection_or_acquire(self, connection=None, pool=True, *_, **__):
+    def connection_or_acquire(self, connection: ConnectionT = None,
+                              pool: bool = True, *_, **__) -> ContextManager:
         """Context used to acquire a connection from the pool.
 
         For use within a :keyword:`with` statement to get a connection
@@ -842,7 +942,8 @@ class Celery:
         """
         return FallbackContext(connection, self._acquire_connection, pool=pool)
 
-    def producer_or_acquire(self, producer=None):
+    def producer_or_acquire(self,
+                            producer: ProducerT = None) -> ContextManager:
         """Context used to acquire a producer from the pool.
 
         For use within a :keyword:`with` statement to get a producer
@@ -856,23 +957,23 @@ class Celery:
             producer, self.producer_pool.acquire, block=True,
         )
 
-    def prepare_config(self, c):
+    def prepare_config(self, c: Mapping) -> Mapping:
         """Prepare configuration before it is merged with the defaults."""
         return find_deprecated_settings(c)
 
-    def now(self):
+    def now(self) -> datetime:
         """Return the current time and date as a datetime."""
         return self.loader.now(utc=self.conf.enable_utc)
 
-    def select_queues(self, queues=None):
+    def select_queues(self, queues: Sequence[str] = None) -> None:
         """Select subset of queues.
 
         Arguments:
             queues (Sequence[str]): a list of queue names to keep.
         """
-        return self.amqp.queues.select(queues)
+        self.amqp.queues.select(queues)
 
-    def either(self, default_key, *defaults):
+    def either(self, default_key: str, *defaults) -> Any:
         """Get key from configuration or use default values.
 
         Fallback to the value of a configuration key if none of the
@@ -882,17 +983,17 @@ class Celery:
             first(None, defaults), starpromise(self.conf.get, default_key),
         ])
 
-    def bugreport(self):
+    def bugreport(self) -> str:
         """Return information useful in bug reports."""
         return bugreport(self)
 
-    def _get_backend(self):
+    def _get_backend(self) -> BackendT:
         backend, url = backends.by_url(
             self.backend_cls or self.conf.result_backend,
             self.loader)
         return backend(app=self, url=url)
 
-    def _finalize_pending_conf(self):
+    def _finalize_pending_conf(self) -> MutableMapping:
         """Get config value by key and finalize loading the configuration.
 
         Note:
@@ -902,7 +1003,7 @@ class Celery:
         conf = self._conf = self._load_config()
         return conf
 
-    def _load_config(self):
+    def _load_config(self) -> MutableMapping:
         if isinstance(self.on_configure, Signal):
             self.on_configure.send(sender=self)
         else:
@@ -936,7 +1037,7 @@ class Celery:
         self.on_after_configure.send(sender=self, source=self._conf)
         return self._conf
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         self._pool = None
         try:
             self.__dict__['amqp']._producer_pool = None
@@ -944,13 +1045,14 @@ class Celery:
             pass
         self.on_after_fork.send(sender=self)
 
-    def signature(self, *args, **kwargs):
+    def signature(self, *args, **kwargs) -> SignatureT:
         """Return a new :class:`~celery.Signature` bound to this app."""
         kwargs['app'] = self
         return self._canvas.signature(*args, **kwargs)
 
-    def add_periodic_task(self, schedule, sig,
-                          args=(), kwargs=(), name=None, **opts):
+    def add_periodic_task(self, schedule: ScheduleT, sig: SignatureT,
+                          args: Tuple = (), kwargs: Dict = (),
+                          name: str = None, **opts) -> str:
         key, entry = self._sig_to_periodic_task_entry(
             schedule, sig, args, kwargs, name, **opts)
         if self.configured:
@@ -959,8 +1061,10 @@ class Celery:
             self._pending_periodic_tasks.append((key, entry))
         return key
 
-    def _sig_to_periodic_task_entry(self, schedule, sig,
-                                    args=(), kwargs={}, name=None, **opts):
+    def _sig_to_periodic_task_entry(
+            self, schedule: ScheduleT, sig: SignatureT,
+            args: Tuple = (), kwargs: Dict = {},
+            name: str = None, **opts) -> Tuple[str, Mapping]:
         sig = (sig.clone(args, kwargs)
                if isinstance(sig, abstract.CallableSignature)
                else self.signature(sig.name, args, kwargs))
@@ -972,18 +1076,22 @@ class Celery:
             'options': dict(sig.options, **opts),
         }
 
-    def _add_periodic_task(self, key, entry):
+    def _add_periodic_task(self, key: str, entry: Mapping) -> None:
         self._conf.beat_schedule[key] = entry
 
-    def create_task_cls(self):
+    def create_task_cls(self) -> type:
         """Create a base task class bound to this app."""
         return self.subclass_with_self(
             self.task_cls, name='Task', attribute='_app',
             keep_reduce=True, abstract=True,
         )
 
-    def subclass_with_self(self, Class, name=None, attribute='app',
-                           reverse=None, keep_reduce=False, **kw):
+    def subclass_with_self(self, Class: type,
+                           name: str = None,
+                           attribute: str = 'app',
+                           reverse: str = None,
+                           keep_reduce: bool = False,
+                           **kw) -> type:
         """Subclass an app-compatible class.
 
         App-compatible means that the class has a class attribute that
@@ -1017,24 +1125,24 @@ class Celery:
 
         return type(name or Class.__name__, (Class,), attrs)
 
-    def _rgetattr(self, path):
+    def _rgetattr(self, path: str) -> Any:
         return attrgetter(path)(self)
 
-    def __enter__(self):
+    def __enter__(self) -> AppT:
         return self
 
-    def __exit__(self, *exc_info):
+    def __exit__(self, *exc_info) -> None:
         self.close()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<{0} {1}>'.format(type(self).__name__, appstr(self))
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         if self._using_v1_reduce:
             return self.__reduce_v1__()
         return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__()))
 
-    def __reduce_v1__(self):
+    def __reduce_v1__(self) -> Tuple:
         # Reduce only pickles the configuration changes,
         # so the default configuration doesn't have to be passed
         # between processes.
@@ -1043,7 +1151,7 @@ class Celery:
             (self.__class__, self.Pickler) + self.__reduce_args__(),
         )
 
-    def __reduce_keys__(self):
+    def __reduce_keys__(self) -> Mapping[str, Any]:
         """Keyword arguments used to reconstruct the object when unpickling."""
         return {
             'main': self.main,
@@ -1061,7 +1169,7 @@ class Celery:
             'namespace': self.namespace,
         }
 
-    def __reduce_args__(self):
+    def __reduce_args__(self) -> Tuple:
         """Deprecated method, please use :meth:`__reduce_keys__` instead."""
         return (self.main, self._conf.changes if self.configured else {},
                 self.loader_cls, self.backend_cls, self.amqp_cls,
@@ -1069,7 +1177,7 @@ class Celery:
                 False, self._config_source)
 
     @cached_property
-    def Worker(self):
+    def Worker(self) -> WorkerT:
         """Worker application.
 
         See Also:
@@ -1078,7 +1186,7 @@ class Celery:
         return self.subclass_with_self('celery.apps.worker:Worker')
 
     @cached_property
-    def WorkController(self, **kwargs):
+    def WorkController(self, **kwargs) -> WorkerT:
         """Embeddable worker.
 
         See Also:
@@ -1087,7 +1195,7 @@ class Celery:
         return self.subclass_with_self('celery.worker:WorkController')
 
     @cached_property
-    def Beat(self, **kwargs):
+    def Beat(self, **kwargs) -> BeatT:
         """:program:`celery beat` scheduler application.
 
         See Also:
@@ -1096,7 +1204,7 @@ class Celery:
         return self.subclass_with_self('celery.apps.beat:Beat')
 
     @cached_property
-    def Task(self):
+    def Task(self) -> type:
         """Base task class for this app."""
         return self.create_task_cls()
 
@@ -1105,7 +1213,7 @@ class Celery:
         return prepare_annotations(self.conf.task_annotations)
 
     @cached_property
-    def AsyncResult(self):
+    def AsyncResult(self) -> type:
         """Create new result instance.
 
         See Also:
@@ -1114,11 +1222,11 @@ class Celery:
         return self.subclass_with_self('celery.result:AsyncResult')
 
     @cached_property
-    def ResultSet(self):
+    def ResultSet(self) -> type:
         return self.subclass_with_self('celery.result:ResultSet')
 
     @cached_property
-    def GroupResult(self):
+    def GroupResult(self) -> type:
         """Create new group result instance.
 
         See Also:
@@ -1127,7 +1235,7 @@ class Celery:
         return self.subclass_with_self('celery.result:GroupResult')
 
     @property
-    def pool(self):
+    def pool(self) -> ResourceT:
         """Broker connection pool: :class:`~@pool`.
 
         Note:
@@ -1141,12 +1249,12 @@ class Celery:
         return self._pool
 
     @property
-    def current_task(self):
+    def current_task(self) -> Optional[TaskT]:
         """Instance of task being executed, or :const:`None`."""
         return _task_stack.top
 
     @property
-    def current_worker_task(self):
+    def current_worker_task(self) -> Optional[TaskT]:
         """The task currently being executed by a worker or :const:`None`.
 
         Differs from :data:`current_task` in that it's not affected
@@ -1155,7 +1263,7 @@ class Celery:
         return get_current_worker_task()
 
     @cached_property
-    def oid(self):
+    def oid(self) -> str:
         """Universally unique identifier for this app."""
         # since 4.0: thread.get_ident() is not included when
         # generating the process id.  This is due to how the RPC
@@ -1164,53 +1272,53 @@ class Celery:
         return oid_from(self, threads=False)
 
     @cached_property
-    def amqp(self):
+    def amqp(self) -> AppAMQPT:
         """AMQP related functionality: :class:`~@amqp`."""
         return instantiate(self.amqp_cls, app=self)
 
     @cached_property
-    def backend(self):
+    def backend(self) -> BackendT:
         """Current backend instance."""
         return self._get_backend()
 
     @property
-    def conf(self):
+    def conf(self) -> MutableMapping:
         """Current configuration."""
         if self._conf is None:
             self._conf = self._load_config()
         return self._conf
 
     @conf.setter
-    def conf(self, d):  # noqa
+    def conf(self, d: MutableMapping) -> None:  # noqa
         self._conf = d
 
     @cached_property
-    def control(self):
+    def control(self) -> AppControlT:
         """Remote control: :class:`~@control`."""
         return instantiate(self.control_cls, app=self)
 
     @cached_property
-    def events(self):
+    def events(self) -> AppEventsT:
         """Consuming and sending events: :class:`~@events`."""
         return instantiate(self.events_cls, app=self)
 
     @cached_property
-    def loader(self):
+    def loader(self) -> LoaderT:
         """Current loader instance."""
         return get_loader_cls(self.loader_cls)(app=self)
 
     @cached_property
-    def log(self):
+    def log(self) -> AppLogT:
         """Logging: :class:`~@log`."""
         return instantiate(self.log_cls, app=self)
 
     @cached_property
-    def _canvas(self):
+    def _canvas(self) -> ModuleType:
         from celery import canvas
         return canvas
 
     @cached_property
-    def tasks(self):
+    def tasks(self) -> TaskRegistryT:
         """Task registry.
 
         Warning:

+ 13 - 6
celery/app/defaults.py

@@ -1,8 +1,9 @@
 # -*- coding: utf-8 -*-
 """Configuration introspection and defaults."""
 import sys
-from collections import deque, namedtuple
+from collections import deque
 from datetime import timedelta
+from typing import NamedTuple
 from celery.utils.functional import memoize
 from celery.utils.serialization import strtobool
 
@@ -31,7 +32,13 @@ OLD_NS = {'celery_{0}'}
 OLD_NS_BEAT = {'celerybeat_{0}'}
 OLD_NS_WORKER = {'celeryd_{0}'}
 
-searchresult = namedtuple('searchresult', ('namespace', 'key', 'type'))
+
+class find_result_t(NamedTuple):
+    """Return value of :func:`find`."""
+
+    namespace: str
+    key: str
+    type: 'Option'
 
 
 def Namespace(__old__=None, **options):
@@ -340,18 +347,18 @@ def find(name, namespace='celery'):
     # - Try specified name-space first.
     namespace = namespace.lower()
     try:
-        return searchresult(
+        return find_result_t(
             namespace, name.lower(), NAMESPACES[namespace][name.lower()],
         )
     except KeyError:
         # - Try all the other namespaces.
         for ns, opts in NAMESPACES.items():
             if ns.lower() == name.lower():
-                return searchresult(None, ns, opts)
+                return find_result_t(None, ns, opts)
             elif isinstance(opts, dict):
                 try:
-                    return searchresult(ns, name.lower(), opts[name.lower()])
+                    return find_result_t(ns, name.lower(), opts[name.lower()])
                 except KeyError:
                     pass
     # - See if name is a qualname last.
-    return searchresult(None, name.lower(), DEFAULTS[name.lower()])
+    return find_result_t(None, name.lower(), DEFAULTS[name.lower()])

+ 12 - 6
celery/app/events.py

@@ -1,5 +1,8 @@
 """Implementation for the app.events shortcuts."""
 from contextlib import contextmanager
+from celery.events import EventDispatcher, EventReceiver
+from celery.events.state import State
+from celery.types import AppT
 from kombu.utils.objects import cached_property
 
 
@@ -10,27 +13,30 @@ class Events:
     dispatcher_cls = 'celery.events.dispatcher:EventDispatcher'
     state_cls = 'celery.events.state:State'
 
-    def __init__(self, app=None):
+    def __init__(self, app: AppT = None):
         self.app = app
 
     @cached_property
-    def Receiver(self):
+    def Receiver(self) -> EventReceiver:
         return self.app.subclass_with_self(
             self.receiver_cls, reverse='events.Receiver')
 
     @cached_property
-    def Dispatcher(self):
+    def Dispatcher(self) -> EventDispatcher:
         return self.app.subclass_with_self(
             self.dispatcher_cls, reverse='events.Dispatcher')
 
     @cached_property
-    def State(self):
+    def State(self) -> State:
         return self.app.subclass_with_self(
             self.state_cls, reverse='events.State')
 
     @contextmanager
-    def default_dispatcher(self, hostname=None, enabled=True,
-                           buffer_while_offline=False):
+    def default_dispatcher(
+            self,
+            hostname: str = None,
+            enabled: bool = True,
+            buffer_while_offline: bool = False) -> EventDispatcher:
         with self.app.amqp.producer_pool.acquire(block=True) as prod:
             # pylint: disable=too-many-function-args
             # This is a property pylint...

+ 56 - 26
celery/app/log.py

@@ -12,6 +12,7 @@ import os
 import sys
 
 from logging.handlers import WatchedFileHandler
+from typing import Union
 
 from kombu.utils.encoding import set_default_encoding_file
 
@@ -19,6 +20,7 @@ from celery import signals
 from celery._state import get_current_task
 from celery.local import class_property
 from celery.platforms import isatty
+from celery.types import AppT
 from celery.utils.log import (
     get_logger, mlevel,
     ColorFormatter, LoggingProxy, get_multiprocessing_logger,
@@ -35,7 +37,7 @@ MP_LOG = os.environ.get('MP_LOG', False)
 class TaskFormatter(ColorFormatter):
     """Formatter for tasks, adding the task name and id."""
 
-    def format(self, record):
+    def format(self, record: logging.LogRecord) -> str:
         task = get_current_task()
         if task and task.request:
             record.__dict__.update(task_id=task.request.id,
@@ -54,15 +56,20 @@ class Logging:
     #: will do nothing.
     _setup = False
 
-    def __init__(self, app):
+    def __init__(self, app: AppT) -> None:
         self.app = app
         self.loglevel = mlevel(logging.WARN)
         self.format = self.app.conf.worker_log_format
         self.task_format = self.app.conf.worker_task_log_format
         self.colorize = self.app.conf.worker_log_color
 
-    def setup(self, loglevel=None, logfile=None, redirect_stdouts=False,
-              redirect_level='WARNING', colorize=None, hostname=None):
+    def setup(self,
+              loglevel: Union[str, int] = None,
+              logfile: str = None,
+              redirect_stdouts: bool = False,
+              redirect_level: str = 'WARNING',
+              colorize: bool = None,
+              hostname: str = None) -> bool:
         loglevel = mlevel(loglevel)
         handled = self.setup_logging_subsystem(
             loglevel, logfile, colorize=colorize, hostname=hostname,
@@ -76,7 +83,9 @@ class Logging:
         )
         return handled
 
-    def redirect_stdouts(self, loglevel=None, name='celery.redirected'):
+    def redirect_stdouts(self,
+                         loglevel: int = None,
+                         name: str = 'celery.redirected') -> None:
         self.redirect_stdouts_to_logger(
             get_logger(name), loglevel=loglevel
         )
@@ -85,10 +94,15 @@ class Logging:
             CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''),
         )
 
-    def setup_logging_subsystem(self, loglevel=None, logfile=None, format=None,
-                                colorize=None, hostname=None, **kwargs):
+    def setup_logging_subsystem(self,
+                                loglevel: Union[int, str] = None,
+                                logfile: str = None,
+                                format: str = None,
+                                colorize: bool = None,
+                                hostname: str = None,
+                                **kwargs) -> bool:
         if self.already_setup:
-            return
+            return False
         if logfile and hostname:
             logfile = node_format(logfile, hostname)
         Logging._setup = True
@@ -144,18 +158,28 @@ class Logging:
         os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel),
                           _MP_FORK_LOGFILE_=logfile_name,
                           _MP_FORK_LOGFORMAT_=format)
-        return receivers
-
-    def _configure_logger(self, logger, logfile, loglevel,
-                          format, colorize, **kwargs):
+        return bool(receivers)
+
+    def _configure_logger(self,
+                          logger: logging.Logger,
+                          logfile: str,
+                          loglevel: int,
+                          format: str,
+                          colorize: bool,
+                          **kwargs) -> None:
         if logger is not None:
             self.setup_handlers(logger, logfile, format,
                                 colorize, **kwargs)
             if loglevel:
                 logger.setLevel(loglevel)
 
-    def setup_task_loggers(self, loglevel=None, logfile=None, format=None,
-                           colorize=None, propagate=False, **kwargs):
+    def setup_task_loggers(self,
+                           loglevel: Union[str, int] = None,
+                           logfile: str = None,
+                           format: str = None,
+                           colorize: bool = None,
+                           propagate: bool = False,
+                           **kwargs) -> logging.Logger:
         """Setup the task logger.
 
         If `logfile` is not specified, then `sys.stderr` is used.
@@ -181,8 +205,10 @@ class Logging:
         )
         return logger
 
-    def redirect_stdouts_to_logger(self, logger, loglevel=None,
-                                   stdout=True, stderr=True):
+    def redirect_stdouts_to_logger(self, logger: logging.Logger,
+                                   loglevel: int = None,
+                                   stdout: bool = True,
+                                   stderr: bool = True) -> LoggingProxy:
         """Redirect :class:`sys.stdout` and :class:`sys.stderr` to logger.
 
         Arguments:
@@ -197,7 +223,9 @@ class Logging:
             sys.stderr = proxy
         return proxy
 
-    def supports_color(self, colorize=None, logfile=None):
+    def supports_color(self,
+                       colorize: bool = None,
+                       logfile: str = None) -> bool:
         colorize = self.colorize if colorize is None else colorize
         if self.app.IS_WINDOWS:
             # Windows does not support ANSI color codes.
@@ -208,11 +236,12 @@ class Logging:
             return logfile is None and isatty(sys.stderr)
         return colorize
 
-    def colored(self, logfile=None, enabled=None):
+    def colored(self, logfile: str = None, enabled: bool = None) -> colored:
         return colored(enabled=self.supports_color(enabled, logfile))
 
-    def setup_handlers(self, logger, logfile, format, colorize,
-                       formatter=ColorFormatter, **kwargs):
+    def setup_handlers(self, logger: logging.Logger,
+                       logfile: str, format: str, colorize: bool,
+                       formatter=ColorFormatter, **kwargs) -> logging.Logger:
         if self._is_configured(logger):
             return logger
         handler = self._detect_handler(logfile)
@@ -220,30 +249,31 @@ class Logging:
         logger.addHandler(handler)
         return logger
 
-    def _detect_handler(self, logfile=None):
+    def _detect_handler(self, logfile: str = None) -> logging.Handler:
         """Create handler from filename, an open stream or `None` (stderr)."""
         logfile = sys.__stderr__ if logfile is None else logfile
         if hasattr(logfile, 'write'):
             return logging.StreamHandler(logfile)
         return WatchedFileHandler(logfile)
 
-    def _has_handler(self, logger):
+    def _has_handler(self, logger: logging.Logger) -> bool:
         return any(
             not isinstance(h, logging.NullHandler)
             for h in logger.handlers or []
         )
 
-    def _is_configured(self, logger):
+    def _is_configured(self, logger: logging.Logger) -> bool:
         return self._has_handler(logger) and not getattr(
             logger, '_rudimentary_setup', False)
 
-    def get_default_logger(self, name='celery', **kwargs):
+    def get_default_logger(self, name: str = 'celery',
+                           **kwargs) -> logging.Logger:
         return get_logger(name)
 
     @class_property
-    def already_setup(self):
+    def already_setup(self) -> bool:
         return self._setup
 
     @already_setup.setter  # noqa
-    def already_setup(self, was_setup):
+    def already_setup(self, was_setup: bool) -> None:
         self._setup = was_setup

+ 7 - 16
celery/app/registry.py

@@ -2,8 +2,10 @@
 """Registry of available tasks."""
 import inspect
 from importlib import import_module
+from typing import Any
 from celery._state import get_current_app
 from celery.exceptions import NotRegistered, InvalidTaskError
+from celery.types import TaskT
 
 __all__ = ['TaskRegistry']
 
@@ -13,10 +15,10 @@ class TaskRegistry(dict):
 
     NotRegistered = NotRegistered
 
-    def __missing__(self, key):
+    def __missing__(self, key: str) -> Any:
         raise self.NotRegistered(key)
 
-    def register(self, task):
+    def register(self, task: TaskT) -> None:
         """Register a task in the task registry.
 
         The task will be automatically instantiated if not already an
@@ -28,7 +30,7 @@ class TaskRegistry(dict):
                     type(task).__name__))
         self[task.name] = inspect.isclass(task) and task() or task
 
-    def unregister(self, name):
+    def unregister(self, name: str) -> None:
         """Unregister task by name.
 
         Arguments:
@@ -43,23 +45,12 @@ class TaskRegistry(dict):
         except KeyError:
             raise self.NotRegistered(name)
 
-    # -- these methods are irrelevant now and will be removed in 4.0
-    def regular(self):
-        return self.filter_types('regular')
 
-    def periodic(self):
-        return self.filter_types('periodic')
-
-    def filter_types(self, type):
-        return {name: task for name, task in self.items()
-                if getattr(task, 'type', 'regular') == type}
-
-
-def _unpickle_task(name):
+def _unpickle_task(name: str) -> TaskT:
     return get_current_app().tasks[name]
 
 
-def _unpickle_task_v2(name, module=None):
+def _unpickle_task_v2(name: str, module: str = None) -> TaskT:
     if module:
         import_module(module)
     return get_current_app().tasks[name]

+ 32 - 12
celery/app/routes.py

@@ -6,8 +6,10 @@ Contains utilities for working with task routers, (:setting:`task_routes`).
 import re
 import string
 from collections import Mapping, OrderedDict
+from typing import Any, Callable, Sequence, Union, Tuple
 from kombu import Queue
 from celery.exceptions import QueueNotFound
+from celery.types import AppT, RouterT, TaskT
 from celery.utils.collections import lpmerge
 from celery.utils.functional import maybe_evaluate, mlazy
 from celery.utils.imports import symbol_by_name
@@ -15,7 +17,8 @@ from celery.utils.imports import symbol_by_name
 __all__ = ['MapRoute', 'Router', 'prepare']
 
 
-def glob_to_re(glob, quote=string.punctuation.replace('*', '')):
+def glob_to_re(glob: str, *,
+               quote: str = string.punctuation.replace('*', '')) -> str:
     glob = ''.join('\\' + c if c in quote else c for c in glob)
     return glob.replace('*', '.+?')
 
@@ -23,7 +26,10 @@ def glob_to_re(glob, quote=string.punctuation.replace('*', '')):
 class MapRoute:
     """Creates a router out of a :class:`dict`."""
 
-    def __init__(self, map):
+    map: Sequence[Tuple[str, Any]] = None
+    patterns: Mapping = None
+
+    def __init__(self, map: Union[Mapping, Sequence[Tuple[str, Any]]]) -> None:
         map = map.items() if isinstance(map, Mapping) else map
         self.map = {}
         self.patterns = OrderedDict()
@@ -35,7 +41,7 @@ class MapRoute:
             else:
                 self.map[k] = v
 
-    def __call__(self, name, *args, **kwargs):
+    def __call__(self, name: str, *args, **kwargs) -> Mapping:
         try:
             return dict(self.map[name])
         except KeyError:
@@ -53,14 +59,19 @@ class MapRoute:
 class Router:
     """Route tasks based on the :setting:`task_routes` setting."""
 
-    def __init__(self, routes=None, queues=None,
-                 create_missing=False, app=None):
+    def __init__(self,
+                 routes: Sequence = None,
+                 queues: Mapping = None,
+                 create_missing: bool = False,
+                 app: AppT = None) -> None:
         self.app = app
         self.queues = {} if queues is None else queues
         self.routes = [] if routes is None else routes
         self.create_missing = create_missing
 
-    def route(self, options, name, args=(), kwargs={}, task_type=None):
+    def route(self, options: Mapping, name: str,
+              args: Sequence = (), kwargs: Mapping = {},
+              task_type: TaskT = None) -> Mapping:
         options = self.expand_destination(options)  # expands 'queue'
         if self.routes:
             route = self.lookup_route(name, args, kwargs, options, task_type)
@@ -71,7 +82,7 @@ class Router:
                               self.app.conf.task_default_queue), options)
         return options
 
-    def expand_destination(self, route):
+    def expand_destination(self, route: Union[str, Mapping]) -> Mapping:
         # Route can be a queue name: convenient for direct exchanges.
         if isinstance(route, str):
             queue, route = route, {}
@@ -91,15 +102,24 @@ class Router:
                         'Queue {0!r} missing from task_queues'.format(queue))
         return route
 
-    def lookup_route(self, name,
-                     args=None, kwargs=None, options=None, task_type=None):
+    def lookup_route(self, name: str,
+                     args: Sequence = None,
+                     kwargs: Mapping = None,
+                     options: Mapping = None,
+                     task_type: TaskT = None) -> Mapping:
         query = self.query_router
         for router in self.routes:
             route = query(router, name, args, kwargs, options, task_type)
             if route is not None:
                 return route
 
-    def query_router(self, router, task, args, kwargs, options, task_type):
+    def query_router(self,
+                     router: Union[RouterT, Callable],
+                     task: str,
+                     args: Sequence,
+                     kwargs: Mapping,
+                     options: Mapping,
+                     task_type: TaskT) -> None:
         router = maybe_evaluate(router)
         if hasattr(router, 'route_for_task'):
             # pre 4.0 router class
@@ -107,7 +127,7 @@ class Router:
         return router(task, args, kwargs, options, task=task_type)
 
 
-def expand_router_string(router):
+def expand_router_string(router: Any):
     router = symbol_by_name(router)
     if hasattr(router, 'route_for_task'):
         # need to instantiate pre 4.0 router classes
@@ -115,7 +135,7 @@ def expand_router_string(router):
     return router
 
 
-def prepare(routes):
+def prepare(routes: Any) -> Sequence[RouterT]:
     """Expand the :setting:`task_routes` setting."""
     def expand_route(route):
         if isinstance(route, (Mapping, list, tuple)):

+ 190 - 122
celery/app/task.py

@@ -2,8 +2,14 @@
 """Task implementation: request context and the task base class."""
 import sys
 
+from datetime import datetime
+from typing import (
+    Any, Awaitable, Callable, Iterable, Mapping, Sequence, Tuple, Union,
+)
+
 from billiard.einfo import ExceptionInfo
 from kombu.exceptions import OperationalError
+from kombu.types import ProducerT
 from kombu.utils.uuid import uuid
 
 from celery import current_app, group
@@ -13,9 +19,13 @@ from celery.canvas import signature
 from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.local import class_property
 from celery.result import EagerResult
+from celery.types import (
+    AppT, BackendT, ResultT, SignatureT, TaskT, TracerT, WorkerConsumerT,
+)
 from celery.utils import abstract
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.imports import instantiate
+from celery.utils.threads import LocalStack
 
 from .annotations import resolve_all as resolve_all_annotations
 from .registry import _unpickle_task_v2
@@ -40,13 +50,13 @@ R_INSTANCE = '<@task: {0.name} of {app}{flags}>'
 TaskType = type
 
 
-def _strflags(flags, default=''):
+def _strflags(flags: Sequence, default: str = '') -> str:
     if flags:
         return ' ({0})'.format(', '.join(flags))
     return default
 
 
-def _reprtask(task, fmt=None, flags=None):
+def _reprtask(task: TaskT, fmt: str = None, flags: Sequence = None) -> str:
     flags = list(flags) if flags is not None else []
     if not fmt:
         fmt = R_BOUND_TASK if task._app else R_UNBOUND_TASK
@@ -59,51 +69,54 @@ def _reprtask(task, fmt=None, flags=None):
 class Context:
     """Task request variables (Task.request)."""
 
-    logfile = None
-    loglevel = None
-    hostname = None
-    id = None
-    args = None
-    kwargs = None
-    retries = 0
-    eta = None
-    expires = None
-    is_eager = False
-    headers = None
-    delivery_info = None
-    reply_to = None
-    root_id = None
-    parent_id = None
-    correlation_id = None
-    taskset = None   # compat alias to group
-    group = None
-    chord = None
-    chain = None
-    utc = None
-    called_directly = True
-    callbacks = None
-    errbacks = None
-    timelimit = None
-    origin = None
-    _children = None   # see property
-    _protected = 0
-
-    def __init__(self, *args, **kwargs):
+    logfile: str = None
+    loglevel: int = None
+    hostname: str = None
+    id: str = None
+    args: Sequence = None
+    kwargs: Mapping = None
+    retries: int = 0
+    eta: datetime = None
+    expires: Union[float, datetime] = None
+    is_eager: bool = False
+    headers: Mapping = None
+    delivery_info: Mapping = None
+    reply_to: str = None
+    root_id: str = None
+    parent_id: str = None
+    correlation_id: str = None
+    # compat alias to group
+    taskset: str = None
+    group: str = None
+    chord: SignatureT = None
+    chain: Sequence[SignatureT] = None
+    utc: bool = None
+    called_directly: bool = True
+    callbacks: Sequence[SignatureT] = None
+    errbacks: Sequence[SignatureT] = None
+    timelimit: Tuple[float, float] = None
+    origin: str = None
+
+    # see property
+    _children: Sequence[ResultT] = None
+    _protected: int = 0
+
+    def __init__(self, *args, **kwargs) -> None:
         self.update(*args, **kwargs)
 
-    def update(self, *args, **kwargs):
-        return self.__dict__.update(*args, **kwargs)
+    def update(self, *args, **kwargs) -> None:
+        self.__dict__.update(*args, **kwargs)
 
-    def clear(self):
-        return self.__dict__.clear()
+    def clear(self) -> None:
+        self.__dict__.clear()
 
-    def get(self, key, default=None):
+    def get(self, key: str, default: Any = None) -> Any:
         return getattr(self, key, default)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<Context: {0!r}>'.format(vars(self))
 
-    def as_execution_options(self):
+    def as_execution_options(self) -> Mapping:
         limit_hard, limit_soft = self.timelimit or (None, None)
         return {
             'task_id': self.id,
@@ -124,7 +137,7 @@ class Context:
         }
 
     @property
-    def children(self):
+    def children(self) -> Sequence[ResultT]:
         # children must be an empy list for every thread
         if self._children is None:
             self._children = []
@@ -141,7 +154,7 @@ class Task:
         is overridden).
     """
 
-    __trace__ = None
+    __trace__: TracerT = None
     __v2_compat__ = False  # set by old base in celery.task.base
 
     MaxRetriesExceededError = MaxRetriesExceededError
@@ -151,42 +164,42 @@ class Task:
     Strategy = 'celery.worker.strategy:default'
 
     #: This is the instance bound to if the task is a method of a class.
-    __self__ = None
+    __self__: Any = None
 
     #: The application instance associated with this task class.
-    _app = None
+    _app: AppT = None
 
     #: Name of the task.
-    name = None
+    name: str = None
 
     #: Enable argument checking.
     #: You can set this to false if you don't want the signature to be
     #: checked when calling the task.
     #: Defaults to :attr:`app.strict_typing <@Celery.strict_typing>`.
-    typing = None
+    typing: bool = None
 
     #: Maximum number of retries before giving up.  If set to :const:`None`,
     #: it will **never** stop retrying.
-    max_retries = 3
+    max_retries: int = 3
 
     #: Default time in seconds before a retry of the task should be
     #: executed.  3 minutes by default.
-    default_retry_delay = 3 * 60
+    default_retry_delay = 180.0
 
     #: Rate limit for this task type.  Examples: :const:`None` (no rate
     #: limit), `'100/s'` (hundred tasks a second), `'100/m'` (hundred tasks
     #: a minute),`'100/h'` (hundred tasks an hour)
-    rate_limit = None
+    rate_limit: Union[int, str] = None
 
     #: If enabled the worker won't store task state and return values
     #: for this task.  Defaults to the :setting:`task_ignore_result`
     #: setting.
-    ignore_result = None
+    ignore_result: bool = None
 
     #: If enabled the request will keep track of subtasks started by
     #: this task, and this information will be sent with the result
     #: (``result.children``).
-    trail = True
+    trail: bool = True
 
     #: If enabled the worker will send monitoring events related to
     #: this task (but only if the worker is configured to send
@@ -194,29 +207,29 @@ class Task:
     #: Note that this has no effect on the task-failure event case
     #: where a task is not registered (as it will have no task class
     #: to check this flag).
-    send_events = True
+    send_events: bool = True
 
     #: When enabled errors will be stored even if the task is otherwise
     #: configured to ignore results.
-    store_errors_even_if_ignored = None
+    store_errors_even_if_ignored: bool = None
 
     #: The name of a serializer that are registered with
     #: :mod:`kombu.serialization.registry`.  Default is `'pickle'`.
-    serializer = None
+    serializer: str = None
 
     #: Hard time limit.
     #: Defaults to the :setting:`task_time_limit` setting.
-    time_limit = None
+    time_limit: float = None
 
     #: Soft time limit.
     #: Defaults to the :setting:`task_soft_time_limit` setting.
-    soft_time_limit = None
+    soft_time_limit: float = None
 
     #: The result store backend used for this task.
-    backend = None
+    backend: BackendT = None
 
     #: If disabled this task won't be registered automatically.
-    autoregister = True
+    autoregister: bool = True
 
     #: If enabled the task will report its status as 'started' when the task
     #: is executed by a worker.  Disabled by default as the normal behavior
@@ -229,7 +242,7 @@ class Task:
     #:
     #: The application default can be overridden using the
     #: :setting:`task_track_started` setting.
-    track_started = None
+    track_started: bool = None
 
     #: When enabled messages for this task will be acknowledged **after**
     #: the task has been executed, and not *just before* (the
@@ -240,7 +253,7 @@ class Task:
     #:
     #: The application default can be overridden with the
     #: :setting:`task_acks_late` setting.
-    acks_late = None
+    acks_late: bool = None
 
     #: Even if :attr:`acks_late` is enabled, the worker will
     #: acknowledge tasks when the worker process executing them abruptly
@@ -252,7 +265,7 @@ class Task:
     #:
     #: Warning: Enabling this can cause message loops; make sure you know
     #: what you're doing.
-    reject_on_worker_lost = None
+    reject_on_worker_lost: bool = None
 
     #: Tuple of expected exceptions.
     #:
@@ -260,29 +273,29 @@ class Task:
     #: and that shouldn't be regarded as a real error by the worker.
     #: Currently this means that the state will be updated to an error
     #: state, but the worker won't log the event as an error.
-    throws = ()
+    throws: Tuple[type] = ()
 
     #: Default task expiry time.
-    expires = None
+    expires: float = None
 
     #: Max length of result representation used in logs and events.
-    resultrepr_maxsize = 1024
+    resultrepr_maxsize: int = 1024
 
     #: Task request stack, the current request will be the topmost.
-    request_stack = None
+    request_stack: LocalStack = None
 
     #: Some may expect a request to exist even if the task hasn't been
     #: called.  This should probably be deprecated.
-    _default_request = None
+    _default_request: Context = None
 
     #: Deprecated attribute ``abstract`` here for compatibility.
-    abstract = True
+    abstract: bool = True
 
-    _exec_options = None
+    _exec_options: Mapping = None
 
-    __bound__ = False
+    __bound__: bool = False
 
-    from_config = (
+    from_config: Tuple[Tuple[str, str], ...] = (
         ('serializer', 'task_serializer'),
         ('rate_limit', 'task_default_rate_limit'),
         ('track_started', 'task_track_started'),
@@ -292,13 +305,14 @@ class Task:
         ('store_errors_even_if_ignored', 'task_store_errors_even_if_ignored'),
     )
 
-    _backend = None  # set by backend property.
+    # set by backend property.
+    _backend: BackendT = None
 
     # - Tasks are lazily bound, so that configuration is not set
     # - until the task is actually used
 
     @classmethod
-    def bind(cls, app):
+    def bind(cls, app: AppT) -> AppT:
         was_bound, cls.__bound__ = cls.__bound__, True
         cls._app = app
         conf = app.conf
@@ -324,17 +338,17 @@ class Task:
         return app
 
     @classmethod
-    def on_bound(cls, app):
+    def on_bound(cls, app: AppT) -> None:
         """Called when the task is bound to an app.
 
         Note:
             This class method can be defined to do additional actions when
             the task class is bound to an app.
         """
-        pass
+        ...
 
     @classmethod
-    def _get_app(cls):
+    def _get_app(cls) -> AppT:
         if cls._app is None:
             cls._app = current_app
         if not cls.__bound__:
@@ -345,7 +359,7 @@ class Task:
     app = class_property(_get_app, bind)
 
     @classmethod
-    def annotate(cls):
+    def annotate(cls) -> None:
         for d in resolve_all_annotations(cls.app.annotations, cls):
             for key, value in d.items():
                 if key.startswith('@'):
@@ -354,7 +368,7 @@ class Task:
                     setattr(cls, key, value)
 
     @classmethod
-    def add_around(cls, attr, around):
+    def add_around(cls, attr: str, around: Callable) -> None:
         orig = getattr(cls, attr)
         if getattr(orig, '__wrapped__', None):
             orig = orig.__wrapped__
@@ -362,7 +376,7 @@ class Task:
         meth.__wrapped__ = orig
         setattr(cls, attr, meth)
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         _task_stack.push(self)
         self.push_request(args=args, kwargs=kwargs)
         try:
@@ -374,7 +388,7 @@ class Task:
             self.pop_request()
             _task_stack.pop()
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         # - tasks are pickled into the name of the task only, and the reciever
         # - simply grabs it from the local registry.
         # - in later versions the module of the task is also included,
@@ -384,14 +398,15 @@ class Task:
         mod = mod if mod and mod in sys.modules else None
         return (_unpickle_task_v2, (self.name, mod), None)
 
-    def run(self, *args, **kwargs):
+    def run(self, *args, **kwargs) -> Any:
         """The body of the task executed by workers."""
         raise NotImplementedError('Tasks must define the run method.')
 
-    def start_strategy(self, app, consumer, **kwargs):
+    def start_strategy(self, app: AppT, consumer: WorkerConsumerT,
+                       **kwargs) -> Callable:
         return instantiate(self.Strategy, self, app, consumer, **kwargs)
 
-    def delay(self, *args, **kwargs):
+    def delay(self, *args, **kwargs) -> ResultT:
         """Star argument version of :meth:`apply_async`.
 
         Does not support the extra options enabled by :meth:`apply_async`.
@@ -404,8 +419,15 @@ class Task:
         """
         return self.apply_async(args, kwargs)
 
-    def apply_async(self, args=None, kwargs=None, task_id=None, producer=None,
-                    link=None, link_error=None, shadow=None, **options):
+    def apply_async(self,
+                    args: Sequence = None,
+                    kwargs: Mapping = None,
+                    task_id: str = None,
+                    producer: ProducerT = None,
+                    link: Sequence[SignatureT] = None,
+                    link_error: Sequence[SignatureT] = None,
+                    shadow: str = None,
+                    **options) -> ResultT:
         """Apply tasks asynchronously by sending a message.
 
         Arguments:
@@ -526,7 +548,10 @@ class Task:
             **options
         )
 
-    def shadow_name(self, args, kwargs, options):
+    def shadow_name(self,
+                    args: Sequence,
+                    kwargs: Mapping,
+                    options: Mapping) -> str:
         """Override for custom task name in worker logs/monitoring.
 
         Example:
@@ -548,8 +573,12 @@ class Task:
         """
         pass
 
-    def signature_from_request(self, request=None, args=None, kwargs=None,
-                               queue=None, **extra_options):
+    def signature_from_request(self,
+                               request: Context = None,
+                               args: Sequence = None,
+                               kwargs: Mapping = None,
+                               queue: str = None,
+                               **extra_options) -> SignatureT:
         request = self.request if request is None else request
         args = request.args if args is None else args
         kwargs = request.kwargs if kwargs is None else kwargs
@@ -570,8 +599,15 @@ class Task:
         )
     subtask_from_request = signature_from_request  # XXX compat
 
-    def retry(self, args=None, kwargs=None, exc=None, throw=True,
-              eta=None, countdown=None, max_retries=None, **options):
+    def retry(self,
+              args: Sequence = None,
+              kwargs: Mapping = None,
+              exc: Exception = None,
+              throw: bool = True,
+              eta: datetime = None,
+              countdown: float = None,
+              max_retries: int = None,
+              **options) -> None:
         """Retry the task.
 
         Example:
@@ -680,10 +716,18 @@ class Task:
             raise ret
         return ret
 
-    def apply(self, args=None, kwargs=None,
-              link=None, link_error=None,
-              task_id=None, retries=None, throw=None,
-              logfile=None, loglevel=None, headers=None, **options):
+    def apply(self,
+              args: Sequence = None,
+              kwargs: Mapping = None,
+              link: Sequence[SignatureT] = None,
+              link_error: Sequence[SignatureT] = None,
+              task_id: str = None,
+              retries: int = None,
+              throw: bool = None,
+              logfile: str = None,
+              loglevel: int = None,
+              headers: Mapping = None,
+              **options) -> ResultT:
         """Execute this task locally, by blocking until the task returns.
 
         Arguments:
@@ -735,7 +779,7 @@ class Task:
         state = states.SUCCESS if ret.info is None else ret.info.state
         return EagerResult(task_id, retval, state, traceback=tb)
 
-    def AsyncResult(self, task_id, **kwargs):
+    def AsyncResult(self, task_id: str, **kwargs) -> ResultT:
         """Get AsyncResult instance for this kind of task.
 
         Arguments:
@@ -744,7 +788,8 @@ class Task:
         return self._get_app().AsyncResult(
             task_id, backend=self.backend, **kwargs)
 
-    def signature(self, args=None, *starargs, **starkwargs):
+    def signature(self, args: Sequence = None,
+                  *starargs, **starkwargs) -> SignatureT:
         """Create signature.
 
         Returns:
@@ -756,36 +801,39 @@ class Task:
         return signature(self, args, *starargs, **starkwargs)
     subtask = signature
 
-    def s(self, *args, **kwargs):
+    def s(self, *args, **kwargs) -> SignatureT:
         """Create signature.
 
         Shortcut for ``.s(*a, **k) -> .signature(a, k)``.
         """
         return self.signature(args, kwargs)
 
-    def si(self, *args, **kwargs):
+    def si(self, *args, **kwargs) -> SignatureT:
         """Create immutable signature.
 
         Shortcut for ``.si(*a, **k) -> .signature(a, k, immutable=True)``.
         """
         return self.signature(args, kwargs, immutable=True)
 
-    def chunks(self, it, n):
+    def chunks(self, it: Iterable, n: int) -> SignatureT:
         """Create a :class:`~celery.canvas.chunks` task for this task."""
         from celery import chunks
         return chunks(self.s(), it, n, app=self.app)
 
-    def map(self, it):
+    def map(self, it: Iterable) -> SignatureT:
         """Create a :class:`~celery.canvas.xmap` task from ``it``."""
         from celery import xmap
         return xmap(self.s(), it, app=self.app)
 
-    def starmap(self, it):
+    def starmap(self, it: Iterable) -> SignatureT:
         """Create a :class:`~celery.canvas.xstarmap` task from ``it``."""
         from celery import xstarmap
         return xstarmap(self.s(), it, app=self.app)
 
-    def send_event(self, type_, retry=True, retry_policy=None, **fields):
+    def send_event(self, type_: str,
+                   retry: bool = True,
+                   retry_policy: Mapping = None,
+                   **fields) -> Awaitable:
         """Send monitoring event message.
 
         This can be used to add custom event types in :pypi:`Flower`
@@ -811,7 +859,7 @@ class Task:
                 type_,
                 uuid=req.id, retry=retry, retry_policy=retry_policy, **fields)
 
-    def replace(self, sig):
+    def replace(self, sig: SignatureT) -> None:
         """Replace this task, with a new task inheriting the task id.
 
         .. versionadded:: 4.0
@@ -851,7 +899,8 @@ class Task:
         sig.delay()
         raise Ignore('Replaced by new task')
 
-    def add_to_chord(self, sig, lazy=False):
+    def add_to_chord(self, sig: SignatureT,
+                     lazy: bool = False) -> Union[ResultT, SignatureT]:
         """Add signature to the chord the current task is a member of.
 
         .. versionadded:: 4.0
@@ -871,7 +920,10 @@ class Task:
         self.backend.add_to_chord(self.request.group, result)
         return sig.delay() if not lazy else sig
 
-    def update_state(self, task_id=None, state=None, meta=None):
+    def update_state(self,
+                     task_id: str = None,
+                     state: str = None,
+                     meta: Mapping = None) -> None:
         """Update task state.
 
         Arguments:
@@ -884,7 +936,11 @@ class Task:
             task_id = self.request.id
         self.backend.store_result(task_id, meta, state)
 
-    def on_success(self, retval, task_id, args, kwargs):
+    def on_success(self,
+                   retval: Any,
+                   task_id: str,
+                   args: Sequence,
+                   kwargs: Mapping) -> None:
         """Success handler.
 
         Run by the worker if the task executes successfully.
@@ -898,9 +954,14 @@ class Task:
         Returns:
             None: The return value of this handler is ignored.
         """
-        pass
-
-    def on_retry(self, exc, task_id, args, kwargs, einfo):
+        ...
+
+    def on_retry(self,
+                 exc: Exception,
+                 task_id: str,
+                 args: Sequence,
+                 kwargs: Mapping,
+                 einfo: ExceptionInfo) -> None:
         """Retry handler.
 
         This is run by the worker when the task is to be retried.
@@ -915,9 +976,14 @@ class Task:
         Returns:
             None: The return value of this handler is ignored.
         """
-        pass
-
-    def on_failure(self, exc, task_id, args, kwargs, einfo):
+        ...
+
+    def on_failure(self,
+                   exc: Exception,
+                   task_id: str,
+                   args: Sequence,
+                   kwargs: Mapping,
+                   einfo: ExceptionInfo) -> None:
         """Error handler.
 
         This is run by the worker when the task fails.
@@ -932,9 +998,11 @@ class Task:
         Returns:
             None: The return value of this handler is ignored.
         """
-        pass
+        ...
 
-    def after_return(self, status, retval, task_id, args, kwargs, einfo):
+    def after_return(self, status: str, retval: Any, task_id: str,
+                     args: Sequence, kwargs: Mapping,
+                     einfo: ExceptionInfo) -> None:
         """Handler called after the task returns.
 
         Arguments:
@@ -948,24 +1016,24 @@ class Task:
         Returns:
             None: The return value of this handler is ignored.
         """
-        pass
+        ...
 
-    def add_trail(self, result):
+    def add_trail(self, result: ResultT) -> ResultT:
         if self.trail:
             self.request.children.append(result)
         return result
 
-    def push_request(self, *args, **kwargs):
+    def push_request(self, *args, **kwargs) -> None:
         self.request_stack.push(Context(*args, **kwargs))
 
-    def pop_request(self):
+    def pop_request(self) -> None:
         self.request_stack.pop()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(task)``."""
         return _reprtask(self, R_SELF_TASK if self.__self__ else R_INSTANCE)
 
-    def _get_request(self):
+    def _get_request(self) -> Context:
         """Get current request object."""
         req = self.request_stack.top
         if req is None:
@@ -977,23 +1045,23 @@ class Task:
         return req
     request = property(_get_request)
 
-    def _get_exec_options(self):
+    def _get_exec_options(self) -> Mapping:
         if self._exec_options is None:
             self._exec_options = extract_exec_options(self)
         return self._exec_options
 
     @property
-    def backend(self):
+    def backend(self) -> BackendT:
         backend = self._backend
         if backend is None:
             return self.app.backend
         return backend
 
     @backend.setter
-    def backend(self, value):  # noqa
+    def backend(self, value: BackendT) -> None:  # noqa
         self._backend = value
 
     @property
-    def __name__(self):
+    def __name__(self) -> str:
         return self.__class__.__name__
 BaseTask = Task  # noqa: E305 XXX compat alias

+ 17 - 5
celery/app/trace.py

@@ -19,8 +19,8 @@ import logging
 import os
 import sys
 
-from collections import namedtuple
 from time import monotonic
+from typing import Any, NamedTuple
 from warnings import warn
 
 from billiard.einfo import ExceptionInfo
@@ -79,9 +79,16 @@ LOG_RETRY = """\
 Task %(name)s[%(id)s] retry: %(exc)s\
 """
 
-log_policy_t = namedtuple(
-    'log_policy_t', ('format', 'description', 'severity', 'traceback', 'mail'),
-)
+
+class log_policy_t(NamedTuple):
+    """Describes the logging policy for a specific state."""
+
+    format: str
+    description: str
+    severity: int
+    traceback: int
+    mail: int
+
 
 log_policy_reject = log_policy_t(LOG_REJECTED, 'rejected', logging.WARN, 1, 1)
 log_policy_ignore = log_policy_t(LOG_IGNORED, 'ignored', logging.INFO, 0, 0)
@@ -111,7 +118,12 @@ IGNORE_STATES = frozenset({IGNORED, RETRY, REJECTED})
 _localized = []
 _patched = {}
 
-trace_ok_t = namedtuple('trace_ok_t', ('retval', 'info', 'runtime', 'retstr'))
+trace_ok_t = NamedTuple('trace_ok_t', [
+    ('retval', Any),
+    ('info', 'TraceInfo'),
+    ('runtime', float),
+    ('retstr', str),
+])
 
 
 def task_has_custom(task, attr):

+ 61 - 36
celery/app/utils.py

@@ -4,21 +4,23 @@ import os
 import platform as _platform
 import re
 
-from collections import Mapping, namedtuple
+from collections import Mapping
 from copy import deepcopy
 from types import ModuleType
+from typing import Any, Callable, MutableMapping, NamedTuple, Set, Union
 
 from kombu.utils.url import maybe_sanitize_url
 
 from celery.exceptions import ImproperlyConfigured
 from celery.platforms import pyimplementation
+from celery.types import AppT
 from celery.utils.collections import ConfigurationView
 from celery.utils.text import pretty
 from celery.utils.imports import import_from_cwd, symbol_by_name, qualname
 
 from .defaults import (
     _TO_NEW_KEY, _TO_OLD_KEY, _OLD_DEFAULTS, _OLD_SETTING_KEYS,
-    DEFAULTS, SETTING_KEYS, find,
+    DEFAULTS, SETTING_KEYS, find, find_result_t,
 )
 
 __all__ = [
@@ -65,7 +67,7 @@ Or change all of the settings to use the new format :)
 FMT_REPLACE_SETTING = '{replace:<36} -> {with_}'
 
 
-def appstr(app):
+def appstr(app: AppT) -> str:
     """String used in __repr__ etc, to id app instances."""
     return '{0}:{1:#x}'.format(app.main or '__main__', id(app))
 
@@ -80,7 +82,7 @@ class Settings(ConfigurationView):
     """
 
     @property
-    def broker_read_url(self):
+    def broker_read_url(self) -> str:
         return (
             os.environ.get('CELERY_BROKER_READ_URL') or
             self.get('broker_read_url') or
@@ -88,7 +90,7 @@ class Settings(ConfigurationView):
         )
 
     @property
-    def broker_write_url(self):
+    def broker_write_url(self) -> str:
         return (
             os.environ.get('CELERY_BROKER_WRITE_URL') or
             self.get('broker_write_url') or
@@ -96,40 +98,40 @@ class Settings(ConfigurationView):
         )
 
     @property
-    def broker_url(self):
+    def broker_url(self) -> str:
         return (
             os.environ.get('CELERY_BROKER_URL') or
             self.first('broker_url', 'broker_host')
         )
 
     @property
-    def task_default_exchange(self):
+    def task_default_exchange(self) -> str:
         return self.first(
             'task_default_exchange',
             'task_default_queue',
         )
 
     @property
-    def task_default_routing_key(self):
+    def task_default_routing_key(self) -> str:
         return self.first(
             'task_default_routing_key',
             'task_default_queue',
         )
 
     @property
-    def timezone(self):
+    def timezone(self) -> str:
         # this way we also support django's time zone.
         return self.first('timezone', 'time_zone')
 
-    def without_defaults(self):
+    def without_defaults(self) -> 'Settings':
         """Return the current configuration, but without defaults."""
         # the last stash is the default settings, so just skip that
         return Settings({}, self.maps[:-1])
 
-    def value_set_for(self, key):
+    def value_set_for(self, key: str) -> bool:
         return key in self.without_defaults()
 
-    def find_option(self, name, namespace=''):
+    def find_option(self, name: str, namespace: str = '') -> find_result_t:
         """Search for option by name.
 
         Example:
@@ -146,11 +148,11 @@ class Settings(ConfigurationView):
         """
         return find(name, namespace)
 
-    def find_value_for_key(self, name, namespace='celery'):
+    def find_value_for_key(self, name: str, namespace: str = 'celery') -> Any:
         """Shortcut to ``get_by_parts(*find_option(name)[:-1])``."""
         return self.get_by_parts(*self.find_option(name, namespace)[:-1])
 
-    def get_by_parts(self, *parts):
+    def get_by_parts(self, *parts) -> Any:
         """Return the current value for setting specified as a path.
 
         Example:
@@ -160,7 +162,7 @@ class Settings(ConfigurationView):
         """
         return self['_'.join(part for part in parts if part)]
 
-    def finalize(self):
+    def finalize(self) -> None:
         # See PendingConfiguration in celery/app/base.py
         # first access will read actual configuration.
         try:
@@ -169,7 +171,9 @@ class Settings(ConfigurationView):
             pass
         return self
 
-    def table(self, with_defaults=False, censored=True):
+    def table(self, *,
+              with_defaults: bool = False,
+              censored: bool = True) -> Mapping:
         filt = filter_hidden_settings if censored else lambda v: v
         dict_members = dir(dict)
         self.finalize()
@@ -179,24 +183,29 @@ class Settings(ConfigurationView):
             if not k.startswith('_') and k not in dict_members
         })
 
-    def humanize(self, with_defaults=False, censored=True):
+    def humanize(self, *,
+                 with_defaults: bool = False,
+                 censored: bool = True) -> str:
         """Return a human readable text showing configuration changes."""
         return '\n'.join(
             '{0}: {1}'.format(key, pretty(value, width=50))
             for key, value in self.table(with_defaults, censored).items())
 
 
-def _new_key_to_old(key, convert=_TO_OLD_KEY.get):
+def _new_key_to_old(key: str, *, convert: Callable = _TO_OLD_KEY.get) -> str:
     return convert(key, key)
 
 
-def _old_key_to_new(key, convert=_TO_NEW_KEY.get):
+def _old_key_to_new(key: str, *, convert: Callable = _TO_NEW_KEY.get) -> str:
     return convert(key, key)
 
 
-_settings_info_t = namedtuple('settings_info_t', (
-    'defaults', 'convert', 'key_t', 'mix_error',
-))
+class _settings_info_t(NamedTuple):
+    defaults: Mapping
+    convert: Mapping
+    key_t: Callable
+    mix_error: str
+
 
 _settings_info = _settings_info_t(
     DEFAULTS, _TO_NEW_KEY, _old_key_to_new, E_MIX_OLD_INTO_NEW,
@@ -206,8 +215,12 @@ _old_settings_info = _settings_info_t(
 )
 
 
-def detect_settings(conf, preconf={}, ignore_keys=set(), prefix=None,
-                    all_keys=SETTING_KEYS, old_keys=_OLD_SETTING_KEYS):
+def detect_settings(conf: MutableMapping,
+                    preconf: Mapping = {},
+                    ignore_keys: Set = set(),
+                    prefix: str = None,
+                    all_keys: Set = SETTING_KEYS,
+                    old_keys: Set = _OLD_SETTING_KEYS) -> Settings:
     source = conf
     if conf is None:
         source, conf = preconf, {}
@@ -260,42 +273,51 @@ def detect_settings(conf, preconf={}, ignore_keys=set(), prefix=None,
 class AppPickler:
     """Old application pickler/unpickler (< 3.1)."""
 
-    def __call__(self, cls, *args):
+    def __call__(self, cls: type, *args) -> AppT:
         kwargs = self.build_kwargs(*args)
         app = self.construct(cls, **kwargs)
         self.prepare(app, **kwargs)
         return app
 
-    def prepare(self, app, **kwargs):
+    def prepare(self, app: AppT, **kwargs) -> None:
         app.conf.update(kwargs['changes'])
 
-    def build_kwargs(self, *args):
+    def build_kwargs(self, *args) -> Mapping:
         return self.build_standard_kwargs(*args)
 
-    def build_standard_kwargs(self, main, changes, loader, backend, amqp,
-                              events, log, control, accept_magic_kwargs,
-                              config_source=None):
+    def build_standard_kwargs(
+            self,
+            main: str,
+            changes: Mapping,
+            loader: Union[str, type],
+            backend: Union[str, type],
+            amqp: Union[str, type],
+            events: Union[str, type],
+            log: Union[str, type],
+            control: Union[str, type],
+            accept_magic_kwargs: bool,
+            config_source: str = None) -> Mapping:
         return dict(main=main, loader=loader, backend=backend, amqp=amqp,
                     changes=changes, events=events, log=log, control=control,
                     set_as_current=False,
                     config_source=config_source)
 
-    def construct(self, cls, **kwargs):
+    def construct(self, cls: Callable, **kwargs) -> Any:
         return cls(**kwargs)
 
 
-def _unpickle_app(cls, pickler, *args):
+def _unpickle_app(cls, pickler: type, *args) -> AppT:
     """Rebuild app for versions 2.5+."""
     return pickler()(cls, *args)
 
 
-def _unpickle_app_v2(cls, kwargs):
+def _unpickle_app_v2(cls: type, kwargs: Mapping) -> AppT:
     """Rebuild app for versions 3.1+."""
     kwargs['set_as_current'] = False
     return cls(**kwargs)
 
 
-def filter_hidden_settings(conf):
+def filter_hidden_settings(conf: Mapping) -> Mapping:
     """Filter sensitive settings."""
     def maybe_censor(key, value, mask='*' * 8):
         if isinstance(value, Mapping):
@@ -314,7 +336,7 @@ def filter_hidden_settings(conf):
     return {k: maybe_censor(k, v) for k, v in conf.items()}
 
 
-def bugreport(app):
+def bugreport(app: AppT) -> str:
     """Return a string containing information useful in bug-reports."""
     import billiard
     import celery
@@ -344,7 +366,10 @@ def bugreport(app):
     )
 
 
-def find_app(app, symbol_by_name=symbol_by_name, imp=import_from_cwd):
+def find_app(app: AppT,
+             *,
+             symbol_by_name: Callable = symbol_by_name,
+             imp: Callable = import_from_cwd) -> AppT:
     """Find app by name."""
     from .base import Celery
 

+ 80 - 52
celery/backends/async.py

@@ -4,14 +4,22 @@ import threading
 
 from collections import deque
 from time import monotonic, sleep
-from weakref import WeakKeyDictionary
+from typing import (
+    Any, Awaitable, Callable, Iterable, Iterator, Mapping, Set, Sequence,
+)
 from queue import Empty
+from weakref import WeakKeyDictionary
 
+from kombu.types import MessageT
 from kombu.utils.compat import detect_environment
 from kombu.utils.objects import cached_property
 
 from celery import states
 from celery.exceptions import TimeoutError
+from celery.types import AppT, BackendT, ResultT, ResultConsumerT
+from celery.utils.collections import BufferMap
+
+from .base import pending_results_t
 
 __all__ = [
     'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
@@ -21,9 +29,9 @@ __all__ = [
 drainers = {}
 
 
-def register_drainer(name):
+def register_drainer(name: str) -> Callable:
     """Decorator used to register a new result drainer type."""
-    def _inner(cls):
+    def _inner(cls: type) -> type:
         drainers[name] = cls
         return cls
     return _inner
@@ -33,16 +41,19 @@ def register_drainer(name):
 class Drainer:
     """Result draining service."""
 
-    def __init__(self, result_consumer):
+    def __init__(self, result_consumer: ResultConsumerT) -> None:
         self.result_consumer = result_consumer
 
-    def start(self):
-        pass
+    def start(self) -> None:
+        ...
 
-    def stop(self):
-        pass
+    def stop(self) -> None:
+        ...
 
-    def drain_events_until(self, p, timeout=None, on_interval=None, wait=None):
+    def drain_events_until(self, p: Awaitable,
+                           timeout: float = None,
+                           on_interval: Callable = None,
+                           wait: Callable = None) -> Iterator:
         wait = wait or self.result_consumer.drain_events
         time_start = monotonic()
 
@@ -59,7 +70,8 @@ class Drainer:
             if p.ready:  # got event on the wanted channel.
                 break
 
-    def wait_for(self, p, wait, timeout=None):
+    def wait_for(self, p: Awaitable, wait: Callable,
+                 timeout: float = None) -> None:
         wait(timeout=timeout)
 
 
@@ -67,13 +79,13 @@ class greenletDrainer(Drainer):
     spawn = None
     _g = None
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         super(greenletDrainer, self).__init__(*args, **kwargs)
         self._started = threading.Event()
         self._stopped = threading.Event()
         self._shutdown = threading.Event()
 
-    def run(self):
+    def run(self) -> None:
         self._started.set()
         while not self._stopped.is_set():
             try:
@@ -82,16 +94,17 @@ class greenletDrainer(Drainer):
                 pass
         self._shutdown.set()
 
-    def start(self):
+    def start(self) -> None:
         if not self._started.is_set():
             self._g = self.spawn(self.run)
             self._started.wait()
 
-    def stop(self):
+    def stop(self) -> None:
         self._stopped.set()
         self._shutdown.wait(threading.TIMEOUT_MAX)
 
-    def wait_for(self, p, wait, timeout=None):
+    def wait_for(self, p: Awaitable, wait: Callable,
+                 timeout: float = None) -> None:
         self.start()
         if not p.ready:
             sleep(0)
@@ -101,7 +114,7 @@ class greenletDrainer(Drainer):
 class eventletDrainer(greenletDrainer):
 
     @cached_property
-    def spawn(self):
+    def spawn(self) -> Callable:
         from eventlet import spawn
         return spawn
 
@@ -110,7 +123,7 @@ class eventletDrainer(greenletDrainer):
 class geventDrainer(greenletDrainer):
 
     @cached_property
-    def spawn(self):
+    def spawn(self) -> Callable:
         from gevent import spawn
         return spawn
 
@@ -118,10 +131,13 @@ class geventDrainer(greenletDrainer):
 class AsyncBackendMixin:
     """Mixin for backends that enables the async API."""
 
-    def _collect_into(self, result, bucket):
+    def _collect_into(self, result: ResultT, bucket: deque):
         self.result_consumer.buckets[result] = bucket
 
-    def iter_native(self, result, no_ack=True, **kwargs):
+    async def iter_native(
+            self, result: ResultT,
+            *,
+            no_ack: bool = True, **kwargs) -> Iterator[str, Mapping]:
         self._ensure_not_eager()
 
         results = result.results
@@ -145,7 +161,9 @@ class AsyncBackendMixin:
             node = bucket.popleft()
             yield node.id, node._cache
 
-    def add_pending_result(self, result, weak=False, start_drainer=True):
+    def add_pending_result(self, result: ResultT,
+                           weak: bool = False,
+                           start_drainer: bool = True) -> ResultT:
         if start_drainer:
             self.result_consumer.drainer.start()
         try:
@@ -154,57 +172,64 @@ class AsyncBackendMixin:
             self._add_pending_result(result.id, result, weak=weak)
         return result
 
-    def _maybe_resolve_from_buffer(self, result):
+    def _maybe_resolve_from_buffer(self, result: ResultT) -> None:
         result._maybe_set_cache(self._pending_messages.take(result.id))
 
-    def _add_pending_result(self, task_id, result, weak=False):
+    def _add_pending_result(self, task_id: str, result: ResultT,
+                            weak: bool = False) -> None:
         concrete, weak_ = self._pending_results
         if task_id not in weak_ and result.id not in concrete:
             (weak_ if weak else concrete)[task_id] = result
             self.result_consumer.consume_from(task_id)
 
-    def add_pending_results(self, results, weak=False):
+    def add_pending_results(self, results: Sequence[ResultT],
+                            weak: bool = False) -> None:
         self.result_consumer.drainer.start()
-        return [self.add_pending_result(result, weak=weak, start_drainer=False)
-                for result in results]
+        [self.add_pending_result(result, weak=weak, start_drainer=False)
+         for result in results]
 
-    def remove_pending_result(self, result):
+    def remove_pending_result(self, result: ResultT) -> ResultT:
         self._remove_pending_result(result.id)
         self.on_result_fulfilled(result)
         return result
 
-    def _remove_pending_result(self, task_id):
+    def _remove_pending_result(self, task_id: str) -> None:
         for map in self._pending_results:
             map.pop(task_id, None)
 
-    def on_result_fulfilled(self, result):
+    def on_result_fulfilled(self, result: ResultT) -> None:
         self.result_consumer.cancel_for(result.id)
 
-    def wait_for_pending(self, result,
-                         callback=None, propagate=True, **kwargs):
+    def wait_for_pending(self, result: ResultT,
+                         callback: Callable = None,
+                         propagate: bool = True,
+                         **kwargs) -> Any:
         self._ensure_not_eager()
         for _ in self._wait_for_pending(result, **kwargs):
             pass
         return result.maybe_throw(callback=callback, propagate=propagate)
 
-    def _wait_for_pending(self, result,
-                          timeout=None, on_interval=None, on_message=None,
-                          **kwargs):
+    def _wait_for_pending(self, result: ResultT,
+                          timeout: float = None,
+                          on_interval: Callable = None,
+                          on_message: Callable = None,
+                          **kwargs) -> Iterable:
         return self.result_consumer._wait_for_pending(
             result, timeout=timeout,
             on_interval=on_interval, on_message=on_message,
         )
 
     @property
-    def is_async(self):
+    def is_async(self) -> bool:
         return True
 
 
 class BaseResultConsumer:
     """Manager responsible for consuming result messages."""
 
-    def __init__(self, backend, app, accept,
-                 pending_results, pending_messages):
+    def __init__(self, backend: BackendT, app: AppT, accept: Set[str],
+                 pending_results: pending_results_t,
+                 pending_messages: BufferMap) -> None:
         self.backend = backend
         self.app = app
         self.accept = accept
@@ -214,37 +239,39 @@ class BaseResultConsumer:
         self.buckets = WeakKeyDictionary()
         self.drainer = drainers[detect_environment()](self)
 
-    def start(self, initial_task_id, **kwargs):
+    def start(self, initial_task_id: str, **kwargs) -> None:
         raise NotImplementedError()
 
-    def stop(self):
-        pass
+    def stop(self) -> None:
+        ...
 
-    def drain_events(self, timeout=None):
+    def drain_events(self, timeout: float = None) -> None:
         raise NotImplementedError()
 
-    def consume_from(self, task_id):
+    def consume_from(self, task_id: str) -> None:
         raise NotImplementedError()
 
-    def cancel_for(self, task_id):
+    def cancel_for(self, task_id: str) -> None:
         raise NotImplementedError()
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         self.buckets.clear()
         self.buckets = WeakKeyDictionary()
         self.on_message = None
         self.on_after_fork()
 
-    def on_after_fork(self):
-        pass
+    def on_after_fork(self) -> None:
+        ...
 
-    def drain_events_until(self, p, timeout=None, on_interval=None):
+    def drain_events_until(self, p: Awaitable,
+                           timeout: float = None,
+                           on_interval: Callable = None) -> Iterable:
         return self.drainer.drain_events_until(
             p, timeout=timeout, on_interval=on_interval)
 
     def _wait_for_pending(self, result,
                           timeout=None, on_interval=None, on_message=None,
-                          **kwargs):
+                          **kwargs) -> Iterable:
         self.on_wait_for_pending(result, timeout=timeout, **kwargs)
         prev_on_m, self.on_message = self.on_message, on_message
         try:
@@ -258,13 +285,14 @@ class BaseResultConsumer:
         finally:
             self.on_message = prev_on_m
 
-    def on_wait_for_pending(self, result, timeout=None, **kwargs):
-        pass
+    def on_wait_for_pending(self, result: ResultT,
+                            timeout: float = None, **kwargs) -> None:
+        ...
 
-    def on_out_of_band_result(self, message):
+    def on_out_of_band_result(self, message: MessageT) -> None:
         self.on_state_change(message.payload, message)
 
-    def _get_pending_result(self, task_id):
+    def _get_pending_result(self, task_id: str) -> ResultT:
         for mapping in self._pending_results:
             try:
                 return mapping[task_id]
@@ -272,7 +300,7 @@ class BaseResultConsumer:
                 pass
         raise KeyError(task_id)
 
-    def on_state_change(self, meta, message):
+    def on_state_change(self, meta: Mapping, message: MessageT) -> None:
         if self.on_message:
             self.on_message(meta)
         if meta['status'] in states.READY_STATES:

+ 264 - 166
celery/backends/base.py

@@ -9,8 +9,12 @@
 import sys
 import time
 
-from collections import namedtuple
 from datetime import timedelta
+from numbers import Number
+from typing import (
+    Any, AnyStr, Callable, Dict, Iterable, Iterator,
+    Mapping, MutableMapping, NamedTuple, Optional, Set, Sequence, Tuple,
+)
 from weakref import WeakValueDictionary
 
 from billiard.einfo import ExceptionInfo
@@ -18,6 +22,7 @@ from kombu.serialization import (
     dumps, loads, prepare_accept_content,
     registry as serializer_registry,
 )
+from kombu.types import ProducerT
 from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 from kombu.utils.url import maybe_sanitize_url
 
@@ -30,6 +35,7 @@ from celery.exceptions import (
 from celery.result import (
     GroupResult, ResultBase, allow_join_result, result_from_tuple,
 )
+from celery.types import AppT, BackendT, ResultT, RequestT, SignatureT
 from celery.utils.collections import BufferMap
 from celery.utils.functional import LRUCache, arity_greater
 from celery.utils.log import get_logger
@@ -47,10 +53,6 @@ logger = get_logger(__name__)
 
 MESSAGE_BUFFER_MAX = 8192
 
-pending_results_t = namedtuple('pending_results_t', (
-    'concrete', 'weak',
-))
-
 E_NO_BACKEND = """
 No result backend is configured.
 Please see the documentation for more information.
@@ -66,15 +68,22 @@ Result backends that supports chords: Redis, Database, Memcached, and more.
 """
 
 
-def unpickle_backend(cls, args, kwargs):
+class pending_results_t(NamedTuple):
+    """Tuple of concrete and weak references to pending results."""
+
+    concrete: Mapping
+    weak: WeakValueDictionary
+
+
+def unpickle_backend(cls: type, args: Tuple, kwargs: Dict) -> BackendT:
     """Return an unpickled backend."""
     return cls(*args, app=current_app._get_current_object(), **kwargs)
 
 
 class _nulldict(dict):
 
-    def ignore(self, *a, **kw):
-        pass
+    def ignore(self, *a, **kw) -> None:
+        ...
     __setitem__ = update = setdefault = ignore
 
 
@@ -89,7 +98,7 @@ class Backend:
     #: Time to sleep between polling each individual item
     #: in `ResultSet.iterate`. as opposed to the `interval`
     #: argument which is for each pass.
-    subpolling_interval = None
+    subpolling_interval: float = None
 
     #: If true the backend must implement :meth:`get_many`.
     supports_native_join = False
@@ -109,9 +118,15 @@ class Backend:
         'interval_max': 1,
     }
 
-    def __init__(self, app,
-                 serializer=None, max_cached_results=None, accept=None,
-                 expires=None, expires_type=None, url=None, **kwargs):
+    def __init__(self, app: AppT,
+                 *,
+                 serializer: str = None,
+                 max_cached_results: int = None,
+                 accept: Set[str] = None,
+                 expires: float = None,
+                 expires_type: Callable = None,
+                 url: str = None,
+                 **kwargs) -> None:
         self.app = app
         conf = self.app.conf
         self.serializer = serializer or conf.result_serializer
@@ -128,7 +143,7 @@ class Backend:
         self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
         self.url = url
 
-    def as_uri(self, include_password=False):
+    def as_uri(self, include_password: bool = False) -> str:
         """Return the backend as an URI, sanitizing the password or not."""
         # when using maybe_sanitize_url(), "/" is added
         # we're stripping it for consistency
@@ -137,33 +152,42 @@ class Backend:
         url = maybe_sanitize_url(self.url or '')
         return url[:-1] if url.endswith(':///') else url
 
-    def mark_as_started(self, task_id, **meta):
+    async def mark_as_started(self, task_id: str, **meta) -> Any:
         """Mark a task as started."""
-        return self.store_result(task_id, meta, states.STARTED)
+        return await self.store_result(task_id, meta, states.STARTED)
 
-    def mark_as_done(self, task_id, result,
-                     request=None, store_result=True, state=states.SUCCESS):
+    async def mark_as_done(self, task_id: str, result: Any,
+                           *,
+                           request: RequestT = None,
+                           store_result: bool = True,
+                           state: str = states.SUCCESS) -> None:
         """Mark task as successfully executed."""
         if store_result:
-            self.store_result(task_id, result, state, request=request)
+            await self.store_result(task_id, result, state, request=request)
         if request and request.chord:
-            self.on_chord_part_return(request, state, result)
-
-    def mark_as_failure(self, task_id, exc,
-                        traceback=None, request=None,
-                        store_result=True, call_errbacks=True,
-                        state=states.FAILURE):
+            await self.on_chord_part_return(request, state, result)
+
+    async def mark_as_failure(self, task_id: str, exc: Exception,
+                              *,
+                              traceback: str = None,
+                              request: RequestT = None,
+                              store_result: bool = True,
+                              call_errbacks: bool = True,
+                              state: str = states.FAILURE) -> None:
         """Mark task as executed with failure."""
         if store_result:
-            self.store_result(task_id, exc, state,
-                              traceback=traceback, request=request)
+            await self.store_result(task_id, exc, state,
+                                    traceback=traceback, request=request)
         if request:
             if request.chord:
-                self.on_chord_part_return(request, state, exc)
+                await self.on_chord_part_return(request, state, exc)
             if call_errbacks and request.errbacks:
-                self._call_task_errbacks(request, exc, traceback)
+                await self._call_task_errbacks(request, exc, traceback)
 
-    def _call_task_errbacks(self, request, exc, traceback):
+    async def _call_task_errbacks(self,
+                                  request: RequestT,
+                                  exc: Exception,
+                                  traceback: str) -> None:
         old_signature = []
         for errback in request.errbacks:
             errback = self.app.signature(errback)
@@ -176,30 +200,38 @@ class Backend:
             # need to do so if the errback only takes a single task_id arg.
             task_id = request.id
             root_id = request.root_id or task_id
-            group(old_signature, app=self.app).apply_async(
+            await group(old_signature, app=self.app).apply_async(
                 (task_id,), parent_id=task_id, root_id=root_id
             )
 
-    def mark_as_revoked(self, task_id, reason='',
-                        request=None, store_result=True, state=states.REVOKED):
+    async def mark_as_revoked(self, task_id: str, reason: str = '',
+                              *,
+                              request: RequestT = None,
+                              store_result: bool = True,
+                              state: str = states.REVOKED) -> None:
         exc = TaskRevokedError(reason)
         if store_result:
-            self.store_result(task_id, exc, state,
-                              traceback=None, request=request)
+            await self.store_result(task_id, exc, state,
+                                    traceback=None, request=request)
         if request and request.chord:
-            self.on_chord_part_return(request, state, exc)
-
-    def mark_as_retry(self, task_id, exc, traceback=None,
-                      request=None, store_result=True, state=states.RETRY):
+            await self.on_chord_part_return(request, state, exc)
+
+    async def mark_as_retry(self, task_id: str, exc: Exception,
+                            *,
+                            traceback: str = None,
+                            request: RequestT = None,
+                            store_result: bool = True,
+                            state: str = states.RETRY) -> None:
         """Mark task as being retries.
 
         Note:
             Stores the current exception (if any).
         """
-        return self.store_result(task_id, exc, state,
-                                 traceback=traceback, request=request)
+        return await self.store_result(task_id, exc, state,
+                                       traceback=traceback, request=request)
 
-    def chord_error_from_stack(self, callback, exc=None):
+    async def chord_error_from_stack(self, callback: Callable,
+                                     exc: Exception = None) -> ExceptionInfo:
         # need below import for test for some crazy reason
         from celery import group  # pylint: disable
         app = self.app
@@ -208,34 +240,37 @@ class Backend:
         except KeyError:
             backend = self
         try:
-            group(
+            await group(
                 [app.signature(errback)
                  for errback in callback.options.get('link_error') or []],
                 app=app,
             ).apply_async((callback.id,))
         except Exception as eb_exc:  # pylint: disable=broad-except
-            return backend.fail_from_current_stack(callback.id, exc=eb_exc)
+            return await backend.fail_from_current_stack(
+                callback.id, exc=eb_exc)
         else:
-            return backend.fail_from_current_stack(callback.id, exc=exc)
+            return await backend.fail_from_current_stack(
+                callback.id, exc=exc)
 
-    def fail_from_current_stack(self, task_id, exc=None):
+    async def fail_from_current_stack(self, task_id: str,
+                                      exc: Exception = None) -> ExceptionInfo:
         type_, real_exc, tb = sys.exc_info()
         try:
             exc = real_exc if exc is None else exc
             ei = ExceptionInfo((type_, exc, tb))
-            self.mark_as_failure(task_id, exc, ei.traceback)
+            await self.mark_as_failure(task_id, exc, ei.traceback)
             return ei
         finally:
             del tb
 
-    def prepare_exception(self, exc, serializer=None):
+    def prepare_exception(self, exc: Exception, serializer: str = None) -> Any:
         """Prepare exception for serialization."""
         serializer = self.serializer if serializer is None else serializer
         if serializer in EXCEPTION_ABLE_CODECS:
             return get_pickleable_exception(exc)
         return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
 
-    def exception_to_python(self, exc):
+    def exception_to_python(self, exc: Any) -> Exception:
         """Convert serialized exception to Python exception."""
         if exc:
             if not isinstance(exc, BaseException):
@@ -245,34 +280,35 @@ class Backend:
                 exc = get_pickled_exception(exc)
         return exc
 
-    def prepare_value(self, result):
+    def prepare_value(self, result: Any) -> Any:
         """Prepare value for storage."""
         if self.serializer != 'pickle' and isinstance(result, ResultBase):
             return result.as_tuple()
         return result
 
-    def encode(self, data):
+    def encode(self, data: Any) -> AnyStr:
         _, _, payload = self._encode(data)
         return payload
 
-    def _encode(self, data):
+    def _encode(self, data: Any) -> AnyStr:
         return dumps(data, serializer=self.serializer)
 
-    def meta_from_decoded(self, meta):
+    def meta_from_decoded(self, meta: MutableMapping) -> MutableMapping:
         if meta['status'] in self.EXCEPTION_STATES:
             meta['result'] = self.exception_to_python(meta['result'])
         return meta
 
-    def decode_result(self, payload):
+    def decode_result(self, payload: AnyStr) -> Mapping:
         return self.meta_from_decoded(self.decode(payload))
 
-    def decode(self, payload):
+    def decode(self, payload: AnyStr) -> Mapping:
         return loads(payload,
                      content_type=self.content_type,
                      content_encoding=self.content_encoding,
                      accept=self.accept)
 
-    def prepare_expires(self, value, type=None):
+    def prepare_expires(self, value: Optional[Number],
+                        type: Callable = None) -> Optional[Number]:
         if value is None:
             value = self.app.conf.result_expires
         if isinstance(value, timedelta):
@@ -281,61 +317,63 @@ class Backend:
             return type(value)
         return value
 
-    def prepare_persistent(self, enabled=None):
+    def prepare_persistent(self, enabled: bool = None) -> bool:
         if enabled is not None:
             return enabled
         p = self.app.conf.result_persistent
         return self.persistent if p is None else p
 
-    def encode_result(self, result, state):
+    def encode_result(self, result: Any, state: str) -> Any:
         if state in self.EXCEPTION_STATES and isinstance(result, Exception):
             return self.prepare_exception(result)
         else:
             return self.prepare_value(result)
 
-    def is_cached(self, task_id):
+    def is_cached(self, task_id: str) -> bool:
         return task_id in self._cache
 
-    def store_result(self, task_id, result, state,
-                     traceback=None, request=None, **kwargs):
+    async def store_result(self, task_id: str, result: Any, state: str,
+                           *,
+                           traceback: str = None,
+                           request: RequestT = None, **kwargs) -> Any:
         """Update task state and result."""
         result = self.encode_result(result, state)
-        self._store_result(task_id, result, state, traceback,
-                           request=request, **kwargs)
+        await self._store_result(task_id, result, state, traceback,
+                                 request=request, **kwargs)
         return result
 
-    def forget(self, task_id):
+    async def forget(self, task_id: str) -> None:
         self._cache.pop(task_id, None)
-        self._forget(task_id)
+        await self._forget(task_id)
 
-    def _forget(self, task_id):
+    async def _forget(self, task_id: str) -> None:
         raise NotImplementedError('backend does not implement forget.')
 
-    def get_state(self, task_id):
+    def get_state(self, task_id: str) -> str:
         """Get the state of a task."""
         return self.get_task_meta(task_id)['status']
 
-    def get_traceback(self, task_id):
+    def get_traceback(self, task_id: str) -> str:
         """Get the traceback for a failed task."""
         return self.get_task_meta(task_id).get('traceback')
 
-    def get_result(self, task_id):
+    def get_result(self, task_id: str) -> Any:
         """Get the result of a task."""
         return self.get_task_meta(task_id).get('result')
 
-    def get_children(self, task_id):
+    def get_children(self, task_id: str) -> Sequence[str]:
         """Get the list of subtasks sent by a task."""
         try:
             return self.get_task_meta(task_id)['children']
         except KeyError:
             pass
 
-    def _ensure_not_eager(self):
+    def _ensure_not_eager(self) -> None:
         if self.app.conf.task_always_eager:
             raise RuntimeError(
                 "Cannot retrieve result with task_always_eager enabled")
 
-    def get_task_meta(self, task_id, cache=True):
+    def get_task_meta(self, task_id: str, cache: bool = True) -> Mapping:
         self._ensure_not_eager()
         if cache:
             try:
@@ -348,15 +386,15 @@ class Backend:
             self._cache[task_id] = meta
         return meta
 
-    def reload_task_result(self, task_id):
+    def reload_task_result(self, task_id: str) -> None:
         """Reload task result, even if it has been previously fetched."""
         self._cache[task_id] = self.get_task_meta(task_id, cache=False)
 
-    def reload_group_result(self, group_id):
+    def reload_group_result(self, group_id: str) -> None:
         """Reload group result, even if it has been previously fetched."""
         self._cache[group_id] = self.get_group_meta(group_id, cache=False)
 
-    def get_group_meta(self, group_id, cache=True):
+    def get_group_meta(self, group_id: str, cache: bool = True) -> Mapping:
         self._ensure_not_eager()
         if cache:
             try:
@@ -369,91 +407,118 @@ class Backend:
             self._cache[group_id] = meta
         return meta
 
-    def restore_group(self, group_id, cache=True):
+    def restore_group(self, group_id: str, cache: bool = True) -> GroupResult:
         """Get the result for a group."""
         meta = self.get_group_meta(group_id, cache=cache)
         if meta:
             return meta['result']
 
-    def save_group(self, group_id, result):
+    def save_group(self, group_id: str, result: GroupResult) -> GroupResult:
         """Store the result of an executed group."""
         return self._save_group(group_id, result)
 
-    def delete_group(self, group_id):
+    def _save_group(self, group_id: str, result: GroupResult) -> GroupResult:
+        raise NotImplementedError()
+
+    def delete_group(self, group_id: str) -> None:
         self._cache.pop(group_id, None)
-        return self._delete_group(group_id)
+        self._delete_group(group_id)
 
-    def cleanup(self):
+    def _delete_group(self, group_id: str) -> None:
+        raise NotImplementedError()
+
+    def cleanup(self) -> None:
         """Backend cleanup.
 
         Note:
             This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`.
         """
-        pass
+        ...
 
-    def process_cleanup(self):
+    def process_cleanup(self) -> None:
         """Cleanup actions to do at the end of a task worker process."""
-        pass
+        ...
 
-    def on_task_call(self, producer, task_id):
+    async def on_task_call(self, producer: ProducerT, task_id: str) -> Mapping:
         return {}
 
-    def add_to_chord(self, chord_id, result):
+    async def add_to_chord(self, chord_id: str, result: ResultT) -> None:
         raise NotImplementedError('Backend does not support add_to_chord')
 
-    def on_chord_part_return(self, request, state, result, **kwargs):
-        pass
+    async def on_chord_part_return(self, request: RequestT,
+                                   state: str,
+                                   result: Any,
+                                   **kwargs) -> None:
+        ...
 
-    def fallback_chord_unlock(self, group_id, body, result=None,
-                              countdown=1, **kwargs):
+    async def fallback_chord_unlock(self, group_id: str, body: SignatureT,
+                                    result: ResultT = None,
+                                    countdown: float = 1,
+                                    **kwargs) -> None:
         kwargs['result'] = [r.as_tuple() for r in result]
-        self.app.tasks['celery.chord_unlock'].apply_async(
+        await self.app.tasks['celery.chord_unlock'].apply_async(
             (group_id, body,), kwargs, countdown=countdown,
         )
 
-    def ensure_chords_allowed(self):
-        pass
+    def ensure_chords_allowed(self) -> None:
+        ...
 
-    def apply_chord(self, header, partial_args, group_id, body,
-                    options={}, **kwargs):
+    async def apply_chord(self, header: SignatureT, partial_args: Sequence,
+                          *,
+                          group_id: str, body: SignatureT,
+                          options: Mapping = {}, **kwargs) -> ResultT:
         self.ensure_chords_allowed()
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
-        result = header(*partial_args, task_id=group_id, **fixed_options or {})
-        self.fallback_chord_unlock(group_id, body, **kwargs)
+        result = await header(
+            *partial_args, task_id=group_id, **fixed_options or {})
+        await self.fallback_chord_unlock(group_id, body, **kwargs)
         return result
 
-    def current_task_children(self, request=None):
+    def current_task_children(self,
+                              request: RequestT = None) -> Sequence[Tuple]:
         request = request or getattr(get_current_task(), 'request', None)
         if request:
             return [r.as_tuple() for r in getattr(request, 'children', [])]
 
-    def __reduce__(self, args=(), kwargs={}):
+    def __reduce__(self, args: Tuple = (), kwargs: Dict = {}) -> Tuple:
         return (unpickle_backend, (self.__class__, args, kwargs))
 
 
 class SyncBackendMixin:
 
-    def iter_native(self, result, timeout=None, interval=0.5, no_ack=True,
-                    on_message=None, on_interval=None):
+    async def iter_native(
+            self, result: ResultT,
+            *,
+            timeout: float = None,
+            interval: float = 0.5,
+            no_ack: bool = True,
+            on_message: Callable = None,
+            on_interval: Callable = None) -> Iterable[Tuple[str, Mapping]]:
         self._ensure_not_eager()
         results = result.results
         if not results:
             return iter([])
-        return self.get_many(
+        return await self.get_many(
             {r.id for r in results},
             timeout=timeout, interval=interval, no_ack=no_ack,
             on_message=on_message, on_interval=on_interval,
         )
 
-    def wait_for_pending(self, result, timeout=None, interval=0.5,
-                         no_ack=True, on_message=None, on_interval=None,
-                         callback=None, propagate=True):
+    async def wait_for_pending(self, result: ResultT,
+                               *,
+                               timeout: float = None,
+                               interval: float = 0.5,
+                               no_ack: bool = True,
+                               on_message: Callable = None,
+                               on_interval: Callable = None,
+                               callback: Callable = None,
+                               propagate: bool = True) -> Any:
         self._ensure_not_eager()
         if on_message is not None:
             raise ImproperlyConfigured(
                 'Backend does not support on_message callback')
 
-        meta = self.wait_for(
+        meta = await self.wait_for(
             result.id, timeout=timeout,
             interval=interval,
             on_interval=on_interval,
@@ -463,8 +528,12 @@ class SyncBackendMixin:
             result._maybe_set_cache(meta)
             return result.maybe_throw(propagate=propagate, callback=callback)
 
-    def wait_for(self, task_id,
-                 timeout=None, interval=0.5, no_ack=True, on_interval=None):
+    async def wait_for(self, task_id: str,
+                       *,
+                       timeout: float = None,
+                       interval: float = 0.5,
+                       no_ack: bool = True,
+                       on_interval: Callable = None) -> Mapping:
         """Wait for task and return its result.
 
         If the task raises an exception, this exception
@@ -484,21 +553,23 @@ class SyncBackendMixin:
             if meta['status'] in states.READY_STATES:
                 return meta
             if on_interval:
-                on_interval()
+                await on_interval()
             # avoid hammering the CPU checking status.
             time.sleep(interval)
             time_elapsed += interval
             if timeout and time_elapsed >= timeout:
                 raise TimeoutError('The operation timed out.')
 
-    def add_pending_result(self, result, weak=False):
+    def add_pending_result(self, result: ResultT,
+                           *,
+                           weak: bool = False) -> ResultT:
         return result
 
-    def remove_pending_result(self, result):
+    def remove_pending_result(self, result: ResultT) -> ResultT:
         return result
 
     @property
-    def is_async(self):
+    def is_async(self) -> bool:
         return False
 
 
@@ -514,7 +585,7 @@ class BaseKeyValueStoreBackend(Backend):
     chord_keyprefix = 'chord-unlock-'
     implements_incr = False
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         if hasattr(self.key_t, '__func__'):  # pragma: no cover
             self.key_t = self.key_t.__func__  # remove binding
         self._encode_prefixes()
@@ -522,51 +593,51 @@ class BaseKeyValueStoreBackend(Backend):
         if self.implements_incr:
             self.apply_chord = self._apply_chord_incr
 
-    def _encode_prefixes(self):
+    def _encode_prefixes(self) -> None:
         self.task_keyprefix = self.key_t(self.task_keyprefix)
         self.group_keyprefix = self.key_t(self.group_keyprefix)
         self.chord_keyprefix = self.key_t(self.chord_keyprefix)
 
-    def get(self, key):
+    async def get(self, key: AnyStr) -> AnyStr:
         raise NotImplementedError('Must implement the get method.')
 
-    def mget(self, keys):
+    async def mget(self, keys: Sequence[AnyStr]) -> Sequence[AnyStr]:
         raise NotImplementedError('Does not support get_many')
 
-    def set(self, key, value):
+    async def set(self, key: AnyStr, value: AnyStr) -> None:
         raise NotImplementedError('Must implement the set method.')
 
-    def delete(self, key):
+    async def delete(self, key: AnyStr) -> None:
         raise NotImplementedError('Must implement the delete method')
 
-    def incr(self, key):
+    async def incr(self, key: AnyStr) -> None:
         raise NotImplementedError('Does not implement incr')
 
-    def expire(self, key, value):
-        pass
+    async def expire(self, key: AnyStr, value: AnyStr) -> None:
+        ...
 
-    def get_key_for_task(self, task_id, key=''):
+    def get_key_for_task(self, task_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for a task by id."""
         key_t = self.key_t
         return key_t('').join([
             self.task_keyprefix, key_t(task_id), key_t(key),
         ])
 
-    def get_key_for_group(self, group_id, key=''):
+    def get_key_for_group(self, group_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for a group by id."""
         key_t = self.key_t
         return key_t('').join([
             self.group_keyprefix, key_t(group_id), key_t(key),
         ])
 
-    def get_key_for_chord(self, group_id, key=''):
+    def get_key_for_chord(self, group_id: str, key: AnyStr = '') -> AnyStr:
         """Get the cache key for the chord waiting on group with given id."""
         key_t = self.key_t
         return key_t('').join([
             self.chord_keyprefix, key_t(group_id), key_t(key),
         ])
 
-    def _strip_prefix(self, key):
+    def _strip_prefix(self, key: AnyStr) -> AnyStr:
         """Take bytes: emit string."""
         key = self.key_t(key)
         for prefix in self.task_keyprefix, self.group_keyprefix:
@@ -574,14 +645,18 @@ class BaseKeyValueStoreBackend(Backend):
                 return bytes_to_str(key[len(prefix):])
         return bytes_to_str(key)
 
-    def _filter_ready(self, values, READY_STATES=states.READY_STATES):
+    def _filter_ready(
+            self,
+            values: Sequence[Tuple[Any, Any]],
+            *,
+            READY_STATES: Set[str] = states.READY_STATES) -> Iterable[Tuple]:
         for k, v in values:
             if v is not None:
                 v = self.decode_result(v)
                 if v['status'] in READY_STATES:
                     yield k, v
 
-    def _mget_to_results(self, values, keys):
+    def _mget_to_results(self, values: Any, keys: Sequence[AnyStr]) -> Mapping:
         if hasattr(values, 'items'):
             # client returns dict so mapping preserved.
             return {
@@ -595,9 +670,16 @@ class BaseKeyValueStoreBackend(Backend):
                 for i, v in self._filter_ready(enumerate(values))
             }
 
-    def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
-                 on_message=None, on_interval=None, max_iterations=None,
-                 READY_STATES=states.READY_STATES):
+    async def get_many(
+            self, task_ids: Sequence[str],
+            *,
+            timeout: float = None,
+            interval: float = 0.5,
+            no_ack: bool = True,
+            on_message: Callable = None,
+            on_interval: Callable = None,
+            max_iterations: int = None,
+            READY_STATES=states.READY_STATES) -> Iterator[str, Mapping]:
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
         cached_ids = set()
@@ -616,54 +698,59 @@ class BaseKeyValueStoreBackend(Backend):
         iterations = 0
         while ids:
             keys = list(ids)
-            r = self._mget_to_results(self.mget([self.get_key_for_task(k)
-                                                 for k in keys]), keys)
+            payloads = await self.mget([
+                self.get_key_for_task(k) for k in keys
+            ])
+            r = self._mget_to_results(payloads, keys)
             cache.update(r)
             ids.difference_update({bytes_to_str(v) for v in r})
             for key, value in r.items():
                 if on_message is not None:
-                    on_message(value)
+                    await on_message(value)
                 yield bytes_to_str(key), value
             if timeout and iterations * interval >= timeout:
                 raise TimeoutError('Operation timed out ({0})'.format(timeout))
             if on_interval:
-                on_interval()
+                await on_interval()
             time.sleep(interval)  # don't busy loop.
             iterations += 1
             if max_iterations and iterations >= max_iterations:
                 break
 
-    def _forget(self, task_id):
-        self.delete(self.get_key_for_task(task_id))
+    async def _forget(self, task_id: str) -> None:
+        await self.delete(self.get_key_for_task(task_id))
 
-    def _store_result(self, task_id, result, state,
-                      traceback=None, request=None, **kwargs):
+    async def _store_result(self, task_id: str, result: Any, state: str,
+                            traceback: str = None,
+                            request: RequestT = None,
+                            **kwargs) -> Any:
         meta = {
             'status': state, 'result': result, 'traceback': traceback,
             'children': self.current_task_children(request),
             'task_id': bytes_to_str(task_id),
         }
-        self.set(self.get_key_for_task(task_id), self.encode(meta))
+        await self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 
-    def _save_group(self, group_id, result):
-        self.set(self.get_key_for_group(group_id),
-                 self.encode({'result': result.as_tuple()}))
+    async def _save_group(self,
+                          group_id: str, result: GroupResult) -> GroupResult:
+        await self.set(self.get_key_for_group(group_id),
+                       self.encode({'result': result.as_tuple()}))
         return result
 
-    def _delete_group(self, group_id):
-        self.delete(self.get_key_for_group(group_id))
+    async def _delete_group(self, group_id: str) -> None:
+        await self.delete(self.get_key_for_group(group_id))
 
-    def _get_task_meta_for(self, task_id):
+    async def _get_task_meta_for(self, task_id: str) -> Mapping:
         """Get task meta-data for a task by id."""
-        meta = self.get(self.get_key_for_task(task_id))
+        meta = await self.get(self.get_key_for_task(task_id))
         if not meta:
             return {'status': states.PENDING, 'result': None}
         return self.decode_result(meta)
 
-    def _restore_group(self, group_id):
+    async def _restore_group(self, group_id: str) -> GroupResult:
         """Get task meta-data for a task by id."""
-        meta = self.get(self.get_key_for_group(group_id))
+        meta = await self.get(self.get_key_for_group(group_id))
         # previously this was always pickled, but later this
         # was extended to support other serializers, so the
         # structure is kind of weird.
@@ -673,16 +760,25 @@ class BaseKeyValueStoreBackend(Backend):
             meta['result'] = result_from_tuple(result, self.app)
             return meta
 
-    def _apply_chord_incr(self, header, partial_args, group_id, body,
-                          result=None, options={}, **kwargs):
+    async def _apply_chord_incr(self, header: SignatureT,
+                                partial_args: Sequence,
+                                group_id: str,
+                                body: SignatureT,
+                                result: ResultT = None,
+                                options: Mapping = {},
+                                **kwargs) -> ResultT:
         self.ensure_chords_allowed()
-        self.save_group(group_id, self.app.GroupResult(group_id, result))
+        await self.save_group(group_id, self.app.GroupResult(group_id, result))
 
         fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
 
-        return header(*partial_args, task_id=group_id, **fixed_options or {})
+        return await header(
+            *partial_args, task_id=group_id, **fixed_options or {})
 
-    def on_chord_part_return(self, request, state, result, **kwargs):
+    async def on_chord_part_return(
+            self,
+            request: RequestT, state: str, result: Any,
+            **kwargs) -> None:
         if not self.implements_incr:
             return
         app = self.app
@@ -691,25 +787,27 @@ class BaseKeyValueStoreBackend(Backend):
             return
         key = self.get_key_for_chord(gid)
         try:
-            deps = GroupResult.restore(gid, backend=self)
+            deps = await GroupResult.restore(gid, backend=self)
         except Exception as exc:  # pylint: disable=broad-except
             callback = maybe_signature(request.chord, app=app)
             logger.exception('Chord %r raised: %r', gid, exc)
-            return self.chord_error_from_stack(
+            await self.chord_error_from_stack(
                 callback,
                 ChordError('Cannot restore group: {0!r}'.format(exc)),
             )
+            return
         if deps is None:
             try:
                 raise ValueError(gid)
             except ValueError as exc:
                 callback = maybe_signature(request.chord, app=app)
                 logger.exception('Chord callback %r raised: %r', gid, exc)
-                return self.chord_error_from_stack(
+                await self.chord_error_from_stack(
                     callback,
                     ChordError('GroupResult {0} no longer exists'.format(gid)),
                 )
-        val = self.incr(key)
+                return
+        val = await self.incr(key)
         size = len(deps)
         if val > size:  # pragma: no cover
             logger.warning('Chord counter incremented too many times for %r',
@@ -719,7 +817,7 @@ class BaseKeyValueStoreBackend(Backend):
             j = deps.join_native if deps.supports_native_join else deps.join
             try:
                 with allow_join_result():
-                    ret = j(timeout=3.0, propagate=True)
+                    ret = await j(timeout=3.0, propagate=True)
             except Exception as exc:  # pylint: disable=broad-except
                 try:
                     culprit = next(deps._failed_join_report())
@@ -730,21 +828,21 @@ class BaseKeyValueStoreBackend(Backend):
                     reason = repr(exc)
 
                 logger.exception('Chord %r raised: %r', gid, reason)
-                self.chord_error_from_stack(callback, ChordError(reason))
+                await self.chord_error_from_stack(callback, ChordError(reason))
             else:
                 try:
-                    callback.delay(ret)
+                    await callback.delay(ret)
                 except Exception as exc:  # pylint: disable=broad-except
                     logger.exception('Chord %r raised: %r', gid, exc)
-                    self.chord_error_from_stack(
+                    await self.chord_error_from_stack(
                         callback,
                         ChordError('Callback error: {0!r}'.format(exc)),
                     )
             finally:
                 deps.delete()
-                self.client.delete(key)
+                await self.client.delete(key)
         else:
-            self.expire(key, self.expires)
+            await self.expire(key, self.expires)
 
 
 class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
@@ -756,16 +854,16 @@ class DisabledBackend(BaseBackend):
 
     _cache = {}   # need this attribute to reset cache in tests.
 
-    def store_result(self, *args, **kwargs):
-        pass
+    async def store_result(self, *args, **kwargs) -> Any:
+        ...
 
-    def ensure_chords_allowed(self):
+    def ensure_chords_allowed(self) -> None:
         raise NotImplementedError(E_CHORD_NO_BACKEND.strip())
 
-    def _is_disabled(self, *args, **kwargs):
+    def _is_disabled(self, *args, **kwargs) -> bool:
         raise NotImplementedError(E_NO_BACKEND.strip())
 
-    def as_uri(self, *args, **kwargs):
+    def as_uri(self, *args, **kwargs) -> str:
         return 'disabled://'
 
     get_state = get_result = get_traceback = _is_disabled

+ 71 - 47
celery/backends/rpc.py

@@ -6,12 +6,17 @@ RPC-style result backend, using reply-to and one queue per client.
 import kombu
 import time
 
+from typing import Any, Dict, Iterator, Mapping, Set, Tuple, Union
+
 from kombu.common import maybe_declare
+from kombu.types import ChannelT, ConnectionT, EntityT, MessageT, ProducerT
 from kombu.utils.compat import register_after_fork
 from kombu.utils.objects import cached_property
 
 from celery import states
 from celery._state import current_task, task_join_will_block
+from celery.types import AppT, BackendT, ResultT, RequestT
+from celery.result import GroupResult
 
 from . import base
 from .async import AsyncBackendMixin, BaseResultConsumer
@@ -32,21 +37,23 @@ class BacklogLimitExceeded(Exception):
     """Too much state history to fast-forward."""
 
 
-def _on_after_fork_cleanup_backend(backend):
+def _on_after_fork_cleanup_backend(backend: BackendT) -> None:
     backend._after_fork()
 
 
 class ResultConsumer(BaseResultConsumer):
     Consumer = kombu.Consumer
 
-    _connection = None
-    _consumer = None
+    _connection: ConnectionT = None
+    _consumer: ConsumerT = None
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
         self._create_binding = self.backend._create_binding
 
-    def start(self, initial_task_id, no_ack=True, **kwargs):
+    def start(self, initial_task_id: str,
+              *,
+              no_ack: bool = True, **kwargs) -> None:
         self._connection = self.app.connection()
         initial_queue = self._create_binding(initial_task_id)
         self._consumer = self.Consumer(
@@ -55,33 +62,34 @@ class ResultConsumer(BaseResultConsumer):
             accept=self.accept)
         self._consumer.consume()
 
-    def drain_events(self, timeout=None):
+    def drain_events(self, timeout: float = None) -> None:
         if self._connection:
-            return self._connection.drain_events(timeout=timeout)
+            self._connection.drain_events(timeout=timeout)
         elif timeout:
             time.sleep(timeout)
 
-    def stop(self):
+    def stop(self) -> None:
         try:
             self._consumer.cancel()
         finally:
             self._connection.close()
 
-    def on_after_fork(self):
+    def on_after_fork(self) -> None:
         self._consumer = None
         if self._connection is not None:
             self._connection.collect()
             self._connection = None
 
-    def consume_from(self, task_id):
+    def consume_from(self, task_id: str) -> None:
         if self._consumer is None:
-            return self.start(task_id)
-        queue = self._create_binding(task_id)
-        if not self._consumer.consuming_from(queue):
-            self._consumer.add_queue(queue)
-            self._consumer.consume()
+            self.start(task_id)
+        else:
+            queue = self._create_binding(task_id)
+            if not self._consumer.consuming_from(queue):
+                self._consumer.add_queue(queue)
+                self._consumer.consume()
 
-    def cancel_for(self, task_id):
+    def cancel_for(self, task_id: str) -> None:
         if self._consumer:
             self._consumer.cancel_by_queue(self._create_binding(task_id).name)
 
@@ -117,8 +125,14 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
 
         can_cache_declaration = False
 
-    def __init__(self, app, connection=None, exchange=None, exchange_type=None,
-                 persistent=None, serializer=None, auto_delete=True, **kwargs):
+    def __init__(self, app: AppT,
+                 connection: ConnectionT = None,
+                 exchange: str = None,
+                 exchange_type: str = None,
+                 persistent: bool = None,
+                 serializer: str = None,
+                 auto_delete: bool = True,
+                 **kwargs) -> None:
         super().__init__(app, **kwargs)
         conf = self.app.conf
         self._connection = connection
@@ -139,32 +153,36 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         if register_after_fork is not None:
             register_after_fork(self, _on_after_fork_cleanup_backend)
 
-    def _after_fork(self):
+    def _after_fork(self) -> None:
         # clear state for child processes.
         self._pending_results.clear()
         self.result_consumer._after_fork()
 
-    def _create_exchange(self, name, type='direct', delivery_mode=2):
+    def _create_exchange(self, name: str,
+                         type: str = 'direct',
+                         delivery_mode: Union[int, str] = 2) -> Exchange:
         # uses direct to queue routing (anon exchange).
         return self.Exchange(None)
 
-    def _create_binding(self, task_id):
+    def _create_binding(self, task_id: str) -> EntityT:
         """Create new binding for task with id."""
         # RPC backend caches the binding, as one queue is used for all tasks.
         return self.binding
 
-    def ensure_chords_allowed(self):
+    def ensure_chords_allowed(self) -> None:
         raise NotImplementedError(E_NO_CHORD_SUPPORT.strip())
 
-    def on_task_call(self, producer, task_id):
+    def on_task_call(self, producer: ProducerT, task_id: str) -> Mapping:
         # Called every time a task is sent when using this backend.
         # We declare the queue we receive replies on in advance of sending
         # the message, but we skip this if running in the prefork pool
         # (task_join_will_block), as we know the queue is already declared.
         if not task_join_will_block():
             maybe_declare(self.binding(producer.channel), retry=True)
+        return {}
 
-    def destination_for(self, task_id, request):
+    def destination_for(self,
+                        task_id: str, request: RequestT) -> Tuple[str, str]:
         """Get the destination for result by task id.
 
         Returns:
@@ -179,22 +197,23 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
                 'RPC backend missing task request for {0!r}'.format(task_id))
         return request.reply_to, request.correlation_id or task_id
 
-    def on_reply_declare(self, task_id):
+    def on_reply_declare(self, task_id: str) -> None:
         # Return value here is used as the `declare=` argument
         # for Producer.publish.
         # By default we don't have to declare anything when sending a result.
-        pass
+        ...
 
-    def on_result_fulfilled(self, result):
+    def on_result_fulfilled(self, result: ResultT) -> None:
         # This usually cancels the queue after the result is received,
         # but we don't have to cancel since we have one queue per process.
-        pass
+        ...
 
-    def as_uri(self, include_password=True):
+    def as_uri(self, include_password: bool = True) -> str:
         return 'rpc://'
 
-    def store_result(self, task_id, result, state,
-                     traceback=None, request=None, **kwargs):
+    def store_result(self, task_id: str, result: Any, state: str,
+                     traceback: str = None, request: RequestT = None,
+                     **kwargs) -> Any:
         """Send task return value and state."""
         routing_key, correlation_id = self.destination_for(task_id, request)
         if not routing_key:
@@ -212,7 +231,9 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             )
         return result
 
-    def _to_result(self, task_id, state, result, traceback, request):
+    def _to_result(self,
+                   task_id: str, state: str, result: Any,
+                   traceback: str, request: RequestT) -> Mapping:
         return {
             'task_id': task_id,
             'status': state,
@@ -221,7 +242,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             'children': self.current_task_children(request),
         }
 
-    def on_out_of_band_result(self, task_id, message):
+    def on_out_of_band_result(self, task_id: str, message: MessageT) -> None:
         # Callback called when a reply for a task is received,
         # but we have no idea what do do with it.
         # Since the result is not pending, we put it in a separate
@@ -230,7 +251,8 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             self.result_consumer.on_out_of_band_result(message)
         self._out_of_band[task_id] = message
 
-    def get_task_meta(self, task_id, backlog_limit=1000):
+    def get_task_meta(self, task_id: str,
+                      *, backlog_limit: int = 1000) -> Mapping:
         buffered = self._out_of_band.pop(task_id, None)
         if buffered:
             return self._set_cache_by_message(task_id, buffered)
@@ -262,13 +284,15 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
                 # result probably pending.
                 return {'status': states.PENDING, 'result': None}
 
-    def _set_cache_by_message(self, task_id, message):
+    def _set_cache_by_message(self,
+                              task_id: str, message: MessageT) -> Mapping:
         payload = self._cache[task_id] = self.meta_from_decoded(
             message.payload)
         return payload
 
-    def _slurp_from_queue(self, task_id, accept,
-                          limit=1000, no_ack=False):
+    def _slurp_from_queue(self, task_id: str, accept: Set[str],
+                          limit: int = 1000,
+                          no_ack: bool = False) -> Iterator[MessageT]:
         with self.app.pool.acquire_channel(block=True) as (_, channel):
             binding = self._create_binding(task_id)(channel)
             binding.declare()
@@ -281,7 +305,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             else:
                 raise self.BacklogLimitExceeded(task_id)
 
-    def _get_message_task_id(self, message):
+    def _get_message_task_id(self, message: MessageT) -> str:
         try:
             # try property first so we don't have to deserialize
             # the payload.
@@ -290,10 +314,10 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
             # message sent by old Celery version, need to deserialize.
             return message.payload['task_id']
 
-    def revive(self, channel):
-        pass
+    def revive(self, channel: ChannelT) -> None:
+        ...
 
-    def reload_task_result(self, task_id):
+    def reload_task_result(self, task_id: str) -> None:
         raise NotImplementedError(
             'reload_task_result is not supported by this backend.')
 
@@ -302,19 +326,19 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         raise NotImplementedError(
             'reload_group_result is not supported by this backend.')
 
-    def save_group(self, group_id, result):
+    def save_group(self, group_id: str, result: GroupResult) -> None:
         raise NotImplementedError(
             'save_group is not supported by this backend.')
 
-    def restore_group(self, group_id, cache=True):
+    def restore_group(self, group_id: str, cache: bool = True) -> GroupResult:
         raise NotImplementedError(
             'restore_group is not supported by this backend.')
 
-    def delete_group(self, group_id):
+    def delete_group(self, group_id: str) -> None:
         raise NotImplementedError(
             'delete_group is not supported by this backend.')
 
-    def __reduce__(self, args=(), kwargs={}):
+    def __reduce__(self, args: Tuple = (), kwargs: Dict = {}) -> Tuple:
         return super().__reduce__(args, dict(
             kwargs,
             connection=self._connection,
@@ -327,7 +351,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         ))
 
     @property
-    def binding(self):
+    def binding(self) -> EntityT:
         return self.Queue(
             self.oid, self.exchange, self.oid,
             durable=False,
@@ -336,6 +360,6 @@ class RPCBackend(base.Backend, AsyncBackendMixin):
         )
 
     @cached_property
-    def oid(self):
+    def oid(self) -> str:
         # cached here is the app OID: name of queue we receive results on.
         return self.app.oid

+ 9 - 3
celery/beat.py

@@ -8,10 +8,10 @@ import shelve
 import sys
 import traceback
 
-from collections import namedtuple
 from functools import total_ordering
 from threading import Event, Thread
 from time import monotonic
+from typing import NamedTuple
 
 from billiard import ensure_multiprocessing
 from billiard.context import Process
@@ -32,8 +32,6 @@ __all__ = [
     'PersistentScheduler', 'Service', 'EmbeddedService',
 ]
 
-event_t = namedtuple('event_t', ('time', 'priority', 'entry'))
-
 logger = get_logger(__name__)
 debug, info, error, warning = (logger.debug, logger.info,
                                logger.error, logger.warning)
@@ -41,6 +39,14 @@ debug, info, error, warning = (logger.debug, logger.info,
 DEFAULT_MAX_INTERVAL = 300  # 5 minutes
 
 
+class event_t(NamedTuple):
+    """Represents beat event in heap."""
+
+    time: float
+    priority: int
+    entry: 'ScheduleEntry'
+
+
 class SchedulingError(Exception):
     """An error occurred while scheduling a task."""
 

+ 48 - 47
celery/bootsteps.py

@@ -122,24 +122,24 @@ class StartStopStep(Step):
     #: Optional obj created by the :meth:`create` method.
     #: This is used by :class:`StartStopStep` to keep the
     #: original service object.
-    obj = None
+    obj: Any = None
 
-    def start(self, parent):
+    async def start(self, parent: Any) -> None:
         if self.obj:
-            return self.obj.start()
+            await self.obj.start()
 
-    def stop(self, parent):
+    async def stop(self, parent: Any) -> None:
         if self.obj:
-            return self.obj.stop()
+            await self.obj.stop()
 
-    def close(self, parent):
-        pass
+    async def close(self, parent: Any) -> None:
+        ...
 
-    def terminate(self, parent):
+    async def terminate(self, parent: Any) -> None:
         if self.obj:
-            return getattr(self.obj, 'terminate', self.obj.stop)()
+            await getattr(self.obj, 'terminate', self.obj.stop)()
 
-    def include(self, parent):
+    def include(self, parent: Any) -> bool:
         inc, ret = self._should_include(parent)
         if inc:
             self.obj = ret
@@ -156,27 +156,27 @@ class ConsumerStep(StartStopStep):
     def get_consumers(self, channel):
         raise NotImplementedError('missing get_consumers')
 
-    def start(self, c):
+    async def start(self, c):
         channel = c.connection.channel()
         self.consumers = self.get_consumers(channel)
         for consumer in self.consumers or []:
             consumer.consume()
 
-    def stop(self, c):
-        self._close(c, True)
+    async def stop(self, c):
+        await self._close(c, True)
 
-    def shutdown(self, c):
-        self._close(c, False)
+    async def shutdown(self, c):
+        await self._close(c, False)
 
-    def _close(self, c, cancel_consumers=True):
+    async def _close(self, c, cancel_consumers=True):
         channels = set()
         for consumer in self.consumers or []:
             if cancel_consumers:
-                ignore_errors(c.connection, consumer.cancel)
+                await ignore_errors(c.connection, consumer.cancel)
             if consumer.channel:
                 channels.add(consumer.channel)
         for channel in channels:
-            ignore_errors(c.connection, channel.close)
+            await ignore_errors(c.connection, channel.close)
 
 
 def _pre(ns: Step, fmt: str) -> str:
@@ -235,11 +235,11 @@ class Blueprint:
 
     GraphFormatter = StepFormatter
 
-    name = None                        # type: str
-    state = None                       # type: int
-    started = 0                        # type: int
-    default_steps = set()              # type: Set[Union[str, Step]]
-    state_to_name = {                  # type: Mapping[int, str]
+    name: str = None
+    state: int = None
+    started: int = 0
+    default_steps: Set[Union[str, Step]] = set()
+    state_to_name: Mapping[int, str] = {
         0: 'initializing',
         RUN: 'running',
         CLOSE: 'closing',
@@ -257,16 +257,16 @@ class Blueprint:
         self.on_close = on_close
         self.on_stopped = on_stopped
         self.shutdown_complete = Event()
-        self.steps = {} : Mapping[str, Step]
+        self.steps: Mapping[str, Step] = {}
 
-    def start(self, parent: Any) -> None:
+    async def start(self, parent: Any) -> None:
         self.state = RUN
         if self.on_start:
             self.on_start()
         for i, step in enumerate(s for s in parent.steps if s is not None):
             self._debug('Starting %s', step.alias)
             self.started = i + 1
-            step.start(parent)
+            await step.start(parent)
             logger.debug('^-- substep ok')
 
     def human_state(self) -> str:
@@ -278,23 +278,24 @@ class Blueprint:
             info.update(step.info(parent) or {})
         return info
 
-    def close(self, parent: Any) -> None:
+    async def close(self, parent: Any) -> None:
         if self.on_close:
             self.on_close()
-        self.send_all(parent, 'close', 'closing', reverse=False)
-
-    def restart(self,
-                parent: Any,
-                method: str = 'stop',
-                description: str = 'restarting',
-                propagate: bool = False) -> None:
-        self.send_all(parent, method, description, propagate=propagate)
-
-    def send_all(self, parent: Any, method: str,
-                 description: str = None,
-                 reverse: bool = True,
-                 propagate: bool = True,
-                 args: Sequence = ()) -> None:
+        await self.send_all(parent, 'close', 'closing', reverse=False)
+
+    async def restart(
+            self,
+            parent: Any,
+            method: str = 'stop',
+            description: str = 'restarting',
+            propagate: bool = False) -> None:
+        await self.send_all(parent, method, description, propagate=propagate)
+
+    async def send_all(self, parent: Any, method: str,
+                       description: str = None,
+                       reverse: bool = True,
+                       propagate: bool = True,
+                       args: Sequence = ()) -> None:
         description = description or method.replace('_', ' ')
         steps = reversed(parent.steps) if reverse else parent.steps
         for step in steps:
@@ -304,16 +305,16 @@ class Blueprint:
                     self._debug('%s %s...',
                                 description.capitalize(), step.alias)
                     try:
-                        fun(parent, *args)
+                        await fun(parent, *args)
                     except Exception as exc:  # pylint: disable=broad-except
                         if propagate:
                             raise
                         logger.exception(
                             'Error on %s %s: %r', description, step.alias, exc)
 
-    def stop(self, parent: Any,
-             close: bool = True,
-             terminate: bool = False) -> None:
+    async def stop(self, parent: Any,
+                   close: bool = True,
+                   terminate: bool = False) -> None:
         what = 'terminating' if terminate else 'stopping'
         if self.state in (CLOSE, TERMINATE):
             return
@@ -323,10 +324,10 @@ class Blueprint:
             self.state = TERMINATE
             self.shutdown_complete.set()
             return
-        self.close(parent)
+        await self.close(parent)
         self.state = CLOSE
 
-        self.restart(
+        await self.restart(
             parent, 'terminate' if terminate else 'stop',
             description=what, propagate=False,
         )

+ 254 - 125
celery/canvas.py

@@ -12,14 +12,21 @@ from collections import MutableSequence, deque
 from copy import deepcopy
 from functools import partial as _partial, reduce
 from operator import itemgetter
+from typing import (
+    Any, Callable, Dict, Iterable, Iterator,
+    List, Mapping, Sequence, Tuple, Optional,
+    cast,
+)
 
+from kombu.types import ConnectionT, ProducerT
 from kombu.utils.functional import fxrange, reprcall
 from kombu.utils.objects import cached_property
 from kombu.utils.uuid import uuid
-from vine import barrier
+from vine import Thenable, barrier
 
 from celery._state import current_app
 from celery.result import GroupResult
+from celery.types import AppT, ResultT, RouterT, SignatureT, TaskT
 from celery.utils import abstract
 from celery.utils.functional import (
     maybe_list, is_list, _regen, regen, chunks as _chunks,
@@ -34,7 +41,7 @@ __all__ = [
 ]
 
 
-def maybe_unroll_group(g):
+def maybe_unroll_group(g: SignatureT) -> Any:
     """Unroll group with only one member."""
     # Issue #1656
     try:
@@ -50,11 +57,11 @@ def maybe_unroll_group(g):
         return g.tasks[0] if size == 1 else g
 
 
-def task_name_from(task):
+def task_name_from(task: Any) -> str:
     return getattr(task, 'name', task)
 
 
-def _upgrade(fields, sig):
+def _upgrade(fields: Mapping, sig: SignatureT) -> SignatureT:
     """Used by custom signatures in .from_dict, to keep common fields."""
     sig.update(chord_size=fields.get('chord_size'))
     return sig
@@ -118,17 +125,18 @@ class Signature(dict):
     """
 
     TYPES = {}
-    _app = _type = None
+    _app: AppT = None
+    _type: TaskT = None
 
     @classmethod
-    def register_type(cls, name=None):
-        def _inner(subclass):
+    def register_type(cls, name: str = None) -> Callable:
+        def _inner(subclass: type) -> type:
             cls.TYPES[name or subclass.__name__] = subclass
             return subclass
         return _inner
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         typ = d.get('subtask_type')
         if typ:
             target_cls = cls.TYPES[typ]
@@ -136,9 +144,16 @@ class Signature(dict):
                 return target_cls.from_dict(d, app=app)
         return Signature(d, app=app)
 
-    def __init__(self, task=None, args=None, kwargs=None, options=None,
-                 type=None, subtask_type=None, immutable=False,
-                 app=None, **ex):
+    def __init__(self,
+                 task: str = None,
+                 args: Sequence = None,
+                 kwargs: Mapping = None,
+                 options: Mapping = None,
+                 type: TaskT = None,
+                 subtask_type: str = None,
+                 immutable: bool = False,
+                 app: AppT = None,
+                 **ex) -> None:
         self._app = app
 
         if isinstance(task, dict):
@@ -161,16 +176,17 @@ class Signature(dict):
                 chord_size=None,
             )
 
-    def __call__(self, *partial_args, **partial_kwargs):
+    def __call__(self, *partial_args, **partial_kwargs) -> Any:
         """Call the task directly (in the current process)."""
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
         return self.type(*args, **kwargs)
 
-    def delay(self, *partial_args, **partial_kwargs):
+    def delay(self, *partial_args, **partial_kwargs) -> ResultT:
         """Shortcut to :meth:`apply_async` using star arguments."""
         return self.apply_async(partial_args, partial_kwargs)
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self, args: Sequence = (), kwargs: Mapping = {},
+              **options) -> ResultT:
         """Call task locally.
 
         Same as :meth:`apply_async` but executed the task inline instead
@@ -180,7 +196,11 @@ class Signature(dict):
         args, kwargs, options = self._merge(args, kwargs, options)
         return self.type.apply(args, kwargs, **options)
 
-    def apply_async(self, args=(), kwargs={}, route_name=None, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    route_name: str = None,
+                    **options) -> None:
         """Apply this task asynchronously.
 
         Arguments:
@@ -209,7 +229,10 @@ class Signature(dict):
         #   Borks on this, as it's a property
         return _apply(args, kwargs, **options)
 
-    def _merge(self, args=(), kwargs={}, options={}, force=False):
+    def _merge(self, args: Sequence = (),
+               kwargs: Mapping = {},
+               options: Mapping = {},
+               force: bool = False) -> Tuple[Tuple, Dict, Dict]:
         if self.immutable and not force:
             return (self.args, self.kwargs,
                     dict(self.options, **options) if options else self.options)
@@ -217,7 +240,10 @@ class Signature(dict):
                 dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
                 dict(self.options, **options) if options else self.options)
 
-    def clone(self, args=(), kwargs={}, **opts):
+    def clone(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **opts) -> SignatureT:
         """Create a copy of this signature.
 
         Arguments:
@@ -240,8 +266,13 @@ class Signature(dict):
         return s
     partial = clone
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         """Finalize the signature by adding a concrete task id.
 
         The task won't be called and you shouldn't call the signature
@@ -273,7 +304,10 @@ class Signature(dict):
         return self.AsyncResult(tid)
     _freeze = freeze
 
-    def replace(self, args=None, kwargs=None, options=None):
+    def replace(self,
+                args: Sequence = None,
+                kwargs: Mapping = None,
+                options: Mapping = None) -> SignatureT:
         """Replace the args, kwargs or options set for this signature.
 
         These are only replaced if the argument for the section is
@@ -288,7 +322,7 @@ class Signature(dict):
             s.options = options
         return s
 
-    def set(self, immutable=None, **options):
+    def set(self, immutable: bool = None, **options) -> SignatureT:
         """Set arbitrary execution options (same as ``.options.update(…)``).
 
         Returns:
@@ -300,44 +334,45 @@ class Signature(dict):
         self.options.update(options)
         return self
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         self.immutable = immutable
 
-    def _with_list_option(self, key):
+    def _with_list_option(self, key: str) -> Any:
         items = self.options.setdefault(key, [])
         if not isinstance(items, MutableSequence):
             items = self.options[key] = [items]
         return items
 
-    def append_to_list_option(self, key, value):
+    def append_to_list_option(self, key: str, value: Any) -> Any:
         items = self._with_list_option(key)
         if value not in items:
             items.append(value)
         return value
 
-    def extend_list_option(self, key, value):
+    def extend_list_option(self, key: str, value: Any) -> None:
         items = self._with_list_option(key)
         items.extend(maybe_list(value))
 
-    def link(self, callback):
+    def link(self, callback: SignatureT) -> SignatureT:
         """Add callback task to be applied if this task succeeds.
 
         Returns:
             Signature: the argument passed, for chaining
                 or use with :func:`~functools.reduce`.
         """
-        return self.append_to_list_option('link', callback)
+        return cast(SignatureT, self.append_to_list_option('link', callback))
 
-    def link_error(self, errback):
+    def link_error(self, errback: SignatureT) -> SignatureT:
         """Add callback task to be applied on error in task execution.
 
         Returns:
             Signature: the argument passed, for chaining
                 or use with :func:`~functools.reduce`.
         """
-        return self.append_to_list_option('link_error', errback)
+        return cast(SignatureT,
+                    self.append_to_list_option('link_error', errback))
 
-    def on_error(self, errback):
+    def on_error(self, errback: SignatureT) -> SignatureT:
         """Version of :meth:`link_error` that supports chaining.
 
         on_error chains the original signature, not the errback so::
@@ -350,7 +385,7 @@ class Signature(dict):
         self.link_error(errback)
         return self
 
-    def flatten_links(self):
+    def flatten_links(self) -> List[SignatureT]:
         """Return a recursive list of dependencies.
 
         "unchain" if you will, but with links intact.
@@ -361,7 +396,7 @@ class Signature(dict):
                 for link in maybe_list(self.options.get('link')) or [])
         )))
 
-    def __or__(self, other):
+    def __or__(self, other: SignatureT) -> SignatureT:
         # These could be implemented in each individual class,
         # I'm sure, but for now we have this.
         if isinstance(other, chord) and len(other.tasks) == 1:
@@ -431,7 +466,7 @@ class Signature(dict):
             return _chain(self, other, app=self._app)
         return NotImplemented
 
-    def election(self):
+    def election(self) -> ResultT:
         type = self.type
         app = type.app
         tid = self.options.get('task_id') or uuid()
@@ -442,50 +477,50 @@ class Signature(dict):
                                  connection=P.connection)
             return type.AsyncResult(tid)
 
-    def reprcall(self, *args, **kwargs):
+    def reprcall(self, *args, **kwargs) -> str:
         args, kwargs, _ = self._merge(args, kwargs, {}, force=True)
         return reprcall(self['task'], args, kwargs)
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo) -> SignatureT:
         memo[id(self)] = self
         return dict(self)
 
-    def __invert__(self):
+    def __invert__(self) -> Any:
         return self.apply_async().get()
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         # for serialization, the task type is lazily loaded,
         # and not stored in the dict itself.
         return signature, (dict(self),)
 
-    def __json__(self):
+    def __json__(self) -> Any:
         return dict(self)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.reprcall()
 
     @property
-    def name(self):
+    def name(self) -> str:
         # for duck typing compatibility with Task.name
         return self.task
 
     @cached_property
-    def type(self):
+    def type(self) -> TaskT:
         return self._type or self.app.tasks[self['task']]
 
     @cached_property
-    def app(self):
+    def app(self) -> AppT:
         return self._app or current_app
 
     @cached_property
-    def AsyncResult(self):
+    def AsyncResult(self) -> type:
         try:
             return self.type.AsyncResult
         except KeyError:  # task not registered
             return self.app.AsyncResult
 
     @cached_property
-    def _apply_async(self):
+    def _apply_async(self) -> Callable:
         try:
             return self.type.apply_async
         except KeyError:
@@ -509,7 +544,7 @@ class _chain(Signature):
     tasks = getitem_property('kwargs.tasks', 'Tasks in chain.')
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         tasks = d['kwargs']['tasks']
         if tasks:
             if isinstance(tasks, tuple):  # aaaargh
@@ -518,7 +553,7 @@ class _chain(Signature):
             tasks[0] = maybe_signature(tasks[0], app=app)
         return _upgrade(d, _chain(tasks, app=app, **d['options']))
 
-    def __init__(self, *tasks, **options):
+    def __init__(self, *tasks, **options) -> None:
         tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
                  else tasks)
         Signature.__init__(
@@ -528,11 +563,11 @@ class _chain(Signature):
         self.subtask_type = 'chain'
         self._frozen = None
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         if self.tasks:
             return self.apply_async(args, kwargs)
 
-    def clone(self, *args, **kwargs):
+    def clone(self, *args, **kwargs) -> SignatureT:
         to_signature = maybe_signature
         s = Signature.clone(self, *args, **kwargs)
         s.kwargs['tasks'] = [
@@ -541,7 +576,10 @@ class _chain(Signature):
         ]
         return s
 
-    def apply_async(self, args=(), kwargs={}, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **options) -> ResultT:
         # python is best at unpacking kwargs, so .run is here to do that.
         app = self.app
         if app.conf.task_always_eager:
@@ -549,9 +587,19 @@ class _chain(Signature):
         return self.run(args, kwargs, app=app, **(
             dict(self.options, **options) if options else self.options))
 
-    def run(self, args=(), kwargs={}, group_id=None, chord=None,
-            task_id=None, link=None, link_error=None,
-            producer=None, root_id=None, parent_id=None, app=None, **options):
+    def run(self,
+            args: Sequence = (),
+            kwargs: Mapping = {},
+            group_id: str = None,
+            chord: SignatureT = None,
+            task_id: str = None,
+            link: Sequence[SignatureT] = None,
+            link_error: Sequence[SignatureT] = None,
+            producer: ProducerT = None,
+            root_id: str = None,
+            parent_id: str = None,
+            app: AppT = None,
+            **options) -> ResultT:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         app = app or self.app
@@ -580,8 +628,13 @@ class _chain(Signature):
             first_task.apply_async(**options)
             return results[0]
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         _, results = self._frozen = self.prepare_steps(
@@ -590,10 +643,17 @@ class _chain(Signature):
         )
         return results[0]
 
-    def prepare_steps(self, args, tasks,
-                      root_id=None, parent_id=None, link_error=None, app=None,
-                      last_task_id=None, group_id=None, chord_body=None,
-                      clone=True, from_dict=Signature.from_dict):
+    def prepare_steps(self, args: Sequence, tasks: Sequence[SignatureT],
+                      root_id: str = None,
+                      parent_id: str = None,
+                      link_error: Sequence[SignatureT] = None,
+                      app: AppT = None,
+                      last_task_id: str = None,
+                      group_id: str = None,
+                      chord_body: SignatureT = None,
+                      clone: bool = True,
+                      from_dict: Callable = Signature.from_dict
+                      ) -> Tuple[Sequence[SignatureT], Sequence[ResultT]]:
         app = app or self.app
         # use chain message field for protocol 2 and later.
         # this avoids pickle blowing the stack on the recursion
@@ -693,7 +753,10 @@ class _chain(Signature):
                 prev_res = node
         return tasks, results
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **options) -> ResultT:
         last, fargs = None, args
         for task in self.tasks:
             res = task.clone(fargs).apply(
@@ -702,7 +765,7 @@ class _chain(Signature):
         return last
 
     @property
-    def app(self):
+    def app(self) -> AppT:
         app = self._app
         if app is None:
             try:
@@ -711,7 +774,7 @@ class _chain(Signature):
                 pass
         return app or current_app
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if not self.tasks:
             return '<{0}@{1:#x}: empty>'.format(
                 type(self).__name__, id(self))
@@ -769,7 +832,7 @@ class chain(_chain):
     """
 
     # could be function, but must be able to reference as :class:`chain`.
-    def __new__(cls, *tasks, **kwargs):
+    def __new__(cls, *tasks, **kwargs) -> SignatureT:
         # This forces `chain(X, Y, Z)` to work the same way as `X | Y | Z`
         if not kwargs and tasks:
             if len(tasks) == 1 and is_list(tasks[0]):
@@ -784,18 +847,21 @@ class _basemap(Signature):
     _unpack_args = itemgetter('task', 'it')
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
             d, cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']),
         )
 
-    def __init__(self, task, it, **options):
+    def __init__(self, task: str, it: Iterable, **options) -> None:
         Signature.__init__(
             self, self._task_name, (),
             {'task': task, 'it': regen(it)}, immutable=True, **options
         )
 
-    def apply_async(self, args=(), kwargs={}, **opts):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **opts) -> ResultT:
         # need to evaluate generators
         task, it = self._unpack_args(self.kwargs)
         return self.type.apply_async(
@@ -815,7 +881,7 @@ class xmap(_basemap):
 
     _task_name = 'celery.map'
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         task, it = self._unpack_args(self.kwargs)
         return '[{0}(x) for x in {1}]'.format(
             task.task, truncate(repr(it), 100))
@@ -827,7 +893,7 @@ class xstarmap(_basemap):
 
     _task_name = 'celery.starmap'
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         task, it = self._unpack_args(self.kwargs)
         return '[{0}(*x) for x in {1}]'.format(
             task.task, truncate(repr(it), 100))
@@ -840,29 +906,32 @@ class chunks(Signature):
     _unpack_args = itemgetter('task', 'it', 'n')
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
             d, chunks(*cls._unpack_args(
                 d['kwargs']), app=app, **d['options']),
         )
 
-    def __init__(self, task, it, n, **options):
+    def __init__(self, task: str, it: Iterable, n: int, **options) -> None:
         Signature.__init__(
             self, 'celery.chunks', (),
             {'task': task, 'it': regen(it), 'n': n},
             immutable=True, **options
         )
 
-    def __call__(self, **options):
+    def __call__(self, **options) -> ResultT:
         return self.apply_async(**options)
 
-    def apply_async(self, args=(), kwargs={}, **opts):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = {},
+                    **opts) -> ResultT:
         return self.group().apply_async(
             args, kwargs,
             route_name=task_name_from(self.kwargs.get('task')), **opts
         )
 
-    def group(self):
+    def group(self) -> SignatureT:
         # need to evaluate generators
         task, it, n = self._unpack_args(self.kwargs)
         return group((xstarmap(task, part, app=self._app)
@@ -870,11 +939,12 @@ class chunks(Signature):
                      app=self._app)
 
     @classmethod
-    def apply_chunks(cls, task, it, n, app=None):
+    def apply_chunks(cls, task: str, it: Iterable, n: int,
+                     app: AppT = None) -> ResultT:
         return cls(task, it, n, app=app)()
 
 
-def _maybe_group(tasks, app):
+def _maybe_group(tasks: Any, app: AppT) -> Sequence[SignatureT]:
     if isinstance(tasks, dict):
         tasks = signature(tasks, app=app)
 
@@ -921,12 +991,12 @@ class group(Signature):
     tasks = getitem_property('kwargs.tasks', 'Tasks in group.')
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         return _upgrade(
             d, group(d['kwargs']['tasks'], app=app, **d['options']),
         )
 
-    def __init__(self, *tasks, **options):
+    def __init__(self, *tasks, **options) -> None:
         if len(tasks) == 1:
             tasks = tasks[0]
             if isinstance(tasks, group):
@@ -938,17 +1008,27 @@ class group(Signature):
         )
         self.subtask_type = 'group'
 
-    def __call__(self, *partial_args, **options):
+    def __call__(self, *partial_args, **options) -> ResultT:
         return self.apply_async(partial_args, **options)
 
-    def skew(self, start=1.0, stop=None, step=1.0):
+    def skew(self,
+             start: float = 1.0,
+             stop: float = None,
+             step: float = 1.0) -> SignatureT:
         it = fxrange(start, stop, step, repeatlast=True)
         for task in self.tasks:
             task.set(countdown=next(it))
         return self
 
-    def apply_async(self, args=(), kwargs=None, add_to_parent=True,
-                    producer=None, link=None, link_error=None, **options):
+    def apply_async(self,
+                    args: Sequence = (),
+                    kwargs: Mapping = None,
+                    *,
+                    add_to_parent: bool = True,
+                    producer: ProducerT = None,
+                    link: Sequence[SignatureT] = None,
+                    link_error: Sequence[SignatureT] = None,
+                    **options) -> ResultT:
         if link is not None:
             raise TypeError('Cannot add link to group: use a chord')
         if link_error is not None:
@@ -982,7 +1062,10 @@ class group(Signature):
             parent_task.add_trail(result)
         return result
 
-    def apply(self, args=(), kwargs={}, **options):
+    def apply(self,
+              args: Sequence = (),
+              kwargs: Mapping = {},
+              **options) -> ResultT:
         app = self.app
         if not self.tasks:
             return self.freeze()  # empty group returns GroupResult
@@ -992,23 +1075,30 @@ class group(Signature):
             sig.apply(args=args, kwargs=kwargs, **options) for sig, _ in tasks
         ])
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         for task in self.tasks:
             task.set_immutable(immutable)
 
-    def link(self, sig):
+    def link(self, sig: SignatureT) -> SignatureT:
         # Simply link to first task
         sig = sig.clone().set(immutable=True)
         return self.tasks[0].link(sig)
 
-    def link_error(self, sig):
+    def link_error(self, sig: SignatureT) -> SignatureT:
         sig = sig.clone().set(immutable=True)
         return self.tasks[0].link_error(sig)
 
-    def _prepared(self, tasks, partial_args, group_id, root_id, app,
-                  CallableSignature=abstract.CallableSignature,
-                  from_dict=Signature.from_dict,
-                  isinstance=isinstance, tuple=tuple):
+    def _prepared(self, tasks: Sequence[SignatureT],
+                  partial_args: Sequence,
+                  group_id: str,
+                  root_id: str,
+                  app: AppT,
+                  *,
+                  CallableSignature: Callable = abstract.CallableSignature,
+                  from_dict: Callable = Signature.from_dict,
+                  isinstance: Callable = isinstance,
+                  tuple: Callable = tuple
+                  ) -> Iterable[Tuple[SignatureT, ResultT]]:
         for task in tasks:
             if isinstance(task, CallableSignature):
                 # local sigs are always of type Signature, and we
@@ -1029,9 +1119,15 @@ class group(Signature):
                     task.args = tuple(partial_args) + tuple(task.args)
                 yield task, task.freeze(group_id=group_id, root_id=root_id)
 
-    def _apply_tasks(self, tasks, producer=None, app=None, p=None,
-                     add_to_parent=None, chord=None,
-                     args=None, kwargs=None, **options):
+    def _apply_tasks(self, tasks: Sequence[SignatureT],
+                     producer: ProducerT = None,
+                     app: AppT = None,
+                     p: Thenable = None,
+                     add_to_parent: bool = None,
+                     chord: SignatureT = None,
+                     args: Sequence = None,
+                     kwargs: Mapping = None,
+                     **options) -> Iterable[ResultT]:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         app = app or self.app
@@ -1053,7 +1149,7 @@ class group(Signature):
                     res.then(p, weak=True)
                 yield res  # <-- r.parent, etc set in the frozen result.
 
-    def _freeze_gid(self, options):
+    def _freeze_gid(self, options: Mapping) -> Tuple[Mapping, str, str]:
         # remove task_id and use that as the group_id,
         # if we don't remove it then every task will have the same id...
         options = dict(self.options, **options)
@@ -1061,8 +1157,13 @@ class group(Signature):
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         opts = self.options
@@ -1089,7 +1190,9 @@ class group(Signature):
         return self.app.GroupResult(gid, results)
     _freeze = freeze
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
+    def _freeze_unroll(self, new_tasks: Sequence[SignatureT],
+                       group_id: str, chord: SignatureT,
+                       root_id: str, parent_id: str) -> Iterator[ResultT]:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         stack = deque(self.tasks)
@@ -1103,18 +1206,18 @@ class group(Signature):
                                   chord=chord, root_id=root_id,
                                   parent_id=parent_id)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.tasks:
             return remove_repeating_from_task(
                 self.tasks[0]['task'],
                 'group({0.tasks!r})'.format(self))
         return 'group(<empty>)'
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.tasks)
 
     @property
-    def app(self):
+    def app(self) -> AppT:
         app = self._app
         if app is None:
             try:
@@ -1153,18 +1256,25 @@ class chord(Signature):
     """
 
     @classmethod
-    def from_dict(cls, d, app=None):
+    def from_dict(cls, d: Mapping, app: AppT = None) -> SignatureT:
         args, d['kwargs'] = cls._unpack_args(**d['kwargs'])
         return _upgrade(d, cls(*args, app=app, **d))
 
     @staticmethod
-    def _unpack_args(header=None, body=None, **kwargs):
+    def _unpack_args(header: Sequence[SignatureT] = None,
+                     body: SignatureT = None,
+                     **kwargs) -> Tuple[Tuple, Mapping]:
         # Python signatures are better at extracting keys from dicts
         # than manually popping things off.
         return (header, body), kwargs
 
-    def __init__(self, header, body=None, task='celery.chord',
-                 args=(), kwargs={}, app=None, **options):
+    def __init__(self, header: Sequence[SignatureT],
+                 body: SignatureT = None,
+                 task: str = 'celery.chord',
+                 args: Sequence = (),
+                 kwargs: Mapping = {},
+                 app: AppT = None,
+                 **options) -> None:
         Signature.__init__(
             self, task, args,
             dict(kwargs=kwargs, header=_maybe_group(header, app),
@@ -1172,11 +1282,16 @@ class chord(Signature):
         )
         self.subtask_type = 'chord'
 
-    def __call__(self, body=None, **options):
+    def __call__(self, body: SignatureT = None, **options) -> ResultT:
         return self.apply_async((), {'body': body} if body else {}, **options)
 
-    def freeze(self, _id=None, group_id=None, chord=None,
-               root_id=None, parent_id=None):
+    def freeze(self,
+               _id: str = None,
+               *,
+               group_id: str = None,
+               chord: SignatureT = None,
+               root_id: str = None,
+               parent_id: str = None) -> ResultT:
         # pylint: disable=redefined-outer-name
         #   XXX chord is also a class in outer scope.
         if not isinstance(self.tasks, group):
@@ -1200,9 +1315,15 @@ class chord(Signature):
         self.id = self.tasks.id
         return bodyres
 
-    def apply_async(self, args=(), kwargs={}, task_id=None,
-                    producer=None, connection=None,
-                    router=None, result_cls=None, **options):
+    def apply_async(self,
+                    args: Sequence =(),
+                    kwargs: Mapping = {},
+                    task_id: str = None,
+                    producer: ProducerT = None,
+                    connection: ConnectionT = None,
+                    router: RouterT = None,
+                    result_cls: type = None,
+                    **options) -> ResultT:
         kwargs = kwargs or {}
         args = (tuple(args) + tuple(self.args)
                 if args and not self.immutable else self.args)
@@ -1223,7 +1344,10 @@ class chord(Signature):
         # chord([A, B, ...], C)
         return self.run(tasks, body, args, task_id=task_id, **options)
 
-    def apply(self, args=(), kwargs={}, propagate=True, body=None, **options):
+    def apply(self, args: Sequence = (), kwargs: Mapping = {},
+              propagate: bool = True,
+              body: SignatureT = None,
+              **options) -> ResultT:
         body = self.body if body is None else body
         tasks = (self.tasks.clone() if isinstance(self.tasks, group)
                  else group(self.tasks, app=self.app))
@@ -1231,7 +1355,8 @@ class chord(Signature):
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
         )
 
-    def _traverse_tasks(self, tasks, value=None):
+    def _traverse_tasks(self, tasks: Sequence[SignatureT],
+                        value: Any = None) -> Iterator[SignatureT]:
         stack = deque(list(tasks))
         while stack:
             task = stack.popleft()
@@ -1240,12 +1365,14 @@ class chord(Signature):
             else:
                 yield task if value is None else value
 
-    def __length_hint__(self):
+    def __length_hint__(self) -> int:
         return sum(self._traverse_tasks(self.tasks, 1))
 
-    def run(self, header, body, partial_args, app=None, interval=None,
-            countdown=1, max_retries=None, eager=False,
-            task_id=None, **options):
+    def run(self, header: SignatureT, body: SignatureT, partial_args: Sequence,
+            app: AppT = None, interval: float = None,
+            countdown: float = 1, max_retries: int = None,
+            eager: bool = False,
+            task_id: str = None, **options) -> ResultT:
         app = app or self._get_app(body)
         group_id = header.options.get('task_id') or uuid()
         root_id = body.options.get('root_id')
@@ -1267,7 +1394,7 @@ class chord(Signature):
         bodyres.parent = parent
         return bodyres
 
-    def clone(self, *args, **kwargs):
+    def clone(self, *args, **kwargs) -> SignatureT:
         s = Signature.clone(self, *args, **kwargs)
         # need to make copy of body
         try:
@@ -1276,20 +1403,20 @@ class chord(Signature):
             pass
         return s
 
-    def link(self, callback):
+    def link(self, callback: SignatureT) -> SignatureT:
         self.body.link(callback)
         return callback
 
-    def link_error(self, errback):
+    def link_error(self, errback: SignatureT) -> SignatureT:
         self.body.link_error(errback)
         return errback
 
-    def set_immutable(self, immutable):
+    def set_immutable(self, immutable: bool) -> None:
         # changes mutability of header only, not callback.
         for task in self.tasks:
             task.set_immutable(immutable)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if self.body:
             if isinstance(self.body, _chain):
                 return remove_repeating_from_task(
@@ -1304,10 +1431,10 @@ class chord(Signature):
         return '<chord without body: {0.tasks!r}>'.format(self)
 
     @cached_property
-    def app(self):
+    def app(self) -> AppT:
         return self._get_app(self.body)
 
-    def _get_app(self, body=None):
+    def _get_app(self, body: SignatureT = None) -> AppT:
         app = self._app
         if app is None:
             try:
@@ -1323,7 +1450,7 @@ class chord(Signature):
     body = getitem_property('kwargs.body', 'Body task of chord.')
 
 
-def signature(varies, *args, **kwargs):
+def signature(varies: SignatureT, *args, **kwargs) -> SignatureT:
     """Create new signature.
 
     - if the first argument is a signature already then it's cloned.
@@ -1340,11 +1467,13 @@ def signature(varies, *args, **kwargs):
     return Signature(varies, *args, **kwargs)
 
 
-def maybe_signature(d, app=None, clone=False):
+def maybe_signature(d: Optional[SignatureT],
+                    app: AppT = None,
+                    clone: bool = False) -> Optional[SignatureT]:
     """Ensure obj is a signature, or None.
 
     Arguments:
-        d (Optional[Union[abstract.CallableSignature, Mapping]]):
+        d (Optional[SignatureT]):
             Signature or dict-serialized signature.
         app (celery.Celery):
             App to bind signature to.
@@ -1353,7 +1482,7 @@ def maybe_signature(d, app=None, clone=False):
            will be cloned when this flag is enabled.
 
     Returns:
-        Optional[abstract.CallableSignature]
+        Optional[SignatureT]
     """
     if d is not None:
         if isinstance(d, abstract.CallableSignature):

+ 9 - 2
celery/concurrency/asynpool.py

@@ -23,11 +23,12 @@ import socket
 import struct
 import time
 
-from collections import Counter, deque, namedtuple
+from collections import Counter, deque
 from io import BytesIO
 from numbers import Integral
 from pickle import HIGHEST_PROTOCOL
 from time import sleep
+from typing import NamedTuple
 from weakref import WeakValueDictionary, ref
 
 from billiard.pool import RUN, TERMINATE, ACK, NACK, WorkersJoined
@@ -91,7 +92,13 @@ SCHED_STRATEGIES = {
 }
 SCHED_STRATEGY_TO_NAME = {v: k for k, v in SCHED_STRATEGIES.items()}
 
-Ack = namedtuple('Ack', ('id', 'fd', 'payload'))
+
+class Ack(NamedTuple):
+    """Ack message payload."""
+
+    id: int
+    fd: int
+    payload: bytes
 
 
 def gen_not_started(gen):

+ 15 - 26
celery/contrib/pytest.py

@@ -12,12 +12,11 @@ NO_WORKER = os.environ.get('NO_WORKER')
 
 
 @contextmanager
-def _create_app(request,
-                enable_logging=False,
-                use_trap=False,
-                parameters={},
-                **config):
-    # type: (Any, **Any) -> Celery
+def _create_app(request: Any,
+                enable_logging: bool = False,
+                use_trap: bool = False,
+                parameters: Mapping = {},
+                **config) -> Celery:
     """Utility context used to setup Celery app for pytest fixtures."""
     test_app = TestApp(
         set_as_current=False,
@@ -41,8 +40,7 @@ def _create_app(request,
 
 
 @pytest.fixture(scope='session')
-def use_celery_app_trap():
-    # type: () -> bool
+def use_celery_app_trap() -> bool:
     """You can override this fixture to enable the app trap.
 
     The app trap raises an exception whenever something attempts
@@ -56,8 +54,7 @@ def celery_session_app(request,
                        celery_config,
                        celery_parameters,
                        celery_enable_logging,
-                       use_celery_app_trap):
-    # type: (Any) -> Celery
+                       use_celery_app_trap) -> Celery:
     """Session Fixture: Return app for session fixtures."""
     mark = request.node.get_marker('celery')
     config = dict(celery_config, **mark.kwargs if mark else {})
@@ -77,8 +74,7 @@ def celery_session_worker(request,
                           celery_session_app,
                           celery_includes,
                           celery_worker_pool,
-                          celery_worker_parameters):
-    # type: (Any, Celery, Sequence[str], str) -> WorkController
+                          celery_worker_parameters) -> WorkController:
     """Session Fixture: Start worker that lives throughout test suite."""
     if not NO_WORKER:
         for module in celery_includes:
@@ -90,15 +86,13 @@ def celery_session_worker(request,
 
 
 @pytest.fixture(scope='session')
-def celery_enable_logging():
-    # type: () -> bool
+def celery_enable_logging() -> bool:
     """You can override this fixture to enable logging."""
     return False
 
 
 @pytest.fixture(scope='session')
-def celery_includes():
-    # type: () -> Sequence[str]
+def celery_includes() -> Sequence[str]:
     """You can override this include modules when a worker start.
 
     You can have this return a list of module names to import,
@@ -108,8 +102,7 @@ def celery_includes():
 
 
 @pytest.fixture(scope='session')
-def celery_worker_pool():
-    # type: () -> Union[str, Any]
+def celery_worker_pool() -> Union[str, type]:
     """You can override this fixture to set the worker pool.
 
     The "solo" pool is used by default, but you can set this to
@@ -119,8 +112,7 @@ def celery_worker_pool():
 
 
 @pytest.fixture(scope='session')
-def celery_config():
-    # type: () -> Mapping[str, Any]
+def celery_config() -> Mapping[str, Any]:
     """Redefine this fixture to configure the test Celery app.
 
     The config returned by your fixture will then be used
@@ -130,8 +122,7 @@ def celery_config():
 
 
 @pytest.fixture(scope='session')
-def celery_parameters():
-    # type: () -> Mapping[str, Any]
+def celery_parameters() -> Mapping[str, Any]:
     """Redefine this fixture to change the init parameters of test Celery app.
 
     The dict returned by your fixture will then be used
@@ -141,8 +132,7 @@ def celery_parameters():
 
 
 @pytest.fixture(scope='session')
-def celery_worker_parameters():
-    # type: () -> Mapping[str, Any]
+def celery_worker_parameters() -> Mapping[str, Any]:
     """Redefine this fixture to change the init parameters of Celery workers.
 
     This can be used e. g. to define queues the worker will consume tasks from.
@@ -175,8 +165,7 @@ def celery_worker(request,
                   celery_app,
                   celery_includes,
                   celery_worker_pool,
-                  celery_worker_parameters):
-    # type: (Any, Celery, Sequence[str], str) -> WorkController
+                  celery_worker_parameters) -> WorkController:
     """Fixture: Start worker in a thread, stop it when the test returns."""
     if not NO_WORKER:
         for module in celery_includes:

+ 126 - 82
celery/events/state.py

@@ -23,6 +23,7 @@ from decimal import Decimal
 from itertools import islice
 from operator import itemgetter
 from time import time
+from typing import Any, Iterator, List, Mapping, Set, Sequence, Tuple
 from weakref import WeakSet, ref
 
 from kombu.clocks import timetuple
@@ -64,7 +65,7 @@ R_WORKER = '<Worker: {0.hostname} ({0.status_string} clock:{0.clock})'
 R_TASK = '<Task: {0.name}({0.uuid}) {0.state} clock:{0.clock}>'
 
 #: Mapping of task event names to task state.
-TASK_EVENT_TO_STATE = {
+TASK_EVENT_TO_STATE: Mapping[str, str] = {
     'sent': states.PENDING,
     'received': states.RECEIVED,
     'started': states.STARTED,
@@ -92,26 +93,31 @@ class CallableDefaultdict(defaultdict):
         ...     'proj.tasks.add', reverse=True))
     """
 
-    def __init__(self, fun, *args, **kwargs):
+    def __init__(self, fun: Callable, *args, **kwargs) -> None:
         self.fun = fun
         super().__init__(*args, **kwargs)
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args, **kwargs) -> Any:
         return self.fun(*args, **kwargs)
 Callable.register(CallableDefaultdict)  # noqa: E305
 
 
 @memoize(maxsize=1000, keyfun=lambda a, _: a[0])
-def _warn_drift(hostname, drift, local_received, timestamp):
+def _warn_drift(hostname: str, drift: float,
+                local_received: float, timestamp: float) -> None:
     # we use memoize here so the warning is only logged once per hostname
     warn(DRIFT_WARNING, hostname, drift,
          datetime.fromtimestamp(local_received),
          datetime.fromtimestamp(timestamp))
 
 
-def heartbeat_expires(timestamp, freq=60,
-                      expire_window=HEARTBEAT_EXPIRE_WINDOW,
-                      Decimal=Decimal, float=float, isinstance=isinstance):
+def heartbeat_expires(timestamp: float,
+                      freq: float = 60.0,
+                      expire_window: float = HEARTBEAT_EXPIRE_WINDOW,
+                      *,
+                      Decimal: Callable = Decimal,
+                      float: Callable = float,
+                      isinstance: Callable = isinstance) -> float:
     """Return time when heartbeat expires."""
     # some json implementations returns decimal.Decimal objects,
     # which aren't compatible with float.
@@ -121,26 +127,26 @@ def heartbeat_expires(timestamp, freq=60,
     return timestamp + (freq * (expire_window / 1e2))
 
 
-def _depickle_task(cls, fields):
+def _depickle_task(cls: type, fields: Mapping) -> 'Task':
     return cls(**fields)
 
 
-def with_unique_field(attr):
+def with_unique_field(attr: str) -> Callable:
 
-    def _decorate_cls(cls):
+    def _decorate_cls(cls: type) -> type:
 
-        def __eq__(this, other):
+        def __eq__(this, other: Any) -> bool:
             if isinstance(other, this.__class__):
                 return getattr(this, attr) == getattr(other, attr)
             return NotImplemented
         cls.__eq__ = __eq__
 
-        def __ne__(this, other):
+        def __ne__(this, other: Any) -> bool:
             res = this.__eq__(other)
             return True if res is NotImplemented else not res
         cls.__ne__ = __ne__
 
-        def __hash__(this):
+        def __hash__(this: Any) -> int:
             return hash(getattr(this, attr))
         cls.__hash__ = __hash__
 
@@ -161,9 +167,18 @@ class Worker:
     if not PYPY:  # pragma: no cover
         __slots__ = _fields + ('event', '__dict__', '__weakref__')
 
-    def __init__(self, hostname=None, pid=None, freq=60,
-                 heartbeats=None, clock=0, active=None, processed=None,
-                 loadavg=None, sw_ident=None, sw_ver=None, sw_sys=None):
+    def __init__(self,
+                 hostname: str = None,
+                 pid: int = None,
+                 freq: float = 60.0,
+                 heartbeats: Sequence[float] = None,
+                 clock: int = 0,
+                 active: int = None,
+                 processed: int = None,
+                 loadavg: Tuple[float, float, float] = None,
+                 sw_ident: str = None,
+                 sw_ver: str = None,
+                 sw_sys: str = None) -> None:
         self.hostname = hostname
         self.pid = pid
         self.freq = freq
@@ -177,23 +192,28 @@ class Worker:
         self.sw_sys = sw_sys
         self.event = self._create_event_handler()
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return self.__class__, (self.hostname, self.pid, self.freq,
                                 self.heartbeats, self.clock, self.active,
                                 self.processed, self.loadavg, self.sw_ident,
                                 self.sw_ver, self.sw_sys)
 
-    def _create_event_handler(self):
+    def _create_event_handler(self) -> Callable:
         _set = object.__setattr__
         hbmax = self.heartbeat_max
         heartbeats = self.heartbeats
         hb_pop = self.heartbeats.pop
         hb_append = self.heartbeats.append
 
-        def event(type_, timestamp=None,
-                  local_received=None, fields=None,
-                  max_drift=HEARTBEAT_DRIFT_MAX, abs=abs, int=int,
-                  insort=bisect.insort, len=len):
+        def event(type_: str,
+                  timestamp: float = None,
+                  local_received: float = None,
+                  fields: Mapping = None,
+                  max_drift: float = HEARTBEAT_DRIFT_MAX,
+                  abs: Callable = abs,
+                  int: Callable = int,
+                  insort: Callable = bisect.insort,
+                  len: Callable = len) -> None:
             fields = fields or {}
             for k, v in fields.items():
                 _set(self, k, v)
@@ -216,28 +236,28 @@ class Worker:
                         insort(heartbeats, local_received)
         return event
 
-    def update(self, f, **kw):
+    def update(self, f: Mapping, **kw) -> None:
         for k, v in (dict(f, **kw) if kw else f).items():
             setattr(self, k, v)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_WORKER.format(self)
 
     @property
-    def status_string(self):
+    def status_string(self) -> str:
         return 'ONLINE' if self.alive else 'OFFLINE'
 
     @property
-    def heartbeat_expires(self):
+    def heartbeat_expires(self) -> float:
         return heartbeat_expires(self.heartbeats[-1],
                                  self.freq, self.expire_window)
 
     @property
-    def alive(self, nowfun=time):
+    def alive(self, *, nowfun: Callable = time) -> bool:
         return bool(self.heartbeats and nowfun() < self.heartbeat_expires)
 
     @property
-    def id(self):
+    def id(self) -> str:
         return '{0.hostname}.{0.pid}'.format(self)
 
 
@@ -285,7 +305,11 @@ class Task:
         'root_id', 'parent_id',
     )
 
-    def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
+    def __init__(self,
+                 uuid: str = None,
+                 cluster_state: 'State' = None,
+                 children: Sequence[str] = None,
+                 **kwargs) -> None:
         self.uuid = uuid
         self.cluster_state = cluster_state
         self.children = WeakSet(
@@ -301,10 +325,14 @@ class Task:
         if kwargs:
             self.__dict__.update(kwargs)
 
-    def event(self, type_, timestamp=None, local_received=None, fields=None,
-              precedence=states.precedence,
-              setattr=setattr, task_event_to_state=TASK_EVENT_TO_STATE.get,
-              RETRY=states.RETRY):
+    def event(self, type_: str,
+              timestamp: float = None,
+              local_received: float = None,
+              fields: Mapping = None,
+              precedence: Callable = states.precedence,
+              setattr: Callable = setattr,
+              task_event_to_state: Callable = TASK_EVENT_TO_STATE.get,
+              RETRY: str = states.RETRY) -> None:
         fields = fields or {}
 
         # using .get is faster than catching KeyError in this case.
@@ -331,11 +359,12 @@ class Task:
         # update current state with info from this event.
         self.__dict__.update(fields)
 
-    def info(self, fields=None, extra=[]):
+    def info(self, fields: Sequence[str] = None,
+             extra: Sequence[str] = []) -> Mapping:
         """Information about this task suitable for on-screen display."""
         fields = self._info_fields if fields is None else fields
 
-        def _keys():
+        def _keys() -> Iterator[Tuple[str, Any]]:
             for key in list(fields) + list(extra):
                 value = getattr(self, key, None)
                 if value is not None:
@@ -343,46 +372,46 @@ class Task:
 
         return dict(_keys())
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_TASK.format(self)
 
-    def as_dict(self):
+    def as_dict(self) -> Mapping:
         get = object.__getattribute__
         handler = self._serializer_handlers.get
         return {
             k: handler(k, pass1)(get(self, k)) for k in self._fields
         }
 
-    def _serializable_children(self, value):
+    def _serializable_children(self, value: Any) -> Sequence[str]:
         return [task.id for task in self.children]
 
-    def _serializable_root(self, value):
+    def _serializable_root(self, value: Any) -> str:
         return self.root_id
 
-    def _serializable_parent(self, value):
+    def _serializable_parent(self, value: Any) -> str:
         return self.parent_id
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return _depickle_task, (self.__class__, self.as_dict())
 
     @property
-    def id(self):
+    def id(self) -> str:
         return self.uuid
 
     @property
-    def origin(self):
+    def origin(self) -> str:
         return self.client if self.worker is None else self.worker.id
 
     @property
-    def ready(self):
+    def ready(self) -> bool:
         return self.state in states.READY_STATES
 
     @cached_property
-    def parent(self):
+    def parent(self) -> 'Task':
         return self.parent_id and self.cluster_state.tasks[self.parent_id]
 
     @cached_property
-    def root(self):
+    def root(self) -> 'Task':
         return self.root_id and self.cluster_state.tasks[self.root_id]
 
 
@@ -395,11 +424,17 @@ class State:
     task_count = 0
     heap_multiplier = 4
 
-    def __init__(self, callback=None,
-                 workers=None, tasks=None, taskheap=None,
-                 max_workers_in_memory=5000, max_tasks_in_memory=10000,
-                 on_node_join=None, on_node_leave=None,
-                 tasks_by_type=None, tasks_by_worker=None):
+    def __init__(self,
+                 callback: Callable = None,
+                 workers: Sequence[Worker] = None,
+                 tasks: Sequence[Task] = None,
+                 taskheap: List = None,
+                 max_workers_in_memory: int = 5000,
+                 max_tasks_in_memory: int = 10000,
+                 on_node_join: Callable = None,
+                 on_node_leave: Callable = None,
+                 tasks_by_type: Mapping = None,
+                 tasks_by_worker: Mapping = None) -> None:
         self.event_callback = callback
         self.workers = (LRUCache(max_workers_in_memory)
                         if workers is None else workers)
@@ -416,23 +451,21 @@ class State:
         self._tasks_to_resolve = {}
         self.rebuild_taskheap()
 
-        # type: Mapping[TaskName, WeakSet[Task]]
-        self.tasks_by_type = CallableDefaultdict(
+        self.tasks_by_type: Mapping[str, Set[Task]] = CallableDefaultdict(
             self._tasks_by_type, WeakSet)
         self.tasks_by_type.update(
             _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
 
-        # type: Mapping[Hostname, WeakSet[Task]]
-        self.tasks_by_worker = CallableDefaultdict(
+        self.tasks_by_worker: Mapping[str, Set[Task] = CallableDefaultdict(
             self._tasks_by_worker, WeakSet)
         self.tasks_by_worker.update(
             _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
 
     @cached_property
-    def _event(self):
+    def _event(self) -> Callable:
         return self._create_dispatcher()
 
-    def freeze_while(self, fun, *args, **kwargs):
+    def freeze_while(self, fun: Callable, *args, **kwargs) -> Any:
         clear_after = kwargs.pop('clear_after', False)
         with self._mutex:
             try:
@@ -441,11 +474,11 @@ class State:
                 if clear_after:
                     self._clear()
 
-    def clear_tasks(self, ready=True):
+    def clear_tasks(self, ready: bool = True) -> None:
         with self._mutex:
-            return self._clear_tasks(ready)
+            self._clear_tasks(ready)
 
-    def _clear_tasks(self, ready=True):
+    def _clear_tasks(self, ready: bool = True) -> None:
         if ready:
             in_progress = {
                 uuid: task for uuid, task in self.itertasks()
@@ -457,17 +490,18 @@ class State:
             self.tasks.clear()
         self._taskheap[:] = []
 
-    def _clear(self, ready=True):
+    def _clear(self, ready: bool = True) -> None:
         self.workers.clear()
         self._clear_tasks(ready)
         self.event_count = 0
         self.task_count = 0
 
-    def clear(self, ready=True):
+    def clear(self, ready: bool = True) -> None:
         with self._mutex:
-            return self._clear(ready)
+            self._clear(ready)
 
-    def get_or_create_worker(self, hostname, **kwargs):
+    def get_or_create_worker(self, hostname: str,
+                             **kwargs) -> Tuple[Worker, bool]:
         """Get or create worker by hostname.
 
         Returns:
@@ -483,7 +517,7 @@ class State:
                 hostname, **kwargs)
             return worker, True
 
-    def get_or_create_task(self, uuid):
+    def get_or_create_task(self, uuid: str) -> Tuple[Task, bool]:
         """Get or create task by uuid."""
         try:
             return self.tasks[uuid], False
@@ -491,11 +525,11 @@ class State:
             task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
             return task, True
 
-    def event(self, event):
+    def event(self, event: Mapping) -> Tuple[Any, str]:
         with self._mutex:
             return self._event(event)
 
-    def _create_dispatcher(self):
+    def _create_dispatcher(self) -> Callable:
         # noqa: C901
         # pylint: disable=too-many-statements
         # This code is highly optimized, but not for reusability.
@@ -522,9 +556,12 @@ class State:
         get_task_by_type_set = self.tasks_by_type.__getitem__
         get_task_by_worker_set = self.tasks_by_worker.__getitem__
 
-        def _event(event,
-                   timetuple=timetuple, KeyError=KeyError,
-                   insort=bisect.insort, created=True):
+        def _event(event: Mapping,
+                   *,
+                   timetuple: Callable = timetuple,
+                   KeyError: type = KeyError,
+                   insort: Callable = bisect.insort,
+                   created: bool = True) -> Tuple[Any, str]:
             self.event_count += 1
             if event_callback:
                 event_callback(self, event)
@@ -618,27 +655,29 @@ class State:
                 return (task, task_created), subject
         return _event
 
-    def _add_pending_task_child(self, task):
+    def _add_pending_task_child(self, task: Task) -> None:
         try:
             ch = self._tasks_to_resolve[task.parent_id]
         except KeyError:
             ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
         ch.add(task)
 
-    def rebuild_taskheap(self, timetuple=timetuple):
+    def rebuild_taskheap(self, *, timetuple: Callable = timetuple) -> None:
         heap = self._taskheap[:] = [
             timetuple(t.clock, t.timestamp, t.origin, ref(t))
             for t in self.tasks.values()
         ]
         heap.sort()
 
-    def itertasks(self, limit=None):
+    def itertasks(self, limit: int = None) -> Iterator[Task]:
         for index, row in enumerate(self.tasks.items()):
             yield row
             if limit and index + 1 >= limit:
                 break
 
-    def tasks_by_time(self, limit=None, reverse=True):
+    def tasks_by_time(self,
+                      limit: int = None,
+                      reverse: bool = True) -> Iterator[Tuple[str, Task]]:
         """Generator yielding tasks ordered by time.
 
         Yields:
@@ -658,7 +697,9 @@ class State:
                     seen.add(uuid)
     tasks_by_timestamp = tasks_by_time
 
-    def _tasks_by_type(self, name, limit=None, reverse=True):
+    def _tasks_by_type(self, name: str,
+                       limit: int = None,
+                       reverse: bool = True) -> Iterator[str, Task]:
         """Get all tasks by type.
 
         This is slower than accessing :attr:`tasks_by_type`,
@@ -673,7 +714,9 @@ class State:
             0, limit,
         )
 
-    def _tasks_by_worker(self, hostname, limit=None, reverse=True):
+    def _tasks_by_worker(self, hostname: str,
+                         limit: int = None,
+                         reverse: bool = True) -> Iterator[str, Task]:
         """Get all tasks by worker.
 
         Slower than accessing :attr:`tasks_by_worker`, but ordered by time.
@@ -684,18 +727,18 @@ class State:
             0, limit,
         )
 
-    def task_types(self):
+    def task_types(self) -> List[str]:
         """Return a list of all seen task types."""
         return sorted(self._seen_types)
 
-    def alive_workers(self):
+    def alive_workers(self) -> Iterator[Worker]:
         """Return a list of (seemingly) alive workers."""
         return (w for w in self.workers.values() if w.alive)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return R_STATE.format(self)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple:
         return self.__class__, (
             self.event_callback, self.workers, self.tasks, None,
             self.max_workers_in_memory, self.max_tasks_in_memory,
@@ -705,10 +748,11 @@ class State:
         )
 
 
-def _serialize_Task_WeakSet_Mapping(mapping):
+def _serialize_Task_WeakSet_Mapping(mapping: Mapping) -> Mapping:
     return {name: [t.id for t in tasks] for name, tasks in mapping.items()}
 
 
-def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
+def _deserialize_Task_WeakSet_Mapping(
+        mapping: Mapping, tasks: Sequence[Task]) -> Mapping[str, Set[Task]]:
     return {name: WeakSet(tasks[i] for i in ids if i in tasks)
             for name, ids in (mapping or {}).items()}

+ 0 - 4
celery/platforms.py

@@ -14,8 +14,6 @@ import signal as _signal
 import sys
 import warnings
 
-from collections import namedtuple
-
 from billiard.compat import get_fdmax, close_open_fds
 # fileno used to be in this module
 from kombu.utils.compat import maybe_fileno
@@ -65,8 +63,6 @@ PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | ((os.R_OK))
 PIDLOCKED = """ERROR: Pidfile ({0}) already exists.
 Seems we're already running? (pid: {1})"""
 
-_range = namedtuple('_range', ('start', 'stop'))
-
 C_FORCE_ROOT = os.environ.get('C_FORCE_ROOT', False)
 
 ROOT_DISALLOWED = """\

+ 11 - 2
celery/schedules.py

@@ -4,8 +4,9 @@ import numbers
 import re
 
 from bisect import bisect, bisect_left
-from collections import Iterable, namedtuple
+from collections import Iterable
 from datetime import datetime, timedelta
+from typing import NamedTuple
 
 from kombu.utils.objects import cached_property
 
@@ -21,7 +22,6 @@ __all__ = [
     'maybe_schedule', 'solar',
 ]
 
-schedstate = namedtuple('schedstate', ('is_due', 'next'))
 
 CRON_PATTERN_INVALID = """\
 Invalid crontab pattern.  Valid range is {min}-{max}. \
@@ -50,6 +50,15 @@ SOLAR_INVALID_EVENT = """\
 Argument event "{event}" is invalid, must be one of {all_events}.\
 """
 
+Kailuga1
+
+
+class schedstate(NamedTuple):
+    """Return value of ``schedule.is_due``."""
+
+    is_due: bool
+    next: float
+
 
 def cronfield(s):
     return '*' if s is None else s

+ 13 - 17
celery/utils/collections.py

@@ -103,7 +103,7 @@ class DictAttribute:
     `obj[k] = val -> obj.k = val`
     """
 
-    obj = None  # type: Mapping[Any, Any]
+    obj: Mapping = None
 
     def __init__(self, obj: Mapping) -> None:
         object.__setattr__(self, 'obj', obj)
@@ -155,10 +155,10 @@ MutableMapping.register(DictAttribute)  # noqa: E305
 class ChainMap(MutableMapping):
     """Key lookup on a sequence of maps."""
 
-    key_t = None      # type: Optional[KeyCallback]
-    changes = None    # type: Mapping
-    defaults = None   # type: Sequence[Mapping]
-    maps = None       # type: Sequence[Mapping]
+    key_t: Optional[KeyCallback] = None
+    changes: Mapping = None
+    defaults: Sequence[Mapping] = None
+    maps: Sequence[Mapping] = None
 
     def __init__(self, *maps: Sequence[Mapping],
                  key_t: KeyCallback = None, **kwargs) -> None:
@@ -305,7 +305,6 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
         return key,
 
     def __getitem__(self, key: str) -> Any:
-        # type: (str) -> Any
         keys = self._to_keys(key)
         getitem = super(ConfigurationView, self).__getitem__
         for k in keys + (
@@ -420,11 +419,9 @@ class LimitedSet:
         self.minlen = 0 if minlen is None else minlen
         self.expires = 0 if expires is None else expires
 
-        # type: Mapping[str, Any]
-        self._data = {}
+        self._data: Mapping[str, Any] = {}
 
-        # type: Sequence[Tuple[float, Any]]
-        self._heap = []
+        self._heap: Sequence[Tuple[float, Any]] = []
 
         if data:
             # import items from data
@@ -577,7 +574,7 @@ MutableSet.register(LimitedSet)  # noqa: E305
 class Evictable:
     """Mixin for classes supporting the ``evict`` method."""
 
-    Empty = Empty  # type: Exception
+    Empty = Empty
 
     def evict(self) -> None:
         """Force evict until maxsize is enforced."""
@@ -607,12 +604,12 @@ class Messagebuffer(Evictable):
                  maxsize: Optional[int],
                  iterable: Optional[Iterable]=None, deque: Any=deque) -> None:
         self.maxsize = maxsize
-        self.data = deque(iterable or [])  # type: deque
+        self.data = deque(iterable or [])
 
-        self._append = self.data.append    # type: Callable[[Any], None]
-        self._pop = self.data.popleft      # type: Callable[[], Any]
-        self._len = self.data.__len__      # type: Callable[[], int]
-        self._extend = self.data.extend    # type: Callable[[Iterable], None]
+        self._append = self.data.append
+        self._pop = self.data.popleft
+        self._len = self.data.__len__
+        self._extend = self.data.extend
 
     def put(self, item: Any) -> None:
         self._append(item)
@@ -682,7 +679,6 @@ class BufferMap(OrderedDict, Evictable):
         if iterable:
             self.update(iterable)
 
-        # type: int
         self.total = sum(len(buf) for buf in self.items())
 
     def put(self, key: Any, item: Any) -> None:

+ 3 - 3
celery/utils/debug.py

@@ -29,7 +29,7 @@ __all__ = [
     'humanbytes', 'mem_rss', 'ps', 'cry',
 ]
 
-UNITS = (               # type: Sequence[Tuple[float, str]]
+UNITS: Sequence[Tuple[float, str]] = (
     (2 ** 40.0, 'TB'),
     (2 ** 30.0, 'GB'),
     (2 ** 20.0, 'MB'),
@@ -37,8 +37,8 @@ UNITS = (               # type: Sequence[Tuple[float, str]]
     (0.0, 'b'),
 )
 
-_process = None         # type: Optional[Process]
-_mem_sample = []        # type: MutableSequence[str]
+_process: Process = None
+_mem_sample: MutableSequence[str] = []
 
 
 def _on_blocking(signum: int, frame: Any) -> None:

+ 5 - 5
celery/utils/functional.py

@@ -50,8 +50,8 @@ class mlazy(lazy):
     """
 
     #: Set to :const:`True` after the object has been evaluated.
-    evaluated = False  # type: bool
-    _value = None      # type: Any
+    evaluated: bool = False
+    _value: Any = None
 
     def evaluate(self) -> Any:
         if not self.evaluated:
@@ -169,7 +169,7 @@ def mattrgetter(*attrs: str) -> Callable[[Any], Mapping[str, Any]]:
 
 def uniq(it: Iterable) -> Iterable[Any]:
     """Return all unique elements in ``it``, preserving order."""
-    seen = set()  # type: MutableSet
+    seen = set()
     return (seen.add(obj) or obj for obj in it if obj not in seen)
 
 
@@ -274,7 +274,7 @@ def head_from_fun(fun: Callable,
         name, fun = fun.__class__.__name__, fun.__call__
     else:
         name = fun.__name__
-    definition = FUNHEAD_TEMPLATE.format(   # type: str
+    definition = FUNHEAD_TEMPLATE.format(
         fun_name=name,
         fun_args=_argsfromspec(getfullargspec(fun)),
         fun_value=1,
@@ -285,7 +285,7 @@ def head_from_fun(fun: Callable,
     # pylint: disable=exec-used
     # Tasks are rarely, if ever, created at runtime - exec here is fine.
     exec(definition, namespace)
-    result = namespace[name]  # type: Any
+    result: Any = namespace[name]
     result._source = definition
     if bound:
         return partial(result, object())

+ 7 - 7
celery/utils/graph.py

@@ -47,7 +47,7 @@ class DependencyGraph(Iterable):
     def __init__(self, it: Optional[Iterable]=None,
                  formatter: Optional['GraphFormatter']=None) -> None:
         self.formatter = formatter or GraphFormatter()
-        self.adjacent = {}  # type: Dict[Any, Any]
+        self.adjacent: Dict[Any, Any] = {}
         if it is not None:
             self.update(it)
 
@@ -116,8 +116,8 @@ class DependencyGraph(Iterable):
 
         See https://en.wikipedia.org/wiki/Topological_sorting
         """
-        count = Counter()  # type: Counter
-        result = []        # type: MutableSequence[Any]
+        count = Counter()
+        result: MutableSequence[Any] = []
 
         for node in self:
             for successor in self[node]:
@@ -141,9 +141,9 @@ class DependencyGraph(Iterable):
         See Also:
             :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
         """
-        result = []  # type: MutableSequence[Any]
-        stack = []   # type: MutableSequence[Any]
-        low = {}     # type: Dict[Any, Any]
+        result: MutableSequence[Any] = []
+        stack: MutableSequence[Any] = []
+        low: Dict[Any, Any] = {}
 
         def visit(node):
             if node in low:
@@ -178,7 +178,7 @@ class DependencyGraph(Iterable):
             formatter (celery.utils.graph.GraphFormatter): Custom graph
                 formatter to use.
         """
-        seen = set()  # type: MutableSet
+        seen: MutableSet = set()
         draw = formatter or self.formatter
 
         def P(s):

+ 6 - 6
celery/utils/imports.py

@@ -8,7 +8,7 @@ import warnings
 from contextlib import contextmanager
 from imp import reload
 from types import ModuleType
-from typing import Any, Callable, Iterator, Optional
+from typing import Any, Callable, Iterator
 from kombu.utils.imports import symbol_by_name
 
 #: Billiard sets this when execv is enabled.
@@ -65,8 +65,8 @@ def cwd_in_path() -> Iterator:
 
 
 def find_module(module: str,
-                path: Optional[str]=None,
-                imp: Optional[Callable]=None) -> ModuleType:
+                path: str = None,
+                imp: Callable = None) -> ModuleType:
     """Version of :func:`imp.find_module` supporting dots."""
     if imp is None:
         imp = importlib.import_module
@@ -86,8 +86,8 @@ def find_module(module: str,
 
 
 def import_from_cwd(module: str,
-                    imp: Optional[Callable]=None,
-                    package: Optional[str]=None) -> ModuleType:
+                    imp: Callable = None,
+                    package: str = None) -> ModuleType:
     """Import module, temporarily including modules in the current directory.
 
     Modules located in the current directory has
@@ -100,7 +100,7 @@ def import_from_cwd(module: str,
 
 
 def reload_from_cwd(module: ModuleType,
-                    reloader: Optional[Callable]=None) -> Any:
+                    reloader: Callable = None) -> Any:
     """Reload module (ensuring that CWD is in sys.path)."""
     if reloader is None:
         reloader = reload

+ 1 - 1
celery/utils/log.py

@@ -34,7 +34,7 @@ RESERVED_LOGGER_NAMES = {'celery', 'celery.task'}
 # Every logger in the celery package inherits from the "celery"
 # logger, and every task logger inherits from the "celery.task"
 # logger.
-base_logger = logger = _get_logger('celery')  # type: logging.Logger
+base_logger = logger = _get_logger('celery')
 
 
 def set_in_sighandler(value: bool) -> None:

+ 34 - 17
celery/utils/saferepr.py

@@ -11,13 +11,13 @@ Differences from regular :func:`repr`:
 Very slow with no limits, super quick with limits.
 """
 import traceback
-from collections import Mapping, deque, namedtuple
+from collections import Mapping, deque
 from decimal import Decimal
 from itertools import chain
 from numbers import Number
 from pprint import _recursion
 from typing import (
-    Any, AnyStr, Callable, Iterator, Set, Sequence, Tuple,
+    Any, AnyStr, Callable, Iterator, NamedTuple, Set, Sequence, Tuple,
 )
 from .text import truncate
 
@@ -26,25 +26,42 @@ __all__ = ['saferepr', 'reprstream']
 # pylint: disable=redefined-outer-name
 # We cache globals and attribute lookups, so disable this warning.
 
-#: Node representing literal text.
-#:   - .value: is the literal text value
-#:   - .truncate: specifies if this text can be truncated, for things like
-#:                LIT_DICT_END this will be False, as we always display
-#:                the ending brackets, e.g:  [[[1, 2, 3, ...,], ..., ]]
-#:   - .direction: If +1 the current level is increment by one,
-#:                 if -1 the current level is decremented by one, and
-#:                 if 0 the current level is unchanged.
-_literal = namedtuple('_literal', ('value', 'truncate', 'direction'))
 
-#: Node representing a dictionary key.
-_key = namedtuple('_key', ('value',))
+class _literal(NamedTuple):
+    """Node representing literal text.
 
-#: Node representing quoted text, e.g. a string value.
-_quoted = namedtuple('_quoted', ('value',))
+    Attributes:
+       - .value: is the literal text value
+       - .truncate: specifies if this text can be truncated, for things like
+                    LIT_DICT_END this will be False, as we always display
+                    the ending brackets, e.g:  [[[1, 2, 3, ...,], ..., ]]
+       - .direction: If +1 the current level is increment by one,
+                     if -1 the current level is decremented by one, and
+                     if 0 the current level is unchanged.
+    """
+
+    value: str
+    truncate: bool
+    direction: int
+
+
+class _key(NamedTuple):
+    """Node representing a dictionary key."""
+
+    value: str
+
+
+class _quoted(NamedTuple):
+    """Node representing quoted text, e.g. a string value."""
+
+    value: str
+
+
+class _dirty(NamedTuple):
+    """Recursion protection."""
 
+    objid: int
 
-#: Recursion protection.
-_dirty = namedtuple('_dirty', ('objid',))
 
 #: Types that are repsented as chars.
 chars_t = (bytes, str)

+ 3 - 3
celery/utils/serialization.py

@@ -105,13 +105,13 @@ class UnpickleableExceptionWrapper(Exception):
     """
 
     #: The module of the original exception.
-    exc_module = None       # type: str
+    exc_module: str = None
 
     #: The name of the original exception class.
-    exc_cls_name = None     # type: str
+    exc_cls_name: str = None
 
     #: The arguments for the original exception.
-    exc_args = None         # type: Sequence[Any]
+    exc_args: Sequence[Any] = None
 
     def __init__(self, exc_module: str, exc_cls_name: str,
                  exc_args: Sequence[Any], text: Optional[str]=None) -> None:

+ 2 - 4
celery/utils/static/__init__.py

@@ -2,13 +2,11 @@
 import os
 
 
-def get_file(*args):
-    # type: (*str) -> str
+def get_file(*args) -> str:
     """Get filename for static file."""
     return os.path.join(os.path.abspath(os.path.dirname(__file__)), *args)
 
 
-def logo():
-    # type: () -> bytes
+def logo() -> bytes:
     """Celery logo image."""
     return get_file('celery_128.png')

+ 7 - 4
celery/utils/sysinfo.py

@@ -1,15 +1,18 @@
 # -*- coding: utf-8 -*-
 """System information utilities."""
 import os
-from collections import namedtuple
 from math import ceil
+from typing import NamedTuple
 from kombu.utils.objects import cached_property
 
 __all__ = ['load_average', 'load_average_t', 'df']
 
-load_average_t = namedtuple('load_average_t', (
-    'min_1', 'min_5', 'min_15',
-))
+class load_average_t(NamedTuple):
+    """Load average information triple."""
+
+    min_1: float
+    min_5: float
+    min_15: float
 
 
 def _avg(f: float) -> float:

+ 27 - 26
celery/worker/autoscale.py

@@ -10,15 +10,13 @@ the :option:`celery worker --autoscale` option is used.
 """
 import os
 import threading
-
 from time import monotonic, sleep
-
+from typing import Mapping, Optional, Tuple
 from kombu.async.semaphore import DummyLock
-
 from celery import bootsteps
+from celery.types import AutoscalerT, LoopT, PoolT, RequestT, WorkerT
 from celery.utils.log import get_logger
 from celery.utils.threads import bgThread
-
 from . import state
 from .components import Pool
 
@@ -37,11 +35,11 @@ class WorkerComponent(bootsteps.StartStopStep):
     conditional = True
     requires = (Pool,)
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         self.enabled = w.autoscale
         w.autoscaler = None
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> Optional[AutoscalerT]:
         scaler = w.autoscaler = self.instantiate(
             w.autoscaler_cls,
             w.pool, w.max_concurrency, w.min_concurrency,
@@ -49,7 +47,7 @@ class WorkerComponent(bootsteps.StartStopStep):
         )
         return scaler if not w.use_eventloop else None
 
-    def register_with_event_loop(self, w, hub):
+    def register_with_event_loop(self, w: WorkerT, hub: LoopT) -> None:
         w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
         hub.call_repeatedly(
             w.autoscaler.keepalive, w.autoscaler.maybe_scale,
@@ -59,26 +57,29 @@ class WorkerComponent(bootsteps.StartStopStep):
 class Autoscaler(bgThread):
     """Background thread to autoscale pool workers."""
 
-    def __init__(self, pool, max_concurrency,
-                 min_concurrency=0, worker=None,
-                 keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
+    _last_scale_up: float = None
+
+    def __init__(self, pool: PoolT, max_concurrency: int,
+                 min_concurrency: int = 0,
+                 worker: WorkerT = None,
+                 keepalive: float = AUTOSCALE_KEEPALIVE,
+                 mutex: threading.Lock = None) -> None:
         super(Autoscaler, self).__init__()
         self.pool = pool
         self.mutex = mutex or threading.Lock()
         self.max_concurrency = max_concurrency
         self.min_concurrency = min_concurrency
         self.keepalive = keepalive
-        self._last_scale_up = None
         self.worker = worker
 
         assert self.keepalive, 'cannot scale down too fast.'
 
-    def body(self):
+    def body(self) -> None:
         with self.mutex:
             self.maybe_scale()
         sleep(1.0)
 
-    def _maybe_scale(self, req=None):
+    def _maybe_scale(self, req: RequestT = None) -> None:
         procs = self.processes
         cur = min(self.qty, self.max_concurrency)
         if cur > procs:
@@ -89,11 +90,11 @@ class Autoscaler(bgThread):
             self.scale_down(procs - cur)
             return True
 
-    def maybe_scale(self, req=None):
+    def maybe_scale(self, req: RequestT = None) -> None:
         if self._maybe_scale(req):
             self.pool.maintain_pool()
 
-    def update(self, max=None, min=None):
+    def update(self, max: int = None, min: int = None) -> Tuple[int, int]:
         with self.mutex:
             if max is not None:
                 if max < self.processes:
@@ -105,35 +106,35 @@ class Autoscaler(bgThread):
                 self.min_concurrency = min
             return self.max_concurrency, self.min_concurrency
 
-    def force_scale_up(self, n):
+    def force_scale_up(self, n: int) -> None:
         with self.mutex:
             new = self.processes + n
             if new > self.max_concurrency:
                 self.max_concurrency = new
             self._grow(n)
 
-    def force_scale_down(self, n):
+    def force_scale_down(self, n: int) -> None:
         with self.mutex:
             new = self.processes - n
             if new < self.min_concurrency:
                 self.min_concurrency = max(new, 0)
             self._shrink(min(n, self.processes))
 
-    def scale_up(self, n):
+    def scale_up(self, n: int) -> None:
         self._last_scale_up = monotonic()
-        return self._grow(n)
+        self._grow(n)
 
-    def scale_down(self, n):
+    def scale_down(self, n: int) -> None:
         if self._last_scale_up and (
                 monotonic() - self._last_scale_up > self.keepalive):
-            return self._shrink(n)
+            self._shrink(n)
 
-    def _grow(self, n):
+    def _grow(self, n: int) -> None:
         info('Scaling up %s processes.', n)
         self.pool.grow(n)
         self.worker.consumer._update_prefetch_count(n)
 
-    def _shrink(self, n):
+    def _shrink(self, n: int) -> None:
         info('Scaling down %s processes.', n)
         try:
             self.pool.shrink(n)
@@ -143,7 +144,7 @@ class Autoscaler(bgThread):
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
         self.worker.consumer._update_prefetch_count(-n)
 
-    def info(self):
+    def info(self) -> Mapping:
         return {
             'max': self.max_concurrency,
             'min': self.min_concurrency,
@@ -152,9 +153,9 @@ class Autoscaler(bgThread):
         }
 
     @property
-    def qty(self):
+    def qty(self) -> int:
         return len(state.reserved_requests)
 
     @property
-    def processes(self):
+    def processes(self) -> int:
         return self.pool.num_processes

+ 28 - 25
celery/worker/components.py

@@ -2,13 +2,14 @@
 """Worker-level Bootsteps."""
 import atexit
 import warnings
-
+from typing import Mapping, Sequence, Tuple, Union
 from kombu.async import Hub as _Hub, get_event_loop, set_event_loop
 from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 from kombu.async.timer import Timer as _Timer
-
+from kombu.types import HubT
 from celery import bootsteps
 from celery._state import _set_task_join_will_block
+from celery.types import BeatT, ConsumerT, PoolT, StepT, WorkerT
 from celery.exceptions import ImproperlyConfigured
 from celery.platforms import IS_WINDOWS
 from celery.utils.log import worker_logger as logger
@@ -32,7 +33,7 @@ as early as possible.
 class Timer(bootsteps.Step):
     """Timer bootstep."""
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> None:
         if w.use_eventloop:
             # does not use dedicated timer thread.
             w.timer = _Timer(max_interval=10.0)
@@ -46,26 +47,26 @@ class Timer(bootsteps.Step):
                                        on_error=self.on_timer_error,
                                        on_tick=self.on_timer_tick)
 
-    def on_timer_error(self, exc):
+    def on_timer_error(self, exc: Exception) -> None:
         logger.error('Timer error: %r', exc, exc_info=True)
 
-    def on_timer_tick(self, delay):
+    def on_timer_tick(self, delay: float) -> None:
         logger.debug('Timer wake-up! Next ETA %s secs.', delay)
 
 
 class Hub(bootsteps.StartStopStep):
     """Worker starts the event loop."""
 
-    requires = (Timer,)
+    requires: Sequence[StepT] = (Timer,)
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         w.hub = None
         super(Hub, self).__init__(w, **kwargs)
 
-    def include_if(self, w):
+    def include_if(self, w: WorkerT) -> bool:
         return w.use_eventloop
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> 'Hub':
         w.hub = get_event_loop()
         if w.hub is None:
             required_hub = getattr(w._conninfo, 'requires_hub', None)
@@ -74,16 +75,16 @@ class Hub(bootsteps.StartStopStep):
         self._patch_thread_primitives(w)
         return self
 
-    def start(self, w):
-        pass
+    async def start(self, w: WorkerT) -> None:
+        ...
 
-    def stop(self, w):
+    async def stop(self, w: WorkerT) -> None:
         w.hub.close()
 
-    def terminate(self, w):
+    async def terminate(self, w: WorkerT) -> None:
         w.hub.close()
 
-    def _patch_thread_primitives(self, w):
+    def _patch_thread_primitives(self, w: WorkerT) -> None:
         # make clock use dummy lock
         w.app.clock.mutex = DummyLock()
         # multiprocessing's ApplyResult uses this lock.
@@ -111,7 +112,9 @@ class Pool(bootsteps.StartStopStep):
 
     requires = (Hub,)
 
-    def __init__(self, w, autoscale=None, **kwargs):
+    def __init__(self, w: WorkerT,
+                 autoscale: Union[str, Tuple[int, int]] = None,
+                 **kwargs) -> None:
         w.pool = None
         w.max_concurrency = None
         w.min_concurrency = w.concurrency
@@ -124,15 +127,15 @@ class Pool(bootsteps.StartStopStep):
             w.max_concurrency, w.min_concurrency = w.autoscale
         super(Pool, self).__init__(w, **kwargs)
 
-    def close(self, w):
+    async def close(self, w: WorkerT) -> None:
         if w.pool:
             w.pool.close()
 
-    def terminate(self, w):
+    async def terminate(self, w: WorkerT) -> None:
         if w.pool:
             w.pool.terminate()
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> PoolT:
         semaphore = None
         max_restarts = None
         if w.app.conf.worker_pool in GREEN_POOLS:  # pragma: no cover
@@ -168,10 +171,10 @@ class Pool(bootsteps.StartStopStep):
         _set_task_join_will_block(pool.task_join_will_block)
         return pool
 
-    def info(self, w):
+    def info(self, w: WorkerT) -> Mapping:
         return {'pool': w.pool.info if w.pool else 'N/A'}
 
-    def register_with_event_loop(self, w, hub):
+    def register_with_event_loop(self, w: WorkerT, hub: HubT) -> None:
         w.pool.register_with_event_loop(hub)
 
 
@@ -184,12 +187,12 @@ class Beat(bootsteps.StartStopStep):
     label = 'Beat'
     conditional = True
 
-    def __init__(self, w, beat=False, **kwargs):
+    def __init__(self, w: WorkerT, beat: bool = False, **kwargs) -> None:
         self.enabled = w.beat = beat
         w.beat = None
         super(Beat, self).__init__(w, beat=beat, **kwargs)
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> BeatT:
         from celery.beat import EmbeddedService
         if w.pool_cls.__module__.endswith(('gevent', 'eventlet')):
             raise ImproperlyConfigured(ERR_B_GREEN)
@@ -202,12 +205,12 @@ class Beat(bootsteps.StartStopStep):
 class StateDB(bootsteps.Step):
     """Bootstep that sets up between-restart state database file."""
 
-    def __init__(self, w, **kwargs):
+    def __init__(self, w: WorkerT, **kwargs) -> None:
         self.enabled = w.statedb
         w._persistence = None
         super(StateDB, self).__init__(w, **kwargs)
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> None:
         w._persistence = w.state.Persistent(w.state, w.statedb, w.app.clock)
         atexit.register(w._persistence.save)
 
@@ -217,7 +220,7 @@ class Consumer(bootsteps.StartStopStep):
 
     last = True
 
-    def create(self, w):
+    def create(self, w: WorkerT) -> ConsumerT:
         if w.max_concurrency:
             prefetch_count = max(w.min_concurrency, 1) * w.prefetch_multiplier
         else:

+ 4 - 2
celery/worker/consumer/agent.py

@@ -1,5 +1,7 @@
 """Celery + :pypi:`cell` integration."""
+from typing import Any
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from .connection import Connection
 
 __all__ = ['Agent']
@@ -11,10 +13,10 @@ class Agent(bootsteps.StartStopStep):
     conditional = True
     requires = (Connection,)
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         self.agent_cls = self.enabled = c.app.conf.worker_agent
         super(Agent, self).__init__(c, **kwargs)
 
-    def create(self, c):
+    def create(self, c: WorkerConsumerT) -> Any:
         agent = c.agent = self.instantiate(self.agent_cls, c.connection)
         return agent

+ 8 - 6
celery/worker/consumer/connection.py

@@ -1,6 +1,8 @@
 """Consumer Broker Connection Bootstep."""
+from typing import Mapping
 from kombu.common import ignore_errors
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.utils.log import get_logger
 
 __all__ = ['Connection']
@@ -12,22 +14,22 @@ info = logger.info
 class Connection(bootsteps.StartStopStep):
     """Service managing the consumer broker connection."""
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         c.connection = None
         super(Connection, self).__init__(c, **kwargs)
 
-    def start(self, c):
-        c.connection = c.connect()
+    async def start(self, c) -> None:
+        c.connection = await c.connect()
         info('Connected to %s', c.connection.as_uri())
 
-    def shutdown(self, c):
+    async def shutdown(self, c) -> None:
         # We must set self.connection to None here, so
         # that the green pidbox thread exits.
         connection, c.connection = c.connection, None
         if connection:
-            ignore_errors(connection, connection.close)
+            await ignore_errors(connection, connection.close)
 
-    def info(self, c):
+    def info(self, c) -> Mapping:
         params = 'N/A'
         if c.connection:
             params = c.connection.info()

+ 113 - 83
celery/worker/consumer/consumer.py

@@ -11,19 +11,24 @@ import os
 
 from collections import defaultdict
 from time import sleep
+from typing import Callable, Mapping, Tuple
 
 from billiard.common import restart_state
 from billiard.exceptions import RestartFreqExceeded
 from kombu.async.semaphore import DummyLock
+from kombu.types import ConnectionT, MessageT
 from kombu.utils.compat import _detect_environment
 from kombu.utils.encoding import safe_repr
 from kombu.utils.limits import TokenBucket
-from vine import ppartial, promise
+from vine import Thenable, ppartial, promise
 
 from celery import bootsteps
 from celery import signals
 from celery.app.trace import build_tracer
 from celery.exceptions import InvalidTaskError, NotRegistered
+from celery.types import (
+    AppT, LoopT, PoolT, RequestT, TaskT, TimerT, WorkerT, WorkerConsumerT,
+)
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.nodenames import gethostname
@@ -109,7 +114,7 @@ body: {0}
 """
 
 
-def dump_body(m, body):
+def dump_body(m: MessageT, body: bytes):
     """Format message body for debugging purposes."""
     # v2 protocol does not deserialize body
     body = m.body if body is None else body
@@ -151,15 +156,25 @@ class Consumer:
             'celery.worker.consumer.agent:Agent',
         ]
 
-        def shutdown(self, parent):
-            self.send_all(parent, 'shutdown')
-
-    def __init__(self, on_task_request,
-                 init_callback=noop, hostname=None,
-                 pool=None, app=None,
-                 timer=None, controller=None, hub=None, amqheartbeat=None,
-                 worker_options=None, disable_rate_limits=False,
-                 initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
+        async def shutdown(self, parent: WorkerConsumerT) -> None:
+            await self.send_all(parent, 'shutdown')
+
+    def __init__(self,
+                 on_task_request: Callable,
+                 *,
+                 init_callback: Callable = noop,
+                 hostname: str = None,
+                 pool: PoolT = None,
+                 app: AppT = None,
+                 timer: TimerT = None,
+                 controller: WorkerT = None,
+                 hub: LoopT = None,
+                 amqheartbeat: float = None,
+                 worker_options: Mapping = None,
+                 disable_rate_limits: bool = False,
+                 initial_prefetch_count: int = 2,
+                 prefetch_multiplier: int = 1,
+                 **kwargs) -> None:
         self.app = app
         self.controller = controller
         self.init_callback = init_callback
@@ -213,14 +228,14 @@ class Consumer:
         )
         self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
 
-    def call_soon(self, p, *args, **kwargs):
+    def call_soon(self, p: Thenable, *args, **kwargs) -> Thenable:
         p = ppartial(p, *args, **kwargs)
         if self.hub:
             return self.hub.call_soon(p)
         self._pending_operations.append(p)
         return p
 
-    def perform_pending_operations(self):
+    def perform_pending_operations(self) -> None:
         if not self.hub:
             while self._pending_operations:
                 try:
@@ -228,16 +243,16 @@ class Consumer:
                 except Exception as exc:  # pylint: disable=broad-except
                     logger.exception('Pending callback raised: %r', exc)
 
-    def bucket_for_task(self, type):
+    def bucket_for_task(self, type: TaskT) -> TokenBucket:
         limit = rate(getattr(type, 'rate_limit', None))
         return TokenBucket(limit, capacity=1) if limit else None
 
-    def reset_rate_limits(self):
+    def reset_rate_limits(self) -> None:
         self.task_buckets.update(
             (n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
         )
 
-    def _update_prefetch_count(self, index=0):
+    def _update_prefetch_count(self, index: int = 0) -> int:
         """Update prefetch count after pool/shrink grow operations.
 
         Index must be the change in number of processes as a positive
@@ -256,33 +271,35 @@ class Consumer:
         )
         return self._update_qos_eventually(index)
 
-    def _update_qos_eventually(self, index):
+    def _update_qos_eventually(self, index: int) -> int:
         return (self.qos.decrement_eventually if index < 0
                 else self.qos.increment_eventually)(
             abs(index) * self.prefetch_multiplier)
 
-    def _limit_move_to_pool(self, request):
+    async def _limit_move_to_pool(self, request: RequestT) -> None:
         task_reserved(request)
-        self.on_task_request(request)
+        await self.on_task_request(request)
 
-    def _on_bucket_wakeup(self, bucket, tokens):
+    async def _on_bucket_wakeup(self, bucket: TokenBucket, tokens: int):
         try:
             request = bucket.pop()
         except IndexError:
             pass
         else:
-            self._limit_move_to_pool(request)
+            await self._limit_move_to_pool(request)
             self._schedule_oldest_bucket_request(bucket, tokens)
 
-    def _schedule_oldest_bucket_request(self, bucket, tokens):
+    def _schedule_oldest_bucket_request(
+            self, bucket: TokenBucket, tokens: int) -> None:
         try:
             request = bucket.pop()
         except IndexError:
             pass
         else:
-            return self._schedule_bucket_request(request, bucket, tokens)
+            self._schedule_bucket_request(request, bucket, tokens)
 
-    def _schedule_bucket_request(self, request, bucket, tokens):
+    def _schedule_bucket_request(
+            self, request: RequestT, bucket: TokenBucket, tokens: int) -> None:
         bucket.can_consume(tokens)
         bucket.add(request)
         pri = self._limit_order = (self._limit_order + 1) % 10
@@ -292,12 +309,16 @@ class Consumer:
             priority=pri,
         )
 
-    def _limit_task(self, request, bucket, tokens):
+    def _limit_task(
+            self, request: RequestT,
+            bucket: TokenBucket,
+            tokens: int) -> None:
         if bucket.contents:
-            return bucket.add(request)
-        return self._schedule_bucket_request(request, bucket, tokens)
+            bucket.add(request)
+        else:
+            self._schedule_bucket_request(request, bucket, tokens)
 
-    def start(self):
+    async def start(self) -> None:
         blueprint = self.blueprint
         while blueprint.state != CLOSE:
             maybe_shutdown()
@@ -309,7 +330,7 @@ class Consumer:
                     sleep(1)
             self.restart_count += 1
             try:
-                blueprint.start(self)
+                await blueprint.start(self)
             except self.connection_errors as exc:
                 # If we're not retrying connections, no need to catch
                 # connection errors
@@ -324,42 +345,42 @@ class Consumer:
                     else:
                         self.on_connection_error_before_connected(exc)
                     self.on_close()
-                    blueprint.restart(self)
+                    await blueprint.restart(self)
 
-    def on_connection_error_before_connected(self, exc):
+    def on_connection_error_before_connected(self, exc: Exception) -> None:
         error(CONNECTION_ERROR, self.conninfo.as_uri(), exc,
               'Trying to reconnect...')
 
-    def on_connection_error_after_connected(self, exc):
+    def on_connection_error_after_connected(self, exc: Exception) -> None:
         warn(CONNECTION_RETRY, exc_info=True)
         try:
             self.connection.collect()
         except Exception:  # pylint: disable=broad-except
             pass
 
-    def register_with_event_loop(self, hub):
-        self.blueprint.send_all(
+    async def register_with_event_loop(self, hub: LoopT) -> None:
+        await self.blueprint.send_all(
             self, 'register_with_event_loop', args=(hub,),
             description='Hub.register',
         )
 
-    def shutdown(self):
-        self.blueprint.shutdown(self)
+    async def shutdown(self) -> None:
+        await self.blueprint.shutdown(self)
 
-    def stop(self):
-        self.blueprint.stop(self)
+    async def stop(self) -> None:
+        await self.blueprint.stop(self)
 
-    def on_ready(self):
+    def on_ready(self) -> None:
         callback, self.init_callback = self.init_callback, None
         if callback:
             callback(self)
 
-    def loop_args(self):
+    def loop_args(self) -> Tuple:
         return (self, self.connection, self.task_consumer,
                 self.blueprint, self.hub, self.qos, self.amqheartbeat,
                 self.app.clock, self.amqheartbeat_rate)
 
-    def on_decode_error(self, message, exc):
+    async def on_decode_error(self, message: MessageT, exc: Exception) -> None:
         """Callback called if an error occurs while decoding a message.
 
         Simply logs the error and acknowledges the message so it
@@ -373,9 +394,9 @@ class Consumer:
              exc, message.content_type, message.content_encoding,
              safe_repr(message.headers), dump_body(message, message.body),
              exc_info=1)
-        message.ack()
+        await message.ack()
 
-    def on_close(self):
+    def on_close(self) -> None:
         # Clear internal queues to get rid of old messages.
         # They can't be acked anyway, as a delivery tag is specific
         # to the current channel.
@@ -390,26 +411,29 @@ class Consumer:
         if self.pool and self.pool.flush:
             self.pool.flush()
 
-    def connect(self):
+    async def connect(self) -> ConnectionT:
         """Establish the broker connection used for consuming tasks.
 
         Retries establishing the connection if the
         :setting:`broker_connection_retry` setting is enabled
         """
         conn = self.connection_for_read(heartbeat=self.amqheartbeat)
+        await conn.connect()
         if self.hub:
             conn.transport.register_with_event_loop(conn.connection, self.hub)
         return conn
 
-    def connection_for_read(self, heartbeat=None):
-        return self.ensure_connected(
+    async def connection_for_read(self,
+                                  heartbeat: float = None) -> ConnectionT:
+        return await self.ensure_connected(
             self.app.connection_for_read(heartbeat=heartbeat))
 
-    def connection_for_write(self, heartbeat=None):
-        return self.ensure_connected(
+    async def connection_for_write(self,
+                                   heartbeat: float = None) -> ConnectionT:
+        return await self.ensure_connected(
             self.app.connection_for_write(heartbeat=heartbeat))
 
-    def ensure_connected(self, conn):
+    async def ensure_connected(self, conn: ConnectionT) -> ConnectionT:
         # Callback called for each retry while the connection
         # can't be established.
         def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
@@ -422,25 +446,28 @@ class Consumer:
         # until needed.
         if not self.app.conf.broker_connection_retry:
             # retry disabled, just call connect directly.
-            conn.connect()
+            await conn.connect()
             return conn
 
-        conn = conn.ensure_connection(
+        conn = await conn.ensure_connection(
             _error_handler, self.app.conf.broker_connection_max_retries,
             callback=maybe_shutdown,
         )
         return conn
 
-    def _flush_events(self):
+    def _flush_events(self) -> None:
         if self.event_dispatcher:
             self.event_dispatcher.flush()
 
-    def on_send_event_buffered(self):
+    def on_send_event_buffered(self) -> None:
         if self.hub:
             self.hub._ready.add(self._flush_events)
 
-    def add_task_queue(self, queue, exchange=None, exchange_type=None,
-                       routing_key=None, **options):
+    async def add_task_queue(self, queue: str,
+                             exchange: str = None,
+                             exchange_type: str = None,
+                             routing_key: str = None,
+                             **options) -> None:
         cset = self.task_consumer
         queues = self.app.amqp.queues
         # Must use in' here, as __missing__ will automatically
@@ -458,33 +485,34 @@ class Consumer:
                                   routing_key=routing_key, **options)
         if not cset.consuming_from(queue):
             cset.add_queue(q)
-            cset.consume()
+            await cset.consume()
             info('Started consuming from %s', queue)
 
-    def cancel_task_queue(self, queue):
+    async def cancel_task_queue(self, queue):
         info('Canceling queue %s', queue)
         self.app.amqp.queues.deselect(queue)
-        self.task_consumer.cancel_by_queue(queue)
+        await self.task_consumer.cancel_by_queue(queue)
 
-    def apply_eta_task(self, task):
+    def apply_eta_task(self, request: RequestT) -> None:
         """Method called by the timer to apply a task with an ETA/countdown."""
-        task_reserved(task)
-        self.on_task_request(task)
+        task_reserved(request)
+        self.on_task_request(request)
         self.qos.decrement_eventually()
 
-    def _message_report(self, body, message):
+    def _message_report(self, body: bytes, message: MessageT) -> str:
         return MESSAGE_REPORT.format(dump_body(message, body),
                                      safe_repr(message.content_type),
                                      safe_repr(message.content_encoding),
                                      safe_repr(message.delivery_info),
                                      safe_repr(message.headers))
 
-    def on_unknown_message(self, body, message):
+    async def on_unknown_message(self, body: bytes, message: MessageT) -> None:
         warn(UNKNOWN_FORMAT, self._message_report(body, message))
-        message.reject_log_error(logger, self.connection_errors)
+        await message.reject_log_error(logger, self.connection_errors)
         signals.task_rejected.send(sender=self, message=message, exc=None)
 
-    def on_unknown_task(self, body, message, exc):
+    async def on_unknown_task(
+            self, body: bytes, message: MessageT, exc: Exception) -> None:
         error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         try:
             id_, name = message.headers['id'], message.headers['task']
@@ -499,12 +527,12 @@ class Consumer:
             reply_to=message.properties.get('reply_to'),
             errbacks=None,
         )
-        message.reject_log_error(logger, self.connection_errors)
-        self.app.backend.mark_as_failure(
+        await message.reject_log_error(logger, self.connection_errors)
+        await self.app.backend.mark_as_failure(
             id_, NotRegistered(name), request=request,
         )
         if self.event_dispatcher:
-            self.event_dispatcher.send(
+            await self.event_dispatcher.send(
                 'task-failed', uuid=id_,
                 exception='NotRegistered({0!r})'.format(name),
             )
@@ -512,27 +540,29 @@ class Consumer:
             sender=self, message=message, exc=exc, name=name, id=id_,
         )
 
-    def on_invalid_task(self, body, message, exc):
+    async def on_invalid_task(
+            self, body: bytes, message: MessageT, exc: Exception) -> None:
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
-        message.reject_log_error(logger, self.connection_errors)
+        await message.reject_log_error(logger, self.connection_errors)
         signals.task_rejected.send(sender=self, message=message, exc=exc)
 
-    def update_strategies(self):
+    def update_strategies(self) -> None:
         loader = self.app.loader
         for name, task in self.app.tasks.items():
             self.strategies[name] = task.start_strategy(self.app, self)
             task.__trace__ = build_tracer(name, task, loader, self.hostname,
                                           app=self.app)
 
-    def create_task_handler(self, promise=promise):
+    def create_task_handler(self) -> Callable:
         strategies = self.strategies
         on_unknown_message = self.on_unknown_message
         on_unknown_task = self.on_unknown_task
         on_invalid_task = self.on_invalid_task
         callbacks = self.on_task_message
         call_soon = self.call_soon
+        promise_t = promise
 
-        def on_task_received(message):
+        async def on_task_received(message: MessageT) -> None:
             # payload will only be set for v1 protocol, since v2
             # will defer deserializing the message body to the pool.
             payload = None
@@ -544,29 +574,29 @@ class Consumer:
                 try:
                     payload = message.decode()
                 except Exception as exc:  # pylint: disable=broad-except
-                    return self.on_decode_error(message, exc)
+                    return await self.on_decode_error(message, exc)
                 try:
                     type_, payload = payload['task'], payload  # protocol v1
                 except (TypeError, KeyError):
-                    return on_unknown_message(payload, message)
+                    return await on_unknown_message(payload, message)
             try:
                 strategy = strategies[type_]
             except KeyError as exc:
-                return on_unknown_task(None, message, exc)
+                return await on_unknown_task(None, message, exc)
             else:
                 try:
-                    strategy(
+                    await strategy(
                         message, payload,
-                        promise(call_soon, (message.ack_log_error,)),
-                        promise(call_soon, (message.reject_log_error,)),
+                        promise_t(call_soon, (message.ack_log_error,)),
+                        promise_t(call_soon, (message.reject_log_error,)),
                         callbacks,
                     )
                 except InvalidTaskError as exc:
-                    return on_invalid_task(payload, message, exc)
+                    return await on_invalid_task(payload, message, exc)
 
         return on_task_received
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(self)``."""
         return '<Consumer: {self.hostname} ({state})>'.format(
             self=self, state=self.blueprint.human_state(),
@@ -583,9 +613,9 @@ class Evloop(bootsteps.StartStopStep):
     label = 'event loop'
     last = True
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         self.patch_all(c)
         c.loop(*c.loop_args())
 
-    def patch_all(self, c):
+    def patch_all(self, c: WorkerConsumerT) -> None:
         c.qos._mutex = DummyLock()

+ 3 - 2
celery/worker/consumer/control.py

@@ -6,6 +6,7 @@ The actual commands are implemented in :mod:`celery.worker.control`.
 """
 from celery import bootsteps
 from celery.utils.log import get_logger
+from celery.types import WorkerConsumerT
 from celery.worker import pidbox
 from .tasks import Tasks
 
@@ -19,7 +20,7 @@ class Control(bootsteps.StartStopStep):
 
     requires = (Tasks,)
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         self.is_green = c.pool is not None and c.pool.is_green
         self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
         self.start = self.box.start
@@ -27,6 +28,6 @@ class Control(bootsteps.StartStopStep):
         self.shutdown = self.box.shutdown
         super(Control, self).__init__(c, **kwargs)
 
-    def include_if(self, c):
+    def include_if(self, c: WorkerConsumerT) -> bool:
         return (c.app.conf.worker_enable_remote_control and
                 c.conninfo.supports_exchange_type('fanout'))

+ 14 - 13
celery/worker/consumer/events.py

@@ -4,6 +4,7 @@
 """
 from kombu.common import ignore_errors
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from .connection import Connection
 
 __all__ = ['Events']
@@ -14,11 +15,11 @@ class Events(bootsteps.StartStopStep):
 
     requires = (Connection,)
 
-    def __init__(self, c,
-                 task_events=True,
-                 without_heartbeat=False,
-                 without_gossip=False,
-                 **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 task_events: bool = True,
+                 without_heartbeat: bool = False,
+                 without_gossip: bool = False,
+                 **kwargs) -> None:
         self.groups = None if task_events else ['worker']
         self.send_events = (
             task_events or
@@ -28,7 +29,7 @@ class Events(bootsteps.StartStopStep):
         c.event_dispatcher = None
         super(Events, self).__init__(c, **kwargs)
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         # flush events sent while connection was down.
         prev = self._close(c)
         dis = c.event_dispatcher = c.app.events.Dispatcher(
@@ -45,10 +46,13 @@ class Events(bootsteps.StartStopStep):
             dis.extend_buffer(prev)
             dis.flush()
 
-    def stop(self, c):
-        pass
+    async def stop(self, c: WorkerConsumerT) -> None:
+        ...
 
-    def _close(self, c):
+    async def shutdown(self, c: WorkerConsumerT) -> None:
+        await self._close(c)
+
+    async def _close(self, c: WorkerConsumerT) -> None:
         if c.event_dispatcher:
             dispatcher = c.event_dispatcher
             # remember changes from remote control commands:
@@ -56,10 +60,7 @@ class Events(bootsteps.StartStopStep):
 
             # close custom connection
             if dispatcher.connection:
-                ignore_errors(c, dispatcher.connection.close)
+                await ignore_errors(c, dispatcher.connection.close)
             ignore_errors(c, dispatcher.close)
             c.event_dispatcher = None
             return dispatcher
-
-    def shutdown(self, c):
-        self._close(c)

+ 48 - 38
celery/worker/consumer/gossip.py

@@ -3,11 +3,14 @@ from collections import defaultdict
 from functools import partial
 from heapq import heappush
 from operator import itemgetter
+from typing import Callable, Set, Sequence
 
 from kombu import Consumer
 from kombu.async.semaphore import DummyLock
+from kombu.types import ChannelT, ConsumerT, MessageT
 
 from celery import bootsteps
+from celery.types import AppT, EventT, SignatureT, WorkerT, WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.objects import Bunch
 
@@ -29,10 +32,13 @@ class Gossip(bootsteps.ConsumerStep):
     _cons_stamp_fields = itemgetter(
         'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
     )
-    compatible_transports = {'amqp', 'redis'}
+    compatible_transports: Set[str] = {'amqp', 'redis'}
 
-    def __init__(self, c, without_gossip=False,
-                 interval=5.0, heartbeat_interval=2.0, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_gossip: bool = False,
+                 interval: float = 5.0,
+                 heartbeat_interval: float = 2.0,
+                 **kwargs) -> None:
         self.enabled = not without_gossip and self.compatible_transport(c.app)
         self.app = c.app
         c.gossip = self
@@ -72,40 +78,41 @@ class Gossip(bootsteps.ConsumerStep):
 
         super(Gossip, self).__init__(c, **kwargs)
 
-    def compatible_transport(self, app):
+    def compatible_transport(self, app: AppT) -> bool:
         with app.connection_for_read() as conn:
             return conn.transport.driver_type in self.compatible_transports
 
-    def election(self, id, topic, action=None):
+    async def election(self, id: str, topic: str, action: str = None):
         self.consensus_replies[id] = []
-        self.dispatcher.send(
+        await self.dispatcher.send(
             'worker-elect',
             id=id, topic=topic, action=action, cver=1,
         )
 
-    def call_task(self, task):
+    async def call_task(self, task: SignatureT) -> None:
         try:
             self.app.signature(task).apply_async()
         except Exception as exc:  # pylint: disable=broad-except
             logger.exception('Could not call task: %r', exc)
 
-    def on_elect(self, event):
+    async def on_elect(self, event: EventT) -> None:
         try:
             (id_, clock, hostname, pid,
              topic, action, _) = self._cons_stamp_fields(event)
         except KeyError as exc:
-            return logger.exception('election request missing field %s', exc)
-        heappush(
-            self.consensus_requests[id_],
-            (clock, '%s.%s' % (hostname, pid), topic, action),
-        )
-        self.dispatcher.send('worker-elect-ack', id=id_)
+            logger.exception('election request missing field %s', exc)
+        else:
+            heappush(
+                self.consensus_requests[id_],
+                (clock, '%s.%s' % (hostname, pid), topic, action),
+            )
+            await self.dispatcher.send('worker-elect-ack', id=id_)
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         super().start(c)
         self.dispatcher = c.event_dispatcher
 
-    def on_elect_ack(self, event):
+    async def on_elect_ack(self, event: EventT) -> None:
         id = event['id']
         try:
             replies = self.consensus_replies[id]
@@ -125,59 +132,62 @@ class Gossip(bootsteps.ConsumerStep):
                 except KeyError:
                     logger.exception('Unknown election topic %r', topic)
                 else:
-                    handler(action)
+                    await handler(action)
             else:
                 info('node %s elected for %r', leader, id)
             self.consensus_requests.pop(id, None)
             self.consensus_replies.pop(id, None)
 
-    def on_node_join(self, worker):
+    async def on_node_join(self, worker: WorkerT):
         debug('%s joined the party', worker.hostname)
-        self._call_handlers(self.on.node_join, worker)
+        await self._call_handlers(self.on.node_join, worker)
 
-    def on_node_leave(self, worker):
+    async def on_node_leave(self, worker: WorkerT):
         debug('%s left', worker.hostname)
-        self._call_handlers(self.on.node_leave, worker)
+        await self._call_handlers(self.on.node_leave, worker)
 
-    def on_node_lost(self, worker):
+    async def on_node_lost(self, worker: WorkerT):
         info('missed heartbeat from %s', worker.hostname)
-        self._call_handlers(self.on.node_lost, worker)
+        await self._call_handlers(self.on.node_lost, worker)
 
-    def _call_handlers(self, handlers, *args, **kwargs):
+    async def _call_handlers(self, handlers: Sequence[Callable],
+                             *args, **kwargs) -> None:
         for handler in handlers:
             try:
-                handler(*args, **kwargs)
+                await handler(*args, **kwargs)
             except Exception as exc:  # pylint: disable=broad-except
                 logger.exception(
                     'Ignored error from handler %r: %r', handler, exc)
 
-    def register_timer(self):
+    def register_timer(self) -> None:
         if self._tref is not None:
             self._tref.cancel()
         self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
 
-    def periodic(self):
+    async def periodic(self) -> None:
         workers = self.state.workers
         dirty = set()
         for worker in workers.values():
             if not worker.alive:
                 dirty.add(worker)
-                self.on_node_lost(worker)
+                await self.on_node_lost(worker)
         for worker in dirty:
             workers.pop(worker.hostname, None)
 
-    def get_consumers(self, channel):
+    def get_consumers(self, channel: ChannelT) -> Sequence[ConsumerT]:
         self.register_timer()
         ev = self.Receiver(channel, routing_key='worker.#',
                            queue_ttl=self.heartbeat_interval)
-        return [Consumer(
-            channel,
-            queues=[ev.queue],
-            on_message=partial(self.on_message, ev.event_from_message),
-            no_ack=True
-        )]
-
-    def on_message(self, prepare, message):
+        return [
+            Consumer(
+                channel,
+                queues=[ev.queue],
+                on_message=partial(self.on_message, ev.event_from_message),
+                no_ack=True
+            )
+        ]
+
+    async def on_message(self, prepare: Callable, message: MessageT) -> None:
         _type = message.delivery_info['routing_key']
 
         # For redis when `fanout_patterns=False` (See Issue #1882)
@@ -188,7 +198,7 @@ class Gossip(bootsteps.ConsumerStep):
         except KeyError:
             pass
         else:
-            return handler(message.payload)
+            return await handler(message.payload)
 
         # proto2: hostname in header; proto1: in body
         hostname = (message.headers.get('hostname') or

+ 11 - 6
celery/worker/consumer/heart.py

@@ -1,5 +1,6 @@
 """Worker Event Heartbeat Bootstep."""
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.worker import heartbeat
 from .events import Events
 
@@ -17,19 +18,23 @@ class Heart(bootsteps.StartStopStep):
 
     requires = (Events,)
 
-    def __init__(self, c,
-                 without_heartbeat=False, heartbeat_interval=None, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_heartbeat: bool = False,
+                 heartbeat_interval: float = None,
+                 **kwargs) -> None:
         self.enabled = not without_heartbeat
         self.heartbeat_interval = heartbeat_interval
         c.heart = None
         super(Heart, self).__init__(c, **kwargs)
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         c.heart = heartbeat.Heart(
             c.timer, c.event_dispatcher, self.heartbeat_interval,
         )
-        c.heart.start()
+        await c.heart.start()
 
-    def stop(self, c):
-        c.heart = c.heart and c.heart.stop()
+    async def stop(self, c: WorkerConsumerT) -> None:
+        heart, c.heart = c.heart, None
+        if heart:
+            await heart.stop()
     shutdown = stop

+ 22 - 13
celery/worker/consumer/mingle.py

@@ -1,5 +1,7 @@
 """Worker <-> Worker Sync at startup (Bootstep)."""
+from typing import Mapping
 from celery import bootsteps
+from celery.types import AppT, WorkerConsumerT
 from celery.utils.log import get_logger
 from .events import Events
 
@@ -23,53 +25,60 @@ class Mingle(bootsteps.StartStopStep):
     requires = (Events,)
     compatible_transports = {'amqp', 'redis'}
 
-    def __init__(self, c, without_mingle=False, **kwargs):
+    def __init__(self, c: WorkerConsumerT,
+                 without_mingle: bool = False,
+                 **kwargs) -> None:
         self.enabled = not without_mingle and self.compatible_transport(c.app)
         super(Mingle, self).__init__(
             c, without_mingle=without_mingle, **kwargs)
 
-    def compatible_transport(self, app):
+    def compatible_transport(self, app: AppT) -> bool:
         with app.connection_for_read() as conn:
             return conn.transport.driver_type in self.compatible_transports
 
-    def start(self, c):
-        self.sync(c)
+    async def start(self, c: WorkerConsumerT) -> None:
+        await self.sync(c)
 
-    def sync(self, c):
+    async def sync(self, c: WorkerConsumerT) -> None:
         info('mingle: searching for neighbors')
-        replies = self.send_hello(c)
+        replies = await self.send_hello(c)
         if replies:
             info('mingle: sync with %s nodes',
                  len([reply for reply, value in replies.items() if value]))
-            [self.on_node_reply(c, nodename, reply)
+            [await self.on_node_reply(c, nodename, reply)
              for nodename, reply in replies.items() if reply]
             info('mingle: sync complete')
         else:
             info('mingle: all alone')
 
-    def send_hello(self, c):
+    async def send_hello(self, c: WorkerConsumerT) -> Mapping:
         inspect = c.app.control.inspect(timeout=1.0, connection=c.connection)
         our_revoked = c.controller.state.revoked
         replies = inspect.hello(c.hostname, our_revoked._data) or {}
         replies.pop(c.hostname, None)  # delete my own response
         return replies
 
-    def on_node_reply(self, c, nodename, reply):
+    async def on_node_reply(self, c: WorkerConsumerT,
+                            nodename: str, reply: Mapping) -> None:
         debug('mingle: processing reply from %s', nodename)
         try:
-            self.sync_with_node(c, **reply)
+            await self.sync_with_node(c, **reply)
         except MemoryError:
             raise
         except Exception as exc:  # pylint: disable=broad-except
             exception('mingle: sync with %s failed: %r', nodename, exc)
 
-    def sync_with_node(self, c, clock=None, revoked=None, **kwargs):
+    async def sync_with_node(self, c: WorkerConsumerT,
+                             clock: int = None,
+                             revoked: Mapping = None,
+                             **kwargs) -> None:
         self.on_clock_event(c, clock)
         self.on_revoked_received(c, revoked)
 
-    def on_clock_event(self, c, clock):
+    def on_clock_event(self, c: WorkerConsumerT, clock: int):
         c.app.clock.adjust(clock) if clock else c.app.clock.forward()
 
-    def on_revoked_received(self, c, revoked):
+    def on_revoked_received(self,
+                            c: WorkerConsumerT, revoked: Mapping) -> None:
         if revoked:
             c.controller.state.revoked.update(revoked)

+ 9 - 7
celery/worker/consumer/tasks.py

@@ -1,6 +1,8 @@
 """Worker Task Consumer Bootstep."""
+from typing import Mapping
 from kombu.common import QoS, ignore_errors
 from celery import bootsteps
+from celery.types import WorkerConsumerT
 from celery.utils.log import get_logger
 from .mingle import Mingle
 
@@ -14,11 +16,11 @@ class Tasks(bootsteps.StartStopStep):
 
     requires = (Mingle,)
 
-    def __init__(self, c, **kwargs):
+    def __init__(self, c: WorkerConsumerT, **kwargs) -> None:
         c.task_consumer = c.qos = None
         super(Tasks, self).__init__(c, **kwargs)
 
-    def start(self, c):
+    async def start(self, c: WorkerConsumerT) -> None:
         """Start task consumer."""
         c.update_strategies()
 
@@ -36,20 +38,20 @@ class Tasks(bootsteps.StartStopStep):
             c.connection, on_decode_error=c.on_decode_error,
         )
 
-        def set_prefetch_count(prefetch_count):
-            return c.task_consumer.qos(
+        async def set_prefetch_count(prefetch_count: int) -> None:
+            await c.task_consumer.qos(
                 prefetch_count=prefetch_count,
                 apply_global=qos_global,
             )
         c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
 
-    def stop(self, c):
+    async def stop(self, c: WorkerConsumerT) -> None:
         """Stop task consumer."""
         if c.task_consumer:
             debug('Canceling task consumer...')
             ignore_errors(c, c.task_consumer.cancel)
 
-    def shutdown(self, c):
+    async def shutdown(self, c: WorkerConsumerT) -> None:
         """Shutdown task consumer."""
         if c.task_consumer:
             self.stop(c)
@@ -57,6 +59,6 @@ class Tasks(bootsteps.StartStopStep):
             ignore_errors(c, c.task_consumer.close)
             c.task_consumer = None
 
-    def info(self, c):
+    def info(self, c: WorkerConsumerT) -> Mapping:
         """Return task consumer info."""
         return {'prefetch_count': c.qos.value if c.qos else 'N/A'}

+ 100 - 58
celery/worker/control.py

@@ -2,19 +2,19 @@
 """Worker remote control command implementations."""
 import io
 import tempfile
-
-from collections import UserDict, namedtuple
-
+from collections import UserDict
+from typing import (
+    Any, Callable, Iterable, Mapping, NamedTuple, Sequence, Tuple, Union,
+)
 from billiard.common import TERM_SIGNAME
 from kombu.utils.encoding import safe_repr
-
 from celery.exceptions import WorkerShutdown
 from celery.platforms import signals as _signals
+from celery.types import ControlStateT as StateT, RequestT, TimerT
 from celery.utils.functional import maybe_list
 from celery.utils.log import get_logger
 from celery.utils.serialization import jsonify, strtobool
 from celery.utils.time import rate
-
 from . import state as worker_state
 from .request import Request
 
@@ -23,17 +23,24 @@ __all__ = ['Panel']
 DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
 logger = get_logger(__name__)
 
-controller_info_t = namedtuple('controller_info_t', [
-    'alias', 'type', 'visible', 'default_timeout',
-    'help', 'signature', 'args', 'variadic',
-])
+class controller_info_t(NamedTuple):
+    """Metadata about control command."""
+
+    alias: str
+    type: str
+    visible: bool
+    default_timeout: float
+    help: str
+    signature: str
+    args: Sequence[Tuple[str, type]]
+    variadic: str
 
 
-def ok(value):
+def ok(value: Any) -> Mapping:
     return {'ok': value}
 
 
-def nok(value):
+def nok(value: Any) -> Mapping:
     return {'error': value}
 
 
@@ -44,17 +51,24 @@ class Panel(UserDict):
     meta = {}      # -"-
 
     @classmethod
-    def register(cls, *args, **kwargs):
+    def register(cls, *args, **kwargs) -> Callable:
         if args:
             return cls._register(**kwargs)(*args)
         return cls._register(**kwargs)
 
     @classmethod
-    def _register(cls, name=None, alias=None, type='control',
-                  visible=True, default_timeout=1.0, help=None,
-                  signature=None, args=None, variadic=None):
-
-        def _inner(fun):
+    def _register(cls,
+                  name: str = None,
+                  alias: str = None,
+                  type: str = 'control',
+                  visible: bool = True,
+                  default_timeout: float = 1.0,
+                  help: str = None,
+                  signature: str = None,
+                  args: Sequence[Tuple[str, type]] = None,
+                  variadic: str = None) -> Callable:
+
+        def _inner(fun: Callable) -> Callable:
             control_name = name or fun.__name__
             _help = help or (fun.__doc__ or '').strip().split('\n')[0]
             cls.data[control_name] = fun
@@ -67,18 +81,18 @@ class Panel(UserDict):
         return _inner
 
 
-def control_command(**kwargs):
+def control_command(**kwargs) -> Callable:
     return Panel.register(type='control', **kwargs)
 
 
-def inspect_command(**kwargs):
+def inspect_command(**kwargs) -> Callable:
     return Panel.register(type='inspect', **kwargs)
 
 # -- App
 
 
 @inspect_command()
-def report(state):
+def report(state: StateT) -> Mapping:
     """Information about Celery installation for bug reports."""
     return ok(state.app.bugreport())
 
@@ -88,14 +102,14 @@ def report(state):
     signature='[include_defaults=False]',
     args=[('with_defaults', strtobool)],
 )
-def conf(state, with_defaults=False, **kwargs):
+def conf(state: StateT, with_defaults: bool = False, **kwargs) -> Mapping:
     """List configuration."""
     return jsonify(state.app.conf.table(with_defaults=with_defaults),
                    keyfilter=_wanted_config_key,
                    unknown_type_filter=safe_repr)
 
 
-def _wanted_config_key(key):
+def _wanted_config_key(key: Any) -> bool:
     return isinstance(key, str) and not key.startswith('__')
 
 
@@ -105,7 +119,9 @@ def _wanted_config_key(key):
     variadic='ids',
     signature='[id1 [id2 [... [idN]]]]',
 )
-def query_task(state, ids, **kwargs):
+def query_task(state: StateT,
+               ids: Union[str, Sequence[str]],
+               **kwargs) -> Mapping[str, Tuple[str, Mapping]]:
     """Query for task information by id."""
     return {
         req.id: (_state_of_task(req), req.info())
@@ -113,8 +129,10 @@ def query_task(state, ids, **kwargs):
     }
 
 
-def _find_requests_by_id(ids,
-                         get_request=worker_state.requests.__getitem__):
+def _find_requests_by_id(
+        ids: Sequence[str],
+        *,
+        get_request=worker_state.requests.__getitem__) -> Iterable[RequestT]:
     for task_id in ids:
         try:
             yield get_request(task_id)
@@ -122,9 +140,11 @@ def _find_requests_by_id(ids,
             pass
 
 
-def _state_of_task(request,
-                   is_active=worker_state.active_requests.__contains__,
-                   is_reserved=worker_state.reserved_requests.__contains__):
+def _state_of_task(
+        request: RequestT,
+        *,
+        is_active=worker_state.active_requests.__contains__,
+        is_reserved=worker_state.reserved_requests.__contains__) -> str:
     if is_active(request):
         return 'active'
     elif is_reserved(request):
@@ -136,7 +156,9 @@ def _state_of_task(request,
     variadic='task_id',
     signature='[id1 [id2 [... [idN]]]]',
 )
-def revoke(state, task_id, terminate=False, signal=None, **kwargs):
+def revoke(state: StateT, task_id: str,
+           terminate: bool = False,
+           signal: Union[str, int] = None, **kwargs) -> Mapping:
     """Revoke task by task id (or list of ids).
 
     Keyword Arguments:
@@ -176,7 +198,7 @@ def revoke(state, task_id, terminate=False, signal=None, **kwargs):
     args=[('signal', str)],
     signature='<signal> [id1 [id2 [... [idN]]]]'
 )
-def terminate(state, signal, task_id, **kwargs):
+def terminate(state: StateT, signal: str, task_id: str, **kwargs) -> Mapping:
     """Terminate task by task id (or list of ids)."""
     return revoke(state, task_id, terminate=True, signal=signal)
 
@@ -185,7 +207,8 @@ def terminate(state, signal, task_id, **kwargs):
     args=[('task_name', str), ('rate_limit', str)],
     signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>',
 )
-def rate_limit(state, task_name, rate_limit, **kwargs):
+def rate_limit(state: StateT, task_name: str, rate_limit: Union[str, int],
+               **kwargs) -> Mapping:
     """Tell worker(s) to modify the rate limit for a task by type.
 
     See Also:
@@ -225,7 +248,10 @@ def rate_limit(state, task_name, rate_limit, **kwargs):
     args=[('task_name', str), ('soft', float), ('hard', float)],
     signature='<task_name> <soft_secs> [hard_secs]',
 )
-def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
+def time_limit(state: StateT,
+               task_name: str = None,
+               hard: float = None,
+               soft: float = None, **kwargs) -> Mapping:
     """Tell worker(s) to modify the time limit for task by type.
 
     Arguments:
@@ -252,13 +278,14 @@ def time_limit(state, task_name=None, hard=None, soft=None, **kwargs):
 
 
 @inspect_command()
-def clock(state, **kwargs):
+def clock(state: StateT, **kwargs) -> Mapping:
     """Get current logical clock value."""
     return {'clock': state.app.clock.value}
 
 
 @control_command()
-def election(state, id, topic, action=None, **kwargs):
+def election(state: StateT, id: str, topic: str, action: str = None,
+             **kwargs) -> None:
     """Hold election.
 
     Arguments:
@@ -271,7 +298,7 @@ def election(state, id, topic, action=None, **kwargs):
 
 
 @control_command()
-def enable_events(state):
+def enable_events(state: StateT) -> Mapping:
     """Tell worker(s) to send task-related events."""
     dispatcher = state.consumer.event_dispatcher
     if dispatcher.groups and 'task' not in dispatcher.groups:
@@ -282,7 +309,7 @@ def enable_events(state):
 
 
 @control_command()
-def disable_events(state):
+def disable_events(state: StateT) -> Mapping:
     """Tell worker(s) to stop sending task-related events."""
     dispatcher = state.consumer.event_dispatcher
     if 'task' in dispatcher.groups:
@@ -293,7 +320,7 @@ def disable_events(state):
 
 
 @control_command()
-def heartbeat(state):
+def heartbeat(state: StateT) -> None:
     """Tell worker(s) to send event heartbeat immediately."""
     logger.debug('Heartbeat requested by remote.')
     dispatcher = state.consumer.event_dispatcher
@@ -303,7 +330,8 @@ def heartbeat(state):
 # -- Worker
 
 @inspect_command(visible=False)
-def hello(state, from_node, revoked=None, **kwargs):
+def hello(state: StateT, from_node: str,
+          revoked: Mapping = None, **kwargs) -> Mapping:
     """Request mingle sync-data."""
     # pylint: disable=redefined-outer-name
     # XXX Note that this redefines `revoked`:
@@ -319,24 +347,24 @@ def hello(state, from_node, revoked=None, **kwargs):
 
 
 @inspect_command(default_timeout=0.2)
-def ping(state, **kwargs):
+def ping(state: StateT, **kwargs) -> Mapping:
     """Ping worker(s)."""
     return ok('pong')
 
 
 @inspect_command()
-def stats(state, **kwargs):
+def stats(state: StateT, **kwargs) -> Mapping:
     """Request worker statistics/information."""
     return state.consumer.controller.stats()
 
 
 @inspect_command(alias='dump_schedule')
-def scheduled(state, **kwargs):
+def scheduled(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of currently scheduled ETA/countdown tasks."""
     return list(_iter_schedule_requests(state.consumer.timer))
 
 
-def _iter_schedule_requests(timer):
+def _iter_schedule_requests(timer: TimerT) -> Iterable[Mapping]:
     for waiting in timer.schedule.queue:
         try:
             arg0 = waiting.entry.args[0]
@@ -352,7 +380,7 @@ def _iter_schedule_requests(timer):
 
 
 @inspect_command(alias='dump_reserved')
-def reserved(state, **kwargs):
+def reserved(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of currently reserved tasks, not including scheduled/active."""
     reserved_tasks = (
         state.tset(worker_state.reserved_requests) -
@@ -364,14 +392,14 @@ def reserved(state, **kwargs):
 
 
 @inspect_command(alias='dump_active')
-def active(state, **kwargs):
+def active(state: StateT, **kwargs) -> Sequence[Mapping]:
     """List of tasks currently being executed."""
     return [request.info()
             for request in state.tset(worker_state.active_requests)]
 
 
 @inspect_command(alias='dump_revoked')
-def revoked(state, **kwargs):
+def revoked(state: StateT, **kwargs) -> Sequence[str]:
     """List of revoked task-ids."""
     return list(worker_state.revoked)
 
@@ -381,7 +409,9 @@ def revoked(state, **kwargs):
     variadic='taskinfoitems',
     signature='[attr1 [attr2 [... [attrN]]]]',
 )
-def registered(state, taskinfoitems=None, builtins=False, **kwargs):
+def registered(state: StateT,
+               taskinfoitems: Sequence[str] = None,
+               builtins: bool = False, **kwargs) -> Sequence[Mapping]:
     """List of registered tasks.
 
     Arguments:
@@ -415,7 +445,10 @@ def registered(state, taskinfoitems=None, builtins=False, **kwargs):
     args=[('type', str), ('num', int), ('max_depth', int)],
     signature='[object_type=Request] [num=200 [max_depth=10]]',
 )
-def objgraph(state, num=200, max_depth=10, type='Request'):  # pragma: no cover
+def objgraph(state: StateT,  # pragma: no cover
+             num: int = 200,
+             max_depth: int = 10,
+             type: str = 'Request') -> Mapping:
     """Create graph of uncollected objects (memory-leak debugging).
 
     Arguments:
@@ -440,7 +473,7 @@ def objgraph(state, num=200, max_depth=10, type='Request'):  # pragma: no cover
 
 
 @inspect_command()
-def memsample(state, **kwargs):
+def memsample(state: StateT, **kwargs) -> str:
     """Sample current RSS memory usage."""
     from celery.utils.debug import sample_mem
     return sample_mem()
@@ -450,7 +483,8 @@ def memsample(state, **kwargs):
     args=[('samples', int)],
     signature='[n_samples=10]',
 )
-def memdump(state, samples=10, **kwargs):  # pragma: no cover
+def memdump(state: StateT,  # pragma: no cover
+            samples: int = 10, **kwargs) -> str:
     """Dump statistics of previous memsample requests."""
     from celery.utils import debug
     out = io.StringIO()
@@ -464,7 +498,7 @@ def memdump(state, samples=10, **kwargs):  # pragma: no cover
     args=[('n', int)],
     signature='[N=1]',
 )
-def pool_grow(state, n=1, **kwargs):
+def pool_grow(state: StateT, n: int = 1, **kwargs) -> Mapping:
     """Grow pool by n processes/threads."""
     if state.consumer.controller.autoscaler:
         state.consumer.controller.autoscaler.force_scale_up(n)
@@ -478,7 +512,7 @@ def pool_grow(state, n=1, **kwargs):
     args=[('n', int)],
     signature='[N=1]',
 )
-def pool_shrink(state, n=1, **kwargs):
+def pool_shrink(state: StateT, n: int = 1, **kwargs) -> Mapping:
     """Shrink pool by n processes/threads."""
     if state.consumer.controller.autoscaler:
         state.consumer.controller.autoscaler.force_scale_down(n)
@@ -489,7 +523,11 @@ def pool_shrink(state, n=1, **kwargs):
 
 
 @control_command()
-def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
+def pool_restart(state: StateT,
+                 modules: Sequence[str] = None,
+                 reload: bool = False,
+                 reloader: Callable = None,
+                 **kwargs) -> Mapping:
     """Restart execution pool."""
     if state.app.conf.worker_pool_restarts:
         state.consumer.controller.reload(modules, reload, reloader=reloader)
@@ -502,7 +540,7 @@ def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs):
     args=[('max', int), ('min', int)],
     signature='[max [min]]',
 )
-def autoscale(state, max=None, min=None):
+def autoscale(state: StateT, max: int = None, min: int = None) -> Mapping:
     """Modify autoscale settings."""
     autoscaler = state.consumer.controller.autoscaler
     if autoscaler:
@@ -512,7 +550,8 @@ def autoscale(state, max=None, min=None):
 
 
 @control_command()
-def shutdown(state, msg='Got shutdown from remote', **kwargs):
+def shutdown(state: StateT, msg: str = 'Got shutdown from remote',
+             **kwargs) -> None:
     """Shutdown worker(s)."""
     logger.warning(msg)
     raise WorkerShutdown(msg)
@@ -529,8 +568,11 @@ def shutdown(state, msg='Got shutdown from remote', **kwargs):
     ],
     signature='<queue> [exchange [type [routing_key]]]',
 )
-def add_consumer(state, queue, exchange=None, exchange_type=None,
-                 routing_key=None, **options):
+def add_consumer(state: StateT, queue: str,
+                 exchange: str = None,
+                 exchange_type: str = None,
+                 routing_key: str = None,
+                 **options) -> Mapping:
     """Tell worker(s) to consume from task queue by name."""
     state.consumer.call_soon(
         state.consumer.add_task_queue,
@@ -542,7 +584,7 @@ def add_consumer(state, queue, exchange=None, exchange_type=None,
     args=[('queue', str)],
     signature='<queue>',
 )
-def cancel_consumer(state, queue, **_):
+def cancel_consumer(state: StateT, queue: str, **_) -> Mapping:
     """Tell worker(s) to stop consuming from task queue by name."""
     state.consumer.call_soon(
         state.consumer.cancel_task_queue, queue,
@@ -551,7 +593,7 @@ def cancel_consumer(state, queue, **_):
 
 
 @inspect_command()
-def active_queues(state):
+async def active_queues(state: StateT) -> Sequence[Mapping]:
     """List the task queues a worker are currently consuming from."""
     if state.consumer.task_consumer:
         return [dict(queue.as_dict(recurse=True))

+ 15 - 9
celery/worker/heartbeat.py

@@ -4,7 +4,10 @@
 This is the internal thread responsible for sending heartbeat events
 at regular intervals (may not be an actual thread).
 """
+from typing import Awaitable
+from celery.events import EventDispatcher
 from celery.signals import heartbeat_sent
+from celery.types import EventT, TimerT
 from celery.utils.sysinfo import load_average
 from .state import SOFTWARE_INFO, active_requests, all_total_count
 
@@ -22,7 +25,8 @@ class Heart:
             heartbeats.  Default is 2 seconds.
     """
 
-    def __init__(self, timer, eventer, interval=None):
+    def __init__(self, *, timer: TimerT, eventer: EventDispatcher,
+                 interval: float = None):
         self.timer = timer
         self.eventer = eventer
         self.interval = float(interval or 2.0)
@@ -36,23 +40,25 @@ class Heart:
         self._send_sent_signal = (
             heartbeat_sent.send if heartbeat_sent.receivers else None)
 
-    def _send(self, event):
+    async def _send(self, event: EventT) -> Awaitable:
         if self._send_sent_signal is not None:
             self._send_sent_signal(sender=self)
-        return self.eventer.send(event, freq=self.interval,
-                                 active=len(active_requests),
-                                 processed=all_total_count[0],
-                                 loadavg=load_average(),
-                                 **SOFTWARE_INFO)
+        return await self.eventer.send(
+            event,
+            freq=self.interval,
+            active=len(active_requests),
+            processed=all_total_count[0],
+            loadavg=load_average(),
+            **SOFTWARE_INFO)
 
-    def start(self):
+    async def start(self) -> None:
         if self.eventer.enabled:
             self._send('worker-online')
             self.tref = self.timer.call_repeatedly(
                 self.interval, self._send, ('worker-heartbeat',),
             )
 
-    def stop(self):
+    async def stop(self) -> None:
         if self.tref is not None:
             self.timer.cancel(self.tref)
             self.tref = None

+ 72 - 47
celery/worker/request.py

@@ -7,21 +7,27 @@ how tasks are executed.
 import logging
 import sys
 
-from datetime import datetime
+from datetime import datetime, tzinfo
+from typing import Any, Awaitable, Callable, Mapping, Sequence, Tuple, Union
 from weakref import ref
 
 from billiard.common import TERM_SIGNAME
+from billiard.einfo import ExceptionInfo
+from kombu.types import MessageT
 from kombu.utils.encoding import safe_repr, safe_str
 from kombu.utils.objects import cached_property
 
 from celery import signals
 from celery.app.trace import trace_task, trace_task_ret
+from celery.events import EventDispatcher
 from celery.exceptions import (
     Ignore, TaskRevokedError, InvalidTaskError,
     SoftTimeLimitExceeded, TimeLimitExceeded,
     WorkerLostError, Terminated, Retry, Reject,
 )
 from celery.platforms import signals as _signals
+from celery.types import AppT, PoolT, SignatureT, TaskT
+from celery.utils.collections import LimitedSet
 from celery.utils.functional import maybe, noop
 from celery.utils.log import get_logger
 from celery.utils.nodenames import gethostname
@@ -44,7 +50,7 @@ _does_info = False
 _does_debug = False
 
 
-def __optimize__():
+def __optimize__() -> None:
     # this is also called by celery.app.trace.setup_worker_optimizations
     global _does_debug
     global _does_info
@@ -65,13 +71,13 @@ class Request:
     """A request for task execution."""
 
     acknowledged = False
-    time_start = None
-    worker_pid = None
-    time_limits = (None, None)
-    _already_revoked = False
-    _terminate_on_ack = None
-    _apply_result = None
-    _tzlocal = None
+    time_start: float = None
+    worker_pid: int = None
+    time_limits: Tuple[float, float] = (None, None)
+    _already_revoked: bool = False
+    _terminate_on_ack: bool = None
+    _apply_result: Awaitable = None
+    _tzlocal: tzinfo = None
 
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
@@ -83,13 +89,23 @@ class Request:
             '__weakref__', '__dict__',
         )
 
-    def __init__(self, message, on_ack=noop,
-                 hostname=None, eventer=None, app=None,
-                 connection_errors=None, request_dict=None,
-                 task=None, on_reject=noop, body=None,
-                 headers=None, decoded=False, utc=True,
-                 maybe_make_aware=maybe_make_aware,
-                 maybe_iso8601=maybe_iso8601, **opts):
+    def __init__(self, message: MessageT,
+                 *,
+                 on_ack: Callable = noop,
+                 hostname: str = None,
+                 eventer: EventDispatcher = None,
+                 app: AppT = None,
+                 connection_errors: Tuple = None,
+                 request_dict: Mapping = None,
+                 task: str = None,
+                 on_reject: Callable = noop,
+                 body: Mapping = None,
+                 headers: Mapping = None,
+                 decoded: bool = False,
+                 utc: bool = True,
+                 maybe_make_aware: Callable = maybe_make_aware,
+                 maybe_iso8601: Callable = maybe_iso8601,
+                 **opts) -> None:
         if headers is None:
             headers = message.headers
         if body is None:
@@ -163,10 +179,10 @@ class Request:
         self.request_dict = headers
 
     @property
-    def delivery_info(self):
+    def delivery_info(self) -> Mapping:
         return self.request_dict['delivery_info']
 
-    def execute_using_pool(self, pool, **kwargs):
+    def execute_using_pool(self, pool: PoolT, **kwargs) -> Awaitable:
         """Used by the worker to send this task to the pool.
 
         Arguments:
@@ -198,7 +214,7 @@ class Request:
         self._apply_result = maybe(ref, result)
         return result
 
-    def execute(self, loglevel=None, logfile=None):
+    def execute(self, loglevel: int = None, logfile: str = None) -> Any:
         """Execute the task in a :func:`~celery.app.trace.trace_task`.
 
         Arguments:
@@ -230,15 +246,16 @@ class Request:
         self.acknowledge()
         return retval
 
-    def maybe_expire(self):
+    def maybe_expire(self) -> bool:
         """If expired, mark the task as revoked."""
         if self.expires:
             now = datetime.now(self.expires.tzinfo)
             if now > self.expires:
                 revoked_tasks.add(self.id)
                 return True
+        return False
 
-    def terminate(self, pool, signal=None):
+    def terminate(self, pool: PoolT, signal: Union[str, int] = None) -> None:
         signal = _signals.signum(signal or TERM_SIGNAME)
         if self.time_start:
             pool.terminate_job(self.worker_pid, signal)
@@ -250,7 +267,8 @@ class Request:
             if obj is not None:
                 obj.terminate(signal)
 
-    def _announce_revoked(self, reason, terminated, signum, expired):
+    def _announce_revoked(self, reason: str, terminated: bool,
+                          signum: int, expired: bool):
         task_ready(self)
         self.send_event('task-revoked',
                         terminated=terminated, signum=signum, expired=expired)
@@ -262,7 +280,7 @@ class Request:
         send_revoked(self.task, request=self,
                      terminated=terminated, signum=signum, expired=expired)
 
-    def revoked(self):
+    def revoked(self) -> bool:
         """If revoked, skip task and mark state."""
         expired = False
         if self._already_revoked:
@@ -277,11 +295,11 @@ class Request:
             return True
         return False
 
-    def send_event(self, type, **fields):
+    def send_event(self, type: str, **fields) -> Awaitable:
         if self.eventer and self.eventer.enabled and self.task.send_events:
-            self.eventer.send(type, uuid=self.id, **fields)
+            return self.eventer.send(type, uuid=self.id, **fields)
 
-    def on_accepted(self, pid, time_accepted):
+    def on_accepted(self, pid: int, time_accepted: float) -> None:
         """Handler called when task is accepted by worker pool."""
         self.worker_pid = pid
         self.time_start = time_accepted
@@ -294,7 +312,7 @@ class Request:
         if self._terminate_on_ack is not None:
             self.terminate(*self._terminate_on_ack)
 
-    def on_timeout(self, soft, timeout):
+    def on_timeout(self, soft: bool, timeout: float) -> None:
         """Handler called if the task times out."""
         task_ready(self)
         if soft:
@@ -313,7 +331,8 @@ class Request:
         if self.task.acks_late:
             self.acknowledge()
 
-    def on_success(self, failed__retval__runtime, **kwargs):
+    def on_success(self, failed__retval__runtime: Tuple[bool, Any, float],
+                   **kwargs) -> None:
         """Handler called if the task was successfully processed."""
         failed, retval, runtime = failed__retval__runtime
         if failed:
@@ -327,7 +346,7 @@ class Request:
 
         self.send_event('task-succeeded', result=retval, runtime=runtime)
 
-    def on_retry(self, exc_info):
+    def on_retry(self, exc_info: ExceptionInfo) -> None:
         """Handler called if the task should be retried."""
         if self.task.acks_late:
             self.acknowledge()
@@ -336,7 +355,9 @@ class Request:
                         exception=safe_repr(exc_info.exception.exc),
                         traceback=safe_str(exc_info.traceback))
 
-    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
+    def on_failure(self, exc_info: ExceptionInfo,
+                   send_failed_event: bool = True,
+                   return_ok: bool = False) -> None:
         """Handler called if the task raised an exception."""
         task_ready(self)
         if isinstance(exc_info.exception, MemoryError):
@@ -385,19 +406,19 @@ class Request:
             error('Task handler raised error: %r', exc,
                   exc_info=exc_info.exc_info)
 
-    def acknowledge(self):
+    def acknowledge(self) -> None:
         """Acknowledge task."""
         if not self.acknowledged:
             self.on_ack(logger, self.connection_errors)
             self.acknowledged = True
 
-    def reject(self, requeue=False):
+    def reject(self, requeue: bool = False) -> None:
         if not self.acknowledged:
             self.on_reject(logger, self.connection_errors, requeue)
             self.acknowledged = True
             self.send_event('task-rejected', requeue=requeue)
 
-    def info(self, safe=False):
+    def info(self, safe: bool = False) -> Mapping:
         return {
             'id': self.id,
             'name': self.name,
@@ -411,10 +432,10 @@ class Request:
             'worker_pid': self.worker_pid,
         }
 
-    def humaninfo(self):
+    def humaninfo(self) -> str:
         return '{0.name}[{0.id}]'.format(self)
 
-    def __str__(self):
+    def __str__(self) -> str:
         """``str(self)``."""
         return ' '.join([
             self.humaninfo(),
@@ -422,7 +443,7 @@ class Request:
             ' expires:[{0}]'.format(self.expires) if self.expires else '',
         ])
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(self)``."""
         return '<{0}: {1} {2} {3}>'.format(
             type(self).__name__, self.humaninfo(),
@@ -430,32 +451,32 @@ class Request:
         )
 
     @property
-    def tzlocal(self):
+    def tzlocal(self) -> tzinfo:
         if self._tzlocal is None:
             self._tzlocal = self.app.conf.timezone
         return self._tzlocal
 
     @property
-    def store_errors(self):
+    def store_errors(self) -> bool:
         return (not self.task.ignore_result or
                 self.task.store_errors_even_if_ignored)
 
     @property
-    def reply_to(self):
+    def reply_to(self) -> str:
         # used by rpc backend when failures reported by parent process
         return self.request_dict['reply_to']
 
     @property
-    def correlation_id(self):
+    def correlation_id(self) -> str:
         # used similarly to reply_to
         return self.request_dict['correlation_id']
 
     @cached_property
-    def _payload(self):
+    def _payload(self) -> Any:
         return self.body if self._decoded else self.message.payload
 
     @cached_property
-    def chord(self):
+    def chord(self) -> SignatureT:
         # used by backend.mark_as_failure when failure is reported
         # by parent process
         # pylint: disable=unpacking-non-sequence
@@ -464,7 +485,7 @@ class Request:
         return embed.get('chord')
 
     @cached_property
-    def errbacks(self):
+    def errbacks(self) -> Sequence[SignatureT]:
         # used by backend.mark_as_failure when failure is reported
         # by parent process
         # pylint: disable=unpacking-non-sequence
@@ -473,15 +494,19 @@ class Request:
         return embed.get('errbacks')
 
     @cached_property
-    def group(self):
+    def group(self) -> str:
         # used by backend.on_chord_part_return when failures reported
         # by parent process
         return self.request_dict['group']
 
 
-def create_request_cls(base, task, pool, hostname, eventer,
-                       ref=ref, revoked_tasks=revoked_tasks,
-                       task_ready=task_ready, trace=trace_task_ret):
+def create_request_cls(base: type, task: TaskT, pool: PoolT,
+                       hostname: str, eventer: EventDispatcher,
+                       *,
+                       ref: Callable = ref,
+                       revoked_tasks: LimitedSet = revoked_tasks,
+                       task_ready: Callable = task_ready,
+                       trace: Callable = trace_task_ret) -> type:
     default_time_limit = task.time_limit
     default_soft_time_limit = task.soft_time_limit
     apply_async = pool.apply_async

+ 52 - 40
celery/worker/state.py

@@ -12,12 +12,17 @@ import weakref
 import zlib
 
 from collections import Counter
+from typing import (
+    Any, Callable, Mapping, MutableMapping, Set, Sequence, Union,
+)
 
+from kombu.clocks import Clock
 from kombu.serialization import pickle, pickle_protocol
 from kombu.utils.objects import cached_property
 
 from celery import __version__
 from celery.exceptions import WorkerShutdown, WorkerTerminate
+from celery.types import RequestT
 from celery.utils.collections import LimitedSet
 
 __all__ = [
@@ -27,7 +32,7 @@ __all__ = [
 ]
 
 #: Worker software/platform information.
-SOFTWARE_INFO = {
+SOFTWARE_INFO: Mapping[str, Any] = {
     'sw_ident': 'py-celery',
     'sw_ver': __version__,
     'sw_sys': platform.system(),
@@ -41,28 +46,28 @@ REVOKES_MAX = 50000
 REVOKE_EXPIRES = 10800
 
 #: Mapping of reserved task_id->Request.
-requests = {}
+requests: Mapping[RequestT] = {}
 
 #: set of all reserved :class:`~celery.worker.request.Request`'s.
-reserved_requests = weakref.WeakSet()
+reserved_requests: Set[RequestT] = weakref.WeakSet()
 
 #: set of currently active :class:`~celery.worker.request.Request`'s.
-active_requests = weakref.WeakSet()
+active_requests: Set[RequestT] = weakref.WeakSet()
 
 #: count of tasks accepted by the worker, sorted by type.
-total_count = Counter()
+total_count: Mapping[str, int] = Counter()
 
 #: count of all tasks accepted by the worker
-all_total_count = [0]
+all_total_count: Sequence[int] = [0]
 
 #: the list of currently revoked tasks.  Persistent if ``statedb`` set.
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
-should_stop = None
-should_terminate = None
+should_stop: Union[int, bool] = None
+should_terminate: Union[int, bool] = None
 
 
-def reset_state():
+def reset_state() -> None:
     requests.clear()
     reserved_requests.clear()
     active_requests.clear()
@@ -71,7 +76,7 @@ def reset_state():
     revoked.clear()
 
 
-def maybe_shutdown():
+def maybe_shutdown() -> None:
     """Shutdown if flags have been set."""
     if should_stop is not None and should_stop is not False:
         raise WorkerShutdown(should_stop)
@@ -79,28 +84,34 @@ def maybe_shutdown():
         raise WorkerTerminate(should_terminate)
 
 
-def task_reserved(request,
-                  add_request=requests.__setitem__,
-                  add_reserved_request=reserved_requests.add):
+def task_reserved(
+        request: RequestT,
+        *,
+        add_request: Callable = requests.__setitem__,
+        add_reserved_request: Callable = reserved_requests.add) -> None:
     """Update global state when a task has been reserved."""
     add_request(request.id, request)
     add_reserved_request(request)
 
 
-def task_accepted(request,
-                  _all_total_count=all_total_count,
-                  add_active_request=active_requests.add,
-                  add_to_total_count=total_count.update):
+def task_accepted(
+        request: RequestT,
+        *,
+        _all_total_count: Sequence[int] = all_total_count,
+        add_active_request: Callable = active_requests.add,
+        add_to_total_count: Callable = total_count.update) -> None:
     """Update global state when a task has been accepted."""
     add_active_request(request)
     add_to_total_count({request.name: 1})
     all_total_count[0] += 1
 
 
-def task_ready(request,
-               remove_request=requests.pop,
-               discard_active_request=active_requests.discard,
-               discard_reserved_request=reserved_requests.discard):
+def task_ready(
+        request: RequestT,
+        *,
+        remove_request: Callable = requests.pop,
+        discard_reserved_request: Callable = reserved_requests.discard,
+        discard_active_request: Callable = active_requests.discard) -> None:
     """Update global state when a task is ready."""
     remove_request(request.id, None)
     discard_active_request(request)
@@ -129,7 +140,7 @@ if C_BENCH:  # pragma: no cover
 
     if current_process()._name == 'MainProcess':
         @atexit.register
-        def on_shutdown():
+        def on_shutdown() -> None:
             if bench_first is not None and bench_last is not None:
                 print('- Time spent in benchmark: {0!r}'.format(
                       bench_last - bench_first))
@@ -137,7 +148,7 @@ if C_BENCH:  # pragma: no cover
                       sum(bench_sample) / len(bench_sample)))
                 memdump()
 
-    def task_reserved(request):  # noqa
+    def task_reserved(request: RequestT, **kwargs) -> None:  # noqa
         """Called when a task is reserved by the worker."""
         global bench_start
         global bench_first
@@ -149,7 +160,7 @@ if C_BENCH:  # pragma: no cover
 
         return __reserved(request)
 
-    def task_ready(request):  # noqa
+    def task_ready(request: RequestT, **kwargs) -> None:  # noqa
         """Called when a task is completed."""
         global all_count
         global bench_start
@@ -182,39 +193,40 @@ class Persistent:
     decompress = zlib.decompress
     _is_open = False
 
-    def __init__(self, state, filename, clock=None):
+    def __init__(self, state: Any,
+                 filename: str, clock: Clock = None) -> None:
         self.state = state
         self.filename = filename
         self.clock = clock
         self.merge()
 
-    def open(self):
+    def open(self) -> MutableMapping:
         return self.storage.open(
             self.filename, protocol=self.protocol, writeback=True,
         )
 
-    def merge(self):
+    def merge(self) -> None:
         self._merge_with(self.db)
 
-    def sync(self):
+    def sync(self) -> None:
         self._sync_with(self.db)
         self.db.sync()
 
-    def close(self):
+    def close(self) -> None:
         if self._is_open:
             self.db.close()
             self._is_open = False
 
-    def save(self):
+    def save(self) -> None:
         self.sync()
         self.close()
 
-    def _merge_with(self, d):
+    def _merge_with(self, d: MutableMapping) -> MutableMapping:
         self._merge_revoked(d)
         self._merge_clock(d)
         return d
 
-    def _sync_with(self, d):
+    def _sync_with(self, d: MutableMapping) -> MutableMapping:
         self._revoked_tasks.purge()
         d.update({
             str('__proto__'): 3,
@@ -223,11 +235,11 @@ class Persistent:
         })
         return d
 
-    def _merge_clock(self, d):
+    def _merge_clock(self, d: MutableMapping):
         if self.clock:
             d[str('clock')] = self.clock.adjust(d.get(str('clock')) or 0)
 
-    def _merge_revoked(self, d):
+    def _merge_revoked(self, d: MutableMapping) -> None:
         try:
             self._merge_revoked_v3(d[str('zrevoked')])
         except KeyError:
@@ -238,29 +250,29 @@ class Persistent:
         # purge expired items at boot
         self._revoked_tasks.purge()
 
-    def _merge_revoked_v3(self, zrevoked):
+    def _merge_revoked_v3(self, zrevoked: str) -> None:
         if zrevoked:
             self._revoked_tasks.update(pickle.loads(self.decompress(zrevoked)))
 
-    def _merge_revoked_v2(self, saved):
+    def _merge_revoked_v2(self, saved: Mapping) -> Mapping:
         if not isinstance(saved, LimitedSet):
             # (pre 3.0.18) used to be stored as a dict
             return self._merge_revoked_v1(saved)
         self._revoked_tasks.update(saved)
 
-    def _merge_revoked_v1(self, saved):
+    def _merge_revoked_v1(self, saved: Sequence) -> None:
         add = self._revoked_tasks.add
         for item in saved:
             add(item)
 
-    def _dumps(self, obj):
+    def _dumps(self, obj: Any) -> bytes:
         return pickle.dumps(obj, protocol=self.protocol)
 
     @property
-    def _revoked_tasks(self):
+    def _revoked_tasks(self) -> LimitedSet:
         return self.state.revoked
 
     @cached_property
-    def db(self):
+    def db(self) -> MutableMapping:
         self._is_open = True
         return self.open()

+ 45 - 14
celery/worker/strategy.py

@@ -1,14 +1,16 @@
 # -*- coding: utf-8 -*-
 """Task execution strategy (optimization)."""
 import logging
-
+from typing import (
+    Awaitable, Callable, Dict, List, Mapping, NamedTuple, Sequence, Tuple,
+)
 from kombu.async.timer import to_timestamp
-
+from kombu.types import MessageT
 from celery.exceptions import InvalidTaskError
+from celery.types import AppT, WorkerConsumerT
 from celery.utils.log import get_logger
 from celery.utils.saferepr import saferepr
 from celery.utils.time import timezone
-
 from .request import Request, create_request_cls
 from .state import task_reserved
 
@@ -20,7 +22,16 @@ logger = get_logger(__name__)
 # We cache globals and attribute lookups, so disable this warning.
 
 
-def proto1_to_proto2(message, body):
+class converted_message_t(NamedTuple):
+    """Describes a converted message."""
+
+    body: Tuple[List, Dict, Mapping]
+    headers: Mapping
+    decoded: bool
+    utc: bool
+
+
+def proto1_to_proto2(message: MessageT, body: Mapping) -> converted_message_t:
     """Convert Task message protocol 1 arguments to protocol 2.
 
     Returns:
@@ -50,13 +61,28 @@ def proto1_to_proto2(message, body):
         'chord': body.get('chord'),
         'chain': None,
     }
-    return (args, kwargs, embed), body, True, body.get('utc', True)
+    return converted_message_t(
+        body=(args, kwargs, embed),
+        headers=body,
+        decoded=True,
+        utc=body.get('utc', True),
+    )
+
 
+StrategyT = Callable[
+    [MessageT, Mapping, Callable, Callable, Sequence[Callable]],
+    Awaitable,
+]
 
-def default(task, app, consumer,
-            info=logger.info, error=logger.error, task_reserved=task_reserved,
-            to_system_tz=timezone.to_system, bytes=bytes,
-            proto1_to_proto2=proto1_to_proto2):
+
+def default(task: str, app: AppT, consumer: WorkerConsumerT,
+            *,
+            info: Callable = logger.info,
+            error: Callable = logger.error,
+            task_reserved: Callable = task_reserved,
+            to_system_tz: Callable = timezone.to_system,
+            proto1_to_proto2: Callable = proto1_to_proto2,
+            bytes: Callable = bytes) -> StrategyT:
     """Default task execution strategy.
 
     Note:
@@ -83,9 +109,14 @@ def default(task, app, consumer,
     Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
 
     revoked_tasks = consumer.controller.state.revoked
-
-    def task_message_handler(message, body, ack, reject, callbacks,
-                             to_timestamp=to_timestamp):
+    convert_to_timestamp = to_timestamp
+
+    async def task_message_handler(
+            message: MessageT,
+            body: Mapping,
+            ack: Callable,
+            reject: Callable,
+            callbacks: Sequence[Callable]) -> Awaitable:
         if body is None:
             body, headers, decoded, utc = (
                 message.body, message.headers, False, True,
@@ -118,9 +149,9 @@ def default(task, app, consumer,
         if req.eta:
             try:
                 if req.utc:
-                    eta = to_timestamp(to_system_tz(req.eta))
+                    eta = convert_to_timestamp(to_system_tz(req.eta))
                 else:
-                    eta = to_timestamp(req.eta, timezone.local)
+                    eta = convert_to_timestamp(req.eta, timezone.local)
             except (OverflowError, ValueError) as exc:
                 error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                       req.eta, exc, req.info(safe=True), exc_info=True)

+ 116 - 85
celery/worker/worker.py

@@ -15,6 +15,8 @@ The worker consists of several components, all managed by bootsteps
 import os
 import sys
 
+from typing import Any, Callable, Mapping, Set, Sequence, Union
+
 from billiard import cpu_count
 from kombu.utils.compat import detect_environment
 
@@ -25,7 +27,10 @@ from celery import signals
 from celery.exceptions import (
     ImproperlyConfigured, WorkerTerminate, TaskRevokedError,
 )
-from celery.platforms import EX_FAILURE, create_pidlock
+from celery.types import (
+    AppT, BlueprintT, LoopT, PoolT, RequestT, StepT, WorkerConsumerT,
+)
+from celery.platforms import EX_FAILURE, Pidfile, create_pidlock
 from celery.utils.imports import reload_from_cwd
 from celery.utils.log import mlevel, worker_logger as logger
 from celery.utils.nodenames import default_nodename, worker_direct
@@ -57,25 +62,27 @@ Trying to deselect queue subset of {0!r}, but queue {1} isn't
 defined in the `task_queues` setting.
 """
 
+CSVListArgT = Union[Sequence[str], str]
+
 
 class WorkController:
     """Unmanaged worker instance."""
 
-    app = None
+    app: AppT = None
 
-    pidlock = None
-    blueprint = None
-    pool = None
-    semaphore = None
+    pidlock: Pidfile = None
+    blueprint: BlueprintT = None
+    pool: PoolT = None
+    semaphore: Any = None
 
     #: contains the exit code if a :exc:`SystemExit` event is handled.
-    exitcode = None
+    exitcode: int = None
 
     class Blueprint(bootsteps.Blueprint):
         """Worker bootstep blueprint."""
 
         name = 'Worker'
-        default_steps = {
+        default_steps: Set[Union[str, StepT]] = {
             'celery.worker.components:Hub',
             'celery.worker.components:Pool',
             'celery.worker.components:Beat',
@@ -85,7 +92,12 @@ class WorkController:
             'celery.worker.autoscale:WorkerComponent',
         }
 
-    def __init__(self, app=None, hostname=None, **kwargs):
+    def __init__(
+            self,
+            *,
+            app: AppT = None,
+            hostname: str = None,
+            **kwargs) -> None:
         self.app = app or self.app
         self.hostname = default_nodename(hostname)
         self.app.loader.init_worker()
@@ -95,9 +107,15 @@ class WorkController:
 
         self.setup_instance(**self.prepare_args(**kwargs))
 
-    def setup_instance(self, queues=None, ready_callback=None, pidfile=None,
-                       include=None, use_eventloop=None, exclude_queues=None,
-                       **kwargs):
+    def setup_instance(self,
+                       *,
+                       queues: CSVListArgT = None,
+                       ready_callback: Callable = None,
+                       pidfile: str = None,
+                       include: CSVListArgT = None,
+                       use_eventloop: bool = None,
+                       exclude_queues: CSVListArgT = None,
+                       **kwargs) -> None:
         self.pidfile = pidfile
         self.setup_queues(queues, exclude_queues)
         self.setup_includes(str_to_list(include))
@@ -135,33 +153,34 @@ class WorkController:
         )
         self.blueprint.apply(self, **kwargs)
 
-    def on_init_blueprint(self):
-        pass
+    def on_init_blueprint(self) -> None:
+        ...
 
-    def on_before_init(self, **kwargs):
-        pass
+    def on_before_init(self, **kwargs) -> None:
+        ...
 
-    def on_after_init(self, **kwargs):
-        pass
+    def on_after_init(self, **kwargs) -> None:
+        ...
 
-    def on_start(self):
+    def on_start(self) -> None:
         if self.pidfile:
             self.pidlock = create_pidlock(self.pidfile)
 
-    def on_consumer_ready(self, consumer):
-        pass
+    def on_consumer_ready(self, consumer: WorkerConsumerT) -> None:
+        ...
 
-    def on_close(self):
+    def on_close(self) -> None:
         self.app.loader.shutdown_worker()
 
-    def on_stopped(self):
+    def on_stopped(self) -> None:
         self.timer.stop()
         self.consumer.shutdown()
 
         if self.pidlock:
             self.pidlock.release()
 
-    def setup_queues(self, include, exclude=None):
+    def setup_queues(self, include: CSVListArgT,
+                     exclude: CSVListArgT = None) -> None:
         include = str_to_list(include)
         exclude = str_to_list(exclude)
         try:
@@ -177,7 +196,7 @@ class WorkController:
         if self.app.conf.worker_direct:
             self.app.amqp.queues.select_add(worker_direct(self.hostname))
 
-    def setup_includes(self, includes):
+    def setup_includes(self, includes: Sequence[str]) -> None:
         # Update celery_include to have all known task modules, so that we
         # ensure all task modules are imported in case an execv happens.
         prev = tuple(self.app.conf.include)
@@ -189,81 +208,88 @@ class WorkController:
                         for task in self.app.tasks.values()}
         self.app.conf.include = tuple(set(prev) | task_modules)
 
-    def prepare_args(self, **kwargs):
+    def prepare_args(self, **kwargs) -> Mapping:
         return kwargs
 
-    def _send_worker_shutdown(self):
+    def _send_worker_shutdown(self) -> None:
         signals.worker_shutdown.send(sender=self)
 
-    def start(self):
+    async def start(self) -> None:
         try:
-            self.blueprint.start(self)
+            await self.blueprint.start(self)
         except WorkerTerminate:
-            self.terminate()
+            await self.terminate()
         except Exception as exc:
             logger.critical('Unrecoverable error: %r', exc, exc_info=True)
-            self.stop(exitcode=EX_FAILURE)
+            await self.stop(exitcode=EX_FAILURE)
         except SystemExit as exc:
-            self.stop(exitcode=exc.code)
+            await self.stop(exitcode=exc.code)
         except KeyboardInterrupt:
-            self.stop(exitcode=EX_FAILURE)
+            await self.stop(exitcode=EX_FAILURE)
 
-    def register_with_event_loop(self, hub):
-        self.blueprint.send_all(
+    async def register_with_event_loop(self, hub: LoopT) -> None:
+        await self.blueprint.send_all(
             self, 'register_with_event_loop', args=(hub,),
             description='hub.register',
         )
 
-    def _process_task_sem(self, req):
-        return self._quick_acquire(self._process_task, req)
+    async def _process_task_sem(self, req: RequestT) -> None:
+        await self._quick_acquire(self._process_task, req)
 
-    def _process_task(self, req):
+    async def _process_task(self, req: RequestT) -> None:
         """Process task by sending it to the pool of workers."""
         try:
-            req.execute_using_pool(self.pool)
+            await req.execute_using_pool(self.pool)
         except TaskRevokedError:
             try:
                 self._quick_release()   # Issue 877
             except AttributeError:
                 pass
 
-    def signal_consumer_close(self):
+    def signal_consumer_close(self) -> None:
         try:
             self.consumer.close()
         except AttributeError:
             pass
 
-    def should_use_eventloop(self):
+    def should_use_eventloop(self) -> bool:
         return (detect_environment() == 'default' and
                 self._conninfo.transport.implements.async and
                 not self.app.IS_WINDOWS)
 
-    def stop(self, in_sighandler=False, exitcode=None):
+    async def stop(self,
+                   *,
+                   in_sighandler: bool = False,
+                   exitcode: int = None) -> None:
         """Graceful shutdown of the worker server."""
         if exitcode is not None:
             self.exitcode = exitcode
         if self.blueprint.state == RUN:
             self.signal_consumer_close()
             if not in_sighandler or self.pool.signal_safe:
-                self._shutdown(warm=True)
+                await self._shutdown(warm=True)
         self._send_worker_shutdown()
 
-    def terminate(self, in_sighandler=False):
+    async def terminate(self, *, in_sighandler: bool = False) -> None:
         """Not so graceful shutdown of the worker server."""
         if self.blueprint.state != TERMINATE:
             self.signal_consumer_close()
             if not in_sighandler or self.pool.signal_safe:
-                self._shutdown(warm=False)
+                await self._shutdown(warm=False)
 
-    def _shutdown(self, warm=True):
+    async def _shutdown(self, *, warm: bool = True) -> None:
         # if blueprint does not exist it means that we had an
         # error before the bootsteps could be initialized.
         if self.blueprint is not None:
             with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT):  # Issue 975
-                self.blueprint.stop(self, terminate=not warm)
+                await self.blueprint.stop(self, terminate=not warm)
                 self.blueprint.join()
 
-    def reload(self, modules=None, reload=False, reloader=None):
+    def reload(self,
+               modules: Sequence[str] = None,
+               *,
+               reload: bool = False,
+               reloader: Callable = None) -> None:
         list(self._reload_modules(
             modules, force_reload=reload, reloader=reloader))
 
@@ -275,14 +301,17 @@ class WorkController:
         except NotImplementedError:
             pass
 
-    def _reload_modules(self, modules=None, **kwargs):
+    def _reload_modules(self, modules: Sequence[str] = None, **kwargs) -> None:
         return (
             self._maybe_reload_module(m, **kwargs)
             for m in set(self.app.loader.task_modules
                          if modules is None else (modules or ()))
         )
 
-    def _maybe_reload_module(self, module, force_reload=False, reloader=None):
+    def _maybe_reload_module(self, module: str,
+                             *,
+                             force_reload: bool = False,
+                             reloader: Callable = None) -> Any:
         if module not in sys.modules:
             logger.debug('importing module %s', module)
             return self.app.loader.import_from_cwd(module)
@@ -290,12 +319,12 @@ class WorkController:
             logger.debug('reloading module %s', module)
             return reload_from_cwd(sys.modules[module], reloader)
 
-    def info(self):
+    def info(self) -> Mapping[str, Any]:
         return {'total': self.state.total_count,
                 'pid': os.getpid(),
                 'clock': str(self.app.clock)}
 
-    def rusage(self):
+    def rusage(self) -> Mapping[str, Any]:
         if resource is None:
             raise NotImplementedError('rusage not supported by this platform')
         s = resource.getrusage(resource.RUSAGE_SELF)
@@ -318,7 +347,7 @@ class WorkController:
             'nivcsw': s.ru_nivcsw,
         }
 
-    def stats(self):
+    def stats(self) -> Mapping[str, Any]:
         info = self.info()
         info.update(self.blueprint.info(self))
         info.update(self.consumer.blueprint.info(self.consumer))
@@ -328,49 +357,54 @@ class WorkController:
             info['rusage'] = 'N/A'
         return info
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """``repr(worker)``."""
         return '<Worker: {self.hostname} ({state})>'.format(
             self=self,
             state=self.blueprint.human_state() if self.blueprint else 'INIT',
         )
 
-    def __str__(self):
+    def __str__(self) -> str:
         """``str(worker) == worker.hostname``."""
         return self.hostname
 
     @property
-    def state(self):
+    def state(self) -> int:
         return state
 
-    def setup_defaults(self, concurrency=None, loglevel='WARN', logfile=None,
-                       task_events=None, pool=None, consumer_cls=None,
-                       timer_cls=None, timer_precision=None,
-                       autoscaler_cls=None,
-                       pool_putlocks=None,
-                       pool_restarts=None,
-                       optimization=None, O=None,  # O maps to -O=fair
-                       statedb=None,
-                       time_limit=None,
-                       soft_time_limit=None,
-                       scheduler=None,
-                       pool_cls=None,              # XXX use pool
-                       state_db=None,              # XXX use statedb
-                       task_time_limit=None,       # XXX use time_limit
-                       task_soft_time_limit=None,  # XXX use soft_time_limit
-                       scheduler_cls=None,         # XXX use scheduler
-                       schedule_filename=None,
-                       max_tasks_per_child=None,
-                       prefetch_multiplier=None, disable_rate_limits=None,
-                       worker_lost_wait=None,
-                       max_memory_per_child=None, **_kw):
+    def setup_defaults(self,
+                       *,
+                       concurrency: int = None,
+                       loglevel: Union[str, int] = 'WARN',
+                       logfile: str = None,
+                       task_events: bool = None,
+                       pool: Union[str, type] = None,
+                       consumer_cls: Union[str, type] = None,
+                       timer_cls: Union[str, type] = None,
+                       timer_precision: float = None,
+                       autoscaler_cls: Union[str, type] = None,
+                       pool_putlocks: bool = None,
+                       pool_restarts: bool = None,
+                       optimization: str = None,
+                       O: str = None,  # O maps to -O=fair
+                       statedb: str = None,
+                       time_limit: float = None,
+                       soft_time_limit: float = None,
+                       scheduler: Union[str, type] = None,
+                       schedule_filename: str = None,
+                       max_tasks_per_child: int = None,
+                       prefetch_multiplier: float = None,
+                       disable_rate_limits: bool = None,
+                       worker_lost_wait: float = None,
+                       max_memory_per_child: float = None,
+                       **_kw) -> None:
         either = self.app.either
         self.loglevel = loglevel
         self.logfile = logfile
 
         self.concurrency = either('worker_concurrency', concurrency)
         self.task_events = either('worker_send_task_events', task_events)
-        self.pool_cls = either('worker_pool', pool, pool_cls)
+        self.pool_cls = either('worker_pool', pool)
         self.consumer_cls = either('worker_consumer', consumer_cls)
         self.timer_cls = either('worker_timer', timer_cls)
         self.timer_precision = either(
@@ -380,16 +414,13 @@ class WorkController:
         self.autoscaler_cls = either('worker_autoscaler', autoscaler_cls)
         self.pool_putlocks = either('worker_pool_putlocks', pool_putlocks)
         self.pool_restarts = either('worker_pool_restarts', pool_restarts)
-        self.statedb = either('worker_state_db', statedb, state_db)
+        self.statedb = either('worker_state_db', statedb)
         self.schedule_filename = either(
             'beat_schedule_filename', schedule_filename,
         )
-        self.scheduler = either('beat_scheduler', scheduler, scheduler_cls)
-        self.time_limit = either(
-            'task_time_limit', time_limit, task_time_limit)
-        self.soft_time_limit = either(
-            'task_soft_time_limit', soft_time_limit, task_soft_time_limit,
-        )
+        self.scheduler = either('beat_scheduler', scheduler)
+        self.time_limit = either('task_time_limit', time_limit)
+        self.soft_time_limit = either('task_soft_time_limit', soft_time_limit)
         self.max_tasks_per_child = either(
             'worker_max_tasks_per_child', max_tasks_per_child,
         )