Переглянути джерело

with_default_connection -> with default_connection

Ask Solem 14 роки тому
батько
коміт
c396539693

+ 26 - 26
celery/app/base.py

@@ -8,9 +8,12 @@ Application Base Class.
 :license: BSD, see LICENSE for more details.
 
 """
+from __future__ import absolute_import, with_statement
+
 import platform as _platform
 import sys
 
+from contextlib import contextmanager
 from copy import deepcopy
 from functools import wraps
 from threading import Lock
@@ -84,21 +87,15 @@ class LamportClock(object):
 
     def __init__(self, initial_value=0):
         self.value = initial_value
-        self._mutex = Lock()
+        self.mutex = Lock()
 
     def adjust(self, other):
-        self._mutex.acquire()
-        try:
+        with self.mutex:
             self.value = max(self.value, other) + 1
-        finally:
-            self._mutex.release()
 
     def forward(self):
-        self._mutex.acquire()
-        try:
+        with self.mutex:
             self.value += 1
-        finally:
-            self._mutex.release()
         return self.value
 
 
@@ -193,8 +190,8 @@ class BaseApp(object):
         exchange = options.get("exchange")
         exchange_type = options.get("exchange_type")
 
-        def _do_publish(connection=None, **_):
-            publish = publisher or self.amqp.TaskPublisher(connection,
+        with self.default_connection(connection, connect_timeout) as conn:
+            publish = publisher or self.amqp.TaskPublisher(conn,
                                             exchange=exchange,
                                             exchange_type=exchange_type)
             try:
@@ -204,12 +201,8 @@ class BaseApp(object):
                                             expires=expires, **options)
             finally:
                 publisher or publish.close()
-
             return result_cls(new_id)
 
-        return self.with_default_connection(_do_publish)(
-                connection=connection, connect_timeout=connect_timeout)
-
     def AsyncResult(self, task_id, backend=None, task_name=None):
         """Create :class:`celery.result.BaseAsyncResult` instance."""
         from celery.result import BaseAsyncResult
@@ -255,6 +248,16 @@ class BaseApp(object):
                                 "BROKER_CONNECTION_TIMEOUT", connect_timeout),
                     transport_options=self.conf.BROKER_TRANSPORT_OPTIONS)
 
+    @contextmanager
+    def default_connection(self, connection=None, connect_timeout=None):
+        """For use within a with-statement to get a connection from the pool
+        if one is not already provided."""
+        if connection:
+            yield connection
+        else:
+            with self.pool.acquire(block=True) as connection:
+                yield connection
+
     def with_default_connection(self, fun):
         """With any function accepting `connection` and `connect_timeout`
         keyword arguments, establishes a default connection if one is
@@ -263,20 +266,17 @@ class BaseApp(object):
         Any automatically established connection will be closed after
         the function returns.
 
-        """
+        **Deprecated**
 
+        Use ``with app.default_connection(connection)`` instead.
+
+        """
         @wraps(fun)
         def _inner(*args, **kwargs):
-            connection = kwargs.get("connection")
-            kwargs["connection"] = conn = connection or \
-                    self.pool.acquire(block=True)
-            close_connection = not connection and conn.release or None
-
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                if close_connection:
-                    close_connection()
+            connection = kwargs.pop("connection", None)
+            connect_timeout = kwargs.get("connect_timeout")
+            with self.default_connection(connection, connect_timeout) as c:
+                return fun(*args, **dict(kwargs, connection=c))
         return _inner
 
     def prepare_config(self, c):

+ 33 - 47
celery/beat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import errno
 import os
 import time
@@ -69,10 +71,10 @@ class ScheduleEntry(object):
             options={}, relative=False):
         self.name = name
         self.task = task
-        self.schedule = maybe_schedule(schedule, relative)
         self.args = args
         self.kwargs = kwargs
         self.options = options
+        self.schedule = maybe_schedule(schedule, relative)
         self.last_run_at = last_run_at or self._default_now()
         self.total_run_count = total_run_count or 0
 
@@ -82,25 +84,20 @@ class ScheduleEntry(object):
     def next(self, last_run_at=None):
         """Returns a new instance of the same class, but with
         its date and count fields updated."""
-        last_run_at = last_run_at or datetime.now()
-        total_run_count = self.total_run_count + 1
         return self.__class__(**dict(self,
-                                     last_run_at=last_run_at,
-                                     total_run_count=total_run_count))
+                                     last_run_at=last_run_at or datetime.now(),
+                                     total_run_count=self.total_run_count + 1))
 
     def update(self, other):
         """Update values from another entry.
 
-        Does only update "editable" fields (schedule, args,
-        kwargs, options).
+        Does only update "editable" fields (task, schedule, args, kwargs,
+        options).
 
         """
-        self.task = other.task
-        self.schedule = other.schedule
-        self.args = other.args
-        self.kwargs = other.kwargs
-        self.options = other.options
-
+        self.__dict__.update({"task": other.task, "schedule": other.schedule,
+                              "args": other.args, "kwargs": other.kwargs,
+                              "options": other.options})
     def is_due(self):
         """See :meth:`celery.task.base.PeriodicTask.is_due`."""
         return self.schedule.is_due(self.last_run_at)
@@ -109,8 +106,8 @@ class ScheduleEntry(object):
         return vars(self).iteritems()
 
     def __repr__(self):
-        return "<Entry: %s %s(*%s, **%s) {%s}>" % (
-                self.name, self.task, self.args, self.kwargs, self.schedule)
+        return ("<Entry: %(name)s %(task)s(*%(args)s, **%(kwargs)s) "
+                "{%(schedule)s}>" % vars(self))
 
 
 class Scheduler(object):
@@ -135,15 +132,12 @@ class Scheduler(object):
 
     def __init__(self, schedule=None, logger=None, max_interval=None,
             app=None, Publisher=None, lazy=False, **kwargs):
-        if schedule is None:
-            schedule = {}
-        self.app = app_or_default(app)
-        conf = self.app.conf
-        self.data = maybe_promise(schedule)
-        self.logger = logger or self.app.log.get_default_logger(
-                                                name="celery.beat")
-        self.max_interval = max_interval or conf.CELERYBEAT_MAX_LOOP_INTERVAL
-        self.Publisher = Publisher or self.app.amqp.TaskPublisher
+        app = self.app = app_or_default(app)
+        self.data = maybe_promise({} if schedule is None else schedule)
+        self.logger = logger or app.log.get_default_logger(name="celery.beat")
+        self.max_interval = max_interval or \
+                                app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
+        self.Publisher = Publisher or app.amqp.TaskPublisher
         if not lazy:
             self.setup_schedule()
 
@@ -198,11 +192,7 @@ class Scheduler(object):
         # so we have that done if an exception is raised (doesn't schedule
         # forever.)
         entry = self.reserve(entry)
-
-        try:
-            task = registry.tasks[entry.task]
-        except KeyError:
-            task = None
+        task = registry.tasks.get(entry.task)
 
         try:
             if task:
@@ -345,25 +335,22 @@ class Service(object):
 
     def __init__(self, logger=None, max_interval=None, schedule_filename=None,
             scheduler_cls=None, app=None):
-        self.app = app_or_default(app)
+        app = self.app = app_or_default(app)
         self.max_interval = max_interval or \
-                            self.app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
+                                app.conf.CELERYBEAT_MAX_LOOP_INTERVAL
         self.scheduler_cls = scheduler_cls or self.scheduler_cls
-        self.logger = logger or self.app.log.get_default_logger(
-                                                name="celery.beat")
+        self.logger = logger or app.log.get_default_logger(name="celery.beat")
         self.schedule_filename = schedule_filename or \
-                                    self.app.conf.CELERYBEAT_SCHEDULE_FILENAME
+                                    app.conf.CELERYBEAT_SCHEDULE_FILENAME
 
         self._shutdown = threading.Event()
         self._stopped = threading.Event()
-        silence = self.max_interval < 60 and 10 or 1
         self.debug = SilenceRepeated(self.logger.debug,
-                                     max_iterations=silence)
+                        10 if self.max_interval < 60 else 1)
 
     def start(self, embedded_process=False):
         self.logger.info("Celerybeat: Starting...")
-        self.logger.debug("Celerybeat: "
-            "Ticking with max interval->%s" % (
+        self.logger.debug("Celerybeat: Ticking with max interval->%s" % (
                     humanize_seconds(self.scheduler.max_interval)))
 
         signals.beat_init.send(sender=self)
@@ -372,14 +359,13 @@ class Service(object):
             platforms.set_process_title("celerybeat")
 
         try:
-            try:
-                while not self._shutdown.isSet():
-                    interval = self.scheduler.tick()
-                    self.debug("Celerybeat: Waking up %s." % (
-                            humanize_seconds(interval, prefix="in ")))
-                    time.sleep(interval)
-            except (KeyboardInterrupt, SystemExit):
-                self._shutdown.set()
+            while not self._shutdown.isSet():
+                interval = self.scheduler.tick()
+                self.debug("Celerybeat: Waking up %s." % (
+                        humanize_seconds(interval, prefix="in ")))
+                time.sleep(interval)
+        except (KeyboardInterrupt, SystemExit):
+            self._shutdown.set()
         finally:
             self.sync()
 
@@ -390,7 +376,7 @@ class Service(object):
     def stop(self, wait=False):
         self.logger.info("Celerybeat: Shutting down...")
         self._shutdown.set()
-        wait and self._stopped.wait()           # block until shutdown done.
+        wait and self._stopped.wait()  # block until shutdown done.
 
     def get_scheduler(self, lazy=False):
         filename = self.schedule_filename

+ 9 - 6
celery/loaders/base.py

@@ -38,9 +38,12 @@ class BaseLoader(object):
         * What modules are imported to find tasks?
 
     """
-    worker_initialized = False
-    override_backends = {}
+    builtin_modules = BUILTIN_MODULES
     configured = False
+    error_envvar_not_set = ERROR_ENVVAR_NOT_SET
+    override_backends = {}
+    worker_initialized = False
+
     _conf = None
 
     def __init__(self, app=None, **kwargs):
@@ -71,9 +74,9 @@ class BaseLoader(object):
                 self.import_module if imp is None else imp)
 
     def import_default_modules(self):
-        imports = self.conf.get("CELERY_IMPORTS") or ()
-        imports = set(list(imports)) | BUILTIN_MODULES
-        return [self.import_task_module(module) for module in imports]
+        imports = set(list(self.conf.get("CELERY_IMPORTS") or ()))
+        return [self.import_task_module(module)
+                    for module in imports | self.builtin_modules]
 
     def init_worker(self):
         if not self.worker_initialized:
@@ -85,7 +88,7 @@ class BaseLoader(object):
         if not module_name:
             if silent:
                 return False
-            raise ImproperlyConfigured(ERROR_ENVVAR_NOT_SET % (module_name, ))
+            raise ImproperlyConfigured(self.error_envvar_not_set % module_name)
         return self.config_from_object(module_name, silent=silent)
 
     def config_from_object(self, obj, silent=False):

+ 2 - 6
celery/result.py

@@ -1,4 +1,4 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, with_statement
 
 import time
 
@@ -308,14 +308,10 @@ class ResultSet(object):
 
     def revoke(self, connection=None, connect_timeout=None):
         """Revoke all tasks in the set."""
-
-        def _do_revoke(connection=None, connect_timeout=None):
+        with self.app.default_connection(connection, connect_timeout) as conn:
             for result in self.results:
                 result.revoke(connection=connection)
 
-        return self.app.with_default_connection(_do_revoke)(
-                connection=connection, connect_timeout=connect_timeout)
-
     def __iter__(self):
         return self.iterate()
 

+ 9 - 22
celery/task/control.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import, with_statement
+
 from kombu.pidbox import Mailbox
 
 from celery.app import app_or_default
@@ -94,16 +96,9 @@ class Control(object):
         :returns: the number of tasks discarded.
 
         """
-
-        def _do_discard(connection=None, connect_timeout=None):
-            consumer = self.app.amqp.get_task_consumer(connection=connection)
-            try:
+        with self.app.default_connection(connection, connect_timeout) as conn:
+            with self.app.amqp.get_task_consumer(connection=conn) as consumer:
                 return consumer.discard_all()
-            finally:
-                consumer.close()
-
-        return self.app.with_default_connection(_do_discard)(
-                connection=connection, connect_timeout=connect_timeout)
 
     def revoke(self, task_id, destination=None, terminate=False,
             signal="SIGTERM", **kwargs):
@@ -211,19 +206,11 @@ class Control(object):
             received.
 
         """
-        def _do_broadcast(connection=None, connect_timeout=None,
-                          channel=None):
-            return self.mailbox(connection)._broadcast(command, arguments,
-                                                       destination, reply,
-                                                       timeout, limit,
-                                                       callback,
-                                                       channel=channel)
-
-        if channel:
-            return _do_broadcast(connection, connect_timeout, channel)
-        else:
-            return self.app.with_default_connection(_do_broadcast)(
-                    connection=connection, connect_timeout=connect_timeout)
+        with self.app.default_connection(connection, connect_timeout) as conn:
+            return self.mailbox(conn)._broadcast(command, arguments,
+                                                 destination, reply, timeout,
+                                                 limit, callback,
+                                                 channel=channel)
 
 
 _default_control = Control(app_or_default())

+ 13 - 16
celery/task/sets.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import, with_statement
+
 import warnings
 
 from kombu.utils import cached_property
@@ -150,26 +152,21 @@ class TaskSet(UserList):
     def apply_async(self, connection=None, connect_timeout=None,
             publisher=None, taskset_id=None):
         """Apply taskset."""
-        return self.app.with_default_connection(self._apply_async)(
-                    connection=connection,
-                    connect_timeout=connect_timeout,
-                    publisher=publisher,
-                    taskset_id=taskset_id)
+        app = self.app
 
-    def _apply_async(self, connection=None, connect_timeout=None,
-            publisher=None, taskset_id=None):
-        if self.app.conf.CELERY_ALWAYS_EAGER:
+        if app.conf.CELERY_ALWAYS_EAGER:
             return self.apply(taskset_id=taskset_id)
 
-        setid = taskset_id or gen_unique_id()
-        pub = publisher or self.Publisher(connection=connection)
-        try:
-            results = self._async_results(setid, pub)
-        finally:
-            if not publisher:  # created by us.
-                pub.close()
+        with app.default_connection(connection, connect_timeout) as conn:
+            setid = taskset_id or gen_unique_id()
+            pub = publisher or self.Publisher(connection=conn)
+            try:
+                results = self._async_results(setid, pub)
+            finally:
+                if not publisher:  # created by us.
+                    pub.close()
 
-        return self.app.TaskSetResult(setid, results)
+            return app.TaskSetResult(setid, results)
 
     def _async_results(self, taskset_id, publisher):
         return [task.apply_async(taskset_id=taskset_id, publisher=publisher)

+ 4 - 8
celery/utils/timer2.py

@@ -1,4 +1,6 @@
 """timer2 - Scheduler for Python functions."""
+from __future__ import absolute_import, with_statement
+
 import atexit
 import heapq
 import logging
@@ -164,15 +166,12 @@ class Timer(Thread):
                 traceback.print_exception(typ, val, tb)
 
     def next(self):
-        self.not_empty.acquire()
-        try:
+        with self.not_empty:
             delay, entry = self.scheduler.next()
             if entry is None:
                 if delay is None:
                     self.not_empty.wait(1.0)
                 return delay
-        finally:
-            self.not_empty.release()
         return self.apply_entry(entry)
 
     def run(self):
@@ -212,13 +211,10 @@ class Timer(Thread):
 
     def enter(self, entry, eta, priority=None):
         self.ensure_started()
-        self.mutex.acquire()
-        try:
+        with self.mutex:
             entry = self.schedule.enter(entry, eta, priority)
             self.not_empty.notify()
             return entry
-        finally:
-            self.mutex.release()
 
     def apply_at(self, eta, fun, args=(), kwargs={}, priority=0):
         return self.enter(self.Entry(fun, args, kwargs), eta, priority)

+ 9 - 11
celery/worker/__init__.py

@@ -251,7 +251,7 @@ class WorkController(object):
         except SystemTerminate:
             self.terminate()
             raise SystemExit()
-        except (SystemExit, KeyboardInterrupt), exc:
+        except BaseException, exc:
             self.stop()
             raise exc
 
@@ -260,27 +260,25 @@ class WorkController(object):
         try:
             request.task.execute(request, self.pool,
                                  self.loglevel, self.logfile)
+        except Exception, exc:
+            self.logger.critical("Internal error %s: %s\n%s" % (
+                            exc.__class__, exc, traceback.format_exc()))
         except SystemTerminate:
             self.terminate()
             raise SystemExit()
-        except (SystemExit, KeyboardInterrupt), exc:
+        except BaseException, exc:
             self.stop()
             raise exc
-        except Exception, exc:
-            self.logger.critical("Internal error %s: %s\n%s" % (
-                            exc.__class__, exc, traceback.format_exc()))
 
     def stop(self, in_sighandler=False):
         """Graceful shutdown of the worker server."""
-        if in_sighandler and not self.pool.signal_safe:
-            return
-        blocking(self._shutdown, warm=True)
+        if not in_sighandler or self.pool.signal_safe:
+            blocking(self._shutdown, warm=True)
 
     def terminate(self, in_sighandler=False):
         """Not so graceful shutdown of the worker server."""
-        if in_sighandler and not self.pool.signal_safe:
-            return
-        blocking(self._shutdown, warm=False)
+        if not in_sighandler or self.pool.signal_safe:
+            blocking(self._shutdown, warm=False)
 
     def _shutdown(self, warm=True):
         what = (warm and "stopping" or "terminating").capitalize()

+ 4 - 8
celery/worker/buckets.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import, with_statement
+
 import threading
 
 from collections import deque
@@ -51,14 +53,11 @@ class TaskBucket(object):
     def put(self, request):
         """Put a :class:`~celery.worker.job.TaskRequest` into
         the appropiate bucket."""
-        self.mutex.acquire()
-        try:
+        with self.mutex:
             if request.task_name not in self.buckets:
                 self.add_bucket_for_type(request.task_name)
             self.buckets[request.task_name].put_nowait(request)
             self.not_empty.notify()
-        finally:
-            self.mutex.release()
     put_nowait = put
 
     def _get_immediate(self):
@@ -113,8 +112,7 @@ class TaskBucket(object):
         time_start = time()
         did_timeout = lambda: timeout and time() - time_start > timeout
 
-        self.not_empty.acquire()
-        try:
+        with self.not_empty:
             while True:
                 try:
                     remaining_time, item = self._get()
@@ -129,8 +127,6 @@ class TaskBucket(object):
                     sleep(min(remaining_time, timeout or 1))
                 else:
                     return item
-        finally:
-            self.not_empty.release()
 
     def get_nowait(self):
         return self.get(block=False)