|
@@ -2,6 +2,7 @@ import os
|
|
|
import sys
|
|
|
|
|
|
from datetime import timedelta
|
|
|
+from itertools import chain
|
|
|
|
|
|
from celery import routes
|
|
|
from celery.app.defaults import DEFAULTS
|
|
@@ -10,6 +11,55 @@ from celery.utils import noop, isatty
|
|
|
from celery.utils.functional import wraps
|
|
|
|
|
|
|
|
|
+class MultiDictView(AttributeDict):
|
|
|
+ """View for one more more dicts.
|
|
|
+
|
|
|
+ * When getting a key, the dicts are searched in order.
|
|
|
+ * When setting a key, the key is added to the first dict.
|
|
|
+
|
|
|
+ """
|
|
|
+ dicts = None
|
|
|
+
|
|
|
+ def __init__(self, *dicts):
|
|
|
+ self.__dict__["dicts"] = dicts
|
|
|
+
|
|
|
+ def __getitem__(self, key):
|
|
|
+ for d in self.__dict__["dicts"]:
|
|
|
+ try:
|
|
|
+ return d[key]
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
+ raise KeyError(key)
|
|
|
+
|
|
|
+ def __setitem__(self, key, value):
|
|
|
+ self.__dict__["dicts"][0][key] = value
|
|
|
+
|
|
|
+ def get(self, key, default=None):
|
|
|
+ try:
|
|
|
+ return self[key]
|
|
|
+ except KeyError:
|
|
|
+ return default
|
|
|
+
|
|
|
+ def setdefault(self, key, default):
|
|
|
+ try:
|
|
|
+ return self[key]
|
|
|
+ except KeyError:
|
|
|
+ self[key] = default
|
|
|
+ return default
|
|
|
+
|
|
|
+ def __contains__(self, key):
|
|
|
+ for d in self.__dict__["dicts"]:
|
|
|
+ if key in d:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return repr(dict(iter(self)))
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ return chain(*[d.iteritems() for d in self.__dict__["dicts"]])
|
|
|
+
|
|
|
+
|
|
|
class BaseApp(object):
|
|
|
_amqp = None
|
|
|
_backend = None
|
|
@@ -30,14 +80,13 @@ class BaseApp(object):
|
|
|
|
|
|
def merge(self, a, b):
|
|
|
"""Like ``dict(a, **b)`` except it will keep values from ``a``
|
|
|
- if the value in ``b`` is :const:`None`""".
|
|
|
+ if the value in ``b`` is :const:`None`."""
|
|
|
b = dict(b)
|
|
|
for key, value in a.items():
|
|
|
if b.get(key) is None:
|
|
|
b[key] = value
|
|
|
return b
|
|
|
|
|
|
-
|
|
|
def AsyncResult(self, task_id, backend=None):
|
|
|
from celery.result import BaseAsyncResult
|
|
|
return BaseAsyncResult(task_id, app=self,
|
|
@@ -47,7 +96,6 @@ class BaseApp(object):
|
|
|
from celery.result import TaskSetResult
|
|
|
return TaskSetResult(taskset_id, results, app=self)
|
|
|
|
|
|
-
|
|
|
def send_task(self, name, args=None, kwargs=None, countdown=None,
|
|
|
eta=None, task_id=None, publisher=None, connection=None,
|
|
|
connect_timeout=None, result_cls=None, expires=None,
|
|
@@ -117,12 +165,13 @@ class BaseApp(object):
|
|
|
|
|
|
def pre_config_merge(self, c):
|
|
|
if not c.get("CELERY_RESULT_BACKEND"):
|
|
|
- c["CELERY_RESULT_BACKEND"] = c.get("CELERY_BACKEND")
|
|
|
+ rbackend = c.get("CELERY_BACKEND")
|
|
|
+ if rbackend:
|
|
|
+ c["CELERY_RESULT_BACKEND"] = backend
|
|
|
if not c.get("BROKER_BACKEND"):
|
|
|
- c["BROKER_BACKEND"] = c.get("BROKER_TRANSPORT") or \
|
|
|
- c.get("CARROT_BACKEND")
|
|
|
- c.setdefault("CELERY_SEND_TASK_ERROR_EMAILS",
|
|
|
- c.get("SEND_CELERY_TASK_ERROR_EMAILS"))
|
|
|
+ cbackend = c.get("BROKER_TRANSPORT") or c.get("CARROT_BACKEND")
|
|
|
+ if cbackend:
|
|
|
+ c["BROKER_BACKEND"] = cbackend
|
|
|
return c
|
|
|
|
|
|
def post_config_merge(self, c):
|
|
@@ -185,7 +234,7 @@ class BaseApp(object):
|
|
|
if self._conf is None:
|
|
|
config = self.pre_config_merge(self.loader.conf)
|
|
|
self._conf = self.post_config_merge(
|
|
|
- AttributeDict(DEFAULTS, **config))
|
|
|
+ MultiDictView(config, DEFAULTS))
|
|
|
return self._conf
|
|
|
|
|
|
@property
|