Browse Source

Remove 2.4 workarounds

Ask Solem 14 years ago
parent
commit
1dfe6d0d54
43 changed files with 185 additions and 454 deletions
  1. 0 1
      celery/app/base.py
  2. 1 1
      celery/bin/celeryd_multi.py
  3. 1 1
      celery/datastructures.py
  4. 2 2
      celery/db/session.py
  5. 2 0
      celery/execute/__init__.py
  6. 9 13
      celery/execute/trace.py
  7. 2 0
      celery/loaders/__init__.py
  8. 6 0
      celery/loaders/app.py
  9. 10 10
      celery/loaders/base.py
  10. 5 2
      celery/loaders/default.py
  11. 8 0
      celery/local.py
  12. 11 17
      celery/log.py
  13. 3 2
      celery/registry.py
  14. 4 6
      celery/result.py
  15. 4 6
      celery/routes.py
  16. 4 4
      celery/schedules.py
  17. 3 4
      celery/task/base.py
  18. 5 1
      celery/task/http.py
  19. 10 13
      celery/tests/test_app/test_loaders.py
  20. 4 5
      celery/tests/test_backends/test_database.py
  21. 4 5
      celery/tests/test_backends/test_redis.py
  22. 2 1
      celery/tests/test_bin/test_celerybeat.py
  23. 5 5
      celery/tests/test_bin/test_celeryd.py
  24. 4 5
      celery/tests/test_compat/test_decorators.py
  25. 7 45
      celery/tests/test_compat/test_log.py
  26. 0 2
      celery/tests/test_slow/test_buckets.py
  27. 1 5
      celery/tests/test_task/test_result.py
  28. 14 15
      celery/tests/test_task/test_task_builtins.py
  29. 11 37
      celery/tests/test_task/test_task_http.py
  30. 5 11
      celery/tests/test_task/test_task_sets.py
  31. 5 8
      celery/tests/test_utils/test_serialization.py
  32. 4 4
      celery/tests/test_worker/test_worker.py
  33. 6 13
      celery/tests/test_worker/test_worker_job.py
  34. 2 59
      celery/tests/utils.py
  35. 0 3
      celery/utils/__init__.py
  36. 1 99
      celery/utils/compat.py
  37. 2 8
      celery/utils/serialization.py
  38. 0 3
      celery/utils/timer2.py
  39. 1 1
      celery/worker/buckets.py
  40. 4 19
      celery/worker/consumer.py
  41. 5 8
      celery/worker/job.py
  42. 6 9
      celery/worker/mediator.py
  43. 2 1
      celery/worker/state.py

+ 0 - 1
celery/app/base.py

@@ -25,7 +25,6 @@ import kombu
 if kombu.VERSION < (1, 1, 0):
     raise ImportError("Celery requires Kombu version 1.1.0 or higher.")
 
-
 BUGREPORT_INFO = """
 platform -> system:%(system)s arch:%(arch)s imp:%(py_i)s
 software -> celery:%(celery_v)s kombu:%(kombu_v)s py:%(py_v)s

+ 1 - 1
celery/bin/celeryd_multi.py

@@ -92,12 +92,12 @@ import signal
 import socket
 import sys
 
+from collections import defaultdict
 from subprocess import Popen
 from time import sleep
 
 from celery import __version__
 from celery.utils import term
-from celery.utils.compat import any, defaultdict
 
 SIGNAMES = set(sig for sig in dir(signal)
                         if sig.startswith("SIG") and "_" not in sig)

+ 1 - 1
celery/datastructures.py

@@ -8,7 +8,7 @@ Custom data structures.
 :license: BSD, see LICENSE for more details.
 
 """
-from __future__ import generators
+from __future__ import absolute_import
 
 import time
 import traceback

+ 2 - 2
celery/db/session.py

@@ -1,9 +1,9 @@
+from collections import defaultdict
+
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
 
-from celery.utils.compat import defaultdict
-
 ResultModelBase = declarative_base()
 
 _SETUP = defaultdict(lambda: False)

+ 2 - 0
celery/execute/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery import current_app
 from celery.utils import deprecated
 

+ 9 - 13
celery/execute/trace.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import traceback
 
@@ -32,14 +34,14 @@ class TraceInfo(object):
         """
         try:
             return cls(states.SUCCESS, retval=fun(*args, **kwargs))
-        except (SystemExit, KeyboardInterrupt):
-            raise
         except RetryTaskError, exc:
             return cls(states.RETRY, retval=exc, exc_info=sys.exc_info())
         except Exception, exc:
             if propagate:
                 raise
             return cls(states.FAILURE, retval=exc, exc_info=sys.exc_info())
+        except BaseException, exc:
+            raise
         except:  # pragma: no cover
             # For Python2.5 where raising strings are still allowed
             # (but deprecated)
@@ -93,12 +95,11 @@ class TaskTrace(object):
                                  trace.exc_type, trace.tb, trace.strtb)
         return r
 
-    def handle_after_return(self, status, retval, type_, tb, strtb):
-        einfo = None
+    def handle_after_return(self, status, retval, type_, tb, strtb, einfo=None):
         if status in states.EXCEPTION_STATES:
             einfo = ExceptionInfo((retval, type_, tb))
         self.task.after_return(status, retval, self.task_id,
-                               self.args, self.kwargs, einfo=einfo)
+                               self.args, self.kwargs, einfo)
 
     def handle_success(self, retval, *args):
         """Handle successful execution."""
@@ -107,25 +108,20 @@ class TaskTrace(object):
 
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
-
         # Create a simpler version of the RetryTaskError that stringifies
         # the original exception instead of including the exception instance.
         # This is for reporting the retry in logs, email etc, while
         # guaranteeing pickleability.
         message, orig_exc = exc.args
         expanded_msg = "%s: %s" % (message, str(orig_exc))
-        einfo = ExceptionInfo((type_,
-                               type_(expanded_msg, None),
-                               tb))
-        self.task.on_retry(exc, self.task_id,
-                           self.args, self.kwargs, einfo=einfo)
+        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
+        self.task.on_retry(exc, self.task_id, self.args, self.kwargs, einfo)
         return einfo
 
     def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
         einfo = ExceptionInfo((type_, exc, tb))
-        self.task.on_failure(exc, self.task_id,
-                             self.args, self.kwargs, einfo=einfo)
+        self.task.on_failure(exc, self.task_id, self.args, self.kwargs, einfo)
         signals.task_failure.send(sender=self.task, task_id=self.task_id,
                                   exception=exc, args=self.args,
                                   kwargs=self.kwargs, traceback=tb,

+ 2 - 0
celery/loaders/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 
 from celery.utils import get_cls_by_name

+ 6 - 0
celery/loaders/app.py

@@ -1,3 +1,9 @@
+from __future__ import absolute_import
+
+import os
+
+from celery.datastructures import DictAttribute
+from celery.exceptions import ImproperlyConfigured
 from celery.loaders.base import BaseLoader
 
 

+ 10 - 10
celery/loaders/base.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import importlib
 import os
 import re
@@ -11,7 +13,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import get_cls_by_name
 from celery.utils import import_from_cwd as _import_from_cwd
 
-BUILTIN_MODULES = ["celery.task"]
+BUILTIN_MODULES = frozenset(["celery.task"])
 
 ERROR_ENVVAR_NOT_SET = (
 """The environment variable %r is not set,
@@ -23,7 +25,7 @@ a configuration module.""")
 class BaseLoader(object):
     """The base class for loaders.
 
-    Loaders handles to following things:
+    Loaders handles,
 
         * Reading celery client/worker configurations.
 
@@ -65,14 +67,13 @@ class BaseLoader(object):
         return importlib.import_module(module)
 
     def import_from_cwd(self, module, imp=None):
-        if imp is None:
-            imp = self.import_module
-        return _import_from_cwd(module, imp)
+        return _import_from_cwd(module,
+                self.import_module if imp is None else imp)
 
     def import_default_modules(self):
-        imports = self.conf.get("CELERY_IMPORTS") or []
-        imports = set(list(imports) + BUILTIN_MODULES)
-        return map(self.import_task_module, imports)
+        imports = self.conf.get("CELERY_IMPORTS") or ()
+        imports = set(list(imports)) | BUILTIN_MODULES
+        return [self.import_task_module(module) for module in imports]
 
     def init_worker(self):
         if not self.worker_initialized:
@@ -172,5 +173,4 @@ class BaseLoader(object):
 
     @cached_property
     def mail(self):
-        from celery.utils import mail
-        return mail
+        return self.import_module("celery.utils.mail")

+ 5 - 2
celery/loaders/default.py

@@ -1,10 +1,13 @@
+from __future__ import absolute_import
+
 import os
 import warnings
+
 from importlib import import_module
 
 from celery.datastructures import AttributeDict
-from celery.loaders.base import BaseLoader
 from celery.exceptions import NotConfigured
+from celery.loaders.base import BaseLoader
 
 DEFAULT_CONFIG_MODULE = "celeryconfig"
 
@@ -19,7 +22,7 @@ class Loader(BaseLoader):
         """Read configuration from :file:`celeryconfig.py` and configure
         celery and Django so it can be used by regular Python."""
         configname = os.environ.get("CELERY_CONFIG_MODULE",
-                                    DEFAULT_CONFIG_MODULE)
+                                     DEFAULT_CONFIG_MODULE)
         try:
             celeryconfig = self.import_from_cwd(configname)
         except ImportError:

+ 8 - 0
celery/local.py

@@ -1,3 +1,11 @@
+def try_import(module):
+    from importlib import import_module
+    try:
+        return import_module(module)
+    except ImportError:
+        pass
+
+
 class LocalProxy(object):
     """Code stolen from werkzeug.local.LocalProxy."""
     __slots__ = ('__local', '__dict__', '__name__')

+ 11 - 17
celery/log.py

@@ -1,4 +1,6 @@
 """celery.log"""
+from __future__ import absolute_import
+
 import logging
 import threading
 import sys
@@ -100,37 +102,29 @@ class Logging(object):
         if colorize is None:
             colorize = self.supports_color(logfile)
 
-        if mputil:
-            try:
-                mputil._logger = None
-            except AttributeError:
-                pass
+        if mputil and hasattr(mputil, "_logger"):
+            mputil._logger = None
         ensure_process_aware_logger()
         receivers = signals.setup_logging.send(sender=None,
-                                               loglevel=loglevel,
-                                               logfile=logfile,
-                                               format=format,
-                                               colorize=colorize)
+                        loglevel=loglevel, logfile=logfile,
+                        format=format, colorize=colorize)
         if not receivers:
             root = logging.getLogger()
 
             if self.app.conf.CELERYD_HIJACK_ROOT_LOGGER:
                 root.handlers = []
 
-            mp = mputil and mputil.get_logger() or None
-            for logger in (root, mp):
-                if logger:
-                    self._setup_logger(logger, logfile, format,
-                                       colorize, **kwargs)
-                    logger.setLevel(loglevel)
+            mp = mputil.get_logger() if mputil else None
+            for logger in filter(None, (root, mp)):
+                self._setup_logger(logger, logfile, format, colorize, **kwargs)
+                logger.setLevel(loglevel)
         Logging._setup = True
         return receivers
 
     def _detect_handler(self, logfile=None):
         """Create log handler with either a filename, an open stream
         or :const:`None` (stderr)."""
-        if logfile is None:
-            logfile = sys.__stderr__
+        logfile = sys.__stderr__ if logfile is None else logfile
         if hasattr(logfile, "write"):
             return logging.StreamHandler(logfile)
         return WatchedFileHandler(logfile)

+ 3 - 2
celery/registry.py

@@ -1,4 +1,6 @@
 """celery.registry"""
+from __future__ import absolute_import
+
 import inspect
 
 from celery.exceptions import NotRegistered
@@ -27,8 +29,7 @@ class TaskRegistry(UserDict):
         instance.
 
         """
-        task = inspect.isclass(task) and task() or task
-        self.data[task.name] = task
+        self[task.name] = inspect.isclass(task) and task() or task
 
     def unregister(self, name):
         """Unregister task by name.

+ 4 - 6
celery/result.py

@@ -1,4 +1,4 @@
-from __future__ import generators
+from __future__ import absolute_import
 
 import time
 
@@ -7,10 +7,10 @@ from itertools import imap
 
 from celery import current_app
 from celery import states
+from celery import current_app
 from celery.app import app_or_default
 from celery.exceptions import TimeoutError
 from celery.registry import _unpickle_task
-from celery.utils.compat import any, all
 
 
 def _unpickle_result(task_id, task_name):
@@ -35,10 +35,10 @@ class BaseAsyncResult(object):
     backend = None
 
     def __init__(self, task_id, backend, task_name=None, app=None):
+        self.app = app_or_default(app)
         self.task_id = task_id
         self.backend = backend
         self.task_name = task_name
-        self.app = app_or_default(app)
 
     def forget(self):
         """Forget about (and possibly remove the result of) this task."""
@@ -476,9 +476,7 @@ class TaskSetResult(ResultSet):
     @classmethod
     def restore(self, taskset_id, backend=None):
         """Restore previously saved taskset result."""
-        if backend is None:
-            backend = current_app.backend
-        return backend.restore_taskset(taskset_id)
+        return (backend or current_app.backend).restore_taskset(taskset_id)
 
     def itersubtasks(self):
         """Depreacted.   Use ``iter(self.results)`` instead."""

+ 4 - 6
celery/routes.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery.exceptions import QueueNotFound
 from celery.utils import firstmethod, instantiate, lpmerge, mpromise
 
@@ -22,12 +24,8 @@ class Router(object):
             app=None):
         from celery.app import app_or_default
         self.app = app_or_default(app)
-        if queues is None:
-            queues = {}
-        if routes is None:
-            routes = []
-        self.queues = queues
-        self.routes = routes
+        self.queues = {} if queues is None else queues
+        self.routes = [] if routes is None else routes
         self.create_missing = create_missing
 
     def route(self, options, task, args=(), kwargs={}):

+ 4 - 4
celery/schedules.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 from dateutil.relativedelta import relativedelta
 from pyparsing import (Word, Literal, ZeroOrMore, Optional,
@@ -256,12 +258,11 @@ class crontab(schedule):
     def remaining_estimate(self, last_run_at):
         """Returns when the periodic task should run next as a timedelta."""
         weekday = last_run_at.isoweekday()
-        if weekday == 7:    # Sunday is day 0, not day 7.
-            weekday = 0
+        weekday = 0 if weekday == 7 else weekday  # Sunday is day 0, not day 7.
 
         execute_this_hour = (weekday in self.day_of_week and
                                 last_run_at.hour in self.hour and
-                                last_run_at.minute < max(self.minute))
+                                    last_run_at.minute < max(self.minute))
 
         if execute_this_hour:
             next_minute = min(minute for minute in self.minute
@@ -271,7 +272,6 @@ class crontab(schedule):
                                   microsecond=0)
         else:
             next_minute = min(self.minute)
-
             execute_today = (weekday in self.day_of_week and
                                  last_run_at.hour < max(self.hour))
 

+ 3 - 4
celery/task/base.py

@@ -626,7 +626,7 @@ class BaseTask(object):
             task_id = self.request.id
         self.backend.store_result(task_id, meta, state)
 
-    def on_retry(self, exc, task_id, args, kwargs, einfo=None):
+    def on_retry(self, exc, task_id, args, kwargs, einfo):
         """Retry handler.
 
         This is run by the worker when the task is to be retried.
@@ -644,8 +644,7 @@ class BaseTask(object):
         """
         pass
 
-    def after_return(self, status, retval, task_id, args,
-            kwargs, einfo=None):
+    def after_return(self, status, retval, task_id, args, kwargs, einfo):
         """Handler called after the task returns.
 
         :param status: Current task state.
@@ -664,7 +663,7 @@ class BaseTask(object):
         if self.request.chord:
             self.backend.on_chord_part_return(self)
 
-    def on_failure(self, exc, task_id, args, kwargs, einfo=None):
+    def on_failure(self, exc, task_id, args, kwargs, einfo):
         """Error handler.
 
         This is run by the worker when the task fails.

+ 5 - 1
celery/task/http.py

@@ -1,12 +1,16 @@
 import urllib2
+
 from urllib import urlencode
 from urlparse import urlparse
+try:
+    from urlparse import parse_qsl
+except ImportError:
+    from cgi import parse_qsl
 
 from anyjson import deserialize
 
 from celery import __version__ as celery_version
 from celery.task.base import Task as BaseTask
-from celery.utils.compat import parse_qsl
 
 GET_METHODS = frozenset(["GET", "HEAD"])
 

+ 10 - 13
celery/tests/test_app/test_loaders.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import os
 import sys
 
@@ -11,7 +13,7 @@ from celery.loaders.app import AppLoader
 
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
-from celery.tests.utils import with_environ, execute_context
+from celery.tests.utils import with_environ
 
 
 class ObjectConfig(object):
@@ -127,17 +129,15 @@ class TestLoaderBase(unittest.TestCase):
         MockMail.Mailer.raise_on_send = True
         opts = dict(self.message_options, **self.server_options)
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             self.loader.mail_admins(fail_silently=True, **opts)
-            return log[0].message
+            warning = log[0].message
 
-        warning = execute_context(catch_warnings(record=True),
-                                  with_catch_warnings)
-        self.assertIsInstance(warning, MockMail.SendmailWarning)
-        self.assertIn("KeyError", warning.args[0])
+            self.assertIsInstance(warning, MockMail.SendmailWarning)
+            self.assertIn("KeyError", warning.args[0])
 
-        self.assertRaises(KeyError, self.loader.mail_admins,
-                          fail_silently=False, **opts)
+            self.assertRaises(KeyError, self.loader.mail_admins,
+                              fail_silently=False, **opts)
 
     def test_mail_admins(self):
         MockMail.Mailer.raise_on_send = False
@@ -214,13 +214,10 @@ class TestDefaultLoader(unittest.TestCase):
             def import_from_cwd(self, name):
                 raise ImportError(name)
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             l = _Loader()
             self.assertDictEqual(l.conf, {})
             context_executed[0] = True
-
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
         self.assertTrue(context_executed[0])
 
 

+ 4 - 5
celery/tests/test_backends/test_database.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import sys
 
 from datetime import datetime
@@ -10,7 +12,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.utils import gen_unique_id
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import mask_modules
 from celery.tests.utils import unittest
 
 try:
@@ -39,13 +41,10 @@ class test_DatabaseBackend(unittest.TestCase):
             raise SkipTest("sqlalchemy not installed")
 
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
-
-        def with_SQLAlchemy_masked(_val):
+        with mask_modules("sqlalchemy"):
             from celery.backends.database import _sqlalchemy_installed
             self.assertRaises(ImproperlyConfigured, _sqlalchemy_installed)
 
-        execute_context(mask_modules("sqlalchemy"), with_SQLAlchemy_masked)
-
     def test_pickle_hack_for_sqla_05(self):
         import sqlalchemy as sa
         from celery.db import session

+ 4 - 5
celery/tests/test_backends/test_redis.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import sys
 import socket
 from celery.tests.utils import unittest
@@ -11,7 +13,7 @@ from celery.utils import gen_unique_id
 from celery.backends import pyredis
 from celery.backends.pyredis import RedisBackend
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import mask_modules
 
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg_emitted = False
@@ -108,12 +110,9 @@ class TestRedisBackendNoRedis(unittest.TestCase):
     def test_redis_None_if_redis_not_installed(self):
         prev = sys.modules.pop("celery.backends.pyredis")
         try:
-
-            def with_redis_masked(_val):
+            with mask_modules("redis"):
                 from celery.backends.pyredis import redis
                 self.assertIsNone(redis)
-            context = mask_modules("redis")
-            execute_context(context, with_redis_masked)
         finally:
             sys.modules["celery.backends.pyredis"] = prev
 

+ 2 - 1
celery/tests/test_bin/test_celerybeat.py

@@ -1,6 +1,8 @@
 import logging
 import sys
 
+from collections import defaultdict
+
 from kombu.tests.utils import redirect_stdouts
 
 from celery import beat
@@ -8,7 +10,6 @@ from celery import platforms
 from celery.app import app_or_default
 from celery.bin import celerybeat as celerybeat_bin
 from celery.apps import beat as beatapp
-from celery.utils.compat import defaultdict
 
 from celery.tests.utils import AppCase
 

+ 5 - 5
celery/tests/test_bin/test_celeryd.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import logging
 import os
 import sys
@@ -9,6 +11,7 @@ try:
 except ImportError:
     current_process = None  # noqa
 
+from functools import wraps
 
 from nose import SkipTest
 from kombu.tests.utils import redirect_stdouts
@@ -24,7 +27,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import patch
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import AppCase, execute_context, StringIO
+from celery.tests.utils import AppCase, StringIO
 
 
 patch.ensure_process_aware_logger()
@@ -187,15 +190,12 @@ class test_Worker(AppCase):
 
         prev, os.geteuid = os.geteuid, geteuid
         try:
-
-            def with_catch_warnings(log):
+            with catch_warnings(record=True) as log:
                 worker = self.Worker()
                 worker.run()
                 self.assertTrue(log)
                 self.assertIn("superuser privileges is not encouraged",
                               log[0].message.args[0])
-            context = catch_warnings(record=True)
-            execute_context(context, with_catch_warnings)
         finally:
             os.geteuid = prev
 

+ 4 - 5
celery/tests/test_compat/test_decorators.py

@@ -1,10 +1,11 @@
+from __future__ import with_statement
+
 import warnings
 
 from celery.task import base
 
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context
 
 
 def add(x, y):
@@ -16,11 +17,9 @@ class test_decorators(unittest.TestCase):
     def setUp(self):
         warnings.resetwarnings()
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True):
             from celery import decorators
-            return decorators
-        context = catch_warnings(record=True)
-        self.decorators = execute_context(context, with_catch_warnings)
+            self.decorators = decorators
 
     def assertCompatDecorator(self, decorator, type, **opts):
         task = decorator(**opts)(add)

+ 7 - 45
celery/tests/test_compat/test_log.py

@@ -1,4 +1,4 @@
-from __future__ import generators
+from __future__ import with_statement
 
 import sys
 import logging
@@ -10,35 +10,10 @@ from celery import log
 from celery.log import (setup_logger, setup_task_logger,
                         get_default_logger, get_task_logger,
                         redirect_stdouts_to_logger, LoggingProxy)
-from celery.tests.utils import contextmanager
-from celery.tests.utils import override_stdouts, execute_context
 from celery.utils import gen_unique_id
-from celery.utils.compat import LoggerAdapter
 from celery.utils.compat import _CompatLoggerAdapter
-
-
-def get_handlers(logger):
-    if isinstance(logger, LoggerAdapter):
-        return logger.logger.handlers
-    return logger.handlers
-
-
-def set_handlers(logger, new_handlers):
-    if isinstance(logger, LoggerAdapter):
-        logger.logger.handlers = new_handlers
-    logger.handlers = new_handlers
-
-
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_handlers(logger)
-    sio = StringIO()
-    siohandler = logging.StreamHandler(sio)
-    set_handlers(logger, [siohandler])
-
-    yield sio
-
-    set_handlers(logger, old_handlers)
+from celery.tests.utils import (override_stdouts, wrap_logger,
+                                get_handlers, set_handlers)
 
 
 class test_default_logger(unittest.TestCase):
@@ -50,13 +25,10 @@ class test_default_logger(unittest.TestCase):
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
-        def with_wrap_logger(sio):
+        with wrap_logger(logger, loglevel=loglevel) as sio:
             logger.log(loglevel, logmsg)
             return sio.getvalue().strip()
 
-        context = wrap_logger(logger, loglevel=loglevel)
-        execute_context(context, with_wrap_logger)
-
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         return self.assertEqual(val, logmsg, reason)
@@ -81,16 +53,13 @@ class test_default_logger(unittest.TestCase):
         l = self.get_logger()
         set_handlers(l, [])
 
-        def with_override_stdouts(outs):
+        with override_stdouts() as outs:
             stdout, stderr = outs
             l = self.setup_logger(logfile=stderr, loglevel=logging.INFO,
                                   root=False)
             l.info("The quick brown fox...")
             self.assertIn("The quick brown fox...", stderr.getvalue())
 
-        context = override_stdouts()
-        execute_context(context, with_override_stdouts)
-
     def test_setup_logger_no_handlers_file(self):
         l = self.get_logger()
         set_handlers(l, [])
@@ -103,14 +72,10 @@ class test_default_logger(unittest.TestCase):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False)
         try:
-
-            def with_wrap_logger(sio):
+            with wrap_logger(logger) as sio:
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
                 self.assertIn("foo", sio.getvalue())
-
-            context = wrap_logger(logger)
-            execute_context(context, with_wrap_logger)
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
@@ -118,7 +83,7 @@ class test_default_logger(unittest.TestCase):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False)
 
-        def with_wrap_logger(sio):
+        with wrap_logger(logger) as sio:
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p.close()
             p.write("foo")
@@ -135,9 +100,6 @@ class test_default_logger(unittest.TestCase):
             self.assertFalse(p.isatty())
             self.assertIsNone(p.fileno())
 
-        context = wrap_logger(logger)
-        execute_context(context, with_wrap_logger)
-
 
 class test_task_logger(test_default_logger):
 

+ 0 - 2
celery/tests/test_slow/test_buckets.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 import sys
 import time
 

+ 1 - 5
celery/tests/test_task/test_result.py

@@ -1,16 +1,12 @@
-from __future__ import generators
-
-from celery.tests.utils import unittest
-
 from celery import states
 from celery.app import app_or_default
 from celery.utils import gen_unique_id
-from celery.utils.compat import all
 from celery.utils.serialization import pickle
 from celery.result import AsyncResult, EagerResult, TaskSetResult
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 
+from celery.tests.utils import unittest
 from celery.tests.utils import skip_if_quick
 
 

+ 14 - 15
celery/tests/test_task/test_task_builtins.py

@@ -1,8 +1,10 @@
+from __future__ import with_statement
+
 import warnings
 
 from celery.task import ping, PingTask, backend_cleanup
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest, execute_context
+from celery.tests.utils import unittest
 
 
 def some_func(i):
@@ -14,32 +16,29 @@ class test_deprecated(unittest.TestCase):
     def test_ping(self):
         warnings.resetwarnings()
 
-        def block(log):
+        with catch_warnings(record=True) as log:
             prev = PingTask.app.conf.CELERY_ALWAYS_EAGER
             PingTask.app.conf.CELERY_ALWAYS_EAGER = True
             try:
-                return ping(), log[0].message
+                pong = ping()
+                warning = log[0].message
+                self.assertEqual(pong, "pong")
+                self.assertIsInstance(warning, DeprecationWarning)
+                self.assertIn("ping task has been deprecated",
+                              warning.args[0])
             finally:
                 PingTask.app.conf.CELERY_ALWAYS_EAGER = prev
 
-        pong, warning = execute_context(catch_warnings(record=True), block)
-        self.assertEqual(pong, "pong")
-        self.assertIsInstance(warning, DeprecationWarning)
-        self.assertIn("ping task has been deprecated",
-                      warning.args[0])
-
     def test_TaskSet_import_from_task_base(self):
         warnings.resetwarnings()
 
-        def block(log):
+        with catch_warnings(record=True) as log:
             from celery.task.base import TaskSet, subtask
             TaskSet()
             subtask(PingTask)
-            return log[0].message, log[1].message
-
-        for w in execute_context(catch_warnings(record=True), block):
-            self.assertIsInstance(w, DeprecationWarning)
-            self.assertIn("is deprecated", w.args[0])
+            for w in (log[0].message, log[1].message):
+                self.assertIsInstance(w, DeprecationWarning)
+                self.assertIn("is deprecated", w.args[0])
 
 
 class test_backend_cleanup(unittest.TestCase):

+ 11 - 37
celery/tests/test_task/test_task_http.py

@@ -1,8 +1,9 @@
 # -*- coding: utf-8 -*-
-from __future__ import generators
+from __future__ import with_statement
 
 import logging
 
+from contextlib import contextmanager
 from functools import wraps
 try:
     from urllib import addinfourl
@@ -12,8 +13,7 @@ except ImportError:  # py3k
 from anyjson import serialize
 
 from celery.task import http
-from celery.tests.utils import unittest
-from celery.tests.utils import execute_context, contextmanager, StringIO
+from celery.tests.utils import unittest, StringIO
 
 
 @contextmanager
@@ -98,94 +98,68 @@ class TestHttpDispatch(unittest.TestCase):
     def test_dispatch_success(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(success_response(100)):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
             self.assertEqual(d.dispatch(), 100)
 
-        context = mock_urlopen(success_response(100))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_failure(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(fail_response("Invalid moon alignment")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.RemoteExecuteError, d.dispatch)
 
-        context = mock_urlopen(fail_response("Invalid moon alignment"))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_empty_response(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(_response("")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
 
-        context = mock_urlopen(_response(""))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_non_json(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(_response("{'#{:'''")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
 
-        context = mock_urlopen(_response("{'#{:'''"))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_unknown_status(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(unknown_response()):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.UnknownStatusError, d.dispatch)
 
-        context = mock_urlopen(unknown_response())
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_POST(self):
         logger = logging.getLogger("celery.unittest")
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(success_response(100)):
             d = http.HttpDispatch("http://example.com/mul", "POST", {
                                     "x": 10, "y": 10}, logger)
             self.assertEqual(d.dispatch(), 100)
 
-        context = mock_urlopen(success_response(100))
-        execute_context(context, with_mock_urlopen)
-
 
 class TestURL(unittest.TestCase):
 
     def test_URL_get_async(self):
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         try:
-
-            def with_mock_urlopen(_val):
+            with mock_urlopen(success_response(100)):
                 d = http.URL("http://example.com/mul").get_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
-
-            context = mock_urlopen(success_response(100))
-            execute_context(context, with_mock_urlopen)
         finally:
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False
 
     def test_URL_post_async(self):
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         try:
-
-            def with_mock_urlopen(_val):
+            with mock_urlopen(success_response(100)):
                 d = http.URL("http://example.com/mul").post_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
-
-            context = mock_urlopen(success_response(100))
-            execute_context(context, with_mock_urlopen)
         finally:
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False

+ 5 - 11
celery/tests/test_task/test_task_sets.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import anyjson
 import warnings
 
@@ -6,7 +8,6 @@ from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
 
 
@@ -93,7 +94,7 @@ class test_TaskSet(unittest.TestCase):
     def test_interface__compat(self):
         warnings.resetwarnings()
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertTrue(log)
             self.assertIn("Using this invocation of TaskSet is deprecated",
@@ -103,29 +104,22 @@ class test_TaskSet(unittest.TestCase):
                                     for i in (2, 4, 8)])
             return ts
 
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
-
         # TaskSet.task (deprecated)
-        def with_catch_warnings2(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertEqual(ts.task.name, MockTask.name)
             self.assertTrue(log)
             self.assertIn("TaskSet.task is deprecated",
                           log[0].message.args[0])
 
-        execute_context(catch_warnings(record=True), with_catch_warnings2)
-
         # TaskSet.task_name (deprecated)
-        def with_catch_warnings3(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertEqual(ts.task_name, MockTask.name)
             self.assertTrue(log)
             self.assertIn("TaskSet.task_name is deprecated",
                           log[0].message.args[0])
 
-        execute_context(catch_warnings(record=True), with_catch_warnings3)
-
     def test_task_arg_can_be_iterable__compat(self):
         ts = TaskSet([MockTask.subtask((i, i))
                         for i in (2, 4, 8)])

+ 5 - 8
celery/tests/test_utils/test_serialization.py

@@ -1,7 +1,9 @@
+from __future__ import with_statement
+
 import sys
-from celery.tests.utils import unittest
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import unittest
+from celery.tests.utils import mask_modules
 
 
 class TestAAPickle(unittest.TestCase):
@@ -9,14 +11,9 @@ class TestAAPickle(unittest.TestCase):
     def test_no_cpickle(self):
         prev = sys.modules.pop("celery.utils.serialization", None)
         try:
-
-            def with_cPickle_masked(_val):
+            with mask_modules("cPickle"):
                 from celery.utils.serialization import pickle
                 import pickle as orig_pickle
                 self.assertIs(pickle.dumps, orig_pickle.dumps)
-
-            context = mask_modules("cPickle")
-            execute_context(context, with_cPickle_masked)
-
         finally:
             sys.modules["celery.utils.serialization"] = prev

+ 4 - 4
celery/tests/test_worker/test_worker.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import socket
 import sys
 
@@ -25,7 +27,7 @@ from celery.utils.timer2 import Timer
 
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
-from celery.tests.utils import AppCase, execute_context, skip
+from celery.tests.utils import AppCase, skip
 
 
 class PlaceHolder(object):
@@ -253,13 +255,11 @@ class test_Consumer(unittest.TestCase):
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             l.receive_message(m.decode(), m)
             self.assertTrue(log)
             self.assertIn("unknown message", log[0].message.args[0])
 
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
 
     @patch("celery.utils.timer2.to_timestamp")
     def test_receive_message_eta_OverflowError(self, to_timestamp):

+ 6 - 13
celery/tests/test_worker/test_worker_job.py

@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+from __future__ import with_statement
+
 import anyjson
 import logging
 import os
@@ -27,7 +29,7 @@ from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context, StringIO, wrap_logger
+from celery.tests.utils import StringIO, wrap_logger
 
 
 scratch = {"ACK": False}
@@ -89,19 +91,14 @@ class test_WorkerTaskTrace(unittest.TestCase):
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         try:
 
-            def with_wrap_logger(sio):
+            logger = mytask.app.log.get_default_logger()
+            with wrap_logger(logger) as sio:
                 uuid = gen_unique_id()
                 ret = jail(uuid, mytask.name, [2], {})
                 self.assertEqual(ret, 4)
                 mytask.backend.mark_as_done.assert_called_with(uuid, 4)
                 logs = sio.getvalue().strip()
                 self.assertIn("Process cleanup failed", logs)
-                return 1234
-
-            logger = mytask.app.log.get_default_logger()
-            self.assertEqual(execute_context(
-                    wrap_logger(logger), with_wrap_logger), 1234)
-
         finally:
             mytask.backend = backend
 
@@ -418,17 +415,13 @@ class test_TaskRequest(unittest.TestCase):
 
         WorkerTaskTrace.execute = _error_exec
         try:
-
-            def with_catch_warnings(log):
+            with catch_warnings(record=True) as log:
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                                         [4], {})
                 self.assertIsInstance(res, ExceptionInfo)
                 self.assertTrue(log)
                 self.assertIn("Exception outside", log[0].message.args[0])
                 self.assertIn("KeyError", log[0].message.args[0])
-
-            context = catch_warnings(record=True)
-            execute_context(context, with_catch_warnings)
         finally:
             WorkerTaskTrace.execute = old_exec
 

+ 2 - 59
celery/tests/utils.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 try:
     import unittest
     unittest.skip
@@ -17,20 +15,15 @@ except ImportError:  # py3k
     import builtins  # noqa
 
 from functools import wraps
+from contextlib import contextmanager
 
-from celery.utils.compat import StringIO, LoggerAdapter
-try:
-    from contextlib import contextmanager
-except ImportError:
-    from celery.tests.utils import fallback_contextmanager
-    contextmanager = fallback_contextmanager  # noqa
 
 import mock
-
 from nose import SkipTest
 
 from celery.app import app_or_default
 from celery.utils import noop
+from celery.utils.compat import StringIO, LoggerAdapter
 
 
 class Mock(mock.Mock):
@@ -60,37 +53,6 @@ class AppCase(unittest.TestCase):
         pass
 
 
-class GeneratorContextManager(object):
-    def __init__(self, gen):
-        self.gen = gen
-
-    def __enter__(self):
-        try:
-            return self.gen.next()
-        except StopIteration:
-            raise RuntimeError("generator didn't yield")
-
-    def __exit__(self, type, value, traceback):
-        if type is None:
-            try:
-                self.gen.next()
-            except StopIteration:
-                return
-            else:
-                raise RuntimeError("generator didn't stop")
-        else:
-            try:
-                self.gen.throw(type, value, traceback)
-                raise RuntimeError("generator didn't stop after throw()")
-            except StopIteration:
-                return True
-            except AttributeError:
-                raise value
-            except:
-                if sys.exc_info()[1] is not value:
-                    raise
-
-
 def get_handlers(logger):
     if isinstance(logger, LoggerAdapter):
         return logger.logger.handlers
@@ -115,25 +77,6 @@ def wrap_logger(logger, loglevel=logging.ERROR):
     set_handlers(logger, old_handlers)
 
 
-def fallback_contextmanager(fun):
-    def helper(*args, **kwds):
-        return GeneratorContextManager(fun(*args, **kwds))
-    return helper
-
-
-def execute_context(context, fun):
-    val = context.__enter__()
-    exc_info = (None, None, None)
-    try:
-        try:
-            return fun(val)
-        except:
-            exc_info = sys.exc_info()
-            raise
-    finally:
-        context.__exit__(*exc_info)
-
-
 @contextmanager
 def eager_tasks():
     app = app_or_default()

+ 0 - 3
celery/utils/__init__.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 import os
 import sys
 import operator
@@ -19,7 +17,6 @@ from kombu.utils import rpartition
 
 from celery.utils.compat import StringIO
 
-
 LOG_LEVELS = dict(logging._levelNames)
 LOG_LEVELS["FATAL"] = logging.FATAL
 LOG_LEVELS[logging.FATAL] = "FATAL"

+ 1 - 99
celery/utils/compat.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 ############## py3k #########################################################
 try:
     from UserList import UserList       # noqa
@@ -19,39 +17,6 @@ except ImportError:
     except ImportError:
         from io import StringIO         # noqa
 
-############## urlparse.parse_qsl ###########################################
-
-try:
-    from urlparse import parse_qsl
-except ImportError:
-    from cgi import parse_qsl  # noqa
-
-############## __builtin__.all ##############################################
-
-try:
-    all([True])
-    all = all
-except NameError:
-
-    def all(iterable):
-        for item in iterable:
-            if not item:
-                return False
-        return True
-
-############## __builtin__.any ##############################################
-
-try:
-    any([True])
-    any = any
-except NameError:
-
-    def any(iterable):
-        for item in iterable:
-            if item:
-                return True
-        return False
-
 ############## collections.OrderedDict ######################################
 
 import weakref
@@ -270,53 +235,6 @@ try:
 except ImportError:
     OrderedDict = CompatOrderedDict  # noqa
 
-############## collections.defaultdict ######################################
-
-try:
-    from collections import defaultdict
-except ImportError:
-    # Written by Jason Kirtland, taken from Python Cookbook:
-    # <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/523034>
-    class defaultdict(dict):  # noqa
-
-        def __init__(self, default_factory=None, *args, **kwargs):
-            dict.__init__(self, *args, **kwargs)
-            self.default_factory = default_factory
-
-        def __getitem__(self, key):
-            try:
-                return dict.__getitem__(self, key)
-            except KeyError:
-                return self.__missing__(key)
-
-        def __missing__(self, key):
-            if self.default_factory is None:
-                raise KeyError(key)
-            self[key] = value = self.default_factory()
-            return value
-
-        def __reduce__(self):
-            f = self.default_factory
-            args = f is None and tuple() or f
-            return type(self), args, None, None, self.iteritems()
-
-        def copy(self):
-            return self.__copy__()
-
-        def __copy__(self):
-            return type(self)(self.default_factory, self)
-
-        def __deepcopy__(self):
-            import copy
-            return type(self)(self.default_factory,
-                        copy.deepcopy(self.items()))
-
-        def __repr__(self):
-            return "defaultdict(%s, %s)" % (self.default_factory,
-                                            dict.__repr__(self))
-    import collections
-    collections.defaultdict = defaultdict           # Pickle needs this.
-
 ############## logging.LoggerAdapter ########################################
 import inspect
 import logging
@@ -326,16 +244,6 @@ except ImportError:
     multiprocessing = None  # noqa
 import sys
 
-from logging import LogRecord
-
-log_takes_extra = "extra" in inspect.getargspec(logging.Logger._log)[0]
-
-# The func argument to LogRecord was added in 2.5
-if "func" not in inspect.getargspec(LogRecord.__init__)[0]:
-
-    def LogRecord(name, level, fn, lno, msg, args, exc_info, func):
-        return logging.LogRecord(name, level, fn, lno, msg, args, exc_info)
-
 
 def _checkLevel(level):
     if isinstance(level, int):
@@ -390,7 +298,7 @@ class _CompatLoggerAdapter(object):
 
     def makeRecord(self, name, level, fn, lno, msg, args, exc_info,
             func=None, extra=None):
-        rv = LogRecord(name, level, fn, lno, msg, args, exc_info, func)
+        rv = logging.LogRecord(name, level, fn, lno, msg, args, exc_info, func)
         if extra is not None:
             for key, value in extra.items():
                 if key in ("message", "asctime") or key in rv.__dict__:
@@ -441,12 +349,6 @@ try:
 except ImportError:
     LoggerAdapter = _CompatLoggerAdapter  # noqa
 
-
-def log_with_extra(logger, level, msg, *args, **kwargs):
-    if not log_takes_extra:
-        kwargs.pop("extra", None)
-    return logger.log(level, msg, *args, **kwargs)
-
 ############## itertools.izip_longest #######################################
 
 try:

+ 2 - 8
celery/utils/serialization.py

@@ -11,7 +11,7 @@ except ImportError:
     cpickle = None  # noqa
 
 if sys.version_info < (2, 6):  # pragma: no cover
-    # cPickle is broken in Python <= 2.5.
+    # cPickle is broken in Python <= 2.6.
     # It unsafely and incorrectly uses relative instead of absolute imports,
     # so e.g.:
     #       exceptions.KeyError
@@ -26,14 +26,8 @@ else:
     pickle = cpickle or pypickle
 
 
-# BaseException was introduced in Python 2.5.
-try:
-    _error_bases = (BaseException, )
-except NameError:  # pragma: no cover
-    _error_bases = (SystemExit, KeyboardInterrupt)
-
 #: List of base classes we probably don't want to reduce to.
-unwanted_base_classes = (StandardError, Exception) + _error_bases + (object, )
+unwanted_base_classes = (StandardError, Exception, BaseException, object)
 
 
 if sys.version_info < (2, 5):  # pragma: no cover

+ 0 - 3
celery/utils/timer2.py

@@ -1,7 +1,4 @@
 """timer2 - Scheduler for Python functions."""
-
-from __future__ import generators
-
 import atexit
 import heapq
 import logging

+ 1 - 1
celery/worker/buckets.py

@@ -6,7 +6,7 @@ from Queue import Queue, Empty
 
 from celery.datastructures import TokenBucket
 from celery.utils import timeutils
-from celery.utils.compat import all, izip_longest, chain_from_iterable
+from celery.utils.compat import izip_longest, chain_from_iterable
 
 
 class RateLimitExceeded(Exception):

+ 4 - 19
celery/worker/consumer.py

@@ -67,9 +67,6 @@ up and running.
   early, *then* close the connection.
 
 """
-
-from __future__ import generators
-
 import socket
 import sys
 import threading
@@ -139,14 +136,11 @@ class QoS(object):
 
     def increment(self, n=1):
         """Increment the current prefetch count value by one."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
                 new_value = self.value + max(n, 0)
                 self.value = self.set(new_value)
             return self.value
-        finally:
-            self._mutex.release()
 
     def _sub(self, n=1):
         assert self.value - n > 1
@@ -154,14 +148,11 @@ class QoS(object):
 
     def decrement(self, n=1):
         """Decrement the current prefetch count value by one."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
                 self._sub(n)
                 self.set(self.value)
             return self.value
-        finally:
-            self._mutex.release()
 
     def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.
@@ -170,12 +161,9 @@ class QoS(object):
         when necessary.
 
         """
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
                 self._sub(n)
-        finally:
-            self._mutex.release()
 
     def set(self, pcount):
         """Set channel prefetch_count setting."""
@@ -193,11 +181,8 @@ class QoS(object):
 
     def update(self):
         """Update prefetch count with current value."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             return self.set(self.value)
-        finally:
-            self._mutex.release()
 
 
 class Consumer(object):

+ 5 - 8
celery/worker/job.py

@@ -1,4 +1,3 @@
-import logging
 import os
 import sys
 import time
@@ -17,7 +16,6 @@ from celery.execute.trace import TaskTrace
 from celery.registry import tasks
 from celery.utils import noop, kwdict, fun_takes_kwargs
 from celery.utils import get_symbol_by_name, truncate_text
-from celery.utils.compat import log_with_extra
 from celery.utils.encoding import safe_repr, safe_str
 from celery.utils.timeutils import maybe_iso8601
 from celery.worker import state
@@ -523,12 +521,11 @@ class TaskRequest(object):
                    "args": self.args,
                    "kwargs": self.kwargs}
 
-        log_with_extra(self.logger, logging.ERROR,
-                       self.error_msg.strip() % context,
-                       exc_info=exc_info,
-                       extra={"data": {"hostname": self.hostname,
-                                       "id": self.task_id,
-                                       "name": self.task_name}})
+        self.logger.error(self.error_msg.strip() % context,
+                          exc_info=exc_info,
+                          extra={"data": {"id": self.task_id,
+                                          "name": self.task_name,
+                                          "hostname": self.hostname}})
 
         task_obj = tasks.get(self.task_name, object)
         self.send_error_email(task_obj, context, exc_info.exception,

+ 6 - 9
celery/worker/mediator.py

@@ -3,7 +3,6 @@
 Worker Controller Threads
 
 """
-import logging
 import os
 import sys
 import threading
@@ -12,7 +11,6 @@ import traceback
 from Queue import Empty
 
 from celery.app import app_or_default
-from celery.utils.compat import log_with_extra
 
 
 class Mediator(threading.Thread):
@@ -51,13 +49,12 @@ class Mediator(threading.Thread):
         try:
             self.callback(task)
         except Exception, exc:
-            log_with_extra(self.logger, logging.ERROR,
-                           "Mediator callback raised exception %r\n%s" % (
-                               exc, traceback.format_exc()),
-                           exc_info=sys.exc_info(),
-                           extra={"data": {"hostname": task.hostname,
-                                           "id": task.task_id,
-                                           "name": task.task_name}})
+            self.logger.error("Mediator callback raised exception %r\n%s" % (
+                                exc, traceback.format_exc()),
+                              exc_info=sys.exc_info(),
+                              extra={"data": {"id": task.task_id,
+                                              "name": task.task_name,
+                                              "hostname": task.hostname}})
 
     def run(self):
         """Move tasks forver or until :meth:`stop` is called."""

+ 2 - 1
celery/worker/state.py

@@ -2,10 +2,11 @@ import os
 import platform
 import shelve
 
+from collections import defaultdict
+
 from kombu.utils import cached_property
 
 from celery import __version__
-from celery.utils.compat import defaultdict
 from celery.datastructures import LimitedSet
 
 #: Worker software/platform information.