Browse Source

Merge branch '3.0'

Ask Solem 12 years ago
parent
commit
243d58ca60

+ 1 - 1
celery/app/log.py

@@ -38,7 +38,7 @@ class TaskFormatter(ColorFormatter):
 
     def format(self, record):
         task = get_current_task()
-        if task:
+        if task and task.request:
             record.__dict__.update(task_id=task.request.id,
                                    task_name=task.name)
         else:

+ 15 - 5
celery/app/task.py

@@ -236,6 +236,10 @@ class Task(object):
     #: Default task expiry time.
     expires = None
 
+    #: Some may expect a request to exist even if the task has not been
+    #: called.  This should probably be deprecated.
+    _default_request = None
+
     __bound__ = False
 
     from_config = (
@@ -274,7 +278,6 @@ class Task(object):
 
             from celery.utils.threads import LocalStack
             self.request_stack = LocalStack()
-            self.request_stack.push(Context())
 
         # PeriodicTask uses this to add itself to the PeriodicTask schedule.
         self.on_bound(app)
@@ -783,10 +786,17 @@ class Task(object):
         """`repr(task)`"""
         return '<@task: {0.name}>'.format(self)
 
-    @property
-    def request(self):
-        """Current request object."""
-        return self.request_stack.top
+    def _get_request(self):
+        """Get current request object."""
+        req = self.request_stack.top
+        if req is None:
+            # task was not called, but some may still expect a request
+            # to be there, perhaps that should be deprecated.
+            if self._default_request is None:
+                self._default_request = Context()
+            return self._default_request
+        return req
+    request = property(_get_request)
 
     @property
     def __name__(self):

+ 6 - 5
celery/task/base.py

@@ -6,7 +6,7 @@
     The task implementation has been moved to :mod:`celery.app.task`.
 
     This contains the backward compatible Task class used in the old API,
-    and shouldn't be used anymore.
+    and shouldn't be used in new applications.
 
 """
 from __future__ import absolute_import
@@ -21,7 +21,8 @@ from celery.utils.log import get_task_logger
 
 #: list of methods that must be classmethods in the old API.
 _COMPAT_CLASSMETHODS = (
-    'delay', 'apply_async', 'retry', 'apply', 'AsyncResult', 'subtask',
+    'delay', 'apply_async', 'retry', 'apply',
+    'AsyncResult', 'subtask', '_get_request',
 )
 
 
@@ -61,10 +62,10 @@ class Task(BaseTask):
     for name in _COMPAT_CLASSMETHODS:
         locals()[name] = reclassmethod(getattr(BaseTask, name))
 
+    @class_property
     @classmethod
-    def _get_request(self):
-        return self.request_stack.top
-    request = class_property(_get_request)
+    def request(cls):
+        return cls._get_request()
 
     @classmethod
     def get_logger(self, **kwargs):

+ 2 - 2
celery/task/trace.py

@@ -387,8 +387,8 @@ def _install_stack_protection():
         def __protected_call__(self, *args, **kwargs):
             stack = self.request_stack
             req = stack.top
-            if req and not req._protected and len(stack) == 2 and \
-                    not req.called_directly:
+            if req and not req._protected and \
+                    len(stack) == 1 and not req.called_directly:
                 req._protected = 1
                 return self.run(*args, **kwargs)
             return orig(self, *args, **kwargs)

+ 21 - 9
celery/tests/contrib/test_abortable.py

@@ -21,19 +21,31 @@ class test_AbortableTask(Case):
 
     def test_is_not_aborted(self):
         t = MyAbortableTask()
-        result = t.apply_async()
-        tid = result.id
-        self.assertFalse(t.is_aborted(task_id=tid))
+        t.push_request()
+        try:
+            result = t.apply_async()
+            tid = result.id
+            self.assertFalse(t.is_aborted(task_id=tid))
+        finally:
+            t.pop_request()
 
     def test_is_aborted_not_abort_result(self):
         t = MyAbortableTask()
         t.AsyncResult = AsyncResult
-        t.request.id = 'foo'
-        self.assertFalse(t.is_aborted())
+        t.push_request()
+        try:
+            t.request.id = 'foo'
+            self.assertFalse(t.is_aborted())
+        finally:
+            t.pop_request()
 
     def test_abort_yields_aborted(self):
         t = MyAbortableTask()
-        result = t.apply_async()
-        result.abort()
-        tid = result.id
-        self.assertTrue(t.is_aborted(task_id=tid))
+        t.push_request()
+        try:
+            result = t.apply_async()
+            result.abort()
+            tid = result.id
+            self.assertTrue(t.is_aborted(task_id=tid))
+        finally:
+            t.pop_request()

+ 1 - 0
celery/tests/tasks/test_sets.py

@@ -148,6 +148,7 @@ class test_TaskSet(Case):
         def xyz():
             pass
         from celery._state import _task_stack
+        xyz.push_request()
         _task_stack.push(xyz)
         try:
             ts.apply_async(publisher=Publisher())

+ 64 - 37
celery/tests/tasks/test_tasks.py

@@ -141,27 +141,37 @@ class test_task_retries(Case):
         self.assertEqual(retry_task_noargs.iterations, 4)
 
     def test_retry_kwargs_can_be_empty(self):
-        with self.assertRaises(RetryTaskError):
-            retry_task_mockapply.retry(args=[4, 4], kwargs=None)
-
-    def test_retry_not_eager(self):
-        retry_task_mockapply.request.called_directly = False
-        exc = Exception('baz')
+        retry_task_mockapply.push_request()
         try:
-            retry_task_mockapply.retry(args=[4, 4], kwargs={'task_retries': 0},
-                                       exc=exc, throw=False)
-            self.assertTrue(retry_task_mockapply.__class__.applied)
+            with self.assertRaises(RetryTaskError):
+                retry_task_mockapply.retry(args=[4, 4], kwargs=None)
         finally:
-            retry_task_mockapply.__class__.applied = 0
+            retry_task_mockapply.pop_request()
 
+    def test_retry_not_eager(self):
+        retry_task_mockapply.push_request()
         try:
-            with self.assertRaises(RetryTaskError):
+            retry_task_mockapply.request.called_directly = False
+            exc = Exception('baz')
+            try:
                 retry_task_mockapply.retry(
                     args=[4, 4], kwargs={'task_retries': 0},
-                    exc=exc, throw=True)
-            self.assertTrue(retry_task_mockapply.__class__.applied)
+                    exc=exc, throw=False,
+                )
+                self.assertTrue(retry_task_mockapply.__class__.applied)
+            finally:
+                retry_task_mockapply.__class__.applied = 0
+
+            try:
+                with self.assertRaises(RetryTaskError):
+                    retry_task_mockapply.retry(
+                        args=[4, 4], kwargs={'task_retries': 0},
+                        exc=exc, throw=True)
+                self.assertTrue(retry_task_mockapply.__class__.applied)
+            finally:
+                retry_task_mockapply.__class__.applied = 0
         finally:
-            retry_task_mockapply.__class__.applied = 0
+            retry_task_mockapply.pop_request()
 
     def test_retry_with_kwargs(self):
         retry_task_customexc.__class__.max_retries = 3
@@ -322,11 +332,16 @@ class test_tasks(Case):
         self.assertTrue(publisher.exchange)
 
     def test_context_get(self):
-        request = self.createTask('c.unittest.t.c.g').request
-        request.foo = 32
-        self.assertEqual(request.get('foo'), 32)
-        self.assertEqual(request.get('bar', 36), 36)
-        request.clear()
+        task = self.createTask('c.unittest.t.c.g')
+        task.push_request()
+        try:
+            request = task.request
+            request.foo = 32
+            self.assertEqual(request.get('foo'), 32)
+            self.assertEqual(request.get('bar', 36), 36)
+            request.clear()
+        finally:
+            task.pop_request()
 
     def test_task_class_repr(self):
         task = self.createTask('c.unittest.t.repr')
@@ -350,9 +365,13 @@ class test_tasks(Case):
 
     def test_after_return(self):
         task = self.createTask('c.unittest.t.after_return')
-        task.request.chord = return_True_task.s()
-        task.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
-        task.request.clear()
+        task.push_request()
+        try:
+            task.request.chord = return_True_task.s()
+            task.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
+            task.request.clear()
+        finally:
+            task.pop_request()
 
     def test_send_task_sent_event(self):
         T1 = self.createTask('c.unittest.t.t1')
@@ -393,15 +412,19 @@ class test_tasks(Case):
         def yyy():
             pass
 
-        tid = uuid()
-        yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
-        self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
-
-        yyy.request.id = tid
-        yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
-        self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
+        yyy.push_request()
+        try:
+            tid = uuid()
+            yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
+            self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
+            self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
+
+            yyy.request.id = tid
+            yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
+            self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
+            self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
+        finally:
+            yyy.pop_request()
 
     def test_repr(self):
 
@@ -421,13 +444,17 @@ class test_tasks(Case):
 
     def test_get_logger(self):
         t1 = self.createTask('c.unittest.t.t1')
-        logfh = WhateverIO()
-        logger = t1.get_logger(logfile=logfh, loglevel=0)
-        self.assertTrue(logger)
+        t1.push_request()
+        try:
+            logfh = WhateverIO()
+            logger = t1.get_logger(logfile=logfh, loglevel=0)
+            self.assertTrue(logger)
 
-        t1.request.loglevel = 3
-        logger = t1.get_logger(logfile=logfh, loglevel=None)
-        self.assertTrue(logger)
+            t1.request.loglevel = 3
+            logger = t1.get_logger(logfile=logfh, loglevel=None)
+            self.assertTrue(logger)
+        finally:
+            t1.pop_request()
 
 
 class test_TaskSet(Case):

+ 5 - 4
celery/tests/worker/test_request.py

@@ -619,7 +619,7 @@ class test_TaskRequest(AppCase):
     def test_worker_task_trace_handle_retry(self):
         from celery.exceptions import RetryTaskError
         tid = uuid()
-        mytask.request.update({'id': tid})
+        mytask.push_request(id=tid)
         try:
             raise ValueError('foo')
         except Exception as exc:
@@ -634,12 +634,13 @@ class test_TaskRequest(AppCase):
                 self.assertEqual(mytask.backend.get_status(tid),
                                  states.RETRY)
         finally:
-            mytask.request.clear()
+            mytask.pop_request()
 
     def test_worker_task_trace_handle_failure(self):
         tid = uuid()
-        mytask.request.update({'id': tid})
+        mytask.push_request()
         try:
+            mytask.request.id = tid
             try:
                 raise ValueError('foo')
             except Exception as exc:
@@ -651,7 +652,7 @@ class test_TaskRequest(AppCase):
                 self.assertEqual(mytask.backend.get_status(tid),
                                  states.FAILURE)
         finally:
-            mytask.request.clear()
+            mytask.pop_request()
 
     def test_task_wrapper_mail_attrs(self):
         tw = TaskRequest(mytask.name, uuid(), [], {})