Ask Solem 8 năm trước cách đây
mục cha
commit
fe248bc3e2

+ 1 - 1
celery/tests/app/test_log.py

@@ -154,7 +154,7 @@ class test_ColorFormatter(AppCase):
 class test_default_logger(AppCase):
 
     def setup(self):
-        self.setup_logger = self.app.log.setup_logging_subsystem
+        self.setup_logger = self.app.log.setup_logger
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         signals.setup_logging.receivers[:] = []
         self.app.log.already_setup = False

+ 27 - 17
celery/utils/debug.py

@@ -6,22 +6,31 @@ import traceback
 
 from contextlib import contextmanager
 from functools import partial
+from io import StringIO
+from numbers import Number
 from pprint import pprint
+from typing import (
+    Any, AnyStr, Generator, IO, Iterator, Iterable, MutableSequence,
+    Optional, Sequence, SupportsInt, Tuple, Union,
+)
 
 from celery.five import WhateverIO
 from celery.platforms import signals
 
+from .typing import Timeout
+
 try:
     from psutil import Process
 except ImportError:
-    Process = None  # noqa
+    class Process:  # noqa
+        pass
 
 __all__ = [
     'blockdetection', 'sample_mem', 'memdump', 'sample',
     'humanbytes', 'mem_rss', 'ps', 'cry',
 ]
 
-UNITS = (
+UNITS = (               # type: Sequence[Tuple[float, str]]
     (2 ** 40.0, 'TB'),
     (2 ** 30.0, 'GB'),
     (2 ** 20.0, 'MB'),
@@ -29,11 +38,11 @@ UNITS = (
     (0.0, 'b'),
 )
 
-_process = None
-_mem_sample = []
+_process = None         # type: Optional[Process]
+_mem_sample = []        # type: MutableSequence[str]
 
 
-def _on_blocking(signum, frame):
+def _on_blocking(signum: int, frame: Any) -> None:
     import inspect
     raise RuntimeError(
         'Blocking detection timed-out at: {0}'.format(
@@ -43,7 +52,7 @@ def _on_blocking(signum, frame):
 
 
 @contextmanager
-def blockdetection(timeout):
+def blockdetection(timeout: Timeout) -> Generator:
     """A timeout context using ``SIGALRM`` that can be used to detect blocking
     functions."""
     if not timeout:
@@ -62,7 +71,7 @@ def blockdetection(timeout):
             signals.reset_alarm()
 
 
-def sample_mem():
+def sample_mem() -> str:
     """Sample RSS memory usage.
 
     Statistics can then be output by calling :func:`memdump`.
@@ -72,7 +81,7 @@ def sample_mem():
     return current_rss
 
 
-def _memdump(samples=10):  # pragma: no cover
+def _memdump(samples: int=10) -> Tuple[Iterable[Any], str]:  # pragma: no cover
     S = _mem_sample
     prev = list(S) if len(S) <= samples else sample(S, samples)
     _mem_sample[:] = []
@@ -82,7 +91,7 @@ def _memdump(samples=10):  # pragma: no cover
     return prev, after_collect
 
 
-def memdump(samples=10, file=None):  # pragma: no cover
+def memdump(samples: int=10, file: IO=None) -> None:  # pragma: no cover
     """Dump memory statistics.
 
     Will print a sample of all RSS memory samples added by
@@ -101,7 +110,7 @@ def memdump(samples=10, file=None):  # pragma: no cover
     say('- rss (end): {0}.'.format(after_collect))
 
 
-def sample(x, n, k=0):
+def sample(x: Sequence, n: int, k: int=0) -> Iterator[Any]:
     """Given a list `x` a sample of length ``n`` of that list is returned.
 
     E.g. if `n` is 10, and `x` has 100 items, a list of every tenth.
@@ -118,7 +127,7 @@ def sample(x, n, k=0):
         k += j
 
 
-def hfloat(f, p=5):
+def hfloat(f: Union[SupportsInt, AnyStr], p: int=5) -> str:
     """Convert float to value suitable for humans.
 
     Arguments:
@@ -126,10 +135,10 @@ def hfloat(f, p=5):
         p (int): Floating point precision (default is 5).
     """
     i = int(f)
-    return i if i == f else '{0:.{p}}'.format(f, p=p)
+    return str(i) if i == f else '{0:.{p}}'.format(f, p=p)
 
 
-def humanbytes(s):
+def humanbytes(s: Union[float, int]) -> str:
     """Convert bytes to human-readable form (e.g. KB, MB)."""
     return next(
         '{0}{1}'.format(hfloat(s / div if div else s), unit)
@@ -137,14 +146,14 @@ def humanbytes(s):
     )
 
 
-def mem_rss():
+def mem_rss() -> str:
     """Return RSS memory usage as a humanized string."""
     p = ps()
     if p is not None:
         return humanbytes(_process_memory_info(p).rss)
 
 
-def ps():  # pragma: no cover
+def ps() -> Process:  # pragma: no cover
     """Return the global :class:`psutil.Process` instance,
     or :const:`None` if :pypi:`psutil` is not installed."""
     global _process
@@ -153,14 +162,15 @@ def ps():  # pragma: no cover
     return _process
 
 
-def _process_memory_info(process):
+def _process_memory_info(process: Process) -> Any:
     try:
         return process.memory_info()
     except AttributeError:
         return process.get_memory_info()
 
 
-def cry(out=None, sepchr='=', seplen=49):  # pragma: no cover
+def cry(out: Optional[StringIO]=None,
+        sepchr: str='=', seplen: int=49) -> None:  # pragma: no cover
     """Return stack-trace of all active threads,
     taken from https://gist.github.com/737056."""
     import threading

+ 27 - 13
celery/utils/deprecated.py

@@ -2,6 +2,8 @@
 """Deprecation utilities."""
 import warnings
 
+from typing import Any, Callable, Mapping, Optional
+
 from vine.utils import wraps
 
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
@@ -21,8 +23,11 @@ DEPRECATION_FMT = """
 """
 
 
-def warn(description=None, deprecation=None,
-         removal=None, alternative=None, stacklevel=2):
+def warn(description: Optional[str]=None,
+         deprecation: Optional[str]=None,
+         removal: Optional[str]=None,
+         alternative: Optional[str]=None,
+         stacklevel: int=2) -> None:
     ctx = {'description': description,
            'deprecation': deprecation, 'removal': removal,
            'alternative': alternative}
@@ -33,8 +38,10 @@ def warn(description=None, deprecation=None,
     warnings.warn(w, stacklevel=stacklevel)
 
 
-def Callable(deprecation=None, removal=None,
-             alternative=None, description=None):
+def Callable(deprecation: Optional[str]=None,
+             removal: Optional[str]=None,
+             alternative: Optional[str]=None,
+             description: Optional[str]=None) -> Callable:
     """Decorator for deprecated functions.
 
     A deprecation warning will be emitted when the function is called.
@@ -62,9 +69,11 @@ def Callable(deprecation=None, removal=None,
     return _inner
 
 
-def Property(deprecation=None, removal=None,
-             alternative=None, description=None):
-    def _inner(fun):
+def Property(deprecation: Optional[str]=None,
+             removal: Optional[str]=None,
+             alternative: Optional[str]=None,
+             description: Optional[str]=None) -> Callable:
+    def _inner(fun: Callable) -> Any:
         return _deprecated_property(
             fun, deprecation=deprecation, removal=removal,
             alternative=alternative, description=description or fun.__name__)
@@ -73,7 +82,12 @@ def Property(deprecation=None, removal=None,
 
 class _deprecated_property:
 
-    def __init__(self, fget=None, fset=None, fdel=None, doc=None, **depreinfo):
+    def __init__(self,
+                 fget: Optional[Callable]=None,
+                 fset: Optional[Callable]=None,
+                 fdel: Optional[Callable]=None,
+                 doc: Optional[str]=None,
+                 **depreinfo) -> None:
         self.__get = fget
         self.__set = fset
         self.__del = fdel
@@ -83,13 +97,13 @@ class _deprecated_property:
         self.depreinfo = depreinfo
         self.depreinfo.setdefault('stacklevel', 3)
 
-    def __get__(self, obj, type=None):
+    def __get__(self, obj: Any, type: Optional[Any]=None) -> Any:
         if obj is None:
             return self
         warn(**self.depreinfo)
         return self.__get(obj)
 
-    def __set__(self, obj, value):
+    def __set__(self, obj: Any, value: Any) -> Any:
         if obj is None:
             return self
         if self.__set is None:
@@ -97,7 +111,7 @@ class _deprecated_property:
         warn(**self.depreinfo)
         self.__set(obj, value)
 
-    def __delete__(self, obj):
+    def __delete__(self, obj: Any) -> Any:
         if obj is None:
             return self
         if self.__del is None:
@@ -105,8 +119,8 @@ class _deprecated_property:
         warn(**self.depreinfo)
         self.__del(obj)
 
-    def setter(self, fset):
+    def setter(self, fset: Callable) -> Any:
         return self.__class__(self.__get, fset, self.__del, **self.depreinfo)
 
-    def deleter(self, fdel):
+    def deleter(self, fdel: Callable) -> Any:
         return self.__class__(self.__get, self.__set, fdel, **self.depreinfo)

+ 44 - 35
celery/utils/functional.py

@@ -4,8 +4,13 @@ import sys
 
 from collections import UserList
 from functools import partial
-from inspect import getfullargspec, isfunction
+from inspect import FullArgSpec, getfullargspec, isfunction
 from itertools import chain, islice
+from typing import (
+    Any, Callable, Iterable, Iterator, Optional,
+    Mapping, MutableSequence, Sequence, Tuple, Union,
+)
+from typing import MutableSet  # noqa
 
 from kombu.utils.functional import (
     LRUCache, dictfilter, lazy, maybe_evaluate, memoize,
@@ -13,6 +18,8 @@ from kombu.utils.functional import (
 )
 from vine import promise
 
+from .typing import ExcInfo
+
 __all__ = [
     'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
     'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
@@ -27,10 +34,10 @@ def {fun_name}({fun_args}):
 
 class DummyContext:
 
-    def __enter__(self):
+    def __enter__(self) -> Any:
         return self
 
-    def __exit__(self, *exc_info):
+    def __exit__(self, *exc_info: ExcInfo) -> Any:
         pass
 
 
@@ -40,19 +47,18 @@ class mlazy(lazy):
     The function is only evaluated once, every subsequent access
     will return the same value.
     """
-
     #: Set to :const:`True` after the object has been evaluated.
-    evaluated = False
-    _value = None
+    evaluated = False  # type: bool
+    _value = None      # type: Any
 
-    def evaluate(self):
+    def evaluate(self) -> Any:
         if not self.evaluated:
             self._value = super().evaluate()
             self.evaluated = True
         return self._value
 
 
-def noop(*args, **kwargs):
+def noop(*args: Tuple, **kwargs: Mapping) -> Any:
     """No operation.
 
     Takes any arguments/keyword arguments and does nothing.
@@ -60,20 +66,20 @@ def noop(*args, **kwargs):
     pass
 
 
-def pass1(arg, *args, **kwargs):
+def pass1(arg: Any, *args: Tuple, **kwargs: Mapping) -> Any:
     """Take any number of arguments/keyword arguments and return
     the first positional argument."""
     return arg
 
 
-def evaluate_promises(it):
+def evaluate_promises(it: Iterable) -> Iterator[Any]:
     for value in it:
         if isinstance(value, promise):
             value = value()
         yield value
 
 
-def first(predicate, it):
+def first(predicate: Callable[[Any], Any], it: Iterable) -> Any:
     """Return the first element in ``iterable`` that ``predicate`` gives a
     :const:`True` value for.
 
@@ -87,7 +93,7 @@ def first(predicate, it):
     )
 
 
-def firstmethod(method, on_call=None):
+def firstmethod(method: str, on_call: Optional[Callable]=None) -> Any:
     """Return a function that with a list of instances,
     finds the first instance that gives a value for the given method.
 
@@ -109,7 +115,7 @@ def firstmethod(method, on_call=None):
     return _matcher
 
 
-def chunks(it, n):
+def chunks(it: Iterable, n: int) -> Iterable:
     """Split an iterator into chunks with `n` elements each.
 
     Example:
@@ -127,7 +133,8 @@ def chunks(it, n):
         yield [first] + list(islice(it, n - 1))
 
 
-def padlist(container, size, default=None):
+def padlist(container: Sequence, size: int,
+            default: Optional[Any]=None) -> Sequence:
     """Pad list with default elements.
 
     Example:
@@ -143,19 +150,19 @@ def padlist(container, size, default=None):
     return list(container)[:size] + [default] * (size - len(container))
 
 
-def mattrgetter(*attrs):
+def mattrgetter(*attrs: str) -> Callable[[Any], Mapping[str, Any]]:
     """Like :func:`operator.itemgetter` but return :const:`None` on missing
     attributes instead of raising :exc:`AttributeError`."""
     return lambda obj: {attr: getattr(obj, attr, None) for attr in attrs}
 
 
-def uniq(it):
+def uniq(it: Iterable) -> Iterable[Any]:
     """Return all unique elements in ``it``, preserving order."""
-    seen = set()
+    seen = set()  # type: MutableSet
     return (seen.add(obj) or obj for obj in it if obj not in seen)
 
 
-def regen(it):
+def regen(it: Iterable) -> Union[list, tuple, '_regen']:
     """``Regen`` takes any iterable, and if the object is an
     generator it will cache the evaluated list on first access,
     so that the generator can be "consumed" multiple times."""
@@ -167,21 +174,21 @@ def regen(it):
 class _regen(UserList, list):
     # must be subclass of list so that json can encode.
 
-    def __init__(self, it):
-        self.__it = it
-        self.__index = 0
-        self.__consumed = []
+    def __init__(self, it: Iterable) -> None:
+        self.__it = it        # type: Iterator
+        self.__index = 0      # type: int
+        self.__consumed = []  # type: MutableSequence[Any]
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return list, (self.data,)
 
-    def __length_hint__(self):
+    def __length_hint__(self) -> int:
         return self.__it.__length_hint__()
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator:
         return chain(self.__consumed, self.__it)
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: Any) -> Any:
         if index < 0:
             return self.data[index]
         try:
@@ -196,7 +203,7 @@ class _regen(UserList, list):
                 return self.__consumed[index]
 
     @property
-    def data(self):
+    def data(self) -> MutableSequence:
         try:
             self.__consumed.extend(list(self.__it))
         except StopIteration:
@@ -204,7 +211,7 @@ class _regen(UserList, list):
         return self.__consumed
 
 
-def _argsfromspec(spec, replace_defaults=True):
+def _argsfromspec(spec: FullArgSpec, replace_defaults: bool=True) -> str:
     if spec.defaults:
         split = len(spec.defaults)
         defaults = (list(range(len(spec.defaults))) if replace_defaults
@@ -221,7 +228,8 @@ def _argsfromspec(spec, replace_defaults=True):
     ]))
 
 
-def head_from_fun(fun, bound=False, debug=False):
+def head_from_fun(fun: Callable,
+                  bound: bool=False, debug: bool=False) -> partial:
     # we could use inspect.Signature here, but that implementation
     # is very slow since it implements the argument checking
     # in pure-Python.  Instead we use exec to create a new function
@@ -231,7 +239,7 @@ def head_from_fun(fun, bound=False, debug=False):
         name, fun = fun.__class__.__name__, fun.__call__
     else:
         name = fun.__name__
-    definition = FUNHEAD_TEMPLATE.format(
+    definition = FUNHEAD_TEMPLATE.format(   # type: str
         fun_name=name,
         fun_args=_argsfromspec(getfullargspec(fun)),
         fun_value=1,
@@ -240,21 +248,22 @@ def head_from_fun(fun, bound=False, debug=False):
         print(definition, file=sys.stderr)
     namespace = {'__name__': fun.__module__}
     exec(definition, namespace)
-    result = namespace[name]
+    result = namespace[name]  # type: Any
     result._source = definition
     if bound:
         return partial(result, object())
     return result
 
 
-def arity_greater(fun, n):
+def arity_greater(fun: Callable, n: int) -> bool:
     argspec = getfullargspec(fun)
-    return argspec.varargs or len(argspec.args) > n
+    return bool(argspec.varargs or len(argspec.args) > n)
 
 
-def fun_takes_argument(name, fun, position=None):
+def fun_takes_argument(name: str, fun: Callable,
+                       position: Optional[int]=None) -> bool:
     spec = getfullargspec(fun)
-    return (
+    return bool(
         spec.varkw or spec.varargs or
         (len(spec.args) >= position if position else name in spec.args)
     )

+ 58 - 41
celery/utils/graph.py

@@ -2,6 +2,10 @@
 """Dependency graph implementation."""
 from collections import Counter
 from textwrap import dedent
+from typing import (
+    Any, Dict, MutableSet, MutableSequence,
+    Optional, IO, Iterable, Iterator, Sequence, Tuple,
+)
 
 from kombu.utils.encoding import safe_str, bytes_to_str
 
@@ -25,7 +29,7 @@ class CycleError(Exception):
     """A cycle was detected in an acyclic graph."""
 
 
-class DependencyGraph:
+class DependencyGraph(Iterable):
     """A directed acyclic graph of objects and their dependencies.
 
     Supports a robust topological sort
@@ -38,26 +42,27 @@ class DependencyGraph:
         Does not support cycle detection.
     """
 
-    def __init__(self, it=None, formatter=None):
+    def __init__(self, it: Optional[Iterable]=None,
+                 formatter: Optional['GraphFormatter']=None) -> None:
         self.formatter = formatter or GraphFormatter()
-        self.adjacent = {}
+        self.adjacent = {}  # type: Dict[Any, Any]
         if it is not None:
             self.update(it)
 
-    def add_arc(self, obj):
+    def add_arc(self, obj: Any) -> None:
         """Add an object to the graph."""
         self.adjacent.setdefault(obj, [])
 
-    def add_edge(self, A, B):
+    def add_edge(self, A: Any, B: Any) -> None:
         """Add an edge from object ``A`` to object ``B``
         (``A`` depends on ``B``)."""
         self[A].append(B)
 
-    def connect(self, graph):
+    def connect(self, graph: 'DependencyGraph') -> None:
         """Add nodes from another graph."""
         self.adjacent.update(graph.adjacent)
 
-    def topsort(self):
+    def topsort(self) -> Sequence[Any]:
         """Sort the graph topologically.
 
         Returns:
@@ -79,7 +84,7 @@ class DependencyGraph:
                     graph.add_edge(node_c, successor_c)
         return [t[0] for t in graph._khan62()]
 
-    def valency_of(self, obj):
+    def valency_of(self, obj: Any) -> int:
         """Return the valency (degree) of a vertex in the graph."""
         try:
             l = [len(self[obj])]
@@ -89,7 +94,7 @@ class DependencyGraph:
             l.append(self.valency_of(node))
         return sum(l)
 
-    def update(self, it):
+    def update(self, it: Iterable) -> None:
         """Update the graph with data from a list
         of ``(obj, dependencies)`` tuples."""
         tups = list(it)
@@ -99,17 +104,17 @@ class DependencyGraph:
             for dep in deps:
                 self.add_edge(obj, dep)
 
-    def edges(self):
+    def edges(self) -> Iterator[Any]:
         """Return generator that yields for all edges in the graph."""
         return (obj for obj, adj in self.items() if adj)
 
-    def _khan62(self):
+    def _khan62(self) -> Sequence[Any]:
         """Khans simple topological sort algorithm from '62
 
         See https://en.wikipedia.org/wiki/Topological_sorting
         """
-        count = Counter()
-        result = []
+        count = Counter()  # type: Counter
+        result = []        # type: MutableSequence[Any]
 
         for node in self:
             for successor in self[node]:
@@ -127,13 +132,15 @@ class DependencyGraph:
         result.reverse()
         return result
 
-    def _tarjan72(self):
+    def _tarjan72(self) -> Sequence[Any]:
         """Tarjan's algorithm to find strongly connected components.
 
         See Also:
             http://bit.ly/vIMv3h.
         """
-        result, stack, low = [], [], {}
+        result = []  # type: MutableSequence[Any]
+        stack = []   # type: MutableSequence[Any]
+        low = {}     # type: Dict[Any, Any]
 
         def visit(node):
             if node in low:
@@ -159,7 +166,8 @@ class DependencyGraph:
 
         return result
 
-    def to_dot(self, fh, formatter=None):
+    def to_dot(self, fh: IO,
+               formatter: Optional['GraphFormatter']=None) -> None:
         """Convert the graph to DOT format.
 
         Arguments:
@@ -167,7 +175,7 @@ class DependencyGraph:
             formatter (celery.utils.graph.GraphFormatter): Custom graph
                 formatter to use.
         """
-        seen = set()
+        seen = set()  # type: MutableSet
         draw = formatter or self.formatter
 
         def P(s):
@@ -187,28 +195,28 @@ class DependencyGraph:
                 P(draw.edge(obj, req))
         P(draw.tail())
 
-    def format(self, obj):
-        return self.formatter(obj) if self.formatter else obj
+    def format(self, obj: Any) -> Any:
+        return self.formatter.node(obj) if self.formatter else obj
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Any]:
         return iter(self.adjacent)
 
-    def __getitem__(self, node):
+    def __getitem__(self, node: Any) -> Any:
         return self.adjacent[node]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.adjacent)
 
-    def __contains__(self, obj):
+    def __contains__(self, obj: Any) -> bool:
         return obj in self.adjacent
 
-    def items(self):
+    def items(self) -> Any:
         return self.adjacent.items()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '\n'.join(self.repr_node(N) for N in self)
 
-    def repr_node(self, obj, level=1, fmt='{0}({1})'):
+    def repr_node(self, obj: Any, level: int=1, fmt: str='{0}({1})') -> str:
         output = [fmt.format(obj, self.valency_of(obj))]
         if obj in self:
             for other in self[obj]:
@@ -242,8 +250,13 @@ class GraphFormatter:
     term_scheme = {'fillcolor': 'palegreen1', 'color': 'palegreen2'}
     graph_scheme = {'bgcolor': 'mintcream'}
 
-    def __init__(self, root=None, type=None, id=None,
-                 indent=0, inw=' ' * 4, **scheme):
+    def __init__(self,
+                 root: Any=None,
+                 type: Optional[str]=None,
+                 id: Optional[str]=None,
+                 indent: int=0,
+                 inw: str=' ' * 4,
+                 **scheme) -> None:
         self.id = id or 'dependencies'
         self.root = root
         self.type = type or 'digraph'
@@ -253,52 +266,56 @@ class GraphFormatter:
         self.scheme = dict(self.scheme, **scheme)
         self.graph_scheme = dict(self.graph_scheme, root=self.label(self.root))
 
-    def attr(self, name, value):
+    def attr(self, name: str, value: Any) -> str:
         value = '"{0}"'.format(value)
         return self.FMT(self._attr, name=name, value=value)
 
-    def attrs(self, d, scheme=None):
+    def attrs(self, d: Dict, scheme: Optional[Dict]=None) -> str:
         d = dict(self.scheme, **dict(scheme, **d or {}) if scheme else d)
         return self._attrsep.join(
             safe_str(self.attr(k, v)) for k, v in d.items()
         )
 
-    def head(self, **attrs):
+    def head(self, **attrs: Dict[str, str]) -> str:
         return self.FMT(
             self._head, id=self.id, type=self.type,
             attrs=self.attrs(attrs, self.graph_scheme),
         )
 
-    def tail(self):
+    def tail(self) -> str:
         return self.FMT(self._tail)
 
-    def label(self, obj):
+    def label(self, obj: Any) -> str:
         return obj
 
-    def node(self, obj, **attrs):
+    def node(self, obj: Any, **attrs: Dict[str, str]) -> str:
         return self.draw_node(obj, self.node_scheme, attrs)
 
-    def terminal_node(self, obj, **attrs):
+    def terminal_node(self, obj: Any, **attrs: Dict[str, str]) -> str:
         return self.draw_node(obj, self.term_scheme, attrs)
 
-    def edge(self, a, b, **attrs):
+    def edge(self, a: Any, b: Any, **attrs: Dict[str, str]) -> str:
         return self.draw_edge(a, b, **attrs)
 
-    def _enc(self, s):
-        return s.encode('utf-8', 'ignore')
+    def _enc(self, s: str) -> str:
+        return s.encode('utf-8', 'ignore').decode()
 
-    def FMT(self, fmt, *args, **kwargs):
+    def FMT(self, fmt: str, *args, **kwargs) -> str:
         return self._enc(fmt.format(
             *args, **dict(kwargs, IN=self.IN, INp=self.INp)
         ))
 
-    def draw_edge(self, a, b, scheme=None, attrs=None):
+    def draw_edge(self, a: Any, b: Any,
+                  scheme: Optional[Dict]=None,
+                  attrs: Optional[Dict]=None) -> str:
         return self.FMT(
             self._edge, self.label(a), self.label(b),
             dir=self.direction, attrs=self.attrs(attrs, self.edge_scheme),
         )
 
-    def draw_node(self, obj, scheme=None, attrs=None):
+    def draw_node(self, obj: Any,
+                  scheme: Optional[Dict]=None,
+                  attrs: Optional[Dict]=None) -> str:
         return self.FMT(
             self._node, self.label(obj), attrs=self.attrs(attrs, scheme),
         )

+ 15 - 8
celery/utils/imports.py

@@ -7,6 +7,8 @@ import sys
 
 from contextlib import contextmanager
 from imp import reload
+from types import ModuleType
+from typing import Any, Callable, Iterator, Optional
 
 from kombu.utils import symbol_by_name
 
@@ -27,7 +29,7 @@ class NotAPackage(Exception):
     pass
 
 
-def qualname(obj):
+def qualname(obj: Any) -> str:
     if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
         obj = obj.__class__
     q = getattr(obj, '__qualname__', None)
@@ -36,7 +38,7 @@ def qualname(obj):
     return q
 
 
-def instantiate(name, *args, **kwargs):
+def instantiate(name: Any, *args, **kwargs) -> Any:
     """Instantiate class by name.
 
     See :func:`symbol_by_name`.
@@ -45,7 +47,7 @@ def instantiate(name, *args, **kwargs):
 
 
 @contextmanager
-def cwd_in_path():
+def cwd_in_path() -> Iterator:
     cwd = os.getcwd()
     if cwd in sys.path:
         yield
@@ -60,7 +62,9 @@ def cwd_in_path():
                 pass
 
 
-def find_module(module, path=None, imp=None):
+def find_module(module: str,
+                path: Optional[str]=None,
+                imp: Optional[Callable]=None) -> ModuleType:
     """Version of :func:`imp.find_module` supporting dots."""
     if imp is None:
         imp = importlib.import_module
@@ -79,7 +83,9 @@ def find_module(module, path=None, imp=None):
         return _imp.find_module(module)
 
 
-def import_from_cwd(module, imp=None, package=None):
+def import_from_cwd(module: str,
+                    imp: Optional[Callable]=None,
+                    package: Optional[str]=None) -> ModuleType:
     """Import module, but make sure it finds modules
     located in the current directory.
 
@@ -92,20 +98,21 @@ def import_from_cwd(module, imp=None, package=None):
         return imp(module, package=package)
 
 
-def reload_from_cwd(module, reloader=None):
+def reload_from_cwd(module: ModuleType,
+                    reloader: Optional[Callable]=None) -> Any:
     if reloader is None:
         reloader = reload
     with cwd_in_path():
         return reloader(module)
 
 
-def module_file(module):
+def module_file(module: ModuleType) -> str:
     """Return the correct original file name of a module."""
     name = module.__file__
     return name[:-1] if name.endswith('.pyc') else name
 
 
-def gen_task_name(app, name, module_name):
+def gen_task_name(app: Any, name: str, module_name: str) -> str:
     """Generate task name from name/module pair."""
     module_name = module_name or '__main__'
     try:

+ 10 - 8
celery/utils/iso8601.py

@@ -34,7 +34,9 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 """
 import re
 
-from datetime import datetime
+from typing import Optional
+
+from datetime import datetime, tzinfo
 from pytz import FixedOffset
 
 __all__ = ['parse_iso8601']
@@ -51,20 +53,20 @@ TIMEZONE_REGEX = re.compile(
 )
 
 
-def parse_iso8601(datestring):
+def parse_iso8601(datestring: str, tz: Optional[tzinfo]=None) -> datetime:
     """Parse and convert ISO-8601 string into a
     :class:`~datetime.datetime` object"""
     m = ISO8601_REGEX.match(datestring)
     if not m:
         raise ValueError('unable to parse date string %r' % datestring)
     groups = m.groupdict()
-    tz = groups['timezone']
-    if tz == 'Z':
+    tz_str = groups['timezone']
+    if tz_str == 'Z':
         tz = FixedOffset(0)
-    elif tz:
-        m = TIMEZONE_REGEX.match(tz)
-        prefix, hours, minutes = m.groups()
-        hours, minutes = int(hours), int(minutes)
+    elif tz_str:
+        m = TIMEZONE_REGEX.match(tz_str)
+        prefix, hours_str, minutes_str = m.groups()
+        hours, minutes = int(hours_str), int(minutes_str)
         if prefix == '-':
             hours = -hours
             minutes = -minutes

+ 24 - 21
celery/utils/log.py

@@ -8,6 +8,8 @@ import threading
 import traceback
 
 from contextlib import contextmanager
+from typing import Any, Iterable, Iterator, Optional, Tuple, Union
+
 from kombu.log import get_logger as _get_logger, LOG_LEVELS
 from kombu.utils.encoding import safe_str
 
@@ -30,15 +32,15 @@ MP_LOG = os.environ.get('MP_LOG', False)
 # Every logger in the celery package inherits from the "celery"
 # logger, and every task logger inherits from the "celery.task"
 # logger.
-base_logger = logger = _get_logger('celery')
+base_logger = logger = _get_logger('celery')  # type: logging.Logger
 
 
-def set_in_sighandler(value):
+def set_in_sighandler(value: bool) -> None:
     global _in_sighandler
     _in_sighandler = value
 
 
-def iter_open_logger_fds():
+def iter_open_logger_fds() -> Iterable[Any]:
     seen = set()
     loggers = (list(logging.Logger.manager.loggerDict.values()) +
                [logging.getLogger(None)])
@@ -56,7 +58,7 @@ def iter_open_logger_fds():
 
 
 @contextmanager
-def in_sighandler():
+def in_sighandler() -> Iterator:
     set_in_sighandler(True)
     try:
         yield
@@ -64,7 +66,7 @@ def in_sighandler():
         set_in_sighandler(False)
 
 
-def logger_isa(l, p, max=1000):
+def logger_isa(l: logging.Logger, p: logging.Logger, max: int=1000) -> bool:
     this, seen = l, set()
     for _ in range(max):
         if this == p:
@@ -83,7 +85,7 @@ def logger_isa(l, p, max=1000):
     return False
 
 
-def get_logger(name):
+def get_logger(name: Union[str, logging.Logger]) -> logging.Logger:
     l = _get_logger(name)
     if logging.root not in (l, l.parent) and l is not base_logger:
         if not logger_isa(l, base_logger):  # pragma: no cover
@@ -93,14 +95,14 @@ task_logger = get_logger('celery.task')
 worker_logger = get_logger('celery.worker')
 
 
-def get_task_logger(name):
+def get_task_logger(name: Union[str, logging.Logger]) -> logging.Logger:
     logger = get_logger(name)
     if not logger_isa(logger, task_logger):
         logger.parent = task_logger
     return logger
 
 
-def mlevel(level):
+def mlevel(level: Union[int, str]) -> int:
     if level and not isinstance(level, numbers.Integral):
         return LOG_LEVELS[level.upper()]
     return level
@@ -116,16 +118,16 @@ class ColorFormatter(logging.Formatter):
         'CRITICAL': COLORS['magenta'],
     }
 
-    def __init__(self, fmt=None, use_color=True):
+    def __init__(self, fmt: Optional[str]=None, use_color: bool=True) -> None:
         logging.Formatter.__init__(self, fmt)
         self.use_color = use_color
 
-    def formatException(self, ei):
+    def formatException(self, ei: Tuple) -> str:
         if ei and not isinstance(ei, tuple):
             ei = sys.exc_info()
         return logging.Formatter.formatException(self, ei)
 
-    def format(self, record):
+    def format(self, record: logging.LogRecord) -> str:
         msg = logging.Formatter.format(self, record)
         color = self.colors.get(record.levelname)
 
@@ -171,7 +173,8 @@ class LoggingProxy:
     loglevel = logging.ERROR
     _thread = threading.local()
 
-    def __init__(self, logger, loglevel=None):
+    def __init__(self, logger: logging.Logger,
+                 loglevel: Optional[Union[int, str]]=None) -> None:
         self.logger = logger
         self.loglevel = mlevel(loglevel or self.logger.level or self.loglevel)
         self._safewrap_handlers()
@@ -194,7 +197,7 @@ class LoggingProxy:
             handler.handleError = WithSafeHandleError().handleError
         return [wrap_handler(h) for h in self.logger.handlers]
 
-    def write(self, data):
+    def write(self, data: Any) -> None:
         """Write message to logging object."""
         if _in_sighandler:
             return print(safe_str(data), file=sys.__stderr__)
@@ -209,7 +212,7 @@ class LoggingProxy:
             finally:
                 self._thread.recurse_protection = False
 
-    def writelines(self, sequence):
+    def writelines(self, sequence: Iterable[str]) -> None:
         """`writelines(sequence_of_strings) -> None`.
 
         Write the strings to the file.
@@ -220,22 +223,22 @@ class LoggingProxy:
         for part in sequence:
             self.write(part)
 
-    def flush(self):
+    def flush(self) -> None:
         """This object is not buffered so any :meth:`flush` requests
         are ignored."""
         pass
 
-    def close(self):
+    def close(self) -> None:
         """When the object is closed, no write requests are forwarded to
         the logging object anymore."""
         self.closed = True
 
-    def isatty(self):
+    def isatty(self) -> bool:
         """Always return :const:`False`. Just here for file support."""
         return False
 
 
-def get_multiprocessing_logger():
+def get_multiprocessing_logger() -> logging.Logger:
     try:
         from billiard import util
     except ImportError:  # pragma: no cover
@@ -244,7 +247,7 @@ def get_multiprocessing_logger():
         return util.get_logger()
 
 
-def reset_multiprocessing_logger():
+def reset_multiprocessing_logger() -> None:
     try:
         from billiard import util
     except ImportError:  # pragma: no cover
@@ -254,7 +257,7 @@ def reset_multiprocessing_logger():
             util._logger = None
 
 
-def current_process():
+def current_process() -> Any:
     try:
         from billiard import process
     except ImportError:  # pragma: no cover
@@ -263,6 +266,6 @@ def current_process():
         return process.current_process()
 
 
-def current_process_index(base=1):
+def current_process_index(base: int=1) -> int:
     index = getattr(current_process(), 'index', None)
     return index + base if index is not None else index

+ 8 - 5
celery/utils/objects.py

@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 """Object related utilities including introspection, etc."""
+from typing import Any, Callable, Set, Sequence
 
 __all__ = ['Bunch', 'FallbackContext', 'mro_lookup']
 
@@ -7,11 +8,12 @@ __all__ = ['Bunch', 'FallbackContext', 'mro_lookup']
 class Bunch:
     """Object that enables you to modify attributes."""
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs) -> None:
         self.__dict__.update(kwargs)
 
 
-def mro_lookup(cls, attr, stop=set(), monkey_patched=[]):
+def mro_lookup(cls: Any, attr: str,
+               stop: Set=set(), monkey_patched: Sequence=[]) -> Any:
     """Return the first node by MRO order that defines an attribute.
 
     Arguments:
@@ -67,14 +69,15 @@ class FallbackContext:
             return FallbackContext(connection, create_new_connection)
     """
 
-    def __init__(self, provided, fallback, *fb_args, **fb_kwargs):
+    def __init__(self, provided: Any, fallback: Callable,
+                 *fb_args, **fb_kwargs) -> None:
         self.provided = provided
         self.fallback = fallback
         self.fb_args = fb_args
         self.fb_kwargs = fb_kwargs
         self._context = None
 
-    def __enter__(self):
+    def __enter__(self) -> Any:
         if self.provided is not None:
             return self.provided
         context = self._context = self.fallback(
@@ -82,6 +85,6 @@ class FallbackContext:
         ).__enter__()
         return context
 
-    def __exit__(self, *exc_info):
+    def __exit__(self, *exc_info) -> Any:
         if self._context is not None:
             return self._context.__exit__(*exc_info)

+ 17 - 9
celery/utils/saferepr.py

@@ -17,6 +17,9 @@ from decimal import Decimal
 from itertools import chain
 from numbers import Number
 from pprint import _recursion
+from typing import (
+    Any, Callable, Iterator, MutableSequence, Optional, Set, Sequence, Tuple,
+)
 
 from kombu.utils.encoding import bytes_to_str
 
@@ -46,15 +49,16 @@ LIT_TUPLE_END = _literal(')', False, -1)
 LIT_TUPLE_END_SV = _literal(',)', False, -1)
 
 
-def saferepr(o, maxlen=None, maxlevels=3, seen=None):
+def saferepr(o: Any, maxlen: Optional[int]=None,
+             maxlevels: int=3, seen: Optional[Set]=None) -> str:
     return ''.join(_saferepr(
         o, maxlen=maxlen, maxlevels=maxlevels, seen=seen
     ))
 
 
-def _chaindict(mapping,
-               LIT_DICT_KVSEP=LIT_DICT_KVSEP,
-               LIT_LIST_SEP=LIT_LIST_SEP):
+def _chaindict(mapping: Mapping,
+               LIT_DICT_KVSEP: str=LIT_DICT_KVSEP,
+               LIT_LIST_SEP: str=LIT_LIST_SEP) -> Iterator[Any]:
     size = len(mapping)
     for i, (k, v) in enumerate(mapping.items()):
         yield _key(k)
@@ -64,7 +68,7 @@ def _chaindict(mapping,
             yield LIT_LIST_SEP
 
 
-def _chainlist(it, LIT_LIST_SEP=LIT_LIST_SEP):
+def _chainlist(it: Sequence, LIT_LIST_SEP: str=LIT_LIST_SEP) -> Iterator[Any]:
     size = len(it)
     for i, v in enumerate(it):
         yield v
@@ -72,11 +76,12 @@ def _chainlist(it, LIT_LIST_SEP=LIT_LIST_SEP):
             yield LIT_LIST_SEP
 
 
-def _repr_empty_set(s):
+def _repr_empty_set(s: Any) -> str:
     return '%s()' % (type(s).__name__,)
 
 
-def _saferepr(o, maxlen=None, maxlevels=3, seen=None):
+def _saferepr(o: Any, maxlen: Optional[int]=None,
+              maxlevels: int=3, seen: Optional[Set]=None) -> str:
     stack = deque([iter([o])])
     for token, it in reprstream(stack, seen=seen, maxlevels=maxlevels):
         if maxlen is not None and maxlen <= 0:
@@ -105,7 +110,8 @@ def _saferepr(o, maxlen=None, maxlevels=3, seen=None):
                 yield rest2.value
 
 
-def _reprseq(val, lit_start, lit_end, builtin_type, chainer):
+def _reprseq(val: Any, lit_start: str, lit_end: str, builtin_type: Any,
+             chainer: Callable) -> Tuple[Any, Any, Any]:
     if type(val) is builtin_type:  # noqa
         return lit_start, lit_end, chainer(val)
     return (
@@ -115,7 +121,9 @@ def _reprseq(val, lit_start, lit_end, builtin_type, chainer):
     )
 
 
-def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance):
+def reprstream(stack: MutableSequence, seen: Optional[Set]=None,
+               maxlevels: int=3, level: int=0,
+               isinstance: Callable=isinstance) -> Iterator[Any]:
     seen = seen or set()
     append = stack.append
     popleft = stack.popleft

+ 34 - 25
celery/utils/serialization.py

@@ -8,6 +8,7 @@ from base64 import b64encode as base64encode, b64decode as base64decode
 from functools import partial
 from inspect import getmro
 from itertools import takewhile
+from typing import Any, AnyStr, Callable, Optional, Sequence, Union
 
 from kombu.utils.encoding import bytes_to_str, str_to_bytes
 
@@ -32,12 +33,14 @@ except NameError:  # pragma: no cover
     unwanted_base_classes = (Exception, BaseException, object)  # py3k
 
 
-def subclass_exception(name, parent, module):  # noqa
+def subclass_exception(name: str, parent: Any, module: str) -> Any:  # noqa
     return type(name, (parent,), {'__module__': module})
 
 
-def find_pickleable_exception(exc, loads=pickle.loads,
-                              dumps=pickle.dumps):
+def find_pickleable_exception(
+        exc: Exception,
+        loads: Callable[[AnyStr], Any]=pickle.loads,
+        dumps: Callable[[Any], AnyStr]=pickle.dumps) -> Optional[Exception]:
     """With an exception instance, iterate over its super classes (by MRO)
     and find the first super exception that is pickleable.  It does
     not go below :exc:`Exception` (i.e. it skips :exc:`Exception`,
@@ -63,11 +66,12 @@ def find_pickleable_exception(exc, loads=pickle.loads,
             return superexc
 
 
-def itermro(cls, stop):
+def itermro(cls: Any, stop: Any) -> Any:
     return takewhile(lambda sup: sup not in stop, getmro(cls))
 
 
-def create_exception_cls(name, module, parent=None):
+def create_exception_cls(name: str, module: str,
+                         parent: Optional[Any]=None) -> Exception:
     """Dynamically create an exception class."""
     if not parent:
         parent = Exception
@@ -96,15 +100,16 @@ class UnpickleableExceptionWrapper(Exception):
     """
 
     #: The module of the original exception.
-    exc_module = None
+    exc_module = None       # type: str
 
     #: The name of the original exception class.
-    exc_cls_name = None
+    exc_cls_name = None     # type: str
 
     #: The arguments for the original exception.
-    exc_args = None
+    exc_args = None         # type: Sequence[Any]
 
-    def __init__(self, exc_module, exc_cls_name, exc_args, text=None):
+    def __init__(self, exc_module: str, exc_cls_name: str,
+                 exc_args: Sequence[Any], text: Optional[str]=None) -> None:
         safe_exc_args = []
         for arg in exc_args:
             try:
@@ -118,22 +123,22 @@ class UnpickleableExceptionWrapper(Exception):
         self.text = text
         Exception.__init__(self, exc_module, exc_cls_name, safe_exc_args, text)
 
-    def restore(self):
+    def restore(self) -> Exception:
         return create_exception_cls(self.exc_cls_name,
                                     self.exc_module)(*self.exc_args)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.text
 
     @classmethod
-    def from_exception(cls, exc):
+    def from_exception(cls, exc: Exception) -> 'UnpickleableExceptionWrapper':
         return cls(exc.__class__.__module__,
                    exc.__class__.__name__,
                    getattr(exc, 'args', []),
                    safe_repr(exc))
 
 
-def get_pickleable_exception(exc):
+def get_pickleable_exception(exc: Exception) -> Exception:
     """Make sure exception is pickleable."""
     try:
         pickle.loads(pickle.dumps(exc))
@@ -147,7 +152,10 @@ def get_pickleable_exception(exc):
     return UnpickleableExceptionWrapper.from_exception(exc)
 
 
-def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps):
+def get_pickleable_etype(
+        cls: Any,
+        loads: Callable[[AnyStr], Any]=pickle.loads,
+        dumps: Callable[[Any], AnyStr]=pickle.dumps) -> Exception:
     try:
         loads(dumps(cls))
     except:
@@ -156,7 +164,7 @@ def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps):
         return cls
 
 
-def get_pickled_exception(exc):
+def get_pickled_exception(exc: Exception) -> Exception:
     """Get original exception from exception pickled using
     :meth:`get_pickleable_exception`."""
     if isinstance(exc, UnpickleableExceptionWrapper):
@@ -164,17 +172,18 @@ def get_pickled_exception(exc):
     return exc
 
 
-def b64encode(s):
+def b64encode(s: AnyStr) -> str:
     return bytes_to_str(base64encode(str_to_bytes(s)))
 
 
-def b64decode(s):
+def b64decode(s: AnyStr) -> bytes:
     return base64decode(str_to_bytes(s))
 
 
-def strtobool(term, table={'false': False, 'no': False, '0': False,
-                           'true': True, 'yes': True, '1': True,
-                           'on': True, 'off': False}):
+def strtobool(term: Union[str, bool],
+              table={'false': False, 'no': False, '0': False,
+                     'true': True, 'yes': True, '1': True,
+                     'on': True, 'off': False}) -> bool:
     """Convert common terms for true/false to bool
     (true/false/yes/no/on/off/1/0)."""
     if isinstance(term, str):
@@ -185,10 +194,10 @@ def strtobool(term, table={'false': False, 'no': False, '0': False,
     return term
 
 
-def jsonify(obj,
-            builtin_types=(numbers.Real, str), key=None,
-            keyfilter=None,
-            unknown_type_filter=None):
+def jsonify(obj: Any,
+            builtin_types=(numbers.Real, str), key: Optional[str]=None,
+            keyfilter: Optional[Callable[[str], Any]]=None,
+            unknown_type_filter: Optional[Callable[[Any], Any]]=None) -> Any:
     """Transforms object making it suitable for json serialization"""
     from kombu.abstract import Object as KombuDictType
     _jsonify = partial(jsonify, builtin_types=builtin_types, key=key,
@@ -232,7 +241,7 @@ def jsonify(obj,
         return unknown_type_filter(obj)
 
 
-def maybe_reraise():
+def maybe_reraise() -> None:
     """Re-raise if an exception is currently being handled, or return
     otherwise."""
     exc_info = sys.exc_info()

+ 22 - 10
celery/utils/sysinfo.py

@@ -2,43 +2,55 @@
 """System information utilities."""
 import os
 
+from collections import namedtuple
+
 from math import ceil
+from typing import Tuple
 
 from kombu.utils import cached_property
 
-__all__ = ['load_average', 'df']
+__all__ = ['load_average', 'load_average_t', 'df']
+
+load_average_t = namedtuple('load_average_t', (
+    'min_1', 'min_5', 'min_15',
+))
+
+
+def _avg(f: float) -> float:
+    return ceil(f * 1e2) / 1e2
 
 
 if hasattr(os, 'getloadavg'):
 
-    def load_average():
-        return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg())
+    def load_average() -> load_average_t:
+        min_1, min_5, min_15 = os.getloadavg()
+        return load_average_t(_avg(min_1), _avg(min_5), _avg(min_15))
 
 else:  # pragma: no cover
     # Windows doesn't have getloadavg
-    def load_average():  # noqa
-        return (0.0, 0.0, 0.0)
+    def load_average() -> load_average_t:  # noqa
+        return load_average_t(0.0, 0.0, 0.0)
 
 
 class df:
 
-    def __init__(self, path):
+    def __init__(self, path: str) -> None:
         self.path = path
 
     @property
-    def total_blocks(self):
+    def total_blocks(self) -> float:
         return self.stat.f_blocks * self.stat.f_frsize / 1024
 
     @property
-    def available(self):
+    def available(self) -> float:
         return self.stat.f_bavail * self.stat.f_frsize / 1024
 
     @property
-    def capacity(self):
+    def capacity(self) -> int:
         avail = self.stat.f_bavail
         used = self.stat.f_blocks - self.stat.f_bfree
         return int(ceil(used * 100.0 / (used + avail) + 0.5))
 
     @cached_property
-    def stat(self):
+    def stat(self) -> os.statvfs_result:
         return os.statvfs(os.path.abspath(self.path))

+ 3 - 2
celery/utils/term.py

@@ -4,6 +4,7 @@ import platform
 
 from functools import reduce
 from typing import Any, Tuple
+from typing import Mapping  # noqa
 
 __all__ = ['colored']
 
@@ -30,8 +31,8 @@ class colored:
         ...       c.green('dog ')))
     """
 
-    def __init__(self, *s: Tuple[Any],
-                 enabled: bool=True, op: str='', **kwargs):
+    def __init__(self, *s: Tuple[str],
+                 enabled: bool=True, op: str='', **kwargs) -> None:
         self.s = s
         self.enabled = not IS_WINDOWS and enabled
         self.op = op

+ 39 - 36
celery/utils/threads.py

@@ -7,9 +7,12 @@ import threading
 import traceback
 
 from contextlib import contextmanager
+from typing import Any, Callable, Iterator, List, Optional
 
 from celery.local import Proxy
 
+from .typing import Timeout
+
 try:
     from greenlet import getcurrent as get_ident
 except ImportError:  # pragma: no cover
@@ -24,7 +27,6 @@ except ImportError:  # pragma: no cover
             except ImportError:
                 from dummy_thread import get_ident      # noqa
 
-
 __all__ = [
     'bgThread', 'Local', 'LocalStack', 'LocalManager',
     'get_ident', 'default_socket_timeout',
@@ -34,7 +36,7 @@ USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS')
 
 
 @contextmanager
-def default_socket_timeout(timeout):
+def default_socket_timeout(timeout: Timeout) -> Iterator:
     prev = socket.getdefaulttimeout()
     socket.setdefaulttimeout(timeout)
     yield
@@ -43,21 +45,21 @@ def default_socket_timeout(timeout):
 
 class bgThread(threading.Thread):
 
-    def __init__(self, name=None, **kwargs):
+    def __init__(self, name: Optional[str]=None, **kwargs) -> None:
         super().__init__()
         self._is_shutdown = threading.Event()
         self._is_stopped = threading.Event()
         self.daemon = True
         self.name = name or self.__class__.__name__
 
-    def body(self):
+    def body(self) -> None:
         raise NotImplementedError()
 
-    def on_crash(self, msg, *fmt, **kwargs):
+    def on_crash(self, msg: str, *fmt, **kwargs) -> None:
         print(msg.format(*fmt), file=sys.stderr)
         traceback.print_exc(None, sys.stderr)
 
-    def run(self):
+    def run(self) -> None:
         body = self.body
         shutdown_set = self._is_shutdown.is_set
         try:
@@ -73,7 +75,7 @@ class bgThread(threading.Thread):
         finally:
             self._set_stopped()
 
-    def _set_stopped(self):
+    def _set_stopped(self) -> None:
         try:
             self._is_stopped.set()
         except TypeError:  # pragma: no cover
@@ -81,7 +83,7 @@ class bgThread(threading.Thread):
             # so gc collected built-in modules.
             pass
 
-    def stop(self):
+    def stop(self) -> None:
         """Graceful shutdown."""
         self._is_shutdown.set()
         self._is_stopped.wait()
@@ -89,7 +91,7 @@ class bgThread(threading.Thread):
             self.join(threading.TIMEOUT_MAX)
 
 
-def release_local(local):
+def release_local(local: 'Local') -> None:
     """Releases the contents of the local for the current context.
     This makes it possible to use locals without a manager.
 
@@ -112,27 +114,27 @@ def release_local(local):
 class Local:
     __slots__ = ('__storage__', '__ident_func__')
 
-    def __init__(self):
+    def __init__(self) -> None:
         object.__setattr__(self, '__storage__', {})
         object.__setattr__(self, '__ident_func__', get_ident)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator:
         return iter(self.__storage__.items())
 
-    def __call__(self, proxy):
+    def __call__(self, proxy: Any) -> 'Proxy':
         """Create a proxy for a name."""
         return Proxy(self, proxy)
 
-    def __release_local__(self):
+    def __release_local__(self) -> None:
         self.__storage__.pop(self.__ident_func__(), None)
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         try:
             return self.__storage__[self.__ident_func__()][name]
         except KeyError:
             raise AttributeError(name)
 
-    def __setattr__(self, name, value):
+    def __setattr__(self, name: str, value: Any) -> None:
         ident = self.__ident_func__()
         storage = self.__storage__
         try:
@@ -140,7 +142,7 @@ class Local:
         except KeyError:
             storage[ident] = {name: value}
 
-    def __delattr__(self, name):
+    def __delattr__(self, name: str) -> None:
         try:
             del self.__storage__[self.__ident_func__()][name]
         except KeyError:
@@ -172,29 +174,29 @@ class _LocalStack:
     resolves to the topmost item on the stack.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._local = Local()
 
-    def __release_local__(self):
+    def __release_local__(self) -> None:
         self._local.__release_local__()
 
-    def _get__ident_func__(self):
+    def _get__ident_func__(self) -> Optional[Callable]:
         return self._local.__ident_func__
 
-    def _set__ident_func__(self, value):
+    def _set__ident_func__(self, value: Optional[Callable]) -> None:
         object.__setattr__(self._local, '__ident_func__', value)
     __ident_func__ = property(_get__ident_func__, _set__ident_func__)
     del _get__ident_func__, _set__ident_func__
 
-    def __call__(self):
-        def _lookup():
+    def __call__(self) -> 'Proxy':
+        def _lookup() -> Any:
             rv = self.top
             if rv is None:
                 raise RuntimeError('object unbound')
             return rv
         return Proxy(_lookup)
 
-    def push(self, obj):
+    def push(self, obj: Any) -> Any:
         """Pushes a new item to the stack"""
         rv = getattr(self._local, 'stack', None)
         if rv is None:
@@ -202,7 +204,7 @@ class _LocalStack:
         rv.append(obj)
         return rv
 
-    def pop(self):
+    def pop(self) -> Any:
         """Remove the topmost item from the stack, will return the
         old value or `None` if the stack was already empty.
         """
@@ -215,12 +217,12 @@ class _LocalStack:
         else:
             return stack.pop()
 
-    def __len__(self):
+    def __len__(self) -> int:
         stack = getattr(self._local, 'stack', None)
         return len(stack) if stack else 0
 
     @property
-    def stack(self):
+    def stack(self) -> List:
         """get_current_worker_task uses this to find
         the original task that was executed by the worker."""
         stack = getattr(self._local, 'stack', None)
@@ -229,7 +231,7 @@ class _LocalStack:
         return []
 
     @property
-    def top(self):
+    def top(self) -> Any:
         """The topmost item on the stack.  If the stack is empty,
         `None` is returned.
         """
@@ -250,7 +252,8 @@ class LocalManager:
     function for the wrapped locals.
     """
 
-    def __init__(self, locals=None, ident_func=None):
+    def __init__(self, locals: Optional[List]=None,
+                 ident_func: Optional[Callable]=None) -> None:
         if locals is None:
             self.locals = []
         elif isinstance(locals, Local):
@@ -264,14 +267,14 @@ class LocalManager:
         else:
             self.ident_func = get_ident
 
-    def get_ident(self):
+    def get_ident(self) -> Any:
         """Return the context identifier the local objects use internally
         for this context.  You cannot override this method to change the
         behavior but use it to link other context local objects (such as
         SQLAlchemy's scoped sessions) to the Werkzeug locals."""
         return self.ident_func()
 
-    def cleanup(self):
+    def cleanup(self) -> None:
         """Manually clean up the data in the locals for this context.
 
         Call this at the end of the request or use ``make_middleware()``.
@@ -279,7 +282,7 @@ class LocalManager:
         for local in self.locals:
             release_local(local)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<{0} storages: {1}>'.format(
             self.__class__.__name__, len(self.locals))
 
@@ -287,18 +290,18 @@ class LocalManager:
 class _FastLocalStack(threading.local):
 
     def __init__(self):
-        self.stack = []
-        self.push = self.stack.append
-        self.pop = self.stack.pop
+        self.stack = []                 # type: List[Any]
+        self.push = self.stack.append   # type: Callable[[Any], None]
+        self.pop = self.stack.pop       # type: Callable[[], Any]
 
     @property
-    def top(self):
+    def top(self) -> Any:
         try:
             return self.stack[-1]
         except (AttributeError, IndexError):
             return None
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.stack)
 
 if USE_FAST_LOCALS:  # pragma: no cover

+ 27 - 21
celery/utils/timer2.py

@@ -10,7 +10,9 @@ import sys
 import threading
 
 from itertools import count
+from numbers import Number
 from time import sleep
+from typing import Callable, MutableSequence, Optional, Union
 
 from kombu.async.timer import Entry, Timer as Schedule, to_timestamp, logger
 
@@ -30,14 +32,17 @@ class Timer(threading.Thread):
     _timer_count = count(1)
 
     if TIMER_DEBUG:  # pragma: no cover
-        def start(self, *args, **kwargs):
+        def start(self, *args, **kwargs) -> None:
             import traceback
             print('- Timer starting')
             traceback.print_stack()
             super().start(*args, **kwargs)
 
-    def __init__(self, schedule=None, on_error=None, on_tick=None,
-                 on_start=None, max_interval=None, **kwargs):
+    def __init__(self, schedule: Optional[Schedule]=None,
+                 on_error: Optional[Callable]=None,
+                 on_tick: Optional[Callable]=None,
+                 on_start: Optional[Callable]=None,
+                 max_interval: Optional[Number]=None, **kwargs) -> None:
         self.schedule = schedule or self.Schedule(on_error=on_error,
                                                   max_interval=max_interval)
         self.on_start = on_start
@@ -50,7 +55,7 @@ class Timer(threading.Thread):
         self.daemon = True
         self.name = 'Timer-{0}'.format(next(self._timer_count))
 
-    def _next_entry(self):
+    def _next_entry(self) -> Optional[Number]:
         with self.not_empty:
             delay, entry = next(self.scheduler)
             if entry is None:
@@ -60,7 +65,7 @@ class Timer(threading.Thread):
         return self.schedule.apply_entry(entry)
     __next__ = next = _next_entry  # for 2to3
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.running = True
             self.scheduler = iter(self.schedule)
@@ -83,60 +88,61 @@ class Timer(threading.Thread):
             logger.error('Thread Timer crashed: %r', exc, exc_info=True)
             os._exit(1)
 
-    def stop(self):
+    def stop(self) -> None:
         self._is_shutdown.set()
         if self.running:
             self._is_stopped.wait()
             self.join(threading.TIMEOUT_MAX)
             self.running = False
 
-    def ensure_started(self):
+    def ensure_started(self) -> None:
         if not self.running and not self.isAlive():
             if self.on_start:
                 self.on_start(self)
             self.start()
 
-    def _do_enter(self, meth, *args, **kwargs):
+    def _do_enter(self, meth: str, *args, **kwargs) -> Entry:
         self.ensure_started()
         with self.mutex:
             entry = getattr(self.schedule, meth)(*args, **kwargs)
             self.not_empty.notify()
             return entry
 
-    def enter(self, entry, eta, priority=None):
+    def enter(self, entry: Entry, eta: float,
+              priority: Optional[int]=None) -> Entry:
         return self._do_enter('enter_at', entry, eta, priority=priority)
 
-    def call_at(self, *args, **kwargs):
+    def call_at(self, *args, **kwargs) -> Entry:
         return self._do_enter('call_at', *args, **kwargs)
 
-    def enter_after(self, *args, **kwargs):
+    def enter_after(self, *args, **kwargs) -> Entry:
         return self._do_enter('enter_after', *args, **kwargs)
 
-    def call_after(self, *args, **kwargs):
+    def call_after(self, *args, **kwargs) -> Entry:
         return self._do_enter('call_after', *args, **kwargs)
 
-    def call_repeatedly(self, *args, **kwargs):
+    def call_repeatedly(self, *args, **kwargs) -> Entry:
         return self._do_enter('call_repeatedly', *args, **kwargs)
 
-    def exit_after(self, secs, priority=10):
+    def exit_after(self, secs: Union[int, float],
+                   priority: Optional[int]=10) -> None:
         self.call_after(secs, sys.exit, priority)
 
-    def cancel(self, tref):
+    def cancel(self, tref: Entry) -> None:
         tref.cancel()
 
-    def clear(self):
+    def clear(self) -> None:
         self.schedule.clear()
 
-    def empty(self):
+    def empty(self) -> bool:
         return not len(self)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.schedule)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return True
-    __nonzero__ = __bool__
 
     @property
-    def queue(self):
+    def queue(self) -> MutableSequence:
         return self.schedule.queue

+ 52 - 37
celery/utils/timeutils.py

@@ -6,6 +6,7 @@ import time as _time
 
 from calendar import monthrange
 from datetime import date, datetime, timedelta, tzinfo
+from typing import Any, Dict, Optional, Union
 
 from kombu.utils import cached_property, reprcall
 
@@ -53,7 +54,7 @@ class LocalTimezone(tzinfo):
     """
     _offset_cache = {}
 
-    def __init__(self):
+    def __init__(self) -> None:
         # This code is moved in __init__ to execute it as late as possible
         # See get_default_timezone().
         self.STDOFFSET = timedelta(seconds=-_time.timezone)
@@ -64,21 +65,21 @@ class LocalTimezone(tzinfo):
         self.DSTDIFF = self.DSTOFFSET - self.STDOFFSET
         tzinfo.__init__(self)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return '<LocalTimezone: UTC{0:+03d}>'.format(
             int(self.DSTOFFSET.total_seconds() / 3600),
         )
 
-    def utcoffset(self, dt):
+    def utcoffset(self, dt: datetime) -> timedelta:
         return self.DSTOFFSET if self._isdst(dt) else self.STDOFFSET
 
-    def dst(self, dt):
+    def dst(self, dt: timedelta) -> timedelta:
         return self.DSTDIFF if self._isdst(dt) else ZERO
 
-    def tzname(self, dt):
+    def tzname(self, dt: datetime) -> str:
         return _time.tzname[self._isdst(dt)]
 
-    def fromutc(self, dt):
+    def fromutc(self, dt: datetime) -> datetime:
         # The base tzinfo class no longer implements a DST
         # offset aware .fromutc() in Python 3 (Issue #2306).
 
@@ -91,7 +92,7 @@ class LocalTimezone(tzinfo):
             tz = self._offset_cache[offset] = FixedOffset(offset)
         return tz.fromutc(dt.replace(tzinfo=tz))
 
-    def _isdst(self, dt):
+    def _isdst(self, dt: datetime) -> bool:
         tt = (dt.year, dt.month, dt.day,
               dt.hour, dt.minute, dt.second,
               dt.weekday(), 0, 0)
@@ -102,42 +103,45 @@ class LocalTimezone(tzinfo):
 
 class _Zone:
 
-    def tz_or_local(self, tzinfo=None):
+    def tz_or_local(self, tzinfo: Optional[tzinfo]=None) -> tzinfo:
         if tzinfo is None:
             return self.local
         return self.get_timezone(tzinfo)
 
-    def to_local(self, dt, local=None, orig=None):
+    def to_local(self, dt: datetime,
+                 local: Optional[tzinfo]=None,
+                 orig: Optional[tzinfo]=None) -> datetime:
         if is_naive(dt):
             dt = make_aware(dt, orig or self.utc)
         return localize(dt, self.tz_or_local(local))
 
-    def to_system(self, dt):
+    def to_system(self, dt: datetime) -> datetime:
         # tz=None is a special case since Python 3.3, and will
         # convert to the current local timezone (Issue #2306).
         return dt.astimezone(tz=None)
 
-    def to_local_fallback(self, dt):
+    def to_local_fallback(self, dt: datetime) -> datetime:
         if is_naive(dt):
             return make_aware(dt, self.local)
         return localize(dt, self.local)
 
-    def get_timezone(self, zone):
+    def get_timezone(self, zone: Union[str, tzinfo]) -> tzinfo:
         if isinstance(zone, str):
             return _timezone(zone)
         return zone
 
     @cached_property
-    def local(self):
+    def local(self) -> tzinfo:
         return LocalTimezone()
 
     @cached_property
-    def utc(self):
+    def utc(self) -> tzinfo:
         return self.get_timezone('UTC')
 timezone = _Zone()
 
 
-def maybe_timedelta(delta):
+def maybe_timedelta(
+        delta: Optional[Union[numbers.Real, timedelta]]) -> timedelta:
     """Coerces integer to :class:`~datetime.timedelta` if argument
     is an integer."""
     if isinstance(delta, numbers.Real):
@@ -145,7 +149,7 @@ def maybe_timedelta(delta):
     return delta
 
 
-def delta_resolution(dt, delta):
+def delta_resolution(dt: datetime, delta: timedelta) -> datetime:
     """Round a :class:`~datetime.datetime` to the resolution of
     a :class:`~datetime.timedelta`.
 
@@ -169,7 +173,9 @@ def delta_resolution(dt, delta):
     return dt
 
 
-def remaining(start, ends_in, now=None, relative=False):
+def remaining(start: datetime, ends_in: timedelta,
+              now: Optional[Callable[[], datetime]]=None,
+              relative: bool=False) -> timedelta:
     """Calculate the remaining time for a start date and a
     :class:`~datetime.timedelta`.
 
@@ -198,7 +204,7 @@ def remaining(start, ends_in, now=None, relative=False):
     return ret
 
 
-def rate(rate):
+def rate(rate: Union[str, numbers.Number]) -> int:
     """Parse rate strings, such as `"100/m"`, `"2/h"` or `"0.5/s"`
     and convert them to seconds."""
     if rate:
@@ -209,7 +215,7 @@ def rate(rate):
     return 0
 
 
-def weekday(name):
+def weekday(name: str) -> int:
     """Return the position of a weekday (0 - 7, where 0 is Sunday).
 
     Example:
@@ -224,7 +230,9 @@ def weekday(name):
         raise KeyError(name)
 
 
-def humanize_seconds(secs, prefix='', sep='', now='now', microseconds=False):
+def humanize_seconds(secs: numbers.Number,
+                     prefix: str='', sep: str='', now: str='now',
+                     microseconds: bool=False) -> str:
     """Show seconds in human form, e.g. 60 is "1 minute", 7200 is "2
     hours".
 
@@ -245,7 +253,7 @@ def humanize_seconds(secs, prefix='', sep='', now='now', microseconds=False):
     return now
 
 
-def maybe_iso8601(dt):
+def maybe_iso8601(dt: Optional[Union[str, datetime]]) -> Optional[datetime]:
     """Either ``datetime | str -> datetime`` or ``None -> None``"""
     if not dt:
         return
@@ -254,13 +262,13 @@ def maybe_iso8601(dt):
     return parse_iso8601(dt)
 
 
-def is_naive(dt):
+def is_naive(dt: datetime) -> bool:
     """Return :const:`True` if the :class:`~datetime.datetime` is naive
     (does not have timezone information)."""
     return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None
 
 
-def make_aware(dt, tz):
+def make_aware(dt: datetime, tz: tzinfo) -> datetime:
     """Sets the timezone for a :class:`~datetime.datetime` object."""
     try:
         _localize = tz.localize
@@ -275,7 +283,7 @@ def make_aware(dt, tz):
                        _localize(dt, is_dst=False))
 
 
-def localize(dt, tz):
+def localize(dt: datetime, tz: tzinfo) -> datetime:
     """Convert aware :class:`~datetime.datetime` to another timezone."""
     dt = dt.astimezone(tz)
     try:
@@ -292,12 +300,12 @@ def localize(dt, tz):
                        _normalize(dt, is_dst=False))
 
 
-def to_utc(dt):
+def to_utc(dt: datetime) -> datetime:
     """Converts naive :class:`~datetime.datetime` to UTC"""
     return make_aware(dt, timezone.utc)
 
 
-def maybe_make_aware(dt, tz=None):
+def maybe_make_aware(dt: datetime, tz: Optional[tzinfo]=None) -> datetime:
     if is_naive(dt):
         dt = to_utc(dt)
     return localize(
@@ -308,9 +316,16 @@ def maybe_make_aware(dt, tz=None):
 class ffwd:
     """Version of ``dateutil.relativedelta`` that only supports addition."""
 
-    def __init__(self, year=None, month=None, weeks=0, weekday=None, day=None,
-                 hour=None, minute=None, second=None, microsecond=None,
-                 **kwargs):
+    def __init__(self,
+                 year: Optional[int]=None,
+                 month: Optional[int]=None,
+                 weeks: int=0,
+                 weekday: Optional[int]=None,
+                 hour: Optional[int]=None,
+                 minute: Optional[int]=None,
+                 second: Optional[numbers.Number]=None,
+                 microsecond: Optional[numbers.Number]=None,
+                 **kwargs) -> None:
         self.year = year
         self.month = month
         self.weeks = weeks
@@ -323,11 +338,11 @@ class ffwd:
         self.days = weeks * 7
         self._has_time = self.hour is not None or self.minute is not None
 
-    def __repr__(self):
-        return reprcall('ffwd', (), self._fields(weeks=self.weeks,
-                                                 weekday=self.weekday))
+    def __repr__(self) -> str:
+        return reprcall('ffwd', (), self._fields(
+            weeks=self.weeks, weekday=self.weekday))
 
-    def __radd__(self, other):
+    def __radd__(self, other: Any) -> datetime:
         if not isinstance(other, date):
             return NotImplemented
         year = self.year or other.year
@@ -339,7 +354,7 @@ class ffwd:
             ret += timedelta(days=(7 - ret.weekday() + self.weekday) % 7)
         return ret + timedelta(days=self.days)
 
-    def _fields(self, **extra):
+    def _fields(self, **extra) -> Dict:
         return dictfilter({
             'year': self.year, 'month': self.month, 'day': self.day,
             'hour': self.hour, 'minute': self.minute,
@@ -347,15 +362,15 @@ class ffwd:
         }, **extra)
 
 
-def utcoffset(time=_time, localtime=_time.localtime):
+def utcoffset(time=_time, localtime=_time.localtime) -> float:
     if localtime().tm_isdst:
         return time.altzone // 3600
     return time.timezone // 3600
 
 
-def adjust_timestamp(ts, offset, here=utcoffset):
+def adjust_timestamp(ts: float, offset: float, here=utcoffset) -> float:
     return ts - (offset - here()) * 3600
 
 
-def maybe_s_to_ms(v):
+def maybe_s_to_ms(v: Optional[numbers.Number]) -> int:
     return int(float(v) * 1000.0) if v is not None else v