test_state.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from __future__ import absolute_import, unicode_literals
  2. import pickle
  3. from time import time
  4. import pytest
  5. from case import Mock, patch
  6. from celery import uuid
  7. from celery.exceptions import WorkerShutdown, WorkerTerminate
  8. from celery.utils.collections import LimitedSet
  9. from celery.worker import state
  10. @pytest.fixture
  11. def reset_state():
  12. yield
  13. state.active_requests.clear()
  14. state.revoked.clear()
  15. state.total_count.clear()
  16. class MockShelve(dict):
  17. filename = None
  18. in_sync = False
  19. closed = False
  20. def open(self, filename, **kwargs):
  21. self.filename = filename
  22. return self
  23. def sync(self):
  24. self.in_sync = True
  25. def close(self):
  26. self.closed = True
  27. class MyPersistent(state.Persistent):
  28. storage = MockShelve()
  29. class test_maybe_shutdown:
  30. def teardown(self):
  31. state.should_stop = None
  32. state.should_terminate = None
  33. def test_should_stop(self):
  34. state.should_stop = True
  35. with pytest.raises(WorkerShutdown):
  36. state.maybe_shutdown()
  37. state.should_stop = 0
  38. with pytest.raises(WorkerShutdown):
  39. state.maybe_shutdown()
  40. state.should_stop = False
  41. try:
  42. state.maybe_shutdown()
  43. except SystemExit:
  44. raise RuntimeError('should not have exited')
  45. state.should_stop = None
  46. try:
  47. state.maybe_shutdown()
  48. except SystemExit:
  49. raise RuntimeError('should not have exited')
  50. state.should_stop = 0
  51. try:
  52. state.maybe_shutdown()
  53. except SystemExit as exc:
  54. assert exc.code == 0
  55. else:
  56. raise RuntimeError('should have exited')
  57. state.should_stop = 303
  58. try:
  59. state.maybe_shutdown()
  60. except SystemExit as exc:
  61. assert exc.code == 303
  62. else:
  63. raise RuntimeError('should have exited')
  64. def test_should_terminate(self):
  65. state.should_terminate = True
  66. with pytest.raises(WorkerTerminate):
  67. state.maybe_shutdown()
  68. @pytest.mark.usefixtures('reset_state')
  69. class test_Persistent:
  70. @pytest.fixture
  71. def p(self):
  72. return MyPersistent(state, filename='celery-state')
  73. def test_close_twice(self, p):
  74. p._is_open = False
  75. p.close()
  76. def test_constructor(self, p):
  77. assert p.db == {}
  78. assert p.db.filename == p.filename
  79. def test_save(self, p):
  80. p.db['foo'] = 'bar'
  81. p.save()
  82. assert p.db.in_sync
  83. assert p.db.closed
  84. def add_revoked(self, p, *ids):
  85. for id in ids:
  86. p.db.setdefault(str('revoked'), LimitedSet()).add(id)
  87. def test_merge(self, p, data=['foo', 'bar', 'baz']):
  88. state.revoked.update(data)
  89. p.merge()
  90. for item in data:
  91. assert item in state.revoked
  92. def test_merge_dict(self, p):
  93. p.clock = Mock()
  94. p.clock.adjust.return_value = 626
  95. d = {str('revoked'): {str('abc'): time()}, str('clock'): 313}
  96. p._merge_with(d)
  97. p.clock.adjust.assert_called_with(313)
  98. assert d[str('clock')] == 626
  99. assert str('abc') in state.revoked
  100. def test_sync_clock_and_purge(self, p):
  101. passthrough = Mock()
  102. passthrough.side_effect = lambda x: x
  103. with patch('celery.worker.state.revoked') as revoked:
  104. d = {str('clock'): 0}
  105. p.clock = Mock()
  106. p.clock.forward.return_value = 627
  107. p._dumps = passthrough
  108. p.compress = passthrough
  109. p._sync_with(d)
  110. revoked.purge.assert_called_with()
  111. assert d[str('clock')] == 627
  112. assert str('revoked') not in d
  113. assert d[str('zrevoked')] is revoked
  114. def test_sync(self, p,
  115. data1=['foo', 'bar', 'baz'], data2=['baz', 'ini', 'koz']):
  116. self.add_revoked(p, *data1)
  117. for item in data2:
  118. state.revoked.add(item)
  119. p.sync()
  120. assert p.db[str('zrevoked')]
  121. pickled = p.decompress(p.db[str('zrevoked')])
  122. assert pickled
  123. saved = pickle.loads(pickled)
  124. for item in data2:
  125. assert item in saved
  126. class SimpleReq(object):
  127. def __init__(self, name):
  128. self.id = uuid()
  129. self.name = name
  130. @pytest.mark.usefixtures('reset_state')
  131. class test_state:
  132. def test_accepted(self, requests=[SimpleReq('foo'),
  133. SimpleReq('bar'),
  134. SimpleReq('baz'),
  135. SimpleReq('baz')]):
  136. for request in requests:
  137. state.task_accepted(request)
  138. for req in requests:
  139. assert req in state.active_requests
  140. assert state.total_count['foo'] == 1
  141. assert state.total_count['bar'] == 1
  142. assert state.total_count['baz'] == 2
  143. def test_ready(self, requests=[SimpleReq('foo'),
  144. SimpleReq('bar')]):
  145. for request in requests:
  146. state.task_accepted(request)
  147. assert len(state.active_requests) == 2
  148. for request in requests:
  149. state.task_ready(request)
  150. assert len(state.active_requests) == 0