Browse Source

Tests passing

Ask Solem 10 years ago
parent
commit
b0cfa0d818

+ 2 - 2
celery/app/amqp.py

@@ -275,9 +275,9 @@ class AMQP(object):
         kwargs = kwargs or {}
         utc = self.utc
         if not isinstance(args, (list, tuple)):
-            raise ValueError('task args must be a list or tuple')
+            raise TypeError('task args must be a list or tuple')
         if not isinstance(kwargs, Mapping):
-            raise ValueError('task keyword arguments must be a mapping')
+            raise TypeError('task keyword arguments must be a mapping')
         if countdown:  # convert countdown to ETA
             now = now or self.app.now()
             timezone = timezone or self.app.timezone

+ 1 - 1
celery/app/base.py

@@ -286,7 +286,7 @@ class Celery(object):
                 '_decorated': True,
                 '__doc__': fun.__doc__,
                 '__module__': fun.__module__,
-                '__header__': staticmethod(head_from_fun(fun)),
+                '__header__': staticmethod(head_from_fun(fun, bound=bind)),
                 '__wrapped__': fun}, **options))()
             self._tasks[task.name] = task
             task.bind(self)  # connects task to this app

+ 3 - 3
celery/app/task.py

@@ -513,9 +513,9 @@ class Task(object):
         :keyword eta: Explicit time and date to run the retry at
                       (must be a :class:`~datetime.datetime` instance).
         :keyword max_retries: If set, overrides the default retry limit.
-            A value of :const:`None`, means "use the default", so if you want infinite
-            retries you would have to set the :attr:`max_retries` attribute of the
-            task to :const:`None` first.
+            A value of :const:`None`, means "use the default", so if you want
+            infinite retries you would have to set the :attr:`max_retries`
+            attribute of the task to :const:`None` first.
         :keyword time_limit: If set, overrides the default time limit.
         :keyword soft_time_limit: If set, overrides the default soft
                                   time limit.

+ 0 - 1
celery/canvas.py

@@ -451,7 +451,6 @@ class chain(Signature):
             root_id = res.id if root_id is None else root_id
             i += 1
 
-
             if prev_task:
                 # link previous task to this task.
                 prev_task.link(task)

+ 8 - 3
celery/tests/app/test_app.py

@@ -332,9 +332,14 @@ class test_App(AppCase):
 
     def test_apply_async_has__self__(self):
         @self.app.task(__self__='hello', shared=False)
-        def aawsX():
+        def aawsX(x, y):
             pass
 
+        with self.assertRaises(TypeError):
+            aawsX.apply_async(())
+        with self.assertRaises(TypeError):
+            aawsX.apply_async((2, ))
+
         with patch('celery.app.amqp.AMQP.create_task_message') as create:
             with patch('celery.app.amqp.AMQP.send_task_message') as send:
                 create.return_value = Mock(), Mock(), Mock(), Mock()
@@ -346,11 +351,11 @@ class test_App(AppCase):
     def test_apply_async_adds_children(self):
         from celery._state import _task_stack
 
-        @self.app.task(shared=False)
+        @self.app.task(bind=True, shared=False)
         def a3cX1(self):
             pass
 
-        @self.app.task(shared=False)
+        @self.app.task(bind=True, shared=False)
         def a3cX2(self):
             pass
 

+ 1 - 1
celery/tests/app/test_builtins.py

@@ -84,7 +84,7 @@ class test_group(BuiltinsCase):
     def test_apply_async_eager(self):
         self.task.apply = Mock()
         self.app.conf.CELERY_ALWAYS_EAGER = True
-        self.task.apply_async()
+        self.task.apply_async((1, 2, 3, 4, 5))
         self.assertTrue(self.task.apply.called)
 
     def test_apply(self):

+ 0 - 1
celery/tests/app/test_log.py

@@ -207,7 +207,6 @@ class test_default_logger(AppCase):
             self.app.log.setup_logging_subsystem(colorize=True)
 
     def test_setup_logging_subsystem_no_mputil(self):
-        from celery.utils import log as logtools
         with restore_logging():
             with mask_modules('billiard.util'):
                 self.app.log.setup_logging_subsystem()

+ 2 - 2
celery/tests/tasks/test_tasks.py

@@ -261,12 +261,12 @@ class test_tasks(TasksCase):
             IncompleteTask().run()
 
     def test_task_kwargs_must_be_dictionary(self):
-        with self.assertRaises(ValueError):
+        with self.assertRaises(TypeError):
             self.increment_counter.apply_async([], 'str')
 
     def test_task_args_must_be_list(self):
         with self.assertRaises(ValueError):
-            self.increment_counter.apply_async('str', {})
+            self.increment_counter.apply_async('s', {})
 
     def test_regular_task(self):
         self.assertIsInstance(self.mytask, Task)

+ 4 - 3
celery/utils/functional.py

@@ -13,7 +13,7 @@ import threading
 
 from collections import OrderedDict
 from functools import partial, wraps
-from inspect import getargspec, isfunction, ismethod
+from inspect import getargspec, isfunction
 from itertools import islice
 
 from kombu.utils import cached_property
@@ -327,12 +327,11 @@ def _argsfromspec(spec, replace_defaults=True):
     ]))
 
 
-def head_from_fun(fun, debug=True):
+def head_from_fun(fun, bound=False, debug=False):
     if not isfunction(fun) and hasattr(fun, '__call__'):
         name, fun = fun.__class__.__name__, fun.__call__
     else:
         name = fun.__name__
-    spec = getargspec(fun)
     definition = FUNHEAD_TEMPLATE.format(
         fun_name=name,
         fun_args=_argsfromspec(getargspec(fun)),
@@ -344,4 +343,6 @@ def head_from_fun(fun, debug=True):
     exec(definition, namespace)
     result = namespace[name]
     result._source = definition
+    if bound:
+        return partial(result, object())
     return result