| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 | from __future__ import absolute_import, unicode_literalsimport randomimport sslfrom contextlib import contextmanagerfrom datetime import timedeltafrom pickle import dumps, loadsimport pytestfrom case import ANY, ContextMock, Mock, call, mock, patch, skipfrom celery import signature, states, uuidfrom celery.canvas import Signaturefrom celery.exceptions import (ChordError, CPendingDeprecationWarning,                               ImproperlyConfigured)from celery.utils.collections import AttributeDictdef raise_on_second_call(mock, exc, *retval):    def on_first_call(*args, **kwargs):        mock.side_effect = exc        return mock.return_value    mock.side_effect = on_first_call    if retval:        mock.return_value, = retvalclass Connection(object):    connected = True    def disconnect(self):        self.connected = Falseclass Pipeline(object):    def __init__(self, client):        self.client = client        self.steps = []    def __getattr__(self, attr):        def add_step(*args, **kwargs):            self.steps.append((getattr(self.client, attr), args, kwargs))            return self        return add_step    def __enter__(self):        return self    def __exit__(self, type, value, traceback):        pass    def execute(self):        return [step(*a, **kw) for step, a, kw in self.steps]class Redis(mock.MockCallbacks):    Connection = Connection    Pipeline = Pipeline    def __init__(self, host=None, port=None, db=None, password=None, **kw):        self.host = host        self.port = port        self.db = db        self.password = password        self.keyspace = {}        self.expiry = {}        self.connection = self.Connection()    def get(self, key):        return self.keyspace.get(key)    def setex(self, key, expires, value):        self.set(key, value)        self.expire(key, expires)    def set(self, key, value):        self.keyspace[key] = value    def expire(self, key, expires):        self.expiry[key] = expires        return expires    def delete(self, key):        return bool(self.keyspace.pop(key, None))    def pipeline(self):        return self.Pipeline(self)    def _get_list(self, key):        try:            return self.keyspace[key]        except KeyError:            l = self.keyspace[key] = []            return l    def rpush(self, key, value):        self._get_list(key).append(value)    def lrange(self, key, start, stop):        return self._get_list(key)[start:stop]    def llen(self, key):        return len(self.keyspace.get(key) or [])class Sentinel(mock.MockCallbacks):    def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None,                 **connection_kwargs):        self.sentinel_kwargs = sentinel_kwargs        self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs)                          for hostname, port in sentinels]        self.min_other_sentinels = min_other_sentinels        self.connection_kwargs = connection_kwargs    def master_for(self, service_name, redis_class):        return random.choice(self.sentinels)class redis(object):    StrictRedis = Redis    class ConnectionPool(object):        def __init__(self, **kwargs):            pass    class UnixDomainSocketConnection(object):        def __init__(self, **kwargs):            passclass sentinel(object):    Sentinel = Sentinelclass test_RedisBackend:    def get_backend(self):        from celery.backends.redis import RedisBackend        class _RedisBackend(RedisBackend):            redis = redis        return _RedisBackend    def get_E_LOST(self):        from celery.backends.redis import E_LOST        return E_LOST    def setup(self):        self.Backend = self.get_backend()        self.E_LOST = self.get_E_LOST()        self.b = self.Backend(app=self.app)    @pytest.mark.usefixtures('depends_on_current_app')    @skip.unless_module('redis')    def test_reduce(self):        from celery.backends.redis import RedisBackend        x = RedisBackend(app=self.app)        assert loads(dumps(x))    def test_no_redis(self):        self.Backend.redis = None        with pytest.raises(ImproperlyConfigured):            self.Backend(app=self.app)    def test_url(self):        self.app.conf.redis_socket_timeout = 30.0        self.app.conf.redis_socket_connect_timeout = 100.0        x = self.Backend(            'redis://:bosco@vandelay.com:123//1', app=self.app,        )        assert x.connparams        assert x.connparams['host'] == 'vandelay.com'        assert x.connparams['db'] == 1        assert x.connparams['port'] == 123        assert x.connparams['password'] == 'bosco'        assert x.connparams['socket_timeout'] == 30.0        assert x.connparams['socket_connect_timeout'] == 100.0    def test_socket_url(self):        self.app.conf.redis_socket_timeout = 30.0        self.app.conf.redis_socket_connect_timeout = 100.0        x = self.Backend(            'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,        )        assert x.connparams        assert x.connparams['path'] == '/tmp/redis.sock'        assert (x.connparams['connection_class'] is                redis.UnixDomainSocketConnection)        assert 'host' not in x.connparams        assert 'port' not in x.connparams        assert x.connparams['socket_timeout'] == 30.0        assert 'socket_connect_timeout' not in x.connparams        assert x.connparams['db'] == 3    @skip.unless_module('redis')    def test_backend_ssl(self):        self.app.conf.redis_backend_use_ssl = {            'ssl_cert_reqs': ssl.CERT_REQUIRED,            'ssl_ca_certs': '/path/to/ca.crt',            'ssl_certfile': '/path/to/client.crt',            'ssl_keyfile': '/path/to/client.key',        }        self.app.conf.redis_socket_timeout = 30.0        self.app.conf.redis_socket_connect_timeout = 100.0        x = self.Backend(            'redis://:bosco@vandelay.com:123//1', app=self.app,        )        assert x.connparams        assert x.connparams['host'] == 'vandelay.com'        assert x.connparams['db'] == 1        assert x.connparams['port'] == 123        assert x.connparams['password'] == 'bosco'        assert x.connparams['socket_timeout'] == 30.0        assert x.connparams['socket_connect_timeout'] == 100.0        assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED        assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'        assert x.connparams['ssl_certfile'] == '/path/to/client.crt'        assert x.connparams['ssl_keyfile'] == '/path/to/client.key'        from redis.connection import SSLConnection        assert x.connparams['connection_class'] is SSLConnection    def test_compat_propertie(self):        x = self.Backend(            'redis://:bosco@vandelay.com:123//1', app=self.app,        )        with pytest.warns(CPendingDeprecationWarning):            assert x.host == 'vandelay.com'        with pytest.warns(CPendingDeprecationWarning):            assert x.db == 1        with pytest.warns(CPendingDeprecationWarning):            assert x.port == 123        with pytest.warns(CPendingDeprecationWarning):            assert x.password == 'bosco'    def test_conf_raises_KeyError(self):        self.app.conf = AttributeDict({            'result_serializer': 'json',            'result_cache_max': 1,            'result_expires': None,            'accept_content': ['json'],        })        self.Backend(app=self.app)    @patch('celery.backends.redis.logger')    def test_on_connection_error(self, logger):        intervals = iter([10, 20, 30])        exc = KeyError()        assert self.b.on_connection_error(None, exc, intervals, 1) == 10        logger.error.assert_called_with(            self.E_LOST, 1, 'Inf', 'in 10.00 seconds')        assert self.b.on_connection_error(10, exc, intervals, 2) == 20        logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')        assert self.b.on_connection_error(10, exc, intervals, 3) == 30        logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')    def test_incr(self):        self.b.client = Mock(name='client')        self.b.incr('foo')        self.b.client.incr.assert_called_with('foo')    def test_expire(self):        self.b.client = Mock(name='client')        self.b.expire('foo', 300)        self.b.client.expire.assert_called_with('foo', 300)    def test_apply_chord(self):        header = Mock(name='header')        header.results = [Mock(name='t1'), Mock(name='t2')]        self.b.apply_chord(            header, (1, 2), 'gid', None,            options={'max_retries': 10},        )        header.assert_called_with(1, 2, max_retries=10, task_id='gid')    def test_unpack_chord_result(self):        self.b.exception_to_python = Mock(name='etp')        decode = Mock(name='decode')        exc = KeyError()        tup = decode.return_value = (1, 'id1', states.FAILURE, exc)        with pytest.raises(ChordError):            self.b._unpack_chord_result(tup, decode)        decode.assert_called_with(tup)        self.b.exception_to_python.assert_called_with(exc)        exc = ValueError()        tup = decode.return_value = (2, 'id2', states.RETRY, exc)        ret = self.b._unpack_chord_result(tup, decode)        self.b.exception_to_python.assert_called_with(exc)        assert ret is self.b.exception_to_python()    def test_on_chord_part_return_no_gid_or_tid(self):        request = Mock(name='request')        request.id = request.group = None        assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None    def test_ConnectionPool(self):        self.b.redis = Mock(name='redis')        assert self.b._ConnectionPool is None        assert self.b.ConnectionPool is self.b.redis.ConnectionPool        assert self.b.ConnectionPool is self.b.redis.ConnectionPool    def test_expires_defaults_to_config(self):        self.app.conf.result_expires = 10        b = self.Backend(expires=None, app=self.app)        assert b.expires == 10    def test_expires_is_int(self):        b = self.Backend(expires=48, app=self.app)        assert b.expires == 48    def test_add_to_chord(self):        b = self.Backend('redis://', app=self.app)        gid = uuid()        b.add_to_chord(gid, 'sig')        b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)    def test_expires_is_None(self):        b = self.Backend(expires=None, app=self.app)        assert b.expires == self.app.conf.result_expires.total_seconds()    def test_expires_is_timedelta(self):        b = self.Backend(expires=timedelta(minutes=1), app=self.app)        assert b.expires == 60    def test_mget(self):        assert self.b.mget(['a', 'b', 'c'])        self.b.client.mget.assert_called_with(['a', 'b', 'c'])    def test_set_no_expire(self):        self.b.expires = None        self.b.set('foo', 'bar')    def create_task(self):        tid = uuid()        task = Mock(name='task-{0}'.format(tid))        task.name = 'foobarbaz'        self.app.tasks['foobarbaz'] = task        task.request.chord = signature(task)        task.request.id = tid        task.request.chord['chord_size'] = 10        task.request.group = 'group_id'        return task    @patch('celery.result.GroupResult.restore')    def test_on_chord_part_return(self, restore):        tasks = [self.create_task() for i in range(10)]        for i in range(10):            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)            assert self.b.client.rpush.call_count            self.b.client.rpush.reset_mock()        assert self.b.client.lrange.call_count        jkey = self.b.get_key_for_group('group_id', '.j')        tkey = self.b.get_key_for_group('group_id', '.t')        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])        self.b.client.expire.assert_has_calls([            call(jkey, 86400), call(tkey, 86400),        ])    def test_on_chord_part_return__success(self):        with self.chord_context(2) as (_, request, callback):            self.b.on_chord_part_return(request, states.SUCCESS, 10)            callback.delay.assert_not_called()            self.b.on_chord_part_return(request, states.SUCCESS, 20)            callback.delay.assert_called_with([10, 20])    def test_on_chord_part_return__callback_raises(self):        with self.chord_context(1) as (_, request, callback):            callback.delay.side_effect = KeyError(10)            task = self.app._tasks['add'] = Mock(name='add_task')            self.b.on_chord_part_return(request, states.SUCCESS, 10)            task.backend.fail_from_current_stack.assert_called_with(                callback.id, exc=ANY,            )    def test_on_chord_part_return__ChordError(self):        with self.chord_context(1) as (_, request, callback):            self.b.client.pipeline = ContextMock()            raise_on_second_call(self.b.client.pipeline, ChordError())            self.b.client.pipeline.return_value.rpush().llen().get().expire(            ).expire().execute.return_value = (1, 1, 0, 4, 5)            task = self.app._tasks['add'] = Mock(name='add_task')            self.b.on_chord_part_return(request, states.SUCCESS, 10)            task.backend.fail_from_current_stack.assert_called_with(                callback.id, exc=ANY,            )    def test_on_chord_part_return__other_error(self):        with self.chord_context(1) as (_, request, callback):            self.b.client.pipeline = ContextMock()            raise_on_second_call(self.b.client.pipeline, RuntimeError())            self.b.client.pipeline.return_value.rpush().llen().get().expire(            ).expire().execute.return_value = (1, 1, 0, 4, 5)            task = self.app._tasks['add'] = Mock(name='add_task')            self.b.on_chord_part_return(request, states.SUCCESS, 10)            task.backend.fail_from_current_stack.assert_called_with(                callback.id, exc=ANY,            )    @contextmanager    def chord_context(self, size=1):        with patch('celery.backends.redis.maybe_signature') as ms:            tasks = [self.create_task() for i in range(size)]            request = Mock(name='request')            request.id = 'id1'            request.group = 'gid1'            callback = ms.return_value = Signature('add')            callback.id = 'id1'            callback['chord_size'] = size            callback.delay = Mock(name='callback.delay')            yield tasks, request, callback    def test_process_cleanup(self):        self.b.process_cleanup()    def test_get_set_forget(self):        tid = uuid()        self.b.store_result(tid, 42, states.SUCCESS)        assert self.b.get_state(tid) == states.SUCCESS        assert self.b.get_result(tid) == 42        self.b.forget(tid)        assert self.b.get_state(tid) == states.PENDING    def test_set_expires(self):        self.b = self.Backend(expires=512, app=self.app)        tid = uuid()        key = self.b.get_key_for_task(tid)        self.b.store_result(tid, 42, states.SUCCESS)        self.b.client.expire.assert_called_with(            key, 512,        )class test_SentinelBackend:    def get_backend(self):        from celery.backends.redis import SentinelBackend        class _SentinelBackend(SentinelBackend):            redis = redis            sentinel = sentinel        return _SentinelBackend    def get_E_LOST(self):        from celery.backends.redis import E_LOST        return E_LOST    def setup(self):        self.Backend = self.get_backend()        self.E_LOST = self.get_E_LOST()        self.b = self.Backend(app=self.app)    @pytest.mark.usefixtures('depends_on_current_app')    @skip.unless_module('redis')    def test_reduce(self):        from celery.backends.redis import SentinelBackend        x = SentinelBackend(app=self.app)        assert loads(dumps(x))    def test_no_redis(self):        self.Backend.redis = None        with pytest.raises(ImproperlyConfigured):            self.Backend(app=self.app)    def test_url(self):        self.app.conf.redis_socket_timeout = 30.0        self.app.conf.redis_socket_connect_timeout = 100.0        x = self.Backend(            'sentinel://:test@github.com:123/1;'            'sentinel://:test@github.com:124/1',            app=self.app,        )        assert x.connparams        assert "host" not in x.connparams        assert x.connparams['db'] == 1        assert "port" not in x.connparams        assert x.connparams['password'] == "test"        assert len(x.connparams['hosts']) == 2        expected_hosts = ["github.com", "github.com"]        found_hosts = [cp['host'] for cp in x.connparams['hosts']]        assert found_hosts == expected_hosts        expected_ports = [123, 124]        found_ports = [cp['port'] for cp in x.connparams['hosts']]        assert found_ports == expected_ports        expected_passwords = ["test", "test"]        found_passwords = [cp['password'] for cp in x.connparams['hosts']]        assert found_passwords == expected_passwords        expected_dbs = [1, 1]        found_dbs = [cp['db'] for cp in x.connparams['hosts']]        assert found_dbs == expected_dbs    def test_get_sentinel_instance(self):        x = self.Backend(            'sentinel://:test@github.com:123/1;'            'sentinel://:test@github.com:124/1',            app=self.app,        )        sentinel_instance = x._get_sentinel_instance(**x.connparams)        assert sentinel_instance.sentinel_kwargs == {}        assert sentinel_instance.connection_kwargs['db'] == 1        assert sentinel_instance.connection_kwargs['password'] == "test"        assert len(sentinel_instance.sentinels) == 2    def test_get_pool(self):        x = self.Backend(            'sentinel://:test@github.com:123/1;'            'sentinel://:test@github.com:124/1',            app=self.app,        )        pool = x._get_pool(**x.connparams)        assert pool
 |