|
@@ -2,21 +2,32 @@ from __future__ import absolute_import
|
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
+from contextlib import contextmanager
|
|
|
from pickle import loads, dumps
|
|
|
|
|
|
from celery import signature
|
|
|
from celery import states
|
|
|
-from celery import group
|
|
|
from celery import uuid
|
|
|
+from celery.canvas import Signature
|
|
|
from celery.datastructures import AttributeDict
|
|
|
-from celery.exceptions import ImproperlyConfigured
|
|
|
+from celery.exceptions import ChordError, ImproperlyConfigured
|
|
|
|
|
|
from celery.tests.case import (
|
|
|
- AppCase, Mock, MockCallbacks, SkipTest,
|
|
|
+ ANY, AppCase, ContextMock, Mock, MockCallbacks, SkipTest,
|
|
|
call, depends_on_current_app, patch,
|
|
|
)
|
|
|
|
|
|
|
|
|
+def raise_on_second_call(mock, exc, *retval):
|
|
|
+
|
|
|
+ def on_first_call(*args, **kwargs):
|
|
|
+ mock.side_effect = exc
|
|
|
+ return mock.return_value
|
|
|
+ mock.side_effect = on_first_call
|
|
|
+ if retval:
|
|
|
+ mock.return_value, = retval
|
|
|
+
|
|
|
+
|
|
|
class Connection(object):
|
|
|
connected = True
|
|
|
|
|
@@ -121,8 +132,14 @@ class test_RedisBackend(AppCase):
|
|
|
|
|
|
return _RedisBackend
|
|
|
|
|
|
+ def get_E_LOST(self):
|
|
|
+ from celery.backends.redis import E_LOST
|
|
|
+ return E_LOST
|
|
|
+
|
|
|
def setup(self):
|
|
|
self.Backend = self.get_backend()
|
|
|
+ self.E_LOST = self.get_E_LOST()
|
|
|
+ self.b = self.Backend(app=self.app)
|
|
|
|
|
|
@depends_on_current_app
|
|
|
def test_reduce(self):
|
|
@@ -184,6 +201,70 @@ class test_RedisBackend(AppCase):
|
|
|
})
|
|
|
self.Backend(app=self.app)
|
|
|
|
|
|
+ @patch('celery.backends.redis.error')
|
|
|
+ def test_on_connection_error(self, error):
|
|
|
+ intervals = iter([10, 20, 30])
|
|
|
+ exc = KeyError()
|
|
|
+ self.assertEqual(
|
|
|
+ self.b.on_connection_error(None, exc, intervals, 1), 10,
|
|
|
+ )
|
|
|
+ error.assert_called_with(self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
|
|
|
+ self.assertEqual(
|
|
|
+ self.b.on_connection_error(10, exc, intervals, 2), 20,
|
|
|
+ )
|
|
|
+ error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
|
|
|
+ self.assertEqual(
|
|
|
+ self.b.on_connection_error(10, exc, intervals, 3), 30,
|
|
|
+ )
|
|
|
+ error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
|
|
|
+
|
|
|
+ def test_incr(self):
|
|
|
+ self.b.client = Mock(name='client')
|
|
|
+ self.b.incr('foo')
|
|
|
+ self.b.client.incr.assert_called_with('foo')
|
|
|
+
|
|
|
+ def test_expire(self):
|
|
|
+ self.b.client = Mock(name='client')
|
|
|
+ self.b.expire('foo', 300)
|
|
|
+ self.b.client.expire.assert_called_with('foo', 300)
|
|
|
+
|
|
|
+ def test_apply_chord(self):
|
|
|
+ header = Mock(name='header')
|
|
|
+ header.results = [Mock(name='t1'), Mock(name='t2')]
|
|
|
+ print(self.b.apply_chord,)
|
|
|
+ self.b.apply_chord(
|
|
|
+ header, (1, 2), 'gid', None,
|
|
|
+ options={'max_retries': 10},
|
|
|
+ )
|
|
|
+ header.assert_called_with(1, 2, max_retries=10, task_id='gid')
|
|
|
+
|
|
|
+ def test_unpack_chord_result(self):
|
|
|
+ self.b.exception_to_python = Mock(name='etp')
|
|
|
+ decode = Mock(name='decode')
|
|
|
+ exc = KeyError()
|
|
|
+ tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
|
|
|
+ with self.assertRaises(ChordError):
|
|
|
+ self.b._unpack_chord_result(tup, decode)
|
|
|
+ decode.assert_called_with(tup)
|
|
|
+ self.b.exception_to_python.assert_called_with(exc)
|
|
|
+
|
|
|
+ exc = ValueError()
|
|
|
+ tup = decode.return_value = (2, 'id2', states.RETRY, exc)
|
|
|
+ ret = self.b._unpack_chord_result(tup, decode)
|
|
|
+ self.b.exception_to_python.assert_called_with(exc)
|
|
|
+ self.assertIs(ret, self.b.exception_to_python())
|
|
|
+
|
|
|
+ def test_on_chord_part_return_no_gid_or_tid(self):
|
|
|
+ request = Mock(name='request')
|
|
|
+ request.id = request.group = None
|
|
|
+ self.assertIsNone(self.b.on_chord_part_return(request, 'SUCCESS', 10))
|
|
|
+
|
|
|
+ def test_ConnectionPool(self):
|
|
|
+ self.b.redis = Mock(name='redis')
|
|
|
+ self.assertIsNone(self.b._ConnectionPool)
|
|
|
+ self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
|
|
|
+ self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
|
|
|
+
|
|
|
def test_expires_defaults_to_config(self):
|
|
|
self.app.conf.result_expires = 10
|
|
|
b = self.Backend(expires=None, app=self.app)
|
|
@@ -210,68 +291,110 @@ class test_RedisBackend(AppCase):
|
|
|
b = self.Backend(expires=timedelta(minutes=1), app=self.app)
|
|
|
self.assertEqual(b.expires, 60)
|
|
|
|
|
|
- def test_apply_chord(self):
|
|
|
- self.Backend(app=self.app).apply_chord(
|
|
|
- group(app=self.app), (), 'group_id', {},
|
|
|
- result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
|
|
|
- )
|
|
|
-
|
|
|
def test_mget(self):
|
|
|
- b = self.Backend(app=self.app)
|
|
|
- self.assertTrue(b.mget(['a', 'b', 'c']))
|
|
|
- b.client.mget.assert_called_with(['a', 'b', 'c'])
|
|
|
+ self.assertTrue(self.b.mget(['a', 'b', 'c']))
|
|
|
+ self.b.client.mget.assert_called_with(['a', 'b', 'c'])
|
|
|
|
|
|
def test_set_no_expire(self):
|
|
|
- b = self.Backend(app=self.app)
|
|
|
- b.expires = None
|
|
|
- b.set('foo', 'bar')
|
|
|
+ self.b.expires = None
|
|
|
+ self.b.set('foo', 'bar')
|
|
|
+
|
|
|
+ def create_task(self):
|
|
|
+ tid = uuid()
|
|
|
+ task = Mock(name='task-{0}'.format(tid))
|
|
|
+ task.name = 'foobarbaz'
|
|
|
+ self.app.tasks['foobarbaz'] = task
|
|
|
+ task.request.chord = signature(task)
|
|
|
+ task.request.id = tid
|
|
|
+ task.request.chord['chord_size'] = 10
|
|
|
+ task.request.group = 'group_id'
|
|
|
+ return task
|
|
|
|
|
|
@patch('celery.result.GroupResult.restore')
|
|
|
def test_on_chord_part_return(self, restore):
|
|
|
- b = self.Backend(app=self.app)
|
|
|
-
|
|
|
- def create_task():
|
|
|
- tid = uuid()
|
|
|
- task = Mock(name='task-{0}'.format(tid))
|
|
|
- task.name = 'foobarbaz'
|
|
|
- self.app.tasks['foobarbaz'] = task
|
|
|
- task.request.chord = signature(task)
|
|
|
- task.request.id = tid
|
|
|
- task.request.chord['chord_size'] = 10
|
|
|
- task.request.group = 'group_id'
|
|
|
- return task
|
|
|
-
|
|
|
- tasks = [create_task() for i in range(10)]
|
|
|
+ tasks = [self.create_task() for i in range(10)]
|
|
|
|
|
|
for i in range(10):
|
|
|
- b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
|
|
|
- self.assertTrue(b.client.rpush.call_count)
|
|
|
- b.client.rpush.reset_mock()
|
|
|
- self.assertTrue(b.client.lrange.call_count)
|
|
|
- jkey = b.get_key_for_group('group_id', '.j')
|
|
|
- tkey = b.get_key_for_group('group_id', '.t')
|
|
|
- b.client.delete.assert_has_calls([call(jkey), call(tkey)])
|
|
|
- b.client.expire.assert_has_calls([
|
|
|
+ self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
|
|
|
+ self.assertTrue(self.b.client.rpush.call_count)
|
|
|
+ self.b.client.rpush.reset_mock()
|
|
|
+ self.assertTrue(self.b.client.lrange.call_count)
|
|
|
+ jkey = self.b.get_key_for_group('group_id', '.j')
|
|
|
+ tkey = self.b.get_key_for_group('group_id', '.t')
|
|
|
+ self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
|
|
|
+ self.b.client.expire.assert_has_calls([
|
|
|
call(jkey, 86400), call(tkey, 86400),
|
|
|
])
|
|
|
|
|
|
+ def test_on_chord_part_return__success(self):
|
|
|
+ with self.chord_context(2) as (_, request, callback):
|
|
|
+ self.b.on_chord_part_return(request, states.SUCCESS, 10)
|
|
|
+ self.assertFalse(callback.delay.called)
|
|
|
+ self.b.on_chord_part_return(request, states.SUCCESS, 20)
|
|
|
+ callback.delay.assert_called_with([10, 20])
|
|
|
+
|
|
|
+ def test_on_chord_part_return__callback_raises(self):
|
|
|
+ with self.chord_context(1) as (_, request, callback):
|
|
|
+ callback.delay.side_effect = KeyError(10)
|
|
|
+ task = self.app._tasks['add'] = Mock(name='add_task')
|
|
|
+ self.b.on_chord_part_return(request, states.SUCCESS, 10)
|
|
|
+ task.backend.fail_from_current_stack.assert_called_with(
|
|
|
+ callback.id, exc=ANY,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_on_chord_part_return__ChordError(self):
|
|
|
+ with self.chord_context(1) as (_, request, callback):
|
|
|
+ self.b.client.pipeline = ContextMock()
|
|
|
+ raise_on_second_call(self.b.client.pipeline, ChordError())
|
|
|
+ self.b.client.pipeline.return_value.rpush().llen().get().expire(
|
|
|
+ ).expire().execute.return_value = (1, 1, 0, 4, 5)
|
|
|
+ task = self.app._tasks['add'] = Mock(name='add_task')
|
|
|
+ self.b.on_chord_part_return(request, states.SUCCESS, 10)
|
|
|
+ task.backend.fail_from_current_stack.assert_called_with(
|
|
|
+ callback.id, exc=ANY,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_on_chord_part_return__other_error(self):
|
|
|
+ with self.chord_context(1) as (_, request, callback):
|
|
|
+ self.b.client.pipeline = ContextMock()
|
|
|
+ raise_on_second_call(self.b.client.pipeline, RuntimeError())
|
|
|
+ self.b.client.pipeline.return_value.rpush().llen().get().expire(
|
|
|
+ ).expire().execute.return_value = (1, 1, 0, 4, 5)
|
|
|
+ task = self.app._tasks['add'] = Mock(name='add_task')
|
|
|
+ self.b.on_chord_part_return(request, states.SUCCESS, 10)
|
|
|
+ task.backend.fail_from_current_stack.assert_called_with(
|
|
|
+ callback.id, exc=ANY,
|
|
|
+ )
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def chord_context(self, size=1):
|
|
|
+ with patch('celery.backends.redis.maybe_signature') as ms:
|
|
|
+ tasks = [self.create_task() for i in range(size)]
|
|
|
+ request = Mock(name='request')
|
|
|
+ request.id = 'id1'
|
|
|
+ request.group = 'gid1'
|
|
|
+ callback = ms.return_value = Signature('add')
|
|
|
+ callback.id = 'id1'
|
|
|
+ callback['chord_size'] = size
|
|
|
+ callback.delay = Mock(name='callback.delay')
|
|
|
+ yield tasks, request, callback
|
|
|
+
|
|
|
def test_process_cleanup(self):
|
|
|
- self.Backend(app=self.app).process_cleanup()
|
|
|
+ self.b.process_cleanup()
|
|
|
|
|
|
def test_get_set_forget(self):
|
|
|
- b = self.Backend(app=self.app)
|
|
|
tid = uuid()
|
|
|
- b.store_result(tid, 42, states.SUCCESS)
|
|
|
- self.assertEqual(b.get_status(tid), states.SUCCESS)
|
|
|
- self.assertEqual(b.get_result(tid), 42)
|
|
|
- b.forget(tid)
|
|
|
- self.assertEqual(b.get_status(tid), states.PENDING)
|
|
|
+ self.b.store_result(tid, 42, states.SUCCESS)
|
|
|
+ self.assertEqual(self.b.get_status(tid), states.SUCCESS)
|
|
|
+ self.assertEqual(self.b.get_result(tid), 42)
|
|
|
+ self.b.forget(tid)
|
|
|
+ self.assertEqual(self.b.get_status(tid), states.PENDING)
|
|
|
|
|
|
def test_set_expires(self):
|
|
|
- b = self.Backend(expires=512, app=self.app)
|
|
|
+ self.b = self.Backend(expires=512, app=self.app)
|
|
|
tid = uuid()
|
|
|
- key = b.get_key_for_task(tid)
|
|
|
- b.store_result(tid, 42, states.SUCCESS)
|
|
|
- b.client.expire.assert_called_with(
|
|
|
+ key = self.b.get_key_for_task(tid)
|
|
|
+ self.b.store_result(tid, 42, states.SUCCESS)
|
|
|
+ self.b.client.expire.assert_called_with(
|
|
|
key, 512,
|
|
|
)
|