diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 71e6b88e310..e4b524cfc5a 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -72,6 +72,7 @@ def _handle_requirements(hass, component, name): def _setup_component(hass, domain, config): """Setup a component for Home Assistant.""" # pylint: disable=too-many-return-statements,too-many-branches + # pylint: disable=too-many-statements if domain in hass.config.components: return True @@ -149,9 +150,15 @@ def _setup_component(hass, domain, config): _CURRENT_SETUP.append(domain) try: - if not component.setup(hass, config): + result = component.setup(hass, config) + if result is False: _LOGGER.error('component %s failed to initialize', domain) return False + elif result is not True: + _LOGGER.error('component %s did not return boolean if setup ' + 'was successful. Disabling component.', domain) + loader.set_component(domain, None) + return False except Exception: # pylint: disable=broad-except _LOGGER.exception('Error during setup of component %s', domain) return False diff --git a/homeassistant/core.py b/homeassistant/core.py index aec9cce5612..a360191f62a 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -14,6 +14,7 @@ import threading import time from types import MappingProxyType +from typing import Any, Callable import voluptuous as vol import homeassistant.helpers.temperature as temp_helper @@ -62,6 +63,29 @@ class CoreState(enum.Enum): return self.value +class JobPriority(util.OrderedEnum): + """Provides job priorities for event bus jobs.""" + + EVENT_CALLBACK = 0 + EVENT_SERVICE = 1 + EVENT_STATE = 2 + EVENT_TIME = 3 + EVENT_DEFAULT = 4 + + @staticmethod + def from_event_type(event_type): + """Return a priority based on event type.""" + if event_type == EVENT_TIME_CHANGED: + return JobPriority.EVENT_TIME + elif event_type == EVENT_STATE_CHANGED: + return JobPriority.EVENT_STATE + elif event_type == EVENT_CALL_SERVICE: + return JobPriority.EVENT_SERVICE + elif event_type == EVENT_SERVICE_EXECUTED: + return JobPriority.EVENT_CALLBACK + return JobPriority.EVENT_DEFAULT + + class HomeAssistant(object): """Root object of the Home Assistant home automation.""" @@ -69,7 +93,7 @@ class HomeAssistant(object): """Initialize new Home Assistant object.""" self.pool = pool = create_worker_pool() self.bus = EventBus(pool) - self.services = ServiceRegistry(self.bus, pool) + self.services = ServiceRegistry(self.bus, self.add_job) self.states = StateMachine(self.bus) self.config = Config() self.state = CoreState.not_running @@ -90,6 +114,17 @@ class HomeAssistant(object): self.pool.block_till_done() self.state = CoreState.running + def add_job(self, + target: Callable[..., None], + *args: Any, + priority: JobPriority=JobPriority.EVENT_DEFAULT) -> None: + """Add job to the worker pool. + + target: target to call. + args: parameters for method to call. + """ + self.pool.add_job(priority, (target,) + args) + def block_till_stopped(self) -> int: """Register service homeassistant/stop and will block until called.""" request_shutdown = threading.Event() @@ -141,30 +176,6 @@ class HomeAssistant(object): self.state = CoreState.not_running -class JobPriority(util.OrderedEnum): - """Provides job priorities for event bus jobs.""" - - EVENT_CALLBACK = 0 - EVENT_SERVICE = 1 - EVENT_STATE = 2 - EVENT_TIME = 3 - EVENT_DEFAULT = 4 - - @staticmethod - def from_event_type(event_type): - """Return a priority based on event type.""" - if event_type == EVENT_TIME_CHANGED: - return JobPriority.EVENT_TIME - elif event_type == EVENT_STATE_CHANGED: - return JobPriority.EVENT_STATE - elif event_type == EVENT_CALL_SERVICE: - return JobPriority.EVENT_SERVICE - elif event_type == EVENT_SERVICE_EXECUTED: - return JobPriority.EVENT_CALLBACK - else: - return JobPriority.EVENT_DEFAULT - - class EventOrigin(enum.Enum): """Represent the origin of an event.""" @@ -222,11 +233,11 @@ class Event(object): class EventBus(object): """Allows firing of and listening for events.""" - def __init__(self, pool=None): + def __init__(self, pool: util.ThreadPool): """Initialize a new event bus.""" self._listeners = {} self._lock = threading.Lock() - self._pool = pool or create_worker_pool() + self._pool = pool @property def listeners(self): @@ -235,7 +246,7 @@ class EventBus(object): return {key: len(self._listeners[key]) for key in self._listeners} - def fire(self, event_type, event_data=None, origin=EventOrigin.local): + def fire(self, event_type: str, event_data=None, origin=EventOrigin.local): """Fire an event.""" if not self._pool.running: raise HomeAssistantError('Home Assistant has shut down.') @@ -575,11 +586,11 @@ class ServiceCall(object): class ServiceRegistry(object): """Offers services over the eventbus.""" - def __init__(self, bus, pool=None): + def __init__(self, bus, add_job): """Initialize a service registry.""" self._services = {} self._lock = threading.Lock() - self._pool = pool or create_worker_pool() + self._add_job = add_job self._bus = bus self._cur_id = 0 bus.listen(EVENT_CALL_SERVICE, self._event_to_service_call) @@ -678,13 +689,11 @@ class ServiceRegistry(object): service_call = ServiceCall(domain, service, service_data, call_id) # Add a job to the pool that calls _execute_service - self._pool.add_job(JobPriority.EVENT_SERVICE, - (self._execute_service, - (service_handler, service_call))) + self._add_job(self._execute_service, service_handler, service_call, + priority=JobPriority.EVENT_SERVICE) - def _execute_service(self, service_and_call): + def _execute_service(self, service, call): """Execute a service and fires a SERVICE_EXECUTED event.""" - service, call = service_and_call service(call) if call.call_id is not None: @@ -831,8 +840,8 @@ def create_worker_pool(worker_count=None): def job_handler(job): """Called whenever a job is available to do.""" try: - func, arg = job - func(arg) + func, *args = job + func(*args) except Exception: # pylint: disable=broad-except # Catch any exception our service/event_listener might throw # We do not want to crash our ThreadPool diff --git a/homeassistant/helpers/discovery.py b/homeassistant/helpers/discovery.py index 480c786d31f..b0cf8af0747 100644 --- a/homeassistant/helpers/discovery.py +++ b/homeassistant/helpers/discovery.py @@ -72,15 +72,20 @@ def load_platform(hass, component, platform, discovered=None, Use `listen_platform` to register a callback for these events. """ - if component is not None: - bootstrap.setup_component(hass, component, hass_config) + def discover_platform(): + """Discover platform job.""" + # No need to fire event if we could not setup component + if not bootstrap.setup_component(hass, component, hass_config): + return - data = { - ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component), - ATTR_PLATFORM: platform, - } + data = { + ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component), + ATTR_PLATFORM: platform, + } - if discovered is not None: - data[ATTR_DISCOVERED] = discovered + if discovered is not None: + data[ATTR_DISCOVERED] = discovered - hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data) + hass.bus.fire(EVENT_PLATFORM_DISCOVERED, data) + + hass.add_job(discover_platform) diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index 1f2584d655a..5dcf3ba2bc9 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -375,8 +375,6 @@ class ThreadPool(object): def block_till_done(self): """Block till current work is done.""" self._work_queue.join() - # import traceback - # traceback.print_stack() def stop(self): """Finish all the jobs and stops all the threads.""" @@ -401,7 +399,7 @@ class ThreadPool(object): # Get new item from work_queue job = self._work_queue.get().item - if job == self._quit_task: + if job is self._quit_task: self._work_queue.task_done() return diff --git a/tests/components/switch/test_flux.py b/tests/components/switch/test_flux.py index ee20daf07ac..78d1f5190d6 100644 --- a/tests/components/switch/test_flux.py +++ b/tests/components/switch/test_flux.py @@ -18,7 +18,6 @@ class TestSwitchFlux(unittest.TestCase): def setUp(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - # self.hass.config.components = ['flux', 'sun', 'light'] def tearDown(self): """Stop everything that was started.""" diff --git a/tests/components/test_demo.py b/tests/components/test_demo.py index c6abe66a7ce..1c9203c944c 100644 --- a/tests/components/test_demo.py +++ b/tests/components/test_demo.py @@ -2,7 +2,6 @@ import json import os import unittest -from unittest.mock import patch from homeassistant.components import demo, device_tracker from homeassistant.remote import JSONEncoder @@ -10,7 +9,6 @@ from homeassistant.remote import JSONEncoder from tests.common import mock_http_component, get_test_home_assistant -@patch('homeassistant.components.sun.setup') class TestDemo(unittest.TestCase): """Test the Demo component.""" @@ -28,19 +26,19 @@ class TestDemo(unittest.TestCase): except FileNotFoundError: pass - def test_if_demo_state_shows_by_default(self, mock_sun_setup): + def test_if_demo_state_shows_by_default(self): """Test if demo state shows if we give no configuration.""" demo.setup(self.hass, {demo.DOMAIN: {}}) self.assertIsNotNone(self.hass.states.get('a.Demo_Mode')) - def test_hiding_demo_state(self, mock_sun_setup): + def test_hiding_demo_state(self): """Test if you can hide the demo card.""" demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}}) self.assertIsNone(self.hass.states.get('a.Demo_Mode')) - def test_all_entities_can_be_loaded_over_json(self, mock_sun_setup): + def test_all_entities_can_be_loaded_over_json(self): """Test if you can hide the demo card.""" demo.setup(self.hass, {demo.DOMAIN: {'hide_demo_state': 1}}) diff --git a/tests/helpers/test_discovery.py b/tests/helpers/test_discovery.py index bdc6e2ed119..b6f9ed5dec8 100644 --- a/tests/helpers/test_discovery.py +++ b/tests/helpers/test_discovery.py @@ -1,10 +1,15 @@ """Test discovery helpers.""" +import os from unittest.mock import patch +from homeassistant import loader, bootstrap, config as config_util from homeassistant.helpers import discovery -from tests.common import get_test_home_assistant +from tests.common import (get_test_home_assistant, get_test_config_dir, + MockModule, MockPlatform) + +VERSION_PATH = os.path.join(get_test_config_dir(), config_util.VERSION_FILE) class TestHelpersDiscovery: @@ -18,6 +23,9 @@ class TestHelpersDiscovery: """Stop everything that was started.""" self.hass.stop() + if os.path.isfile(VERSION_PATH): + os.remove(VERSION_PATH) + @patch('homeassistant.bootstrap.setup_component') def test_listen(self, mock_setup_component): """Test discovery listen/discover combo.""" @@ -69,6 +77,7 @@ class TestHelpersDiscovery: discovery.load_platform(self.hass, 'test_component', 'test_platform', 'discovery info') + self.hass.pool.block_till_done() assert mock_setup_component.called assert mock_setup_component.call_args[0] == \ (self.hass, 'test_component', None) @@ -88,3 +97,42 @@ class TestHelpersDiscovery: self.hass.pool.block_till_done() assert len(calls) == 1 + + def test_circular_import(self): + """Test we don't break doing circular import.""" + component_calls = [] + platform_calls = [] + + def component_setup(hass, config): + """Setup mock component.""" + discovery.load_platform(hass, 'switch', 'test_circular') + component_calls.append(1) + return True + + def setup_platform(hass, config, add_devices_callback, + discovery_info=None): + """Setup mock platform.""" + platform_calls.append(1) + + loader.set_component( + 'test_component', + MockModule('test_component', setup=component_setup)) + + loader.set_component( + 'switch.test_circular', + MockPlatform(setup_platform, + dependencies=['test_component'])) + + bootstrap.from_config_dict({ + 'test_component': None, + 'switch': [{ + 'platform': 'test_circular', + }], + }, self.hass) + + self.hass.pool.block_till_done() + + assert 'test_component' in self.hass.config.components + assert 'switch' in self.hass.config.components + assert len(component_calls) == 1 + assert len(platform_calls) == 2 diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index a9a6310eb79..2fa65f6d4ec 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -226,7 +226,8 @@ class TestHelpersEntityComponent(unittest.TestCase): @patch('homeassistant.helpers.entity_component.EntityComponent' '._setup_platform') - def test_setup_does_discovery(self, mock_setup): + @patch('homeassistant.bootstrap.setup_component', return_value=True) + def test_setup_does_discovery(self, mock_setup_component, mock_setup): """Test setup for discovery.""" component = EntityComponent(_LOGGER, DOMAIN, self.hass) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 34aaa1b83ed..d41dc60ee15 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -161,7 +161,7 @@ class TestBootstrap: def test_component_not_double_initialized(self): """Test we do not setup a component twice.""" - mock_setup = mock.MagicMock() + mock_setup = mock.MagicMock(return_value=True) loader.set_component('comp', MockModule('comp', setup=mock_setup)) @@ -302,3 +302,29 @@ class TestBootstrap: 'valid': True } }) + + def test_disable_component_if_invalid_return(self): + """Test disabling component if invalid return.""" + loader.set_component( + 'disabled_component', + MockModule('disabled_component', setup=lambda hass, config: None)) + + assert not bootstrap.setup_component(self.hass, 'disabled_component') + assert loader.get_component('disabled_component') is None + assert 'disabled_component' not in self.hass.config.components + + loader.set_component( + 'disabled_component', + MockModule('disabled_component', setup=lambda hass, config: False)) + + assert not bootstrap.setup_component(self.hass, 'disabled_component') + assert loader.get_component('disabled_component') is not None + assert 'disabled_component' not in self.hass.config.components + + loader.set_component( + 'disabled_component', + MockModule('disabled_component', setup=lambda hass, config: True)) + + assert bootstrap.setup_component(self.hass, 'disabled_component') + assert loader.get_component('disabled_component') is not None + assert 'disabled_component' in self.hass.config.components diff --git a/tests/test_core.py b/tests/test_core.py index cb698cdc53c..e9513a2adb8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -370,7 +370,13 @@ class TestServiceRegistry(unittest.TestCase): """Setup things to be run when tests are started.""" self.pool = ha.create_worker_pool(0) self.bus = ha.EventBus(self.pool) - self.services = ha.ServiceRegistry(self.bus, self.pool) + + def add_job(*args, **kwargs): + """Forward calls to add_job on Home Assistant.""" + # self works because we also have self.pool defined. + return ha.HomeAssistant.add_job(self, *args, **kwargs) + + self.services = ha.ServiceRegistry(self.bus, add_job) self.services.register("test_domain", "test_service", lambda x: None) def tearDown(self): # pylint: disable=invalid-name