123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- from __future__ import absolute_import
- from datetime import timedelta
- from pickle import loads, dumps
- from celery import signature
- from celery import states
- from celery import group
- from celery import uuid
- from celery.datastructures import AttributeDict
- from celery.exceptions import ImproperlyConfigured
- from celery.utils.timeutils import timedelta_seconds
- from celery.tests.case import (
- AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
- )
- class Connection(object):
- connected = True
- def disconnect(self):
- self.connected = False
- class Pipeline(object):
- def __init__(self, client):
- self.client = client
- self.steps = []
- def __getattr__(self, attr):
- def add_step(*args, **kwargs):
- self.steps.append((getattr(self.client, attr), args, kwargs))
- return self
- return add_step
- def __enter__(self):
- return self
- def __exit__(self, type, value, traceback):
- pass
- def execute(self):
- return [step(*a, **kw) for step, a, kw in self.steps]
- class Redis(MockCallbacks):
- Connection = Connection
- Pipeline = Pipeline
- def __init__(self, host=None, port=None, db=None, password=None, **kw):
- self.host = host
- self.port = port
- self.db = db
- self.password = password
- self.keyspace = {}
- self.expiry = {}
- self.connection = self.Connection()
- def get(self, key):
- return self.keyspace.get(key)
- def setex(self, key, value, expires):
- self.set(key, value)
- self.expire(key, expires)
- def set(self, key, value):
- self.keyspace[key] = value
- def expire(self, key, expires):
- self.expiry[key] = expires
- return expires
- def delete(self, key):
- return bool(self.keyspace.pop(key, None))
- def pipeline(self):
- return self.Pipeline(self)
- def _get_list(self, key):
- try:
- return self.keyspace[key]
- except KeyError:
- l = self.keyspace[key] = []
- return l
- def rpush(self, key, value):
- self._get_list(key).append(value)
- def lrange(self, key, start, stop):
- return self._get_list(key)[start:stop]
- def llen(self, key):
- return len(self.keyspace.get(key) or [])
- class redis(object):
- Redis = Redis
- class ConnectionPool(object):
- def __init__(self, **kwargs):
- pass
- class UnixDomainSocketConnection(object):
- def __init__(self, **kwargs):
- pass
- class test_RedisBackend(AppCase):
- def get_backend(self):
- from celery.backends.redis import RedisBackend
- class _RedisBackend(RedisBackend):
- redis = redis
- return _RedisBackend
- def setup(self):
- self.Backend = self.get_backend()
- @depends_on_current_app
- def test_reduce(self):
- try:
- from celery.backends.redis import RedisBackend
- x = RedisBackend(app=self.app, new_join=True)
- self.assertTrue(loads(dumps(x)))
- except ImportError:
- raise SkipTest('redis not installed')
- def test_no_redis(self):
- self.Backend.redis = None
- with self.assertRaises(ImproperlyConfigured):
- self.Backend(app=self.app, new_join=True)
- def test_url(self):
- x = self.Backend(
- 'redis://:bosco@vandelay.com:123//1', app=self.app,
- new_join=True,
- )
- self.assertTrue(x.connparams)
- self.assertEqual(x.connparams['host'], 'vandelay.com')
- self.assertEqual(x.connparams['db'], 1)
- self.assertEqual(x.connparams['port'], 123)
- self.assertEqual(x.connparams['password'], 'bosco')
- def test_socket_url(self):
- x = self.Backend(
- 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
- new_join=True,
- )
- self.assertTrue(x.connparams)
- self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
- self.assertIs(
- x.connparams['connection_class'],
- redis.UnixDomainSocketConnection,
- )
- self.assertNotIn('host', x.connparams)
- self.assertNotIn('port', x.connparams)
- self.assertEqual(x.connparams['db'], 3)
- def test_compat_propertie(self):
- x = self.Backend(
- 'redis://:bosco@vandelay.com:123//1', app=self.app,
- new_join=True,
- )
- with self.assertPendingDeprecation():
- self.assertEqual(x.host, 'vandelay.com')
- with self.assertPendingDeprecation():
- self.assertEqual(x.db, 1)
- with self.assertPendingDeprecation():
- self.assertEqual(x.port, 123)
- with self.assertPendingDeprecation():
- self.assertEqual(x.password, 'bosco')
- def test_conf_raises_KeyError(self):
- self.app.conf = AttributeDict({
- 'CELERY_RESULT_SERIALIZER': 'json',
- 'CELERY_MAX_CACHED_RESULTS': 1,
- 'CELERY_ACCEPT_CONTENT': ['json'],
- 'CELERY_TASK_RESULT_EXPIRES': None,
- })
- self.Backend(app=self.app, new_join=True)
- def test_expires_defaults_to_config(self):
- self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
- b = self.Backend(expires=None, app=self.app, new_join=True)
- self.assertEqual(b.expires, 10)
- def test_expires_is_int(self):
- b = self.Backend(expires=48, app=self.app, new_join=True)
- self.assertEqual(b.expires, 48)
- def test_set_new_join_from_url_query(self):
- b = self.Backend('redis://?new_join=True;foobar=1', app=self.app)
- self.assertEqual(b.on_chord_part_return, b._new_chord_return)
- self.assertEqual(b.apply_chord, b._new_chord_apply)
- def test_default_is_old_join(self):
- b = self.Backend(app=self.app)
- self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
- self.assertNotEqual(b.apply_chord, b._new_chord_apply)
- def test_expires_is_None(self):
- b = self.Backend(expires=None, app=self.app, new_join=True)
- self.assertEqual(b.expires, timedelta_seconds(
- self.app.conf.CELERY_TASK_RESULT_EXPIRES))
- def test_expires_is_timedelta(self):
- b = self.Backend(
- expires=timedelta(minutes=1), app=self.app, new_join=1,
- )
- self.assertEqual(b.expires, 60)
- def test_apply_chord(self):
- self.Backend(app=self.app, new_join=True).apply_chord(
- group(app=self.app), (), 'group_id', {},
- result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
- )
- def test_mget(self):
- b = self.Backend(app=self.app, new_join=True)
- self.assertTrue(b.mget(['a', 'b', 'c']))
- b.client.mget.assert_called_with(['a', 'b', 'c'])
- def test_set_no_expire(self):
- b = self.Backend(app=self.app, new_join=True)
- b.expires = None
- b.set('foo', 'bar')
- @patch('celery.result.GroupResult.restore')
- def test_on_chord_part_return(self, restore):
- b = self.Backend(app=self.app, new_join=True)
- def create_task():
- tid = uuid()
- task = Mock(name='task-{0}'.format(tid))
- task.name = 'foobarbaz'
- self.app.tasks['foobarbaz'] = task
- task.request.chord = signature(task)
- task.request.id = tid
- task.request.chord['chord_size'] = 10
- task.request.group = 'group_id'
- return task
- tasks = [create_task() for i in range(10)]
- for i in range(10):
- b.on_chord_part_return(tasks[i], states.SUCCESS, i)
- self.assertTrue(b.client.rpush.call_count)
- b.client.rpush.reset_mock()
- self.assertTrue(b.client.lrange.call_count)
- gkey = b.get_key_for_group('group_id', '.j')
- b.client.delete.assert_called_with(gkey)
- b.client.expire.assert_called_with(gkey, 86400)
- def test_process_cleanup(self):
- self.Backend(app=self.app, new_join=True).process_cleanup()
- def test_get_set_forget(self):
- b = self.Backend(app=self.app, new_join=True)
- tid = uuid()
- b.store_result(tid, 42, states.SUCCESS)
- self.assertEqual(b.get_status(tid), states.SUCCESS)
- self.assertEqual(b.get_result(tid), 42)
- b.forget(tid)
- self.assertEqual(b.get_status(tid), states.PENDING)
- def test_set_expires(self):
- b = self.Backend(expires=512, app=self.app, new_join=True)
- tid = uuid()
- key = b.get_key_for_task(tid)
- b.store_result(tid, 42, states.SUCCESS)
- b.client.expire.assert_called_with(
- key, 512,
- )
|