Browse Source

Allow to have a custom request (#3977)

* Allow custom Request, aka custom `on_timeout`.

Allowing a custom Request eases the task of handling timeouts (even hard
timeouts).

Rationale

Some (poorly written) bits of code catch exceptions quite broadly:

  try:
      ...
  except:
      ...

This hurts tasks when a SoftTimeLimitError is raised inside such blocks of
code.  Rewriting those smelly bits of code can take a lot of effort, and
sometimes, the code belongs to a third-party library which makes the task even
harder.

Using a custom request allows to catch hard time limits.

Your app can be customized like:

   from celery import Task as BaseTask
   from celery.worker.request import Request as BaseRequest

   class Request(BaseRequest):
       def on_timeout(self, soft, timeout):
          super(Request, self).on_timeout(soft, timeout)
          if not soft:
	     print('Something hard hit me!')

    class MyTask(BaseTask):
        Request = Request

    @app.task(base=MyTask, bind=True)
    def sometask(self):
        pass

* Check signatures' types have a default Request.

* Test Request is customizable per Task class.

* Document custom requests.

* Exemplify the usage of the custom requests.
Manuel Vázquez Acosta 7 years ago
parent
commit
215376d8f8
4 changed files with 102 additions and 2 deletions
  1. 3 0
      celery/app/task.py
  2. 3 1
      celery/worker/strategy.py
  3. 62 0
      docs/userguide/tasks.rst
  4. 34 1
      t/unit/worker/test_strategy.py

+ 3 - 0
celery/app/task.py

@@ -158,6 +158,9 @@ class Task(object):
     #: Execution strategy used, or the qualified name of one.
     Strategy = 'celery.worker.strategy:default'
 
+    #: Request class used, or the qualified name of one.
+    Request = 'celery.worker.request:Request'
+
     #: This is the instance bound to if the task is a method of a class.
     __self__ = None
 

+ 3 - 1
celery/worker/strategy.py

@@ -11,8 +11,9 @@ from celery.exceptions import InvalidTaskError
 from celery.utils.log import get_logger
 from celery.utils.saferepr import saferepr
 from celery.utils.time import timezone
+from celery.utils.imports import symbol_by_name
 
-from .request import Request, create_request_cls
+from .request import create_request_cls
 from .state import task_reserved
 
 __all__ = ['default']
@@ -84,6 +85,7 @@ def default(task, app, consumer,
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
     body_can_be_buffer = consumer.pool.body_can_be_buffer
+    Request = symbol_by_name(task.Request)
     Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
 
     revoked_tasks = consumer.controller.state.revoked

+ 62 - 0
docs/userguide/tasks.rst

@@ -1502,6 +1502,68 @@ Handlers
 
     The return value of this handler is ignored.
 
+
+Requests and custom requests
+----------------------------
+
+Upon receiving a message to run a task, the `worker <guide-workers>`:ref:
+creates a `request <celery.worker.request.Request>`:class: to represent such
+demand.
+
+Custom task classes may override which request class to use by changing the
+attribute `celery.app.task.Task.Request`:attr:.  You may either assign the
+custom request class itself, or its fully qualified name.
+
+The request has several responsibilities.  Custom request classes should cover
+them all -- they are responsible to actually run and trace the task.  We
+strongly recommend to inherit from `celery.worker.request.Request`:class:.
+
+When using the `pre-forking worker <worker-concurrency>`:ref:, the methods
+`~celery.worker.request.Request.on_timeout`:meth: and
+`~celery.worker.request.Request.on_failure`:meth: are executed in the main
+worker process.  An application may leverage such facility to detect failures
+which are not detected using `celery.app.task.Task.on_failure`:meth:.
+
+As an example, the following custom request detects and logs hard time
+limits, and other failures.
+
+.. code-block:: python
+
+   import logging
+   from celery.worker.request import Request
+
+   logger = logging.getLogger('my.package')
+
+   class MyRequest(Request):
+       'A minimal custom request to log failures and hard time limits.'
+
+       def on_timeout(self, soft, timeout):
+           super(MyRequest, self).on_timeout(soft, timeout)
+           if not soft:
+              logger.warning(
+                  'A hard timeout was enforced for task %s',
+                  self.task.name
+              )
+
+       def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
+           super(Request, self).on_failure(
+               exc_info,
+               send_failed_event=send_failed_event,
+               return_ok=return_ok
+           )
+           logger.warning(
+               'Failure detected for task %s',
+               self.task.name
+           )
+
+   class MyTask(Task):
+       Request = MyRequest  # you can use a FQN 'my.package:MyRequest'
+
+   @app.task(base=MyTask)
+   def some_longrunning_task():
+       # use your imagination
+
+
 .. _task-how-they-work:
 
 How it works

+ 34 - 1
t/unit/worker/test_strategy.py

@@ -8,9 +8,14 @@ from contextlib import contextmanager
 from case import Mock, patch
 from kombu.utils.limits import TokenBucket
 
+from celery import Task
 from celery.exceptions import InvalidTaskError
 from celery.worker import state
-from celery.worker.strategy import proto1_to_proto2
+from celery.worker.strategy import (
+    proto1_to_proto2,
+    default as default_strategy
+)
+from celery.worker.request import Request
 from celery.utils.time import rate
 
 
@@ -114,6 +119,7 @@ class test_default_strategy_proto2:
     def _context(self, sig,
                  rate_limits=True, events=True, utc=True, limit=None):
         assert sig.type.Strategy
+        assert sig.type.Request
 
         reserved = Mock()
         consumer = Mock()
@@ -214,3 +220,30 @@ class test_default_strategy_proto1__no_utc(test_default_strategy_proto2):
     def prepare_message(self, message):
         message.payload['utc'] = False
         return message
+
+
+class test_custom_request_for_default_strategy(test_default_strategy_proto2):
+    def test_custom_request_gets_instantiated(self):
+        _MyRequest = Mock(name='MyRequest')
+
+        class MyRequest(Request):
+            def __init__(self, *args, **kwargs):
+                Request.__init__(self, *args, **kwargs)
+                _MyRequest()
+
+        class MyTask(Task):
+            Request = MyRequest
+
+        @self.app.task(base=MyTask)
+        def failed():
+            raise AssertionError
+
+        sig = failed.s()
+        with self._context(sig) as C:
+            task_message_handler = default_strategy(
+                failed,
+                self.app,
+                C.consumer
+            )
+            task_message_handler(C.message, None, None, None, None)
+            _MyRequest.assert_called()