Browse Source

100% coverage for celery.worker.job

Ask Solem 12 years ago
parent
commit
58735efe1a
3 changed files with 138 additions and 15 deletions
  1. 12 4
      celery/tests/utils.py
  2. 115 3
      celery/tests/worker/test_request.py
  3. 11 8
      celery/worker/job.py

+ 12 - 4
celery/tests/utils.py

@@ -18,7 +18,7 @@ import time
 import warnings
 
 from contextlib import contextmanager
-from datetime import timedelta
+from datetime import datetime, timedelta
 from functools import partial, wraps
 from types import ModuleType
 
@@ -597,9 +597,16 @@ def body_from_sig(app, sig, utc=True):
     errbacks = sig.options.pop('link_error', None)
     countdown = sig.options.pop('countdown', None)
     if countdown:
-        sig.options['eta'] = app.now() + timedelta(seconds=countdown)
-    eta = sig.options.pop('eta', None)
-    eta = eta.isoformat() if eta else None
+        eta = app.now() + timedelta(seconds=countdown)
+    else:
+        eta = sig.options.pop('eta', None)
+    if eta and isinstance(eta, datetime):
+        eta = eta.isoformat()
+    expires = sig.options.pop('expires', None)
+    if expires and isinstance(expires, int):
+        expires = app.now() + timedelta(seconds=expires)
+    if expires and isinstance(expires, datetime):
+        expires = expires.isoformat()
     return {
         'task': sig.task,
         'id': sig.id,
@@ -609,4 +616,5 @@ def body_from_sig(app, sig, utc=True):
         'errbacks': [dict(s) for s in errbacks] if errbacks else None,
         'eta': eta,
         'utc': utc,
+        'expires': expires,
     }

+ 115 - 3
celery/tests/worker/test_request.py

@@ -4,6 +4,7 @@ from __future__ import absolute_import, unicode_literals
 import anyjson
 import os
 import signal
+import socket
 import sys
 import time
 
@@ -22,6 +23,8 @@ from celery.exceptions import (
     WorkerLostError,
     InvalidTaskError,
     TaskRevokedError,
+    Terminated,
+    Ignore,
 )
 from celery.five import keys
 from celery.task.trace import (
@@ -39,10 +42,15 @@ from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.worker import job as module
-from celery.worker.job import Request, TaskRequest
+from celery.worker.job import Request, TaskRequest, logger as req_logger
 from celery.worker.state import revoked
 
-from celery.tests.utils import AppCase, Case, assert_signal_called
+from celery.tests.utils import (
+    AppCase,
+    Case,
+    assert_signal_called,
+    body_from_sig,
+)
 
 scratch = {'ACK': False}
 some_kwargs_scratchpad = {}
@@ -235,7 +243,111 @@ class MockEventDispatcher(object):
         self.sent.append(event)
 
 
-class test_TaskRequest(AppCase):
+class test_Request(AppCase):
+
+    def setup(self):
+
+        @self.app.task()
+        def add(x, y, **kw_):
+            return x + y
+
+        self.add = add
+
+    def get_request(self, sig, Request=Request, **kwargs):
+        return Request(
+            body_from_sig(self.app, sig),
+            on_ack=Mock(),
+            eventer=Mock(),
+            app=self.app,
+            connection_errors=(socket.error, ),
+            task=sig.type,
+            **kwargs
+        )
+
+    def test_invalid_eta_raises_InvalidTaskError(self):
+        with self.assertRaises(InvalidTaskError):
+            self.get_request(self.add.s(2, 2).set(eta='12345'))
+
+    def test_invalid_expires_raises_InvalidTaskError(self):
+        with self.assertRaises(InvalidTaskError):
+            self.get_request(self.add.s(2, 2).set(expires='12345'))
+
+    def test_valid_expires_with_utc_makes_aware(self):
+        with patch('celery.worker.job.maybe_make_aware') as mma:
+            self.get_request(self.add.s(2, 2).set(expires=10))
+            self.assertTrue(mma.called)
+
+    def test_maybe_expire_when_expires_is_None(self):
+        req = self.get_request(self.add.s(2, 2))
+        self.assertFalse(req.maybe_expire())
+
+    def test_on_retry_acks_if_late(self):
+        self.add.acks_late = True
+        try:
+            req = self.get_request(self.add.s(2, 2))
+            req.on_retry(Mock())
+            req.on_ack.assert_called_with(req_logger, req.connection_errors)
+        finally:
+            self.add.acks_late = False
+
+    def test_on_failure_Termianted(self):
+        einfo = None
+        try:
+            raise Terminated('9')
+        except Terminated:
+            einfo = ExceptionInfo()
+        self.assertIsNotNone(einfo)
+        req = self.get_request(self.add.s(2, 2))
+        req.on_failure(einfo)
+        req.eventer.send.assert_called_with(
+            'task-revoked',
+            uuid=req.id, terminated=True, signum='9', expired=False,
+        )
+
+    def test_log_error_propagates_MemoryError(self):
+        einfo = None
+        try:
+            raise MemoryError()
+        except MemoryError:
+            einfo = ExceptionInfo(internal=True)
+        self.assertIsNotNone(einfo)
+        req = self.get_request(self.add.s(2, 2))
+        with self.assertRaises(MemoryError):
+            req._log_error(einfo)
+
+    def test_log_error_when_Ignore(self):
+        einfo = None
+        try:
+            raise Ignore()
+        except Ignore:
+            einfo = ExceptionInfo(internal=True)
+        self.assertIsNotNone(einfo)
+        req = self.get_request(self.add.s(2, 2))
+        req._log_error(einfo)
+        req.on_ack.assert_called_with(req_logger, req.connection_errors)
+
+    def test_tzlocal_is_cached(self):
+        req = self.get_request(self.add.s(2, 2))
+        req._tzlocal = 'foo'
+        self.assertEqual(req.tzlocal, 'foo')
+
+    def test_execute_magic_kwargs(self):
+        task = self.add.s(2, 2)
+        task._freeze()
+        req = self.get_request(task)
+        self.add.accept_magic_kwargs = True
+        try:
+            pool = Mock()
+            req.execute_using_pool(pool)
+            self.assertTrue(pool.apply_async.called)
+            args = pool.apply_async.call_args[1]['args']
+            self.assertEqual(args[0], task.task)
+            self.assertEqual(args[1], task.id)
+            self.assertEqual(args[2], task.args)
+            kwargs = args[3]
+            self.assertEqual(kwargs.get('task_name'), task.task)
+        finally:
+            self.add.accept_magic_kwargs = False
 
     def test_task_wrapper_repr(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})

+ 11 - 8
celery/worker/job.py

@@ -70,7 +70,7 @@ NEEDS_KWDICT = sys.version_info <= (2, 6)
 
 class Request(object):
     """A request for task execution."""
-    if not IS_PYPY:
+    if not IS_PYPY:  # pragma: no cover
         __slots__ = (
             'app', 'name', 'id', 'args', 'kwargs', 'on_ack', 'delivery_info',
             'hostname', 'eventer', 'connection_errors', 'task', 'eta',
@@ -135,7 +135,7 @@ class Request(object):
         if eta is not None:
             try:
                 self.eta = maybe_iso8601(eta)
-            except (AttributeError, ValueError) as exc:
+            except (AttributeError, ValueError, TypeError) as exc:
                 raise InvalidTaskError(
                     'invalid eta value {0!r}: {1}'.format(eta, exc))
             if utc:
@@ -145,7 +145,7 @@ class Request(object):
         if expires is not None:
             try:
                 self.expires = maybe_iso8601(expires)
-            except (AttributeError, ValueError) as exc:
+            except (AttributeError, ValueError, TypeError) as exc:
                 raise InvalidTaskError(
                     'invalid expires value {0!r}: {1}'.format(expires, exc))
             if utc:
@@ -375,6 +375,7 @@ class Request(object):
     def on_failure(self, exc_info):
         """Handler called if the task raised an exception."""
         task_ready(self)
+        send_failed_event = True
 
         if not exc_info.internal:
             exc = exc_info.exception
@@ -389,12 +390,13 @@ class Request(object):
                     self.task.backend.mark_as_failure(self.id, exc)
                 elif isinstance(exc, Terminated):
                     self._announce_revoked('terminated', True, str(exc), False)
+                    send_failed_event = False  # already sent revoked event
             # (acks_late) acknowledge after result stored.
             if self.task.acks_late:
                 self.acknowledge()
-        self._log_error(exc_info)
+        self._log_error(exc_info, send_failed_event=send_failed_event)
 
-    def _log_error(self, einfo):
+    def _log_error(self, einfo, send_failed_event=True):
         einfo.exception = get_pickled_exception(einfo.exception)
         exception, traceback, exc_info, internal, sargs, skwargs = (
             safe_repr(einfo.exception),
@@ -407,9 +409,10 @@ class Request(object):
         format = self.error_msg
         description = 'raised exception'
         severity = logging.ERROR
-        self.send_event(
-            'task-failed', exception=exception, traceback=traceback,
-        )
+        if send_failed_event:
+            self.send_event(
+                'task-failed', exception=exception, traceback=traceback,
+            )
 
         if internal:
             if isinstance(einfo.exception, MemoryError):