|
@@ -9,6 +9,7 @@ from pickle import loads, dumps
|
|
|
|
|
|
from amqp import promise
|
|
|
|
|
|
+from celery import Celery
|
|
|
from celery import shared_task, current_app
|
|
|
from celery import app as _app
|
|
|
from celery import _state
|
|
@@ -19,12 +20,14 @@ from celery.five import items, keys
|
|
|
from celery.loaders.base import BaseLoader, unconfigured
|
|
|
from celery.platforms import pyimplementation
|
|
|
from celery.utils.serialization import pickle
|
|
|
+from celery.utils.timeutils import timezone
|
|
|
|
|
|
from celery.tests.case import (
|
|
|
CELERY_TEST_CONFIG,
|
|
|
AppCase,
|
|
|
Mock,
|
|
|
Case,
|
|
|
+ ContextMock,
|
|
|
depends_on_current_app,
|
|
|
mask_modules,
|
|
|
patch,
|
|
@@ -128,6 +131,12 @@ class test_App(AppCase):
|
|
|
task = app.task(fun)
|
|
|
self.assertEqual(task.name, app.main + '.fun')
|
|
|
|
|
|
+ def test_task_too_many_args(self):
|
|
|
+ with self.assertRaises(TypeError):
|
|
|
+ self.app.task(Mock(name='fun'), True)
|
|
|
+ with self.assertRaises(TypeError):
|
|
|
+ self.app.task(Mock(name='fun'), True, 1, 2)
|
|
|
+
|
|
|
def test_with_config_source(self):
|
|
|
with self.Celery(config_source=ObjectConfig) as app:
|
|
|
self.assertEqual(app.conf.FOO, 1)
|
|
@@ -235,6 +244,18 @@ class test_App(AppCase):
|
|
|
self.assertEqual(prom.fun, self.app._autodiscover_tasks)
|
|
|
self.assertEqual(prom.args[0](), [1, 2, 3])
|
|
|
|
|
|
+ def test_autodiscover_tasks__no_packages(self):
|
|
|
+ fixup1 = Mock(name='fixup')
|
|
|
+ fixup2 = Mock(name='fixup')
|
|
|
+ self.app._autodiscover_tasks_from_names = Mock(name='auto')
|
|
|
+ self.app._fixups = [fixup1, fixup2]
|
|
|
+ fixup1.autodiscover_tasks.return_value = ['A', 'B', 'C']
|
|
|
+ fixup2.autodiscover_tasks.return_value = ['D', 'E', 'F']
|
|
|
+ self.app.autodiscover_tasks(force=True)
|
|
|
+ self.app._autodiscover_tasks_from_names.assert_called_with(
|
|
|
+ ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
|
|
|
+ )
|
|
|
+
|
|
|
@with_environ('CELERY_BROKER_URL', '')
|
|
|
def test_with_broker(self):
|
|
|
with self.Celery(broker='foo://baribaz') as app:
|
|
@@ -739,6 +760,86 @@ class test_App(AppCase):
|
|
|
self.assertIsNone(self.app._pool)
|
|
|
self.app._after_fork(self.app)
|
|
|
|
|
|
+ def test_global_after_fork(self):
|
|
|
+ app = Mock(name='app')
|
|
|
+ prev, _state._apps = _state._apps, [app]
|
|
|
+ try:
|
|
|
+ obj = Mock(name='obj')
|
|
|
+ _appbase._global_after_fork(obj)
|
|
|
+ app._after_fork.assert_called_with(obj)
|
|
|
+ finally:
|
|
|
+ _state._apps = prev
|
|
|
+
|
|
|
+ @patch('multiprocessing.util', create=True)
|
|
|
+ def test_global_after_fork__raises(self, util):
|
|
|
+ app = Mock(name='app')
|
|
|
+ prev, _state._apps = _state._apps, [app]
|
|
|
+ try:
|
|
|
+ obj = Mock(name='obj')
|
|
|
+ exc = app._after_fork.side_effect = KeyError()
|
|
|
+ _appbase._global_after_fork(obj)
|
|
|
+ util._logger.info.assert_called_with(
|
|
|
+ 'after forker raised exception: %r', exc, exc_info=1)
|
|
|
+ util._logger = None
|
|
|
+ _appbase._global_after_fork(obj)
|
|
|
+ finally:
|
|
|
+ _state._apps = prev
|
|
|
+
|
|
|
+ def test_ensure_after_fork__no_multiprocessing(self):
|
|
|
+ prev, _appbase.register_after_fork = (
|
|
|
+ _appbase.register_after_fork, None)
|
|
|
+ try:
|
|
|
+ _appbase._after_fork_registered = False
|
|
|
+ _appbase._ensure_after_fork()
|
|
|
+ self.assertTrue(_appbase._after_fork_registered)
|
|
|
+ finally:
|
|
|
+ _appbase.register_after_fork = prev
|
|
|
+
|
|
|
+ def test_canvas(self):
|
|
|
+ self.assertTrue(self.app.canvas.Signature)
|
|
|
+
|
|
|
+ def test_signature(self):
|
|
|
+ sig = self.app.signature('foo', (1, 2))
|
|
|
+ self.assertIs(sig.app, self.app)
|
|
|
+
|
|
|
+ def test_timezone__none_set(self):
|
|
|
+ self.app.conf.timezone = None
|
|
|
+ tz = self.app.timezone
|
|
|
+ self.assertEqual(tz, timezone.get_timezone('UTC'))
|
|
|
+
|
|
|
+ def test_compat_on_configure(self):
|
|
|
+ on_configure = Mock(name='on_configure')
|
|
|
+
|
|
|
+ class CompatApp(Celery):
|
|
|
+
|
|
|
+ def on_configure(self, *args, **kwargs):
|
|
|
+ on_configure(*args, **kwargs)
|
|
|
+
|
|
|
+ with CompatApp(set_as_current=False) as app:
|
|
|
+ app.loader = Mock()
|
|
|
+ app.loader.conf = {}
|
|
|
+ app._load_config()
|
|
|
+ on_configure.assert_called_with()
|
|
|
+
|
|
|
+ def test_add_periodic_task(self):
|
|
|
+
|
|
|
+ @self.app.task
|
|
|
+ def add(x, y):
|
|
|
+ pass
|
|
|
+ assert not self.app.configured
|
|
|
+ self.app.add_periodic_task(
|
|
|
+ 10, self.app.signature('add', (2, 2)),
|
|
|
+ name='add1', expires=3,
|
|
|
+ )
|
|
|
+ self.assertTrue(self.app._pending_periodic_tasks)
|
|
|
+ assert not self.app.configured
|
|
|
+
|
|
|
+ sig2 = add.s(4, 4)
|
|
|
+ self.assertTrue(self.app.configured)
|
|
|
+ self.app.add_periodic_task(20, sig2, name='add2', expires=4)
|
|
|
+ self.assertIn('add1', self.app.conf.beat_schedule)
|
|
|
+ self.assertIn('add2', self.app.conf.beat_schedule)
|
|
|
+
|
|
|
def test_pool_no_multiprocessing(self):
|
|
|
with mask_modules('multiprocessing.util'):
|
|
|
pool = self.app.pool
|
|
@@ -747,6 +848,18 @@ class test_App(AppCase):
|
|
|
def test_bugreport(self):
|
|
|
self.assertTrue(self.app.bugreport())
|
|
|
|
|
|
+ def test_send_task__connection_provided(self):
|
|
|
+ connection = Mock(name='connection')
|
|
|
+ router = Mock(name='router')
|
|
|
+ router.route.return_value = {}
|
|
|
+ self.app.amqp = Mock(name='amqp')
|
|
|
+ self.app.amqp.Producer.attach_mock(ContextMock(), 'return_value')
|
|
|
+ self.app.send_task('foo', (1, 2), connection=connection, router=router)
|
|
|
+ self.app.amqp.Producer.assert_called_with(connection)
|
|
|
+ self.app.amqp.send_task_message.assert_called_with(
|
|
|
+ self.app.amqp.Producer(), 'foo',
|
|
|
+ self.app.amqp.create_task_message())
|
|
|
+
|
|
|
def test_send_task_sent_event(self):
|
|
|
|
|
|
class Dispatcher(object):
|
|
@@ -799,6 +912,11 @@ class test_App(AppCase):
|
|
|
x.send(Mock(), Mock())
|
|
|
self.assertFalse(task.app.mail_admins.called)
|
|
|
|
|
|
+ def test_select_queues(self):
|
|
|
+ self.app.amqp = Mock(name='amqp')
|
|
|
+ self.app.select_queues({'foo', 'bar'})
|
|
|
+ self.app.amqp.queues.select.assert_called_with({'foo', 'bar'})
|
|
|
+
|
|
|
|
|
|
class test_defaults(AppCase):
|
|
|
|