|
@@ -3,7 +3,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
|
|
|
from celery.result import AsyncResult, EagerResult
|
|
|
from celery.messaging import TaskPublisher
|
|
|
from celery.registry import tasks
|
|
|
-from celery.utils import gen_unique_id, noop
|
|
|
+from celery.utils import gen_unique_id, noop, fun_takes_kwargs
|
|
|
from functools import partial as curry
|
|
|
from datetime import datetime, timedelta
|
|
|
from multiprocessing import get_logger
|
|
@@ -152,12 +152,17 @@ def apply(task, args, kwargs, **options):
|
|
|
# for it to be callable.
|
|
|
task = inspect.isclass(task) and task() or task
|
|
|
|
|
|
- kwargs.update({"task_name": task.name,
|
|
|
- "task_id": task_id,
|
|
|
- "task_retries": retries,
|
|
|
- "task_is_eager": True,
|
|
|
- "logfile": None,
|
|
|
- "loglevel": 0})
|
|
|
+ default_kwargs = {"task_name": task.name,
|
|
|
+ "task_id": task_id,
|
|
|
+ "task_retries": retries,
|
|
|
+ "task_is_eager": True,
|
|
|
+ "logfile": None,
|
|
|
+ "loglevel": 0}
|
|
|
+ fun = getattr(task, "run", task)
|
|
|
+ supported_keys = fun_takes_kwargs(fun, default_kwargs)
|
|
|
+ extend_with = dict((key, val) for key, val in default_kwargs.items()
|
|
|
+ if key in supported_keys)
|
|
|
+ kwargs.update(extend_with)
|
|
|
|
|
|
try:
|
|
|
ret_value = task(*args, **kwargs)
|