debug.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # -*- coding: utf-8 -*-
  2. """Utilities for debugging memory usage, blocking calls, etc."""
  3. import os
  4. import sys
  5. import traceback
  6. from contextlib import contextmanager
  7. from functools import partial
  8. from io import StringIO
  9. from pprint import pprint
  10. from typing import (
  11. Any, AnyStr, Generator, IO, Iterator, Iterable,
  12. Optional, Sequence, SupportsInt, Tuple, Union,
  13. )
  14. from typing import MutableSequence # noqa
  15. from celery.platforms import signals
  16. from .typing import Timeout
  17. try:
  18. from psutil import Process
  19. except ImportError:
  20. class Process: # noqa
  21. pass
  22. __all__ = [
  23. 'blockdetection', 'sample_mem', 'memdump', 'sample',
  24. 'humanbytes', 'mem_rss', 'ps', 'cry',
  25. ]
  26. UNITS: Sequence[Tuple[float, str]] = (
  27. (2 ** 40.0, 'TB'),
  28. (2 ** 30.0, 'GB'),
  29. (2 ** 20.0, 'MB'),
  30. (2 ** 10.0, 'KB'),
  31. (0.0, 'b'),
  32. )
  33. _process: Process = None
  34. _mem_sample: MutableSequence[str] = []
  35. def _on_blocking(signum: int, frame: Any) -> None:
  36. import inspect
  37. raise RuntimeError(
  38. 'Blocking detection timed-out at: {0}'.format(
  39. inspect.getframeinfo(frame)
  40. )
  41. )
  42. @contextmanager
  43. def blockdetection(timeout: Timeout) -> Generator:
  44. """Context that raises an exception if process is blocking.
  45. Uses ``SIGALRM`` to detect blocking functions.
  46. """
  47. if not timeout:
  48. yield
  49. else:
  50. old_handler = signals['ALRM']
  51. old_handler = None if old_handler == _on_blocking else old_handler
  52. signals['ALRM'] = _on_blocking
  53. try:
  54. yield signals.arm_alarm(timeout)
  55. finally:
  56. if old_handler:
  57. signals['ALRM'] = old_handler
  58. signals.reset_alarm()
  59. def sample_mem() -> str:
  60. """Sample RSS memory usage.
  61. Statistics can then be output by calling :func:`memdump`.
  62. """
  63. current_rss = mem_rss()
  64. _mem_sample.append(current_rss)
  65. return current_rss
  66. def _memdump(samples: int=10) -> Tuple[Iterable[Any], str]: # pragma: no cover
  67. S = _mem_sample
  68. prev = list(S) if len(S) <= samples else sample(S, samples)
  69. _mem_sample[:] = []
  70. import gc
  71. gc.collect()
  72. after_collect = mem_rss()
  73. return prev, after_collect
  74. def memdump(samples: int=10, file: IO=None) -> None: # pragma: no cover
  75. """Dump memory statistics.
  76. Will print a sample of all RSS memory samples added by
  77. calling :func:`sample_mem`, and in addition print
  78. used RSS memory after :func:`gc.collect`.
  79. """
  80. say = partial(print, file=file)
  81. if ps() is None:
  82. say('- rss: (psutil not installed).')
  83. return
  84. prev, after_collect = _memdump(samples)
  85. if prev:
  86. say('- rss (sample):')
  87. for mem in prev:
  88. say('- > {0},'.format(mem))
  89. say('- rss (end): {0}.'.format(after_collect))
  90. def sample(x: Sequence, n: int, k: int=0) -> Iterator[Any]:
  91. """Given a list `x` a sample of length ``n`` of that list is returned.
  92. For example, if `n` is 10, and `x` has 100 items, a list of every tenth.
  93. item is returned.
  94. ``k`` can be used as offset.
  95. """
  96. j = len(x) // n
  97. for _ in range(n):
  98. try:
  99. yield x[k]
  100. except IndexError:
  101. break
  102. k += j
  103. def hfloat(f: Union[SupportsInt, AnyStr], p: int=5) -> str:
  104. """Convert float to value suitable for humans.
  105. Arguments:
  106. f (float): The floating point number.
  107. p (int): Floating point precision (default is 5).
  108. """
  109. i = int(f)
  110. return str(i) if i == f else '{0:.{p}}'.format(f, p=p)
  111. def humanbytes(s: Union[float, int]) -> str:
  112. """Convert bytes to human-readable form (e.g., KB, MB)."""
  113. return next(
  114. '{0}{1}'.format(hfloat(s / div if div else s), unit)
  115. for div, unit in UNITS if s >= div
  116. )
  117. def mem_rss() -> str:
  118. """Return RSS memory usage as a humanized string."""
  119. p = ps()
  120. if p is not None:
  121. return humanbytes(_process_memory_info(p).rss)
  122. def ps() -> Process: # pragma: no cover
  123. """Return the global :class:`psutil.Process` instance.
  124. Note:
  125. Returns :const:`None` if :pypi:`psutil` is not installed.
  126. """
  127. global _process
  128. if _process is None and Process is not None:
  129. _process = Process(os.getpid())
  130. return _process
  131. def _process_memory_info(process: Process) -> Any:
  132. try:
  133. return process.memory_info()
  134. except AttributeError:
  135. return process.get_memory_info()
  136. def cry(out: Optional[StringIO]=None,
  137. sepchr: str='=', seplen: int=49) -> None: # pragma: no cover
  138. """Return stack-trace of all active threads.
  139. See Also:
  140. Taken from https://gist.github.com/737056.
  141. """
  142. import threading
  143. out = StringIO() if out is None else out
  144. P = partial(print, file=out)
  145. # get a map of threads by their ID so we can print their names
  146. # during the traceback dump
  147. tmap = {t.ident: t for t in threading.enumerate()}
  148. sep = sepchr * seplen
  149. for tid, frame in sys._current_frames().items():
  150. thread = tmap.get(tid)
  151. if not thread:
  152. # skip old junk (left-overs from a fork)
  153. continue
  154. P('{0.name}'.format(thread))
  155. P(sep)
  156. traceback.print_stack(frame, file=out)
  157. P(sep)
  158. P('LOCAL VARIABLES')
  159. P(sep)
  160. pprint(frame.f_locals, stream=out)
  161. P('\n')
  162. return out.getvalue()