|  | @@ -3,7 +3,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 | 
											
												
													
														|  |  from celery.result import AsyncResult, EagerResult
 |  |  from celery.result import AsyncResult, EagerResult
 | 
											
												
													
														|  |  from celery.messaging import TaskPublisher
 |  |  from celery.messaging import TaskPublisher
 | 
											
												
													
														|  |  from celery.registry import tasks
 |  |  from celery.registry import tasks
 | 
											
												
													
														|  | -from celery.utils import gen_unique_id, noop
 |  | 
 | 
											
												
													
														|  | 
 |  | +from celery.utils import gen_unique_id, noop, fun_takes_kwargs
 | 
											
												
													
														|  |  from functools import partial as curry
 |  |  from functools import partial as curry
 | 
											
												
													
														|  |  from datetime import datetime, timedelta
 |  |  from datetime import datetime, timedelta
 | 
											
												
													
														|  |  from multiprocessing import get_logger
 |  |  from multiprocessing import get_logger
 | 
											
										
											
												
													
														|  | @@ -152,12 +152,17 @@ def apply(task, args, kwargs, **options):
 | 
											
												
													
														|  |      # for it to be callable.
 |  |      # for it to be callable.
 | 
											
												
													
														|  |      task = inspect.isclass(task) and task() or task
 |  |      task = inspect.isclass(task) and task() or task
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    kwargs.update({"task_name": task.name,
 |  | 
 | 
											
												
													
														|  | -                   "task_id": task_id,
 |  | 
 | 
											
												
													
														|  | -                   "task_retries": retries,
 |  | 
 | 
											
												
													
														|  | -                   "task_is_eager": True,
 |  | 
 | 
											
												
													
														|  | -                   "logfile": None,
 |  | 
 | 
											
												
													
														|  | -                   "loglevel": 0})
 |  | 
 | 
											
												
													
														|  | 
 |  | +    default_kwargs = {"task_name": task.name,
 | 
											
												
													
														|  | 
 |  | +                      "task_id": task_id,
 | 
											
												
													
														|  | 
 |  | +                      "task_retries": retries,
 | 
											
												
													
														|  | 
 |  | +                      "task_is_eager": True,
 | 
											
												
													
														|  | 
 |  | +                      "logfile": None,
 | 
											
												
													
														|  | 
 |  | +                      "loglevel": 0}
 | 
											
												
													
														|  | 
 |  | +    fun = getattr(task, "run", task)
 | 
											
												
													
														|  | 
 |  | +    supported_keys = fun_takes_kwargs(fun, default_kwargs)
 | 
											
												
													
														|  | 
 |  | +    extend_with = dict((key, val) for key, val in default_kwargs.items()
 | 
											
												
													
														|  | 
 |  | +                            if key in supported_keys)
 | 
											
												
													
														|  | 
 |  | +    kwargs.update(extend_with)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      try:
 |  |      try:
 | 
											
												
													
														|  |          ret_value = task(*args, **kwargs)
 |  |          ret_value = task(*args, **kwargs)
 |