|
@@ -6,16 +6,17 @@ import socket
|
|
|
import threading
|
|
|
import time
|
|
|
|
|
|
-from itertools import count
|
|
|
-
|
|
|
from kombu.entity import Exchange, Queue
|
|
|
from kombu.messaging import Consumer, Producer
|
|
|
|
|
|
from celery import states
|
|
|
from celery.exceptions import TimeoutError
|
|
|
+from celery.utils.log import get_logger
|
|
|
|
|
|
from .base import BaseDictBackend
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
|
|
|
class BacklogLimitExceeded(Exception):
|
|
|
"""Too much state history to fast-forward."""
|
|
@@ -39,6 +40,13 @@ class AMQPBackend(BaseDictBackend):
|
|
|
|
|
|
supports_native_join = True
|
|
|
|
|
|
+ retry_policy = {
|
|
|
+ "max_retries": 20,
|
|
|
+ "interval_start": 0,
|
|
|
+ "interval_step": 1,
|
|
|
+ "interval_max": 1,
|
|
|
+ }
|
|
|
+
|
|
|
def __init__(self, connection=None, exchange=None, exchange_type=None,
|
|
|
persistent=None, serializer=None, auto_delete=True,
|
|
|
**kwargs):
|
|
@@ -83,19 +91,6 @@ class AMQPBackend(BaseDictBackend):
|
|
|
auto_delete=self.auto_delete,
|
|
|
queue_arguments=self.queue_arguments)
|
|
|
|
|
|
- def _create_producer(self, task_id, connection):
|
|
|
- self._create_binding(task_id)(connection.default_channel).declare()
|
|
|
- return self.Producer(connection, exchange=self.exchange,
|
|
|
- routing_key=task_id.replace("-", ""),
|
|
|
- serializer=self.serializer)
|
|
|
-
|
|
|
- def _create_consumer(self, bindings, channel):
|
|
|
- return self.Consumer(channel, bindings, no_ack=True)
|
|
|
-
|
|
|
- def _publish_result(self, connection, task_id, meta):
|
|
|
- # cache single channel
|
|
|
- self._create_producer(task_id, connection).publish(meta)
|
|
|
-
|
|
|
def revive(self, channel):
|
|
|
pass
|
|
|
|
|
@@ -104,27 +99,18 @@ class AMQPBackend(BaseDictBackend):
|
|
|
interval_max=1):
|
|
|
"""Send task return value and status."""
|
|
|
with self.mutex:
|
|
|
- with self.app.pool.acquire(block=True) as conn:
|
|
|
-
|
|
|
- def errback(error, delay):
|
|
|
- print("Couldn't send result for %r: %r. Retry in %rs." % (
|
|
|
- task_id, error, delay))
|
|
|
-
|
|
|
- send = conn.ensure(self, self._publish_result,
|
|
|
- max_retries=max_retries,
|
|
|
- errback=errback,
|
|
|
- interval_start=interval_start,
|
|
|
- interval_step=interval_step,
|
|
|
- interval_max=interval_max)
|
|
|
- send(conn, task_id, {"task_id": task_id, "status": status,
|
|
|
- "result": self.encode_result(result, status),
|
|
|
- "traceback": traceback,
|
|
|
- "children": self.current_task_children()})
|
|
|
+ with self.app.amqp.producer_pool.acquire(block=True) as pub:
|
|
|
+ pub.publish({"task_id": task_id, "status": status,
|
|
|
+ "result": self.encode_result(result, status),
|
|
|
+ "traceback": traceback,
|
|
|
+ "children": self.current_task_children()},
|
|
|
+ exchange=self.exchange,
|
|
|
+ routing_key=task_id.replace("-", ""),
|
|
|
+ serializer=self.serializer,
|
|
|
+ retry=True, retry_policy=self.retry_policy,
|
|
|
+ declare=[self._create_binding(task_id)])
|
|
|
return result
|
|
|
|
|
|
- def get_task_meta(self, task_id, cache=True):
|
|
|
- return self.poll(task_id)
|
|
|
-
|
|
|
def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
|
|
|
**kwargs):
|
|
|
cached_meta = self._cache.get(task_id)
|
|
@@ -147,23 +133,30 @@ class AMQPBackend(BaseDictBackend):
|
|
|
else:
|
|
|
return self.wait_for(task_id, timeout, cache)
|
|
|
|
|
|
- def poll(self, task_id, backlog_limit=100):
|
|
|
+ def get_task_meta(self, task_id, backlog_limit=1000):
|
|
|
+ # Polling and using basic_get
|
|
|
with self.app.pool.acquire_channel(block=True) as (_, channel):
|
|
|
binding = self._create_binding(task_id)(channel)
|
|
|
binding.declare()
|
|
|
latest, acc = None, None
|
|
|
- for i in count(): # fast-forward
|
|
|
+ for i in xrange(backlog_limit):
|
|
|
latest, acc = acc, binding.get(no_ack=True)
|
|
|
- if not acc:
|
|
|
+ if not acc: # no more messages
|
|
|
break
|
|
|
- if i > backlog_limit:
|
|
|
- raise self.BacklogLimitExceeded(task_id)
|
|
|
+ else:
|
|
|
+ raise self.BacklogLimitExceeded(task_id)
|
|
|
+
|
|
|
if latest:
|
|
|
+ # new state to report
|
|
|
payload = self._cache[task_id] = latest.payload
|
|
|
return payload
|
|
|
- elif task_id in self._cache: # use previously received state.
|
|
|
- return self._cache[task_id]
|
|
|
- return {"status": states.PENDING, "result": None}
|
|
|
+ else:
|
|
|
+ # no new state, use previous
|
|
|
+ try:
|
|
|
+ return self._cache[task_id]
|
|
|
+ except KeyError:
|
|
|
+ # result probably pending.
|
|
|
+ return {"status": states.PENDING, "result": None}
|
|
|
|
|
|
def drain_events(self, connection, consumer, timeout=None, now=time.time):
|
|
|
wait = connection.drain_events
|
|
@@ -190,13 +183,12 @@ class AMQPBackend(BaseDictBackend):
|
|
|
def consume(self, task_id, timeout=None):
|
|
|
with self.app.pool.acquire_channel(block=True) as (conn, channel):
|
|
|
binding = self._create_binding(task_id)
|
|
|
- with self._create_consumer(binding, channel) as consumer:
|
|
|
+ with self.Consumer(channel, binding, no_ack=True) as consumer:
|
|
|
return self.drain_events(conn, consumer, timeout).values()[0]
|
|
|
|
|
|
def get_many(self, task_ids, timeout=None, **kwargs):
|
|
|
with self.app.pool.acquire_channel(block=True) as (conn, channel):
|
|
|
ids = set(task_ids)
|
|
|
- cached_ids = set()
|
|
|
for task_id in ids:
|
|
|
try:
|
|
|
cached = self._cache[task_id]
|
|
@@ -205,11 +197,10 @@ class AMQPBackend(BaseDictBackend):
|
|
|
else:
|
|
|
if cached["status"] in states.READY_STATES:
|
|
|
yield task_id, cached
|
|
|
- cached_ids.add(task_id)
|
|
|
+ ids.discard(task_id)
|
|
|
|
|
|
- ids ^= cached_ids
|
|
|
bindings = [self._create_binding(task_id) for task_id in task_ids]
|
|
|
- with self._create_consumer(bindings, channel) as consumer:
|
|
|
+ with self.Consumer(channel, bindings, no_ack=True) as consumer:
|
|
|
while ids:
|
|
|
r = self.drain_events(conn, consumer, timeout)
|
|
|
ids ^= set(r)
|
|
@@ -238,12 +229,11 @@ class AMQPBackend(BaseDictBackend):
|
|
|
"delete_taskset is not supported by this backend.")
|
|
|
|
|
|
def __reduce__(self, args=(), kwargs={}):
|
|
|
- kwargs.update(
|
|
|
- dict(connection=self._connection,
|
|
|
- exchange=self.exchange.name,
|
|
|
- exchange_type=self.exchange.type,
|
|
|
- persistent=self.persistent,
|
|
|
- serializer=self.serializer,
|
|
|
- auto_delete=self.auto_delete,
|
|
|
- expires=self.expires))
|
|
|
+ kwargs.update(connection=self._connection,
|
|
|
+ exchange=self.exchange.name,
|
|
|
+ exchange_type=self.exchange.type,
|
|
|
+ persistent=self.persistent,
|
|
|
+ serializer=self.serializer,
|
|
|
+ auto_delete=self.auto_delete,
|
|
|
+ expires=self.expires)
|
|
|
return super(AMQPBackend, self).__reduce__(args, kwargs)
|