Browse Source

Refactored stuff out of celery.task.base.

TaskSet is now in celery.task.sets.TaskSet
ExecuteRemoteTask is now in celery.task.builtins

The following methods has been removed:

    TaskSet.remote_execute (use celery.task.remote_execute)
    Taskset.map (use celery.task.dmap)
    Taskset.map_async (use celery.task.dmap_async)
Ask Solem 14 years ago
parent
commit
ad955af185

+ 6 - 4
celery/task/__init__.py

@@ -7,8 +7,10 @@ Working with tasks and task sets.
 from celery.execute import apply_async
 from celery.registry import tasks
 from celery.serialization import pickle
-from celery.task.base import Task, TaskSet, PeriodicTask, ExecuteRemoteTask
-from celery.task.builtins import PingTask
+from celery.task.base import Task, PeriodicTask
+from celery.task.sets import TaskSet
+from celery.task.builtins import PingTask, ExecuteRemoteTask
+from celery.task.builtins import AsynchronousMapTask, _dmap
 from celery.task.control import discard_all
 from celery.task.http import HttpDispatchTask
 
@@ -27,7 +29,7 @@ def dmap(fun, args, timeout=None):
         [4, 8, 16]
 
     """
-    return TaskSet.map(fun, args, timeout=timeout)
+    return _dmap(fun, args, timeout)
 
 
 def dmap_async(fun, args, timeout=None):
@@ -49,7 +51,7 @@ def dmap_async(fun, args, timeout=None):
         [4, 8, 16]
 
     """
-    return TaskSet.map_async(fun, args, timeout=timeout)
+    return AsynchronousMapTask.delay(pickle.dumps(fun), args, timeout=timeout)
 
 
 def execute_remote(fun, *args, **kwargs):

+ 2 - 174
celery/task/base.py

@@ -3,18 +3,17 @@ from datetime import timedelta
 
 from celery import conf
 from celery.log import setup_task_logger
-from celery.utils import gen_unique_id, padlist
 from celery.utils.timeutils import timedelta_seconds
-from celery.result import BaseAsyncResult, TaskSetResult, EagerResult
+from celery.result import BaseAsyncResult, EagerResult
 from celery.execute import apply_async, apply
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.messaging import establish_connection as _establish_connection
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
-from celery.serialization import pickle
 
 from celery.task.schedules import schedule
+from celery.task.sets import TaskSet
 
 
 class TaskType(type):
@@ -499,177 +498,6 @@ class Task(object):
         return "<%s: %s (%s)>" % (kind, self.name, self.type)
 
 
-class ExecuteRemoteTask(Task):
-    """Execute an arbitrary function or object.
-
-    *Note* You probably want :func:`execute_remote` instead, which this
-    is an internal component of.
-
-    The object must be pickleable, so you can't use lambdas or functions
-    defined in the REPL (that is the python shell, or ``ipython``).
-
-    """
-    name = "celery.execute_remote"
-
-    def run(self, ser_callable, fargs, fkwargs, **kwargs):
-        """
-        :param ser_callable: A pickled function or callable object.
-        :param fargs: Positional arguments to apply to the function.
-        :param fkwargs: Keyword arguments to apply to the function.
-
-        """
-        return pickle.loads(ser_callable)(*fargs, **fkwargs)
-
-
-class AsynchronousMapTask(Task):
-    """Task used internally by :func:`dmap_async` and
-    :meth:`TaskSet.map_async`.  """
-    name = "celery.map_async"
-
-    def run(self, ser_callable, args, timeout=None, **kwargs):
-        """:see :meth:`TaskSet.dmap_async`."""
-        return TaskSet.map(pickle.loads(ser_callable), args, timeout=timeout)
-
-
-class TaskSet(object):
-    """A task containing several subtasks, making it possible
-    to track how many, or when all of the tasks has been completed.
-
-    :param task: The task class or name.
-        Can either be a fully qualified task name, or a task class.
-
-    :param args: A list of args, kwargs pairs.
-        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
-
-
-    .. attribute:: task_name
-
-        The name of the task.
-
-    .. attribute:: arguments
-
-        The arguments, as passed to the task set constructor.
-
-    .. attribute:: total
-
-        Total number of tasks in this task set.
-
-    Example
-
-        >>> from djangofeeds.tasks import RefreshFeedTask
-        >>> taskset = TaskSet(RefreshFeedTask, args=[
-        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
-        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
-        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
-        ... ])
-
-        >>> taskset_result = taskset.apply_async()
-        >>> list_of_return_values = taskset_result.join()
-
-    """
-
-    def __init__(self, task, args):
-        try:
-            task_name = task.name
-            task_obj = task
-        except AttributeError:
-            task_name = task
-            task_obj = tasks[task_name]
-
-        # Get task instance
-        task_obj = tasks[task_obj.name]
-
-        self.task = task_obj
-        self.task_name = task_name
-        self.arguments = args
-        self.total = len(args)
-
-    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
-        """Run all tasks in the taskset.
-
-        :returns: A :class:`celery.result.TaskSetResult` instance.
-
-        Example
-
-            >>> ts = TaskSet(RefreshFeedTask, args=[
-            ...         (["http://foo.com/rss"], {}),
-            ...         (["http://bar.com/rss"], {}),
-            ... ])
-            >>> result = ts.apply_async()
-            >>> result.taskset_id
-            "d2c9b261-8eff-4bfb-8459-1e1b72063514"
-            >>> result.subtask_ids
-            ["b4996460-d959-49c8-aeb9-39c530dcde25",
-            "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
-            >>> result.waiting()
-            True
-            >>> time.sleep(10)
-            >>> result.ready()
-            True
-            >>> result.successful()
-            True
-            >>> result.failed()
-            False
-            >>> result.join()
-            [True, True]
-
-        """
-        if conf.ALWAYS_EAGER:
-            return self.apply()
-
-        taskset_id = gen_unique_id()
-        conn = self.task.establish_connection(connect_timeout=connect_timeout)
-        publisher = self.task.get_publisher(connection=conn)
-        try:
-            subtasks = [self.apply_part(arglist, taskset_id, publisher)
-                            for arglist in self.arguments]
-        finally:
-            publisher.close()
-            conn.close()
-
-        return TaskSetResult(taskset_id, subtasks)
-
-    def apply_part(self, arglist, taskset_id, publisher):
-        """Apply a single part of the taskset."""
-        args, kwargs, opts = padlist(arglist, 3, default={})
-        return apply_async(self.task, args, kwargs,
-                           taskset_id=taskset_id, publisher=publisher, **opts)
-
-    def apply(self):
-        """Applies the taskset locally."""
-        taskset_id = gen_unique_id()
-        subtasks = [apply(self.task, args, kwargs)
-                        for args, kwargs in self.arguments]
-
-        # This will be filled with EagerResults.
-        return TaskSetResult(taskset_id, subtasks)
-
-    @classmethod
-    def remote_execute(cls, func, args):
-        """Apply ``args`` to function by distributing the args to the
-        celery server(s)."""
-        pickled = pickle.dumps(func)
-        arguments = [((pickled, arg, {}), {}) for arg in args]
-        return cls(ExecuteRemoteTask, arguments)
-
-    @classmethod
-    def map(cls, func, args, timeout=None):
-        """Distribute processing of the arguments and collect the results."""
-        remote_task = cls.remote_execute(func, args)
-        return remote_task.apply_async().join(timeout=timeout)
-
-    @classmethod
-    def map_async(cls, func, args, timeout=None):
-        """Distribute processing of the arguments and collect the results
-        asynchronously.
-
-        :returns: :class:`celery.result.AsyncResult` instance.
-
-        """
-        serfunc = pickle.dumps(func)
-        return AsynchronousMapTask.delay(serfunc, args, timeout=timeout)
-
-
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a :manpage:`cron` job.
 

+ 41 - 1
celery/task/builtins.py

@@ -1,7 +1,9 @@
 from datetime import timedelta
 
-from celery.task.base import Task, PeriodicTask
 from celery.backends import default_backend
+from celery.serialization import pickle
+from celery.task.base import Task, PeriodicTask
+from celery.task.sets import TaskSet
 
 
 class DeleteExpiredTaskMetaTask(PeriodicTask):
@@ -28,3 +30,41 @@ class PingTask(Task):
     def run(self, **kwargs):
         """:returns: the string ``"pong"``."""
         return "pong"
+
+
+def _dmap(fun, args, timeout=None):
+    pickled = pickle.dumps(fun)
+    arguments = [((pickled, arg, {}), {}) for arg in args]
+    ts = TaskSet(ExecuteRemoteTask, arguments)
+    return ts.apply_async().join(timeout=timeout)
+
+
+class AsynchronousMapTask(Task):
+    """Task used internally by :func:`dmap_async` and
+    :meth:`TaskSet.map_async`.  """
+    name = "celery.map_async"
+
+    def run(self, serfun, args, timeout=None, **kwargs):
+        return _dmap(pickle.loads(serfun), args, timeout=timeout)
+
+
+class ExecuteRemoteTask(Task):
+    """Execute an arbitrary function or object.
+
+    *Note* You probably want :func:`execute_remote` instead, which this
+    is an internal component of.
+
+    The object must be pickleable, so you can't use lambdas or functions
+    defined in the REPL (that is the python shell, or ``ipython``).
+
+    """
+    name = "celery.execute_remote"
+
+    def run(self, ser_callable, fargs, fkwargs, **kwargs):
+        """
+        :param ser_callable: A pickled function or callable object.
+        :param fargs: Positional arguments to apply to the function.
+        :param fkwargs: Keyword arguments to apply to the function.
+
+        """
+        return pickle.loads(ser_callable)(*fargs, **fkwargs)

+ 119 - 0
celery/task/sets.py

@@ -0,0 +1,119 @@
+from celery import conf
+from celery.execute import apply_async
+from celery.registry import tasks
+from celery.result import TaskSetResult
+from celery.utils import gen_unique_id, padlist
+
+
+class TaskSet(object):
+    """A task containing several subtasks, making it possible
+    to track how many, or when all of the tasks has been completed.
+
+    :param task: The task class or name.
+        Can either be a fully qualified task name, or a task class.
+
+    :param args: A list of args, kwargs pairs.
+        e.g. ``[[args1, kwargs1], [args2, kwargs2], ..., [argsN, kwargsN]]``
+
+
+    .. attribute:: task_name
+
+        The name of the task.
+
+    .. attribute:: arguments
+
+        The arguments, as passed to the task set constructor.
+
+    .. attribute:: total
+
+        Total number of tasks in this task set.
+
+    Example
+
+        >>> from djangofeeds.tasks import RefreshFeedTask
+        >>> taskset = TaskSet(RefreshFeedTask, args=[
+        ...                 ([], {"feed_url": "http://cnn.com/rss"}),
+        ...                 ([], {"feed_url": "http://bbc.com/rss"}),
+        ...                 ([], {"feed_url": "http://xkcd.com/rss"})
+        ... ])
+
+        >>> taskset_result = taskset.apply_async()
+        >>> list_of_return_values = taskset_result.join()
+
+    """
+
+    def __init__(self, task, args):
+        try:
+            task_name = task.name
+            task_obj = task
+        except AttributeError:
+            task_name = task
+            task_obj = tasks[task_name]
+
+        # Get task instance
+        task_obj = tasks[task_obj.name]
+
+        self.task = task_obj
+        self.task_name = task_name
+        self.arguments = args
+        self.total = len(args)
+
+    def apply_async(self, connect_timeout=conf.BROKER_CONNECTION_TIMEOUT):
+        """Run all tasks in the taskset.
+
+        :returns: A :class:`celery.result.TaskSetResult` instance.
+
+        Example
+
+            >>> ts = TaskSet(RefreshFeedTask, args=[
+            ...         (["http://foo.com/rss"], {}),
+            ...         (["http://bar.com/rss"], {}),
+            ... ])
+            >>> result = ts.apply_async()
+            >>> result.taskset_id
+            "d2c9b261-8eff-4bfb-8459-1e1b72063514"
+            >>> result.subtask_ids
+            ["b4996460-d959-49c8-aeb9-39c530dcde25",
+            "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
+            >>> result.waiting()
+            True
+            >>> time.sleep(10)
+            >>> result.ready()
+            True
+            >>> result.successful()
+            True
+            >>> result.failed()
+            False
+            >>> result.join()
+            [True, True]
+
+        """
+        if conf.ALWAYS_EAGER:
+            return self.apply()
+
+        taskset_id = gen_unique_id()
+        conn = self.task.establish_connection(connect_timeout=connect_timeout)
+        publisher = self.task.get_publisher(connection=conn)
+        try:
+            subtasks = [self.apply_part(arglist, taskset_id, publisher)
+                            for arglist in self.arguments]
+        finally:
+            publisher.close()
+            conn.close()
+
+        return TaskSetResult(taskset_id, subtasks)
+
+    def apply_part(self, arglist, taskset_id, publisher):
+        """Apply a single part of the taskset."""
+        args, kwargs, opts = padlist(arglist, 3, default={})
+        return apply_async(self.task, args, kwargs,
+                           taskset_id=taskset_id, publisher=publisher, **opts)
+
+    def apply(self):
+        """Applies the taskset locally."""
+        taskset_id = gen_unique_id()
+        subtasks = [apply(self.task, args, kwargs)
+                        for args, kwargs in self.arguments]
+
+        # This will be filled with EagerResults.
+        return TaskSetResult(taskset_id, subtasks)

+ 12 - 20
celery/tests/test_task.py

@@ -19,6 +19,9 @@ from celery.decorators import task as task_dec
 from celery.exceptions import RetryTaskError
 from celery.worker.listener import parse_iso8601
 
+from celery.tests.utils import with_eager_tasks
+
+
 def return_True(*args, **kwargs):
     # Task run functions can't be closures/lambdas, as they're pickled.
     return True
@@ -216,36 +219,28 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertEqual(result.backend, RetryTask.backend)
         self.assertEqual(result.task_id, task_id)
 
+    @with_eager_tasks
     def test_ping(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         self.assertEqual(task.ping(), 'pong')
-        conf.ALWAYS_EAGER = False
 
+    @with_eager_tasks
     def test_execute_remote(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         self.assertEqual(task.execute_remote(return_True, ["foo"]).get(),
-                          True)
-        conf.ALWAYS_EAGER = False
+                         True)
 
+    @with_eager_tasks
     def test_dmap(self):
-        from celery import conf
         import operator
-        conf.ALWAYS_EAGER = True
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
         self.assertEqual(sum(res), sum(operator.add(x, x)
-                                        for x in xrange(10)))
-        conf.ALWAYS_EAGER = False
+                                    for x in xrange(10)))
 
+    @with_eager_tasks
     def test_dmap_async(self):
-        from celery import conf
         import operator
-        conf.ALWAYS_EAGER = True
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
         self.assertEqual(sum(res.get()), sum(operator.add(x, x)
-                                                for x in xrange(10)))
-        conf.ALWAYS_EAGER = False
+                                            for x in xrange(10)))
 
     def assertNextTaskDataEqual(self, consumer, presult, task_name,
             test_eta=False, **kwargs):
@@ -355,16 +350,13 @@ class TestCeleryTasks(unittest.TestCase):
 
 class TestTaskSet(unittest.TestCase):
 
+    @with_eager_tasks
     def test_function_taskset(self):
-        from celery import conf
-        conf.ALWAYS_EAGER = True
         ts = task.TaskSet(return_True_task.name, [
-            ([1], {}), [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
+              ([1], {}), [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
         res = ts.apply_async()
         self.assertListEqual(res.join(), [True, True, True, True, True])
 
-        conf.ALWAYS_EAGER = False
-
     def test_counter_taskset(self):
         IncrementCounterTask.count = 0
         ts = task.TaskSet(IncrementCounterTask, [

+ 1 - 1
celery/tests/test_task_builtins.py

@@ -1,6 +1,6 @@
 import unittest2 as unittest
 
-from celery.task.base import ExecuteRemoteTask
+from celery.task.builtins import ExecuteRemoteTask
 from celery.task.builtins import PingTask, DeleteExpiredTaskMetaTask
 from celery.serialization import pickle
 

+ 13 - 0
celery/tests/utils.py

@@ -79,6 +79,19 @@ def eager_tasks():
     conf.ALWAYS_EAGER = prev
 
 
+def with_eager_tasks(fun):
+
+    @wraps(fun)
+    def _inner(*args, **kwargs):
+        from celery import conf
+        prev = conf.ALWAYS_EAGER
+        conf.ALWAYS_EAGER = True
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            conf.ALWAYS_EAGER = prev
+
+
 def with_environ(env_name, env_value):
 
     def _envpatched(fun):

+ 2 - 20
celery/utils/timeutils.py

@@ -29,20 +29,6 @@ def delta_resolution(dt, delta):
     will be rounded to the nearest hour, and so on until seconds
     which will just return the original datetime.
 
-    Examples::
-
-        >>> now = datetime.now()
-        >>> now
-        datetime.datetime(2010, 3, 30, 11, 50, 58, 41065)
-        >>> delta_resolution(now, timedelta(days=2))
-        datetime.datetime(2010, 3, 30, 0, 0)
-        >>> delta_resolution(now, timedelta(hours=2))
-        datetime.datetime(2010, 3, 30, 11, 0)
-        >>> delta_resolution(now, timedelta(minutes=2))
-        datetime.datetime(2010, 3, 30, 11, 50)
-        >>> delta_resolution(now, timedelta(seconds=2))
-        datetime.datetime(2010, 3, 30, 11, 50, 58, 41065)
-
     """
     delta = timedelta_seconds(delta)
 
@@ -107,12 +93,8 @@ def rate(rate):
 def weekday(name):
     """Return the position of a weekday (0 - 7, where 0 is Sunday).
 
-        >>> weekday("sunday")
-        0
-        >>> weekday("sun")
-        0
-        >>> weekday("mon")
-        1
+        >>> weekday("sunday"), weekday("sun"), weekday("mon")
+        (0, 0, 1)
 
     """
     abbreviation = name[0:3].lower()