123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656 |
- from __future__ import absolute_import, unicode_literals
- import random
- import ssl
- from contextlib import contextmanager
- from datetime import timedelta
- from pickle import dumps, loads
- import pytest
- from case import ANY, ContextMock, Mock, call, mock, patch, skip
- from celery import signature, states, uuid
- from celery.canvas import Signature
- from celery.exceptions import (ChordError, CPendingDeprecationWarning,
- ImproperlyConfigured)
- from celery.utils.collections import AttributeDict
- def raise_on_second_call(mock, exc, *retval):
- def on_first_call(*args, **kwargs):
- mock.side_effect = exc
- return mock.return_value
- mock.side_effect = on_first_call
- if retval:
- mock.return_value, = retval
- class Connection(object):
- connected = True
- def disconnect(self):
- self.connected = False
- class Pipeline(object):
- def __init__(self, client):
- self.client = client
- self.steps = []
- def __getattr__(self, attr):
- def add_step(*args, **kwargs):
- self.steps.append((getattr(self.client, attr), args, kwargs))
- return self
- return add_step
- def __enter__(self):
- return self
- def __exit__(self, type, value, traceback):
- pass
- def execute(self):
- return [step(*a, **kw) for step, a, kw in self.steps]
- class Redis(mock.MockCallbacks):
- Connection = Connection
- Pipeline = Pipeline
- def __init__(self, host=None, port=None, db=None, password=None, **kw):
- self.host = host
- self.port = port
- self.db = db
- self.password = password
- self.keyspace = {}
- self.expiry = {}
- self.connection = self.Connection()
- def get(self, key):
- return self.keyspace.get(key)
- def setex(self, key, expires, value):
- self.set(key, value)
- self.expire(key, expires)
- def set(self, key, value):
- self.keyspace[key] = value
- def expire(self, key, expires):
- self.expiry[key] = expires
- return expires
- def delete(self, key):
- return bool(self.keyspace.pop(key, None))
- def pipeline(self):
- return self.Pipeline(self)
- def _get_list(self, key):
- try:
- return self.keyspace[key]
- except KeyError:
- l = self.keyspace[key] = []
- return l
- def rpush(self, key, value):
- self._get_list(key).append(value)
- def lrange(self, key, start, stop):
- return self._get_list(key)[start:stop]
- def llen(self, key):
- return len(self.keyspace.get(key) or [])
- class Sentinel(mock.MockCallbacks):
- def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None,
- **connection_kwargs):
- self.sentinel_kwargs = sentinel_kwargs
- self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs)
- for hostname, port in sentinels]
- self.min_other_sentinels = min_other_sentinels
- self.connection_kwargs = connection_kwargs
- def master_for(self, service_name, redis_class):
- return random.choice(self.sentinels)
- class redis(object):
- StrictRedis = Redis
- class ConnectionPool(object):
- def __init__(self, **kwargs):
- pass
- class UnixDomainSocketConnection(object):
- def __init__(self, **kwargs):
- pass
- class sentinel(object):
- Sentinel = Sentinel
- class test_RedisResultConsumer:
- def get_backend(self):
- from celery.backends.redis import RedisBackend
- class _RedisBackend(RedisBackend):
- redis = redis
- return _RedisBackend(app=self.app)
- def get_consumer(self):
- return self.get_backend().result_consumer
- @patch('celery.backends.asynchronous.BaseResultConsumer.on_after_fork')
- def test_on_after_fork(self, parent_method):
- consumer = self.get_consumer()
- consumer.start('none')
- consumer.on_after_fork()
- parent_method.assert_called_once()
- consumer.backend.client.connection_pool.reset.assert_called_once()
- consumer._pubsub.close.assert_called_once()
- # PubSub instance not initialized - exception would be raised
- # when calling .close()
- consumer._pubsub = None
- parent_method.reset_mock()
- consumer.backend.client.connection_pool.reset.reset_mock()
- consumer.on_after_fork()
- parent_method.assert_called_once()
- consumer.backend.client.connection_pool.reset.assert_called_once()
- # Continues on KeyError
- consumer._pubsub = Mock()
- consumer._pubsub.close = Mock(side_effect=KeyError)
- parent_method.reset_mock()
- consumer.backend.client.connection_pool.reset.reset_mock()
- consumer.on_after_fork()
- parent_method.assert_called_once()
- @patch('celery.backends.redis.ResultConsumer.cancel_for')
- @patch('celery.backends.asynchronous.BaseResultConsumer.on_state_change')
- def test_on_state_change(self, parent_method, cancel_for):
- consumer = self.get_consumer()
- meta = {'task_id': 'testing', 'status': states.SUCCESS}
- message = 'hello'
- consumer.on_state_change(meta, message)
- parent_method.assert_called_once_with(meta, message)
- cancel_for.assert_called_once_with(meta['task_id'])
- # Does not call cancel_for for other states
- meta = {'task_id': 'testing2', 'status': states.PENDING}
- parent_method.reset_mock()
- cancel_for.reset_mock()
- consumer.on_state_change(meta, message)
- parent_method.assert_called_once_with(meta, message)
- cancel_for.assert_not_called()
- class test_RedisBackend:
- def get_backend(self):
- from celery.backends.redis import RedisBackend
- class _RedisBackend(RedisBackend):
- redis = redis
- return _RedisBackend
- def get_E_LOST(self):
- from celery.backends.redis import E_LOST
- return E_LOST
- def setup(self):
- self.Backend = self.get_backend()
- self.E_LOST = self.get_E_LOST()
- self.b = self.Backend(app=self.app)
- @pytest.mark.usefixtures('depends_on_current_app')
- @skip.unless_module('redis')
- def test_reduce(self):
- from celery.backends.redis import RedisBackend
- x = RedisBackend(app=self.app)
- assert loads(dumps(x))
- def test_no_redis(self):
- self.Backend.redis = None
- with pytest.raises(ImproperlyConfigured):
- self.Backend(app=self.app)
- def test_url(self):
- self.app.conf.redis_socket_timeout = 30.0
- self.app.conf.redis_socket_connect_timeout = 100.0
- x = self.Backend(
- 'redis://:bosco@vandelay.com:123//1', app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['password'] == 'bosco'
- assert x.connparams['socket_timeout'] == 30.0
- assert x.connparams['socket_connect_timeout'] == 100.0
- @skip.unless_module('redis')
- def test_timeouts_in_url_coerced(self):
- x = self.Backend(
- ('redis://:bosco@vandelay.com:123//1?'
- 'socket_timeout=30&socket_connect_timeout=100'),
- app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['password'] == 'bosco'
- assert x.connparams['socket_timeout'] == 30
- assert x.connparams['socket_connect_timeout'] == 100
- def test_socket_url(self):
- self.app.conf.redis_socket_timeout = 30.0
- self.app.conf.redis_socket_connect_timeout = 100.0
- x = self.Backend(
- 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
- )
- assert x.connparams
- assert x.connparams['path'] == '/tmp/redis.sock'
- assert (x.connparams['connection_class'] is
- redis.UnixDomainSocketConnection)
- assert 'host' not in x.connparams
- assert 'port' not in x.connparams
- assert x.connparams['socket_timeout'] == 30.0
- assert 'socket_connect_timeout' not in x.connparams
- assert x.connparams['db'] == 3
- @skip.unless_module('redis')
- def test_backend_ssl(self):
- self.app.conf.redis_backend_use_ssl = {
- 'ssl_cert_reqs': ssl.CERT_REQUIRED,
- 'ssl_ca_certs': '/path/to/ca.crt',
- 'ssl_certfile': '/path/to/client.crt',
- 'ssl_keyfile': '/path/to/client.key',
- }
- self.app.conf.redis_socket_timeout = 30.0
- self.app.conf.redis_socket_connect_timeout = 100.0
- x = self.Backend(
- 'redis://:bosco@vandelay.com:123//1', app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['password'] == 'bosco'
- assert x.connparams['socket_timeout'] == 30.0
- assert x.connparams['socket_connect_timeout'] == 100.0
- assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
- assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'
- assert x.connparams['ssl_certfile'] == '/path/to/client.crt'
- assert x.connparams['ssl_keyfile'] == '/path/to/client.key'
- from redis.connection import SSLConnection
- assert x.connparams['connection_class'] is SSLConnection
- @skip.unless_module('redis')
- def test_backend_ssl_url(self):
- self.app.conf.redis_socket_timeout = 30.0
- self.app.conf.redis_socket_connect_timeout = 100.0
- x = self.Backend(
- 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_REQUIRED',
- app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['password'] == 'bosco'
- assert x.connparams['socket_timeout'] == 30.0
- assert x.connparams['socket_connect_timeout'] == 100.0
- assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
- from redis.connection import SSLConnection
- assert x.connparams['connection_class'] is SSLConnection
- @skip.unless_module('redis')
- def test_backend_ssl_url_options(self):
- x = self.Backend(
- (
- 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_NONE'
- '&ssl_ca_certs=%2Fvar%2Fssl%2Fmyca.pem'
- '&ssl_certfile=%2Fvar%2Fssl%2Fredis-server-cert.pem'
- '&ssl_keyfile=%2Fvar%2Fssl%2Fprivate%2Fworker-key.pem'
- ),
- app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['password'] == 'bosco'
- assert x.connparams['ssl_cert_reqs'] == ssl.CERT_NONE
- assert x.connparams['ssl_ca_certs'] == '/var/ssl/myca.pem'
- assert x.connparams['ssl_certfile'] == '/var/ssl/redis-server-cert.pem'
- assert x.connparams['ssl_keyfile'] == '/var/ssl/private/worker-key.pem'
- @skip.unless_module('redis')
- def test_backend_ssl_url_cert_none(self):
- x = self.Backend(
- 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_OPTIONAL',
- app=self.app,
- )
- assert x.connparams
- assert x.connparams['host'] == 'vandelay.com'
- assert x.connparams['db'] == 1
- assert x.connparams['port'] == 123
- assert x.connparams['ssl_cert_reqs'] == ssl.CERT_OPTIONAL
- from redis.connection import SSLConnection
- assert x.connparams['connection_class'] is SSLConnection
- @skip.unless_module('redis')
- @pytest.mark.parametrize("uri", [
- 'rediss://:bosco@vandelay.com:123//1?ssl_cert_reqs=CERT_KITTY_CATS',
- 'rediss://:bosco@vandelay.com:123//1'
- ])
- def test_backend_ssl_url_invalid(self, uri):
- with pytest.raises(ValueError):
- self.Backend(
- uri,
- app=self.app,
- )
- def test_compat_propertie(self):
- x = self.Backend(
- 'redis://:bosco@vandelay.com:123//1', app=self.app,
- )
- with pytest.warns(CPendingDeprecationWarning):
- assert x.host == 'vandelay.com'
- with pytest.warns(CPendingDeprecationWarning):
- assert x.db == 1
- with pytest.warns(CPendingDeprecationWarning):
- assert x.port == 123
- with pytest.warns(CPendingDeprecationWarning):
- assert x.password == 'bosco'
- def test_conf_raises_KeyError(self):
- self.app.conf = AttributeDict({
- 'result_serializer': 'json',
- 'result_cache_max': 1,
- 'result_expires': None,
- 'accept_content': ['json'],
- })
- self.Backend(app=self.app)
- @patch('celery.backends.redis.logger')
- def test_on_connection_error(self, logger):
- intervals = iter([10, 20, 30])
- exc = KeyError()
- assert self.b.on_connection_error(None, exc, intervals, 1) == 10
- logger.error.assert_called_with(
- self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
- assert self.b.on_connection_error(10, exc, intervals, 2) == 20
- logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
- assert self.b.on_connection_error(10, exc, intervals, 3) == 30
- logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
- def test_incr(self):
- self.b.client = Mock(name='client')
- self.b.incr('foo')
- self.b.client.incr.assert_called_with('foo')
- def test_expire(self):
- self.b.client = Mock(name='client')
- self.b.expire('foo', 300)
- self.b.client.expire.assert_called_with('foo', 300)
- def test_apply_chord(self, unlock='celery.chord_unlock'):
- self.app.tasks[unlock] = Mock()
- header_result = self.app.GroupResult(
- uuid(),
- [self.app.AsyncResult(x) for x in range(3)],
- )
- self.b.apply_chord(header_result, None)
- assert self.app.tasks[unlock].apply_async.call_count == 0
- def test_unpack_chord_result(self):
- self.b.exception_to_python = Mock(name='etp')
- decode = Mock(name='decode')
- exc = KeyError()
- tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
- with pytest.raises(ChordError):
- self.b._unpack_chord_result(tup, decode)
- decode.assert_called_with(tup)
- self.b.exception_to_python.assert_called_with(exc)
- exc = ValueError()
- tup = decode.return_value = (2, 'id2', states.RETRY, exc)
- ret = self.b._unpack_chord_result(tup, decode)
- self.b.exception_to_python.assert_called_with(exc)
- assert ret is self.b.exception_to_python()
- def test_on_chord_part_return_no_gid_or_tid(self):
- request = Mock(name='request')
- request.id = request.group = None
- assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
- def test_ConnectionPool(self):
- self.b.redis = Mock(name='redis')
- assert self.b._ConnectionPool is None
- assert self.b.ConnectionPool is self.b.redis.ConnectionPool
- assert self.b.ConnectionPool is self.b.redis.ConnectionPool
- def test_expires_defaults_to_config(self):
- self.app.conf.result_expires = 10
- b = self.Backend(expires=None, app=self.app)
- assert b.expires == 10
- def test_expires_is_int(self):
- b = self.Backend(expires=48, app=self.app)
- assert b.expires == 48
- def test_add_to_chord(self):
- b = self.Backend('redis://', app=self.app)
- gid = uuid()
- b.add_to_chord(gid, 'sig')
- b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
- def test_expires_is_None(self):
- b = self.Backend(expires=None, app=self.app)
- assert b.expires == self.app.conf.result_expires.total_seconds()
- def test_expires_is_timedelta(self):
- b = self.Backend(expires=timedelta(minutes=1), app=self.app)
- assert b.expires == 60
- def test_mget(self):
- assert self.b.mget(['a', 'b', 'c'])
- self.b.client.mget.assert_called_with(['a', 'b', 'c'])
- def test_set_no_expire(self):
- self.b.expires = None
- self.b.set('foo', 'bar')
- def create_task(self):
- tid = uuid()
- task = Mock(name='task-{0}'.format(tid))
- task.name = 'foobarbaz'
- self.app.tasks['foobarbaz'] = task
- task.request.chord = signature(task)
- task.request.id = tid
- task.request.chord['chord_size'] = 10
- task.request.group = 'group_id'
- return task
- @patch('celery.result.GroupResult.restore')
- def test_on_chord_part_return(self, restore):
- tasks = [self.create_task() for i in range(10)]
- for i in range(10):
- self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
- assert self.b.client.rpush.call_count
- self.b.client.rpush.reset_mock()
- assert self.b.client.lrange.call_count
- jkey = self.b.get_key_for_group('group_id', '.j')
- tkey = self.b.get_key_for_group('group_id', '.t')
- self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
- self.b.client.expire.assert_has_calls([
- call(jkey, 86400), call(tkey, 86400),
- ])
- def test_on_chord_part_return__success(self):
- with self.chord_context(2) as (_, request, callback):
- self.b.on_chord_part_return(request, states.SUCCESS, 10)
- callback.delay.assert_not_called()
- self.b.on_chord_part_return(request, states.SUCCESS, 20)
- callback.delay.assert_called_with([10, 20])
- def test_on_chord_part_return__callback_raises(self):
- with self.chord_context(1) as (_, request, callback):
- callback.delay.side_effect = KeyError(10)
- task = self.app._tasks['add'] = Mock(name='add_task')
- self.b.on_chord_part_return(request, states.SUCCESS, 10)
- task.backend.fail_from_current_stack.assert_called_with(
- callback.id, exc=ANY,
- )
- def test_on_chord_part_return__ChordError(self):
- with self.chord_context(1) as (_, request, callback):
- self.b.client.pipeline = ContextMock()
- raise_on_second_call(self.b.client.pipeline, ChordError())
- self.b.client.pipeline.return_value.rpush().llen().get().expire(
- ).expire().execute.return_value = (1, 1, 0, 4, 5)
- task = self.app._tasks['add'] = Mock(name='add_task')
- self.b.on_chord_part_return(request, states.SUCCESS, 10)
- task.backend.fail_from_current_stack.assert_called_with(
- callback.id, exc=ANY,
- )
- def test_on_chord_part_return__other_error(self):
- with self.chord_context(1) as (_, request, callback):
- self.b.client.pipeline = ContextMock()
- raise_on_second_call(self.b.client.pipeline, RuntimeError())
- self.b.client.pipeline.return_value.rpush().llen().get().expire(
- ).expire().execute.return_value = (1, 1, 0, 4, 5)
- task = self.app._tasks['add'] = Mock(name='add_task')
- self.b.on_chord_part_return(request, states.SUCCESS, 10)
- task.backend.fail_from_current_stack.assert_called_with(
- callback.id, exc=ANY,
- )
- @contextmanager
- def chord_context(self, size=1):
- with patch('celery.backends.redis.maybe_signature') as ms:
- tasks = [self.create_task() for i in range(size)]
- request = Mock(name='request')
- request.id = 'id1'
- request.group = 'gid1'
- callback = ms.return_value = Signature('add')
- callback.id = 'id1'
- callback['chord_size'] = size
- callback.delay = Mock(name='callback.delay')
- yield tasks, request, callback
- def test_process_cleanup(self):
- self.b.process_cleanup()
- def test_get_set_forget(self):
- tid = uuid()
- self.b.store_result(tid, 42, states.SUCCESS)
- assert self.b.get_state(tid) == states.SUCCESS
- assert self.b.get_result(tid) == 42
- self.b.forget(tid)
- assert self.b.get_state(tid) == states.PENDING
- def test_set_expires(self):
- self.b = self.Backend(expires=512, app=self.app)
- tid = uuid()
- key = self.b.get_key_for_task(tid)
- self.b.store_result(tid, 42, states.SUCCESS)
- self.b.client.expire.assert_called_with(
- key, 512,
- )
- class test_SentinelBackend:
- def get_backend(self):
- from celery.backends.redis import SentinelBackend
- class _SentinelBackend(SentinelBackend):
- redis = redis
- sentinel = sentinel
- return _SentinelBackend
- def get_E_LOST(self):
- from celery.backends.redis import E_LOST
- return E_LOST
- def setup(self):
- self.Backend = self.get_backend()
- self.E_LOST = self.get_E_LOST()
- self.b = self.Backend(app=self.app)
- @pytest.mark.usefixtures('depends_on_current_app')
- @skip.unless_module('redis')
- def test_reduce(self):
- from celery.backends.redis import SentinelBackend
- x = SentinelBackend(app=self.app)
- assert loads(dumps(x))
- def test_no_redis(self):
- self.Backend.redis = None
- with pytest.raises(ImproperlyConfigured):
- self.Backend(app=self.app)
- def test_url(self):
- self.app.conf.redis_socket_timeout = 30.0
- self.app.conf.redis_socket_connect_timeout = 100.0
- x = self.Backend(
- 'sentinel://:test@github.com:123/1;'
- 'sentinel://:test@github.com:124/1',
- app=self.app,
- )
- assert x.connparams
- assert "host" not in x.connparams
- assert x.connparams['db'] == 1
- assert "port" not in x.connparams
- assert x.connparams['password'] == "test"
- assert len(x.connparams['hosts']) == 2
- expected_hosts = ["github.com", "github.com"]
- found_hosts = [cp['host'] for cp in x.connparams['hosts']]
- assert found_hosts == expected_hosts
- expected_ports = [123, 124]
- found_ports = [cp['port'] for cp in x.connparams['hosts']]
- assert found_ports == expected_ports
- expected_passwords = ["test", "test"]
- found_passwords = [cp['password'] for cp in x.connparams['hosts']]
- assert found_passwords == expected_passwords
- expected_dbs = [1, 1]
- found_dbs = [cp['db'] for cp in x.connparams['hosts']]
- assert found_dbs == expected_dbs
- def test_get_sentinel_instance(self):
- x = self.Backend(
- 'sentinel://:test@github.com:123/1;'
- 'sentinel://:test@github.com:124/1',
- app=self.app,
- )
- sentinel_instance = x._get_sentinel_instance(**x.connparams)
- assert sentinel_instance.sentinel_kwargs == {}
- assert sentinel_instance.connection_kwargs['db'] == 1
- assert sentinel_instance.connection_kwargs['password'] == "test"
- assert len(sentinel_instance.sentinels) == 2
- def test_get_pool(self):
- x = self.Backend(
- 'sentinel://:test@github.com:123/1;'
- 'sentinel://:test@github.com:124/1',
- app=self.app,
- )
- pool = x._get_pool(**x.connparams)
- assert pool
|