소스 검색

Signals: Sender argument must account for Proxy/PromiseProxy. Issue #1873

Ask Solem 11 년 전
부모
커밋
c403d0be11
3개의 변경된 파일70개의 추가작업 그리고 1개의 파일을 삭제
  1. 30 1
      celery/local.py
  2. 24 0
      celery/tests/utils/test_local.py
  3. 16 0
      celery/utils/dispatch/signal.py

+ 30 - 1
celery/local.py

@@ -157,7 +157,7 @@ class Proxy(object):
     __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v)
     __delattr__ = lambda x, n: delattr(x._get_current_object(), n)
     __str__ = lambda x: str(x._get_current_object())
-    __lt__ = lambda x, o: x._get_current_object() < o
+    __lt_ = lambda x, o: x._get_current_object() < o
     __le__ = lambda x, o: x._get_current_object() <= o
     __eq__ = lambda x, o: x._get_current_object() == o
     __ne__ = lambda x, o: x._get_current_object() != o
@@ -212,12 +212,27 @@ class PromiseProxy(Proxy):
 
     """
 
+    __slots__ = ('__pending__', )
+
     def _get_current_object(self):
         try:
             return object.__getattribute__(self, '__thing')
         except AttributeError:
             return self.__evaluate__()
 
+    def __then__(self, fun, *args, **kwargs):
+        if self.__evaluated__():
+            return fun(*args, **kwargs)
+        from collections import deque
+        try:
+            pending = object.__getattribute__(self, '__pending__')
+        except AttributeError:
+            pending = None
+        if pending is None:
+            pending = deque()
+            object.__setattr__(self, '__pending__', pending)
+        pending.append((fun, args, kwargs))
+
     def __evaluated__(self):
         try:
             object.__getattribute__(self, '__thing')
@@ -243,6 +258,20 @@ class PromiseProxy(Proxy):
                 except AttributeError:  # pragma: no cover
                     # May mask errors so ignore
                     pass
+            try:
+                pending = object.__getattribute__(self, '__pending__')
+            except AttributeError:
+                pass
+            else:
+                try:
+                    while pending:
+                        fun, args, kwargs = pending.popleft()
+                        fun(*args, **kwargs)
+                finally:
+                    try:
+                        object.__delattr__(self, '__pending__')
+                    except AttributeError:
+                        pass
 
 
 def maybe_evaluate(obj):

+ 24 - 0
celery/tests/utils/test_local.py

@@ -329,6 +329,30 @@ class test_PromiseProxy(Case):
         self.assertEqual(p.attr, 123)
         self.assertEqual(X.evals, 1)
 
+    def test_callbacks(self):
+        source = Mock(name='source')
+        p = PromiseProxy(source)
+        cbA = Mock(name='cbA')
+        cbB = Mock(name='cbB')
+        cbC = Mock(name='cbC')
+        p.__then__(cbA, p)
+        p.__then__(cbB, p)
+        self.assertFalse(p.__evaluated__())
+        self.assertTrue(object.__getattribute__(p, '__pending__'))
+
+        self.assertTrue(repr(p))
+        with self.assertRaises(AttributeError):
+            object.__getattribute__(p, '__pending__')
+        cbA.assert_called_with(p)
+        cbB.assert_called_with(p)
+
+        self.assertTrue(p.__evaluated__())
+        p.__then__(cbC, p)
+        cbC.assert_called_with(p)
+
+        with self.assertRaises(AttributeError):
+            object.__getattribute__(p, '__pending__')
+
     def test_maybe_evaluate(self):
         x = PromiseProxy(lambda: 30)
         self.assertFalse(x.__evaluated__())

+ 16 - 0
celery/utils/dispatch/signal.py

@@ -4,7 +4,9 @@ from __future__ import absolute_import
 
 import weakref
 from . import saferef
+
 from celery.five import range
+from celery.local import PromiseProxy, Proxy
 
 __all__ = ['Signal']
 
@@ -12,6 +14,8 @@ WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)
 
 
 def _make_id(target):  # pragma: no cover
+    if isinstance(target, Proxy):
+        target = target._get_current_object()
     if hasattr(target, '__func__'):
         return (id(target.__self__), id(target.__func__))
     return id(target)
@@ -39,6 +43,12 @@ class Signal(object):  # pragma: no cover
             providing_args = []
         self.providing_args = set(providing_args)
 
+    def _connect_proxy(self, fun, sender, weak, dispatch_uid):
+        return self.connect(
+            fun, sender=sender._get_current_object(), weak=weak,
+                dispatch_uid=dispatch_uid,
+        )
+
     def connect(self, *args, **kwargs):
         """Connect receiver to sender for signal.
 
@@ -74,6 +84,12 @@ class Signal(object):  # pragma: no cover
             def _connect_signal(fun):
                 receiver = fun
 
+                if isinstance(sender, PromiseProxy):
+                    sender.__then__(
+                        self._connect_proxy, fun, sender, weak, dispatch_uid,
+                    )
+                    return fun
+
                 if dispatch_uid:
                     lookup_key = (dispatch_uid, _make_id(sender))
                 else: