from __future__ import absolute_import, unicode_literals from celery import group, chord from celery.app import builtins from celery.five import range from celery.utils.functional import pass1 from celery.tests.case import AppCase, ContextMock, Mock, patch class BuiltinsCase(AppCase): def setup(self): @self.app.task(shared=False) def xsum(x): return sum(x) self.xsum = xsum @self.app.task(shared=False) def add(x, y): return x + y self.add = add class test_backend_cleanup(BuiltinsCase): def test_run(self): self.app.backend.cleanup = Mock() self.app.backend.cleanup.__name__ = 'cleanup' cleanup_task = builtins.add_backend_cleanup_task(self.app) cleanup_task() self.app.backend.cleanup.assert_called() class test_accumulate(BuiltinsCase): def setup(self): self.accumulate = self.app.tasks['celery.accumulate'] def test_with_index(self): self.assertEqual(self.accumulate(1, 2, 3, 4, index=0), 1) def test_no_index(self): self.assertEqual(self.accumulate(1, 2, 3, 4), (1, 2, 3, 4)) class test_map(BuiltinsCase): def test_run(self): @self.app.task(shared=False) def map_mul(x): return x[0] * x[1] res = self.app.tasks['celery.map']( map_mul, [(2, 2), (4, 4), (8, 8)], ) self.assertEqual(res, [4, 16, 64]) class test_starmap(BuiltinsCase): def test_run(self): @self.app.task(shared=False) def smap_mul(x, y): return x * y res = self.app.tasks['celery.starmap']( smap_mul, [(2, 2), (4, 4), (8, 8)], ) self.assertEqual(res, [4, 16, 64]) class test_chunks(BuiltinsCase): @patch('celery.canvas.chunks.apply_chunks') def test_run(self, apply_chunks): @self.app.task(shared=False) def chunks_mul(l): return l self.app.tasks['celery.chunks']( chunks_mul, [(2, 2), (4, 4), (8, 8)], 1, ) apply_chunks.assert_called() class test_group(BuiltinsCase): def setup(self): self.maybe_signature = self.patch('celery.canvas.maybe_signature') self.maybe_signature.side_effect = pass1 self.app.producer_or_acquire = Mock() self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value') self.app.conf.task_always_eager = True self.task = builtins.add_group_task(self.app) super(test_group, self).setup() def test_apply_async_eager(self): self.task.apply = Mock(name='apply') self.task.apply_async((1, 2, 3, 4, 5)) self.task.apply.assert_called() def mock_group(self, *tasks): g = group(*tasks, app=self.app) result = g.freeze() for task in g.tasks: task.clone = Mock(name='clone') task.clone.attach_mock(Mock(), 'apply_async') return g, result @patch('celery.app.base.Celery.current_worker_task') def test_task(self, current_worker_task): g, result = self.mock_group(self.add.s(2), self.add.s(4)) self.task(g.tasks, result, result.id, (2,)).results g.tasks[0].clone().apply_async.assert_called_with( group_id=result.id, producer=self.app.producer_or_acquire(), add_to_parent=False, ) current_worker_task.add_trail.assert_called_with(result) @patch('celery.app.base.Celery.current_worker_task') def test_task__disable_add_to_parent(self, current_worker_task): g, result = self.mock_group(self.add.s(2, 2), self.add.s(4, 4)) self.task(g.tasks, result, result.id, None, add_to_parent=False) current_worker_task.add_trail.assert_not_called() class test_chain(BuiltinsCase): def setup(self): BuiltinsCase.setup(self) self.task = builtins.add_chain_task(self.app) def test_not_implemented(self): with self.assertRaises(NotImplementedError): self.task() class test_chord(BuiltinsCase): def setup(self): self.task = builtins.add_chord_task(self.app) super(test_chord, self).setup() def test_apply_async(self): x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s()) r = x.apply_async() self.assertTrue(r) self.assertTrue(r.parent) def test_run_header_not_group(self): self.task([self.add.s(i, i) for i in range(10)], self.xsum.s()) def test_forward_options(self): body = self.xsum.s() x = chord([self.add.s(i, i) for i in range(10)], body=body) x.run = Mock(name='chord.run(x)') x.apply_async(group_id='some_group_id') x.run.assert_called() resbody = x.run.call_args[0][1] self.assertEqual(resbody.options['group_id'], 'some_group_id') x2 = chord([self.add.s(i, i) for i in range(10)], body=body) x2.run = Mock(name='chord.run(x2)') x2.apply_async(chord='some_chord_id') x2.run.assert_called() resbody = x2.run.call_args[0][1] self.assertEqual(resbody.options['chord'], 'some_chord_id') def test_apply_eager(self): self.app.conf.task_always_eager = True x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s()) r = x.apply_async() self.assertEqual(r.get(), 90) def test_apply_eager_with_arguments(self): self.app.conf.task_always_eager = True x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s()) r = x.apply_async([1]) self.assertEqual(r.get(), 55)