Browse Source

Make events.state pickleable again

Ask Solem 11 năm trước cách đây
mục cha
commit
3bc777379e
1 tập tin đã thay đổi với 32 bổ sung6 xóa
  1. 32 6
      celery/events/state.py

+ 32 - 6
celery/events/state.py

@@ -22,14 +22,14 @@ import sys
 import threading
 import threading
 
 
 from datetime import datetime
 from datetime import datetime
-from heapq import heappush, heappop
+from heapq import heapify, heappush, heappop
 from itertools import islice
 from itertools import islice
 from operator import itemgetter
 from operator import itemgetter
 from time import time
 from time import time
 from weakref import ref
 from weakref import ref
 
 
 from kombu.clocks import timetuple
 from kombu.clocks import timetuple
-from kombu.utils import cached_property
+from kombu.utils import cached_property, kwdict
 
 
 from celery import states
 from celery import states
 from celery.five import class_property, items, values
 from celery.five import class_property, items, values
@@ -70,6 +70,10 @@ def heartbeat_expires(timestamp, freq=60,
     return timestamp + freq * (expire_window / 1e2)
     return timestamp + freq * (expire_window / 1e2)
 
 
 
 
+def _depickle_task(cls, fields):
+    return cls(**(fields if CAN_KWDICT else kwdict(fields)))
+
+
 def with_unique_field(attr):
 def with_unique_field(attr):
 
 
     def _decorate_cls(cls):
     def _decorate_cls(cls):
@@ -104,14 +108,19 @@ class Worker(object):
     if not PYPY:
     if not PYPY:
         __slots__ = _fields + ('event', '__dict__', '__weakref__')
         __slots__ = _fields + ('event', '__dict__', '__weakref__')
 
 
-    def __init__(self, hostname=None, pid=None, freq=60):
+    def __init__(self, hostname=None, pid=None, freq=60,
+                 heartbeats=None, clock=0):
         self.hostname = hostname
         self.hostname = hostname
         self.pid = pid
         self.pid = pid
         self.freq = freq
         self.freq = freq
-        self.heartbeats = []
-        self.clock = 0
+        self.heartbeats = [] if heartbeats is None else heartbeats
+        self.clock = clock or 0
         self.event = self._create_event_handler()
         self.event = self._create_event_handler()
 
 
+    def __reduce__(self):
+        return self.__class__, (self.hostname, self.pid, self.freq,
+                                self.heartbeats, self.clock)
+
     def _create_event_handler(self):
     def _create_event_handler(self):
         _set = object.__setattr__
         _set = object.__setattr__
         heartbeats = self.heartbeats
         heartbeats = self.heartbeats
@@ -284,6 +293,15 @@ class Task(object):
     def __repr__(self):
     def __repr__(self):
         return R_TASK.format(self)
         return R_TASK.format(self)
 
 
+    def as_dict(self):
+        get = object.__getattribute__
+        return dict(
+            (k, get(self, k)) for k in self._fields
+        )
+
+    def __reduce__(self):
+        return _depickle_task, (self.__class__, self.as_dict())
+
     @property
     @property
     def ready(self):
     def ready(self):
         return self.state in states.READY_STATES
         return self.state in states.READY_STATES
@@ -361,6 +379,7 @@ class State(object):
         self._mutex = threading.Lock()
         self._mutex = threading.Lock()
         self.handlers = {}
         self.handlers = {}
         self._seen_types = set()
         self._seen_types = set()
+        self.rebuild_taskheap()
 
 
     @cached_property
     @cached_property
     def _event(self):
     def _event(self):
@@ -503,6 +522,13 @@ class State(object):
                 return created
                 return created
         return _event
         return _event
 
 
+    def rebuild_taskheap(self, timetuple=timetuple, heapify=heapify):
+        heap = self._taskheap[:] = [
+            timetuple(t.clock, t.timestamp, t.worker.id, ref(t))
+            for t in values(self.tasks)
+        ]
+        heapify(heap)
+
     def itertasks(self, limit=None):
     def itertasks(self, limit=None):
         for index, row in enumerate(items(self.tasks)):
         for index, row in enumerate(items(self.tasks)):
             yield row
             yield row
@@ -557,6 +583,6 @@ class State(object):
 
 
     def __reduce__(self):
     def __reduce__(self):
         return self.__class__, (
         return self.__class__, (
-            self.event_callback, self.workers, self.tasks, self._taskheap,
+            self.event_callback, self.workers, self.tasks, None,
             self.max_workers_in_memory, self.max_tasks_in_memory,
             self.max_workers_in_memory, self.max_tasks_in_memory,
         )
         )