from __future__ import absolute_import, unicode_literals from celery.contrib.abortable import AbortableTask, AbortableAsyncResult class test_AbortableTask: def setup(self): @self.app.task(base=AbortableTask, shared=False) def abortable(): return True self.abortable = abortable def test_async_result_is_abortable(self): result = self.abortable.apply_async() tid = result.id assert isinstance( self.abortable.AsyncResult(tid), AbortableAsyncResult) def test_is_not_aborted(self): self.abortable.push_request() try: result = self.abortable.apply_async() tid = result.id assert not self.abortable.is_aborted(task_id=tid) finally: self.abortable.pop_request() def test_is_aborted_not_abort_result(self): self.abortable.AsyncResult = self.app.AsyncResult self.abortable.push_request() try: self.abortable.request.id = 'foo' assert not self.abortable.is_aborted() finally: self.abortable.pop_request() def test_abort_yields_aborted(self): self.abortable.push_request() try: result = self.abortable.apply_async() result.abort() tid = result.id assert self.abortable.is_aborted(task_id=tid) finally: self.abortable.pop_request()