Bladeren bron

Fix BytesWarning in concurrency/asynpool.py (#4846)

When Python is executed with the -b CLI option, Celery issues the
following warnings:

.../celery/concurrency/asynpool.py:1254: BytesWarning: Comparison between bytes and string
  header = pack(b'>I', size)
.../celery/concurrency/asynpool.py:814: BytesWarning: Comparison between bytes and string
  header = pack(b'>I', body_size)
.../celery/concurrency/asynpool.py:76: BytesWarning: Comparison between bytes and string
  return unpack(fmt, iobuf.getvalue())  # <-- BytesIO

This occurs due to passing a bytes object to the fmt argument of struct
functions. The solution was borrowed from py-amqp:
https://github.com/celery/py-amqp/pull/117#issuecomment-267181264

For a discussion on passing str instead of bytes to struct functions,
see: https://github.com/python/typeshed/pull/669

Information on the -b CLI option:
https://docs.python.org/3/using/cmdline.html#miscellaneous-options

  -b

  Issue a warning when comparing bytes or bytearray with str or bytes
  with int. Issue an error when the option is given twice (-bb).
Jon Dufresne 6 jaren geleden
bovenliggende
commit
a7c741d78f
2 gewijzigde bestanden met toevoegingen van 28 en 11 verwijderingen
  1. 9 11
      celery/concurrency/asynpool.py
  2. 19 0
      celery/platforms.py

+ 9 - 11
celery/concurrency/asynpool.py

@@ -20,7 +20,6 @@ import gc
 import os
 import select
 import socket
-import struct
 import sys
 import time
 from collections import deque, namedtuple
@@ -41,6 +40,7 @@ from kombu.utils.functional import fxrange
 from vine import promise
 
 from celery.five import Counter, items, values
+from celery.platforms import pack, unpack, unpack_from
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.worker import state as worker_state
@@ -52,13 +52,11 @@ try:
     from _billiard import read as __read__
     readcanbuf = True
 
+    # unpack_from supports memoryview in 2.7.6 and 3.3+
     if sys.version_info[0] == 2 and sys.version_info < (2, 7, 6):
 
-        def unpack_from(fmt, view, _unpack_from=struct.unpack_from):  # noqa
+        def unpack_from(fmt, view, _unpack_from=unpack_from):  # noqa
             return _unpack_from(fmt, view.tobytes())  # <- memoryview
-    else:
-        # unpack_from supports memoryview in 2.7.6 and 3.3+
-        unpack_from = struct.unpack_from  # noqa
 
 except ImportError:  # pragma: no cover
 
@@ -70,7 +68,7 @@ except ImportError:  # pragma: no cover
         return n
     readcanbuf = False  # noqa
 
-    def unpack_from(fmt, iobuf, unpack=struct.unpack):  # noqa
+    def unpack_from(fmt, iobuf, unpack=unpack):  # noqa
         return unpack(fmt, iobuf.getvalue())  # <-- BytesIO
 
 __all__ = ('AsynPool',)
@@ -252,7 +250,7 @@ class ResultHandler(_pool.ResultHandler):
                            else EOFError())
                 Hr += n
 
-        body_size, = unpack_from(b'>i', bufv)
+        body_size, = unpack_from('>i', bufv)
         if readcanbuf:
             buf = bytearray(body_size)
             bufv = memoryview(buf)
@@ -658,7 +656,7 @@ class AsynPool(_pool.Pool):
         self.on_process_down = on_process_down
 
     def _create_write_handlers(self, hub,
-                               pack=struct.pack, dumps=_pickle.dumps,
+                               pack=pack, dumps=_pickle.dumps,
                                protocol=HIGHEST_PROTOCOL):
         """Create handlers used to write data to child processes."""
         fileno_to_inq = self._fileno_to_inq
@@ -820,7 +818,7 @@ class AsynPool(_pool.Pool):
             # inqueues are writable.
             body = dumps(tup, protocol=protocol)
             body_size = len(body)
-            header = pack(b'>I', body_size)
+            header = pack('>I', body_size)
             # index 1,0 is the job ID.
             job = get_job(tup[1][0])
             job._payload = buf_t(header), buf_t(body), body_size
@@ -1255,11 +1253,11 @@ class AsynPool(_pool.Pool):
         return removed
 
     def _create_payload(self, type_, args,
-                        dumps=_pickle.dumps, pack=struct.pack,
+                        dumps=_pickle.dumps, pack=pack,
                         protocol=HIGHEST_PROTOCOL):
         body = dumps((type_, args), protocol=protocol)
         size = len(body)
-        header = pack(b'>I', size)
+        header = pack('>I', size)
         return header, body, size
 
     @classmethod

+ 19 - 0
celery/platforms.py

@@ -13,6 +13,7 @@ import numbers
 import os
 import platform as _platform
 import signal as _signal
+import struct
 import sys
 import warnings
 from collections import namedtuple
@@ -795,3 +796,21 @@ def check_privileges(accept_content):
         warnings.warn(RuntimeWarning(ROOT_DISCOURAGED.format(
             uid=uid, euid=euid, gid=gid, egid=egid,
         )))
+
+
+if sys.version_info < (2, 7, 7):  # pragma: no cover
+    import functools
+
+    def _to_bytes_arg(fun):
+        @functools.wraps(fun)
+        def _inner(s, *args, **kwargs):
+            return fun(s.encode(), *args, **kwargs)
+        return _inner
+
+    pack = _to_bytes_arg(struct.pack)
+    unpack = _to_bytes_arg(struct.unpack)
+    unpack_from = _to_bytes_arg(struct.unpack_from)
+else:
+    pack = struct.pack
+    unpack = struct.unpack
+    unpack_from = struct.unpack_from