Browse Source

Tests passing

Ask Solem 13 years ago
parent
commit
e5a4944a94

+ 5 - 4
celery/tests/test_worker/__init__.py

@@ -473,7 +473,7 @@ class test_Consumer(unittest.TestCase):
         items = [entry[2] for entry in self.eta_schedule.queue]
         items = [entry[2] for entry in self.eta_schedule.queue]
         found = 0
         found = 0
         for item in items:
         for item in items:
-            if item.args[0].task_name == foo_task.name:
+            if item.args[0].name == foo_task.name:
                 found = True
                 found = True
         self.assertTrue(found)
         self.assertTrue(found)
         self.assertTrue(l.task_consumer.qos.call_count)
         self.assertTrue(l.task_consumer.qos.call_count)
@@ -725,8 +725,9 @@ class test_WorkController(AppCase):
         from celery import Celery
         from celery import Celery
         from celery import signals
         from celery import signals
         from celery.app import _tls
         from celery.app import _tls
-        from celery.worker import process_initializer
-        from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
+        from celery.concurrency.processes import process_initializer
+        from celery.concurrency.processes import (WORKER_SIGRESET,
+                                                  WORKER_SIGIGNORE)
 
 
         def on_worker_process_init(**kwargs):
         def on_worker_process_init(**kwargs):
             on_worker_process_init.called = True
             on_worker_process_init.called = True
@@ -881,7 +882,7 @@ class test_WorkController(AppCase):
 
 
         state.Persistent = Mock()
         state.Persistent = Mock()
         try:
         try:
-            worker = self.create_worker(db="statefilename")
+            worker = self.create_worker(state_db="statefilename")
             self.assertTrue(worker._persistence)
             self.assertTrue(worker._persistence)
         finally:
         finally:
             state.Persistent = Persistent
             state.Persistent = Persistent

+ 6 - 6
celery/tests/test_worker/test_worker_job.py

@@ -29,7 +29,7 @@ from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.encoding import from_utf8, default_encode
 from celery.utils.encoding import from_utf8, default_encode
-from celery.worker.job import Request, execute_and_trace
+from celery.worker.job import Request, TaskRequest, execute_and_trace
 from celery.worker.state import revoked
 from celery.worker.state import revoked
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
@@ -430,7 +430,7 @@ class test_TaskRequest(unittest.TestCase):
             mytask.acks_late = False
             mytask.acks_late = False
 
 
     def test_from_message_invalid_kwargs(self):
     def test_from_message_invalid_kwargs(self):
-        body = dict(task="foo", id=1, args=(), kwargs="foo")
+        body = dict(task=mytask.name, id=1, args=(), kwargs="foo")
         with self.assertRaises(InvalidTaskError):
         with self.assertRaises(InvalidTaskError):
             TaskRequest.from_message(None, body)
             TaskRequest.from_message(None, body)
 
 
@@ -547,7 +547,7 @@ class test_TaskRequest(unittest.TestCase):
                           content_type="application/json",
                           content_type="application/json",
                           content_encoding="utf-8")
                           content_encoding="utf-8")
         tw = TaskRequest.from_message(m, m.decode())
         tw = TaskRequest.from_message(m, m.decode())
-        self.assertIsInstance(tw, TaskRequest)
+        self.assertIsInstance(tw, Request)
         self.assertEqual(tw.task_name, body["task"])
         self.assertEqual(tw.task_name, body["task"])
         self.assertEqual(tw.task_id, body["id"])
         self.assertEqual(tw.task_id, body["id"])
         self.assertEqual(tw.args, body["args"])
         self.assertEqual(tw.args, body["args"])
@@ -563,7 +563,7 @@ class test_TaskRequest(unittest.TestCase):
                           content_type="application/json",
                           content_type="application/json",
                           content_encoding="utf-8")
                           content_encoding="utf-8")
         tw = TaskRequest.from_message(m, m.decode())
         tw = TaskRequest.from_message(m, m.decode())
-        self.assertIsInstance(tw, TaskRequest)
+        self.assertIsInstance(tw, Request)
         self.assertEquals(tw.args, [])
         self.assertEquals(tw.args, [])
         self.assertEquals(tw.kwargs, {})
         self.assertEquals(tw.kwargs, {})
 
 
@@ -572,7 +572,7 @@ class test_TaskRequest(unittest.TestCase):
         m = Message(None, body=anyjson.serialize(body), backend="foo",
         m = Message(None, body=anyjson.serialize(body), backend="foo",
                           content_type="application/json",
                           content_type="application/json",
                           content_encoding="utf-8")
                           content_encoding="utf-8")
-        with self.assertRaises(InvalidTaskError):
+        with self.assertRaises(KeyError):
             TaskRequest.from_message(m, m.decode())
             TaskRequest.from_message(m, m.decode())
 
 
     def test_from_message_nonexistant_task(self):
     def test_from_message_nonexistant_task(self):
@@ -665,7 +665,7 @@ class test_TaskRequest(unittest.TestCase):
                     "task_id": tw.task_id,
                     "task_id": tw.task_id,
                     "task_retries": 0,
                     "task_retries": 0,
                     "task_is_eager": False,
                     "task_is_eager": False,
-                    "delivery_info": {},
+                    "delivery_info": {"exchange": None, "routing_key": None},
                     "task_name": tw.task_name})
                     "task_name": tw.task_name})
 
 
     def _test_on_failure(self, exception):
     def _test_on_failure(self, exception):

+ 4 - 3
celery/worker/__init__.py

@@ -24,7 +24,7 @@ from .. import beat
 from .. import concurrency as _concurrency
 from .. import concurrency as _concurrency
 from .. import registry, signals
 from .. import registry, signals
 from ..app import app_or_default
 from ..app import app_or_default
-from ..app.abstract import configured, from_config
+from ..app.abstract import configurated, from_config
 from ..exceptions import SystemTerminate
 from ..exceptions import SystemTerminate
 from ..log import SilenceRepeated
 from ..log import SilenceRepeated
 from ..utils import noop, instantiate
 from ..utils import noop, instantiate
@@ -65,6 +65,7 @@ class WorkController(configurated):
 
 
     _state = None
     _state = None
     _running = 0
     _running = 0
+    _persistence = None
 
 
     def __init__(self, loglevel=None, hostname=None, logger=None,
     def __init__(self, loglevel=None, hostname=None, logger=None,
             ready_callback=noop, embed_clockservice=False, autoscale=None,
             ready_callback=noop, embed_clockservice=False, autoscale=None,
@@ -106,7 +107,7 @@ class WorkController(configurated):
         # Threads + Pool + Consumer
         # Threads + Pool + Consumer
         self.autoscaler = None
         self.autoscaler = None
         max_concurrency = None
         max_concurrency = None
-        min_concurrency = concurrency
+        min_concurrency = self.concurrency
         if autoscale:
         if autoscale:
             max_concurrency, min_concurrency = autoscale
             max_concurrency, min_concurrency = autoscale
 
 
@@ -139,7 +140,7 @@ class WorkController(configurated):
                                         logger=self.logger)
                                         logger=self.logger)
 
 
         self.scheduler = instantiate(self.eta_scheduler_cls,
         self.scheduler = instantiate(self.eta_scheduler_cls,
-                                precision=eta_scheduler_precision,
+                                precision=self.eta_scheduler_precision,
                                 on_error=self.on_timer_error,
                                 on_error=self.on_timer_error,
                                 on_tick=self.on_timer_tick)
                                 on_tick=self.on_timer_tick)
 
 

+ 9 - 5
celery/worker/job.py

@@ -70,7 +70,7 @@ class Request(object):
                  "_does_debug", "_does_info", "request_dict",
                  "_does_debug", "_does_info", "request_dict",
                  "acknowledged", "success_msg", "error_msg",
                  "acknowledged", "success_msg", "error_msg",
                  "retry_msg", "time_start", "worker_pid",
                  "retry_msg", "time_start", "worker_pid",
-                 "_already_revoked", "_terminate_on_ack", "_tzinfo")
+                 "_already_revoked", "_terminate_on_ack", "_tzlocal")
 
 
     #: Format string used to log task success.
     #: Format string used to log task success.
     success_msg = """\
     success_msg = """\
@@ -97,6 +97,7 @@ class Request(object):
         self.kwargs = body.get("kwargs", {})
         self.kwargs = body.get("kwargs", {})
         eta = body.get("eta")
         eta = body.get("eta")
         expires = body.get("expires")
         expires = body.get("expires")
+        utc = body.get("utc", False)
         self.on_ack = on_ack
         self.on_ack = on_ack
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
         self.logger = logger or self.app.log.get_default_logger()
         self.logger = logger or self.app.log.get_default_logger()
@@ -105,7 +106,7 @@ class Request(object):
         self.task = task or tasks[name]
         self.task = task or tasks[name]
         self.acknowledged = self._already_revoked = False
         self.acknowledged = self._already_revoked = False
         self.time_start = self.worker_pid = self._terminate_on_ack = None
         self.time_start = self.worker_pid = self._terminate_on_ack = None
-        self._tzinfo = None
+        self._tzlocal = None
 
 
         # timezone means the message is timezone-aware, and the only timezone
         # timezone means the message is timezone-aware, and the only timezone
         # supported at this point is UTC.
         # supported at this point is UTC.
@@ -121,6 +122,7 @@ class Request(object):
         else:
         else:
             self.expires = None
             self.expires = None
 
 
+        delivery_info = {} if delivery_info is None else delivery_info
         self.delivery_info = {
         self.delivery_info = {
             "exchange": delivery_info.get("exchange"),
             "exchange": delivery_info.get("exchange"),
             "routing_key": delivery_info.get("routing_key"),
             "routing_key": delivery_info.get("routing_key"),
@@ -141,7 +143,9 @@ class Request(object):
     @classmethod
     @classmethod
     def from_message(cls, message, body, **kwargs):
     def from_message(cls, message, body, **kwargs):
         # should be deprecated
         # should be deprecated
-        return cls(body, delivery_info=message.delivery_info, **kwargs)
+        return Request(body,
+                   delivery_info=getattr(message, "delivery_info", None),
+                   **kwargs)
 
 
     def extend_with_default_kwargs(self, loglevel, logfile):
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.
         """Extend the tasks keyword arguments with standard task arguments.
@@ -160,7 +164,7 @@ class Request(object):
                           "loglevel": loglevel,
                           "loglevel": loglevel,
                           "task_id": self.id,
                           "task_id": self.id,
                           "task_name": self.name,
                           "task_name": self.name,
-                          "task_retries": self.request_dict["retries"],
+                          "task_retries": self.request_dict.get("retries", 0),
                           "task_is_eager": False,
                           "task_is_eager": False,
                           "delivery_info": self.delivery_info}
                           "delivery_info": self.delivery_info}
         fun = self.task.run
         fun = self.task.run
@@ -435,7 +439,7 @@ class Request(object):
 
 
 class TaskRequest(Request):
 class TaskRequest(Request):
 
 
-    def __init__(name, id, args=(), kwargs={},
+    def __init__(self, name, id, args=(), kwargs={},
             eta=None, expires=None, **options):
             eta=None, expires=None, **options):
         """Compatibility class."""
         """Compatibility class."""