Browse Source

Tasks now supports positional arguments. celery.worker refactored some (TaskWrapper)

Ask Solem 16 years ago
parent
commit
deeab22b8e
4 changed files with 87 additions and 40 deletions
  1. 16 10
      celery/messaging.py
  2. 8 7
      celery/task.py
  3. 9 7
      celery/tests/test_task.py
  4. 54 16
      celery/worker.py

+ 16 - 10
celery/messaging.py

@@ -14,21 +14,27 @@ class TaskPublisher(Publisher):
     exchange = conf.AMQP_EXCHANGE
     routing_key = conf.AMQP_ROUTING_KEY
 
-    def delay_task(self, task_name, **task_kwargs):
-        return self._delay_task(task_name=task_name, extra_data=task_kwargs)
+    def delay_task(self, task_name, *task_args, **task_kwargs):
+        return self._delay_task(task_name=task_name, args=task_args,
+                                kwargs=task_kwargs)
 
-    def delay_task_in_set(self, task_name, taskset_id, task_kwargs):
+    def delay_task_in_set(self, task_name, taskset_id, task_args,
+            task_kwargs):
         return self._delay_task(task_name=task_name, part_of_set=taskset_id,
-                                extra_data=task_kwargs)
+                                args=task_args, kwargs=task_kwargs)
 
-    def _delay_task(self, task_name, part_of_set=None, extra_data=None):
-        extra_data = extra_data or {}
+    def _delay_task(self, task_name, part_of_set=None, args=None, kwargs=None):
+        args = args or []
+        kwargs = kwargs or {}
         task_id = str(uuid.uuid4())
-        message_data = dict(extra_data)
-        message_data["celeryTASK"] = task_name
-        message_data["celeryID"] = task_id
+        message_data = {
+            "id": task_id,
+            "task": task_name,
+            "args": args,
+            "kwargs": kwargs,
+        }
         if part_of_set:
-            message_data["celeryTASKSET"] = part_of_set
+            message_data["taskset"] = part_of_set
         self.send(message_data)
         return task_id
 

+ 8 - 7
celery/task.py

@@ -10,7 +10,7 @@ import uuid
 import traceback
 
 
-def delay_task(task_name, **kwargs):
+def delay_task(task_name, *args, **kwargs):
     """Delay a task for execution by the ``celery`` daemon.
 
     Examples
@@ -23,7 +23,7 @@ def delay_task(task_name, **kwargs):
                 "Task with name %s not registered in the task registry." % (
                     task_name))
     publisher = TaskPublisher(connection=DjangoAMQPConnection)
-    task_id = publisher.delay_task(task_name, **kwargs)
+    task_id = publisher.delay_task(task_name, *args, **kwargs)
     publisher.close()
     return task_id
 
@@ -116,11 +116,11 @@ class Task(object):
         if not self.name:
             raise NotImplementedError("Tasks must define a name attribute.")
 
-    def __call__(self, **kwargs):
+    def __call__(self, *args, **kwargs):
         """The ``__call__`` is called when you do ``Task().run()`` and calls
         the ``run`` method. It also catches any exceptions and logs them."""
         try:
-            retval = self.run(**kwargs)
+            retval = self.run(*args, **kwargs)
         except Exception, e:
             logger = self.get_logger(**kwargs)
             logger.critical("Task got exception %s: %s\n%s" % (
@@ -129,7 +129,7 @@ class Task(object):
         else:
             return retval
 
-    def run(self, **kwargs):
+    def run(self, *args, **kwargs):
         """The actual task. All subclasses of :class:`Task` must define
         the run method, if not a ``NotImplementedError`` exception is raised.
         """
@@ -148,9 +148,9 @@ class Task(object):
         return TaskConsumer(connection=DjangoAMQPConnection)
 
     @classmethod
-    def delay(cls, **kwargs):
+    def delay(cls, *args, **kwargs):
         """Delay this task for execution by the ``celery`` daemon(s)."""
-        return delay_task(cls.name, **kwargs)
+        return delay_task(cls.name, *args, **kwargs)
 
 
 class TaskSet(object):
@@ -209,6 +209,7 @@ class TaskSet(object):
         for arg in self.arguments:
             subtask_id = publisher.delay_task_in_set(task_name=self.task_name,
                                                      taskset_id=taskset_id,
+                                                     task_args=[],
                                                      task_kwargs=arg)
             subtask_ids.append(subtask_id) 
         publisher.close()

+ 9 - 7
celery/tests/test_task.py

@@ -41,10 +41,11 @@ class TestCeleryTasks(unittest.TestCase):
             **kwargs):
         next_task = consumer.fetch()
         task_data = consumer.decoder(next_task.body)
-        self.assertEquals(task_data["celeryID"], task_id)
-        self.assertEquals(task_data["celeryTASK"], task_name)
+        self.assertEquals(task_data["id"], task_id)
+        self.assertEquals(task_data["task"], task_name)
+        task_kwargs = task_data.get("kwargs", {})
         for arg_name, arg_value in kwargs.items():
-            self.assertEquals(task_data.get(arg_name), arg_value)
+            self.assertEquals(task_kwargs.get(arg_name), arg_value)
 
     def test_raising_task(self):
         rtask = self.createTaskCls("RaisingTask", "c.unittest.t.rtask")
@@ -136,8 +137,9 @@ class TestTaskSet(unittest.TestCase):
         consumer = IncrementCounterTask().get_consumer()
         for subtask_id in subtask_ids:
             m = consumer.decoder(consumer.fetch().body)
-            self.assertEquals(m.get("celeryTASKSET"), taskset_id)
-            self.assertEquals(m.get("celeryTASK"), IncrementCounterTask.name)
-            self.assertEquals(m.get("celeryID"), subtask_id)
-            IncrementCounterTask().run(increment_by=m.get("increment_by"))
+            self.assertEquals(m.get("taskset"), taskset_id)
+            self.assertEquals(m.get("task"), IncrementCounterTask.name)
+            self.assertEquals(m.get("id"), subtask_id)
+            IncrementCounterTask().run(
+                    increment_by=m.get("kwargs", {}).get("increment_by"))
         self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))

+ 54 - 16
celery/worker.py

@@ -22,6 +22,42 @@ class UnknownTask(Exception):
     ignored."""
 
 
+class TaskWrapper(object):
+    def __init__(self, task_name, task_id, task_func, args, kwargs):
+        self.task_name = task_name
+        self.task_id = task_id
+        self.task_func = task_func
+        self.args = args
+        self.kwargs = kwargs
+
+    @classmethod
+    def from_message(cls, message):
+        message_data = simplejson.loads(message.body)
+        task_name = message_data.pop("task")
+        task_id = message_data.pop("id")
+        args = message_data.pop("args")
+        kwargs = message_data.pop("kwargs")
+        if task_name not in tasks:
+            message.reject()
+            raise UnknownTask(task_name)
+        task_func = tasks[task_name]
+        return cls(task_name, task_id, task_func, args, kwargs)
+
+    def extend_kwargs_with_logging(self, loglevel, logfile):
+        task_func_kwargs = {"logfile": logfile,
+                            "loglevel": loglevel}
+        task_func_kwargs.update(self.kwargs)
+        return task_func_kwargs
+
+    def execute(self, loglevel, logfile):
+        task_func_kwargs = self.extend_kwargs_with_logging(logfile, loglevel)
+        return self.task_func(*self.args, **task_func_kwargs)
+
+    def execute_using_pool(self, pool, loglevel, logfile):
+        task_func_kwargs = self.extend_kwargs_with_logging(logfile, loglevel)
+        return pool.apply_async(self.task_func, self.args, task_func_kwargs)
+
+
 class TaskDaemon(object):
     """Executes tasks waiting in the task queue.
 
@@ -43,42 +79,44 @@ class TaskDaemon(object):
         self.logger = setup_logger(loglevel, logfile)
         self.pool = multiprocessing.Pool(self.concurrency)
         self.task_consumer = TaskConsumer(connection=DjangoAMQPConnection)
-        self.task_registry = tasks
 
     def fetch_next_task(self):
         message = self.task_consumer.fetch()
         if message is None: # No messages waiting.
             raise EmptyQueue()
 
-        message_data = simplejson.loads(message.body)
-        task_name = message_data.pop("celeryTASK")
-        task_id = message_data.pop("celeryID")
+        task = TaskWrapper.from_message(message)
         self.logger.info("Got task from broker: %s[%s]" % (
-                            task_name, task_id))
-        if task_name not in self.task_registry:
-            message.reject()
-            raise UnknownTask(task_name)
+                            task.task_name, task.task_id))
+
+        return task
 
-        task_func = self.task_registry[task_name]
-        task_func_params = {"logfile": self.logfile,
-                            "loglevel": self.loglevel}
-        task_func_params.update(message_data)
+    def execute_next_task(self):
+        task = self.fetch_next_task()
 
         try:
-            result = self.pool.apply_async(task_func, [], task_func_params)
+            result = task.execute_using_pool(self.pool, self.loglevel,
+                                             self.logfile)
         except Exception, error:
             self.logger.critical("Worker got exception %s: %s\n%s" % (
                 error.__class__, error, traceback.format_exc()))
             return 
 
         message.ack()
-        return result, task_name, task_id
+        return result, task.task_name, task.task_id
 
     def run_periodic_tasks(self):
-        for task in PeriodicTaskMeta.objects.get_waiting_tasks():
-            task.delay()
+        """Schedule all waiting periodic tasks for execution.
+       
+        Returns list of :class:`celery.models.PeriodicTaskMeta` objects.
+        """
+        waiting_tasks = PeriodicTaskMeta.objects.get_waiting_tasks()
+        [waiting_task.delay()
+                for waiting_task in waiting_tasks]
+        return waiting_tasks
 
     def run(self):
+        """Run the worker server."""
         results = ProcessQueue(self.concurrency, logger=self.logger,
                 done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         last_empty_emit = None