Ask Solem 8 年之前
父節點
當前提交
3abbaa97e8

+ 1 - 1
celery/app/log.py

@@ -93,7 +93,7 @@ class Logging(object):
             return
         if logfile and hostname:
             logfile = node_format(logfile, hostname)
-        self.already_setup = True
+        Logging._setup = True
         loglevel = mlevel(loglevel or self.loglevel)
         format = format or self.format
         colorize = self.supports_color(colorize, logfile)

+ 3 - 0
celery/app/trace.py

@@ -270,6 +270,9 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         :keyword request: Request dict.
 
     """
+    # noqa: C901
+    # pylint: disable=too-many-statements
+
     # If the task doesn't define a custom __call__ method
     # we optimize it away by simply calling the run method directly,
     # saving the extra method call and a line less in the stack trace.

+ 1 - 2
celery/backends/async.py

@@ -45,8 +45,7 @@ class Drainer(object):
     def stop(self):
         pass
 
-    def drain_events_until(self, p, timeout=None, on_interval=None,
-                           monotonic=monotonic, wait=None):
+    def drain_events_until(self, p, timeout=None, on_interval=None, wait=None):
         wait = wait or self.result_consumer.drain_events
         time_start = monotonic()
 

+ 25 - 23
celery/beat.py

@@ -232,6 +232,7 @@ class Scheduler(object):
     def is_due(self, entry):
         return entry.is_due()
 
+    # pylint disable=redefined-outer-name
     def tick(self, event_t=event_t, min=min,
              heappop=heapq.heappop, heappush=heapq.heappush,
              heapify=heapq.heapify, mktime=time.mktime):
@@ -242,8 +243,6 @@ class Scheduler(object):
         Returns:
             float: preferred delay in seconds for next call.
         """
-        # pylint disable=redefined-outer-name
-
         def _when(entry, next_time_to_run):
             return (mktime(entry.schedule.now().timetuple()) +
                     (adjust(next_time_to_run) or 0))
@@ -433,27 +432,7 @@ class PersistentScheduler(Scheduler):
         except Exception as exc:  # pylint: disable=broad-except
             self._store = self._destroy_open_corrupted_schedule(exc)
 
-        for _ in (1, 2):
-            try:
-                self._store[str(b'entries')]
-            except KeyError:
-                # new schedule db
-                try:
-                    self._store[str(b'entries')] = {}
-                except KeyError as exc:
-                    self._store = self._destroy_open_corrupted_schedule(exc)
-                    continue
-            else:
-                if str(b'__version__') not in self._store:
-                    warning('DB Reset: Account for new __version__ field')
-                    self._store.clear()   # remove schedule at 2.2.2 upgrade.
-                elif str(b'tz') not in self._store:
-                    warning('DB Reset: Account for new tz field')
-                    self._store.clear()   # remove schedule at 3.0.8 upgrade
-                elif str(b'utc_enabled') not in self._store:
-                    warning('DB Reset: Account for new utc_enabled field')
-                    self._store.clear()   # remove schedule at 3.0.9 upgrade
-            break
+        self._create_schedule()
 
         tz = self.app.conf.timezone
         stored_tz = self._store.get(str(b'tz'))
@@ -479,6 +458,29 @@ class PersistentScheduler(Scheduler):
         debug('Current schedule:\n' + '\n'.join(
             repr(entry) for entry in values(entries)))
 
+    def _create_schedule(self):
+        for _ in (1, 2):
+            try:
+                self._store[str(b'entries')]
+            except KeyError:
+                # new schedule db
+                try:
+                    self._store[str(b'entries')] = {}
+                except KeyError as exc:
+                    self._store = self._destroy_open_corrupted_schedule(exc)
+                    continue
+            else:
+                if str(b'__version__') not in self._store:
+                    warning('DB Reset: Account for new __version__ field')
+                    self._store.clear()   # remove schedule at 2.2.2 upgrade.
+                elif str(b'tz') not in self._store:
+                    warning('DB Reset: Account for new tz field')
+                    self._store.clear()   # remove schedule at 3.0.8 upgrade
+                elif str(b'utc_enabled') not in self._store:
+                    warning('DB Reset: Account for new utc_enabled field')
+                    self._store.clear()   # remove schedule at 3.0.9 upgrade
+            break
+
     def get_schedule(self):
         return self._store[str(b'entries')]
 

+ 102 - 55
celery/contrib/migrate.py

@@ -14,6 +14,7 @@ from kombu.utils.encoding import ensure_bytes
 from celery.app import app_or_default
 from celery.five import python_2_unicode_compatible, string, string_t
 from celery.utils.nodenames import worker_direct
+from celery.utils.text import str_to_list
 
 __all__ = [
     'StopFiltering', 'State', 'republish', 'migrate_task',
@@ -244,69 +245,115 @@ def prepare_queues(queues):
     return queues
 
 
-def start_filter(app, conn, filter, limit=None, timeout=1.0,
+class Filterer(object):
+
+    def __init__(self, app, conn, filter,
+                 limit=None, timeout=1.0,
                  ack_messages=False, tasks=None, queues=None,
                  callback=None, forever=False, on_declare_queue=None,
                  consume_from=None, state=None, accept=None, **kwargs):
-    """Filter tasks."""
-    state = state or State()
-    queues = prepare_queues(queues)
-    consume_from = [_maybe_queue(app, q)
-                    for q in consume_from or list(queues)]
-    if isinstance(tasks, string_t):
-        tasks = set(tasks.split(','))
-    if tasks is None:
-        tasks = set()
-
-    def update_state(body, message):
-        state.count += 1
-        if limit and state.count >= limit:
+        self.app = app
+        self.conn = conn
+        self.filter = filter
+        self.limit = limit
+        self.timeout = timeout
+        self.ack_messages = ack_messages
+        self.tasks = set(str_to_list(tasks) or [])
+        self.queues = prepare_queues(queues)
+        self.callback = callback
+        self.forever = forever
+        self.on_declare_queue = on_declare_queue
+        self.consume_from = [
+            _maybe_queue(self.app, q)
+            for q in consume_from or list(self.queues)
+        ]
+        self.state = state or State()
+        self.accept = accept
+
+    def start(self):
+        # start migrating messages.
+        with self.prepare_consumer(self.create_consumer()):
+            try:
+                for _ in eventloop(self.conn,  # pragma: no cover
+                                   timeout=self.timeout,
+                                   ignore_timeouts=self.forever):
+                    pass
+            except socket.timeout:
+                pass
+            except StopFiltering:
+                pass
+        return self.state
+
+    def update_state(self, body, message):
+        self.state.count += 1
+        if self.limit and self.state.count >= self.limit:
             raise StopFiltering()
 
-    def ack_message(body, message):
+    def ack_message(self, body, message):
         message.ack()
 
-    consumer = app.amqp.TaskConsumer(conn, queues=consume_from, accept=accept)
-
-    if tasks:
-        filter = filter_callback(filter, tasks)
-        update_state = filter_callback(update_state, tasks)
-        ack_message = filter_callback(ack_message, tasks)
-
-    consumer.register_callback(filter)
-    consumer.register_callback(update_state)
-    if ack_messages:
-        consumer.register_callback(ack_message)
-    if callback is not None:
-        callback = partial(callback, state)
-        if tasks:
-            callback = filter_callback(callback, tasks)
-        consumer.register_callback(callback)
-
-    # declare all queues on the new broker.
-    for queue in consumer.queues:
-        if queues and queue.name not in queues:
-            continue
-        if on_declare_queue is not None:
-            on_declare_queue(queue)
-        try:
-            _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
-            if mcount:
-                state.total_apx += mcount
-        except conn.channel_errors:
-            pass
-
-    # start migrating messages.
-    with consumer:
-        try:
-            for _ in eventloop(conn,  # pragma: no cover
-                               timeout=timeout, ignore_timeouts=forever):
+    def create_consumer(self):
+        return self.app.amqp.TaskConsumer(
+            self.conn,
+            queues=self.consume_from,
+            accept=self.accept,
+        )
+
+    def prepare_consumer(self, consumer):
+        filter = self.filter
+        update_state = self.update_state
+        ack_message = self.ack_message
+        if self.tasks:
+            filter = filter_callback(filter, self.tasks)
+            update_state = filter_callback(update_state, self.tasks)
+            ack_message = filter_callback(ack_message, self.tasks)
+        consumer.register_callback(filter)
+        consumer.register_callback(update_state)
+        if self.ack_messages:
+            consumer.register_callback(self.ack_message)
+        if self.callback is not None:
+            callback = partial(self.callback, self.state)
+            if self.tasks:
+                callback = filter_callback(callback, self.tasks)
+            consumer.register_callback(callback)
+        self.declare_queues(consumer)
+        return consumer
+
+    def declare_queues(self, consumer):
+        # declare all queues on the new broker.
+        for queue in consumer.queues:
+            if self.queues and queue.name not in self.queues:
+                continue
+            if self.on_declare_queue is not None:
+                self.on_declare_queue(queue)
+            try:
+                _, mcount, _ = queue(
+                    consumer.channel).queue_declare(passive=True)
+                if mcount:
+                    self.state.total_apx += mcount
+            except self.conn.channel_errors:
                 pass
-        except socket.timeout:
-            pass
-        except StopFiltering:
-            pass
-    return state
+
+
+def start_filter(app, conn, filter, limit=None, timeout=1.0,
+                 ack_messages=False, tasks=None, queues=None,
+                 callback=None, forever=False, on_declare_queue=None,
+                 consume_from=None, state=None, accept=None, **kwargs):
+    """Filter tasks."""
+    return Filterer(
+        app, conn, filter,
+        limit=limit,
+        timeout=timeout,
+        ack_messages=ack_messages,
+        tasks=tasks,
+        queues=queues,
+        callback=callback,
+        forever=forever,
+        on_declare_queue=on_declare_queue,
+        consume_from=consume_from,
+        state=state,
+        accept=accept,
+        **kwargs).start()
 
 
 def move_task_by_id(task_id, dest, **kwargs):

+ 3 - 0
celery/events/state.py

@@ -509,6 +509,9 @@ class State(object):
         return self._event(dict(fields, type='-'.join(['worker', type_])))[0]
 
     def _create_dispatcher(self):
+        # noqa: C901
+        # pylint: disable=too-many-statements
+        # This code is highly optimized, but not for reusability.
         get_handler = self.handlers.__getitem__
         event_callback = self.event_callback
         wfields = itemgetter('hostname', 'timestamp', 'local_received')

+ 0 - 2
celery/local.py

@@ -564,8 +564,6 @@ def recreate_module(name, compat_modules=(), by_module={}, direct={},
 
 
 def get_compat_module(pkg, name):
-    from .local import Proxy
-
     def prepare(attr):
         if isinstance(attr, string_t):
             return Proxy(getappattr, (attr,))

+ 29 - 25
celery/platforms.py

@@ -534,31 +534,7 @@ def maybe_drop_privileges(uid=None, gid=None):
     gid = gid and parse_gid(gid)
 
     if uid:
-        # If GID isn't defined, get the primary GID of the user.
-        if not gid and pwd:
-            gid = pwd.getpwuid(uid).pw_gid
-        # Must set the GID before initgroups(), as setgid()
-        # is known to zap the group list on some platforms.
-
-        # setgid must happen before setuid (otherwise the setgid operation
-        # may fail because of insufficient privileges and possibly stay
-        # in a privileged group).
-        setgid(gid)
-        initgroups(uid, gid)
-
-        # at last:
-        setuid(uid)
-        # ... and make sure privileges cannot be restored:
-        try:
-            setuid(0)
-        except OSError as exc:
-            if exc.errno != errno.EPERM:
-                raise
-            # we should get here: cannot restore privileges,
-            # everything was fine.
-        else:
-            raise RuntimeError(
-                'non-root user able to restore privileges after setuid.')
+        _setuid(uid, gid)
     else:
         gid and setgid(gid)
 
@@ -568,6 +544,34 @@ def maybe_drop_privileges(uid=None, gid=None):
         raise SecurityError('Still root gid after drop privileges!')
 
 
+def _setuid(uid, gid):
+    # If GID isn't defined, get the primary GID of the user.
+    if not gid and pwd:
+        gid = pwd.getpwuid(uid).pw_gid
+    # Must set the GID before initgroups(), as setgid()
+    # is known to zap the group list on some platforms.
+
+    # setgid must happen before setuid (otherwise the setgid operation
+    # may fail because of insufficient privileges and possibly stay
+    # in a privileged group).
+    setgid(gid)
+    initgroups(uid, gid)
+
+    # at last:
+    setuid(uid)
+    # ... and make sure privileges cannot be restored:
+    try:
+        setuid(0)
+    except OSError as exc:
+        if exc.errno != errno.EPERM:
+            raise
+        # we should get here: cannot restore privileges,
+        # everything was fine.
+    else:
+        raise SecurityError(
+            'non-root user able to restore privileges after setuid.')
+
+
 class Signals(object):
     """Convenience interface to :mod:`signals`.
 

+ 103 - 1
celery/utils/collections.py

@@ -45,6 +45,7 @@ REPR_LIMITED_SET = """\
 
 
 def force_mapping(m):
+    # type: (Any) -> Mapping
     """Wrap object into supporting the mapping interface if necessary."""
     if isinstance(m, (LazyObject, LazySettings)):
         m = m._wrapped
@@ -52,6 +53,7 @@ def force_mapping(m):
 
 
 def lpmerge(L, R):
+    # type: (Mapping, Mapping) -> Mapping
     """In place left precedent dictionary merge.
 
     Keeps values from `L`, if the value in `R` is :const:`None`.
@@ -66,22 +68,26 @@ class OrderedDict(_OrderedDict):
 
     if PY3:  # pragma: no cover
         def _LRUkey(self):
+            # type: () -> Any
             # return value of od.keys does not support __next__,
             # but this version will also not create a copy of the list.
             return next(iter(keys(self)))
     else:
         if _dict_is_ordered:  # pragma: no cover
             def _LRUkey(self):
+                # type: () -> Any
                 # iterkeys is iterable.
                 return next(self.iterkeys())
         else:
             def _LRUkey(self):
+                # type: () -> Any
                 return self._OrderedDict__root[1][2]
 
     if not hasattr(_OrderedDict, 'move_to_end'):
         if _dict_is_ordered:  # pragma: no cover
 
             def move_to_end(self, key, last=True):
+                # type: (Any, bool) -> None
                 if not last:
                     # we don't use this argument, and the only way to
                     # implement this on PyPy seems to be O(n): creating a
@@ -92,6 +98,7 @@ class OrderedDict(_OrderedDict):
         else:
 
             def move_to_end(self, key, last=True):
+                # type: (Any, bool) -> None
                 link = self._OrderedDict__map[key]
                 link_prev = link[0]
                 link_next = link[1]
@@ -117,6 +124,7 @@ class AttributeDictMixin(object):
     """
 
     def __getattr__(self, k):
+        # type: (str) -> Any
         """`d.key -> d[key]`."""
         try:
             return self[k]
@@ -126,6 +134,7 @@ class AttributeDictMixin(object):
                     type(self).__name__, k))
 
     def __setattr__(self, key, value):
+        # type: (str, Any) -> None
         """`d[key] = value -> d.key = value`."""
         self[key] = value
 
@@ -144,49 +153,61 @@ class DictAttribute(object):
     obj = None
 
     def __init__(self, obj):
+        # type: (Any) -> None
         object.__setattr__(self, 'obj', obj)
 
     def __getattr__(self, key):
+        # type: (Any) -> Any
         return getattr(self.obj, key)
 
     def __setattr__(self, key, value):
+        # type: (Any, Any) -> None
         return setattr(self.obj, key, value)
 
     def get(self, key, default=None):
+        # type: (Any, Any) -> Any
         try:
             return self[key]
         except KeyError:
             return default
 
     def setdefault(self, key, default=None):
+        # type: (Any, Any) -> None
         if key not in self:
             self[key] = default
 
     def __getitem__(self, key):
+        # type: (Any) -> Any
         try:
             return getattr(self.obj, key)
         except AttributeError:
             raise KeyError(key)
 
     def __setitem__(self, key, value):
+        # type: (Any, Any) -> Any
         setattr(self.obj, key, value)
 
     def __contains__(self, key):
+        # type: (Any) -> bool
         return hasattr(self.obj, key)
 
     def _iterate_keys(self):
+        # type: () -> Iterable
         return iter(dir(self.obj))
     iterkeys = _iterate_keys
 
     def __iter__(self):
+        # type: () -> Iterable
         return self._iterate_keys()
 
     def _iterate_items(self):
+        # type: () -> Iterable
         for key in self._iterate_keys():
             yield key, getattr(self.obj, key)
     iteritems = _iterate_items
 
     def _iterate_values(self):
+        # type: () -> Iterable
         for key in self._iterate_keys():
             yield getattr(self.obj, key)
     itervalues = _iterate_values
@@ -198,12 +219,15 @@ class DictAttribute(object):
     else:
 
         def keys(self):
+            # type: () -> List[Any]
             return list(self)
 
         def items(self):
+            # type: () -> List[Tuple[Any, Any]]
             return list(self._iterate_items())
 
         def values(self):
+            # type: () -> List[Any]
             return list(self._iterate_values())
 MutableMapping.register(DictAttribute)
 
@@ -217,6 +241,7 @@ class ChainMap(MutableMapping):
     maps = None
 
     def __init__(self, *maps, **kwargs):
+        # type: (*Mapping, **Any) -> None
         maps = list(maps or [{}])
         self.__dict__.update(
             key_t=kwargs.get('key_t'),
@@ -226,11 +251,13 @@ class ChainMap(MutableMapping):
         )
 
     def add_defaults(self, d):
+        # type: (Mapping) -> None
         d = force_mapping(d)
         self.defaults.insert(0, d)
         self.maps.insert(1, d)
 
     def pop(self, key, *default):
+        # type: (Any, *Any) -> Any
         try:
             return self.maps[0].pop(key, *default)
         except KeyError:
@@ -238,12 +265,15 @@ class ChainMap(MutableMapping):
                 'Key not found in the first mapping: {!r}'.format(key))
 
     def __missing__(self, key):
+        # type: (Any) -> Any
         raise KeyError(key)
 
     def _key(self, key):
+        # type: (Any) -> Any
         return self.key_t(key) if self.key_t is not None else key
 
     def __getitem__(self, key):
+        # type: (Any) -> Any
         _key = self._key(key)
         for mapping in self.maps:
             try:
@@ -253,59 +283,72 @@ class ChainMap(MutableMapping):
         return self.__missing__(key)
 
     def __setitem__(self, key, value):
+        # type: (Any, Any) -> None
         self.changes[self._key(key)] = value
 
     def __delitem__(self, key):
+        # type: (Any) -> None
         try:
             del self.changes[self._key(key)]
         except KeyError:
             raise KeyError('Key not found in first mapping: {0!r}'.format(key))
 
     def clear(self):
+        # type: () -> None
         self.changes.clear()
 
     def get(self, key, default=None):
+        # type: (Any, Any) -> Any
         try:
             return self[self._key(key)]
         except KeyError:
             return default
 
     def __len__(self):
+        # type: () -> int
         return len(set().union(*self.maps))
 
     def __iter__(self):
         return self._iterate_keys()
 
     def __contains__(self, key):
+        # type: (Any) -> bool
         key = self._key(key)
         return any(key in m for m in self.maps)
 
     def __bool__(self):
+        # type: () -> bool
         return any(self.maps)
     __nonzero__ = __bool__  # Py2
 
     def setdefault(self, key, default=None):
+        # type: (Any, Any) -> None
         key = self._key(key)
         if key not in self:
             self[key] = default
 
     def update(self, *args, **kwargs):
+        # type: (*Any, **Any) -> Any
         return self.changes.update(*args, **kwargs)
 
     def __repr__(self):
+        # type: () -> str
         return '{0.__class__.__name__}({1})'.format(
             self, ', '.join(map(repr, self.maps)))
 
     @classmethod
     def fromkeys(cls, iterable, *args):
+        # type: (type, Iterable, *Any) -> 'ChainMap'
         """Create a ChainMap with a single dict created from the iterable."""
         return cls(dict.fromkeys(iterable, *args))
 
     def copy(self):
+        # type: () -> 'ChainMap'
         return self.__class__(self.maps[0].copy(), *self.maps[1:])
     __copy__ = copy  # Py2
 
     def _iter(self, op):
+        # type: (Callable) -> Iterable
         # defaults must be first in the stream, so values in
         # changes take precedence.
         # pylint: disable=bad-reversed-sequence
@@ -313,14 +356,17 @@ class ChainMap(MutableMapping):
         return chain(*[op(d) for d in reversed(self.maps)])
 
     def _iterate_keys(self):
+        # type: () -> Iterable
         return uniq(self._iter(lambda d: d.keys()))
     iterkeys = _iterate_keys
 
     def _iterate_items(self):
+        # type: () -> Iterable
         return ((key, self[key]) for key in self)
     iteritems = _iterate_items
 
     def _iterate_values(self):
+        # type: () -> Iterable
         return (self[key] for key in self)
     itervalues = _iterate_values
 
@@ -331,12 +377,15 @@ class ChainMap(MutableMapping):
 
     else:  # noqa
         def keys(self):
+            # type: () -> List[Any]
             return list(self._iterate_keys())
 
         def items(self):
+            # type: () -> List[Tuple[Any, Any]]
             return list(self._iterate_items())
 
         def values(self):
+            # type: () -> List[Any]
             return list(self._iterate_values())
 
 
@@ -356,6 +405,7 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
     """
 
     def __init__(self, changes, defaults=None, keys=None, prefix=None):
+        # type: (Mapping, Mapping, List[str], str) -> None
         defaults = [] if defaults is None else defaults
         super(ConfigurationView, self).__init__(changes, *defaults)
         self.__dict__.update(
@@ -364,6 +414,7 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
         )
 
     def _to_keys(self, key):
+        # type: (str) -> Sequence[str]
         prefix = self.prefix
         if prefix:
             pkey = prefix + key if not key.startswith(prefix) else key
@@ -371,6 +422,7 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
         return key,
 
     def __getitem__(self, key):
+        # type: (str) -> Any
         keys = self._to_keys(key)
         getitem = super(ConfigurationView, self).__getitem__
         for k in keys + (
@@ -389,26 +441,32 @@ class ConfigurationView(ChainMap, AttributeDictMixin):
             raise
 
     def __setitem__(self, key, value):
+        # type: (str, Any) -> Any
         self.changes[self._key(key)] = value
 
     def first(self, *keys):
+        # type: (*str) -> Any
         return first(None, (self.get(key) for key in keys))
 
     def get(self, key, default=None):
+        # type: (str, Any) -> Any
         try:
             return self[key]
         except KeyError:
             return default
 
     def clear(self):
+        # type: () -> None
         """Remove all changes, but keep defaults."""
         self.changes.clear()
 
     def __contains__(self, key):
+        # type: (str) -> bool
         keys = self._to_keys(key)
         return any(any(k in m for k in keys) for m in self.maps)
 
     def swap_with(self, other):
+        # type: (ConfigurationView) -> None
         changes = other.__dict__['changes']
         defaults = other.__dict__['defaults']
         self.__dict__.update(
@@ -478,6 +536,7 @@ class LimitedSet(object):
     max_heap_percent_overload = 15
 
     def __init__(self, maxlen=0, expires=0, data=None, minlen=0):
+        # type: (int, float, Mapping, int) -> None
         self.maxlen = 0 if maxlen is None else maxlen
         self.minlen = 0 if minlen is None else minlen
         self.expires = 0 if expires is None else expires
@@ -495,20 +554,24 @@ class LimitedSet(object):
             raise ValueError('expires cannot be negative!')
 
     def _refresh_heap(self):
+        # type: () -> None
         """Time consuming recreating of heap.  Don't run this too often."""
         self._heap[:] = [entry for entry in values(self._data)]
         heapify(self._heap)
 
     def _maybe_refresh_heap(self):
+        # type: () -> None
         if self._heap_overload >= self.max_heap_percent_overload:
             self._refresh_heap()
 
     def clear(self):
+        # type: () -> None
         """Clear all data, start from scratch again."""
         self._data.clear()
         self._heap[:] = []
 
     def add(self, item, now=None):
+        # type: (Any, float) -> None
         """Add a new item, or reset the expiry time of an existing item."""
         now = now or time.time()
         if item in self._data:
@@ -520,6 +583,7 @@ class LimitedSet(object):
             self.purge()
 
     def update(self, other):
+        # type: (Iterable) -> None
         """Update this set from other LimitedSet, dict or iterable."""
         if not other:
             return
@@ -546,12 +610,14 @@ class LimitedSet(object):
                 self.add(obj)
 
     def discard(self, item):
+        # type: (Any) -> None
         # mark an existing item as removed.  If KeyError is not found, pass.
         self._data.pop(item, None)
         self._maybe_refresh_heap()
     pop_value = discard
 
     def purge(self, now=None):
+        # type: (float) -> None
         """Check oldest items and remove them if needed.
 
         Arguments:
@@ -572,6 +638,7 @@ class LimitedSet(object):
                 self.pop()
 
     def pop(self, default=None):
+        # type: (Any) -> Any
         """Remove and return the oldest item, or :const:`None` when empty."""
         while self._heap:
             _, item = heappop(self._heap)
@@ -584,6 +651,7 @@ class LimitedSet(object):
         return default
 
     def as_dict(self):
+        # type: () -> Dict
         """Whole set as serializable dictionary.
 
         Example:
@@ -599,35 +667,44 @@ class LimitedSet(object):
         return {key: inserted for inserted, key in values(self._data)}
 
     def __eq__(self, other):
+        # type: (Any) -> bool
         return self._data == other._data
 
     def __ne__(self, other):
+        # type: (Any) -> bool
         return not self.__eq__(other)
 
     def __repr__(self):
+        # type: () -> str
         return REPR_LIMITED_SET.format(
             self, name=type(self).__name__, size=len(self),
         )
 
     def __iter__(self):
+        # type: () -> Iterable
         return (i for _, i in sorted(values(self._data)))
 
     def __len__(self):
+        # type: () -> int
         return len(self._data)
 
     def __contains__(self, key):
+        # type: (Any) -> bool
         return key in self._data
 
     def __reduce__(self):
+        # type: () -> Any
         return self.__class__, (
             self.maxlen, self.expires, self.as_dict(), self.minlen)
 
     def __bool__(self):
+        # type: () -> bool
         return bool(self._data)
     __nonzero__ = __bool__  # Py2
 
     @property
     def _heap_overload(self):
+        # type: () -> float
         """Compute how much is heap bigger than data [percents]."""
         return len(self._heap) * 100 / max(len(self._data), 1) - 100
 MutableSet.register(LimitedSet)
@@ -639,16 +716,19 @@ class Evictable(object):
     Empty = Empty
 
     def evict(self):
+        # type: () -> None
         """Force evict until maxsize is enforced."""
         self._evict(range=count)
 
-    def _evict(self, limit=100, range=range):
+    def _evict(self, limit=100):
+        # type: (int) -> None
         try:
             [self._evict1() for _ in range(limit)]
         except IndexError:
             pass
 
     def _evict1(self):
+        # type: () -> None
         if self._evictcount <= self.maxsize:
             raise IndexError()
         try:
@@ -664,6 +744,7 @@ class Messagebuffer(Evictable):
     Empty = Empty
 
     def __init__(self, maxsize, iterable=None, deque=deque):
+        # type: (int, Iterable, Any) -> None
         self.maxsize = maxsize
         self.data = deque(iterable or [])
         self._append = self.data.append
@@ -672,14 +753,17 @@ class Messagebuffer(Evictable):
         self._extend = self.data.extend
 
     def put(self, item):
+        # type: (Any) -> None
         self._append(item)
         self.maxsize and self._evict()
 
     def extend(self, it):
+        # type: (Iterable) -> None
         self._extend(it)
         self.maxsize and self._evict()
 
     def take(self, *default):
+        # type: (*Any) -> Any
         try:
             return self._pop()
         except IndexError:
@@ -688,14 +772,17 @@ class Messagebuffer(Evictable):
             raise self.Empty()
 
     def _pop_to_evict(self):
+        # type: () -> None
         return self.take()
 
     def __repr__(self):
+        # type: () -> str
         return '<{0}: {1}/{2}>'.format(
             type(self).__name__, len(self), self.maxsize,
         )
 
     def __iter__(self):
+        # type: () -> Iterable
         while 1:
             try:
                 yield self._pop()
@@ -703,19 +790,24 @@ class Messagebuffer(Evictable):
                 break
 
     def __len__(self):
+        # type: () -> int
         return self._len()
 
     def __contains__(self, item):
+        # type: () -> bool
         return item in self.data
 
     def __reversed__(self):
+        # type: () -> Iterable
         return reversed(self.data)
 
     def __getitem__(self, index):
+        # type: (Any) -> Any
         return self.data[index]
 
     @property
     def _evictcount(self):
+        # type: () -> int
         return len(self)
 Sequence.register(Messagebuffer)
 
@@ -732,6 +824,7 @@ class BufferMap(OrderedDict, Evictable):
     bufmaxsize = None
 
     def __init__(self, maxsize, iterable=None, bufmaxsize=1000):
+        # type: (int, Iterable, int) -> None
         super(BufferMap, self).__init__()
         self.maxsize = maxsize
         self.bufmaxsize = 1000
@@ -740,17 +833,20 @@ class BufferMap(OrderedDict, Evictable):
         self.total = sum(len(buf) for buf in items(self))
 
     def put(self, key, item):
+        # type: (Any, Any) -> None
         self._get_or_create_buffer(key).put(item)
         self.total += 1
         self.move_to_end(key)   # least recently used.
         self.maxsize and self._evict()
 
     def extend(self, key, it):
+        # type: (Any, Iterable) -> None
         self._get_or_create_buffer(key).extend(it)
         self.total += len(it)
         self.maxsize and self._evict()
 
     def take(self, key, *default):
+        # type: (Any, *Any) -> Any
         item, throw = None, False
         try:
             buf = self[key]
@@ -772,6 +868,7 @@ class BufferMap(OrderedDict, Evictable):
         return item
 
     def _get_or_create_buffer(self, key):
+        # type: (Any) -> Messagebuffer
         try:
             return self[key]
         except KeyError:
@@ -779,12 +876,15 @@ class BufferMap(OrderedDict, Evictable):
             return buf
 
     def _new_buffer(self):
+        # type: () -> Messagebuffer
         return self.Buffer(maxsize=self.bufmaxsize)
 
     def _LRUpop(self, *default):
+        # type: (*Any) -> Any
         return self[self._LRUkey()].take(*default)
 
     def _pop_to_evict(self):
+        # type: () -> None
         for _ in range(100):
             key = self._LRUkey()
             buf = self[key]
@@ -805,10 +905,12 @@ class BufferMap(OrderedDict, Evictable):
                 break
 
     def __repr__(self):
+        # type: () -> str
         return '<{0}: {1}/{2}>'.format(
             type(self).__name__, self.total, self.maxsize,
         )
 
     @property
     def _evictcount(self):
+        # type: () -> int
         return self.total

+ 15 - 0
celery/utils/text.py

@@ -29,6 +29,7 @@ RE_FORMAT = re.compile(r'%(\w)')
 
 
 def str_to_list(s):
+    # type: (str) -> List[str]
     """Convert string to list."""
     if isinstance(s, string_t):
         return s.split(',')
@@ -36,32 +37,38 @@ def str_to_list(s):
 
 
 def dedent_initial(s, n=4):
+    # type: (str, int) -> str
     """Remove identation from first line of text."""
     return s[n:] if s[:n] == ' ' * n else s
 
 
 def dedent(s, n=4, sep='\n'):
+    # type: (str, int, str) -> str
     """Remove identation."""
     return sep.join(dedent_initial(l) for l in s.splitlines())
 
 
 def fill_paragraphs(s, width, sep='\n'):
+    # type: (str, int, str) -> str
     """Fill paragraphs with newlines (or custom separator)."""
     return sep.join(fill(p, width) for p in s.split(sep))
 
 
 def join(l, sep='\n'):
+    # type: (str, str) -> str
     """Concatenate list of strings."""
     return sep.join(v for v in l if v)
 
 
 def ensure_sep(sep, s, n=2):
+    # type: (str, str, int) -> str
     """Ensure text s ends in separator sep'."""
     return s + sep * (n - s.count(sep))
 ensure_newlines = partial(ensure_sep, '\n')
 
 
 def abbr(S, max, ellipsis='...'):
+    # type: (str, int, str) -> str
     """Abbreviate word."""
     if S is None:
         return '???'
@@ -71,6 +78,7 @@ def abbr(S, max, ellipsis='...'):
 
 
 def abbrtask(S, max):
+    # type: (str, int) -> str
     """Abbreviate task name."""
     if S is None:
         return '???'
@@ -82,11 +90,13 @@ def abbrtask(S, max):
 
 
 def indent(t, indent=0, sep='\n'):
+    # type: (str, int, str) -> str
     """Indent text."""
     return sep.join(' ' * indent + p for p in t.split(sep))
 
 
 def truncate(s, maxlen=128, suffix='...'):
+    # type: (str, int, str) -> str
     """Truncate text to a maximum number of characters."""
     if maxlen and len(s) >= maxlen:
         return s[:maxlen].rsplit(' ', 1)[0] + suffix
@@ -94,12 +104,14 @@ def truncate(s, maxlen=128, suffix='...'):
 
 
 def truncate_bytes(s, maxlen=128, suffix=b'...'):
+    # type: (bytes, int, bytes) -> bytes
     if maxlen and len(s) >= maxlen:
         return s[:maxlen].rsplit(b' ', 1)[0] + suffix
     return s
 
 
 def pluralize(n, text, suffix='s'):
+    # type: (int, str, str) -> str
     """Pluralize term when n is greater than one."""
     if n != 1:
         return text + suffix
@@ -107,6 +119,7 @@ def pluralize(n, text, suffix='s'):
 
 
 def pretty(value, width=80, nl_width=80, sep='\n', **kw):
+    # type: (str, int, int, str, **Any) -> str
     """Format value for printing to console."""
     if isinstance(value, dict):
         return '{{{0} {1}'.format(sep, pformat(value, 4, nl_width)[1:])
@@ -119,10 +132,12 @@ def pretty(value, width=80, nl_width=80, sep='\n', **kw):
 
 
 def match_case(s, other):
+    # type: (str, str) -> str
     return s.upper() if other.isupper() else s.lower()
 
 
 def simple_format(s, keys, pattern=RE_FORMAT, expand=r'\1'):
+    # type: (str, Mapping[str, str], Pattern, str) -> str
     """Format string, expanding abbreviations in keys'."""
     if s:
         keys.setdefault('%', '%')

+ 3 - 3
setup.py

@@ -116,15 +116,15 @@ def parse_dist_meta():
     pats = {re_meta: add_default, re_doc: add_doc}
     here = os.path.abspath(os.path.dirname(__file__))
     with open(os.path.join(here, 'celery', '__init__.py')) as meta_fh:
-        meta = {}
+        distmeta = {}
         for line in meta_fh:
             if line.strip() == '# -eof meta-':
                 break
             for pattern, handler in pats.items():
                 m = pattern.match(line.strip())
                 if m:
-                    meta.update(handler(m))
-        return meta
+                    distmeta.update(handler(m))
+        return distmeta
 
 # -*- Installation Requires -*-
 

+ 1 - 1
t/unit/utils/test_platforms.py

@@ -302,7 +302,7 @@ class test_maybe_drop_privileges:
         setuid.assert_has_calls([call(5001), call(0)])
 
         setuid.side_effect = None
-        with pytest.raises(RuntimeError):
+        with pytest.raises(SecurityError):
             maybe_drop_privileges(uid='user', gid='group')
         setuid.side_effect = OSError()
         setuid.side_effect.errno = errno.EINVAL