소스 검색

Remove 2.4 workarounds

Ask Solem 14 년 전
부모
커밋
1dfe6d0d54
43개의 변경된 파일185개의 추가작업 그리고 454개의 파일을 삭제
  1. 0 1
      celery/app/base.py
  2. 1 1
      celery/bin/celeryd_multi.py
  3. 1 1
      celery/datastructures.py
  4. 2 2
      celery/db/session.py
  5. 2 0
      celery/execute/__init__.py
  6. 9 13
      celery/execute/trace.py
  7. 2 0
      celery/loaders/__init__.py
  8. 6 0
      celery/loaders/app.py
  9. 10 10
      celery/loaders/base.py
  10. 5 2
      celery/loaders/default.py
  11. 8 0
      celery/local.py
  12. 11 17
      celery/log.py
  13. 3 2
      celery/registry.py
  14. 4 6
      celery/result.py
  15. 4 6
      celery/routes.py
  16. 4 4
      celery/schedules.py
  17. 3 4
      celery/task/base.py
  18. 5 1
      celery/task/http.py
  19. 10 13
      celery/tests/test_app/test_loaders.py
  20. 4 5
      celery/tests/test_backends/test_database.py
  21. 4 5
      celery/tests/test_backends/test_redis.py
  22. 2 1
      celery/tests/test_bin/test_celerybeat.py
  23. 5 5
      celery/tests/test_bin/test_celeryd.py
  24. 4 5
      celery/tests/test_compat/test_decorators.py
  25. 7 45
      celery/tests/test_compat/test_log.py
  26. 0 2
      celery/tests/test_slow/test_buckets.py
  27. 1 5
      celery/tests/test_task/test_result.py
  28. 14 15
      celery/tests/test_task/test_task_builtins.py
  29. 11 37
      celery/tests/test_task/test_task_http.py
  30. 5 11
      celery/tests/test_task/test_task_sets.py
  31. 5 8
      celery/tests/test_utils/test_serialization.py
  32. 4 4
      celery/tests/test_worker/test_worker.py
  33. 6 13
      celery/tests/test_worker/test_worker_job.py
  34. 2 59
      celery/tests/utils.py
  35. 0 3
      celery/utils/__init__.py
  36. 1 99
      celery/utils/compat.py
  37. 2 8
      celery/utils/serialization.py
  38. 0 3
      celery/utils/timer2.py
  39. 1 1
      celery/worker/buckets.py
  40. 4 19
      celery/worker/consumer.py
  41. 5 8
      celery/worker/job.py
  42. 6 9
      celery/worker/mediator.py
  43. 2 1
      celery/worker/state.py

+ 0 - 1
celery/app/base.py

@@ -25,7 +25,6 @@ import kombu
 if kombu.VERSION < (1, 1, 0):
 if kombu.VERSION < (1, 1, 0):
     raise ImportError("Celery requires Kombu version 1.1.0 or higher.")
     raise ImportError("Celery requires Kombu version 1.1.0 or higher.")
 
 
-
 BUGREPORT_INFO = """
 BUGREPORT_INFO = """
 platform -> system:%(system)s arch:%(arch)s imp:%(py_i)s
 platform -> system:%(system)s arch:%(arch)s imp:%(py_i)s
 software -> celery:%(celery_v)s kombu:%(kombu_v)s py:%(py_v)s
 software -> celery:%(celery_v)s kombu:%(kombu_v)s py:%(py_v)s

+ 1 - 1
celery/bin/celeryd_multi.py

@@ -92,12 +92,12 @@ import signal
 import socket
 import socket
 import sys
 import sys
 
 
+from collections import defaultdict
 from subprocess import Popen
 from subprocess import Popen
 from time import sleep
 from time import sleep
 
 
 from celery import __version__
 from celery import __version__
 from celery.utils import term
 from celery.utils import term
-from celery.utils.compat import any, defaultdict
 
 
 SIGNAMES = set(sig for sig in dir(signal)
 SIGNAMES = set(sig for sig in dir(signal)
                         if sig.startswith("SIG") and "_" not in sig)
                         if sig.startswith("SIG") and "_" not in sig)

+ 1 - 1
celery/datastructures.py

@@ -8,7 +8,7 @@ Custom data structures.
 :license: BSD, see LICENSE for more details.
 :license: BSD, see LICENSE for more details.
 
 
 """
 """
-from __future__ import generators
+from __future__ import absolute_import
 
 
 import time
 import time
 import traceback
 import traceback

+ 2 - 2
celery/db/session.py

@@ -1,9 +1,9 @@
+from collections import defaultdict
+
 from sqlalchemy import create_engine
 from sqlalchemy import create_engine
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.ext.declarative import declarative_base
 
 
-from celery.utils.compat import defaultdict
-
 ResultModelBase = declarative_base()
 ResultModelBase = declarative_base()
 
 
 _SETUP = defaultdict(lambda: False)
 _SETUP = defaultdict(lambda: False)

+ 2 - 0
celery/execute/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery import current_app
 from celery import current_app
 from celery.utils import deprecated
 from celery.utils import deprecated
 
 

+ 9 - 13
celery/execute/trace.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import sys
 import traceback
 import traceback
 
 
@@ -32,14 +34,14 @@ class TraceInfo(object):
         """
         """
         try:
         try:
             return cls(states.SUCCESS, retval=fun(*args, **kwargs))
             return cls(states.SUCCESS, retval=fun(*args, **kwargs))
-        except (SystemExit, KeyboardInterrupt):
-            raise
         except RetryTaskError, exc:
         except RetryTaskError, exc:
             return cls(states.RETRY, retval=exc, exc_info=sys.exc_info())
             return cls(states.RETRY, retval=exc, exc_info=sys.exc_info())
         except Exception, exc:
         except Exception, exc:
             if propagate:
             if propagate:
                 raise
                 raise
             return cls(states.FAILURE, retval=exc, exc_info=sys.exc_info())
             return cls(states.FAILURE, retval=exc, exc_info=sys.exc_info())
+        except BaseException, exc:
+            raise
         except:  # pragma: no cover
         except:  # pragma: no cover
             # For Python2.5 where raising strings are still allowed
             # For Python2.5 where raising strings are still allowed
             # (but deprecated)
             # (but deprecated)
@@ -93,12 +95,11 @@ class TaskTrace(object):
                                  trace.exc_type, trace.tb, trace.strtb)
                                  trace.exc_type, trace.tb, trace.strtb)
         return r
         return r
 
 
-    def handle_after_return(self, status, retval, type_, tb, strtb):
-        einfo = None
+    def handle_after_return(self, status, retval, type_, tb, strtb, einfo=None):
         if status in states.EXCEPTION_STATES:
         if status in states.EXCEPTION_STATES:
             einfo = ExceptionInfo((retval, type_, tb))
             einfo = ExceptionInfo((retval, type_, tb))
         self.task.after_return(status, retval, self.task_id,
         self.task.after_return(status, retval, self.task_id,
-                               self.args, self.kwargs, einfo=einfo)
+                               self.args, self.kwargs, einfo)
 
 
     def handle_success(self, retval, *args):
     def handle_success(self, retval, *args):
         """Handle successful execution."""
         """Handle successful execution."""
@@ -107,25 +108,20 @@ class TaskTrace(object):
 
 
     def handle_retry(self, exc, type_, tb, strtb):
     def handle_retry(self, exc, type_, tb, strtb):
         """Handle retry exception."""
         """Handle retry exception."""
-
         # Create a simpler version of the RetryTaskError that stringifies
         # Create a simpler version of the RetryTaskError that stringifies
         # the original exception instead of including the exception instance.
         # the original exception instead of including the exception instance.
         # This is for reporting the retry in logs, email etc, while
         # This is for reporting the retry in logs, email etc, while
         # guaranteeing pickleability.
         # guaranteeing pickleability.
         message, orig_exc = exc.args
         message, orig_exc = exc.args
         expanded_msg = "%s: %s" % (message, str(orig_exc))
         expanded_msg = "%s: %s" % (message, str(orig_exc))
-        einfo = ExceptionInfo((type_,
-                               type_(expanded_msg, None),
-                               tb))
-        self.task.on_retry(exc, self.task_id,
-                           self.args, self.kwargs, einfo=einfo)
+        einfo = ExceptionInfo((type_, type_(expanded_msg, None), tb))
+        self.task.on_retry(exc, self.task_id, self.args, self.kwargs, einfo)
         return einfo
         return einfo
 
 
     def handle_failure(self, exc, type_, tb, strtb):
     def handle_failure(self, exc, type_, tb, strtb):
         """Handle exception."""
         """Handle exception."""
         einfo = ExceptionInfo((type_, exc, tb))
         einfo = ExceptionInfo((type_, exc, tb))
-        self.task.on_failure(exc, self.task_id,
-                             self.args, self.kwargs, einfo=einfo)
+        self.task.on_failure(exc, self.task_id, self.args, self.kwargs, einfo)
         signals.task_failure.send(sender=self.task, task_id=self.task_id,
         signals.task_failure.send(sender=self.task, task_id=self.task_id,
                                   exception=exc, args=self.args,
                                   exception=exc, args=self.args,
                                   kwargs=self.kwargs, traceback=tb,
                                   kwargs=self.kwargs, traceback=tb,

+ 2 - 0
celery/loaders/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 import os
 
 
 from celery.utils import get_cls_by_name
 from celery.utils import get_cls_by_name

+ 6 - 0
celery/loaders/app.py

@@ -1,3 +1,9 @@
+from __future__ import absolute_import
+
+import os
+
+from celery.datastructures import DictAttribute
+from celery.exceptions import ImproperlyConfigured
 from celery.loaders.base import BaseLoader
 from celery.loaders.base import BaseLoader
 
 
 
 

+ 10 - 10
celery/loaders/base.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import importlib
 import importlib
 import os
 import os
 import re
 import re
@@ -11,7 +13,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import get_cls_by_name
 from celery.utils import get_cls_by_name
 from celery.utils import import_from_cwd as _import_from_cwd
 from celery.utils import import_from_cwd as _import_from_cwd
 
 
-BUILTIN_MODULES = ["celery.task"]
+BUILTIN_MODULES = frozenset(["celery.task"])
 
 
 ERROR_ENVVAR_NOT_SET = (
 ERROR_ENVVAR_NOT_SET = (
 """The environment variable %r is not set,
 """The environment variable %r is not set,
@@ -23,7 +25,7 @@ a configuration module.""")
 class BaseLoader(object):
 class BaseLoader(object):
     """The base class for loaders.
     """The base class for loaders.
 
 
-    Loaders handles to following things:
+    Loaders handles,
 
 
         * Reading celery client/worker configurations.
         * Reading celery client/worker configurations.
 
 
@@ -65,14 +67,13 @@ class BaseLoader(object):
         return importlib.import_module(module)
         return importlib.import_module(module)
 
 
     def import_from_cwd(self, module, imp=None):
     def import_from_cwd(self, module, imp=None):
-        if imp is None:
-            imp = self.import_module
-        return _import_from_cwd(module, imp)
+        return _import_from_cwd(module,
+                self.import_module if imp is None else imp)
 
 
     def import_default_modules(self):
     def import_default_modules(self):
-        imports = self.conf.get("CELERY_IMPORTS") or []
-        imports = set(list(imports) + BUILTIN_MODULES)
-        return map(self.import_task_module, imports)
+        imports = self.conf.get("CELERY_IMPORTS") or ()
+        imports = set(list(imports)) | BUILTIN_MODULES
+        return [self.import_task_module(module) for module in imports]
 
 
     def init_worker(self):
     def init_worker(self):
         if not self.worker_initialized:
         if not self.worker_initialized:
@@ -172,5 +173,4 @@ class BaseLoader(object):
 
 
     @cached_property
     @cached_property
     def mail(self):
     def mail(self):
-        from celery.utils import mail
-        return mail
+        return self.import_module("celery.utils.mail")

+ 5 - 2
celery/loaders/default.py

@@ -1,10 +1,13 @@
+from __future__ import absolute_import
+
 import os
 import os
 import warnings
 import warnings
+
 from importlib import import_module
 from importlib import import_module
 
 
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
-from celery.loaders.base import BaseLoader
 from celery.exceptions import NotConfigured
 from celery.exceptions import NotConfigured
+from celery.loaders.base import BaseLoader
 
 
 DEFAULT_CONFIG_MODULE = "celeryconfig"
 DEFAULT_CONFIG_MODULE = "celeryconfig"
 
 
@@ -19,7 +22,7 @@ class Loader(BaseLoader):
         """Read configuration from :file:`celeryconfig.py` and configure
         """Read configuration from :file:`celeryconfig.py` and configure
         celery and Django so it can be used by regular Python."""
         celery and Django so it can be used by regular Python."""
         configname = os.environ.get("CELERY_CONFIG_MODULE",
         configname = os.environ.get("CELERY_CONFIG_MODULE",
-                                    DEFAULT_CONFIG_MODULE)
+                                     DEFAULT_CONFIG_MODULE)
         try:
         try:
             celeryconfig = self.import_from_cwd(configname)
             celeryconfig = self.import_from_cwd(configname)
         except ImportError:
         except ImportError:

+ 8 - 0
celery/local.py

@@ -1,3 +1,11 @@
+def try_import(module):
+    from importlib import import_module
+    try:
+        return import_module(module)
+    except ImportError:
+        pass
+
+
 class LocalProxy(object):
 class LocalProxy(object):
     """Code stolen from werkzeug.local.LocalProxy."""
     """Code stolen from werkzeug.local.LocalProxy."""
     __slots__ = ('__local', '__dict__', '__name__')
     __slots__ = ('__local', '__dict__', '__name__')

+ 11 - 17
celery/log.py

@@ -1,4 +1,6 @@
 """celery.log"""
 """celery.log"""
+from __future__ import absolute_import
+
 import logging
 import logging
 import threading
 import threading
 import sys
 import sys
@@ -100,37 +102,29 @@ class Logging(object):
         if colorize is None:
         if colorize is None:
             colorize = self.supports_color(logfile)
             colorize = self.supports_color(logfile)
 
 
-        if mputil:
-            try:
-                mputil._logger = None
-            except AttributeError:
-                pass
+        if mputil and hasattr(mputil, "_logger"):
+            mputil._logger = None
         ensure_process_aware_logger()
         ensure_process_aware_logger()
         receivers = signals.setup_logging.send(sender=None,
         receivers = signals.setup_logging.send(sender=None,
-                                               loglevel=loglevel,
-                                               logfile=logfile,
-                                               format=format,
-                                               colorize=colorize)
+                        loglevel=loglevel, logfile=logfile,
+                        format=format, colorize=colorize)
         if not receivers:
         if not receivers:
             root = logging.getLogger()
             root = logging.getLogger()
 
 
             if self.app.conf.CELERYD_HIJACK_ROOT_LOGGER:
             if self.app.conf.CELERYD_HIJACK_ROOT_LOGGER:
                 root.handlers = []
                 root.handlers = []
 
 
-            mp = mputil and mputil.get_logger() or None
-            for logger in (root, mp):
-                if logger:
-                    self._setup_logger(logger, logfile, format,
-                                       colorize, **kwargs)
-                    logger.setLevel(loglevel)
+            mp = mputil.get_logger() if mputil else None
+            for logger in filter(None, (root, mp)):
+                self._setup_logger(logger, logfile, format, colorize, **kwargs)
+                logger.setLevel(loglevel)
         Logging._setup = True
         Logging._setup = True
         return receivers
         return receivers
 
 
     def _detect_handler(self, logfile=None):
     def _detect_handler(self, logfile=None):
         """Create log handler with either a filename, an open stream
         """Create log handler with either a filename, an open stream
         or :const:`None` (stderr)."""
         or :const:`None` (stderr)."""
-        if logfile is None:
-            logfile = sys.__stderr__
+        logfile = sys.__stderr__ if logfile is None else logfile
         if hasattr(logfile, "write"):
         if hasattr(logfile, "write"):
             return logging.StreamHandler(logfile)
             return logging.StreamHandler(logfile)
         return WatchedFileHandler(logfile)
         return WatchedFileHandler(logfile)

+ 3 - 2
celery/registry.py

@@ -1,4 +1,6 @@
 """celery.registry"""
 """celery.registry"""
+from __future__ import absolute_import
+
 import inspect
 import inspect
 
 
 from celery.exceptions import NotRegistered
 from celery.exceptions import NotRegistered
@@ -27,8 +29,7 @@ class TaskRegistry(UserDict):
         instance.
         instance.
 
 
         """
         """
-        task = inspect.isclass(task) and task() or task
-        self.data[task.name] = task
+        self[task.name] = inspect.isclass(task) and task() or task
 
 
     def unregister(self, name):
     def unregister(self, name):
         """Unregister task by name.
         """Unregister task by name.

+ 4 - 6
celery/result.py

@@ -1,4 +1,4 @@
-from __future__ import generators
+from __future__ import absolute_import
 
 
 import time
 import time
 
 
@@ -7,10 +7,10 @@ from itertools import imap
 
 
 from celery import current_app
 from celery import current_app
 from celery import states
 from celery import states
+from celery import current_app
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.registry import _unpickle_task
 from celery.registry import _unpickle_task
-from celery.utils.compat import any, all
 
 
 
 
 def _unpickle_result(task_id, task_name):
 def _unpickle_result(task_id, task_name):
@@ -35,10 +35,10 @@ class BaseAsyncResult(object):
     backend = None
     backend = None
 
 
     def __init__(self, task_id, backend, task_name=None, app=None):
     def __init__(self, task_id, backend, task_name=None, app=None):
+        self.app = app_or_default(app)
         self.task_id = task_id
         self.task_id = task_id
         self.backend = backend
         self.backend = backend
         self.task_name = task_name
         self.task_name = task_name
-        self.app = app_or_default(app)
 
 
     def forget(self):
     def forget(self):
         """Forget about (and possibly remove the result of) this task."""
         """Forget about (and possibly remove the result of) this task."""
@@ -476,9 +476,7 @@ class TaskSetResult(ResultSet):
     @classmethod
     @classmethod
     def restore(self, taskset_id, backend=None):
     def restore(self, taskset_id, backend=None):
         """Restore previously saved taskset result."""
         """Restore previously saved taskset result."""
-        if backend is None:
-            backend = current_app.backend
-        return backend.restore_taskset(taskset_id)
+        return (backend or current_app.backend).restore_taskset(taskset_id)
 
 
     def itersubtasks(self):
     def itersubtasks(self):
         """Depreacted.   Use ``iter(self.results)`` instead."""
         """Depreacted.   Use ``iter(self.results)`` instead."""

+ 4 - 6
celery/routes.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery.exceptions import QueueNotFound
 from celery.exceptions import QueueNotFound
 from celery.utils import firstmethod, instantiate, lpmerge, mpromise
 from celery.utils import firstmethod, instantiate, lpmerge, mpromise
 
 
@@ -22,12 +24,8 @@ class Router(object):
             app=None):
             app=None):
         from celery.app import app_or_default
         from celery.app import app_or_default
         self.app = app_or_default(app)
         self.app = app_or_default(app)
-        if queues is None:
-            queues = {}
-        if routes is None:
-            routes = []
-        self.queues = queues
-        self.routes = routes
+        self.queues = {} if queues is None else queues
+        self.routes = [] if routes is None else routes
         self.create_missing = create_missing
         self.create_missing = create_missing
 
 
     def route(self, options, task, args=(), kwargs={}):
     def route(self, options, task, args=(), kwargs={}):

+ 4 - 4
celery/schedules.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from dateutil.relativedelta import relativedelta
 from dateutil.relativedelta import relativedelta
 from pyparsing import (Word, Literal, ZeroOrMore, Optional,
 from pyparsing import (Word, Literal, ZeroOrMore, Optional,
@@ -256,12 +258,11 @@ class crontab(schedule):
     def remaining_estimate(self, last_run_at):
     def remaining_estimate(self, last_run_at):
         """Returns when the periodic task should run next as a timedelta."""
         """Returns when the periodic task should run next as a timedelta."""
         weekday = last_run_at.isoweekday()
         weekday = last_run_at.isoweekday()
-        if weekday == 7:    # Sunday is day 0, not day 7.
-            weekday = 0
+        weekday = 0 if weekday == 7 else weekday  # Sunday is day 0, not day 7.
 
 
         execute_this_hour = (weekday in self.day_of_week and
         execute_this_hour = (weekday in self.day_of_week and
                                 last_run_at.hour in self.hour and
                                 last_run_at.hour in self.hour and
-                                last_run_at.minute < max(self.minute))
+                                    last_run_at.minute < max(self.minute))
 
 
         if execute_this_hour:
         if execute_this_hour:
             next_minute = min(minute for minute in self.minute
             next_minute = min(minute for minute in self.minute
@@ -271,7 +272,6 @@ class crontab(schedule):
                                   microsecond=0)
                                   microsecond=0)
         else:
         else:
             next_minute = min(self.minute)
             next_minute = min(self.minute)
-
             execute_today = (weekday in self.day_of_week and
             execute_today = (weekday in self.day_of_week and
                                  last_run_at.hour < max(self.hour))
                                  last_run_at.hour < max(self.hour))
 
 

+ 3 - 4
celery/task/base.py

@@ -626,7 +626,7 @@ class BaseTask(object):
             task_id = self.request.id
             task_id = self.request.id
         self.backend.store_result(task_id, meta, state)
         self.backend.store_result(task_id, meta, state)
 
 
-    def on_retry(self, exc, task_id, args, kwargs, einfo=None):
+    def on_retry(self, exc, task_id, args, kwargs, einfo):
         """Retry handler.
         """Retry handler.
 
 
         This is run by the worker when the task is to be retried.
         This is run by the worker when the task is to be retried.
@@ -644,8 +644,7 @@ class BaseTask(object):
         """
         """
         pass
         pass
 
 
-    def after_return(self, status, retval, task_id, args,
-            kwargs, einfo=None):
+    def after_return(self, status, retval, task_id, args, kwargs, einfo):
         """Handler called after the task returns.
         """Handler called after the task returns.
 
 
         :param status: Current task state.
         :param status: Current task state.
@@ -664,7 +663,7 @@ class BaseTask(object):
         if self.request.chord:
         if self.request.chord:
             self.backend.on_chord_part_return(self)
             self.backend.on_chord_part_return(self)
 
 
-    def on_failure(self, exc, task_id, args, kwargs, einfo=None):
+    def on_failure(self, exc, task_id, args, kwargs, einfo):
         """Error handler.
         """Error handler.
 
 
         This is run by the worker when the task fails.
         This is run by the worker when the task fails.

+ 5 - 1
celery/task/http.py

@@ -1,12 +1,16 @@
 import urllib2
 import urllib2
+
 from urllib import urlencode
 from urllib import urlencode
 from urlparse import urlparse
 from urlparse import urlparse
+try:
+    from urlparse import parse_qsl
+except ImportError:
+    from cgi import parse_qsl
 
 
 from anyjson import deserialize
 from anyjson import deserialize
 
 
 from celery import __version__ as celery_version
 from celery import __version__ as celery_version
 from celery.task.base import Task as BaseTask
 from celery.task.base import Task as BaseTask
-from celery.utils.compat import parse_qsl
 
 
 GET_METHODS = frozenset(["GET", "HEAD"])
 GET_METHODS = frozenset(["GET", "HEAD"])
 
 

+ 10 - 13
celery/tests/test_app/test_loaders.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import os
 import os
 import sys
 import sys
 
 
@@ -11,7 +13,7 @@ from celery.loaders.app import AppLoader
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
-from celery.tests.utils import with_environ, execute_context
+from celery.tests.utils import with_environ
 
 
 
 
 class ObjectConfig(object):
 class ObjectConfig(object):
@@ -127,17 +129,15 @@ class TestLoaderBase(unittest.TestCase):
         MockMail.Mailer.raise_on_send = True
         MockMail.Mailer.raise_on_send = True
         opts = dict(self.message_options, **self.server_options)
         opts = dict(self.message_options, **self.server_options)
 
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             self.loader.mail_admins(fail_silently=True, **opts)
             self.loader.mail_admins(fail_silently=True, **opts)
-            return log[0].message
+            warning = log[0].message
 
 
-        warning = execute_context(catch_warnings(record=True),
-                                  with_catch_warnings)
-        self.assertIsInstance(warning, MockMail.SendmailWarning)
-        self.assertIn("KeyError", warning.args[0])
+            self.assertIsInstance(warning, MockMail.SendmailWarning)
+            self.assertIn("KeyError", warning.args[0])
 
 
-        self.assertRaises(KeyError, self.loader.mail_admins,
-                          fail_silently=False, **opts)
+            self.assertRaises(KeyError, self.loader.mail_admins,
+                              fail_silently=False, **opts)
 
 
     def test_mail_admins(self):
     def test_mail_admins(self):
         MockMail.Mailer.raise_on_send = False
         MockMail.Mailer.raise_on_send = False
@@ -214,13 +214,10 @@ class TestDefaultLoader(unittest.TestCase):
             def import_from_cwd(self, name):
             def import_from_cwd(self, name):
                 raise ImportError(name)
                 raise ImportError(name)
 
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             l = _Loader()
             l = _Loader()
             self.assertDictEqual(l.conf, {})
             self.assertDictEqual(l.conf, {})
             context_executed[0] = True
             context_executed[0] = True
-
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
         self.assertTrue(context_executed[0])
         self.assertTrue(context_executed[0])
 
 
 
 

+ 4 - 5
celery/tests/test_backends/test_database.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import sys
 import sys
 
 
 from datetime import datetime
 from datetime import datetime
@@ -10,7 +12,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import mask_modules
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
 
 
 try:
 try:
@@ -39,13 +41,10 @@ class test_DatabaseBackend(unittest.TestCase):
             raise SkipTest("sqlalchemy not installed")
             raise SkipTest("sqlalchemy not installed")
 
 
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
-
-        def with_SQLAlchemy_masked(_val):
+        with mask_modules("sqlalchemy"):
             from celery.backends.database import _sqlalchemy_installed
             from celery.backends.database import _sqlalchemy_installed
             self.assertRaises(ImproperlyConfigured, _sqlalchemy_installed)
             self.assertRaises(ImproperlyConfigured, _sqlalchemy_installed)
 
 
-        execute_context(mask_modules("sqlalchemy"), with_SQLAlchemy_masked)
-
     def test_pickle_hack_for_sqla_05(self):
     def test_pickle_hack_for_sqla_05(self):
         import sqlalchemy as sa
         import sqlalchemy as sa
         from celery.db import session
         from celery.db import session

+ 4 - 5
celery/tests/test_backends/test_redis.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import sys
 import sys
 import socket
 import socket
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
@@ -11,7 +13,7 @@ from celery.utils import gen_unique_id
 from celery.backends import pyredis
 from celery.backends import pyredis
 from celery.backends.pyredis import RedisBackend
 from celery.backends.pyredis import RedisBackend
 
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import mask_modules
 
 
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg_emitted = False
 _no_redis_msg_emitted = False
@@ -108,12 +110,9 @@ class TestRedisBackendNoRedis(unittest.TestCase):
     def test_redis_None_if_redis_not_installed(self):
     def test_redis_None_if_redis_not_installed(self):
         prev = sys.modules.pop("celery.backends.pyredis")
         prev = sys.modules.pop("celery.backends.pyredis")
         try:
         try:
-
-            def with_redis_masked(_val):
+            with mask_modules("redis"):
                 from celery.backends.pyredis import redis
                 from celery.backends.pyredis import redis
                 self.assertIsNone(redis)
                 self.assertIsNone(redis)
-            context = mask_modules("redis")
-            execute_context(context, with_redis_masked)
         finally:
         finally:
             sys.modules["celery.backends.pyredis"] = prev
             sys.modules["celery.backends.pyredis"] = prev
 
 

+ 2 - 1
celery/tests/test_bin/test_celerybeat.py

@@ -1,6 +1,8 @@
 import logging
 import logging
 import sys
 import sys
 
 
+from collections import defaultdict
+
 from kombu.tests.utils import redirect_stdouts
 from kombu.tests.utils import redirect_stdouts
 
 
 from celery import beat
 from celery import beat
@@ -8,7 +10,6 @@ from celery import platforms
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.bin import celerybeat as celerybeat_bin
 from celery.bin import celerybeat as celerybeat_bin
 from celery.apps import beat as beatapp
 from celery.apps import beat as beatapp
-from celery.utils.compat import defaultdict
 
 
 from celery.tests.utils import AppCase
 from celery.tests.utils import AppCase
 
 

+ 5 - 5
celery/tests/test_bin/test_celeryd.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import logging
 import logging
 import os
 import os
 import sys
 import sys
@@ -9,6 +11,7 @@ try:
 except ImportError:
 except ImportError:
     current_process = None  # noqa
     current_process = None  # noqa
 
 
+from functools import wraps
 
 
 from nose import SkipTest
 from nose import SkipTest
 from kombu.tests.utils import redirect_stdouts
 from kombu.tests.utils import redirect_stdouts
@@ -24,7 +27,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.utils import patch
 from celery.utils import patch
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import AppCase, execute_context, StringIO
+from celery.tests.utils import AppCase, StringIO
 
 
 
 
 patch.ensure_process_aware_logger()
 patch.ensure_process_aware_logger()
@@ -187,15 +190,12 @@ class test_Worker(AppCase):
 
 
         prev, os.geteuid = os.geteuid, geteuid
         prev, os.geteuid = os.geteuid, geteuid
         try:
         try:
-
-            def with_catch_warnings(log):
+            with catch_warnings(record=True) as log:
                 worker = self.Worker()
                 worker = self.Worker()
                 worker.run()
                 worker.run()
                 self.assertTrue(log)
                 self.assertTrue(log)
                 self.assertIn("superuser privileges is not encouraged",
                 self.assertIn("superuser privileges is not encouraged",
                               log[0].message.args[0])
                               log[0].message.args[0])
-            context = catch_warnings(record=True)
-            execute_context(context, with_catch_warnings)
         finally:
         finally:
             os.geteuid = prev
             os.geteuid = prev
 
 

+ 4 - 5
celery/tests/test_compat/test_decorators.py

@@ -1,10 +1,11 @@
+from __future__ import with_statement
+
 import warnings
 import warnings
 
 
 from celery.task import base
 from celery.task import base
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context
 
 
 
 
 def add(x, y):
 def add(x, y):
@@ -16,11 +17,9 @@ class test_decorators(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         warnings.resetwarnings()
         warnings.resetwarnings()
 
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True):
             from celery import decorators
             from celery import decorators
-            return decorators
-        context = catch_warnings(record=True)
-        self.decorators = execute_context(context, with_catch_warnings)
+            self.decorators = decorators
 
 
     def assertCompatDecorator(self, decorator, type, **opts):
     def assertCompatDecorator(self, decorator, type, **opts):
         task = decorator(**opts)(add)
         task = decorator(**opts)(add)

+ 7 - 45
celery/tests/test_compat/test_log.py

@@ -1,4 +1,4 @@
-from __future__ import generators
+from __future__ import with_statement
 
 
 import sys
 import sys
 import logging
 import logging
@@ -10,35 +10,10 @@ from celery import log
 from celery.log import (setup_logger, setup_task_logger,
 from celery.log import (setup_logger, setup_task_logger,
                         get_default_logger, get_task_logger,
                         get_default_logger, get_task_logger,
                         redirect_stdouts_to_logger, LoggingProxy)
                         redirect_stdouts_to_logger, LoggingProxy)
-from celery.tests.utils import contextmanager
-from celery.tests.utils import override_stdouts, execute_context
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
-from celery.utils.compat import LoggerAdapter
 from celery.utils.compat import _CompatLoggerAdapter
 from celery.utils.compat import _CompatLoggerAdapter
-
-
-def get_handlers(logger):
-    if isinstance(logger, LoggerAdapter):
-        return logger.logger.handlers
-    return logger.handlers
-
-
-def set_handlers(logger, new_handlers):
-    if isinstance(logger, LoggerAdapter):
-        logger.logger.handlers = new_handlers
-    logger.handlers = new_handlers
-
-
-@contextmanager
-def wrap_logger(logger, loglevel=logging.ERROR):
-    old_handlers = get_handlers(logger)
-    sio = StringIO()
-    siohandler = logging.StreamHandler(sio)
-    set_handlers(logger, [siohandler])
-
-    yield sio
-
-    set_handlers(logger, old_handlers)
+from celery.tests.utils import (override_stdouts, wrap_logger,
+                                get_handlers, set_handlers)
 
 
 
 
 class test_default_logger(unittest.TestCase):
 class test_default_logger(unittest.TestCase):
@@ -50,13 +25,10 @@ class test_default_logger(unittest.TestCase):
 
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
 
-        def with_wrap_logger(sio):
+        with wrap_logger(logger, loglevel=loglevel) as sio:
             logger.log(loglevel, logmsg)
             logger.log(loglevel, logmsg)
             return sio.getvalue().strip()
             return sio.getvalue().strip()
 
 
-        context = wrap_logger(logger, loglevel=loglevel)
-        execute_context(context, with_wrap_logger)
-
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
         return self.assertEqual(val, logmsg, reason)
         return self.assertEqual(val, logmsg, reason)
@@ -81,16 +53,13 @@ class test_default_logger(unittest.TestCase):
         l = self.get_logger()
         l = self.get_logger()
         set_handlers(l, [])
         set_handlers(l, [])
 
 
-        def with_override_stdouts(outs):
+        with override_stdouts() as outs:
             stdout, stderr = outs
             stdout, stderr = outs
             l = self.setup_logger(logfile=stderr, loglevel=logging.INFO,
             l = self.setup_logger(logfile=stderr, loglevel=logging.INFO,
                                   root=False)
                                   root=False)
             l.info("The quick brown fox...")
             l.info("The quick brown fox...")
             self.assertIn("The quick brown fox...", stderr.getvalue())
             self.assertIn("The quick brown fox...", stderr.getvalue())
 
 
-        context = override_stdouts()
-        execute_context(context, with_override_stdouts)
-
     def test_setup_logger_no_handlers_file(self):
     def test_setup_logger_no_handlers_file(self):
         l = self.get_logger()
         l = self.get_logger()
         set_handlers(l, [])
         set_handlers(l, [])
@@ -103,14 +72,10 @@ class test_default_logger(unittest.TestCase):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False)
                                    root=False)
         try:
         try:
-
-            def with_wrap_logger(sio):
+            with wrap_logger(logger) as sio:
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
                 logger.error("foo")
                 self.assertIn("foo", sio.getvalue())
                 self.assertIn("foo", sio.getvalue())
-
-            context = wrap_logger(logger)
-            execute_context(context, with_wrap_logger)
         finally:
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
 
@@ -118,7 +83,7 @@ class test_default_logger(unittest.TestCase):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False)
                                    root=False)
 
 
-        def with_wrap_logger(sio):
+        with wrap_logger(logger) as sio:
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p.close()
             p.close()
             p.write("foo")
             p.write("foo")
@@ -135,9 +100,6 @@ class test_default_logger(unittest.TestCase):
             self.assertFalse(p.isatty())
             self.assertFalse(p.isatty())
             self.assertIsNone(p.fileno())
             self.assertIsNone(p.fileno())
 
 
-        context = wrap_logger(logger)
-        execute_context(context, with_wrap_logger)
-
 
 
 class test_task_logger(test_default_logger):
 class test_task_logger(test_default_logger):
 
 

+ 0 - 2
celery/tests/test_slow/test_buckets.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 import sys
 import sys
 import time
 import time
 
 

+ 1 - 5
celery/tests/test_task/test_result.py

@@ -1,16 +1,12 @@
-from __future__ import generators
-
-from celery.tests.utils import unittest
-
 from celery import states
 from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
-from celery.utils.compat import all
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.result import AsyncResult, EagerResult, TaskSetResult
 from celery.result import AsyncResult, EagerResult, TaskSetResult
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 from celery.task.base import Task
 
 
+from celery.tests.utils import unittest
 from celery.tests.utils import skip_if_quick
 from celery.tests.utils import skip_if_quick
 
 
 
 

+ 14 - 15
celery/tests/test_task/test_task_builtins.py

@@ -1,8 +1,10 @@
+from __future__ import with_statement
+
 import warnings
 import warnings
 
 
 from celery.task import ping, PingTask, backend_cleanup
 from celery.task import ping, PingTask, backend_cleanup
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest, execute_context
+from celery.tests.utils import unittest
 
 
 
 
 def some_func(i):
 def some_func(i):
@@ -14,32 +16,29 @@ class test_deprecated(unittest.TestCase):
     def test_ping(self):
     def test_ping(self):
         warnings.resetwarnings()
         warnings.resetwarnings()
 
 
-        def block(log):
+        with catch_warnings(record=True) as log:
             prev = PingTask.app.conf.CELERY_ALWAYS_EAGER
             prev = PingTask.app.conf.CELERY_ALWAYS_EAGER
             PingTask.app.conf.CELERY_ALWAYS_EAGER = True
             PingTask.app.conf.CELERY_ALWAYS_EAGER = True
             try:
             try:
-                return ping(), log[0].message
+                pong = ping()
+                warning = log[0].message
+                self.assertEqual(pong, "pong")
+                self.assertIsInstance(warning, DeprecationWarning)
+                self.assertIn("ping task has been deprecated",
+                              warning.args[0])
             finally:
             finally:
                 PingTask.app.conf.CELERY_ALWAYS_EAGER = prev
                 PingTask.app.conf.CELERY_ALWAYS_EAGER = prev
 
 
-        pong, warning = execute_context(catch_warnings(record=True), block)
-        self.assertEqual(pong, "pong")
-        self.assertIsInstance(warning, DeprecationWarning)
-        self.assertIn("ping task has been deprecated",
-                      warning.args[0])
-
     def test_TaskSet_import_from_task_base(self):
     def test_TaskSet_import_from_task_base(self):
         warnings.resetwarnings()
         warnings.resetwarnings()
 
 
-        def block(log):
+        with catch_warnings(record=True) as log:
             from celery.task.base import TaskSet, subtask
             from celery.task.base import TaskSet, subtask
             TaskSet()
             TaskSet()
             subtask(PingTask)
             subtask(PingTask)
-            return log[0].message, log[1].message
-
-        for w in execute_context(catch_warnings(record=True), block):
-            self.assertIsInstance(w, DeprecationWarning)
-            self.assertIn("is deprecated", w.args[0])
+            for w in (log[0].message, log[1].message):
+                self.assertIsInstance(w, DeprecationWarning)
+                self.assertIn("is deprecated", w.args[0])
 
 
 
 
 class test_backend_cleanup(unittest.TestCase):
 class test_backend_cleanup(unittest.TestCase):

+ 11 - 37
celery/tests/test_task/test_task_http.py

@@ -1,8 +1,9 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
-from __future__ import generators
+from __future__ import with_statement
 
 
 import logging
 import logging
 
 
+from contextlib import contextmanager
 from functools import wraps
 from functools import wraps
 try:
 try:
     from urllib import addinfourl
     from urllib import addinfourl
@@ -12,8 +13,7 @@ except ImportError:  # py3k
 from anyjson import serialize
 from anyjson import serialize
 
 
 from celery.task import http
 from celery.task import http
-from celery.tests.utils import unittest
-from celery.tests.utils import execute_context, contextmanager, StringIO
+from celery.tests.utils import unittest, StringIO
 
 
 
 
 @contextmanager
 @contextmanager
@@ -98,94 +98,68 @@ class TestHttpDispatch(unittest.TestCase):
     def test_dispatch_success(self):
     def test_dispatch_success(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(success_response(100)):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertEqual(d.dispatch(), 100)
             self.assertEqual(d.dispatch(), 100)
 
 
-        context = mock_urlopen(success_response(100))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_failure(self):
     def test_dispatch_failure(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(fail_response("Invalid moon alignment")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.RemoteExecuteError, d.dispatch)
             self.assertRaises(http.RemoteExecuteError, d.dispatch)
 
 
-        context = mock_urlopen(fail_response("Invalid moon alignment"))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_empty_response(self):
     def test_dispatch_empty_response(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(_response("")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
 
 
-        context = mock_urlopen(_response(""))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_non_json(self):
     def test_dispatch_non_json(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(_response("{'#{:'''")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
             self.assertRaises(http.InvalidResponseError, d.dispatch)
 
 
-        context = mock_urlopen(_response("{'#{:'''"))
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_unknown_status(self):
     def test_dispatch_unknown_status(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(unknown_response()):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertRaises(http.UnknownStatusError, d.dispatch)
             self.assertRaises(http.UnknownStatusError, d.dispatch)
 
 
-        context = mock_urlopen(unknown_response())
-        execute_context(context, with_mock_urlopen)
-
     def test_dispatch_POST(self):
     def test_dispatch_POST(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
 
 
-        def with_mock_urlopen(_val):
+        with mock_urlopen(success_response(100)):
             d = http.HttpDispatch("http://example.com/mul", "POST", {
             d = http.HttpDispatch("http://example.com/mul", "POST", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
             self.assertEqual(d.dispatch(), 100)
             self.assertEqual(d.dispatch(), 100)
 
 
-        context = mock_urlopen(success_response(100))
-        execute_context(context, with_mock_urlopen)
-
 
 
 class TestURL(unittest.TestCase):
 class TestURL(unittest.TestCase):
 
 
     def test_URL_get_async(self):
     def test_URL_get_async(self):
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         try:
         try:
-
-            def with_mock_urlopen(_val):
+            with mock_urlopen(success_response(100)):
                 d = http.URL("http://example.com/mul").get_async(x=10, y=10)
                 d = http.URL("http://example.com/mul").get_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
                 self.assertEqual(d.get(), 100)
-
-            context = mock_urlopen(success_response(100))
-            execute_context(context, with_mock_urlopen)
         finally:
         finally:
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False
 
 
     def test_URL_post_async(self):
     def test_URL_post_async(self):
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = True
         try:
         try:
-
-            def with_mock_urlopen(_val):
+            with mock_urlopen(success_response(100)):
                 d = http.URL("http://example.com/mul").post_async(x=10, y=10)
                 d = http.URL("http://example.com/mul").post_async(x=10, y=10)
                 self.assertEqual(d.get(), 100)
                 self.assertEqual(d.get(), 100)
-
-            context = mock_urlopen(success_response(100))
-            execute_context(context, with_mock_urlopen)
         finally:
         finally:
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False
             http.HttpDispatchTask.app.conf.CELERY_ALWAYS_EAGER = False

+ 5 - 11
celery/tests/test_task/test_task_sets.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import anyjson
 import anyjson
 import warnings
 import warnings
 
 
@@ -6,7 +8,6 @@ from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 from celery.task.sets import subtask, TaskSet
 
 
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
 
 
 
 
@@ -93,7 +94,7 @@ class test_TaskSet(unittest.TestCase):
     def test_interface__compat(self):
     def test_interface__compat(self):
         warnings.resetwarnings()
         warnings.resetwarnings()
 
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertTrue(log)
             self.assertTrue(log)
             self.assertIn("Using this invocation of TaskSet is deprecated",
             self.assertIn("Using this invocation of TaskSet is deprecated",
@@ -103,29 +104,22 @@ class test_TaskSet(unittest.TestCase):
                                     for i in (2, 4, 8)])
                                     for i in (2, 4, 8)])
             return ts
             return ts
 
 
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
-
         # TaskSet.task (deprecated)
         # TaskSet.task (deprecated)
-        def with_catch_warnings2(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertEqual(ts.task.name, MockTask.name)
             self.assertEqual(ts.task.name, MockTask.name)
             self.assertTrue(log)
             self.assertTrue(log)
             self.assertIn("TaskSet.task is deprecated",
             self.assertIn("TaskSet.task is deprecated",
                           log[0].message.args[0])
                           log[0].message.args[0])
 
 
-        execute_context(catch_warnings(record=True), with_catch_warnings2)
-
         # TaskSet.task_name (deprecated)
         # TaskSet.task_name (deprecated)
-        def with_catch_warnings3(log):
+        with catch_warnings(record=True) as log:
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertEqual(ts.task_name, MockTask.name)
             self.assertEqual(ts.task_name, MockTask.name)
             self.assertTrue(log)
             self.assertTrue(log)
             self.assertIn("TaskSet.task_name is deprecated",
             self.assertIn("TaskSet.task_name is deprecated",
                           log[0].message.args[0])
                           log[0].message.args[0])
 
 
-        execute_context(catch_warnings(record=True), with_catch_warnings3)
-
     def test_task_arg_can_be_iterable__compat(self):
     def test_task_arg_can_be_iterable__compat(self):
         ts = TaskSet([MockTask.subtask((i, i))
         ts = TaskSet([MockTask.subtask((i, i))
                         for i in (2, 4, 8)])
                         for i in (2, 4, 8)])

+ 5 - 8
celery/tests/test_utils/test_serialization.py

@@ -1,7 +1,9 @@
+from __future__ import with_statement
+
 import sys
 import sys
-from celery.tests.utils import unittest
 
 
-from celery.tests.utils import execute_context, mask_modules
+from celery.tests.utils import unittest
+from celery.tests.utils import mask_modules
 
 
 
 
 class TestAAPickle(unittest.TestCase):
 class TestAAPickle(unittest.TestCase):
@@ -9,14 +11,9 @@ class TestAAPickle(unittest.TestCase):
     def test_no_cpickle(self):
     def test_no_cpickle(self):
         prev = sys.modules.pop("celery.utils.serialization", None)
         prev = sys.modules.pop("celery.utils.serialization", None)
         try:
         try:
-
-            def with_cPickle_masked(_val):
+            with mask_modules("cPickle"):
                 from celery.utils.serialization import pickle
                 from celery.utils.serialization import pickle
                 import pickle as orig_pickle
                 import pickle as orig_pickle
                 self.assertIs(pickle.dumps, orig_pickle.dumps)
                 self.assertIs(pickle.dumps, orig_pickle.dumps)
-
-            context = mask_modules("cPickle")
-            execute_context(context, with_cPickle_masked)
-
         finally:
         finally:
             sys.modules["celery.utils.serialization"] = prev
             sys.modules["celery.utils.serialization"] = prev

+ 4 - 4
celery/tests/test_worker/test_worker.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import socket
 import socket
 import sys
 import sys
 
 
@@ -25,7 +27,7 @@ from celery.utils.timer2 import Timer
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
-from celery.tests.utils import AppCase, execute_context, skip
+from celery.tests.utils import AppCase, skip
 
 
 
 
 class PlaceHolder(object):
 class PlaceHolder(object):
@@ -253,13 +255,11 @@ class test_Consumer(unittest.TestCase):
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
         l.pidbox_node = MockNode()
 
 
-        def with_catch_warnings(log):
+        with catch_warnings(record=True) as log:
             l.receive_message(m.decode(), m)
             l.receive_message(m.decode(), m)
             self.assertTrue(log)
             self.assertTrue(log)
             self.assertIn("unknown message", log[0].message.args[0])
             self.assertIn("unknown message", log[0].message.args[0])
 
 
-        context = catch_warnings(record=True)
-        execute_context(context, with_catch_warnings)
 
 
     @patch("celery.utils.timer2.to_timestamp")
     @patch("celery.utils.timer2.to_timestamp")
     def test_receive_message_eta_OverflowError(self, to_timestamp):
     def test_receive_message_eta_OverflowError(self, to_timestamp):

+ 6 - 13
celery/tests/test_worker/test_worker_job.py

@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
+from __future__ import with_statement
+
 import anyjson
 import anyjson
 import logging
 import logging
 import os
 import os
@@ -27,7 +29,7 @@ from celery.worker.state import revoked
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
-from celery.tests.utils import execute_context, StringIO, wrap_logger
+from celery.tests.utils import StringIO, wrap_logger
 
 
 
 
 scratch = {"ACK": False}
 scratch = {"ACK": False}
@@ -89,19 +91,14 @@ class test_WorkerTaskTrace(unittest.TestCase):
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         mytask.backend.process_cleanup = Mock(side_effect=KeyError())
         try:
         try:
 
 
-            def with_wrap_logger(sio):
+            logger = mytask.app.log.get_default_logger()
+            with wrap_logger(logger) as sio:
                 uuid = gen_unique_id()
                 uuid = gen_unique_id()
                 ret = jail(uuid, mytask.name, [2], {})
                 ret = jail(uuid, mytask.name, [2], {})
                 self.assertEqual(ret, 4)
                 self.assertEqual(ret, 4)
                 mytask.backend.mark_as_done.assert_called_with(uuid, 4)
                 mytask.backend.mark_as_done.assert_called_with(uuid, 4)
                 logs = sio.getvalue().strip()
                 logs = sio.getvalue().strip()
                 self.assertIn("Process cleanup failed", logs)
                 self.assertIn("Process cleanup failed", logs)
-                return 1234
-
-            logger = mytask.app.log.get_default_logger()
-            self.assertEqual(execute_context(
-                    wrap_logger(logger), with_wrap_logger), 1234)
-
         finally:
         finally:
             mytask.backend = backend
             mytask.backend = backend
 
 
@@ -418,17 +415,13 @@ class test_TaskRequest(unittest.TestCase):
 
 
         WorkerTaskTrace.execute = _error_exec
         WorkerTaskTrace.execute = _error_exec
         try:
         try:
-
-            def with_catch_warnings(log):
+            with catch_warnings(record=True) as log:
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                                         [4], {})
                                         [4], {})
                 self.assertIsInstance(res, ExceptionInfo)
                 self.assertIsInstance(res, ExceptionInfo)
                 self.assertTrue(log)
                 self.assertTrue(log)
                 self.assertIn("Exception outside", log[0].message.args[0])
                 self.assertIn("Exception outside", log[0].message.args[0])
                 self.assertIn("KeyError", log[0].message.args[0])
                 self.assertIn("KeyError", log[0].message.args[0])
-
-            context = catch_warnings(record=True)
-            execute_context(context, with_catch_warnings)
         finally:
         finally:
             WorkerTaskTrace.execute = old_exec
             WorkerTaskTrace.execute = old_exec
 
 

+ 2 - 59
celery/tests/utils.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 try:
 try:
     import unittest
     import unittest
     unittest.skip
     unittest.skip
@@ -17,20 +15,15 @@ except ImportError:  # py3k
     import builtins  # noqa
     import builtins  # noqa
 
 
 from functools import wraps
 from functools import wraps
+from contextlib import contextmanager
 
 
-from celery.utils.compat import StringIO, LoggerAdapter
-try:
-    from contextlib import contextmanager
-except ImportError:
-    from celery.tests.utils import fallback_contextmanager
-    contextmanager = fallback_contextmanager  # noqa
 
 
 import mock
 import mock
-
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.utils import noop
 from celery.utils import noop
+from celery.utils.compat import StringIO, LoggerAdapter
 
 
 
 
 class Mock(mock.Mock):
 class Mock(mock.Mock):
@@ -60,37 +53,6 @@ class AppCase(unittest.TestCase):
         pass
         pass
 
 
 
 
-class GeneratorContextManager(object):
-    def __init__(self, gen):
-        self.gen = gen
-
-    def __enter__(self):
-        try:
-            return self.gen.next()
-        except StopIteration:
-            raise RuntimeError("generator didn't yield")
-
-    def __exit__(self, type, value, traceback):
-        if type is None:
-            try:
-                self.gen.next()
-            except StopIteration:
-                return
-            else:
-                raise RuntimeError("generator didn't stop")
-        else:
-            try:
-                self.gen.throw(type, value, traceback)
-                raise RuntimeError("generator didn't stop after throw()")
-            except StopIteration:
-                return True
-            except AttributeError:
-                raise value
-            except:
-                if sys.exc_info()[1] is not value:
-                    raise
-
-
 def get_handlers(logger):
 def get_handlers(logger):
     if isinstance(logger, LoggerAdapter):
     if isinstance(logger, LoggerAdapter):
         return logger.logger.handlers
         return logger.logger.handlers
@@ -115,25 +77,6 @@ def wrap_logger(logger, loglevel=logging.ERROR):
     set_handlers(logger, old_handlers)
     set_handlers(logger, old_handlers)
 
 
 
 
-def fallback_contextmanager(fun):
-    def helper(*args, **kwds):
-        return GeneratorContextManager(fun(*args, **kwds))
-    return helper
-
-
-def execute_context(context, fun):
-    val = context.__enter__()
-    exc_info = (None, None, None)
-    try:
-        try:
-            return fun(val)
-        except:
-            exc_info = sys.exc_info()
-            raise
-    finally:
-        context.__exit__(*exc_info)
-
-
 @contextmanager
 @contextmanager
 def eager_tasks():
 def eager_tasks():
     app = app_or_default()
     app = app_or_default()

+ 0 - 3
celery/utils/__init__.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 import os
 import os
 import sys
 import sys
 import operator
 import operator
@@ -19,7 +17,6 @@ from kombu.utils import rpartition
 
 
 from celery.utils.compat import StringIO
 from celery.utils.compat import StringIO
 
 
-
 LOG_LEVELS = dict(logging._levelNames)
 LOG_LEVELS = dict(logging._levelNames)
 LOG_LEVELS["FATAL"] = logging.FATAL
 LOG_LEVELS["FATAL"] = logging.FATAL
 LOG_LEVELS[logging.FATAL] = "FATAL"
 LOG_LEVELS[logging.FATAL] = "FATAL"

+ 1 - 99
celery/utils/compat.py

@@ -1,5 +1,3 @@
-from __future__ import generators
-
 ############## py3k #########################################################
 ############## py3k #########################################################
 try:
 try:
     from UserList import UserList       # noqa
     from UserList import UserList       # noqa
@@ -19,39 +17,6 @@ except ImportError:
     except ImportError:
     except ImportError:
         from io import StringIO         # noqa
         from io import StringIO         # noqa
 
 
-############## urlparse.parse_qsl ###########################################
-
-try:
-    from urlparse import parse_qsl
-except ImportError:
-    from cgi import parse_qsl  # noqa
-
-############## __builtin__.all ##############################################
-
-try:
-    all([True])
-    all = all
-except NameError:
-
-    def all(iterable):
-        for item in iterable:
-            if not item:
-                return False
-        return True
-
-############## __builtin__.any ##############################################
-
-try:
-    any([True])
-    any = any
-except NameError:
-
-    def any(iterable):
-        for item in iterable:
-            if item:
-                return True
-        return False
-
 ############## collections.OrderedDict ######################################
 ############## collections.OrderedDict ######################################
 
 
 import weakref
 import weakref
@@ -270,53 +235,6 @@ try:
 except ImportError:
 except ImportError:
     OrderedDict = CompatOrderedDict  # noqa
     OrderedDict = CompatOrderedDict  # noqa
 
 
-############## collections.defaultdict ######################################
-
-try:
-    from collections import defaultdict
-except ImportError:
-    # Written by Jason Kirtland, taken from Python Cookbook:
-    # <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/523034>
-    class defaultdict(dict):  # noqa
-
-        def __init__(self, default_factory=None, *args, **kwargs):
-            dict.__init__(self, *args, **kwargs)
-            self.default_factory = default_factory
-
-        def __getitem__(self, key):
-            try:
-                return dict.__getitem__(self, key)
-            except KeyError:
-                return self.__missing__(key)
-
-        def __missing__(self, key):
-            if self.default_factory is None:
-                raise KeyError(key)
-            self[key] = value = self.default_factory()
-            return value
-
-        def __reduce__(self):
-            f = self.default_factory
-            args = f is None and tuple() or f
-            return type(self), args, None, None, self.iteritems()
-
-        def copy(self):
-            return self.__copy__()
-
-        def __copy__(self):
-            return type(self)(self.default_factory, self)
-
-        def __deepcopy__(self):
-            import copy
-            return type(self)(self.default_factory,
-                        copy.deepcopy(self.items()))
-
-        def __repr__(self):
-            return "defaultdict(%s, %s)" % (self.default_factory,
-                                            dict.__repr__(self))
-    import collections
-    collections.defaultdict = defaultdict           # Pickle needs this.
-
 ############## logging.LoggerAdapter ########################################
 ############## logging.LoggerAdapter ########################################
 import inspect
 import inspect
 import logging
 import logging
@@ -326,16 +244,6 @@ except ImportError:
     multiprocessing = None  # noqa
     multiprocessing = None  # noqa
 import sys
 import sys
 
 
-from logging import LogRecord
-
-log_takes_extra = "extra" in inspect.getargspec(logging.Logger._log)[0]
-
-# The func argument to LogRecord was added in 2.5
-if "func" not in inspect.getargspec(LogRecord.__init__)[0]:
-
-    def LogRecord(name, level, fn, lno, msg, args, exc_info, func):
-        return logging.LogRecord(name, level, fn, lno, msg, args, exc_info)
-
 
 
 def _checkLevel(level):
 def _checkLevel(level):
     if isinstance(level, int):
     if isinstance(level, int):
@@ -390,7 +298,7 @@ class _CompatLoggerAdapter(object):
 
 
     def makeRecord(self, name, level, fn, lno, msg, args, exc_info,
     def makeRecord(self, name, level, fn, lno, msg, args, exc_info,
             func=None, extra=None):
             func=None, extra=None):
-        rv = LogRecord(name, level, fn, lno, msg, args, exc_info, func)
+        rv = logging.LogRecord(name, level, fn, lno, msg, args, exc_info, func)
         if extra is not None:
         if extra is not None:
             for key, value in extra.items():
             for key, value in extra.items():
                 if key in ("message", "asctime") or key in rv.__dict__:
                 if key in ("message", "asctime") or key in rv.__dict__:
@@ -441,12 +349,6 @@ try:
 except ImportError:
 except ImportError:
     LoggerAdapter = _CompatLoggerAdapter  # noqa
     LoggerAdapter = _CompatLoggerAdapter  # noqa
 
 
-
-def log_with_extra(logger, level, msg, *args, **kwargs):
-    if not log_takes_extra:
-        kwargs.pop("extra", None)
-    return logger.log(level, msg, *args, **kwargs)
-
 ############## itertools.izip_longest #######################################
 ############## itertools.izip_longest #######################################
 
 
 try:
 try:

+ 2 - 8
celery/utils/serialization.py

@@ -11,7 +11,7 @@ except ImportError:
     cpickle = None  # noqa
     cpickle = None  # noqa
 
 
 if sys.version_info < (2, 6):  # pragma: no cover
 if sys.version_info < (2, 6):  # pragma: no cover
-    # cPickle is broken in Python <= 2.5.
+    # cPickle is broken in Python <= 2.6.
     # It unsafely and incorrectly uses relative instead of absolute imports,
     # It unsafely and incorrectly uses relative instead of absolute imports,
     # so e.g.:
     # so e.g.:
     #       exceptions.KeyError
     #       exceptions.KeyError
@@ -26,14 +26,8 @@ else:
     pickle = cpickle or pypickle
     pickle = cpickle or pypickle
 
 
 
 
-# BaseException was introduced in Python 2.5.
-try:
-    _error_bases = (BaseException, )
-except NameError:  # pragma: no cover
-    _error_bases = (SystemExit, KeyboardInterrupt)
-
 #: List of base classes we probably don't want to reduce to.
 #: List of base classes we probably don't want to reduce to.
-unwanted_base_classes = (StandardError, Exception) + _error_bases + (object, )
+unwanted_base_classes = (StandardError, Exception, BaseException, object)
 
 
 
 
 if sys.version_info < (2, 5):  # pragma: no cover
 if sys.version_info < (2, 5):  # pragma: no cover

+ 0 - 3
celery/utils/timer2.py

@@ -1,7 +1,4 @@
 """timer2 - Scheduler for Python functions."""
 """timer2 - Scheduler for Python functions."""
-
-from __future__ import generators
-
 import atexit
 import atexit
 import heapq
 import heapq
 import logging
 import logging

+ 1 - 1
celery/worker/buckets.py

@@ -6,7 +6,7 @@ from Queue import Queue, Empty
 
 
 from celery.datastructures import TokenBucket
 from celery.datastructures import TokenBucket
 from celery.utils import timeutils
 from celery.utils import timeutils
-from celery.utils.compat import all, izip_longest, chain_from_iterable
+from celery.utils.compat import izip_longest, chain_from_iterable
 
 
 
 
 class RateLimitExceeded(Exception):
 class RateLimitExceeded(Exception):

+ 4 - 19
celery/worker/consumer.py

@@ -67,9 +67,6 @@ up and running.
   early, *then* close the connection.
   early, *then* close the connection.
 
 
 """
 """
-
-from __future__ import generators
-
 import socket
 import socket
 import sys
 import sys
 import threading
 import threading
@@ -139,14 +136,11 @@ class QoS(object):
 
 
     def increment(self, n=1):
     def increment(self, n=1):
         """Increment the current prefetch count value by one."""
         """Increment the current prefetch count value by one."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
             if self.value:
                 new_value = self.value + max(n, 0)
                 new_value = self.value + max(n, 0)
                 self.value = self.set(new_value)
                 self.value = self.set(new_value)
             return self.value
             return self.value
-        finally:
-            self._mutex.release()
 
 
     def _sub(self, n=1):
     def _sub(self, n=1):
         assert self.value - n > 1
         assert self.value - n > 1
@@ -154,14 +148,11 @@ class QoS(object):
 
 
     def decrement(self, n=1):
     def decrement(self, n=1):
         """Decrement the current prefetch count value by one."""
         """Decrement the current prefetch count value by one."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
             if self.value:
                 self._sub(n)
                 self._sub(n)
                 self.set(self.value)
                 self.set(self.value)
             return self.value
             return self.value
-        finally:
-            self._mutex.release()
 
 
     def decrement_eventually(self, n=1):
     def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.
         """Decrement the value, but do not update the qos.
@@ -170,12 +161,9 @@ class QoS(object):
         when necessary.
         when necessary.
 
 
         """
         """
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             if self.value:
             if self.value:
                 self._sub(n)
                 self._sub(n)
-        finally:
-            self._mutex.release()
 
 
     def set(self, pcount):
     def set(self, pcount):
         """Set channel prefetch_count setting."""
         """Set channel prefetch_count setting."""
@@ -193,11 +181,8 @@ class QoS(object):
 
 
     def update(self):
     def update(self):
         """Update prefetch count with current value."""
         """Update prefetch count with current value."""
-        self._mutex.acquire()
-        try:
+        with self._mutex:
             return self.set(self.value)
             return self.set(self.value)
-        finally:
-            self._mutex.release()
 
 
 
 
 class Consumer(object):
 class Consumer(object):

+ 5 - 8
celery/worker/job.py

@@ -1,4 +1,3 @@
-import logging
 import os
 import os
 import sys
 import sys
 import time
 import time
@@ -17,7 +16,6 @@ from celery.execute.trace import TaskTrace
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.utils import noop, kwdict, fun_takes_kwargs
 from celery.utils import noop, kwdict, fun_takes_kwargs
 from celery.utils import get_symbol_by_name, truncate_text
 from celery.utils import get_symbol_by_name, truncate_text
-from celery.utils.compat import log_with_extra
 from celery.utils.encoding import safe_repr, safe_str
 from celery.utils.encoding import safe_repr, safe_str
 from celery.utils.timeutils import maybe_iso8601
 from celery.utils.timeutils import maybe_iso8601
 from celery.worker import state
 from celery.worker import state
@@ -523,12 +521,11 @@ class TaskRequest(object):
                    "args": self.args,
                    "args": self.args,
                    "kwargs": self.kwargs}
                    "kwargs": self.kwargs}
 
 
-        log_with_extra(self.logger, logging.ERROR,
-                       self.error_msg.strip() % context,
-                       exc_info=exc_info,
-                       extra={"data": {"hostname": self.hostname,
-                                       "id": self.task_id,
-                                       "name": self.task_name}})
+        self.logger.error(self.error_msg.strip() % context,
+                          exc_info=exc_info,
+                          extra={"data": {"id": self.task_id,
+                                          "name": self.task_name,
+                                          "hostname": self.hostname}})
 
 
         task_obj = tasks.get(self.task_name, object)
         task_obj = tasks.get(self.task_name, object)
         self.send_error_email(task_obj, context, exc_info.exception,
         self.send_error_email(task_obj, context, exc_info.exception,

+ 6 - 9
celery/worker/mediator.py

@@ -3,7 +3,6 @@
 Worker Controller Threads
 Worker Controller Threads
 
 
 """
 """
-import logging
 import os
 import os
 import sys
 import sys
 import threading
 import threading
@@ -12,7 +11,6 @@ import traceback
 from Queue import Empty
 from Queue import Empty
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
-from celery.utils.compat import log_with_extra
 
 
 
 
 class Mediator(threading.Thread):
 class Mediator(threading.Thread):
@@ -51,13 +49,12 @@ class Mediator(threading.Thread):
         try:
         try:
             self.callback(task)
             self.callback(task)
         except Exception, exc:
         except Exception, exc:
-            log_with_extra(self.logger, logging.ERROR,
-                           "Mediator callback raised exception %r\n%s" % (
-                               exc, traceback.format_exc()),
-                           exc_info=sys.exc_info(),
-                           extra={"data": {"hostname": task.hostname,
-                                           "id": task.task_id,
-                                           "name": task.task_name}})
+            self.logger.error("Mediator callback raised exception %r\n%s" % (
+                                exc, traceback.format_exc()),
+                              exc_info=sys.exc_info(),
+                              extra={"data": {"id": task.task_id,
+                                              "name": task.task_name,
+                                              "hostname": task.hostname}})
 
 
     def run(self):
     def run(self):
         """Move tasks forver or until :meth:`stop` is called."""
         """Move tasks forver or until :meth:`stop` is called."""

+ 2 - 1
celery/worker/state.py

@@ -2,10 +2,11 @@ import os
 import platform
 import platform
 import shelve
 import shelve
 
 
+from collections import defaultdict
+
 from kombu.utils import cached_property
 from kombu.utils import cached_property
 
 
 from celery import __version__
 from celery import __version__
-from celery.utils.compat import defaultdict
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
 
 
 #: Worker software/platform information.
 #: Worker software/platform information.