Browse Source

Some json implementation may give decimal.Decimal, not float. Closes #1731

Ask Solem 11 years ago
parent
commit
7f0981d5b2
2 changed files with 21 additions and 2 deletions
  1. 9 2
      celery/events/state.py
  2. 12 0
      celery/tests/events/test_state.py

+ 9 - 2
celery/events/state.py

@@ -22,6 +22,7 @@ import sys
 import threading
 
 from datetime import datetime
+from decimal import Decimal
 from heapq import heapify, heappush, heappop
 from itertools import islice
 from operator import itemgetter
@@ -66,8 +67,14 @@ __all__ = ['Worker', 'Task', 'State', 'heartbeat_expires']
 
 
 def heartbeat_expires(timestamp, freq=60,
-                      expire_window=HEARTBEAT_EXPIRE_WINDOW):
-    return timestamp + freq * (expire_window / 1e2)
+                      expire_window=HEARTBEAT_EXPIRE_WINDOW,
+                      Decimal=Decimal, float=float, isinstance=isinstance):
+    # some json implementations returns decimal.Decimal objects,
+    # which are not compatible with float.
+    freq = float(freq) if isinstance(freq, Decimal) else freq
+    if isinstance(timestamp, Decimal):
+        timestamp = float(timestamp)
+    return timestamp + (freq * (expire_window / 1e2))
 
 
 def _depickle_task(cls, fields):

+ 12 - 0
celery/tests/events/test_state.py

@@ -2,6 +2,7 @@ from __future__ import absolute_import
 
 import pickle
 
+from decimal import Decimal
 from random import shuffle
 from time import time
 from itertools import count
@@ -171,6 +172,17 @@ class test_Worker(AppCase):
             hash(Worker(hostname='foo')), hash(Worker(hostname='bar')),
         )
 
+    def test_compatible_with_Decimal(self):
+        w = Worker('george@vandelay.com')
+        timestamp, local_received = Decimal(time()), time()
+        w.event('worker-online', timestamp, local_received, fields={
+            'hostname': 'george@vandelay.com',
+            'timestamp': timestamp,
+            'local_received': local_received,
+            'freq': Decimal(5.6335431),
+        })
+        self.assertTrue(w.alive)
+
     def test_survives_missing_timestamp(self):
         worker = Worker(hostname='foo')
         worker.event('heartbeat')