瀏覽代碼

Cache the value of gethostname

Ask Solem 9 年之前
父節點
當前提交
46c17e9909

+ 4 - 4
celery/app/trace.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 
 
 import logging
 import logging
 import os
 import os
-import socket
 import sys
 import sys
 
 
 from collections import namedtuple
 from collections import namedtuple
@@ -35,6 +34,7 @@ from celery.app import set_default_app
 from celery.app.task import Task as BaseTask, Context
 from celery.app.task import Task as BaseTask, Context
 from celery.exceptions import Ignore, Reject, Retry, InvalidTaskError
 from celery.exceptions import Ignore, Reject, Retry, InvalidTaskError
 from celery.five import monotonic
 from celery.five import monotonic
+from celery.utils import gethostname
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.objects import mro_lookup
 from celery.utils.objects import mro_lookup
 from celery.utils.saferepr import saferepr
 from celery.utils.saferepr import saferepr
@@ -273,7 +273,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     track_started = task.track_started
     track_started = task.track_started
     track_started = not eager and (task.track_started and not ignore_result)
     track_started = not eager and (task.track_started and not ignore_result)
     publish_result = not eager and not ignore_result
     publish_result = not eager and not ignore_result
-    hostname = hostname or socket.gethostname()
+    hostname = hostname or gethostname()
 
 
     loader_task_init = loader.on_task_init
     loader_task_init = loader.on_task_init
     loader_cleanup = loader.on_process_cleanup
     loader_cleanup = loader.on_process_cleanup
@@ -489,7 +489,7 @@ def _trace_task_ret(name, uuid, request, body, content_type,
         )
         )
     else:
     else:
         args, kwargs, embed = body
         args, kwargs, embed = body
-    hostname = socket.gethostname()
+    hostname = gethostname()
     request.update({
     request.update({
         'args': args, 'kwargs': kwargs,
         'args': args, 'kwargs': kwargs,
         'hostname': hostname, 'is_eager': False,
         'hostname': hostname, 'is_eager': False,
@@ -537,7 +537,7 @@ def report_internal_error(task, exc):
 def setup_worker_optimizations(app, hostname=None):
 def setup_worker_optimizations(app, hostname=None):
     global trace_task_ret
     global trace_task_ret
 
 
-    hostname = hostname or socket.gethostname()
+    hostname = hostname or gethostname()
 
 
     # make sure custom Task.__call__ methods that calls super
     # make sure custom Task.__call__ methods that calls super
     # will not mess up the request/task stack.
     # will not mess up the request/task stack.

+ 2 - 3
celery/bin/multi.py

@@ -100,7 +100,6 @@ import errno
 import os
 import os
 import shlex
 import shlex
 import signal
 import signal
-import socket
 import sys
 import sys
 
 
 from collections import OrderedDict, defaultdict, namedtuple
 from collections import OrderedDict, defaultdict, namedtuple
@@ -115,7 +114,7 @@ from celery import VERSION_BANNER
 from celery.five import items
 from celery.five import items
 from celery.platforms import Pidfile, IS_WINDOWS
 from celery.platforms import Pidfile, IS_WINDOWS
 from celery.utils import term
 from celery.utils import term
-from celery.utils import host_format, node_format, nodesplit
+from celery.utils import gethostname, host_format, node_format, nodesplit
 from celery.utils.text import pluralize
 from celery.utils.text import pluralize
 
 
 __all__ = ['MultiTool']
 __all__ = ['MultiTool']
@@ -480,7 +479,7 @@ def multi_args(p, cmd='celery worker', append='', prefix='', suffix=''):
     cmd = options.pop('--cmd', cmd)
     cmd = options.pop('--cmd', cmd)
     append = options.pop('--append', append)
     append = options.pop('--append', append)
     hostname = options.pop('--hostname',
     hostname = options.pop('--hostname',
-                           options.pop('-n', socket.gethostname()))
+                           options.pop('-n', gethostname()))
     prefix = options.pop('--prefix', prefix) or ''
     prefix = options.pop('--prefix', prefix) or ''
     suffix = options.pop('--suffix', suffix) or hostname
     suffix = options.pop('--suffix', suffix) or hostname
     suffix = '' if suffix in ('""', "''") else suffix
     suffix = '' if suffix in ('""', "''") else suffix

+ 1 - 1
celery/tests/bin/test_base.py

@@ -258,7 +258,7 @@ class test_Command(AppCase):
 
 
     def test_host_format(self):
     def test_host_format(self):
         cmd = MockCommand(app=self.app)
         cmd = MockCommand(app=self.app)
-        with patch('socket.gethostname') as hn:
+        with patch('celery.utils.gethostname') as hn:
             hn.return_value = 'blacktron.example.com'
             hn.return_value = 'blacktron.example.com'
             self.assertEqual(cmd.host_format(''), '')
             self.assertEqual(cmd.host_format(''), '')
             self.assertEqual(
             self.assertEqual(

+ 5 - 5
celery/tests/bin/test_multi.py

@@ -67,7 +67,7 @@ class test_NamespacedOptionParser(AppCase):
 
 
 class test_multi_args(AppCase):
 class test_multi_args(AppCase):
 
 
-    @patch('socket.gethostname')
+    @patch('celery.bin.multi.gethostname')
     def test_parse(self, gethostname):
     def test_parse(self, gethostname):
         gethostname.return_value = 'example.com'
         gethostname.return_value = 'example.com'
         p = NamespacedOptionParser([
         p = NamespacedOptionParser([
@@ -298,7 +298,7 @@ class test_MultiTool(AppCase):
         Pidfile.side_effect = pids
         Pidfile.side_effect = pids
 
 
     @patch('celery.bin.multi.Pidfile')
     @patch('celery.bin.multi.Pidfile')
-    @patch('socket.gethostname')
+    @patch('celery.bin.multi.gethostname')
     def test_getpids(self, gethostname, Pidfile):
     def test_getpids(self, gethostname, Pidfile):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
         self.prepare_pidfile_for_getpids(Pidfile)
         self.prepare_pidfile_for_getpids(Pidfile)
@@ -336,7 +336,7 @@ class test_MultiTool(AppCase):
         nodes = self.t.getpids(p, 'celery worker', callback=None)
         nodes = self.t.getpids(p, 'celery worker', callback=None)
 
 
     @patch('celery.bin.multi.Pidfile')
     @patch('celery.bin.multi.Pidfile')
-    @patch('socket.gethostname')
+    @patch('celery.bin.multi.gethostname')
     @patch('celery.bin.multi.sleep')
     @patch('celery.bin.multi.sleep')
     def test_shutdown_nodes(self, slepp, gethostname, Pidfile):
     def test_shutdown_nodes(self, slepp, gethostname, Pidfile):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
@@ -415,7 +415,7 @@ class test_MultiTool(AppCase):
         self.t.show(['foo', 'bar', 'baz'], 'celery worker')
         self.t.show(['foo', 'bar', 'baz'], 'celery worker')
         self.assertTrue(self.fh.getvalue())
         self.assertTrue(self.fh.getvalue())
 
 
-    @patch('socket.gethostname')
+    @patch('celery.bin.multi.gethostname')
     def test_get(self, gethostname):
     def test_get(self, gethostname):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
         self.t.get(['xuzzy@e.com', 'foo', 'bar', 'baz'], 'celery worker')
         self.t.get(['xuzzy@e.com', 'foo', 'bar', 'baz'], 'celery worker')
@@ -423,7 +423,7 @@ class test_MultiTool(AppCase):
         self.t.get(['foo@e.com', 'foo', 'bar', 'baz'], 'celery worker')
         self.t.get(['foo@e.com', 'foo', 'bar', 'baz'], 'celery worker')
         self.assertTrue(self.fh.getvalue())
         self.assertTrue(self.fh.getvalue())
 
 
-    @patch('socket.gethostname')
+    @patch('celery.bin.multi.gethostname')
     def test_names(self, gethostname):
     def test_names(self, gethostname):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
         self.t.names(['foo', 'bar', 'baz'], 'celery worker')
         self.t.names(['foo', 'bar', 'baz'], 'celery worker')

+ 7 - 4
celery/utils/__init__.py

@@ -26,6 +26,8 @@ from kombu.entity import Exchange, Queue
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
 from celery.exceptions import CPendingDeprecationWarning, CDeprecationWarning
 from celery.five import WhateverIO, items, reraise, string_t
 from celery.five import WhateverIO, items, reraise, string_t
 
 
+from .functional import memoize
+
 __all__ = ['worker_direct', 'warn_deprecated', 'deprecated', 'lpmerge',
 __all__ = ['worker_direct', 'warn_deprecated', 'deprecated', 'lpmerge',
            'is_iterable', 'isatty', 'cry', 'maybe_reraise', 'strtobool',
            'is_iterable', 'isatty', 'cry', 'maybe_reraise', 'strtobool',
            'jsonify', 'gen_task_name', 'nodename', 'nodesplit',
            'jsonify', 'gen_task_name', 'nodename', 'nodesplit',
@@ -33,7 +35,6 @@ __all__ = ['worker_direct', 'warn_deprecated', 'deprecated', 'lpmerge',
 
 
 PY3 = sys.version_info[0] == 3
 PY3 = sys.version_info[0] == 3
 
 
-
 PENDING_DEPRECATION_FMT = """
 PENDING_DEPRECATION_FMT = """
     {description} is scheduled for deprecation in \
     {description} is scheduled for deprecation in \
     version {deprecation} and removal in version v{removal}. \
     version {deprecation} and removal in version v{removal}. \
@@ -63,6 +64,8 @@ NODENAME_SEP = '@'
 NODENAME_DEFAULT = 'celery'
 NODENAME_DEFAULT = 'celery'
 RE_FORMAT = re.compile(r'%(\w)')
 RE_FORMAT = re.compile(r'%(\w)')
 
 
+gethostname = memoize(1, Cache=dict)(socket.gethostname)
+
 
 
 def worker_direct(hostname):
 def worker_direct(hostname):
     """Return :class:`kombu.Queue` that is a direct route to
     """Return :class:`kombu.Queue` that is a direct route to
@@ -327,7 +330,7 @@ def nodename(name, hostname):
 
 
 def anon_nodename(hostname=None, prefix='gen'):
 def anon_nodename(hostname=None, prefix='gen'):
     return nodename(''.join([prefix, str(os.getpid())]),
     return nodename(''.join([prefix, str(os.getpid())]),
-                    hostname or socket.gethostname())
+                    hostname or gethostname())
 
 
 
 
 def nodesplit(nodename):
 def nodesplit(nodename):
@@ -340,7 +343,7 @@ def nodesplit(nodename):
 
 
 def default_nodename(hostname):
 def default_nodename(hostname):
     name, host = nodesplit(hostname or '')
     name, host = nodesplit(hostname or '')
-    return nodename(name or NODENAME_DEFAULT, host or socket.gethostname())
+    return nodename(name or NODENAME_DEFAULT, host or gethostname())
 
 
 
 
 def node_format(s, nodename, **extra):
 def node_format(s, nodename, **extra):
@@ -357,7 +360,7 @@ _fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '')
 
 
 
 
 def host_format(s, host=None, name=None, **extra):
 def host_format(s, host=None, name=None, **extra):
-    host = host or socket.gethostname()
+    host = host or gethostname()
     hname, _, domain = host.partition('.')
     hname, _, domain = host.partition('.')
     name = name or hname
     name = name or hname
     keys = dict({
     keys = dict({

+ 2 - 2
celery/worker/consumer.py

@@ -14,7 +14,6 @@ import errno
 import kombu
 import kombu
 import logging
 import logging
 import os
 import os
-import socket
 
 
 from collections import defaultdict
 from collections import defaultdict
 from functools import partial
 from functools import partial
@@ -36,6 +35,7 @@ from celery import signals
 from celery.app.trace import build_tracer
 from celery.app.trace import build_tracer
 from celery.canvas import signature
 from celery.canvas import signature
 from celery.exceptions import InvalidTaskError, NotRegistered
 from celery.exceptions import InvalidTaskError, NotRegistered
+from celery.utils import gethostname
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.text import truncate
 from celery.utils.text import truncate
@@ -172,7 +172,7 @@ class Consumer(object):
         self.app = app
         self.app = app
         self.controller = controller
         self.controller = controller
         self.init_callback = init_callback
         self.init_callback = init_callback
-        self.hostname = hostname or socket.gethostname()
+        self.hostname = hostname or gethostname()
         self.pid = os.getpid()
         self.pid = os.getpid()
         self.pool = pool
         self.pool = pool
         self.timer = timer
         self.timer = timer

+ 2 - 3
celery/worker/request.py

@@ -10,7 +10,6 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import logging
 import logging
-import socket
 import sys
 import sys
 
 
 from datetime import datetime
 from datetime import datetime
@@ -27,7 +26,7 @@ from celery.exceptions import (
 )
 )
 from celery.five import string
 from celery.five import string
 from celery.platforms import signals as _signals
 from celery.platforms import signals as _signals
-from celery.utils import cached_property
+from celery.utils import cached_property, gethostname
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 from celery.utils.timeutils import maybe_iso8601, timezone, maybe_make_aware
 from celery.utils.timeutils import maybe_iso8601, timezone, maybe_make_aware
@@ -120,7 +119,7 @@ class Request(object):
         self.kwargsrepr = headers.get('kwargsrepr', '')
         self.kwargsrepr = headers.get('kwargsrepr', '')
         self.on_ack = on_ack
         self.on_ack = on_ack
         self.on_reject = on_reject
         self.on_reject = on_reject
-        self.hostname = hostname or socket.gethostname()
+        self.hostname = hostname or gethostname()
         self.eventer = eventer
         self.eventer = eventer
         self.connection_errors = connection_errors or ()
         self.connection_errors = connection_errors or ()
         self.task = task or self.app.tasks[type]
         self.task = task or self.app.tasks[type]