Enforce lower case for services and warn if local unknown service called (#2764)

This commit is contained in:
Paulus Schoutsen 2016-08-09 19:41:45 -07:00 committed by GitHub
parent 180a7ec295
commit d80c05b6b6
2 changed files with 15 additions and 10 deletions

View File

@ -582,8 +582,8 @@ class ServiceCall(object):
def __init__(self, domain, service, data=None, call_id=None): def __init__(self, domain, service, data=None, call_id=None):
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain self.domain = domain.lower()
self.service = service self.service = service.lower()
self.data = data or {} self.data = data or {}
self.call_id = call_id self.call_id = call_id
@ -618,7 +618,7 @@ class ServiceRegistry(object):
def has_service(self, domain, service): def has_service(self, domain, service):
"""Test if specified service exists.""" """Test if specified service exists."""
return service in self._services.get(domain, []) return service.lower() in self._services.get(domain.lower(), [])
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def register(self, domain, service, service_func, description=None, def register(self, domain, service, service_func, description=None,
@ -631,6 +631,8 @@ class ServiceRegistry(object):
Schema is called to coerce and validate the service data. Schema is called to coerce and validate the service data.
""" """
domain = domain.lower()
service = service.lower()
description = description or {} description = description or {}
service_obj = Service(service_func, description.get('description'), service_obj = Service(service_func, description.get('description'),
description.get('fields', {}), schema) description.get('fields', {}), schema)
@ -664,8 +666,8 @@ class ServiceRegistry(object):
call_id = self._generate_unique_id() call_id = self._generate_unique_id()
event_data = { event_data = {
ATTR_DOMAIN: domain, ATTR_DOMAIN: domain.lower(),
ATTR_SERVICE: service, ATTR_SERVICE: service.lower(),
ATTR_SERVICE_DATA: service_data, ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id, ATTR_SERVICE_CALL_ID: call_id,
} }
@ -691,11 +693,14 @@ class ServiceRegistry(object):
def _event_to_service_call(self, event): def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus.""" """Callback for SERVICE_CALLED events from the event bus."""
service_data = event.data.get(ATTR_SERVICE_DATA) service_data = event.data.get(ATTR_SERVICE_DATA)
domain = event.data.get(ATTR_DOMAIN) domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE) service = event.data.get(ATTR_SERVICE).lower()
call_id = event.data.get(ATTR_SERVICE_CALL_ID) call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service): if not self.has_service(domain, service):
if event.origin == EventOrigin.local:
_LOGGER.warning('Unable to find service %s/%s',
domain, service)
return return
service_handler = self._services[domain][service] service_handler = self._services[domain][service]

View File

@ -386,7 +386,7 @@ class TestServiceRegistry(unittest.TestCase):
return ha.HomeAssistant.add_job(self, *args, **kwargs) return ha.HomeAssistant.add_job(self, *args, **kwargs)
self.services = ha.ServiceRegistry(self.bus, add_job) self.services = ha.ServiceRegistry(self.bus, add_job)
self.services.register("test_domain", "test_service", lambda x: None) self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop down stuff we started.""" """Stop down stuff we started."""
@ -396,7 +396,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_has_service(self): def test_has_service(self):
"""Test has_service method.""" """Test has_service method."""
self.assertTrue( self.assertTrue(
self.services.has_service("test_domain", "test_service")) self.services.has_service("tesT_domaiN", "tesT_servicE"))
self.assertFalse( self.assertFalse(
self.services.has_service("test_domain", "non_existing")) self.services.has_service("test_domain", "non_existing"))
self.assertFalse( self.assertFalse(
@ -418,7 +418,7 @@ class TestServiceRegistry(unittest.TestCase):
lambda x: calls.append(1)) lambda x: calls.append(1))
self.assertTrue( self.assertTrue(
self.services.call('test_domain', 'register_calls', blocking=True)) self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.assertEqual(1, len(calls)) self.assertEqual(1, len(calls))
def test_call_with_blocking_not_done_in_time(self): def test_call_with_blocking_not_done_in_time(self):