Browse Source

Use all redis pipelines as context managers to ensure that they are always cleaned up properly, especially in the case of exceptions

Justin Patrin 9 years ago
parent
commit
c7bf57098a
1 changed files with 21 additions and 19 deletions
  1. 21 19
      celery/backends/redis.py

+ 21 - 19
celery/backends/redis.py

@@ -160,13 +160,13 @@ class RedisBackend(KeyValueStoreBackend):
         return self.ensure(self._set, (key, value), **retry_policy)
         return self.ensure(self._set, (key, value), **retry_policy)
 
 
     def _set(self, key, value):
     def _set(self, key, value):
-        pipe = self.client.pipeline()
-        if self.expires:
-            pipe.setex(key, value, self.expires)
-        else:
-            pipe.set(key, value)
-        pipe.publish(key, value)
-        pipe.execute()
+        with self.client.pipeline() as pipe:
+            if self.expires:
+                pipe.setex(key, value, self.expires)
+            else:
+                pipe.set(key, value)
+            pipe.publish(key, value)
+            pipe.execute()
 
 
     def delete(self, key):
     def delete(self, key):
         self.client.delete(key)
         self.client.delete(key)
@@ -207,13 +207,14 @@ class RedisBackend(KeyValueStoreBackend):
         jkey = self.get_key_for_group(gid, '.j')
         jkey = self.get_key_for_group(gid, '.j')
         tkey = self.get_key_for_group(gid, '.t')
         tkey = self.get_key_for_group(gid, '.t')
         result = self.encode_result(result, state)
         result = self.encode_result(result, state)
-        _, readycount, totaldiff, _, _ = client.pipeline()              \
-            .rpush(jkey, self.encode([1, tid, state, result]))          \
-            .llen(jkey)                                                 \
-            .get(tkey)                                                  \
-            .expire(jkey, 86400)                                        \
-            .expire(tkey, 86400)                                        \
-            .execute()
+        with client.pipeline() as pipe:
+            _, readycount, totaldiff, _, _ = pipe                           \
+                .rpush(jkey, self.encode([1, tid, state, result]))          \
+                .llen(jkey)                                                 \
+                .get(tkey)                                                  \
+                .expire(jkey, 86400)                                        \
+                .expire(tkey, 86400)                                        \
+                .execute()
 
 
         totaldiff = int(totaldiff or 0)
         totaldiff = int(totaldiff or 0)
 
 
@@ -222,11 +223,12 @@ class RedisBackend(KeyValueStoreBackend):
             total = callback['chord_size'] + totaldiff
             total = callback['chord_size'] + totaldiff
             if readycount == total:
             if readycount == total:
                 decode, unpack = self.decode, self._unpack_chord_result
                 decode, unpack = self.decode, self._unpack_chord_result
-                resl, _, _ = client.pipeline()  \
-                    .lrange(jkey, 0, total)     \
-                    .delete(jkey)               \
-                    .delete(tkey)               \
-                    .execute()
+                with client.pipeline() as pipe:
+                    resl, _, _ = pipe               \
+                        .lrange(jkey, 0, total)     \
+                        .delete(jkey)               \
+                        .delete(tkey)               \
+                        .execute()
                 try:
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])
                     callback.delay([unpack(tup, decode) for tup in resl])
                 except Exception as exc:
                 except Exception as exc: