Selaa lähdekoodia

Worker: Adds state.requests for id -> request mapping (weakrefs)

Ask Solem 8 vuotta sitten
vanhempi
commit
5bfa769f6d

+ 0 - 0
celery/tests/functional/__init__.py


+ 0 - 178
celery/tests/functional/case.py

@@ -1,178 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import atexit
-import logging
-import os
-import signal
-import socket
-import sys
-import traceback
-
-from itertools import count
-from time import time
-
-from celery import current_app
-from celery.exceptions import TimeoutError
-from celery.app.control import flatten_reply
-from celery.utils.imports import qualname
-
-from celery.tests.case import Case
-
-HOSTNAME = socket.gethostname()
-
-
-def say(msg):
-    sys.stderr.write('%s\n' % msg)
-
-
-def try_while(fun, reason='Timed out', timeout=10, interval=0.5):
-    time_start = time()
-    for iterations in count(0):
-        if time() - time_start >= timeout:
-            raise TimeoutError()
-        ret = fun()
-        if ret:
-            return ret
-
-
-class Worker(object):
-    started = False
-    worker_ids = count(1)
-    _shutdown_called = False
-
-    def __init__(self, hostname, loglevel='error', app=None):
-        self.hostname = hostname
-        self.loglevel = loglevel
-        self.app = app or current_app._get_current_object()
-
-    def start(self):
-        if not self.started:
-            self._fork_and_exec()
-            self.started = True
-
-    def _fork_and_exec(self):
-        pid = os.fork()
-        if pid == 0:
-            self.app.worker_main(['worker', '--loglevel=INFO',
-                                  '-n', self.hostname,
-                                  '-P', 'solo'])
-            os._exit(0)
-        self.pid = pid
-
-    def ping(self, *args, **kwargs):
-        return self.app.control.ping(*args, **kwargs)
-
-    def is_alive(self, timeout=1):
-        r = self.ping(destination=[self.hostname], timeout=timeout)
-        return self.hostname in flatten_reply(r)
-
-    def wait_until_started(self, timeout=10, interval=0.5):
-        try_while(
-            lambda: self.is_alive(interval),
-            "Worker won't start (after %s secs.)" % timeout,
-            interval=interval, timeout=timeout,
-        )
-        say('--WORKER %s IS ONLINE--' % self.hostname)
-
-    def ensure_shutdown(self, timeout=10, interval=0.5):
-        os.kill(self.pid, signal.SIGTERM)
-        try_while(
-            lambda: not self.is_alive(interval),
-            "Worker won't shutdown (after %s secs.)" % timeout,
-            timeout=10, interval=0.5,
-        )
-        say('--WORKER %s IS SHUTDOWN--' % self.hostname)
-        self._shutdown_called = True
-
-    def ensure_started(self):
-        self.start()
-        self.wait_until_started()
-
-    @classmethod
-    def managed(cls, hostname=None, caller=None):
-        hostname = hostname or socket.gethostname()
-        if caller:
-            hostname = '.'.join([qualname(caller), hostname])
-        else:
-            hostname += str(next(cls.worker_ids()))
-        worker = cls(hostname)
-        worker.ensure_started()
-        stack = traceback.format_stack()
-
-        @atexit.register
-        def _ensure_shutdown_once():
-            if not worker._shutdown_called:
-                say('-- Found worker not stopped at shutdown: %s\n%s' % (
-                    worker.hostname,
-                    '\n'.join(stack)))
-                worker.ensure_shutdown()
-
-        return worker
-
-
-class WorkerCase(Case):
-    hostname = HOSTNAME
-    worker = None
-
-    @classmethod
-    def setUpClass(cls):
-        logging.getLogger('amqp').setLevel(logging.ERROR)
-        cls.worker = Worker.managed(cls.hostname, caller=cls)
-
-    @classmethod
-    def tearDownClass(cls):
-        cls.worker.ensure_shutdown()
-
-    def assertWorkerAlive(self, timeout=1):
-        self.assertTrue(self.worker.is_alive)
-
-    def inspect(self, timeout=1):
-        return self.app.control.inspect([self.worker.hostname],
-                                        timeout=timeout)
-
-    def my_response(self, response):
-        return flatten_reply(response)[self.worker.hostname]
-
-    def is_accepted(self, task_id, interval=0.5):
-        active = self.inspect(timeout=interval).active()
-        if active:
-            for task in active[self.worker.hostname]:
-                if task['id'] == task_id:
-                    return True
-        return False
-
-    def is_reserved(self, task_id, interval=0.5):
-        reserved = self.inspect(timeout=interval).reserved()
-        if reserved:
-            for task in reserved[self.worker.hostname]:
-                if task['id'] == task_id:
-                    return True
-        return False
-
-    def is_scheduled(self, task_id, interval=0.5):
-        schedule = self.inspect(timeout=interval).scheduled()
-        if schedule:
-            for item in schedule[self.worker.hostname]:
-                if item['request']['id'] == task_id:
-                    return True
-        return False
-
-    def is_received(self, task_id, interval=0.5):
-        return (self.is_reserved(task_id, interval) or
-                self.is_scheduled(task_id, interval) or
-                self.is_accepted(task_id, interval))
-
-    def ensure_accepted(self, task_id, interval=0.5, timeout=10):
-        return try_while(lambda: self.is_accepted(task_id, interval),
-                         'Task not accepted within timeout',
-                         interval=0.5, timeout=10)
-
-    def ensure_received(self, task_id, interval=0.5, timeout=10):
-        return try_while(lambda: self.is_received(task_id, interval),
-                         'Task not receied within timeout',
-                         interval=0.5, timeout=10)
-
-    def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
-        return try_while(lambda: self.is_scheduled(task_id, interval),
-                         'Task not scheduled within timeout',
-                         interval=0.5, timeout=10)

+ 0 - 24
celery/tests/functional/tasks.py

@@ -1,24 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import time
-
-from celery import task, signature
-
-
-@task()
-def add(x, y):
-    return x + y
-
-
-@task()
-def add_cb(x, y, callback=None):
-    result = x + y
-    if callback:
-        return signature(callback).apply_async(result)
-    return result
-
-
-@task()
-def sleeptask(i):
-    time.sleep(i)
-    return i

+ 7 - 7
celery/tests/worker/test_autoscale.py

@@ -95,8 +95,8 @@ class test_Autoscaler(AppCase):
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
         x.body()
         self.assertEqual(x.pool.num_processes, 3)
-        for i in range(20):
-            state.reserved_requests.add(i)
+        _keep = [Mock(name='req{0}'.format(i)) for i in range(20)]
+        [state.task_reserved(m) for m in _keep]
         x.body()
         x.body()
         self.assertEqual(x.pool.num_processes, 10)
@@ -129,7 +129,6 @@ class test_Autoscaler(AppCase):
         worker = Mock(name='worker')
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
         x.scale_up(3)
-        x._last_action = monotonic() - 10000
         x.pool.shrink_raises_exception = True
         x._shrink(1)
 
@@ -201,13 +200,14 @@ class test_Autoscaler(AppCase):
         x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
         x.body()  # the body func scales up or down
 
-        for i in range(35):
-            state.reserved_requests.add(i)
+        _keep = [Mock(name='req{0}'.format(i)) for i in range(35)]
+        for req in _keep:
+            state.task_reserved(req)
             x.body()
             total_num_processes.append(self.pool.num_processes)
 
-        for i in range(35):
-            state.reserved_requests.remove(i)
+        for req in _keep:
+            state.task_ready(req)
             x.body()
             total_num_processes.append(self.pool.num_processes)
 

+ 26 - 10
celery/worker/state.py

@@ -15,6 +15,7 @@ import os
 import sys
 import platform
 import shelve
+import weakref
 import zlib
 
 from kombu.serialization import pickle, pickle_protocol
@@ -41,11 +42,13 @@ REVOKES_MAX = 50000
 #: being expired when the max limit has been exceeded.
 REVOKE_EXPIRES = 10800
 
+requests = weakref.WeakValueDictionary()
+
 #: set of all reserved :class:`~celery.worker.request.Request`'s.
 reserved_requests = set()
 
 #: set of currently active :class:`~celery.worker.request.Request`'s.
-active_requests = set()
+active_requests = weakref.WeakSet()
 
 #: count of tasks accepted by the worker, sorted by type.
 total_count = Counter()
@@ -56,9 +59,6 @@ all_total_count = [0]
 #: the list of currently revoked tasks.  Persistent if ``statedb`` set.
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
-#: Update global state when a task has been reserved.
-task_reserved = reserved_requests.add
-
 should_stop = None
 should_terminate = None
 
@@ -66,6 +66,7 @@ should_terminate = None
 def reset_state():
     reserved_requests.clear()
     active_requests.clear()
+    requests.clear()
     total_count.clear()
     all_total_count[:] = [0]
     revoked.clear()
@@ -78,17 +79,32 @@ def maybe_shutdown():
         raise WorkerTerminate(should_terminate)
 
 
-def task_accepted(request, _all_total_count=all_total_count):
+def task_reserved(request,
+                  add_request=requests.__setitem__,
+                  add_reserved_request=reserved_requests.add):
+    """Update global state when a task has been reserved."""
+    add_request(request.id, request)
+    add_reserved_request(request)
+
+
+def task_accepted(request,
+                  _all_total_count=all_total_count,
+                  add_active_request=active_requests.add,
+                  add_to_total_count=total_count.update):
     """Updates global state when a task has been accepted."""
-    active_requests.add(request)
-    total_count[request.name] += 1
+    add_active_request(request)
+    add_to_total_count({request.name: 1})
     all_total_count[0] += 1
 
 
-def task_ready(request):
+def task_ready(request,
+               remove_request=requests.pop,
+               discard_active_request=active_requests.discard,
+               discard_reserved_request=reserved_requests.discard):
     """Updates global state when a task is ready."""
-    active_requests.discard(request)
-    reserved_requests.discard(request)
+    remove_request(request, None)
+    discard_active_request(request)
+    discard_reserved_request(request)
 
 
 C_BENCH = os.environ.get('C_BENCH') or os.environ.get('CELERY_BENCH')