|  | @@ -2,7 +2,7 @@ from __future__ import absolute_import
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from contextlib import contextmanager
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from celery import group
 | 
	
		
			
				|  |  | +from celery import group, uuid
 | 
	
		
			
				|  |  |  from celery import canvas
 | 
	
		
			
				|  |  |  from celery import result
 | 
	
		
			
				|  |  |  from celery.exceptions import ChordError, Retry
 | 
	
	
		
			
				|  | @@ -219,6 +219,57 @@ class test_chord(ChordCase):
 | 
	
		
			
				|  |  |              chord.run = prev
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class test_add_to_chord(AppCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def setup(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        @self.app.task(shared=False)
 | 
	
		
			
				|  |  | +        def add(x, y):
 | 
	
		
			
				|  |  | +            return x + y
 | 
	
		
			
				|  |  | +        self.add = add
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        @self.app.task(shared=False, bind=True)
 | 
	
		
			
				|  |  | +        def adds(self, sig, lazy=False):
 | 
	
		
			
				|  |  | +            return self.add_to_chord(sig, lazy)
 | 
	
		
			
				|  |  | +        self.adds = adds
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_add_to_chord(self):
 | 
	
		
			
				|  |  | +        self.app.backend = Mock(name='backend')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        sig = self.add.s(2, 2)
 | 
	
		
			
				|  |  | +        sig.delay = Mock(name='sig.delay')
 | 
	
		
			
				|  |  | +        self.adds.request.group = uuid()
 | 
	
		
			
				|  |  | +        self.adds.request.id = uuid()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        with self.assertRaises(ValueError):
 | 
	
		
			
				|  |  | +            # task not part of chord
 | 
	
		
			
				|  |  | +            self.adds.run(sig)
 | 
	
		
			
				|  |  | +        self.adds.request.chord = self.add.s()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        res1 = self.adds.run(sig, True)
 | 
	
		
			
				|  |  | +        self.assertEqual(res1, sig)
 | 
	
		
			
				|  |  | +        self.assertTrue(sig.options['task_id'])
 | 
	
		
			
				|  |  | +        self.assertEqual(sig.options['group_id'], self.adds.request.group)
 | 
	
		
			
				|  |  | +        self.assertEqual(sig.options['chord'], self.adds.request.chord)
 | 
	
		
			
				|  |  | +        self.assertFalse(sig.delay.called)
 | 
	
		
			
				|  |  | +        self.app.backend.add_to_chord.assert_called_with(
 | 
	
		
			
				|  |  | +            self.adds.request.group, sig.freeze(),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.app.backend.reset_mock()
 | 
	
		
			
				|  |  | +        sig2 = self.add.s(4, 4)
 | 
	
		
			
				|  |  | +        sig2.delay = Mock(name='sig2.delay')
 | 
	
		
			
				|  |  | +        res2 = self.adds.run(sig2)
 | 
	
		
			
				|  |  | +        self.assertEqual(res2, sig2.delay.return_value)
 | 
	
		
			
				|  |  | +        self.assertTrue(sig2.options['task_id'])
 | 
	
		
			
				|  |  | +        self.assertEqual(sig2.options['group_id'], self.adds.request.group)
 | 
	
		
			
				|  |  | +        self.assertEqual(sig2.options['chord'], self.adds.request.chord)
 | 
	
		
			
				|  |  | +        sig2.delay.assert_called_with()
 | 
	
		
			
				|  |  | +        self.app.backend.add_to_chord.assert_called_with(
 | 
	
		
			
				|  |  | +            self.adds.request.group, sig2.freeze(),
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class test_Chord_task(ChordCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_run(self):
 |