Sfoglia il codice sorgente

No longer depends on multiprocessing backport, only billiard

Ask Solem 13 anni fa
parent
commit
703ae4b0b7

+ 2 - 2
celery/app/__init__.py

@@ -54,13 +54,13 @@ def _app_or_default(app=None):
 
 def _app_or_default_trace(app=None):  # pragma: no cover
     from traceback import print_stack
-    from multiprocessing import current_process
+    from celery.utils.mp import get_process_name
     if app is None:
         if getattr(state._tls, "current_app", None):
             print("-- RETURNING TO CURRENT APP --")  # noqa+
             print_stack()
             return state._tls.current_app
-        if current_process()._name == "MainProcess":
+        if get_process_name() == "MainProcess":
             raise Exception("DEFAULT APP")
         print("-- RETURNING TO DEFAULT APP --")      # noqa+
         print_stack()

+ 1 - 18
celery/apps/worker.py

@@ -3,10 +3,6 @@ from __future__ import absolute_import
 
 import atexit
 import logging
-try:
-    import multiprocessing
-except ImportError:
-    multiprocessing = None  # noqa
 import os
 import socket
 import sys
@@ -19,6 +15,7 @@ from celery.exceptions import ImproperlyConfigured, SystemTerminate
 from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.log import LOG_LEVELS, get_logger, mlevel
+from celery.utils.mp import cpu_count, get_process_name
 from celery.utils.text import pluralize
 from celery.worker import WorkController
 
@@ -60,20 +57,6 @@ enable the CELERY_CREATE_MISSING_QUEUES setting.
 """
 
 
-def cpu_count():
-    if multiprocessing is not None:
-        try:
-            return multiprocessing.cpu_count()
-        except NotImplementedError:
-            pass
-    return 2
-
-
-def get_process_name():
-    if multiprocessing is not None:
-        return multiprocessing.current_process().name
-
-
 class Worker(configurated):
     WorkController = WorkController
 

+ 3 - 6
celery/beat.py

@@ -18,10 +18,6 @@ import shelve
 import sys
 import threading
 import traceback
-try:
-    import multiprocessing
-except ImportError:
-    multiprocessing = None  # noqa
 
 from kombu.utils import reprcall
 from kombu.utils.functional import maybe_promise
@@ -36,6 +32,7 @@ from .utils import cached_property
 from .utils.imports import instantiate
 from .utils.timeutils import humanize_seconds
 from .utils.log import get_logger
+from .utils.mp import Process
 
 logger = get_logger(__name__)
 debug, info, error = logger.debug, logger.info, logger.error
@@ -449,9 +446,9 @@ class _Threaded(threading.Thread):
         self.service.stop(wait=True)
 
 
-if multiprocessing is not None:
+if Process is not None:
 
-    class _Process(multiprocessing.Process):
+    class _Process(Process):
         """Embedded task scheduler using multiprocessing."""
 
         def __init__(self, *args, **kwargs):

+ 1 - 8
celery/bin/celeryd.py

@@ -75,14 +75,7 @@ from __future__ import absolute_import
 
 import sys
 
-try:
-    import multiprocessing  # noqa
-except ImportError:  # pragma: no cover
-    freeze_support = lambda: True
-else:
-    # patch with freeze_support from billiard
-    from billiard import freeze_support  # noqa
-
+from celery.utils.mp import freeze_support
 from celery.bin.base import Command, Option
 
 

+ 1 - 2
celery/concurrency/processes/__init__.py

@@ -10,8 +10,7 @@ from celery import platforms
 from celery import signals
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
-
-from billiard.pool import Pool, RUN
+from celery.utils.mp import Pool, RUN
 
 if platform.system() == "Windows":  # pragma: no cover
     # On Windows os.kill calls TerminateProcess which cannot be

+ 3 - 2
celery/contrib/rdb.py

@@ -67,8 +67,9 @@ class Rdb(Pdb):
         self.active = True
 
         try:
-            from multiprocessing import current_process
-            _, port_skew = current_process().name.split('-')
+            from celery.utils.mp import current_process
+            if current_process:
+                _, port_skew = current_process().name.split('-')
         except (ImportError, ValueError):
             pass
         port_skew = int(port_skew)

+ 1 - 5
celery/platforms.py

@@ -23,6 +23,7 @@ import sys
 from .local import try_import
 
 from kombu.utils.limits import TokenBucket
+from celery.utils.mp import current_process
 
 _setproctitle = try_import("setproctitle")
 resource = try_import("resource")
@@ -34,11 +35,6 @@ EX_FAILURE = 1
 EX_UNAVAILABLE = getattr(os, "EX_UNAVAILABLE", 69)
 EX_USAGE = getattr(os, "EX_USAGE", 64)
 
-try:
-    from multiprocessing.process import current_process
-except ImportError:
-    current_process = None  # noqa
-
 SYSTEM = _platform.system()
 IS_OSX = SYSTEM == "Darwin"
 IS_WINDOWS = SYSTEM == "Windows"

+ 2 - 3
celery/tests/test_app/test_beat.py

@@ -305,9 +305,8 @@ class test_Service(Case):
 class test_EmbeddedService(Case):
 
     def test_start_stop_process(self):
-        try:
-            from multiprocessing import Process
-        except ImportError:
+        from celery.utils.mp import Process
+        if not Process:
             raise SkipTest("multiprocessing not available")
 
         s = beat.EmbeddedService()

+ 1 - 47
celery/tests/test_bin/test_celeryd.py

@@ -6,10 +6,6 @@ import os
 import sys
 
 from functools import wraps
-try:
-    from multiprocessing import current_process
-except ImportError:
-    current_process = None  # noqa
 
 from mock import patch
 from nose import SkipTest
@@ -24,6 +20,7 @@ from celery.exceptions import ImproperlyConfigured, SystemTerminate
 
 from celery.tests.utils import (AppCase, WhateverIO, mask_modules,
                                 reset_modules, skip_unless_module)
+from celery.utils.mp import current_process
 
 
 from celery.utils.log import ensure_process_aware_logger
@@ -57,49 +54,6 @@ class Worker(cd.Worker):
     WorkController = _WorkController
 
 
-class test_compilation(AppCase):
-
-    def test_no_multiprocessing(self):
-        with mask_modules("multiprocessing"):
-            with reset_modules("celery.apps.worker"):
-                from celery.apps.worker import multiprocessing
-                self.assertIsNone(multiprocessing)
-
-    def test_cpu_count_no_mp(self):
-        with mask_modules("multiprocessing"):
-            with reset_modules("celery.apps.worker"):
-                from celery.apps.worker import cpu_count
-                self.assertEqual(cpu_count(), 2)
-
-    @skip_unless_module("multiprocessing")
-    def test_no_cpu_count(self):
-
-        @patch("multiprocessing.cpu_count")
-        def _do_test(pcount):
-            pcount.side_effect = NotImplementedError("cpu_count")
-            from celery.apps.worker import cpu_count
-            self.assertEqual(cpu_count(), 2)
-            pcount.assert_called_with()
-
-        _do_test()
-
-    def test_process_name_wo_mp(self):
-        with mask_modules("multiprocessing"):
-            with reset_modules("celery.apps.worker"):
-                from celery.apps.worker import get_process_name
-                self.assertIsNone(get_process_name())
-
-    @skip_unless_module("multiprocessing")
-    def test_process_name_w_mp(self):
-
-        @patch("multiprocessing.current_process")
-        def _do_test(current_process):
-            from celery.apps.worker import get_process_name
-            self.assertTrue(get_process_name())
-
-        _do_test()
-
-
 class test_Worker(AppCase):
     Worker = Worker
 

+ 1 - 3
celery/utils/__init__.py

@@ -27,9 +27,7 @@ from .compat import StringIO
 
 from .imports import symbol_by_name, qualname
 from .functional import noop
-
-register_after_fork = symbol_by_name(
-    "multiprocessing.util.register_after_fork", default=noop)
+from .mp import register_after_fork
 
 PENDING_DEPRECATION_FMT = """
     %(description)s is scheduled for deprecation in \

+ 1 - 12
celery/utils/log.py

@@ -6,15 +6,10 @@ import sys
 import threading
 import traceback
 
-try:
-    from multiprocessing import current_process
-    from multiprocessing import util as mputil
-except ImportError:
-    current_process = mputil = None  # noqa
-
 from kombu.log import get_logger as _get_logger, LOG_LEVELS
 
 from .encoding import safe_str, str_t
+from .mp import current_process, util as mputil
 from .term import colored
 
 _process_aware = False
@@ -170,12 +165,6 @@ class LoggingProxy(object):
 
 def _patch_logger_class():
     """Make sure process name is recorded when loggers are used."""
-
-    try:
-        from multiprocessing.process import current_process
-    except ImportError:
-        current_process = None  # noqa
-
     logging._acquireLock()
     try:
         OldLoggerClass = logging.getLoggerClass()

+ 47 - 0
celery/utils/mp.py

@@ -0,0 +1,47 @@
+try:
+    import billiard
+    from billiard import util
+    from billiard import pool
+    current_process = billiard.current_process
+    register_after_fork = util.register_after_fork
+    freeze_support = billiard.freeze_support
+    Process = billiard.Process
+    cpu_count = billiard.cpu_count
+    Pool = pool.Pool
+    RUN = pool.RUN
+except ImportError:
+    try:
+        import multiprocessing
+        from multiprocessing import util
+        from multiprocessing import pool
+        current_process = multiprocessing.current_process
+        register_after_fork = util.register_after_fork
+        freeze_support = multiprocessing.freeze_support
+        Process = multiprocessing.Process
+        cpu_count = multiprocessing.cpu_count
+        Pool = pool.Pool
+        RUN = pool.RUN
+    except ImportError:
+        current_process = None
+        util = None
+        register_after_fork = lambda *a, **kw: None
+        freeze_support = lambda: True
+        Process = None
+        cpu_count = lambda: 2
+        Pool = None
+        RUN = 1
+
+
+def get_process_name():
+    if current_process is not None:
+        return current_process().name
+
+def forking_enable(enabled):
+    try:
+        from billiard import forking_enable
+    except ImportError:
+        try:
+            from multiprocessing import forking_enable
+        except ImportError:
+            return
+    forking_enable(enabled)

+ 1 - 6
celery/utils/patch.py

@@ -18,12 +18,7 @@ _process_aware = False
 
 def _patch_logger_class():
     """Make sure process name is recorded when loggers are used."""
-
-    try:
-        from multiprocessing.process import current_process
-    except ImportError:
-        current_process = None  # noqa
-
+    from .mp import current_process
     logging._acquireLock()
     try:
         OldLoggerClass = logging.getLoggerClass()

+ 2 - 6
celery/worker/__init__.py

@@ -28,6 +28,7 @@ from celery.app import app_or_default, set_default_app
 from celery.app.abstract import configurated, from_config
 from celery.exceptions import SystemTerminate
 from celery.utils.functional import noop
+from celery.utils.mp import forking_enable
 from celery.utils.imports import qualname, reload_from_cwd
 from celery.utils.log import get_logger
 
@@ -86,12 +87,7 @@ class Pool(abstract.StartStopComponent):
             w.max_concurrency, w.min_concurrency = w.autoscale
 
     def create(self, w):
-        try:
-            from billiard import forking_enable
-        except ImportError:
-            pass
-        else:
-            forking_enable(not w.force_execv)
+        forking_enable(not w.force_execv)
         pool = w.pool = self.instantiate(w.pool_cls, w.min_concurrency,
                                 initargs=(w.app, w.hostname),
                                 maxtasksperchild=w.max_tasks_per_child,

+ 1 - 2
requirements/py25.txt

@@ -1,4 +1,3 @@
-multiprocessing==2.6.2.1
 importlib
 ordereddict
-simplejson
+simplejson

+ 1 - 2
setup.cfg

@@ -42,8 +42,7 @@ upload-dir = docs/.build/html
 [bdist_rpm]
 requires = uuid
            importlib
-           multiprocessing == 2.6.2.1
-           billiard>=2.7.3.0
+           billiard>=2.7.3.2
            python-dateutil >= 1.5
            anyjson >= 0.3.1
            kombu >= 2.1.5

+ 1 - 3
setup.py

@@ -112,7 +112,7 @@ try:
 except ImportError:
     install_requires.append("importlib")
 install_requires.extend([
-    "billiard>=2.7.3.0",
+    "billiard>=2.7.3.2",
     "anyjson>=0.3.1",
     "kombu>=2.1.5,<3.0",
 ])
@@ -126,8 +126,6 @@ is_jython = sys.platform.startswith("java")
 is_pypy = hasattr(sys, "pypy_version_info")
 if sys.version_info < (2, 7):
     install_requires.append("ordereddict") # Replacement for the ordered dict
-if sys.version_info < (2, 6) and not (is_jython or is_pypy):
-    install_requires.append("multiprocessing")
 
 if is_jython:
     install_requires.append("threadpool")