瀏覽代碼

celery command: Better error handling

Ask Solem 12 年之前
父節點
當前提交
66853a2b4c
共有 3 個文件被更改,包括 89 次插入46 次删除
  1. 47 13
      celery/bin/base.py
  2. 40 31
      celery/bin/celery.py
  3. 2 2
      celery/bin/graph.py

+ 47 - 13
celery/bin/base.py

@@ -73,6 +73,7 @@ import warnings
 
 from collections import defaultdict
 from heapq import heappush
+from inspect import getargspec
 from optparse import OptionParser, IndentedHelpFormatter, make_option as Option
 from pprint import pformat
 from types import ModuleType
@@ -104,14 +105,20 @@ find_sformat = re.compile(r'%(\w)')
 
 
 class Error(Exception):
+    status = EX_FAILURE
 
-    def __init__(self, reason, status=EX_FAILURE):
+    def __init__(self, reason, status=None):
         self.reason = reason
-        self.status = status
+        self.status = status if status is not None else self.status
         super(Error, self).__init__(reason, status)
 
     def __str__(self):
         return self.reason
+    __unicode__ = __str__
+
+
+class UsageError(Error):
+    status = EX_USAGE
 
 
 class Extensions(object):
@@ -162,6 +169,8 @@ class Command(object):
     :keyword get_app: Callable returning the current app if no app provided.
 
     """
+    Error = Error
+    UsageError = UsageError
     Parser = OptionParser
 
     #: Arg list used in help.
@@ -218,7 +227,8 @@ class Command(object):
     prog_name = 'celery'
 
     def __init__(self, app=None, get_app=None, no_color=False,
-                 stdout=None, stderr=None, quiet=False):
+                 stdout=None, stderr=None, quiet=False, on_error=None,
+                 on_usage_error=None):
         self.app = app
         self.get_app = get_app or self._get_default_app
         self.stdout = stdout or sys.stdout
@@ -228,19 +238,46 @@ class Command(object):
         self.quiet = quiet
         if not self.description:
             self.description = self.__doc__
+        if on_error:
+            self.on_error = on_error
+        if on_usage_error:
+            self.on_usage_error = on_usage_error
+
+    def run(self, *args, **options):
+        """This is the body of the command called by :meth:`handle_argv`."""
+        raise NotImplementedError('subclass responsibility')
+
+    def on_error(self, exc):
+        self.error(self.colored.red('Error: {0}'.format(exc)))
+
+    def on_usage_error(self, exc):
+        self.handle_error(exc)
+
+    def on_concurrency_setup(self):
+        pass
 
     def __call__(self, *args, **kwargs):
+        self.verify_args(args)
         try:
             ret = self.run(*args, **kwargs)
-        except Error as exc:
-            self.error(self.colored.red('Error: {0}'.format(exc)))
+            return ret if ret is not None else EX_OK
+        except self.UsageError as exc:
+            self.on_usage_error(exc)
+            return exc.status
+        except self.Error as exc:
+            self.on_error(exc)
             return exc.status
 
-        return ret if ret is not None else EX_OK
-
-    def run(self, *args, **options):
-        """This is the body of the command called by :meth:`handle_argv`."""
-        raise NotImplementedError('subclass responsibility')
+    def verify_args(self, given, _index=0):
+        S = getargspec(self.run)
+        _index = 1 if S.args and S.args[0] == 'self' else _index
+        required = S.args[_index:-len(S.defaults) if S.defaults else None]
+        missing = required[len(given):]
+        if missing:
+            raise self.UsageError('Missing required {0}: {1}'.format(
+                text.pluralize(len(missing), 'argument'),
+                ', '.join(missing)
+            ))
 
     def execute_from_commandline(self, argv=None):
         """Execute application from command-line.
@@ -272,9 +309,6 @@ class Command(object):
             maybe_patch_concurrency(argv, *pool_option)
             short_opts, long_opts = pool_option
 
-    def on_concurrency_setup(self):
-        pass
-
     def usage(self, command):
         return '%prog {0} [options] {self.args}'.format(command, self=self)
 

+ 40 - 31
celery/bin/celery.py

@@ -12,6 +12,7 @@ import anyjson
 import os
 import sys
 
+from functools import partial
 from importlib import import_module
 
 from celery.five import string_t, values
@@ -21,7 +22,7 @@ from celery.utils import text
 from celery.utils.timeutils import maybe_iso8601
 
 # Cannot use relative imports here due to a Windows issue (#1111).
-from celery.bin.base import Command, Error, Option, Extensions
+from celery.bin.base import Command, Option, Extensions
 
 # Import commands from other modules
 from celery.bin.amqp import amqp
@@ -63,11 +64,6 @@ def determine_exit_status(ret):
     return EX_OK if ret else EX_FAILURE
 
 
-def ensure_broadcast_supported(app):
-    if app.connection().transport.driver_type == 'sql':
-        raise Error('SQL broker transports does not support broadcast')
-
-
 def main(argv=None):
     # Fix for setuptools generated scripts, so that it will
     # work with multiprocessing fork emulation.
@@ -114,7 +110,7 @@ class list_(Command):
         try:
             bindings = management.get_bindings()
         except NotImplementedError:
-            raise Error('Your transport cannot list bindings.')
+            raise self.Error('Your transport cannot list bindings.')
 
         fmt = lambda q, e, r: self.out('{0:<28} {1:<28} {2}'.format(q, e, r))
         fmt('Queue', 'Exchange', 'Routing Key')
@@ -126,10 +122,12 @@ class list_(Command):
         topics = {'bindings': self.list_bindings}
         available = ', '.join(topics)
         if not what:
-            raise Error('You must specify one of {0}'.format(available))
+            raise self.UsageError(
+                'You must specify one of {0}'.format(available))
         if what not in topics:
-            raise Error('unknown topic {0!r} (choose one of: {1})'.format(
-                what, available))
+            raise self.UsageError(
+                'unknown topic {0!r} (choose one of: {1})'.format(
+                    what, available))
         with self.app.connection() as conn:
             self.app.amqp.TaskConsumer(conn).declare()
             topics[what](conn.manager)
@@ -300,19 +298,20 @@ class _RemoteControl(Command):
 
     def run(self, *args, **kwargs):
         if not args:
-            raise Error('Missing {0.name} method. See --help'.format(self))
+            raise self.UsageError(
+                'Missing {0.name} method. See --help'.format(self))
         return self.do_call_method(args, **kwargs)
 
     def do_call_method(self, args, **kwargs):
         method = args[0]
         if method == 'help':
-            raise Error("Did you mean '{0.name} --help'?".format(self))
+            raise self.Error("Did you mean '{0.name} --help'?".format(self))
         if method not in self.choices:
-            raise Error('Unknown {0.name} method {1}'.format(self, method))
-
-        ensure_broadcast_supported(self.app)
+            raise self.UsageError(
+                'Unknown {0.name} method {1}'.format(self, method))
 
-        ensure_broadcast_supported(self.app)
+        if self.app.connection().transport.driver_type == 'sql':
+            raise self.Error('Broadcast not supported by SQL broker transport')
 
         destination = kwargs.get('destination')
         timeout = kwargs.get('timeout') or self.choices[method][0]
@@ -328,8 +327,8 @@ class _RemoteControl(Command):
                           destination=destination,
                           callback=self.say_remote_command_reply)
         if not replies:
-            raise Error('No nodes replied within time constraint.',
-                        status=EX_UNAVAILABLE)
+            raise self.Error('No nodes replied within time constraint.',
+                             status=EX_UNAVAILABLE)
         return replies
 
     def say(self, direction, title, body=''):
@@ -453,8 +452,8 @@ class status(Command):
         )
         replies = I.run('ping', **kwargs)
         if not replies:
-            raise Error('No nodes replied within time constraint',
-                        status=EX_UNAVAILABLE)
+            raise self.Error('No nodes replied within time constraint',
+                             status=EX_UNAVAILABLE)
         nodecount = len(replies)
         if not kwargs.get('quiet', False):
             self.out('\n{0} {1} online.'.format(
@@ -492,11 +491,7 @@ class migrate(Command):
     def on_migrate_task(self, state, body, message):
         self.out(self.progress_fmt.format(state=state, body=body))
 
-    def run(self, *args, **kwargs):
-        if len(args) != 2:
-            # this never exits due to OptionParser.parse_options
-            self.run_from_argv(self.prog_name, ['migrate', '--help'])
-            raise SystemExit()
+    def run(self, source, destination, **kwargs):
         from kombu import Connection
         from celery.contrib.migrate import migrate_tasks
 
@@ -680,12 +675,26 @@ class CeleryCommand(Command):
             cls, argv = self.commands['help'], ['help']
         cls = self.commands.get(command) or self.commands['help']
         try:
-            return cls(app=self.app).run_from_argv(
-                self.prog_name, argv[1:], command=argv[0],
-            )
-        except (TypeError, Error), exc:
-            raise
-            return self.execute('help', argv)
+            return cls(
+                app=self.app, on_error=self.on_error,
+                on_usage_error=partial(self.on_usage_error, command=command),
+            ).run_from_argv(self.prog_name, argv[1:], command=argv[0])
+        except self.UsageError as exc:
+            self.on_usage_error(exc)
+            return exc.status
+        except self.Error as exc:
+            self.on_error(exc)
+            return exc.status
+
+    def on_usage_error(self, exc, command=None):
+        if command:
+            helps = '{self.prog_name} {command} --help'
+        else:
+            helps = '{self.prog_name} --help'
+        self.error(self.colored.magenta("Error: {0}".format(exc)))
+        self.error("""Please try '{0}'""".format(helps.format(
+            self=self, command=command,
+        )))
 
     def remove_options_at_beginning(self, argv, index=0):
         if argv:

+ 2 - 2
celery/bin/graph.py

@@ -13,7 +13,7 @@ from operator import itemgetter
 from celery.datastructures import DependencyGraph, GraphFormatter
 from celery.five import items
 
-from .base import Command, Error
+from .base import Command
 
 
 class graph(Command):
@@ -26,7 +26,7 @@ class graph(Command):
         map = {'bootsteps': self.bootsteps, 'workers': self.workers}
         not what and self.exit_help('graph')
         if what not in map:
-            raise Error('no graph {0} in {1}'.format(what, '|'.join(map)))
+            raise self.Error('no graph {0} in {1}'.format(what, '|'.join(map)))
         return map[what](*args, **kwargs)
 
     def bootsteps(self, *args, **kwargs):