Selaa lähdekoodia

Merge branch '3.0'

Conflicts:
	celery/bin/celery.py
	docs/userguide/tasks.rst
Ask Solem 12 vuotta sitten
vanhempi
commit
5e302ebd79

+ 2 - 2
celery/backends/base.py

@@ -87,8 +87,8 @@ class BaseBackend(object):
         return self.store_result(task_id, exc, status=states.RETRY,
                                  traceback=traceback)
 
-    def mark_as_revoked(self, task_id):
-        return self.store_result(task_id, TaskRevokedError(),
+    def mark_as_revoked(self, task_id, reason=''):
+        return self.store_result(task_id, TaskRevokedError(reason),
                                  status=states.REVOKED, traceback=None)
 
     def prepare_exception(self, exc):

+ 8 - 24
celery/bin/base.py

@@ -71,7 +71,7 @@ from types import ModuleType
 
 import celery
 from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
-from celery.platforms import EX_FAILURE, EX_USAGE
+from celery.platforms import EX_FAILURE, EX_USAGE, maybe_patch_concurrency
 from celery.utils import text
 from celery.utils.imports import symbol_by_name, import_from_cwd
 
@@ -168,9 +168,7 @@ class Command(object):
         if argv is None:
             argv = list(sys.argv)
         # Should we load any special concurrency environment?
-        pool_option = self.with_pool_option(argv)
-        if pool_option:
-            self.maybe_patch_concurrency(argv, *pool_option)
+        self.maybe_patch_concurrency(argv)
         self.on_concurrency_setup()
 
         # Dump version and exit if '--version' arg set.
@@ -179,26 +177,12 @@ class Command(object):
         prog_name = os.path.basename(argv[0])
         return self.handle_argv(prog_name, argv[1:])
 
-    def _find_option_with_arg(self, argv, short_opts=None, long_opts=None):
-        for i, arg in enumerate(argv):
-            if arg.startswith('-'):
-                if long_opts and arg.startswith('--'):
-                    name, _, val = arg.partition('=')
-                    if name in long_opts:
-                        return val
-                if short_opts and arg in short_opts:
-                    return argv[i + 1]
-        raise KeyError('|'.join(short_opts or [] + long_opts or []))
-
-    def maybe_patch_concurrency(self, argv, short_opts=None, long_opts=None):
-        try:
-            pool = self._find_option_with_arg(argv, short_opts, long_opts)
-        except KeyError:
-            pass
-        else:
-            from celery import concurrency
-            # set up eventlet/gevent environments ASAP.
-            concurrency.get_implementation(pool)
+    def maybe_patch_concurrency(self, argv=None):
+        argv = argv or sys.argv
+        pool_option = self.with_pool_option(argv)
+        if pool_option:
+            maybe_patch_concurrency(argv, *pool_option)
+            short_opts, long_opts = pool_option
 
     def on_concurrency_setup(self):
         pass

+ 9 - 3
celery/bin/celery.py

@@ -8,12 +8,15 @@ The :program:`celery` umbrella command.
 """
 from __future__ import absolute_import, print_function
 
-import anyjson
 import sys
+from celery.platforms import maybe_patch_concurrency
+maybe_patch_concurrency(sys.argv, ['-P'], ['--pool'])
+
+import anyjson
 import warnings
 
-from billiard import freeze_support
 from future_builtins import map
+
 from importlib import import_module
 from pprint import pformat
 
@@ -936,8 +939,11 @@ def main():
     try:
         if __name__ != '__main__':  # pragma: no cover
             sys.modules['__main__'] = sys.modules[__name__]
+        cmd = CeleryCommand()
+        cmd.maybe_patch_concurrency()
+        from billiard import freeze_support
         freeze_support()
-        CeleryCommand().execute_from_commandline()
+        cmd.execute_from_commandline()
     except KeyboardInterrupt:
         pass
 

+ 3 - 0
celery/bin/celeryd.py

@@ -117,6 +117,9 @@ The :program:`celery worker` command (previously known as ``celeryd``)
 from __future__ import absolute_import
 
 import sys
+import sys
+from celery.platforms import maybe_patch_concurrency
+maybe_patch_concurrency(sys.argv, ['-P'], ['--pool'])
 
 from billiard import freeze_support
 

+ 1 - 11
celery/concurrency/processes/__init__.py

@@ -12,8 +12,6 @@
 from __future__ import absolute_import
 
 import os
-import platform
-import signal as _signal
 
 from celery import platforms
 from celery import signals
@@ -22,14 +20,6 @@ from celery.concurrency.base import BasePool
 from celery.task import trace
 from billiard.pool import Pool, RUN, CLOSE
 
-if platform.system() == 'Windows':  # pragma: no cover
-    # On Windows os.kill calls TerminateProcess which cannot be
-    # handled by # any process, so this is needed to terminate the task
-    # *and its children* (if any).
-    from ._win import kill_processtree as _kill  # noqa
-else:
-    from os import kill as _kill                 # noqa
-
 #: List of signals to reset when a child process starts.
 WORKER_SIGRESET = frozenset(['SIGTERM',
                              'SIGHUP',
@@ -109,7 +99,7 @@ class TaskPool(BasePool):
             self._pool.close()
 
     def terminate_job(self, pid, signal=None):
-        _kill(pid, signal or _signal.SIGTERM)
+        return self._pool.terminate_job(pid, signal)
 
     def grow(self, n=1):
         return self._pool.grow(n)

+ 0 - 116
celery/concurrency/processes/_win.py

@@ -1,116 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.concurrency.processes._win
-    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-    Windows utilities to terminate process groups.
-
-"""
-from __future__ import absolute_import
-
-import os
-
-# psutil is painfully slow in win32. So to avoid adding big
-# dependencies like pywin32 a ctypes based solution is preferred
-
-# Code based on the winappdbg project http://winappdbg.sourceforge.net/
-# (BSD License)
-from ctypes import (
-    byref, sizeof, windll,
-    Structure, WinError, POINTER,
-    c_size_t, c_char, c_void_p,
-)
-from ctypes.wintypes import DWORD, LONG
-
-ERROR_NO_MORE_FILES = 18
-INVALID_HANDLE_VALUE = c_void_p(-1).value
-
-
-class PROCESSENTRY32(Structure):
-    _fields_ = [
-        ('dwSize',              DWORD),
-        ('cntUsage',            DWORD),
-        ('th32ProcessID',       DWORD),
-        ('th32DefaultHeapID',   c_size_t),
-        ('th32ModuleID',        DWORD),
-        ('cntThreads',          DWORD),
-        ('th32ParentProcessID', DWORD),
-        ('pcPriClassBase',      LONG),
-        ('dwFlags',             DWORD),
-        ('szExeFile',           c_char * 260),
-    ]
-LPPROCESSENTRY32 = POINTER(PROCESSENTRY32)
-
-
-def CreateToolhelp32Snapshot(dwFlags=2, th32ProcessID=0):
-    hSnapshot = windll.kernel32.CreateToolhelp32Snapshot(dwFlags,
-                                                         th32ProcessID)
-    if hSnapshot == INVALID_HANDLE_VALUE:
-        raise WinError()
-    return hSnapshot
-
-
-def Process32First(hSnapshot, pe=None):
-    return _Process32n(windll.kernel32.Process32First, hSnapshot, pe)
-
-
-def Process32Next(hSnapshot, pe=None):
-    return _Process32n(windll.kernel32.Process32Next, hSnapshot, pe)
-
-
-def _Process32n(fun, hSnapshot, pe=None):
-    if pe is None:
-        pe = PROCESSENTRY32()
-    pe.dwSize = sizeof(PROCESSENTRY32)
-    success = fun(hSnapshot, byref(pe))
-    if not success:
-        if windll.kernel32.GetLastError() == ERROR_NO_MORE_FILES:
-            return
-        raise WinError()
-    return pe
-
-
-def get_all_processes_pids():
-    """Return a dictionary with all processes pids as keys and their
-       parents as value. Ignore processes with no parents.
-    """
-    h = CreateToolhelp32Snapshot()
-    parents = {}
-    pe = Process32First(h)
-    while pe:
-        if pe.th32ParentProcessID:
-            parents[pe.th32ProcessID] = pe.th32ParentProcessID
-        pe = Process32Next(h, pe)
-
-    return parents
-
-
-def get_processtree_pids(pid, include_parent=True):
-    """Return a list with all the pids of a process tree"""
-    parents = get_all_processes_pids()
-    all_pids = parents.keys()
-    pids = set([pid])
-    while 1:
-        pids_new = pids.copy()
-
-        for _pid in all_pids:
-            if parents[_pid] in pids:
-                pids_new.add(_pid)
-
-        if pids_new == pids:
-            break
-
-        pids = pids_new.copy()
-
-    if not include_parent:
-        pids.remove(pid)
-
-    return list(pids)
-
-
-def kill_processtree(pid, signum):
-    """Kill a process and all its descendants"""
-    family_pids = get_processtree_pids(pid)
-
-    for _pid in family_pids:
-        os.kill(_pid, signum)

+ 24 - 1
celery/platforms.py

@@ -22,7 +22,6 @@ from future_builtins import map
 
 from .local import try_import
 
-from billiard import current_process
 from kombu.utils.limits import TokenBucket
 
 _setproctitle = try_import('setproctitle')
@@ -67,6 +66,29 @@ def pyimplementation():
         return 'CPython'
 
 
+def _find_option_with_arg(argv, short_opts=None, long_opts=None):
+    for i, arg in enumerate(argv):
+        if arg.startswith('-'):
+            if long_opts and arg.startswith('--'):
+                name, _, val = arg.partition('=')
+                if name in long_opts:
+                    return val
+            if short_opts and arg in short_opts:
+                return argv[i + 1]
+    raise KeyError('|'.join(short_opts or [] + long_opts or []))
+
+
+def maybe_patch_concurrency(argv, short_opts=None, long_opts=None):
+    try:
+        pool = _find_option_with_arg(argv, short_opts, long_opts)
+    except KeyError:
+        pass
+    else:
+        # set up eventlet/gevent environments ASAP.
+        from celery import concurrency
+        concurrency.get_implementation(pool)
+
+
 class LockFailed(Exception):
     """Raised if a pidlock can't be acquired."""
     pass
@@ -591,6 +613,7 @@ else:
 
         """
         if not rate_limit or _setps_bucket.can_consume(1):
+            from billiard import current_process
             if hostname:
                 progname = '{0}@{1}'.format(progname, hostname.split('.')[0])
             return set_process_title(

+ 8 - 2
celery/result.py

@@ -74,14 +74,20 @@ class AsyncResult(ResultBase):
         """Forget about (and possibly remove the result of) this task."""
         self.backend.forget(self.id)
 
-    def revoke(self, connection=None):
+    def revoke(self, connection=None, terminate=False, signal=None):
         """Send revoke signal to all workers.
 
         Any worker receiving the task, or having reserved the
         task, *must* ignore it.
 
+        :keyword terminate: Also terminate the process currently working
+            on the task (if any).
+        :keyword signal: Name of signal to send to process if terminate.
+            Default is TERM.
+
         """
-        self.app.control.revoke(self.id, connection=connection)
+        self.app.control.revoke(self.id, connection=connection,
+                                terminate=terminate, signal=signal)
 
     def get(self, timeout=None, propagate=True, interval=0.5):
         """Wait until task is ready, and return its result.

+ 2 - 1
celery/utils/__init__.py

@@ -10,7 +10,6 @@ from __future__ import absolute_import, print_function
 
 import os
 import sys
-import threading
 import traceback
 import warnings
 import types
@@ -146,6 +145,8 @@ def cry():  # pragma: no cover
     From https://gist.github.com/737056
 
     """
+    import threading
+
     tmap = {}
     main_thread = None
     # get a map of threads by their ID so we can print their names

+ 13 - 9
celery/worker/job.py

@@ -243,11 +243,20 @@ class Request(object):
         if self.time_start:
             signal = _signals.signum(signal or 'TERM')
             pool.terminate_job(self.worker_pid, signal)
-            send_revoked(self.task, signum=signal,
-                         terminated=True, expired=False)
+            self._announce_revoked('terminated', True, signal, False)
         else:
             self._terminate_on_ack = pool, signal
 
+    def _announce_revoked(self, reason, terminated, signum, expired):
+        self.send_event('task-revoked', uuid=self.id,
+                        terminated=terminated, signum=signum, expired=expired)
+        if self.store_errors:
+            self.task.backend.mark_as_revoked(self.id, reason)
+        self.acknowledge()
+        self._already_revoked = True
+        send_revoked(self.task, terminated=terminated,
+                     signum=signum, expired=expired)
+
     def revoked(self):
         """If revoked, skip task and mark state."""
         expired = False
@@ -257,13 +266,8 @@ class Request(object):
             expired = self.maybe_expire()
         if self.id in revoked_tasks:
             warn('Skipping revoked task: %s[%s]', self.name, self.id)
-            self.send_event('task-revoked', uuid=self.id)
-            if self.store_errors:
-                self.task.backend.mark_as_revoked(self.id)
-            self.acknowledge()
-            self._already_revoked = True
-            send_revoked(self.task, terminated=False,
-                         signum=None, expired=expired)
+            self._announce_revoked('expired' if expired else 'revoked',
+                False, None, expired)
             return True
         return False
 

+ 6 - 1
docs/userguide/monitoring.rst

@@ -642,11 +642,16 @@ Task Events
 
     Sent if the execution of the task failed.
 
-* ``task-revoked(uuid)``
+* ``task-revoked(uuid, terminated, signum, expired)``
 
     Sent if the task has been revoked (Note that this is likely
     to be sent by more than one worker).
 
+    - ``terminated`` is set to true if the task process was terminated,
+      and the ``signum`` field set to the signal used.
+
+    - ``expired`` is set to true if the task expired.
+
 * ``task-retried(uuid, exception, traceback, hostname, timestamp)``
 
     Sent if the task failed, but will be retried in the future.

+ 14 - 1
docs/userguide/tasks.rst

@@ -230,7 +230,20 @@ An example task accessing information in the context is:
     @celery.task()
     def dump_context(x, y):
         print('Executing task id {0.id}, args: {0.args!r} kwargs: {0.kwargs!r}'.format(
-                add.request))
+                dump_context.request))
+
+
+:data:`~celery.current_task` can also be used:
+
+.. code-block:: python
+
+    from celery import current_task
+
+    @celery.task()
+    def dump_context(x, y):
+        print('Executing task id {0.id}, args: {0.args!r} kwargs: {0.kwargs!r}'.format(
+                current_task.request))
+
 
 .. _task-logging:
 

+ 0 - 1
extra/release/doc4allmods

@@ -7,7 +7,6 @@ SKIP_FILES="celery.__compat__.rst
             celery.bin.rst
             celery.bin.celeryd_detach.rst
             celery.bin.celeryctl.rst
-            celery.concurrency.processes._win.rst
             celery.contrib.rst
             celery.contrib.bundles.rst
             celery.local.rst