|
@@ -4,52 +4,55 @@ from datetime import timedelta
|
|
|
|
|
|
from pickle import loads, dumps
|
|
|
|
|
|
-from kombu.utils import cached_property, uuid
|
|
|
-
|
|
|
from celery import signature
|
|
|
from celery import states
|
|
|
from celery import group
|
|
|
+from celery import uuid
|
|
|
from celery.datastructures import AttributeDict
|
|
|
from celery.exceptions import ImproperlyConfigured
|
|
|
from celery.utils.timeutils import timedelta_seconds
|
|
|
|
|
|
from celery.tests.case import (
|
|
|
- AppCase, Mock, SkipTest, depends_on_current_app, patch,
|
|
|
+ AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
|
|
|
)
|
|
|
|
|
|
|
|
|
-class Redis(object):
|
|
|
+class Connection(object):
|
|
|
+ connected = True
|
|
|
+
|
|
|
+ def disconnect(self):
|
|
|
+ self.connected = False
|
|
|
+
|
|
|
|
|
|
- class Connection(object):
|
|
|
- connected = True
|
|
|
+class Pipeline(object):
|
|
|
|
|
|
- def disconnect(self):
|
|
|
- self.connected = False
|
|
|
+ def __init__(self, client):
|
|
|
+ self.client = client
|
|
|
+ self.steps = []
|
|
|
|
|
|
- class Pipeline(object):
|
|
|
+ def __getattr__(self, attr):
|
|
|
|
|
|
- def __init__(self, client):
|
|
|
- self.client = client
|
|
|
- self.steps = []
|
|
|
+ def add_step(*args, **kwargs):
|
|
|
+ self.steps.append((getattr(self.client, attr), args, kwargs))
|
|
|
+ return self
|
|
|
+ return add_step
|
|
|
|
|
|
- def __getattr__(self, attr):
|
|
|
+ def execute(self):
|
|
|
+ return [step(*a, **kw) for step, a, kw in self.steps]
|
|
|
|
|
|
- def add_step(*args, **kwargs):
|
|
|
- self.steps.append((getattr(self.client, attr), args, kwargs))
|
|
|
- return self
|
|
|
- return add_step
|
|
|
|
|
|
- def execute(self):
|
|
|
- return [step(*a, **kw) for step, a, kw in self.steps]
|
|
|
+class Redis(MockCallbacks):
|
|
|
+ Connection = Connection
|
|
|
+ Pipeline = Pipeline
|
|
|
|
|
|
def __init__(self, host=None, port=None, db=None, password=None, **kw):
|
|
|
self.host = host
|
|
|
self.port = port
|
|
|
self.db = db
|
|
|
self.password = password
|
|
|
- self.connection = self.Connection()
|
|
|
self.keyspace = {}
|
|
|
self.expiry = {}
|
|
|
+ self.connection = self.Connection()
|
|
|
|
|
|
def get(self, key):
|
|
|
return self.keyspace.get(key)
|
|
@@ -63,16 +66,30 @@ class Redis(object):
|
|
|
|
|
|
def expire(self, key, expires):
|
|
|
self.expiry[key] = expires
|
|
|
+ return expires
|
|
|
|
|
|
def delete(self, key):
|
|
|
- self.keyspace.pop(key)
|
|
|
-
|
|
|
- def publish(self, key, value):
|
|
|
- pass
|
|
|
+ return bool(self.keyspace.pop(key, None))
|
|
|
|
|
|
def pipeline(self):
|
|
|
return self.Pipeline(self)
|
|
|
|
|
|
+ def _get_list(self, key):
|
|
|
+ try:
|
|
|
+ return self.keyspace[key]
|
|
|
+ except KeyError:
|
|
|
+ l = self.keyspace[key] = []
|
|
|
+ return l
|
|
|
+
|
|
|
+ def rpush(self, key, value):
|
|
|
+ self._get_list(key).append(value)
|
|
|
+
|
|
|
+ def lrange(self, key, start, stop):
|
|
|
+ return self._get_list(key)[start:stop]
|
|
|
+
|
|
|
+ def llen(self, key):
|
|
|
+ return len(self.keyspace.get(key) or [])
|
|
|
+
|
|
|
|
|
|
class redis(object):
|
|
|
Redis = Redis
|
|
@@ -91,41 +108,34 @@ class redis(object):
|
|
|
class test_RedisBackend(AppCase):
|
|
|
|
|
|
def get_backend(self):
|
|
|
- from celery.backends import redis
|
|
|
+ from celery.backends.redis import RedisBackend
|
|
|
|
|
|
- class RedisBackend(redis.RedisBackend):
|
|
|
+ class _RedisBackend(RedisBackend):
|
|
|
redis = redis
|
|
|
|
|
|
- return RedisBackend
|
|
|
+ return _RedisBackend
|
|
|
|
|
|
def setup(self):
|
|
|
self.Backend = self.get_backend()
|
|
|
|
|
|
- class MockBackend(self.Backend):
|
|
|
-
|
|
|
- @cached_property
|
|
|
- def client(self):
|
|
|
- return Mock()
|
|
|
-
|
|
|
- self.MockBackend = MockBackend
|
|
|
-
|
|
|
@depends_on_current_app
|
|
|
def test_reduce(self):
|
|
|
try:
|
|
|
from celery.backends.redis import RedisBackend
|
|
|
- x = RedisBackend(app=self.app)
|
|
|
+ x = RedisBackend(app=self.app, new_join=True)
|
|
|
self.assertTrue(loads(dumps(x)))
|
|
|
except ImportError:
|
|
|
raise SkipTest('redis not installed')
|
|
|
|
|
|
def test_no_redis(self):
|
|
|
- self.MockBackend.redis = None
|
|
|
+ self.Backend.redis = None
|
|
|
with self.assertRaises(ImproperlyConfigured):
|
|
|
- self.MockBackend(app=self.app)
|
|
|
+ self.Backend(app=self.app, new_join=True)
|
|
|
|
|
|
def test_url(self):
|
|
|
- x = self.MockBackend(
|
|
|
+ x = self.Backend(
|
|
|
'redis://:bosco@vandelay.com:123//1', app=self.app,
|
|
|
+ new_join=True,
|
|
|
)
|
|
|
self.assertTrue(x.connparams)
|
|
|
self.assertEqual(x.connparams['host'], 'vandelay.com')
|
|
@@ -134,8 +144,9 @@ class test_RedisBackend(AppCase):
|
|
|
self.assertEqual(x.connparams['password'], 'bosco')
|
|
|
|
|
|
def test_socket_url(self):
|
|
|
- x = self.MockBackend(
|
|
|
+ x = self.Backend(
|
|
|
'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
|
|
|
+ new_join=True,
|
|
|
)
|
|
|
self.assertTrue(x.connparams)
|
|
|
self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
|
|
@@ -148,8 +159,9 @@ class test_RedisBackend(AppCase):
|
|
|
self.assertEqual(x.connparams['db'], 3)
|
|
|
|
|
|
def test_compat_propertie(self):
|
|
|
- x = self.MockBackend(
|
|
|
+ x = self.Backend(
|
|
|
'redis://:bosco@vandelay.com:123//1', app=self.app,
|
|
|
+ new_join=True,
|
|
|
)
|
|
|
with self.assertPendingDeprecation():
|
|
|
self.assertEqual(x.host, 'vandelay.com')
|
|
@@ -167,71 +179,85 @@ class test_RedisBackend(AppCase):
|
|
|
'CELERY_ACCEPT_CONTENT': ['json'],
|
|
|
'CELERY_TASK_RESULT_EXPIRES': None,
|
|
|
})
|
|
|
- self.MockBackend(app=self.app)
|
|
|
+ self.Backend(app=self.app, new_join=True)
|
|
|
|
|
|
def test_expires_defaults_to_config(self):
|
|
|
self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
|
|
|
- b = self.Backend(expires=None, app=self.app)
|
|
|
+ b = self.Backend(expires=None, app=self.app, new_join=True)
|
|
|
self.assertEqual(b.expires, 10)
|
|
|
|
|
|
def test_expires_is_int(self):
|
|
|
- b = self.Backend(expires=48, app=self.app)
|
|
|
+ b = self.Backend(expires=48, app=self.app, new_join=True)
|
|
|
self.assertEqual(b.expires, 48)
|
|
|
|
|
|
+ def test_set_new_join_from_url_query(self):
|
|
|
+ b = self.Backend('redis://?new_join=True;foobar=1', app=self.app)
|
|
|
+ self.assertEqual(b.on_chord_part_return, b._new_chord_return)
|
|
|
+ self.assertEqual(b.apply_chord, b._new_chord_apply)
|
|
|
+
|
|
|
+ def test_default_is_old_join(self):
|
|
|
+ b = self.Backend(app=self.app)
|
|
|
+ self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
|
|
|
+ self.assertNotEqual(b.apply_chord, b._new_chord_apply)
|
|
|
+
|
|
|
def test_expires_is_None(self):
|
|
|
- b = self.Backend(expires=None, app=self.app)
|
|
|
+ b = self.Backend(expires=None, app=self.app, new_join=True)
|
|
|
self.assertEqual(b.expires, timedelta_seconds(
|
|
|
self.app.conf.CELERY_TASK_RESULT_EXPIRES))
|
|
|
|
|
|
def test_expires_is_timedelta(self):
|
|
|
- b = self.Backend(expires=timedelta(minutes=1), app=self.app)
|
|
|
+ b = self.Backend(
|
|
|
+ expires=timedelta(minutes=1), app=self.app, new_join=1,
|
|
|
+ )
|
|
|
self.assertEqual(b.expires, 60)
|
|
|
|
|
|
def test_apply_chord(self):
|
|
|
- self.Backend(app=self.app).apply_chord(
|
|
|
+ self.Backend(app=self.app, new_join=True).apply_chord(
|
|
|
group(app=self.app), (), 'group_id', {},
|
|
|
result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
|
|
|
)
|
|
|
|
|
|
def test_mget(self):
|
|
|
- b = self.MockBackend(app=self.app)
|
|
|
+ b = self.Backend(app=self.app, new_join=True)
|
|
|
self.assertTrue(b.mget(['a', 'b', 'c']))
|
|
|
b.client.mget.assert_called_with(['a', 'b', 'c'])
|
|
|
|
|
|
def test_set_no_expire(self):
|
|
|
- b = self.MockBackend(app=self.app)
|
|
|
+ b = self.Backend(app=self.app, new_join=True)
|
|
|
b.expires = None
|
|
|
b.set('foo', 'bar')
|
|
|
|
|
|
@patch('celery.result.GroupResult.restore')
|
|
|
def test_on_chord_part_return(self, restore):
|
|
|
- b = self.MockBackend(app=self.app)
|
|
|
- deps = Mock()
|
|
|
- deps.__len__ = Mock()
|
|
|
- deps.__len__.return_value = 10
|
|
|
- restore.return_value = deps
|
|
|
- b.client.incr.return_value = 1
|
|
|
- task = Mock()
|
|
|
- task.name = 'foobarbaz'
|
|
|
- self.app.tasks['foobarbaz'] = task
|
|
|
- task.request.chord = signature(task)
|
|
|
- task.request.group = 'group_id'
|
|
|
-
|
|
|
- b.on_chord_part_return(task)
|
|
|
- self.assertTrue(b.client.incr.call_count)
|
|
|
-
|
|
|
- b.client.incr.return_value = len(deps)
|
|
|
- b.on_chord_part_return(task)
|
|
|
- deps.join_native.assert_called_with(propagate=True, timeout=3.0)
|
|
|
- deps.delete.assert_called_with()
|
|
|
-
|
|
|
- self.assertTrue(b.client.expire.call_count)
|
|
|
+ b = self.Backend(app=self.app, new_join=True)
|
|
|
+
|
|
|
+ def create_task():
|
|
|
+ tid = uuid()
|
|
|
+ task = Mock(name='task-{0}'.format(tid))
|
|
|
+ task.name = 'foobarbaz'
|
|
|
+ self.app.tasks['foobarbaz'] = task
|
|
|
+ task.request.chord = signature(task)
|
|
|
+ task.request.id = tid
|
|
|
+ task.request.chord['chord_size'] = 10
|
|
|
+ task.request.group = 'group_id'
|
|
|
+ return task
|
|
|
+
|
|
|
+ tasks = [create_task() for i in range(10)]
|
|
|
+
|
|
|
+ for i in range(10):
|
|
|
+ b.on_chord_part_return(tasks[i], states.SUCCESS, i)
|
|
|
+ self.assertTrue(b.client.rpush.call_count)
|
|
|
+ b.client.rpush.reset_mock()
|
|
|
+ self.assertTrue(b.client.lrange.call_count)
|
|
|
+ gkey = b.get_key_for_group('group_id', '.j')
|
|
|
+ b.client.delete.assert_called_with(gkey)
|
|
|
+ b.client.expire.assert_called_witeh(gkey, 86400)
|
|
|
|
|
|
def test_process_cleanup(self):
|
|
|
- self.Backend(app=self.app).process_cleanup()
|
|
|
+ self.Backend(app=self.app, new_join=True).process_cleanup()
|
|
|
|
|
|
def test_get_set_forget(self):
|
|
|
- b = self.Backend(app=self.app)
|
|
|
+ b = self.Backend(app=self.app, new_join=True)
|
|
|
tid = uuid()
|
|
|
b.store_result(tid, 42, states.SUCCESS)
|
|
|
self.assertEqual(b.get_status(tid), states.SUCCESS)
|
|
@@ -240,8 +266,10 @@ class test_RedisBackend(AppCase):
|
|
|
self.assertEqual(b.get_status(tid), states.PENDING)
|
|
|
|
|
|
def test_set_expires(self):
|
|
|
- b = self.Backend(expires=512, app=self.app)
|
|
|
+ b = self.Backend(expires=512, app=self.app, new_join=True)
|
|
|
tid = uuid()
|
|
|
key = b.get_key_for_task(tid)
|
|
|
b.store_result(tid, 42, states.SUCCESS)
|
|
|
- self.assertEqual(b.client.expiry[key], 512)
|
|
|
+ b.client.expire.assert_called_with(
|
|
|
+ key, 512,
|
|
|
+ )
|