test_state.py 21 KB


  1. import pickle
  2. from decimal import Decimal
  3. from random import shuffle
  4. from time import time
  5. from itertools import count
  6. from case import Mock, patch, skip
  7. from celery import states
  8. from celery import uuid
  9. from celery.events import Event
  10. from celery.events.state import (
  11. HEARTBEAT_EXPIRE_WINDOW,
  12. HEARTBEAT_DRIFT_MAX,
  13. State,
  14. Worker,
  15. Task,
  16. heartbeat_expires,
  17. )
  18. class replay:
  19. def __init__(self, state):
  20. self.state = state
  21. self.rewind()
  22. self.setup()
  23. self.current_clock = 0
  24. def setup(self):
  25. pass
  26. def next_event(self):
  27. ev = self.events[next(self.position)]
  28. ev['local_received'] = ev['timestamp']
  29. try:
  30. self.current_clock = ev['clock']
  31. except KeyError:
  32. ev['clock'] = self.current_clock = self.current_clock + 1
  33. return ev
  34. def __iter__(self):
  35. return self
  36. def __next__(self):
  37. try:
  38. self.state.event(self.next_event())
  39. except IndexError:
  40. raise StopIteration()
  41. next = __next__
  42. def rewind(self):
  43. self.position = count(0)
  44. return self
  45. def play(self):
  46. for _ in self:
  47. pass
  48. class ev_worker_online_offline(replay):
  49. def setup(self):
  50. self.events = [
  51. Event('worker-online', hostname='utest1'),
  52. Event('worker-offline', hostname='utest1'),
  53. ]
  54. class ev_worker_heartbeats(replay):
  55. def setup(self):
  56. self.events = [
  57. Event('worker-heartbeat', hostname='utest1',
  58. timestamp=time() - HEARTBEAT_EXPIRE_WINDOW * 2),
  59. Event('worker-heartbeat', hostname='utest1'),
  60. ]
  61. class ev_task_states(replay):
  62. def setup(self):
  63. tid = self.tid = uuid()
  64. tid2 = self.tid2 = uuid()
  65. self.events = [
  66. Event('task-received', uuid=tid, name='task1',
  67. args='(2, 2)', kwargs="{'foo': 'bar'}",
  68. retries=0, eta=None, hostname='utest1'),
  69. Event('task-started', uuid=tid, hostname='utest1'),
  70. Event('task-revoked', uuid=tid, hostname='utest1'),
  71. Event('task-retried', uuid=tid, exception="KeyError('bar')",
  72. traceback='line 2 at main', hostname='utest1'),
  73. Event('task-failed', uuid=tid, exception="KeyError('foo')",
  74. traceback='line 1 at main', hostname='utest1'),
  75. Event('task-succeeded', uuid=tid, result='4',
  76. runtime=0.1234, hostname='utest1'),
  77. Event('foo-bar'),
  78. Event('task-received', uuid=tid2, name='task2',
  79. args='(4, 4)', kwargs="{'foo': 'bar'}",
  80. retries=0, eta=None, parent_id=tid, root_id=tid,
  81. hostname='utest1'),
  82. ]
  83. def QTEV(type, uuid, hostname, clock, name=None, timestamp=None):
  84. """Quick task event."""
  85. return Event('task-{0}'.format(type), uuid=uuid, hostname=hostname,
  86. clock=clock, name=name, timestamp=timestamp or time())
  87. class ev_logical_clock_ordering(replay):
  88. def __init__(self, state, offset=0, uids=None):
  89. self.offset = offset or 0
  90. self.uids = self.setuids(uids)
  91. super().__init__(state)
  92. def setuids(self, uids):
  93. uids = self.tA, self.tB, self.tC = uids or [uuid(), uuid(), uuid()]
  94. return uids
  95. def setup(self):
  96. offset = self.offset
  97. tA, tB, tC = self.uids
  98. self.events = [
  99. QTEV('received', tA, 'w1', name='tA', clock=offset + 1),
  100. QTEV('received', tB, 'w2', name='tB', clock=offset + 1),
  101. QTEV('started', tA, 'w1', name='tA', clock=offset + 3),
  102. QTEV('received', tC, 'w2', name='tC', clock=offset + 3),
  103. QTEV('started', tB, 'w2', name='tB', clock=offset + 5),
  104. QTEV('retried', tA, 'w1', name='tA', clock=offset + 7),
  105. QTEV('succeeded', tB, 'w2', name='tB', clock=offset + 9),
  106. QTEV('started', tC, 'w2', name='tC', clock=offset + 10),
  107. QTEV('received', tA, 'w3', name='tA', clock=offset + 13),
  108. QTEV('succeded', tC, 'w2', name='tC', clock=offset + 12),
  109. QTEV('started', tA, 'w3', name='tA', clock=offset + 14),
  110. QTEV('succeeded', tA, 'w3', name='TA', clock=offset + 16),
  111. ]
  112. def rewind_with_offset(self, offset, uids=None):
  113. self.offset = offset
  114. self.uids = self.setuids(uids or self.uids)
  115. self.setup()
  116. self.rewind()
  117. class ev_snapshot(replay):
  118. def setup(self):
  119. self.events = [
  120. Event('worker-online', hostname='utest1'),
  121. Event('worker-online', hostname='utest2'),
  122. Event('worker-online', hostname='utest3'),
  123. ]
  124. for i in range(20):
  125. worker = not i % 2 and 'utest2' or 'utest1'
  126. type = not i % 2 and 'task2' or 'task1'
  127. self.events.append(Event('task-received', name=type,
  128. uuid=uuid(), hostname=worker))
  129. class test_Worker:
  130. def test_equality(self):
  131. assert Worker(hostname='foo').hostname == 'foo'
  132. assert Worker(hostname='foo') == Worker(hostname='foo')
  133. assert Worker(hostname='foo') != Worker(hostname='bar')
  134. assert hash(Worker(hostname='foo')) == hash(Worker(hostname='foo'))
  135. assert hash(Worker(hostname='foo')) != hash(Worker(hostname='bar'))
  136. def test_heartbeat_expires__Decimal(self):
  137. assert heartbeat_expires(
  138. Decimal(344313.37), freq=60, expire_window=200) == 344433.37
  139. def test_compatible_with_Decimal(self):
  140. w = Worker('george@vandelay.com')
  141. timestamp, local_received = Decimal(time()), time()
  142. w.event('worker-online', timestamp, local_received, fields={
  143. 'hostname': 'george@vandelay.com',
  144. 'timestamp': timestamp,
  145. 'local_received': local_received,
  146. 'freq': Decimal(5.6335431),
  147. })
  148. assert w.alive
  149. def test_eq_ne_other(self):
  150. assert Worker('a@b.com') == Worker('a@b.com')
  151. assert Worker('a@b.com') != Worker('b@b.com')
  152. assert Worker('a@b.com') != object()
  153. def test_reduce_direct(self):
  154. w = Worker('george@vandelay.com')
  155. w.event('worker-online', 10.0, 13.0, fields={
  156. 'hostname': 'george@vandelay.com',
  157. 'timestamp': 10.0,
  158. 'local_received': 13.0,
  159. 'freq': 60,
  160. })
  161. fun, args = w.__reduce__()
  162. w2 = fun(*args)
  163. assert w2.hostname == w.hostname
  164. assert w2.pid == w.pid
  165. assert w2.freq == w.freq
  166. assert w2.heartbeats == w.heartbeats
  167. assert w2.clock == w.clock
  168. assert w2.active == w.active
  169. assert w2.processed == w.processed
  170. assert w2.loadavg == w.loadavg
  171. assert w2.sw_ident == w.sw_ident
  172. def test_update(self):
  173. w = Worker('george@vandelay.com')
  174. w.update({'idx': '301'}, foo=1, clock=30, bah='foo')
  175. assert w.idx == '301'
  176. assert w.foo == 1
  177. assert w.clock == 30
  178. assert w.bah == 'foo'
  179. def test_survives_missing_timestamp(self):
  180. worker = Worker(hostname='foo')
  181. worker.event('heartbeat')
  182. assert worker.heartbeats == []
  183. def test_repr(self):
  184. assert repr(Worker(hostname='foo'))
  185. def test_drift_warning(self):
  186. worker = Worker(hostname='foo')
  187. with patch('celery.events.state.warn') as warn:
  188. worker.event(None, time() + (HEARTBEAT_DRIFT_MAX * 2), time())
  189. warn.assert_called()
  190. assert 'Substantial drift' in warn.call_args[0][0]
  191. def test_updates_heartbeat(self):
  192. worker = Worker(hostname='foo')
  193. worker.event(None, time(), time())
  194. assert len(worker.heartbeats) == 1
  195. h1 = worker.heartbeats[0]
  196. worker.event(None, time(), time() - 10)
  197. assert len(worker.heartbeats) == 2
  198. assert worker.heartbeats[-1] == h1
  199. class test_Task:
  200. def test_equality(self):
  201. assert Task(uuid='foo').uuid == 'foo'
  202. assert Task(uuid='foo') == Task(uuid='foo')
  203. assert Task(uuid='foo') != Task(uuid='bar')
  204. assert hash(Task(uuid='foo')) == hash(Task(uuid='foo'))
  205. assert hash(Task(uuid='foo')) != hash(Task(uuid='bar'))
  206. def test_info(self):
  207. task = Task(uuid='abcdefg',
  208. name='tasks.add',
  209. args='(2, 2)',
  210. kwargs='{}',
  211. retries=2,
  212. result=42,
  213. eta=1,
  214. runtime=0.0001,
  215. expires=1,
  216. parent_id='bdefc',
  217. root_id='dedfef',
  218. foo=None,
  219. exception=1,
  220. received=time() - 10,
  221. started=time() - 8,
  222. exchange='celery',
  223. routing_key='celery',
  224. succeeded=time())
  225. assert sorted(list(task._info_fields)) == sorted(task.info().keys())
  226. assert (sorted(list(task._info_fields + ('received',))) ==
  227. sorted(task.info(extra=('received',))))
  228. assert (sorted(['args', 'kwargs']) ==
  229. sorted(task.info(['args', 'kwargs']).keys()))
  230. assert not list(task.info('foo'))
  231. def test_reduce_direct(self):
  232. task = Task(uuid='uuid', name='tasks.add', args='(2, 2)')
  233. fun, args = task.__reduce__()
  234. task2 = fun(*args)
  235. assert task == task2
  236. def test_ready(self):
  237. task = Task(uuid='abcdefg',
  238. name='tasks.add')
  239. task.event('received', time(), time())
  240. assert not task.ready
  241. task.event('succeeded', time(), time())
  242. assert task.ready
  243. def test_sent(self):
  244. task = Task(uuid='abcdefg',
  245. name='tasks.add')
  246. task.event('sent', time(), time())
  247. assert task.state == states.PENDING
  248. def test_merge(self):
  249. task = Task()
  250. task.event('failed', time(), time())
  251. task.event('started', time(), time())
  252. task.event('received', time(), time(), {
  253. 'name': 'tasks.add', 'args': (2, 2),
  254. })
  255. assert task.state == states.FAILURE
  256. assert task.name == 'tasks.add'
  257. assert task.args == (2, 2)
  258. task.event('retried', time(), time())
  259. assert task.state == states.RETRY
  260. def test_repr(self):
  261. assert repr(Task(uuid='xxx', name='tasks.add'))
  262. class test_State:
  263. def test_repr(self):
  264. assert repr(State())
  265. def test_pickleable(self):
  266. state = State()
  267. r = ev_logical_clock_ordering(state)
  268. r.play()
  269. assert pickle.loads(pickle.dumps(state))
  270. def test_task_logical_clock_ordering(self):
  271. state = State()
  272. r = ev_logical_clock_ordering(state)
  273. tA, tB, tC = r.uids
  274. r.play()
  275. now = list(state.tasks_by_time())
  276. assert now[0][0] == tA
  277. assert now[1][0] == tC
  278. assert now[2][0] == tB
  279. for _ in range(1000):
  280. shuffle(r.uids)
  281. tA, tB, tC = r.uids
  282. r.rewind_with_offset(r.current_clock + 1, r.uids)
  283. r.play()
  284. now = list(state.tasks_by_time())
  285. assert now[0][0] == tA
  286. assert now[1][0] == tC
  287. assert now[2][0] == tB
  288. @skip.todo(reason='not working')
  289. def test_task_descending_clock_ordering(self):
  290. state = State()
  291. r = ev_logical_clock_ordering(state)
  292. tA, tB, tC = r.uids
  293. r.play()
  294. now = list(state.tasks_by_time(reverse=False))
  295. assert now[0][0] == tA
  296. assert now[1][0] == tB
  297. assert now[2][0] == tC
  298. for _ in range(1000):
  299. shuffle(r.uids)
  300. tA, tB, tC = r.uids
  301. r.rewind_with_offset(r.current_clock + 1, r.uids)
  302. r.play()
  303. now = list(state.tasks_by_time(reverse=False))
  304. assert now[0][0] == tB
  305. assert now[1][0] == tC
  306. assert now[2][0] == tA
  307. def test_get_or_create_task(self):
  308. state = State()
  309. task, created = state.get_or_create_task('id1')
  310. assert task.uuid == 'id1'
  311. assert created
  312. task2, created2 = state.get_or_create_task('id1')
  313. assert task2 is task
  314. assert not created2
  315. def test_get_or_create_worker(self):
  316. state = State()
  317. worker, created = state.get_or_create_worker('george@vandelay.com')
  318. assert worker.hostname == 'george@vandelay.com'
  319. assert created
  320. worker2, created2 = state.get_or_create_worker('george@vandelay.com')
  321. assert worker2 is worker
  322. assert not created2
  323. def test_get_or_create_worker__with_defaults(self):
  324. state = State()
  325. worker, created = state.get_or_create_worker(
  326. 'george@vandelay.com', pid=30,
  327. )
  328. assert worker.hostname == 'george@vandelay.com'
  329. assert worker.pid == 30
  330. assert created
  331. worker2, created2 = state.get_or_create_worker(
  332. 'george@vandelay.com', pid=40,
  333. )
  334. assert worker2 is worker
  335. assert worker2.pid == 40
  336. assert not created2
  337. def test_worker_online_offline(self):
  338. r = ev_worker_online_offline(State())
  339. next(r)
  340. assert list(r.state.alive_workers())
  341. assert r.state.workers['utest1'].alive
  342. r.play()
  343. assert not list(r.state.alive_workers())
  344. assert not r.state.workers['utest1'].alive
  345. def test_itertasks(self):
  346. s = State()
  347. s.tasks = {'a': 'a', 'b': 'b', 'c': 'c', 'd': 'd'}
  348. assert len(list(s.itertasks(limit=2))) == 2
  349. def test_worker_heartbeat_expire(self):
  350. r = ev_worker_heartbeats(State())
  351. next(r)
  352. assert not list(r.state.alive_workers())
  353. assert not r.state.workers['utest1'].alive
  354. r.play()
  355. assert list(r.state.alive_workers())
  356. assert r.state.workers['utest1'].alive
  357. def test_task_states(self):
  358. r = ev_task_states(State())
  359. # RECEIVED
  360. next(r)
  361. assert r.tid in r.state.tasks
  362. task = r.state.tasks[r.tid]
  363. assert task.state == states.RECEIVED
  364. assert task.received
  365. assert task.timestamp == task.received
  366. assert task.worker.hostname == 'utest1'
  367. # STARTED
  368. next(r)
  369. assert r.state.workers['utest1'].alive
  370. assert task.state == states.STARTED
  371. assert task.started
  372. assert task.timestamp == task.started
  373. assert task.worker.hostname == 'utest1'
  374. # REVOKED
  375. next(r)
  376. assert task.state == states.REVOKED
  377. assert task.revoked
  378. assert task.timestamp == task.revoked
  379. assert task.worker.hostname == 'utest1'
  380. # RETRY
  381. next(r)
  382. assert task.state == states.RETRY
  383. assert task.retried
  384. assert task.timestamp == task.retried
  385. assert task.worker.hostname, 'utest1'
  386. assert task.exception == "KeyError('bar')"
  387. assert task.traceback == 'line 2 at main'
  388. # FAILURE
  389. next(r)
  390. assert task.state == states.FAILURE
  391. assert task.failed
  392. assert task.timestamp == task.failed
  393. assert task.worker.hostname == 'utest1'
  394. assert task.exception == "KeyError('foo')"
  395. assert task.traceback == 'line 1 at main'
  396. # SUCCESS
  397. next(r)
  398. assert task.state == states.SUCCESS
  399. assert task.succeeded
  400. assert task.timestamp == task.succeeded
  401. assert task.worker.hostname == 'utest1'
  402. assert task.result == '4'
  403. assert task.runtime == 0.1234
  404. # children, parent, root
  405. r.play()
  406. assert r.tid2 in r.state.tasks
  407. task2 = r.state.tasks[r.tid2]
  408. assert task2.parent is task
  409. assert task2.root is task
  410. assert task2 in task.children
  411. def test_task_children_set_if_received_in_wrong_order(self):
  412. r = ev_task_states(State())
  413. r.events.insert(0, r.events.pop())
  414. r.play()
  415. assert r.state.tasks[r.tid2] in r.state.tasks[r.tid].children
  416. assert r.state.tasks[r.tid2].root is r.state.tasks[r.tid]
  417. assert r.state.tasks[r.tid2].parent is r.state.tasks[r.tid]
  418. def assertStateEmpty(self, state):
  419. assert not state.tasks
  420. assert not state.workers
  421. assert not state.event_count
  422. assert not state.task_count
  423. def assertState(self, state):
  424. assert state.tasks
  425. assert state.workers
  426. assert state.event_count
  427. assert state.task_count
  428. def test_freeze_while(self):
  429. s = State()
  430. r = ev_snapshot(s)
  431. r.play()
  432. def work():
  433. pass
  434. s.freeze_while(work, clear_after=True)
  435. assert not s.event_count
  436. s2 = State()
  437. r = ev_snapshot(s2)
  438. r.play()
  439. s2.freeze_while(work, clear_after=False)
  440. assert s2.event_count
  441. def test_clear_tasks(self):
  442. s = State()
  443. r = ev_snapshot(s)
  444. r.play()
  445. assert s.tasks
  446. s.clear_tasks(ready=False)
  447. assert not s.tasks
  448. def test_clear(self):
  449. r = ev_snapshot(State())
  450. r.play()
  451. assert r.state.event_count
  452. assert r.state.workers
  453. assert r.state.tasks
  454. assert r.state.task_count
  455. r.state.clear()
  456. assert not r.state.event_count
  457. assert not r.state.workers
  458. assert r.state.tasks
  459. assert not r.state.task_count
  460. r.state.clear(False)
  461. assert not r.state.tasks
  462. def test_task_types(self):
  463. r = ev_snapshot(State())
  464. r.play()
  465. assert sorted(r.state.task_types()) == ['task1', 'task2']
  466. def test_tasks_by_time(self):
  467. r = ev_snapshot(State())
  468. r.play()
  469. assert len(list(r.state.tasks_by_time())) == 20
  470. assert len(list(r.state.tasks_by_time(reverse=False))) == 20
  471. def test_tasks_by_type(self):
  472. r = ev_snapshot(State())
  473. r.play()
  474. assert len(list(r.state.tasks_by_type('task1'))) == 10
  475. assert len(list(r.state.tasks_by_type('task2'))) == 10
  476. assert len(r.state.tasks_by_type['task1']) == 10
  477. assert len(r.state.tasks_by_type['task2']) == 10
  478. def test_alive_workers(self):
  479. r = ev_snapshot(State())
  480. r.play()
  481. assert len(list(r.state.alive_workers())) == 3
  482. def test_tasks_by_worker(self):
  483. r = ev_snapshot(State())
  484. r.play()
  485. assert len(list(r.state.tasks_by_worker('utest1'))) == 10
  486. assert len(list(r.state.tasks_by_worker('utest2'))) == 10
  487. assert len(r.state.tasks_by_worker['utest1']) == 10
  488. assert len(r.state.tasks_by_worker['utest2']) == 10
  489. def test_survives_unknown_worker_event(self):
  490. s = State()
  491. s.event({
  492. 'type': 'worker-unknown-event-xxx',
  493. 'foo': 'bar',
  494. })
  495. s.event({
  496. 'type': 'worker-unknown-event-xxx',
  497. 'hostname': 'xxx',
  498. 'foo': 'bar',
  499. })
  500. def test_survives_unknown_worker_leaving(self):
  501. s = State(on_node_leave=Mock(name='on_node_leave'))
  502. (worker, created), subject = s.event({
  503. 'type': 'worker-offline',
  504. 'hostname': 'unknown@vandelay.com',
  505. 'timestamp': time(),
  506. 'local_received': time(),
  507. 'clock': 301030134894833,
  508. })
  509. assert worker == Worker('unknown@vandelay.com')
  510. assert not created
  511. assert subject == 'offline'
  512. assert 'unknown@vandelay.com' not in s.workers
  513. s.on_node_leave.assert_called_with(worker)
  514. def test_on_node_join_callback(self):
  515. s = State(on_node_join=Mock(name='on_node_join'))
  516. (worker, created), subject = s.event({
  517. 'type': 'worker-online',
  518. 'hostname': 'george@vandelay.com',
  519. 'timestamp': time(),
  520. 'local_received': time(),
  521. 'clock': 34314,
  522. })
  523. assert worker
  524. assert created
  525. assert subject == 'online'
  526. assert 'george@vandelay.com' in s.workers
  527. s.on_node_join.assert_called_with(worker)
  528. def test_survives_unknown_task_event(self):
  529. s = State()
  530. s.event({
  531. 'type': 'task-unknown-event-xxx',
  532. 'foo': 'bar',
  533. 'uuid': 'x',
  534. 'hostname': 'y',
  535. 'timestamp': time(),
  536. 'local_received': time(),
  537. 'clock': 0,
  538. })
  539. def test_limits_maxtasks(self):
  540. s = State(max_tasks_in_memory=1)
  541. s.heap_multiplier = 2
  542. s.event({
  543. 'type': 'task-unknown-event-xxx',
  544. 'foo': 'bar',
  545. 'uuid': 'x',
  546. 'hostname': 'y',
  547. 'clock': 3,
  548. 'timestamp': time(),
  549. 'local_received': time(),
  550. })
  551. s.event({
  552. 'type': 'task-unknown-event-xxx',
  553. 'foo': 'bar',
  554. 'uuid': 'y',
  555. 'hostname': 'y',
  556. 'clock': 4,
  557. 'timestamp': time(),
  558. 'local_received': time(),
  559. })
  560. s.event({
  561. 'type': 'task-unknown-event-xxx',
  562. 'foo': 'bar',
  563. 'uuid': 'z',
  564. 'hostname': 'y',
  565. 'clock': 5,
  566. 'timestamp': time(),
  567. 'local_received': time(),
  568. })
  569. assert len(s._taskheap) == 2
  570. assert s._taskheap[0].clock == 4
  571. assert s._taskheap[1].clock == 5
  572. s._taskheap.append(s._taskheap[0])
  573. assert list(s.tasks_by_time())
  574. def test_callback(self):
  575. scratch = {}
  576. def callback(state, event):
  577. scratch['recv'] = True
  578. s = State(callback=callback)
  579. s.event({'type': 'worker-online'})
  580. assert scratch.get('recv')