Allow circular dependency with discovery (#2616)

This commit is contained in:
Paulus Schoutsen 2016-07-25 22:49:10 -07:00 committed by GitHub
parent 9c76b30e24
commit f1632496f0
10 changed files with 157 additions and 60 deletions

View File

@ -72,6 +72,7 @@ def _handle_requirements(hass, component, name):
def _setup_component(hass, domain, config):
"""Setup a component for Home Assistant."""
# pylint: disable=too-many-return-statements,too-many-branches
# pylint: disable=too-many-statements
if domain in hass.config.components:
return True
@ -149,9 +150,15 @@ def _setup_component(hass, domain, config):
_CURRENT_SETUP.append(domain)
try:
if not component.setup(hass, config):
result = component.setup(hass, config)
if result is False:
_LOGGER.error('component %s failed to initialize', domain)
return False
elif result is not True:
_LOGGER.error('component %s did not return boolean if setup '
'was successful. Disabling component.', domain)
loader.set_component(domain, None)
return False
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error during setup of component %s', domain)
return False

View File

@ -14,6 +14,7 @@ import threading
import time
from types import MappingProxyType
from typing import Any, Callable
import voluptuous as vol
import homeassistant.helpers.temperature as temp_helper
@ -62,6 +63,29 @@ class CoreState(enum.Enum):
return self.value
class JobPriority(util.OrderedEnum):
"""Provides job priorities for event bus jobs."""
EVENT_CALLBACK = 0
EVENT_SERVICE = 1
EVENT_STATE = 2
EVENT_TIME = 3
EVENT_DEFAULT = 4
@staticmethod
def from_event_type(event_type):
"""Return a priority based on event type."""
if event_type == EVENT_TIME_CHANGED:
return JobPriority.EVENT_TIME
elif event_type == EVENT_STATE_CHANGED:
return JobPriority.EVENT_STATE
elif event_type == EVENT_CALL_SERVICE:
return JobPriority.EVENT_SERVICE
elif event_type == EVENT_SERVICE_EXECUTED:
return JobPriority.EVENT_CALLBACK
return JobPriority.EVENT_DEFAULT
class HomeAssistant(object):
"""Root object of the Home Assistant home automation."""
@ -69,7 +93,7 @@ class HomeAssistant(object):
"""Initialize new Home Assistant object."""
self.pool = pool = create_worker_pool()
self.bus = EventBus(pool)
self.services = ServiceRegistry(self.bus, pool)
self.services = ServiceRegistry(self.bus, self.add_job)
self.states = StateMachine(self.bus)
self.config = Config()
self.state = CoreState.not_running
@ -90,6 +114,17 @@ class HomeAssistant(object):
self.pool.block_till_done()
self.state = CoreState.running
def add_job(self,
target: Callable[..., None],
*args: Any,
priority: JobPriority=JobPriority.EVENT_DEFAULT) -> None:
"""Add job to the worker pool.
target: target to call.
args: parameters for method to call.
"""
self.pool.add_job(priority, (target,) + args)
def block_till_stopped(self) -> int:
"""Register service homeassistant/stop and will block until called."""
request_shutdown = threading.Event()
@ -141,30 +176,6 @@ class HomeAssistant(object):
self.state = CoreState.not_running
class JobPriority(util.OrderedEnum):
"""Provides job priorities for event bus jobs."""
EVENT_CALLBACK = 0
EVENT_SERVICE = 1
EVENT_STATE = 2
EVENT_TIME = 3
EVENT_DEFAULT = 4
@staticmethod
def from_event_type(event_type):
"""Return a priority based on event type."""
if event_type == EVENT_TIME_CHANGED:
return JobPriority.EVENT_TIME
elif event_type == EVENT_STATE_CHANGED:
return JobPriority.EVENT_STATE
elif event_type == EVENT_CALL_SERVICE:
return JobPriority.EVENT_SERVICE
elif event_type == EVENT_SERVICE_EXECUTED:
return JobPriority.EVENT_CALLBACK
else:
return JobPriority.EVENT_DEFAULT
class EventOrigin(enum.Enum):
"""Represent the origin of an event."""
@ -222,11 +233,11 @@ class Event(object):
class EventBus(object):
"""Allows firing of and listening for events."""
def __init__(self, pool=None):
def __init__(self, pool: util.ThreadPool):
"""Initialize a new event bus."""
self._listeners = {}
self._lock = threading.Lock()
self._pool = pool or create_worker_pool()
self._pool = pool
@property
def listeners(self):
@ -235,7 +246,7 @@ class EventBus(object):
return {key: len(self._listeners[key])
for key in self._listeners}
def fire(self, event_type, event_data=None, origin=EventOrigin.local):
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
"""Fire an event."""
if not self._pool.running:
raise HomeAssistantError('Home Assistant has shut down.')
@ -575,11 +586,11 @@ class ServiceCall(object):
class ServiceRegistry(object):
"""Offers services over the eventbus."""
def __init__(self, bus, pool=None):
def __init__(self, bus, add_job):
"""Initialize a service registry."""
self._services = {}
self._lock = threading.Lock()
self._pool = pool or create_worker_pool()
self._add_job = add_job
self._bus = bus
self._cur_id = 0
bus.listen(EVENT_CALL_SERVICE, self._event_to_service_call)
@ -678,13 +689,11 @@ class ServiceRegistry(object):
service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service
self._pool.add_job(JobPriority.EVENT_SERVICE,
(self._execute_service,
(service_handler, service_call)))
self._add_job(self._execute_service, service_handler, service_call,
priority=JobPriority.EVENT_SERVICE)
def _execute_service(self, service_and_call):
def _execute_service(self, service, call):
"""Execute a service and fires a SERVICE_EXECUTED event."""
service, call = service_and_call
service(call)
if call.call_id is not None:
@ -831,8 +840,8 @@ def create_worker_pool(worker_count=None):
def job_handler(job):
"""Called whenever a job is available to do."""
try:
func, arg = job
func(arg)
func, *args = job
func(*args)
except Exception: # pylint: disable=broad-except
# Catch any exception our service/event_listener might throw
# We do not want to crash our ThreadPool

View File

@ -72,15 +72,20 @@ def load_platform(hass, component, platform, discovered=None,
Use `listen_platform` to register a callback for these events.
"""
if component is not None:
bootstrap.setup_component(hass, component, hass_config)
def discover_platform():
"""Discover platform job."""
# No need to fire event if we could not setup component
if not bootstrap.setup_component(hass, component, hass_config):
return
data = {
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform,
}
data = {
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform,
}
if discovered is not None:
data[ATTR_DISCOVERED] = discovered
if discovered is not None:
data[ATTR_DISCOVERED] = discovered
hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data)
hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data)
hass.add_job(discover_platform)

View File

@ -375,8 +375,6 @@ class ThreadPool(object):
def block_till_done(self):
"""Block till current work is done."""
self._work_queue.join()
# import traceback
# traceback.print_stack()
def stop(self):
"""Finish all the jobs and stops all the threads."""
@ -401,7 +399,7 @@ class ThreadPool(object):
# Get new item from work_queue
job = self._work_queue.get().item
if job == self._quit_task:
if job is self._quit_task:
self._work_queue.task_done()
return

View File

@ -18,7 +18,6 @@ class TestSwitchFlux(unittest.TestCase):
def setUp(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
# self.hass.config.components = ['flux', 'sun', 'light']
def tearDown(self):
"""Stop everything that was started."""

View File

@ -2,7 +2,6 @@
import json
import os
import unittest
from unittest.mock import patch
from homeassistant.components import demo, device_tracker
from homeassistant.remote import JSONEncoder
@ -10,7 +9,6 @@ from homeassistant.remote import JSONEncoder
from tests.common import mock_http_component, get_test_home_assistant
@patch('homeassistant.components.sun.setup')
class TestDemo(unittest.TestCase):
"""Test the Demo component."""
@ -28,19 +26,19 @@ class TestDemo(unittest.TestCase):
except FileNotFoundError:
pass
def test_if_demo_state_shows_by_default(self, mock_sun_setup):
def test_if_demo_state_shows_by_default(self):
"""Test if demo state shows if we give no configuration."""
demo.setup(self.hass, {demo.DOMAIN: {}})
self.assertIsNotNone(self.hass.states.get('a.Demo_Mode'))
def test_hiding_demo_state(self, mock_sun_setup):
def test_hiding_demo_state(self):
"""Test if you can hide the demo card."""
demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}})
self.assertIsNone(self.hass.states.get('a.Demo_Mode'))
def test_all_entities_can_be_loaded_over_json(self, mock_sun_setup):
def test_all_entities_can_be_loaded_over_json(self):
"""Test if you can hide the demo card."""
demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}})

View File

@ -1,10 +1,15 @@
"""Test discovery helpers."""
import os
from unittest.mock import patch
from homeassistant import loader, bootstrap, config as config_util
from homeassistant.helpers import discovery
from tests.common import get_test_home_assistant
from tests.common import (get_test_home_assistant, get_test_config_dir,
MockModule, MockPlatform)
VERSION_PATH = os.path.join(get_test_config_dir(), config_util.VERSION_FILE)
class TestHelpersDiscovery:
@ -18,6 +23,9 @@ class TestHelpersDiscovery:
"""Stop everything that was started."""
self.hass.stop()
if os.path.isfile(VERSION_PATH):
os.remove(VERSION_PATH)
@patch('homeassistant.bootstrap.setup_component')
def test_listen(self, mock_setup_component):
"""Test discovery listen/discover combo."""
@ -69,6 +77,7 @@ class TestHelpersDiscovery:
discovery.load_platform(self.hass, 'test_component', 'test_platform',
'discovery info')
self.hass.pool.block_till_done()
assert mock_setup_component.called
assert mock_setup_component.call_args[0] == \
(self.hass, 'test_component', None)
@ -88,3 +97,42 @@ class TestHelpersDiscovery:
self.hass.pool.block_till_done()
assert len(calls) == 1
def test_circular_import(self):
"""Test we don't break doing circular import."""
component_calls = []
platform_calls = []
def component_setup(hass, config):
"""Setup mock component."""
discovery.load_platform(hass, 'switch', 'test_circular')
component_calls.append(1)
return True
def setup_platform(hass, config, add_devices_callback,
discovery_info=None):
"""Setup mock platform."""
platform_calls.append(1)
loader.set_component(
'test_component',
MockModule('test_component', setup=component_setup))
loader.set_component(
'switch.test_circular',
MockPlatform(setup_platform,
dependencies=['test_component']))
bootstrap.from_config_dict({
'test_component': None,
'switch': [{
'platform': 'test_circular',
}],
}, self.hass)
self.hass.pool.block_till_done()
assert 'test_component' in self.hass.config.components
assert 'switch' in self.hass.config.components
assert len(component_calls) == 1
assert len(platform_calls) == 2

View File

@ -226,7 +226,8 @@ class TestHelpersEntityComponent(unittest.TestCase):
@patch('homeassistant.helpers.entity_component.EntityComponent'
'._setup_platform')
def test_setup_does_discovery(self, mock_setup):
@patch('homeassistant.bootstrap.setup_component', return_value=True)
def test_setup_does_discovery(self, mock_setup_component, mock_setup):
"""Test setup for discovery."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)

View File

@ -161,7 +161,7 @@ class TestBootstrap:
def test_component_not_double_initialized(self):
"""Test we do not setup a component twice."""
mock_setup = mock.MagicMock()
mock_setup = mock.MagicMock(return_value=True)
loader.set_component('comp', MockModule('comp', setup=mock_setup))
@ -302,3 +302,29 @@ class TestBootstrap:
'valid': True
}
})
def test_disable_component_if_invalid_return(self):
"""Test disabling component if invalid return."""
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: None))
assert not bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is None
assert 'disabled_component' not in self.hass.config.components
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: False))
assert not bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is not None
assert 'disabled_component' not in self.hass.config.components
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: True))
assert bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is not None
assert 'disabled_component' in self.hass.config.components

View File

@ -370,7 +370,13 @@ class TestServiceRegistry(unittest.TestCase):
"""Setup things to be run when tests are started."""
self.pool = ha.create_worker_pool(0)
self.bus = ha.EventBus(self.pool)
self.services = ha.ServiceRegistry(self.bus, self.pool)
def add_job(*args, **kwargs):
"""Forward calls to add_job on Home Assistant."""
# self works because we also have self.pool defined.
return ha.HomeAssistant.add_job(self, *args, **kwargs)
self.services = ha.ServiceRegistry(self.bus, add_job)
self.services.register("test_domain", "test_service", lambda x: None)
def tearDown(self): # pylint: disable=invalid-name