123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657 |
- from __future__ import absolute_import
- from __future__ import with_statement
- import errno
- import os
- import resource
- import signal
- from mock import Mock, patch
- from celery import current_app
- from celery import platforms
- from celery.platforms import (
- get_fdmax,
- shellsplit,
- ignore_errno,
- set_process_title,
- signals,
- maybe_drop_privileges,
- setuid,
- setgid,
- seteuid,
- setegid,
- initgroups,
- parse_uid,
- parse_gid,
- detached,
- DaemonContext,
- create_pidlock,
- PIDFile,
- LockFailed,
- setgroups,
- _setgroups_hack
- )
- from celery.tests.utils import Case, WhateverIO, override_stdouts
- class test_ignore_errno(Case):
- def test_raises_EBADF(self):
- with ignore_errno('EBADF'):
- exc = OSError()
- exc.errno = errno.EBADF
- raise exc
- def test_otherwise(self):
- with self.assertRaises(OSError):
- with ignore_errno('EBADF'):
- exc = OSError()
- exc.errno = errno.ENOENT
- raise exc
- class test_shellsplit(Case):
- def test_split(self):
- self.assertEqual(shellsplit("the 'quick' brown fox"),
- ['the', 'quick', 'brown', 'fox'])
- class test_set_process_title(Case):
- def when_no_setps(self):
- prev = platforms._setproctitle = platforms._setproctitle, None
- try:
- set_process_title('foo')
- finally:
- platforms._setproctitle = prev
- class test_Signals(Case):
- @patch('signal.getsignal')
- def test_getitem(self, getsignal):
- signals['SIGINT']
- getsignal.assert_called_with(signal.SIGINT)
- def test_supported(self):
- self.assertTrue(signals.supported('INT'))
- self.assertFalse(signals.supported('SIGIMAGINARY'))
- def test_signum(self):
- self.assertEqual(signals.signum(13), 13)
- self.assertEqual(signals.signum('INT'), signal.SIGINT)
- self.assertEqual(signals.signum('SIGINT'), signal.SIGINT)
- with self.assertRaises(TypeError):
- signals.signum('int')
- signals.signum(object())
- @patch('signal.signal')
- def test_ignore(self, set):
- signals.ignore('SIGINT')
- set.assert_called_with(signals.signum('INT'), signals.ignored)
- signals.ignore('SIGTERM')
- set.assert_called_with(signals.signum('TERM'), signals.ignored)
- @patch('signal.signal')
- def test_setitem(self, set):
- handle = lambda *a: a
- signals['INT'] = handle
- set.assert_called_with(signal.SIGINT, handle)
- @patch('signal.signal')
- def test_setitem_raises(self, set):
- set.side_effect = ValueError()
- signals['INT'] = lambda *a: a
- if not current_app.IS_WINDOWS:
- class test_get_fdmax(Case):
- @patch('resource.getrlimit')
- def test_when_infinity(self, getrlimit):
- getrlimit.return_value = [None, resource.RLIM_INFINITY]
- default = object()
- self.assertIs(get_fdmax(default), default)
- @patch('resource.getrlimit')
- def test_when_actual(self, getrlimit):
- getrlimit.return_value = [None, 13]
- self.assertEqual(get_fdmax(None), 13)
- class test_maybe_drop_privileges(Case):
- @patch('celery.platforms.parse_uid')
- @patch('pwd.getpwuid')
- @patch('celery.platforms.setgid')
- @patch('celery.platforms.setuid')
- @patch('celery.platforms.initgroups')
- def test_with_uid(self, initgroups, setuid, setgid,
- getpwuid, parse_uid):
- class pw_struct(object):
- pw_gid = 50001
- getpwuid.return_value = pw_struct()
- parse_uid.return_value = 5001
- maybe_drop_privileges(uid='user')
- parse_uid.assert_called_with('user')
- getpwuid.assert_called_with(5001)
- setgid.assert_called_with(50001)
- initgroups.assert_called_with(5001, 50001)
- setuid.assert_called_with(5001)
- @patch('celery.platforms.parse_uid')
- @patch('celery.platforms.parse_gid')
- @patch('celery.platforms.setgid')
- @patch('celery.platforms.setuid')
- @patch('celery.platforms.initgroups')
- def test_with_guid(self, initgroups, setuid, setgid,
- parse_gid, parse_uid):
- parse_uid.return_value = 5001
- parse_gid.return_value = 50001
- maybe_drop_privileges(uid='user', gid='group')
- parse_uid.assert_called_with('user')
- parse_gid.assert_called_with('group')
- setgid.assert_called_with(50001)
- initgroups.assert_called_with(5001, 50001)
- setuid.assert_called_with(5001)
- @patch('celery.platforms.setuid')
- @patch('celery.platforms.setgid')
- @patch('celery.platforms.parse_gid')
- def test_only_gid(self, parse_gid, setgid, setuid):
- parse_gid.return_value = 50001
- maybe_drop_privileges(gid='group')
- parse_gid.assert_called_with('group')
- setgid.assert_called_with(50001)
- self.assertFalse(setuid.called)
- class test_setget_uid_gid(Case):
- @patch('celery.platforms.parse_uid')
- @patch('os.setuid')
- def test_setuid(self, _setuid, parse_uid):
- parse_uid.return_value = 5001
- setuid('user')
- parse_uid.assert_called_with('user')
- _setuid.assert_called_with(5001)
- @patch('celery.platforms.parse_uid')
- @patch('os.geteuid')
- @patch('os.seteuid')
- def test_seteuid(self, _seteuid, _geteuid, parse_uid):
- parse_uid.return_value = 5001
- _geteuid.return_value = 5001
- seteuid('user')
- parse_uid.assert_called_with('user')
- self.assertFalse(_seteuid.called)
- _geteuid.return_value = 1
- seteuid('user')
- _seteuid.assert_called_with(5001)
- @patch('celery.platforms.parse_gid')
- @patch('os.setgid')
- def test_setgid(self, _setgid, parse_gid):
- parse_gid.return_value = 50001
- setgid('group')
- parse_gid.assert_called_with('group')
- _setgid.assert_called_with(50001)
- @patch('celery.platforms.parse_gid')
- @patch('os.getegid')
- @patch('os.setegid')
- def test_setegid(self, _setegid, _getegid, parse_gid):
- parse_gid.return_value = 50001
- _getegid.return_value = 50001
- setegid('group')
- parse_gid.assert_called_with('group')
- self.assertFalse(_setegid.called)
- _getegid.return_value = 1
- setegid('group')
- _setegid.assert_called_with(50001)
- def test_parse_uid_when_int(self):
- self.assertEqual(parse_uid(5001), 5001)
- @patch('pwd.getpwnam')
- def test_parse_uid_when_existing_name(self, getpwnam):
- class pwent(object):
- pw_uid = 5001
- getpwnam.return_value = pwent()
- self.assertEqual(parse_uid('user'), 5001)
- @patch('pwd.getpwnam')
- def test_parse_uid_when_nonexisting_name(self, getpwnam):
- getpwnam.side_effect = KeyError('user')
- with self.assertRaises(KeyError):
- parse_uid('user')
- def test_parse_gid_when_int(self):
- self.assertEqual(parse_gid(50001), 50001)
- @patch('grp.getgrnam')
- def test_parse_gid_when_existing_name(self, getgrnam):
- class grent(object):
- gr_gid = 50001
- getgrnam.return_value = grent()
- self.assertEqual(parse_gid('group'), 50001)
- @patch('grp.getgrnam')
- def test_parse_gid_when_nonexisting_name(self, getgrnam):
- getgrnam.side_effect = KeyError('group')
- with self.assertRaises(KeyError):
- parse_gid('group')
- class test_initgroups(Case):
- @patch('pwd.getpwuid')
- @patch('os.initgroups', create=True)
- def test_with_initgroups(self, initgroups_, getpwuid):
- getpwuid.return_value = ['user']
- initgroups(5001, 50001)
- initgroups_.assert_called_with('user', 50001)
- @patch('celery.platforms.setgroups')
- @patch('grp.getgrall')
- @patch('pwd.getpwuid')
- def test_without_initgroups(self, getpwuid, getgrall, setgroups):
- prev = getattr(os, 'initgroups', None)
- try:
- delattr(os, 'initgroups')
- except AttributeError:
- pass
- try:
- getpwuid.return_value = ['user']
- class grent(object):
- gr_mem = ['user']
- def __init__(self, gid):
- self.gr_gid = gid
- getgrall.return_value = [grent(1), grent(2), grent(3)]
- initgroups(5001, 50001)
- setgroups.assert_called_with([1, 2, 3])
- finally:
- if prev:
- os.initgroups = prev
- class test_detached(Case):
- def test_without_resource(self):
- prev, platforms.resource = platforms.resource, None
- try:
- with self.assertRaises(RuntimeError):
- detached()
- finally:
- platforms.resource = prev
- @patch('celery.platforms._create_pidlock')
- @patch('celery.platforms.signals')
- @patch('celery.platforms.maybe_drop_privileges')
- @patch('os.geteuid')
- @patch('__builtin__.open')
- def test_default(self, open, geteuid, maybe_drop,
- signals, pidlock):
- geteuid.return_value = 0
- context = detached(uid='user', gid='group')
- self.assertIsInstance(context, DaemonContext)
- signals.reset.assert_called_with('SIGCLD')
- maybe_drop.assert_called_with(uid='user', gid='group')
- open.return_value = Mock()
- geteuid.return_value = 5001
- context = detached(uid='user', gid='group', logfile='/foo/bar')
- self.assertIsInstance(context, DaemonContext)
- open.assert_called_with('/foo/bar', 'a')
- open.return_value.close.assert_called_with()
- context = detached(pidfile='/foo/bar/pid')
- self.assertIsInstance(context, DaemonContext)
- pidlock.assert_called_with('/foo/bar/pid')
- class test_DaemonContext(Case):
- @patch('os.fork')
- @patch('os.setsid')
- @patch('os._exit')
- @patch('os.chdir')
- @patch('os.umask')
- @patch('os.close')
- @patch('os.open')
- @patch('os.dup2')
- def test_open(self, dup2, open, close, umask, chdir, _exit, setsid,
- fork):
- x = DaemonContext(workdir='/opt/workdir')
- fork.return_value = 0
- with x:
- self.assertTrue(x._is_open)
- with x:
- pass
- self.assertEqual(fork.call_count, 2)
- setsid.assert_called_with()
- self.assertFalse(_exit.called)
- chdir.assert_called_with(x.workdir)
- umask.assert_called_with(x.umask)
- self.assertTrue(dup2.called)
- fork.reset_mock()
- fork.return_value = 1
- x = DaemonContext(workdir='/opt/workdir')
- with x:
- pass
- self.assertEqual(fork.call_count, 1)
- _exit.assert_called_with(0)
- x = DaemonContext(workdir='/opt/workdir', fake=True)
- x._detach = Mock()
- with x:
- pass
- self.assertFalse(x._detach.called)
- class test_PIDFile(Case):
- @patch('celery.platforms.PIDFile')
- def test_create_pidlock(self, PIDFile):
- p = PIDFile.return_value = Mock()
- p.is_locked.return_value = True
- p.remove_if_stale.return_value = False
- with self.assertRaises(SystemExit):
- create_pidlock('/var/pid')
- p.remove_if_stale.return_value = True
- ret = create_pidlock('/var/pid')
- self.assertIs(ret, p)
- def test_context(self):
- p = PIDFile('/var/pid')
- p.write_pid = Mock()
- p.remove = Mock()
- with p as _p:
- self.assertIs(_p, p)
- p.write_pid.assert_called_with()
- p.remove.assert_called_with()
- def test_acquire_raises_LockFailed(self):
- p = PIDFile('/var/pid')
- p.write_pid = Mock()
- p.write_pid.side_effect = OSError()
- with self.assertRaises(LockFailed):
- with p:
- pass
- @patch('os.path.exists')
- def test_is_locked(self, exists):
- p = PIDFile('/var/pid')
- exists.return_value = True
- self.assertTrue(p.is_locked())
- exists.return_value = False
- self.assertFalse(p.is_locked())
- @patch('__builtin__.open')
- def test_read_pid(self, open_):
- s = open_.return_value = WhateverIO()
- s.write('1816\n')
- s.seek(0)
- p = PIDFile('/var/pid')
- self.assertEqual(p.read_pid(), 1816)
- @patch('__builtin__.open')
- def test_read_pid_partially_written(self, open_):
- s = open_.return_value = WhateverIO()
- s.write('1816')
- s.seek(0)
- p = PIDFile('/var/pid')
- with self.assertRaises(ValueError):
- p.read_pid()
- @patch('__builtin__.open')
- def test_read_pid_raises_ENOENT(self, open_):
- exc = IOError()
- exc.errno = errno.ENOENT
- open_.side_effect = exc
- p = PIDFile('/var/pid')
- self.assertIsNone(p.read_pid())
- @patch('__builtin__.open')
- def test_read_pid_raises_IOError(self, open_):
- exc = IOError()
- exc.errno = errno.EAGAIN
- open_.side_effect = exc
- p = PIDFile('/var/pid')
- with self.assertRaises(IOError):
- p.read_pid()
- @patch('__builtin__.open')
- def test_read_pid_bogus_pidfile(self, open_):
- s = open_.return_value = WhateverIO()
- s.write('eighteensixteen\n')
- s.seek(0)
- p = PIDFile('/var/pid')
- with self.assertRaises(ValueError):
- p.read_pid()
- @patch('os.unlink')
- def test_remove(self, unlink):
- unlink.return_value = True
- p = PIDFile('/var/pid')
- p.remove()
- unlink.assert_called_with(p.path)
- @patch('os.unlink')
- def test_remove_ENOENT(self, unlink):
- exc = OSError()
- exc.errno = errno.ENOENT
- unlink.side_effect = exc
- p = PIDFile('/var/pid')
- p.remove()
- unlink.assert_called_with(p.path)
- @patch('os.unlink')
- def test_remove_EACCES(self, unlink):
- exc = OSError()
- exc.errno = errno.EACCES
- unlink.side_effect = exc
- p = PIDFile('/var/pid')
- p.remove()
- unlink.assert_called_with(p.path)
- @patch('os.unlink')
- def test_remove_OSError(self, unlink):
- exc = OSError()
- exc.errno = errno.EAGAIN
- unlink.side_effect = exc
- p = PIDFile('/var/pid')
- with self.assertRaises(OSError):
- p.remove()
- unlink.assert_called_with(p.path)
- @patch('os.kill')
- def test_remove_if_stale_process_alive(self, kill):
- p = PIDFile('/var/pid')
- p.read_pid = Mock()
- p.read_pid.return_value = 1816
- kill.return_value = 0
- self.assertFalse(p.remove_if_stale())
- kill.assert_called_with(1816, 0)
- p.read_pid.assert_called_with()
- kill.side_effect = OSError()
- kill.side_effect.errno = errno.ENOENT
- self.assertFalse(p.remove_if_stale())
- @patch('os.kill')
- def test_remove_if_stale_process_dead(self, kill):
- with override_stdouts():
- p = PIDFile('/var/pid')
- p.read_pid = Mock()
- p.read_pid.return_value = 1816
- p.remove = Mock()
- exc = OSError()
- exc.errno = errno.ESRCH
- kill.side_effect = exc
- self.assertTrue(p.remove_if_stale())
- kill.assert_called_with(1816, 0)
- p.remove.assert_called_with()
- def test_remove_if_stale_broken_pid(self):
- with override_stdouts():
- p = PIDFile('/var/pid')
- p.read_pid = Mock()
- p.read_pid.side_effect = ValueError()
- p.remove = Mock()
- self.assertTrue(p.remove_if_stale())
- p.remove.assert_called_with()
- def test_remove_if_stale_no_pidfile(self):
- p = PIDFile('/var/pid')
- p.read_pid = Mock()
- p.read_pid.return_value = None
- p.remove = Mock()
- self.assertTrue(p.remove_if_stale())
- p.remove.assert_called_with()
- @patch('os.fsync')
- @patch('os.getpid')
- @patch('os.open')
- @patch('os.fdopen')
- @patch('__builtin__.open')
- def test_write_pid(self, open_, fdopen, osopen, getpid, fsync):
- getpid.return_value = 1816
- osopen.return_value = 13
- w = fdopen.return_value = WhateverIO()
- w.close = Mock()
- r = open_.return_value = WhateverIO()
- r.write('1816\n')
- r.seek(0)
- p = PIDFile('/var/pid')
- p.write_pid()
- w.seek(0)
- self.assertEqual(w.readline(), '1816\n')
- self.assertTrue(w.close.called)
- getpid.assert_called_with()
- osopen.assert_called_with(p.path, platforms.PIDFILE_FLAGS,
- platforms.PIDFILE_MODE)
- fdopen.assert_called_with(13, 'w')
- fsync.assert_called_with(13)
- open_.assert_called_with(p.path)
- @patch('os.fsync')
- @patch('os.getpid')
- @patch('os.open')
- @patch('os.fdopen')
- @patch('__builtin__.open')
- def test_write_reread_fails(self, open_, fdopen,
- osopen, getpid, fsync):
- getpid.return_value = 1816
- osopen.return_value = 13
- w = fdopen.return_value = WhateverIO()
- w.close = Mock()
- r = open_.return_value = WhateverIO()
- r.write('11816\n')
- r.seek(0)
- p = PIDFile('/var/pid')
- with self.assertRaises(LockFailed):
- p.write_pid()
- class test_setgroups(Case):
- @patch('os.setgroups', create=True)
- def test_setgroups_hack_ValueError(self, setgroups):
- def on_setgroups(groups):
- if len(groups) <= 200:
- setgroups.return_value = True
- return
- raise ValueError()
- setgroups.side_effect = on_setgroups
- _setgroups_hack(range(400))
- setgroups.side_effect = ValueError()
- with self.assertRaises(ValueError):
- _setgroups_hack(range(400))
- @patch('os.setgroups', create=True)
- def test_setgroups_hack_OSError(self, setgroups):
- exc = OSError()
- exc.errno = errno.EINVAL
- def on_setgroups(groups):
- if len(groups) <= 200:
- setgroups.return_value = True
- return
- raise exc
- setgroups.side_effect = on_setgroups
- _setgroups_hack(range(400))
- setgroups.side_effect = exc
- with self.assertRaises(OSError):
- _setgroups_hack(range(400))
- exc2 = OSError()
- exc.errno = errno.ESRCH
- setgroups.side_effect = exc2
- with self.assertRaises(OSError):
- _setgroups_hack(range(400))
- @patch('os.sysconf')
- @patch('celery.platforms._setgroups_hack')
- def test_setgroups(self, hack, sysconf):
- sysconf.return_value = 100
- setgroups(range(400))
- hack.assert_called_with(range(100))
- @patch('os.sysconf')
- @patch('celery.platforms._setgroups_hack')
- def test_setgroups_sysconf_raises(self, hack, sysconf):
- sysconf.side_effect = ValueError()
- setgroups(range(400))
- hack.assert_called_with(range(400))
- @patch('os.getgroups')
- @patch('os.sysconf')
- @patch('celery.platforms._setgroups_hack')
- def test_setgroups_raises_ESRCH(self, hack, sysconf, getgroups):
- sysconf.side_effect = ValueError()
- esrch = OSError()
- esrch.errno = errno.ESRCH
- hack.side_effect = esrch
- with self.assertRaises(OSError):
- setgroups(range(400))
- @patch('os.getgroups')
- @patch('os.sysconf')
- @patch('celery.platforms._setgroups_hack')
- def test_setgroups_raises_EPERM(self, hack, sysconf, getgroups):
- sysconf.side_effect = ValueError()
- eperm = OSError()
- eperm.errno = errno.EPERM
- hack.side_effect = eperm
- getgroups.return_value = range(400)
- setgroups(range(400))
- getgroups.assert_called_with()
- getgroups.return_value = [1000]
- with self.assertRaises(OSError):
- setgroups(range(400))
- getgroups.assert_called_with()
|