Selaa lähdekoodia

More Python 2.4 cleanup

Ask Solem 14 vuotta sitten
vanhempi
commit
b8cf78eb1d

+ 8 - 9
celery/apps/worker.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import atexit
 import logging
 try:
@@ -9,8 +11,6 @@ import socket
 import sys
 import warnings
 
-from kombu.utils import partition
-
 from celery import __version__
 from celery import platforms
 from celery import signals
@@ -61,8 +61,7 @@ class Worker(object):
             autoscale=None, scheduler_cls=None, pool=None, **kwargs):
         self.app = app = app_or_default(app)
         self.concurrency = (concurrency or
-                            app.conf.CELERYD_CONCURRENCY or
-                            cpu_count())
+                            app.conf.CELERYD_CONCURRENCY or cpu_count())
         self.loglevel = loglevel or app.conf.CELERYD_LOG_LEVEL
         self.logfile = logfile or app.conf.CELERYD_LOG_FILE
 
@@ -87,13 +86,13 @@ class Worker(object):
                                        app.conf.CELERY_REDIRECT_STDOUTS_LEVEL)
         self.pool = (pool or app.conf.CELERYD_POOL)
         self.db = db
-        self.use_queues = queues or []
+        self.use_queues = [] if queues is None else queues
         self.queues = None
-        self.include = include or []
+        self.include = [] if include is None else include
         self.pidfile = pidfile
         self.autoscale = None
         if autoscale:
-            max_c, _, min_c = partition(autoscale, ",")
+            max_c, _, min_c = autoscale.partition(",")
             self.autoscale = [int(max_c), min_c and int(min_c) or 0]
         self._isatty = sys.stdout.isatty()
 
@@ -205,8 +204,8 @@ class Worker(object):
             "concurrency": concurrency,
             "loglevel": LOG_LEVELS[self.loglevel],
             "logfile": self.logfile or "[stderr]",
-            "celerybeat": self.run_clockservice and "ON" or "OFF",
-            "events": self.events and "ON" or "OFF",
+            "celerybeat": "ON" if self.run_clockservice else "OFF",
+            "events": "ON" if self.events else "OFF",
             "loader": get_full_cls_name(self.loader.__class__),
             "queues": app.amqp.queues.format(indent=18, indent_first=False),
         }

+ 2 - 2
celery/backends/cache.py

@@ -1,6 +1,6 @@
 from datetime import timedelta
 
-from kombu.utils import partition, cached_property
+from kombu.utils import cached_property
 
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
@@ -70,7 +70,7 @@ class CacheBackend(KeyValueStoreBackend):
 
         backend = backend or self.app.conf.CELERY_CACHE_BACKEND
         self.expires = int(self.expires)
-        self.backend, _, servers = partition(backend, "://")
+        self.backend, _, servers = backend.partition("://")
         self.servers = servers.rstrip('/').split(";")
         try:
             self.Client = backends[self.backend]()

+ 1 - 2
celery/bin/camqadm.py

@@ -12,7 +12,6 @@ import pprint
 from itertools import count
 
 from amqplib import client_0_8 as amqp
-from kombu.utils import partition
 
 from celery.app import app_or_default
 from celery.bin.base import Command
@@ -257,7 +256,7 @@ class AMQShell(cmd.Cmd):
         if first:
             return first
         return [cmd for cmd in names
-                    if partition(cmd, ".")[2].startswith(text)]
+                    if cmd.partition(".")[2].startswith(text)]
 
     def dispatch(self, cmd, argline):
         """Dispatch and execute the command.

+ 2 - 3
celery/concurrency/base.py

@@ -104,9 +104,8 @@ class BasePool(object):
         else:
             self.safe_apply_callback(callback, ret_value)
 
-    def on_worker_error(self, errbacks, exc):
-        einfo = ExceptionInfo((exc.__class__, exc, None))
-        [errback(einfo) for errback in errbacks]
+    def on_worker_error(self, errback, exc):
+        errback(ExceptionInfo((exc.__class__, exc, None)))
 
     def safe_apply_callback(self, fun, *args):
         if fun:

+ 24 - 31
celery/events/__init__.py

@@ -1,8 +1,11 @@
+from __future__ import absolute_import, with_statement
+
 import time
 import socket
 import threading
 
 from collections import deque
+from contextlib import contextmanager
 from itertools import count
 
 from kombu.entity import Exchange, Queue
@@ -87,22 +90,17 @@ class EventDispatcher(object):
         :keyword \*\*fields: Event arguments.
 
         """
-        if not self.enabled:
-            return
-
-        self._lock.acquire()
-        event = Event(type, hostname=self.hostname,
-                            clock=self.app.clock.forward(), **fields)
-        try:
-            try:
-                self.publisher.publish(event,
-                                       routing_key=type.replace("-", "."))
-            except Exception, exc:
-                if not self.buffer_while_offline:
-                    raise
-                self._outbound_buffer.append((type, fields, exc))
-        finally:
-            self._lock.release()
+        if self.enabled:
+            with self._lock:
+                event = Event(type, hostname=self.hostname,
+                                    clock=self.app.clock.forward(), **fields)
+                try:
+                    self.publisher.publish(event,
+                                           routing_key=type.replace("-", "."))
+                except Exception, exc:
+                    if not self.buffer_while_offline:
+                        raise
+                    self._outbound_buffer.append((type, fields, exc))
 
     def flush(self):
         while self._outbound_buffer:
@@ -154,6 +152,7 @@ class EventReceiver(object):
         handler = self.handlers.get(type) or self.handlers.get("*")
         handler and handler(event)
 
+    @contextmanager
     def consumer(self):
         """Create event consumer.
 
@@ -164,24 +163,20 @@ class EventReceiver(object):
 
         """
         consumer = Consumer(self.connection.channel(),
-                            queues=[self.queue],
-                            no_ack=True)
+                            queues=[self.queue], no_ack=True)
         consumer.register_callback(self._receive)
-        return consumer
+        with consumer:
+            yield consumer
+        consumer.channel.close()
 
     def itercapture(self, limit=None, timeout=None, wakeup=True):
-        consumer = self.consumer()
-        consumer.consume()
-        if wakeup:
-            self.wakeup_workers(channel=consumer.channel)
+        with self.consumer() as consumer:
+            if wakeup:
+                self.wakeup_workers(channel=consumer.channel)
 
-        yield consumer
+            yield consumer
 
-        try:
             self.drain_events(limit=limit, timeout=timeout)
-        finally:
-            consumer.cancel()
-            consumer.channel.close()
 
     def capture(self, limit=None, timeout=None, wakeup=True):
         """Open up a consumer capturing events.
@@ -190,9 +185,7 @@ class EventReceiver(object):
         stop unless forced via :exc:`KeyboardInterrupt` or :exc:`SystemExit`.
 
         """
-        list(self.itercapture(limit=limit,
-                              timeout=timeout,
-                              wakeup=wakeup))
+        list(self.itercapture(limit=limit, timeout=timeout, wakeup=wakeup))
 
     def wakeup_workers(self, channel=None):
         self.app.control.broadcast("heartbeat",

+ 12 - 22
celery/events/state.py

@@ -1,10 +1,10 @@
+from __future__ import absolute_import, with_statement
+
 import time
 import heapq
 
 from threading import Lock
 
-from kombu.utils import partition
-
 from celery import states
 from celery.datastructures import AttributeDict, LocalCache
 from celery.utils import kwdict
@@ -178,20 +178,16 @@ class State(object):
 
     def freeze_while(self, fun, *args, **kwargs):
         clear_after = kwargs.pop("clear_after", False)
-        self._mutex.acquire()
-        try:
-            return fun(*args, **kwargs)
-        finally:
-            if clear_after:
-                self._clear()
-            self._mutex.release()
+        with self._mutex:
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                if clear_after:
+                    self._clear()
 
     def clear_tasks(self, ready=True):
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             return self._clear_tasks(ready)
-        finally:
-            self._mutex.release()
 
     def _clear_tasks(self, ready=True):
         if ready:
@@ -208,11 +204,8 @@ class State(object):
         self.task_count = 0
 
     def clear(self, ready=True):
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             return self._clear(ready)
-        finally:
-            self._mutex.release()
 
     def get_or_create_worker(self, hostname, **kwargs):
         """Get or create worker by hostname."""
@@ -255,16 +248,13 @@ class State(object):
         task.worker = worker
 
     def event(self, event):
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             return self._dispatch_event(event)
-        finally:
-            self._mutex.release()
 
     def _dispatch_event(self, event):
         self.event_count += 1
         event = kwdict(event)
-        group, _, type = partition(event.pop("type"), "-")
+        group, _, type = event.pop("type").partition("-")
         self.group_handlers[group](type, event)
         if self.event_callback:
             self.event_callback(self, event)

+ 6 - 10
celery/loaders/__init__.py

@@ -2,13 +2,12 @@ from __future__ import absolute_import
 
 import os
 
+from celery import current_app
 from celery.utils import get_cls_by_name
 
 LOADER_ALIASES = {"app": "celery.loaders.app.AppLoader",
                   "default": "celery.loaders.default.Loader",
                   "django": "djcelery.loaders.DjangoLoader"}
-_loader = None
-_settings = None
 
 
 def get_loader_cls(loader):
@@ -17,20 +16,17 @@ def get_loader_cls(loader):
 
 
 def setup_loader():
+    # XXX Deprecate
     return get_loader_cls(os.environ.setdefault("CELERY_LOADER", "default"))()
 
 
 def current_loader():
     """Detect and return the current loader."""
-    global _loader
-    if _loader is None:
-        _loader = setup_loader()
-    return _loader
+    # XXX Deprecate
+    return current_app.loader
 
 
 def load_settings():
     """Load the global settings object."""
-    global _settings
-    if _settings is None:
-        _settings = current_loader().conf
-    return _settings
+    # XXX Deprecate
+    return current_app.conf

+ 4 - 14
celery/tests/test_app/test_loaders.py

@@ -12,8 +12,7 @@ from celery.loaders import default
 from celery.loaders.app import AppLoader
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
-from celery.tests.utils import with_environ
+from celery.tests.utils import unittest, AppCase, with_environ
 
 
 class ObjectConfig(object):
@@ -59,7 +58,7 @@ class DummyLoader(base.BaseLoader):
         return MockMail()
 
 
-class TestLoaders(unittest.TestCase):
+class TestLoaders(AppCase):
 
     def test_get_loader_cls(self):
 
@@ -67,19 +66,10 @@ class TestLoaders(unittest.TestCase):
                           default.Loader)
 
     def test_current_loader(self):
-        loader1 = loaders.current_loader()
-        loader2 = loaders.current_loader()
-        self.assertIs(loader1, loader2)
-        self.assertIs(loader2, loaders._loader)
+        self.assertIs(loaders.current_loader(), self.app.loader)
 
     def test_load_settings(self):
-        loader = loaders.current_loader()
-        loaders._settings = None
-        settings = loaders.load_settings()
-        self.assertTrue(loaders._settings)
-        settings = loaders.load_settings()
-        self.assertIs(settings, loaders._settings)
-        self.assertIs(settings, loader.conf)
+        self.assertIs(loaders.load_settings(), self.app.conf)
 
     @with_environ("CELERY_LOADER", "default")
     def test_detect_loader_CELERY_LOADER(self):

+ 3 - 6
celery/tests/test_backends/test_amqp.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import socket
 import sys
 
@@ -202,16 +204,11 @@ class test_AMQPBackend(unittest.TestCase):
                 pass
 
         b = self.create_backend()
-        conn = current_app.pool.acquire(block=False)
-        channel = conn.channel()
-        try:
+        with current_app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(gen_unique_id())
             consumer = b._create_consumer(binding, channel)
             self.assertRaises(socket.timeout, b.drain_events,
                               Connection(), consumer, timeout=0.1)
-        finally:
-            channel.close()
-            conn.release()
 
     def test_get_many(self):
         b = self.create_backend()

+ 1 - 1
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -130,7 +130,7 @@ class test_TaskPool(unittest.TestCase):
 
         pool = TaskPool(10)
         exc = KeyError("foo")
-        pool.on_worker_error([errback], exc)
+        pool.on_worker_error(errback, exc)
 
         self.assertTrue(scratch[0])
         self.assertIs(scratch[0].exception, exc)

+ 2 - 10
celery/tests/test_worker/test_worker_job.py

@@ -22,9 +22,8 @@ from celery.log import setup_logger
 from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.utils import gen_unique_id
-from celery.worker.job import WorkerTaskTrace, TaskRequest
-from celery.worker.job import execute_and_trace, AlreadyExecutedError
-from celery.worker.job import InvalidTaskError
+from celery.worker.job import (WorkerTaskTrace, TaskRequest,
+                               InvalidTaskError, execute_and_trace)
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings
@@ -456,13 +455,6 @@ class test_TaskRequest(unittest.TestCase):
         w.handle_failure(value_, type_, tb_, "")
         self.assertEqual(mytask.backend.get_status(uuid), states.FAILURE)
 
-    def test_executed_bit(self):
-        tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
-        self.assertFalse(tw.executed)
-        tw._set_executed_bit()
-        self.assertTrue(tw.executed)
-        self.assertRaises(AlreadyExecutedError, tw._set_executed_bit)
-
     def test_task_wrapper_mail_attrs(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
         x = tw.success_msg % {"name": tw.task_name,

+ 1 - 1
celery/tests/utils.py

@@ -39,7 +39,7 @@ class AppCase(unittest.TestCase):
 
     def setUp(self):
         from celery.app import current_app
-        self._current_app = current_app()
+        self.app = self._current_app = current_app()
         self.setup()
 
     def tearDown(self):

+ 2 - 3
celery/utils/__init__.py

@@ -13,7 +13,6 @@ from itertools import islice
 from pprint import pprint
 
 from kombu.utils import cached_property, gen_unique_id  # noqa
-from kombu.utils import rpartition
 
 from celery.utils.compat import StringIO
 
@@ -304,7 +303,7 @@ def get_cls_by_name(name, aliases={}, imp=None):
         return name                                 # already a class
 
     name = aliases.get(name) or name
-    module_name, _, cls_name = rpartition(name, ".")
+    module_name, _, cls_name = name.rpartition(".")
     try:
         module = imp(module_name)
     except ValueError, exc:
@@ -342,7 +341,7 @@ def abbrtask(S, max):
     if S is None:
         return "???"
     if len(S) > max:
-        module, _, cls = rpartition(S, ".")
+        module, _, cls = S.rpartition(".")
         module = abbr(module, max - len(cls) - 3, False)
         return module + "[.]" + cls
     return S

+ 1 - 3
celery/utils/timeutils.py

@@ -3,8 +3,6 @@ import math
 from datetime import datetime, timedelta
 from dateutil.parser import parse as parse_iso8601
 
-from kombu.utils import partition
-
 DAYNAMES = "sun", "mon", "tue", "wed", "thu", "fri", "sat"
 WEEKDAYS = dict((name, dow) for name, dow in zip(DAYNAMES, range(7)))
 
@@ -90,7 +88,7 @@ def rate(rate):
     and converts them to seconds."""
     if rate:
         if isinstance(rate, basestring):
-            ops, _, modifier = partition(rate, "/")
+            ops, _, modifier = rate.partition("/")
             return RATE_MODIFIER_MAP[modifier or "s"](int(ops)) or 0
         return rate or 0
     return 0

+ 10 - 22
celery/worker/job.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 import sys
 import time
@@ -50,9 +52,12 @@ class InvalidTaskError(Exception):
     """The task has invalid data or is not properly constructed."""
 
 
-class AlreadyExecutedError(Exception):
-    """Tasks can only be executed once, as they might change
-    world-wide state."""
+def default_encode(obj):
+    if sys.platform.startswith("java"):
+        coding = "utf-8"
+    else:
+        coding = sys.getfilesystemencoding()
+    return unicode(obj, coding)
 
 
 class WorkerTaskTrace(TaskTrace):
@@ -211,9 +216,6 @@ class TaskRequest(object):
     #: The message object.  Used to acknowledge the message.
     message = None
 
-    #: Flag set when the task has been executed.
-    executed = False
-
     #: Additional delivery info, e.g. contains the path from
     #: Producer to consumer.
     delivery_info = None
@@ -287,8 +289,8 @@ class TaskRequest(object):
             the message is also rejected.
 
         """
-        _delivery_info = getattr(message, "delivery_info", {})
-        delivery_info = dict((key, _delivery_info.get(key))
+        delivery_info = getattr(message, "delivery_info", {})
+        delivery_info = dict((key, delivery_info.get(key))
                                 for key in WANTED_DELIVERY_INFO)
 
         kwargs = body["kwargs"]
@@ -357,9 +359,6 @@ class TaskRequest(object):
         if self.revoked():
             return
 
-        # Make sure task has not already been executed.
-        self._set_executed_bit()
-
         args = self._get_tracer_args(loglevel, logfile)
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
         result = pool.apply_async(execute_and_trace,
@@ -385,9 +384,6 @@ class TaskRequest(object):
         if self.revoked():
             return
 
-        # Make sure task has not already been executed.
-        self._set_executed_bit()
-
         # acknowledge task as being processed.
         if not self.task.acks_late:
             self.acknowledge()
@@ -585,11 +581,3 @@ class TaskRequest(object):
         """Get the :class:`WorkerTaskTrace` tracer for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         return self.task_name, self.task_id, self.args, task_func_kwargs
-
-    def _set_executed_bit(self):
-        """Set task as executed to make sure it's not executed again."""
-        if self.executed:
-            raise AlreadyExecutedError(
-                   "Task %s[%s] has already been executed" % (
-                       self.task_name, self.task_id))
-        self.executed = True