Przeglądaj źródła

TaskPool api compatible with multiprocessing.Pool

Ask Solem 14 lat temu
rodzic
commit
c6ae39edfb

+ 9 - 25
celery/concurrency/processes/__init__.py

@@ -13,14 +13,10 @@ from celery.utils.functional import partial
 from celery.concurrency.processes.pool import Pool, RUN
 
 
-def pingback(i):
-    return i
-
-
 class TaskPool(object):
     """Process Pool for processing tasks in parallel.
 
-    :param limit: see :attr:`limit`.
+    :param processes: see :attr:`processes`.
     :param logger: see :attr:`logger`.
 
 
@@ -35,18 +31,11 @@ class TaskPool(object):
     """
     Pool = Pool
 
-    def __init__(self, limit, logger=None, initializer=None,
-            maxtasksperchild=None, timeout=None, soft_timeout=None,
-            putlocks=True, initargs=()):
-        self.limit = limit
-        self.logger = logger or log.get_default_logger()
-        self.initializer = initializer
-        self.initargs = initargs
-        self.maxtasksperchild = maxtasksperchild
-        self.timeout = timeout
-        self.soft_timeout = soft_timeout
+    def __init__(self, processes=None, putlocks=True, logger=None, **options):
+        self.processes = processes
         self.putlocks = putlocks
-        self.initargs = initargs
+        self.logger = logger or log.get_default_logger()
+        self.options = options
         self._pool = None
 
     def start(self):
@@ -55,12 +44,7 @@ class TaskPool(object):
         Will pre-fork all workers so they're ready to accept tasks.
 
         """
-        self._pool = self.Pool(processes=self.limit,
-                               initializer=self.initializer,
-                               initargs=self.initargs,
-                               timeout=self.timeout,
-                               soft_timeout=self.soft_timeout,
-                               maxtasksperchild=self.maxtasksperchild)
+        self._pool = self.Pool(processes=self.processes, **self.options)
 
     def stop(self):
         """Gracefully stop the pool."""
@@ -136,8 +120,8 @@ class TaskPool(object):
 
     @property
     def info(self):
-        return {"max-concurrency": self.limit,
+        return {"max-concurrency": self.processes,
                 "processes": [p.pid for p in self._pool._pool],
-                "max-tasks-per-child": self.maxtasksperchild,
+                "max-tasks-per-child": self._pool._maxtasksperchild,
                 "put-guarded-by-semaphore": self.putlocks,
-                "timeouts": (self.soft_timeout, self.timeout)}
+                "timeouts": (self._pool.soft_timeout, self._pool.timeout)}

+ 4 - 4
celery/concurrency/threads.py

@@ -22,16 +22,16 @@ def do_work(target, args=(), kwargs={}, callback=None,
 
 class TaskPool(object):
 
-    def __init__(self, limit, logger=None, **kwargs):
-        self.limit = limit
+    def __init__(self, processes, logger=None, **kwargs):
+        self.processes = processes
         self.logger = logger or log.get_default_logger()
         self._pool = None
 
     def start(self):
-        self._pool = ThreadPool(self.limit)
+        self._pool = ThreadPool(self.processes)
 
     def stop(self):
-        self._pool.dismissWorkers(self.limit, do_join=True)
+        self._pool.dismissWorkers(self.processes, do_join=True)
 
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
             errbacks=None, accept_callback=None, **compat):

+ 9 - 8
celery/tests/test_concurrency_processes.py

@@ -98,10 +98,6 @@ class test_TaskPool(unittest.TestCase):
         pool.terminate()
         self.assertTrue(_pool.terminated)
 
-    def test_pingback(self):
-        for i in xrange(10):
-            self.assertEqual(mp.pingback(i), i)
-
     def test_on_worker_error(self):
         scratch = [None]
 
@@ -169,8 +165,13 @@ class test_TaskPool(unittest.TestCase):
 
     def test_info(self):
         pool = TaskPool(10)
-        procs = [Object(pid=i) for i in range(pool.limit)]
-        pool._pool = Object(_pool=procs)
+        procs = [Object(pid=i) for i in range(pool.processes)]
+        pool._pool = Object(_pool=procs,
+                            _maxtasksperchild=None,
+                            timeout=10,
+                            soft_timeout=5)
         info = pool.info
-        self.assertEqual(info["max-concurrency"], pool.limit)
-        self.assertEqual(len(info["processes"]), pool.limit)
+        self.assertEqual(info["max-concurrency"], pool.processes)
+        self.assertEqual(len(info["processes"]), pool.processes)
+        self.assertIsNone(info["max-tasks-per-child"])
+        self.assertEqual(info["timeouts"], (5, 10))

+ 3 - 3
celery/tests/test_pool.py

@@ -26,13 +26,13 @@ def raise_something(i):
 class TestTaskPool(unittest.TestCase):
 
     def test_attrs(self):
-        p = TaskPool(limit=2)
-        self.assertEqual(p.limit, 2)
+        p = TaskPool(2)
+        self.assertEqual(p.processes, 2)
         self.assertIsInstance(p.logger, logging.Logger)
         self.assertIsNone(p._pool)
 
     def x_apply(self):
-        p = TaskPool(limit=2)
+        p = TaskPool(2)
         p.start()
         scratchpad = {}
         proc_counter = itertools.count().next