Przeglądaj źródła

Tests: Make sure all threads get the current app trap

Ask Solem 11 lat temu
rodzic
commit
00074cafcc

+ 5 - 0
celery/_state.py

@@ -92,6 +92,11 @@ def _get_current_app():
         ))
     return _tls.current_app or default_app
 
+
+def _set_current_app(app):
+    _tls.current_app = app
+
+
 C_STRICT_APP = os.environ.get('C_STRICT_APP')
 if os.environ.get('C_STRICT_APP'):  # pragma: no cover
     def get_current_app():

+ 2 - 2
celery/app/base.py

@@ -26,7 +26,7 @@ from kombu.utils import cached_property, uuid
 from celery import platforms
 from celery import signals
 from celery._state import (
-    _task_stack, _tls, get_current_app, set_default_app,
+    _task_stack, get_current_app, _set_current_app, set_default_app,
     _register_app, get_current_worker_task,
 )
 from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
@@ -147,7 +147,7 @@ class Celery(object):
         _register_app(self)
 
     def set_current(self):
-        _tls.current_app = self
+        _set_current_app(self)
 
     def set_default(self):
         set_default_app(self)

+ 2 - 2
celery/tests/backends/test_mongodb.py

@@ -298,7 +298,7 @@ class test_MongoBackend(AppCase):
         self.backend.taskmeta_collection = MONGODB_COLLECTION
 
         mock_database = MagicMock(spec=['__getitem__', '__setitem__'])
-        mock_collection = Mock()
+        self.backend.collections = mock_collection = Mock()
 
         mock_get_database.return_value = mock_database
         mock_database.__getitem__.return_value = mock_collection
@@ -309,7 +309,7 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(
             MONGODB_COLLECTION)
-        mock_collection.assert_called_once_with()
+        self.assertTrue(mock_collection.remove.called)
 
     def test_get_database_authfailure(self):
         x = MongoBackend(app=self.app)

+ 10 - 7
celery/tests/case.py

@@ -412,8 +412,12 @@ class AppCase(Case):
         self._current_app = current_app()
         self._default_app = _state.default_app
         trap = Trap()
+        self._prev_tls = _state._tls
         _state.set_default_app(trap)
-        _state._tls.current_app = trap
+
+        class NonTLS(object):
+            current_app = trap
+        _state._tls = NonTLS()
 
         self.app = self.Celery(set_as_current=False)
         if not self.contained:
@@ -447,13 +451,12 @@ class AppCase(Case):
                 if isinstance(backend.client, DummyClient):
                     backend.client.cache.clear()
                 backend._cache.clear()
-        from celery._state import (
-            _tls, set_default_app, _set_task_join_will_block,
-        )
-        _set_task_join_will_block(False)
+        from celery import _state
+        _state._set_task_join_will_block(False)
 
-        set_default_app(self._default_app)
-        _tls.current_app = self._current_app
+        _state.set_default_app(self._default_app)
+        _state._tls = self._prev_tls
+        _state._tls.current_app = self._current_app
         if self.app is not self._current_app:
             self.app.close()
         self.app = None