Browse Source

Refactored stuff out of celery.task.base.

TaskSet is now in celery.task.sets.TaskSet
ExecuteRemoteTask is now in celery.task.builtins

The following methods has been removed:

    TaskSet.remote_execute (use celery.task.remote_execute)
    Taskset.map (use celery.task.dmap)
    Taskset.map_async (use celery.task.dmap_async)
Ask Solem 15 years ago
parent
commit
ad955af185

+ 6 - 4
celery/task/__init__.py

@@ -7,8 +7,10 @@ Working with tasks and task sets.
 from celery.execute import apply_async
 from celery.execute import apply_async
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.serialization import pickle
 from celery.serialization import pickle
-from celery.task.base import Task, TaskSet, PeriodicTask, ExecuteRemoteTask
-from celery.task.builtins import PingTask
+from celery.task.base import Task, PeriodicTask
+from celery.task.sets import TaskSet
+from celery.task.builtins import PingTask, ExecuteRemoteTask
+from celery.task.builtins import AsynchronousMapTask, _dmap
 from celery.task.control import discard_all
 from celery.task.control import discard_all
 from celery.task.http import HttpDispatchTask
 from celery.task.http import HttpDispatchTask
 
 
@@ -27,7 +29,7 @@ def dmap(fun, args, timeout=None):
         [4, 8, 16]
         [4, 8, 16]
 
 
     """
     """
-    return TaskSet.map(fun, args, timeout=timeout)
+    return _dmap(fun, args, timeout)
 
 
 
 
 def dmap_async(fun, args, timeout=None):
 def dmap_async(fun, args, timeout=None):
@@ -49,7 +51,7 @@ def dmap_async(fun, args, timeout=None):
         [4, 8, 16]
         [4, 8, 16]
 
 
     """
     """
-    return TaskSet.map_async(fun, args, timeout=timeout)
+    return AsynchronousMapTask.delay(pickle.dumps(fun), args, timeout=timeout)
 
 
 
 
 def execute_remote(fun, *args, **kwargs):
 def execute_remote(fun, *args, **kwargs):

+ 2 - 174
celery/task/base.py

@@ -3,18 +3,17 @@ from datetime import timedelta
 
 
 from celery import conf
 from celery import conf
 from celery.log import setup_task_logger
 from celery.log import setup_task_logger
-from celery.utils import gen_unique_id, padlist
 from celery.utils.timeutils import timedelta_seconds
 from celery.utils.timeutils import timedelta_seconds
-from celery.result import BaseAsyncResult, TaskSetResult, EagerResult
+from celery.result import BaseAsyncResult, EagerResult
 from celery.execute import apply_async, apply
 from celery.execute import apply_async, apply
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.backends import default_backend
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.messaging import establish_connection as _establish_connection
 from celery.messaging import establish_connection as _establish_connection
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
-from celery.serialization import pickle
 
 
 from celery.task.schedules import schedule
 from celery.task.schedules import schedule
+from celery.task.sets import TaskSet
 
 
 
 
 class TaskType(type):
 class TaskType(type):
@@ -499,177 +498,6 @@ class Task(object):
         return "<%s: %s (%s)>" % (kind, self.name, self.type)
         return "<%s: %s (%s)>" % (kind, self.name, self.type)
 
 
 
 
-class ExecuteRemoteTask(Task):
-    """Execute an arbitrary function or object.
-
-    *Note* You probably want :func:`execute_remote` instead, which this
-    is an internal component of.
-
-    The object must be pickleable, so you can't use lambdas or functions
-    defined in the REPL (that is the python shell, or ``ipython``).
-
-    """
-    name = "celery.execute_remote"
-
-    def run(self, ser_callable, fargs, fkwargs, **kwargs):
-        """
-        :param ser_callable: A pickled function or callable object.
-        :param fargs: Positional arguments to apply to the function.
-        :param fkwargs: Keyword arguments to apply to the function.
-
-        """
-        return pickle.loads(ser_callable)(*fargs, **fkwargs)
-
-
-class AsynchronousMapTask(Task):
-    """Task used internally by :func:`dmap_async` and
-    :meth:`TaskSet.map_async`.  """
-    name = "celery.map_async"
-
-    def run(self, ser_callable, args, timeout=None, **kwargs):
-        """:see :meth:`TaskSet.dmap_async`."""
-        return TaskSet.map(pickle.loads(ser_callable), args, timeout=timeout)
-
-
-class TaskSet(object):
-    """A task containing several subtasks, making it possible
-    to track how many, or when all of the tasks has been completed.
-
-    :param task: The task class or name.
-        Can either be a fully qualified task name, or a task class.
-
-    :param args: A list of args, kwargs pairs.
-        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
-
-
-    .. attribute:: task_name
-
-        The name of the task.
-
-    .. attribute:: arguments
-
-        The arguments, as passed to the task set constructor.
-
-    .. attribute:: total
-
-        Total number of tasks in this task set.
-
-    Example
-
-        >>> from djangofeeds.tasks import RefreshFeedTask
-        >>> taskset = TaskSet(RefreshFeedTask, args=[
-        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
-        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
-        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
-        ... ])
-
-        >>> taskset_result = taskset.apply_async()
-        >>> list_of_return_values = taskset_result.join()
-
-    """
-
-    def __init__(self, task, args):
-        try:
-            task_name = task.name
-            task_obj = task
-        except AttributeError:
-            task_name = task
-            task_obj = tasks[task_name]
-
-        # Get task instance
-        task_obj = tasks[task_obj.name]
-
-        self.task = task_obj
-        self.task_name = task_name
-        self.arguments = args
-        self.total = len(args)
-
-    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
-        """Run all tasks in the taskset.
-
-        :returns: A :class:`celery.result.TaskSetResult` instance.
-
-        Example
-
-            >>> ts = TaskSet(RefreshFeedTask, args=[
-            ...         (["http://foo.com/rss"], {}),
-            ...         (["http://bar.com/rss"], {}),
-            ... ])
-            >>> result = ts.apply_async()
-            >>> result.taskset_id
-            "d2c9b261-8eff-4bfb-8459-1e1b72063514"
-            >>> result.subtask_ids
-            ["b4996460-d959-49c8-aeb9-39c530dcde25",
-            "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
-            >>> result.waiting()
-            True
-            >>> time.sleep(10)
-            >>> result.ready()
-            True
-            >>> result.successful()
-            True
-            >>> result.failed()
-            False
-            >>> result.join()
-            [True, True]
-
-        """
-        if conf.ALWAYS_EAGER:
-            return self.apply()
-
-        taskset_id = gen_unique_id()
-        conn = self.task.establish_connection(connect_timeout=connect_timeout)
-        publisher = self.task.get_publisher(connection=conn)
-        try:
-            subtasks = [self.apply_part(arglist, taskset_id, publisher)
-                            for arglist in self.arguments]
-        finally:
-            publisher.close()
-            conn.close()
-
-        return TaskSetResult(taskset_id, subtasks)
-
-    def apply_part(self, arglist, taskset_id, publisher):
-        """Apply a single part of the taskset."""
-        args, kwargs, opts = padlist(arglist, 3, default={})
-        return apply_async(self.task, args, kwargs,
-                           taskset_id=taskset_id, publisher=publisher, **opts)
-
-    def apply(self):
-        """Applies the taskset locally."""
-        taskset_id = gen_unique_id()
-        subtasks = [apply(self.task, args, kwargs)
-                        for args, kwargs in self.arguments]
-
-        # This will be filled with EagerResults.
-        return TaskSetResult(taskset_id, subtasks)
-
-    @classmethod
-    def remote_execute(cls, func, args):
-        """Apply ``args`` to function by distributing the args to the
-        celery server(s)."""
-        pickled = pickle.dumps(func)
-        arguments = [((pickled, arg, {}), {}) for arg in args]
-        return cls(ExecuteRemoteTask, arguments)
-
-    @classmethod
-    def map(cls, func, args, timeout=None):
-        """Distribute processing of the arguments and collect the results."""
-        remote_task = cls.remote_execute(func, args)
-        return remote_task.apply_async().join(timeout=timeout)
-
-    @classmethod
-    def map_async(cls, func, args, timeout=None):
-        """Distribute processing of the arguments and collect the results
-        asynchronously.
-
-        :returns: :class:`celery.result.AsyncResult` instance.
-
-        """
-        serfunc = pickle.dumps(func)
-        return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
-
-
 class PeriodicTask(Task):
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a :manpage:`cron` job.
     """A periodic task is a task that behaves like a :manpage:`cron` job.
 
 

+ 41 - 1
celery/task/builtins.py

@@ -1,7 +1,9 @@
 from datetime import timedelta
 from datetime import timedelta
 
 
-from celery.task.base import Task, PeriodicTask
 from celery.backends import default_backend
 from celery.backends import default_backend
+from celery.serialization import pickle
+from celery.task.base import Task, PeriodicTask
+from celery.task.sets import TaskSet
 
 
 
 
 class DeleteExpiredTaskMetaTask(PeriodicTask):
 class DeleteExpiredTaskMetaTask(PeriodicTask):
@@ -28,3 +30,41 @@ class PingTask(Task):
     def run(self, **kwargs):
     def run(self, **kwargs):
         """:returns: the string ``"pong"``."""
         """:returns: the string ``"pong"``."""
         return "pong"
         return "pong"
+
+
+def _dmap(fun, args, timeout=None):
+    pickled = pickle.dumps(fun)
+    arguments = [((pickled, arg, {}), {}) for arg in args]
+    ts = TaskSet(ExecuteRemoteTask, arguments)
+    return ts.apply_async().join(timeout=timeout)
+
+
+class AsynchronousMapTask(Task):
+    """Task used internally by :func:`dmap_async` and
+    :meth:`TaskSet.map_async`.  """
+    name = "celery.map_async"
+
+    def run(self, serfun, args, timeout=None, **kwargs):
+        return _dmap(pickle.loads(serfun), args, timeout=timeout)
+
+
+class ExecuteRemoteTask(Task):
+    """Execute an arbitrary function or object.
+
+    *Note* You probably want :func:`execute_remote` instead, which this
+    is an internal component of.
+
+    The object must be pickleable, so you can't use lambdas or functions
+    defined in the REPL (that is the python shell, or ``ipython``).
+
+    """
+    name = "celery.execute_remote"
+
+    def run(self, ser_callable, fargs, fkwargs, **kwargs):
+        """
+        :param ser_callable: A pickled function or callable object.
+        :param fargs: Positional arguments to apply to the function.
+        :param fkwargs: Keyword arguments to apply to the function.
+
+        """
+        return pickle.loads(ser_callable)(*fargs, **fkwargs)

+ 119 - 0
celery/task/sets.py

@@ -0,0 +1,119 @@
+from celery import conf
+from celery.execute import apply_async
+from celery.registry import tasks
+from celery.result import TaskSetResult
+from celery.utils import gen_unique_id, padlist
+
+
+class TaskSet(object):
+    """A task containing several subtasks, making it possible
+    to track how many, or when all of the tasks has been completed.
+
+    :param task: The task class or name.
+        Can either be a fully qualified task name, or a task class.
+
+    :param args: A list of args, kwargs pairs.
+        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
+
+
+    .. attribute:: task_name
+
+        The name of the task.
+
+    .. attribute:: arguments
+
+        The arguments, as passed to the task set constructor.
+
+    .. attribute:: total
+
+        Total number of tasks in this task set.
+
+    Example
+
+        >>> from djangofeeds.tasks import RefreshFeedTask
+        >>> taskset = TaskSet(RefreshFeedTask, args=[
+        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
+        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
+        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
+        ... ])
+
+        >>> taskset_result = taskset.apply_async()
+        >>> list_of_return_values = taskset_result.join()
+
+    """
+
+    def __init__(self, task, args):
+        try:
+            task_name = task.name
+            task_obj = task
+        except AttributeError:
+            task_name = task
+            task_obj = tasks[task_name]
+
+        # Get task instance
+        task_obj = tasks[task_obj.name]
+
+        self.task = task_obj
+        self.task_name = task_name
+        self.arguments = args
+        self.total = len(args)
+
+    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
+        """Run all tasks in the taskset.
+
+        :returns: A :class:`celery.result.TaskSetResult` instance.
+
+        Example
+
+            >>> ts = TaskSet(RefreshFeedTask, args=[
+            ...         (["http://foo.com/rss"], {}),
+            ...         (["http://bar.com/rss"], {}),
+            ... ])
+            >>> result = ts.apply_async()
+            >>> result.taskset_id
+            "d2c9b261-8eff-4bfb-8459-1e1b72063514"
+            >>> result.subtask_ids
+            ["b4996460-d959-49c8-aeb9-39c530dcde25",
+            "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
+            >>> result.waiting()
+            True
+            >>> time.sleep(10)
+            >>> result.ready()
+            True
+            >>> result.successful()
+            True
+            >>> result.failed()
+            False
+            >>> result.join()
+            [True, True]
+
+        """
+        if conf.ALWAYS_EAGER:
+            return self.apply()
+
+        taskset_id = gen_unique_id()
+        conn = self.task.establish_connection(connect_timeout=connect_timeout)
+        publisher = self.task.get_publisher(connection=conn)
+        try:
+            subtasks = [self.apply_part(arglist, taskset_id, publisher)
+                            for arglist in self.arguments]
+        finally:
+            publisher.close()
+            conn.close()
+
+        return TaskSetResult(taskset_id, subtasks)
+
+    def apply_part(self, arglist, taskset_id, publisher):
+        """Apply a single part of the taskset."""
+        args, kwargs, opts = padlist(arglist, 3, default={})
+        return apply_async(self.task, args, kwargs,
+                           taskset_id=taskset_id, publisher=publisher, **opts)
+
+    def apply(self):
+        """Applies the taskset locally."""
+        taskset_id = gen_unique_id()
+        subtasks = [apply(self.task, args, kwargs)
+                        for args, kwargs in self.arguments]
+
+        # This will be filled with EagerResults.
+        return TaskSetResult(taskset_id, subtasks)

+ 12 - 20
celery/tests/test_task.py

@@ -19,6 +19,9 @@ from celery.decorators import task as task_dec
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
 from celery.worker.listener import parse_iso8601
 from celery.worker.listener import parse_iso8601
 
 
+from celery.tests.utils import with_eager_tasks
+
+
 def return_True(*args, **kwargs):
 def return_True(*args, **kwargs):
     # Task run functions can't be closures/lambdas, as they're pickled.
     # Task run functions can't be closures/lambdas, as they're pickled.
     return True
     return True
@@ -216,36 +219,28 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertEqual(result.backend, RetryTask.backend)
         self.assertEqual(result.backend, RetryTask.backend)
         self.assertEqual(result.task_id, task_id)
         self.assertEqual(result.task_id, task_id)
 
 
+    @with_eager_tasks
     def test_ping(self):
     def test_ping(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         self.assertEqual(task.ping(), 'pong')
         self.assertEqual(task.ping(), 'pong')
-        conf.ALWAYS_EAGER = False
 
 
+    @with_eager_tasks
     def test_execute_remote(self):
     def test_execute_remote(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         self.assertEqual(task.execute_remote(return_True, ["foo"]).get(),
         self.assertEqual(task.execute_remote(return_True, ["foo"]).get(),
-                          True)
-        conf.ALWAYS_EAGER = False
+                         True)
 
 
+    @with_eager_tasks
     def test_dmap(self):
     def test_dmap(self):
-        from celery import conf
         import operator
         import operator
-        conf.ALWAYS_EAGER = True
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
         self.assertEqual(sum(res), sum(operator.add(x, x)
         self.assertEqual(sum(res), sum(operator.add(x, x)
-                                        for x in xrange(10)))
-        conf.ALWAYS_EAGER = False
+                                    for x in xrange(10)))
 
 
+    @with_eager_tasks
     def test_dmap_async(self):
     def test_dmap_async(self):
-        from celery import conf
         import operator
         import operator
-        conf.ALWAYS_EAGER = True
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
         self.assertEqual(sum(res.get()), sum(operator.add(x, x)
         self.assertEqual(sum(res.get()), sum(operator.add(x, x)
-                                                for x in xrange(10)))
-        conf.ALWAYS_EAGER = False
+                                            for x in xrange(10)))
 
 
     def assertNextTaskDataEqual(self, consumer, presult, task_name,
     def assertNextTaskDataEqual(self, consumer, presult, task_name,
             test_eta=False, **kwargs):
             test_eta=False, **kwargs):
@@ -355,16 +350,13 @@ class TestCeleryTasks(unittest.TestCase):
 
 
 class TestTaskSet(unittest.TestCase):
 class TestTaskSet(unittest.TestCase):
 
 
+    @with_eager_tasks
     def test_function_taskset(self):
     def test_function_taskset(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         ts = task.TaskSet(return_True_task.name, [
         ts = task.TaskSet(return_True_task.name, [
-            ([1], {}), [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
+              ([1], {}), [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
         res = ts.apply_async()
         res = ts.apply_async()
         self.assertListEqual(res.join(), [True, True, True, True, True])
         self.assertListEqual(res.join(), [True, True, True, True, True])
 
 
-        conf.ALWAYS_EAGER = False
-
     def test_counter_taskset(self):
     def test_counter_taskset(self):
         IncrementCounterTask.count = 0
         IncrementCounterTask.count = 0
         ts = task.TaskSet(IncrementCounterTask, [
         ts = task.TaskSet(IncrementCounterTask, [

+ 1 - 1
celery/tests/test_task_builtins.py

@@ -1,6 +1,6 @@
 import unittest2 as unittest
 import unittest2 as unittest
 
 
-from celery.task.base import ExecuteRemoteTask
+from celery.task.builtins import ExecuteRemoteTask
 from celery.task.builtins import PingTask, DeleteExpiredTaskMetaTask
 from celery.task.builtins import PingTask, DeleteExpiredTaskMetaTask
 from celery.serialization import pickle
 from celery.serialization import pickle
 
 

+ 13 - 0
celery/tests/utils.py

@@ -79,6 +79,19 @@ def eager_tasks():
     conf.ALWAYS_EAGER = prev
     conf.ALWAYS_EAGER = prev
 
 
 
 
+def with_eager_tasks(fun):
+
+    @wraps(fun)
+    def _inner(*args, **kwargs):
+        from celery import conf
+        prev = conf.ALWAYS_EAGER
+        conf.ALWAYS_EAGER = True
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            conf.ALWAYS_EAGER = prev
+
+
 def with_environ(env_name, env_value):
 def with_environ(env_name, env_value):
 
 
     def _envpatched(fun):
     def _envpatched(fun):

+ 2 - 20
celery/utils/timeutils.py

@@ -29,20 +29,6 @@ def delta_resolution(dt, delta):
     will be rounded to the nearest hour, and so on until seconds
     will be rounded to the nearest hour, and so on until seconds
     which will just return the original datetime.
     which will just return the original datetime.
 
 
-    Examples::
-
-        >>> now = datetime.now()
-        >>> now
-        datetime.datetime(2010, 3, 30, 11, 50, 58, 41065)
-        >>> delta_resolution(now, timedelta(days=2))
-        datetime.datetime(2010, 3, 30, 0, 0)
-        >>> delta_resolution(now, timedelta(hours=2))
-        datetime.datetime(2010, 3, 30, 11, 0)
-        >>> delta_resolution(now, timedelta(minutes=2))
-        datetime.datetime(2010, 3, 30, 11, 50)
-        >>> delta_resolution(now, timedelta(seconds=2))
-        datetime.datetime(2010, 3, 30, 11, 50, 58, 41065)
-
     """
     """
     delta = timedelta_seconds(delta)
     delta = timedelta_seconds(delta)
 
 
@@ -107,12 +93,8 @@ def rate(rate):
 def weekday(name):
 def weekday(name):
     """Return the position of a weekday (0 - 7, where 0 is Sunday).
     """Return the position of a weekday (0 - 7, where 0 is Sunday).
 
 
-        >>> weekday("sunday")
-        0
-        >>> weekday("sun")
-        0
-        >>> weekday("mon")
-        1
+        >>> weekday("sunday"), weekday("sun"), weekday("mon")
+        (0, 0, 1)
 
 
     """
     """
     abbreviation = name[0:3].lower()
     abbreviation = name[0:3].lower()