test_worker_state.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from celery.tests.utils import unittest
  2. from celery.datastructures import LimitedSet
  3. from celery.worker import state
  4. class StateResetCase(unittest.TestCase):
  5. def setUp(self):
  6. self.reset_state()
  7. self.on_setup()
  8. def tearDown(self):
  9. self.reset_state()
  10. self.on_teardown()
  11. def reset_state(self):
  12. state.active_requests.clear()
  13. state.revoked.clear()
  14. state.total_count.clear()
  15. def on_setup(self):
  16. pass
  17. def on_teardown(self):
  18. pass
  19. class MockShelve(dict):
  20. filename = None
  21. in_sync = False
  22. closed = False
  23. def open(self, filename):
  24. self.filename = filename
  25. return self
  26. def sync(self):
  27. self.in_sync = True
  28. def close(self):
  29. self.closed = True
  30. class MyPersistent(state.Persistent):
  31. storage = MockShelve()
  32. class test_Persistent(StateResetCase):
  33. def on_setup(self):
  34. self.p = MyPersistent(filename="celery-state")
  35. def test_constructor(self):
  36. self.assertDictEqual(self.p.db, {})
  37. self.assertEqual(self.p.db.filename, self.p.filename)
  38. def test_save(self):
  39. self.p.db["foo"] = "bar"
  40. self.p.save()
  41. self.assertTrue(self.p.db.in_sync)
  42. self.assertTrue(self.p.db.closed)
  43. def add_revoked(self, *ids):
  44. for id in ids:
  45. self.p.db.setdefault("revoked", LimitedSet()).add(id)
  46. def test_merge(self, data=["foo", "bar", "baz"]):
  47. self.add_revoked(*data)
  48. self.p.merge(self.p.db)
  49. for item in data:
  50. self.assertIn(item, state.revoked)
  51. def test_sync(self, data1=["foo", "bar", "baz"],
  52. data2=["baz", "ini", "koz"]):
  53. self.add_revoked(*data1)
  54. for item in data2:
  55. state.revoked.add(item)
  56. self.p.sync(self.p.db)
  57. for item in data2:
  58. self.assertIn(item, self.p.db["revoked"])
  59. class SimpleReq(object):
  60. def __init__(self, task_name):
  61. self.task_name = task_name
  62. class test_state(StateResetCase):
  63. def test_accepted(self, requests=[SimpleReq("foo"),
  64. SimpleReq("bar"),
  65. SimpleReq("baz"),
  66. SimpleReq("baz")]):
  67. for request in requests:
  68. state.task_accepted(request)
  69. for req in requests:
  70. self.assertIn(req, state.active_requests)
  71. self.assertEqual(state.total_count["foo"], 1)
  72. self.assertEqual(state.total_count["bar"], 1)
  73. self.assertEqual(state.total_count["baz"], 2)
  74. def test_ready(self, requests=[SimpleReq("foo"),
  75. SimpleReq("bar")]):
  76. for request in requests:
  77. state.task_accepted(request)
  78. self.assertEqual(len(state.active_requests), 2)
  79. for request in requests:
  80. state.task_ready(request)
  81. self.assertEqual(len(state.active_requests), 0)