|
@@ -4,6 +4,7 @@ from __future__ import absolute_import, unicode_literals
|
|
|
import anyjson
|
|
|
import os
|
|
|
import signal
|
|
|
+import socket
|
|
|
import sys
|
|
|
import time
|
|
|
|
|
@@ -22,6 +23,8 @@ from celery.exceptions import (
|
|
|
WorkerLostError,
|
|
|
InvalidTaskError,
|
|
|
TaskRevokedError,
|
|
|
+ Terminated,
|
|
|
+ Ignore,
|
|
|
)
|
|
|
from celery.five import keys
|
|
|
from celery.task.trace import (
|
|
@@ -39,10 +42,15 @@ from celery.task import task as task_dec
|
|
|
from celery.task.base import Task
|
|
|
from celery.utils import uuid
|
|
|
from celery.worker import job as module
|
|
|
-from celery.worker.job import Request, TaskRequest
|
|
|
+from celery.worker.job import Request, TaskRequest, logger as req_logger
|
|
|
from celery.worker.state import revoked
|
|
|
|
|
|
-from celery.tests.utils import AppCase, Case, assert_signal_called
|
|
|
+from celery.tests.utils import (
|
|
|
+ AppCase,
|
|
|
+ Case,
|
|
|
+ assert_signal_called,
|
|
|
+ body_from_sig,
|
|
|
+)
|
|
|
|
|
|
scratch = {'ACK': False}
|
|
|
some_kwargs_scratchpad = {}
|
|
@@ -235,7 +243,111 @@ class MockEventDispatcher(object):
|
|
|
self.sent.append(event)
|
|
|
|
|
|
|
|
|
-class test_TaskRequest(AppCase):
|
|
|
+class test_Request(AppCase):
|
|
|
+
|
|
|
+ def setup(self):
|
|
|
+
|
|
|
+ @self.app.task()
|
|
|
+ def add(x, y, **kw_):
|
|
|
+ return x + y
|
|
|
+
|
|
|
+ self.add = add
|
|
|
+
|
|
|
+ def get_request(self, sig, Request=Request, **kwargs):
|
|
|
+ return Request(
|
|
|
+ body_from_sig(self.app, sig),
|
|
|
+ on_ack=Mock(),
|
|
|
+ eventer=Mock(),
|
|
|
+ app=self.app,
|
|
|
+ connection_errors=(socket.error, ),
|
|
|
+ task=sig.type,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_invalid_eta_raises_InvalidTaskError(self):
|
|
|
+ with self.assertRaises(InvalidTaskError):
|
|
|
+ self.get_request(self.add.s(2, 2).set(eta='12345'))
|
|
|
+
|
|
|
+ def test_invalid_expires_raises_InvalidTaskError(self):
|
|
|
+ with self.assertRaises(InvalidTaskError):
|
|
|
+ self.get_request(self.add.s(2, 2).set(expires='12345'))
|
|
|
+
|
|
|
+ def test_valid_expires_with_utc_makes_aware(self):
|
|
|
+ with patch('celery.worker.job.maybe_make_aware') as mma:
|
|
|
+ self.get_request(self.add.s(2, 2).set(expires=10))
|
|
|
+ self.assertTrue(mma.called)
|
|
|
+
|
|
|
+ def test_maybe_expire_when_expires_is_None(self):
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ self.assertFalse(req.maybe_expire())
|
|
|
+
|
|
|
+ def test_on_retry_acks_if_late(self):
|
|
|
+ self.add.acks_late = True
|
|
|
+ try:
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ req.on_retry(Mock())
|
|
|
+ req.on_ack.assert_called_with(req_logger, req.connection_errors)
|
|
|
+ finally:
|
|
|
+ self.add.acks_late = False
|
|
|
+
|
|
|
+ def test_on_failure_Termianted(self):
|
|
|
+ einfo = None
|
|
|
+ try:
|
|
|
+ raise Terminated('9')
|
|
|
+ except Terminated:
|
|
|
+ einfo = ExceptionInfo()
|
|
|
+ self.assertIsNotNone(einfo)
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ req.on_failure(einfo)
|
|
|
+ req.eventer.send.assert_called_with(
|
|
|
+ 'task-revoked',
|
|
|
+ uuid=req.id, terminated=True, signum='9', expired=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_log_error_propagates_MemoryError(self):
|
|
|
+ einfo = None
|
|
|
+ try:
|
|
|
+ raise MemoryError()
|
|
|
+ except MemoryError:
|
|
|
+ einfo = ExceptionInfo(internal=True)
|
|
|
+ self.assertIsNotNone(einfo)
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ with self.assertRaises(MemoryError):
|
|
|
+ req._log_error(einfo)
|
|
|
+
|
|
|
+ def test_log_error_when_Ignore(self):
|
|
|
+ einfo = None
|
|
|
+ try:
|
|
|
+ raise Ignore()
|
|
|
+ except Ignore:
|
|
|
+ einfo = ExceptionInfo(internal=True)
|
|
|
+ self.assertIsNotNone(einfo)
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ req._log_error(einfo)
|
|
|
+ req.on_ack.assert_called_with(req_logger, req.connection_errors)
|
|
|
+
|
|
|
+ def test_tzlocal_is_cached(self):
|
|
|
+ req = self.get_request(self.add.s(2, 2))
|
|
|
+ req._tzlocal = 'foo'
|
|
|
+ self.assertEqual(req.tzlocal, 'foo')
|
|
|
+
|
|
|
+ def test_execute_magic_kwargs(self):
|
|
|
+ task = self.add.s(2, 2)
|
|
|
+ task._freeze()
|
|
|
+ req = self.get_request(task)
|
|
|
+ self.add.accept_magic_kwargs = True
|
|
|
+ try:
|
|
|
+ pool = Mock()
|
|
|
+ req.execute_using_pool(pool)
|
|
|
+ self.assertTrue(pool.apply_async.called)
|
|
|
+ args = pool.apply_async.call_args[1]['args']
|
|
|
+ self.assertEqual(args[0], task.task)
|
|
|
+ self.assertEqual(args[1], task.id)
|
|
|
+ self.assertEqual(args[2], task.args)
|
|
|
+ kwargs = args[3]
|
|
|
+ self.assertEqual(kwargs.get('task_name'), task.task)
|
|
|
+ finally:
|
|
|
+ self.add.accept_magic_kwargs = False
|
|
|
|
|
|
def test_task_wrapper_repr(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|