debug.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # -*- coding: utf-8 -*-
  2. """Utilities for debugging memory usage, blocking calls, etc."""
  3. from __future__ import absolute_import, print_function, unicode_literals
  4. import os
  5. import sys
  6. import traceback
  7. from contextlib import contextmanager
  8. from functools import partial
  9. from pprint import pprint
  10. from celery.five import WhateverIO, items, range
  11. from celery.platforms import signals
  12. try:
  13. from psutil import Process
  14. except ImportError:
  15. Process = None # noqa
  16. __all__ = [
  17. 'blockdetection', 'sample_mem', 'memdump', 'sample',
  18. 'humanbytes', 'mem_rss', 'ps', 'cry',
  19. ]
  20. UNITS = (
  21. (2 ** 40.0, 'TB'),
  22. (2 ** 30.0, 'GB'),
  23. (2 ** 20.0, 'MB'),
  24. (2 ** 10.0, 'KB'),
  25. (0.0, 'b'),
  26. )
  27. _process = None
  28. _mem_sample = []
  29. def _on_blocking(signum, frame):
  30. import inspect
  31. raise RuntimeError(
  32. 'Blocking detection timed-out at: {0}'.format(
  33. inspect.getframeinfo(frame)
  34. )
  35. )
  36. @contextmanager
  37. def blockdetection(timeout):
  38. """Context that raises an exception if process is blocking.
  39. Uses ``SIGALRM`` to detect blocking functions.
  40. """
  41. if not timeout:
  42. yield
  43. else:
  44. old_handler = signals['ALRM']
  45. old_handler = None if old_handler == _on_blocking else old_handler
  46. signals['ALRM'] = _on_blocking
  47. try:
  48. yield signals.arm_alarm(timeout)
  49. finally:
  50. if old_handler:
  51. signals['ALRM'] = old_handler
  52. signals.reset_alarm()
  53. def sample_mem():
  54. """Sample RSS memory usage.
  55. Statistics can then be output by calling :func:`memdump`.
  56. """
  57. current_rss = mem_rss()
  58. _mem_sample.append(current_rss)
  59. return current_rss
  60. def _memdump(samples=10): # pragma: no cover
  61. S = _mem_sample
  62. prev = list(S) if len(S) <= samples else sample(S, samples)
  63. _mem_sample[:] = []
  64. import gc
  65. gc.collect()
  66. after_collect = mem_rss()
  67. return prev, after_collect
  68. def memdump(samples=10, file=None): # pragma: no cover
  69. """Dump memory statistics.
  70. Will print a sample of all RSS memory samples added by
  71. calling :func:`sample_mem`, and in addition print
  72. used RSS memory after :func:`gc.collect`.
  73. """
  74. say = partial(print, file=file)
  75. if ps() is None:
  76. say('- rss: (psutil not installed).')
  77. return
  78. prev, after_collect = _memdump(samples)
  79. if prev:
  80. say('- rss (sample):')
  81. for mem in prev:
  82. say('- > {0},'.format(mem))
  83. say('- rss (end): {0}.'.format(after_collect))
  84. def sample(x, n, k=0):
  85. """Given a list `x` a sample of length ``n`` of that list is returned.
  86. For example, if `n` is 10, and `x` has 100 items, a list of every tenth.
  87. item is returned.
  88. ``k`` can be used as offset.
  89. """
  90. j = len(x) // n
  91. for _ in range(n):
  92. try:
  93. yield x[k]
  94. except IndexError:
  95. break
  96. k += j
  97. def hfloat(f, p=5):
  98. """Convert float to value suitable for humans.
  99. Arguments:
  100. f (float): The floating point number.
  101. p (int): Floating point precision (default is 5).
  102. """
  103. i = int(f)
  104. return i if i == f else '{0:.{p}}'.format(f, p=p)
  105. def humanbytes(s):
  106. """Convert bytes to human-readable form (e.g., KB, MB)."""
  107. return next(
  108. '{0}{1}'.format(hfloat(s / div if div else s), unit)
  109. for div, unit in UNITS if s >= div
  110. )
  111. def mem_rss():
  112. """Return RSS memory usage as a humanized string."""
  113. p = ps()
  114. if p is not None:
  115. return humanbytes(_process_memory_info(p).rss)
  116. def ps(): # pragma: no cover
  117. """Return the global :class:`psutil.Process` instance.
  118. Note:
  119. Returns :const:`None` if :pypi:`psutil` is not installed.
  120. """
  121. global _process
  122. if _process is None and Process is not None:
  123. _process = Process(os.getpid())
  124. return _process
  125. def _process_memory_info(process):
  126. try:
  127. return process.memory_info()
  128. except AttributeError:
  129. return process.get_memory_info()
  130. def cry(out=None, sepchr='=', seplen=49): # pragma: no cover
  131. """Return stack-trace of all active threads.
  132. See Also:
  133. Taken from https://gist.github.com/737056.
  134. """
  135. import threading
  136. out = WhateverIO() if out is None else out
  137. P = partial(print, file=out)
  138. # get a map of threads by their ID so we can print their names
  139. # during the traceback dump
  140. tmap = {t.ident: t for t in threading.enumerate()}
  141. sep = sepchr * seplen
  142. for tid, frame in items(sys._current_frames()):
  143. thread = tmap.get(tid)
  144. if not thread:
  145. # skip old junk (left-overs from a fork)
  146. continue
  147. P('{0.name}'.format(thread))
  148. P(sep)
  149. traceback.print_stack(frame, file=out)
  150. P(sep)
  151. P('LOCAL VARIABLES')
  152. P(sep)
  153. pprint(frame.f_locals, stream=out)
  154. P('\n')
  155. return out.getvalue()