Allow circular dependency with discovery (#2616)

This commit is contained in:
Paulus Schoutsen 2016-07-25 22:49:10 -07:00 committed by GitHub
parent 9c76b30e24
commit f1632496f0
10 changed files with 157 additions and 60 deletions

View File

@ -72,6 +72,7 @@ def _handle_requirements(hass, component, name):
def _setup_component(hass, domain, config): def _setup_component(hass, domain, config):
"""Setup a component for Home Assistant.""" """Setup a component for Home Assistant."""
# pylint: disable=too-many-return-statements,too-many-branches # pylint: disable=too-many-return-statements,too-many-branches
# pylint: disable=too-many-statements
if domain in hass.config.components: if domain in hass.config.components:
return True return True
@ -149,9 +150,15 @@ def _setup_component(hass, domain, config):
_CURRENT_SETUP.append(domain) _CURRENT_SETUP.append(domain)
try: try:
if not component.setup(hass, config): result = component.setup(hass, config)
if result is False:
_LOGGER.error('component %s failed to initialize', domain) _LOGGER.error('component %s failed to initialize', domain)
return False return False
elif result is not True:
_LOGGER.error('component %s did not return boolean if setup '
'was successful. Disabling component.', domain)
loader.set_component(domain, None)
return False
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error during setup of component %s', domain) _LOGGER.exception('Error during setup of component %s', domain)
return False return False

View File

@ -14,6 +14,7 @@ import threading
import time import time
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable
import voluptuous as vol import voluptuous as vol
import homeassistant.helpers.temperature as temp_helper import homeassistant.helpers.temperature as temp_helper
@ -62,6 +63,29 @@ class CoreState(enum.Enum):
return self.value return self.value
class JobPriority(util.OrderedEnum):
"""Provides job priorities for event bus jobs."""
EVENT_CALLBACK = 0
EVENT_SERVICE = 1
EVENT_STATE = 2
EVENT_TIME = 3
EVENT_DEFAULT = 4
@staticmethod
def from_event_type(event_type):
"""Return a priority based on event type."""
if event_type == EVENT_TIME_CHANGED:
return JobPriority.EVENT_TIME
elif event_type == EVENT_STATE_CHANGED:
return JobPriority.EVENT_STATE
elif event_type == EVENT_CALL_SERVICE:
return JobPriority.EVENT_SERVICE
elif event_type == EVENT_SERVICE_EXECUTED:
return JobPriority.EVENT_CALLBACK
return JobPriority.EVENT_DEFAULT
class HomeAssistant(object): class HomeAssistant(object):
"""Root object of the Home Assistant home automation.""" """Root object of the Home Assistant home automation."""
@ -69,7 +93,7 @@ class HomeAssistant(object):
"""Initialize new Home Assistant object.""" """Initialize new Home Assistant object."""
self.pool = pool = create_worker_pool() self.pool = pool = create_worker_pool()
self.bus = EventBus(pool) self.bus = EventBus(pool)
self.services = ServiceRegistry(self.bus, pool) self.services = ServiceRegistry(self.bus, self.add_job)
self.states = StateMachine(self.bus) self.states = StateMachine(self.bus)
self.config = Config() self.config = Config()
self.state = CoreState.not_running self.state = CoreState.not_running
@ -90,6 +114,17 @@ class HomeAssistant(object):
self.pool.block_till_done() self.pool.block_till_done()
self.state = CoreState.running self.state = CoreState.running
def add_job(self,
target: Callable[..., None],
*args: Any,
priority: JobPriority=JobPriority.EVENT_DEFAULT) -> None:
"""Add job to the worker pool.
target: target to call.
args: parameters for method to call.
"""
self.pool.add_job(priority, (target,) + args)
def block_till_stopped(self) -> int: def block_till_stopped(self) -> int:
"""Register service homeassistant/stop and will block until called.""" """Register service homeassistant/stop and will block until called."""
request_shutdown = threading.Event() request_shutdown = threading.Event()
@ -141,30 +176,6 @@ class HomeAssistant(object):
self.state = CoreState.not_running self.state = CoreState.not_running
class JobPriority(util.OrderedEnum):
"""Provides job priorities for event bus jobs."""
EVENT_CALLBACK = 0
EVENT_SERVICE = 1
EVENT_STATE = 2
EVENT_TIME = 3
EVENT_DEFAULT = 4
@staticmethod
def from_event_type(event_type):
"""Return a priority based on event type."""
if event_type == EVENT_TIME_CHANGED:
return JobPriority.EVENT_TIME
elif event_type == EVENT_STATE_CHANGED:
return JobPriority.EVENT_STATE
elif event_type == EVENT_CALL_SERVICE:
return JobPriority.EVENT_SERVICE
elif event_type == EVENT_SERVICE_EXECUTED:
return JobPriority.EVENT_CALLBACK
else:
return JobPriority.EVENT_DEFAULT
class EventOrigin(enum.Enum): class EventOrigin(enum.Enum):
"""Represent the origin of an event.""" """Represent the origin of an event."""
@ -222,11 +233,11 @@ class Event(object):
class EventBus(object): class EventBus(object):
"""Allows firing of and listening for events.""" """Allows firing of and listening for events."""
def __init__(self, pool=None): def __init__(self, pool: util.ThreadPool):
"""Initialize a new event bus.""" """Initialize a new event bus."""
self._listeners = {} self._listeners = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._pool = pool or create_worker_pool() self._pool = pool
@property @property
def listeners(self): def listeners(self):
@ -235,7 +246,7 @@ class EventBus(object):
return {key: len(self._listeners[key]) return {key: len(self._listeners[key])
for key in self._listeners} for key in self._listeners}
def fire(self, event_type, event_data=None, origin=EventOrigin.local): def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
"""Fire an event.""" """Fire an event."""
if not self._pool.running: if not self._pool.running:
raise HomeAssistantError('Home Assistant has shut down.') raise HomeAssistantError('Home Assistant has shut down.')
@ -575,11 +586,11 @@ class ServiceCall(object):
class ServiceRegistry(object): class ServiceRegistry(object):
"""Offers services over the eventbus.""" """Offers services over the eventbus."""
def __init__(self, bus, pool=None): def __init__(self, bus, add_job):
"""Initialize a service registry.""" """Initialize a service registry."""
self._services = {} self._services = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._pool = pool or create_worker_pool() self._add_job = add_job
self._bus = bus self._bus = bus
self._cur_id = 0 self._cur_id = 0
bus.listen(EVENT_CALL_SERVICE, self._event_to_service_call) bus.listen(EVENT_CALL_SERVICE, self._event_to_service_call)
@ -678,13 +689,11 @@ class ServiceRegistry(object):
service_call = ServiceCall(domain, service, service_data, call_id) service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service # Add a job to the pool that calls _execute_service
self._pool.add_job(JobPriority.EVENT_SERVICE, self._add_job(self._execute_service, service_handler, service_call,
(self._execute_service, priority=JobPriority.EVENT_SERVICE)
(service_handler, service_call)))
def _execute_service(self, service_and_call): def _execute_service(self, service, call):
"""Execute a service and fires a SERVICE_EXECUTED event.""" """Execute a service and fires a SERVICE_EXECUTED event."""
service, call = service_and_call
service(call) service(call)
if call.call_id is not None: if call.call_id is not None:
@ -831,8 +840,8 @@ def create_worker_pool(worker_count=None):
def job_handler(job): def job_handler(job):
"""Called whenever a job is available to do.""" """Called whenever a job is available to do."""
try: try:
func, arg = job func, *args = job
func(arg) func(*args)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# Catch any exception our service/event_listener might throw # Catch any exception our service/event_listener might throw
# We do not want to crash our ThreadPool # We do not want to crash our ThreadPool

View File

@ -72,15 +72,20 @@ def load_platform(hass, component, platform, discovered=None,
Use `listen_platform` to register a callback for these events. Use `listen_platform` to register a callback for these events.
""" """
if component is not None: def discover_platform():
bootstrap.setup_component(hass, component, hass_config) """Discover platform job."""
# No need to fire event if we could not setup component
if not bootstrap.setup_component(hass, component, hass_config):
return
data = { data = {
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component), ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform, ATTR_PLATFORM: platform,
} }
if discovered is not None: if discovered is not None:
data[ATTR_DISCOVERED] = discovered data[ATTR_DISCOVERED] = discovered
hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data) hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data)
hass.add_job(discover_platform)

View File

@ -375,8 +375,6 @@ class ThreadPool(object):
def block_till_done(self): def block_till_done(self):
"""Block till current work is done.""" """Block till current work is done."""
self._work_queue.join() self._work_queue.join()
# import traceback
# traceback.print_stack()
def stop(self): def stop(self):
"""Finish all the jobs and stops all the threads.""" """Finish all the jobs and stops all the threads."""
@ -401,7 +399,7 @@ class ThreadPool(object):
# Get new item from work_queue # Get new item from work_queue
job = self._work_queue.get().item job = self._work_queue.get().item
if job == self._quit_task: if job is self._quit_task:
self._work_queue.task_done() self._work_queue.task_done()
return return

View File

@ -18,7 +18,6 @@ class TestSwitchFlux(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
# self.hass.config.components = ['flux', 'sun', 'light']
def tearDown(self): def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""

View File

@ -2,7 +2,6 @@
import json import json
import os import os
import unittest import unittest
from unittest.mock import patch
from homeassistant.components import demo, device_tracker from homeassistant.components import demo, device_tracker
from homeassistant.remote import JSONEncoder from homeassistant.remote import JSONEncoder
@ -10,7 +9,6 @@ from homeassistant.remote import JSONEncoder
from tests.common import mock_http_component, get_test_home_assistant from tests.common import mock_http_component, get_test_home_assistant
@patch('homeassistant.components.sun.setup')
class TestDemo(unittest.TestCase): class TestDemo(unittest.TestCase):
"""Test the Demo component.""" """Test the Demo component."""
@ -28,19 +26,19 @@ class TestDemo(unittest.TestCase):
except FileNotFoundError: except FileNotFoundError:
pass pass
def test_if_demo_state_shows_by_default(self, mock_sun_setup): def test_if_demo_state_shows_by_default(self):
"""Test if demo state shows if we give no configuration.""" """Test if demo state shows if we give no configuration."""
demo.setup(self.hass, {demo.DOMAIN: {}}) demo.setup(self.hass, {demo.DOMAIN: {}})
self.assertIsNotNone(self.hass.states.get('a.Demo_Mode')) self.assertIsNotNone(self.hass.states.get('a.Demo_Mode'))
def test_hiding_demo_state(self, mock_sun_setup): def test_hiding_demo_state(self):
"""Test if you can hide the demo card.""" """Test if you can hide the demo card."""
demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}}) demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}})
self.assertIsNone(self.hass.states.get('a.Demo_Mode')) self.assertIsNone(self.hass.states.get('a.Demo_Mode'))
def test_all_entities_can_be_loaded_over_json(self, mock_sun_setup): def test_all_entities_can_be_loaded_over_json(self):
"""Test if you can hide the demo card.""" """Test if you can hide the demo card."""
demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}}) demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}})

View File

@ -1,10 +1,15 @@
"""Test discovery helpers.""" """Test discovery helpers."""
import os
from unittest.mock import patch from unittest.mock import patch
from homeassistant import loader, bootstrap, config as config_util
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
from tests.common import get_test_home_assistant from tests.common import (get_test_home_assistant, get_test_config_dir,
MockModule, MockPlatform)
VERSION_PATH = os.path.join(get_test_config_dir(), config_util.VERSION_FILE)
class TestHelpersDiscovery: class TestHelpersDiscovery:
@ -18,6 +23,9 @@ class TestHelpersDiscovery:
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
if os.path.isfile(VERSION_PATH):
os.remove(VERSION_PATH)
@patch('homeassistant.bootstrap.setup_component') @patch('homeassistant.bootstrap.setup_component')
def test_listen(self, mock_setup_component): def test_listen(self, mock_setup_component):
"""Test discovery listen/discover combo.""" """Test discovery listen/discover combo."""
@ -69,6 +77,7 @@ class TestHelpersDiscovery:
discovery.load_platform(self.hass, 'test_component', 'test_platform', discovery.load_platform(self.hass, 'test_component', 'test_platform',
'discovery info') 'discovery info')
self.hass.pool.block_till_done()
assert mock_setup_component.called assert mock_setup_component.called
assert mock_setup_component.call_args[0] == \ assert mock_setup_component.call_args[0] == \
(self.hass, 'test_component', None) (self.hass, 'test_component', None)
@ -88,3 +97,42 @@ class TestHelpersDiscovery:
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
def test_circular_import(self):
"""Test we don't break doing circular import."""
component_calls = []
platform_calls = []
def component_setup(hass, config):
"""Setup mock component."""
discovery.load_platform(hass, 'switch', 'test_circular')
component_calls.append(1)
return True
def setup_platform(hass, config, add_devices_callback,
discovery_info=None):
"""Setup mock platform."""
platform_calls.append(1)
loader.set_component(
'test_component',
MockModule('test_component', setup=component_setup))
loader.set_component(
'switch.test_circular',
MockPlatform(setup_platform,
dependencies=['test_component']))
bootstrap.from_config_dict({
'test_component': None,
'switch': [{
'platform': 'test_circular',
}],
}, self.hass)
self.hass.pool.block_till_done()
assert 'test_component' in self.hass.config.components
assert 'switch' in self.hass.config.components
assert len(component_calls) == 1
assert len(platform_calls) == 2

View File

@ -226,7 +226,8 @@ class TestHelpersEntityComponent(unittest.TestCase):
@patch('homeassistant.helpers.entity_component.EntityComponent' @patch('homeassistant.helpers.entity_component.EntityComponent'
'._setup_platform') '._setup_platform')
def test_setup_does_discovery(self, mock_setup): @patch('homeassistant.bootstrap.setup_component', return_value=True)
def test_setup_does_discovery(self, mock_setup_component, mock_setup):
"""Test setup for discovery.""" """Test setup for discovery."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass) component = EntityComponent(_LOGGER, DOMAIN, self.hass)

View File

@ -161,7 +161,7 @@ class TestBootstrap:
def test_component_not_double_initialized(self): def test_component_not_double_initialized(self):
"""Test we do not setup a component twice.""" """Test we do not setup a component twice."""
mock_setup = mock.MagicMock() mock_setup = mock.MagicMock(return_value=True)
loader.set_component('comp', MockModule('comp', setup=mock_setup)) loader.set_component('comp', MockModule('comp', setup=mock_setup))
@ -302,3 +302,29 @@ class TestBootstrap:
'valid': True 'valid': True
} }
}) })
def test_disable_component_if_invalid_return(self):
"""Test disabling component if invalid return."""
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: None))
assert not bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is None
assert 'disabled_component' not in self.hass.config.components
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: False))
assert not bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is not None
assert 'disabled_component' not in self.hass.config.components
loader.set_component(
'disabled_component',
MockModule('disabled_component', setup=lambda hass, config: True))
assert bootstrap.setup_component(self.hass, 'disabled_component')
assert loader.get_component('disabled_component') is not None
assert 'disabled_component' in self.hass.config.components

View File

@ -370,7 +370,13 @@ class TestServiceRegistry(unittest.TestCase):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.pool = ha.create_worker_pool(0) self.pool = ha.create_worker_pool(0)
self.bus = ha.EventBus(self.pool) self.bus = ha.EventBus(self.pool)
self.services = ha.ServiceRegistry(self.bus, self.pool)
def add_job(*args, **kwargs):
"""Forward calls to add_job on Home Assistant."""
# self works because we also have self.pool defined.
return ha.HomeAssistant.add_job(self, *args, **kwargs)
self.services = ha.ServiceRegistry(self.bus, add_job)
self.services.register("test_domain", "test_service", lambda x: None) self.services.register("test_domain", "test_service", lambda x: None)
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name