|
@@ -10,9 +10,10 @@ from celery.result import AsyncResult, GroupResult, ResultSet
|
|
|
|
|
|
from .conftest import flaky, get_active_redis_channels, get_redis_connection
|
|
|
from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
|
|
|
- add_to_all_to_chord, build_chain_inside_task, collect_ids,
|
|
|
- delayed_sum, delayed_sum_with_soft_guard, identity, ids,
|
|
|
- print_unicode, redis_echo, second_order_replace1, tsum)
|
|
|
+ add_to_all_to_chord, build_chain_inside_task, chord_error,
|
|
|
+ collect_ids, delayed_sum, delayed_sum_with_soft_guard,
|
|
|
+ fail, identity, ids, print_unicode, redis_echo,
|
|
|
+ second_order_replace1, tsum)
|
|
|
|
|
|
TIMEOUT = 120
|
|
|
|
|
@@ -521,3 +522,61 @@ class test_chord:
|
|
|
assert value == 1
|
|
|
assert root_id == expected_root_id
|
|
|
assert parent_id is None
|
|
|
+
|
|
|
+ def test_chord_on_error(self, manager):
|
|
|
+ from celery import states
|
|
|
+ from .tasks import ExpectedException
|
|
|
+ import time
|
|
|
+
|
|
|
+ if not manager.app.conf.result_backend.startswith('redis'):
|
|
|
+ raise pytest.skip('Requires redis result backend.')
|
|
|
+
|
|
|
+ # Run the chord and wait for the error callback to finish.
|
|
|
+ c1 = chord(
|
|
|
+ header=[add.s(1, 2), add.s(3, 4), fail.s()],
|
|
|
+ body=print_unicode.s('This should not be called').on_error(
|
|
|
+ chord_error.s()),
|
|
|
+ )
|
|
|
+ res = c1()
|
|
|
+ try:
|
|
|
+ res.wait(propagate=False)
|
|
|
+ except ExpectedException:
|
|
|
+ pass
|
|
|
+ # Got to wait for children to populate.
|
|
|
+ while not res.children:
|
|
|
+ time.sleep(0.1)
|
|
|
+ try:
|
|
|
+ res.children[0].children[0].wait(propagate=False)
|
|
|
+ except ExpectedException:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Extract the results of the successful tasks from the chord.
|
|
|
+ #
|
|
|
+ # We could do this inside the error handler, and probably would in a
|
|
|
+ # real system, but for the purposes of the test it's obnoxious to get
|
|
|
+ # data out of the error handler.
|
|
|
+ #
|
|
|
+ # So for clarity of our test, we instead do it here.
|
|
|
+
|
|
|
+ # Use the error callback's result to find the failed task.
|
|
|
+ error_callback_result = AsyncResult(
|
|
|
+ res.children[0].children[0].result[0])
|
|
|
+ failed_task_id = error_callback_result.result.args[0].split()[3]
|
|
|
+
|
|
|
+ # Use new group_id result metadata to get group ID.
|
|
|
+ failed_task_result = AsyncResult(failed_task_id)
|
|
|
+ original_group_id = failed_task_result._get_task_meta()['group_id']
|
|
|
+
|
|
|
+ # Use group ID to get preserved group result.
|
|
|
+ backend = fail.app.backend
|
|
|
+ j_key = backend.get_key_for_group(original_group_id, '.j')
|
|
|
+ redis_connection = get_redis_connection()
|
|
|
+ chord_results = [backend.decode(t) for t in
|
|
|
+ redis_connection.lrange(j_key, 0, 3)]
|
|
|
+
|
|
|
+ # Validate group result
|
|
|
+ assert [cr[3] for cr in chord_results if cr[2] == states.SUCCESS] == \
|
|
|
+ [3, 7]
|
|
|
+
|
|
|
+ assert len([cr for cr in chord_results if cr[2] != states.SUCCESS]
|
|
|
+ ) == 1
|