Quellcode durchsuchen

Use mattrgetter in apply_async()

Ask Solem vor 15 Jahren
Ursprung
Commit
ba969f807d
3 geänderte Dateien mit 24 neuen und 20 gelöschten Zeilen
  1. 1 0
      celery/conf.py
  2. 16 20
      celery/execute.py
  3. 7 0
      celery/utils.py

+ 1 - 0
celery/conf.py

@@ -78,6 +78,7 @@ DEFAULT_AMQP_CONSUMER_QUEUES = {
             "routing_key": AMQP_CONSUMER_ROUTING_KEY,
             "exchange_type": AMQP_EXCHANGE_TYPE,
         }
+}
 AMQP_CONSUMER_QUEUES = _get("CELERY_AMQP_CONSUMER_QUEUES",
                             DEFAULT_AMQP_CONSUMER_QUEUES)
 AMQP_CONNECTION_TIMEOUT = _get("CELERY_AMQP_CONNECTION_TIMEOUT")

+ 16 - 20
celery/execute.py

@@ -7,18 +7,19 @@ from billiard.utils.functional import curry
 
 from celery import conf
 from celery import signals
-from celery.utils import gen_unique_id, noop, fun_takes_kwargs
+from celery.utils import gen_unique_id, noop, fun_takes_kwargs, mattrgetter
 from celery.result import AsyncResult, EagerResult
 from celery.registry import tasks
-from celery.messaging import TaskPublisher, with_connection_inline
+from celery.messaging import TaskPublisher, with_connection
 from celery.exceptions import RetryTaskError
 from celery.datastructures import ExceptionInfo
 
-TASK_EXEC_OPTIONS = ("routing_key", "exchange",
-                     "immediate", "mandatory",
-                     "priority", "serializer")
+extract_exec_options = mattrgetter("routing_key", "exchange",
+                                   "immediate", "mandatory",
+                                   "priority", "serializer")
 
 
+@with_connection
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
         **options):
@@ -76,25 +77,20 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     if conf.ALWAYS_EAGER:
         return apply(task, args, kwargs)
 
-    for option_name in TASK_EXEC_OPTIONS:
-        if option_name not in options:
-            options[option_name] = getattr(task, option_name, None)
+    options = dict(extract_exec_options(task), **options)
 
     if countdown: # Convert countdown to ETA.
         eta = datetime.now() + timedelta(seconds=countdown)
 
-    def _delay_task(connection):
-        publish = publisher or TaskPublisher(connection)
-        try:
-            return publish.delay_task(task.name, args or [], kwargs or {},
-                                      task_id=task_id,
-                                      eta=eta,
-                                      **options)
-        finally:
-            publisher or publish.close()
-
-    task_id = with_connection_inline(_delay_task, connection=connection,
-                                     connect_timeout=connect_timeout)
+    publish = publisher or TaskPublisher(connection)
+    try:
+        task_id = publish.delay_task(task.name, args or [], kwargs or {},
+                                     task_id=task_id,
+                                     eta=eta,
+                                     **options)
+    finally:
+        publisher or publish.close()
+
     return AsyncResult(task_id)
 
 

+ 7 - 0
celery/utils.py

@@ -67,6 +67,13 @@ def mitemgetter(*items):
     return lambda container: map(container.get, items)
 
 
+def mattrgetter(*attrs):
+    """Like :func:`operator.itemgetter` but returns ``None`` on missing
+    attributes instead of raising :exc:`AttributeError`."""
+    return lambda obj: dict((attr, getattr(obj, attr, None))
+                                for attr in attrs)
+
+
 def get_full_cls_name(cls):
     """With a class, get its full module and class name."""
     return ".".join([cls.__module__,