Browse Source

Tests passing

Ask Solem 12 years ago
parent
commit
d8af0040dc
3 changed files with 55 additions and 39 deletions
  1. 10 2
      celery/tests/app/test_app.py
  2. 44 36
      celery/tests/tasks/test_tasks.py
  3. 1 1
      celery/utils/timeutils.py

+ 10 - 2
celery/tests/app/test_app.py

@@ -157,15 +157,23 @@ class test_App(Case):
             _utils.MP_MAIN_FILE = None
 
     def test_base_task_inherits_magic_kwargs_from_app(self):
-        from celery.app.task import Task
+        from celery.task import Task as OldTask
 
-        class timkX(Task):
+        class timkX(OldTask):
             abstract = True
 
         app = Celery(set_as_current=False, accept_magic_kwargs=True)
         timkX.bind(app)
         self.assertTrue(timkX.accept_magic_kwargs)
 
+        from celery import Task as NewTask
+
+        class timkY(NewTask):
+            abstract = True
+
+        timkY.bind(app)
+        self.assertFalse(timkY.accept_magic_kwargs)
+
     def test_annotate_decorator(self):
         from celery.app.task import Task
 

+ 44 - 36
celery/tests/tasks/test_tasks.py

@@ -6,10 +6,16 @@ from functools import wraps
 from mock import patch
 from pickle import loads, dumps
 
-from celery import task
-from celery.task import current, Task
+from celery.task import (
+    current,
+    task,
+    Task,
+    BaseTask,
+    TaskSet,
+    periodic_task,
+    PeriodicTask
+)
 from celery.app import app_or_default
-from celery.task import task as task_dec
 from celery.exceptions import RetryTaskError
 from celery.execute import send_task
 from celery.result import EagerResult
@@ -25,14 +31,14 @@ def return_True(*args, **kwargs):
     return True
 
 
-return_True_task = task_dec()(return_True)
+return_True_task = task()(return_True)
 
 
 def raise_exception(self, **kwargs):
     raise Exception('%s error' % self.__class__)
 
 
-class MockApplyTask(task.Task):
+class MockApplyTask(Task):
     applied = 0
 
     def run(self, x, y):
@@ -43,18 +49,18 @@ class MockApplyTask(task.Task):
         self.applied += 1
 
 
-@task.task(name='c.unittest.increment_counter_task', count=0)
+@task(name='c.unittest.increment_counter_task', count=0)
 def increment_counter(increment_by=1):
     increment_counter.count += increment_by or 1
     return increment_counter.count
 
 
-@task.task(name='c.unittest.raising_task')
+@task(name='c.unittest.raising_task')
 def raising():
     raise KeyError('foo')
 
 
-@task.task(max_retries=3, iterations=0)
+@task(max_retries=3, iterations=0)
 def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
     current.iterations += 1
     rmax = current.max_retries if max_retries is None else max_retries
@@ -67,7 +73,7 @@ def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
         raise current.retry(countdown=0, max_retries=rmax)
 
 
-@task.task(max_retries=3, iterations=0)
+@task(max_retries=3, iterations=0, accept_magic_kwargs=True)
 def retry_task_noargs(**kwargs):
     current.iterations += 1
 
@@ -78,7 +84,8 @@ def retry_task_noargs(**kwargs):
         raise current.retry(countdown=0)
 
 
-@task.task(max_retries=3, iterations=0, base=MockApplyTask)
+@task(max_retries=3, iterations=0, base=MockApplyTask,
+        accept_magic_kwargs=True)
 def retry_task_mockapply(arg1, arg2, kwarg=1, **kwargs):
     current.iterations += 1
 
@@ -94,7 +101,7 @@ class MyCustomException(Exception):
     """Random custom exception."""
 
 
-@task.task(max_retries=3, iterations=0, accept_magic_kwargs=True)
+@task(max_retries=3, iterations=0, accept_magic_kwargs=True)
 def retry_task_customexc(arg1, arg2, kwarg=1, **kwargs):
     current.iterations += 1
 
@@ -123,6 +130,7 @@ class test_task_retries(Case):
         self.assertEqual(retry_task.iterations, 11)
 
     def test_retry_no_args(self):
+        assert retry_task_noargs.accept_magic_kwargs
         retry_task_noargs.__class__.max_retries = 3
         retry_task_noargs.iterations = 0
         retry_task_noargs.apply()
@@ -205,14 +213,14 @@ class test_tasks(Case):
     def test_unpickle_task(self):
         import pickle
 
-        @task_dec
+        @task
         def xxx():
             pass
 
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
 
     def createTask(self, name):
-        return task.task(__module__=self.__module__, name=name)(return_True)
+        return task(__module__=self.__module__, name=name)(return_True)
 
     def test_AsyncResult(self):
         task_id = uuid()
@@ -240,7 +248,7 @@ class test_tasks(Case):
 
     def test_incomplete_task_cls(self):
 
-        class IncompleteTask(task.Task):
+        class IncompleteTask(Task):
             name = 'c.unittest.t.itask'
 
         with self.assertRaises(NotImplementedError):
@@ -256,7 +264,7 @@ class test_tasks(Case):
 
     def test_regular_task(self):
         T1 = self.createTask('c.unittest.t.t1')
-        self.assertIsInstance(T1, task.BaseTask)
+        self.assertIsInstance(T1, BaseTask)
         self.assertTrue(T1.run())
         self.assertTrue(callable(T1),
                 'Task class is callable()')
@@ -377,7 +385,7 @@ class test_tasks(Case):
 
     def test_update_state(self):
 
-        @task_dec
+        @task
         def yyy():
             pass
 
@@ -393,7 +401,7 @@ class test_tasks(Case):
 
     def test_repr(self):
 
-        @task_dec
+        @task
         def task_test_repr():
             pass
 
@@ -401,7 +409,7 @@ class test_tasks(Case):
 
     def test_has___name__(self):
 
-        @task_dec
+        @task
         def yyy2():
             pass
 
@@ -423,13 +431,13 @@ class test_TaskSet(Case):
     @with_eager_tasks
     def test_function_taskset(self):
         subtasks = [return_True_task.s(i) for i in range(1, 6)]
-        ts = task.TaskSet(subtasks)
+        ts = TaskSet(subtasks)
         res = ts.apply_async()
         self.assertListEqual(res.join(), [True, True, True, True, True])
 
     def test_counter_taskset(self):
         increment_counter.count = 0
-        ts = task.TaskSet(tasks=[
+        ts = TaskSet(tasks=[
             increment_counter.s(),
             increment_counter.s(increment_by=2),
             increment_counter.s(increment_by=3),
@@ -460,7 +468,7 @@ class test_TaskSet(Case):
 
     def test_named_taskset(self):
         prefix = 'test_named_taskset-'
-        ts = task.TaskSet([return_True_task.subtask([1])])
+        ts = TaskSet([return_True_task.subtask([1])])
         res = ts.apply(taskset_id=prefix + uuid())
         self.assertTrue(res.taskset_id.startswith(prefix))
 
@@ -511,7 +519,7 @@ class test_apply_task(Case):
             f.get()
 
 
-@task.periodic_task(run_every=timedelta(hours=1))
+@periodic_task(run_every=timedelta(hours=1))
 def my_periodic():
     pass
 
@@ -520,7 +528,7 @@ class test_periodic_tasks(Case):
 
     def test_must_have_run_every(self):
         with self.assertRaises(NotImplementedError):
-            type('Foo', (task.PeriodicTask, ), {'__module__': __name__})
+            type('Foo', (PeriodicTask, ), {'__module__': __name__})
 
     def test_remaining_estimate(self):
         self.assertIsInstance(
@@ -547,43 +555,43 @@ class test_periodic_tasks(Case):
         self.assertTrue(repr(p.run_every))
 
 
-@task.periodic_task(run_every=crontab())
+@periodic_task(run_every=crontab())
 def every_minute():
     pass
 
 
-@task.periodic_task(run_every=crontab(minute='*/15'))
+@periodic_task(run_every=crontab(minute='*/15'))
 def quarterly():
     pass
 
 
-@task.periodic_task(run_every=crontab(minute=30))
+@periodic_task(run_every=crontab(minute=30))
 def hourly():
     pass
 
 
-@task.periodic_task(run_every=crontab(hour=7, minute=30))
+@periodic_task(run_every=crontab(hour=7, minute=30))
 def daily():
     pass
 
 
-@task.periodic_task(run_every=crontab(hour=7, minute=30,
-                                      day_of_week='thursday'))
+@periodic_task(run_every=crontab(hour=7, minute=30,
+                                 day_of_week='thursday'))
 def weekly():
     pass
 
 
-@task.periodic_task(run_every=crontab(hour=7, minute=30,
-                                      day_of_week='thursday',
-                                      day_of_month='8-14'))
+@periodic_task(run_every=crontab(hour=7, minute=30,
+                                 day_of_week='thursday',
+                                 day_of_month='8-14'))
 def monthly():
     pass
 
 
-@task.periodic_task(run_every=crontab(hour=7, minute=30,
-                                      day_of_week='thursday',
-                                      day_of_month='8-14',
-                                      month_of_year=3))
+@periodic_task(run_every=crontab(hour=7, minute=30,
+                                 day_of_week='thursday',
+                                 day_of_month='8-14',
+                                 month_of_year=3))
 def yearly():
     pass
 

+ 1 - 1
celery/utils/timeutils.py

@@ -48,7 +48,7 @@ class _Zone(object):
         return self.get_timezone(tzinfo)
 
     def to_local(self, dt, local=None, orig=None):
-        return to_tz(dt, orig or self.utc).astimezone(
+        return set_tz(dt, orig or self.utc).astimezone(
                     self.tz_or_local(local))
 
     def get_timezone(self, zone):