Use config entry to setup platforms (#13752)

* Use config entry to setup platforms

* Rename to async_forward_entry

* Add tests

* Catch if platform not exists for entry
This commit is contained in:
Paulus Schoutsen 2018-04-09 10:09:08 -04:00 committed by GitHub
parent cb51553c2d
commit 73de749411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 271 additions and 45 deletions

View File

@ -39,18 +39,17 @@ class HueBridge(object):
async def async_setup(self, tries=0):
"""Set up a phue bridge based on host parameter."""
host = self.host
hass = self.hass
try:
self.api = await get_bridge(
self.hass, host,
self.config_entry.data['username']
)
hass, host, self.config_entry.data['username'])
except AuthenticationRequired:
# usernames can become invalid if hub is reset or user removed.
# We are going to fail the config entry setup and initiate a new
# linking procedure. When linking succeeds, it will remove the
# old config entry.
self.hass.async_add_job(self.hass.config_entries.flow.async_init(
hass.async_add_job(hass.config_entries.flow.async_init(
DOMAIN, source='import', data={
'host': host,
}
@ -69,7 +68,7 @@ class HueBridge(object):
self.config_entry.state = config_entries.ENTRY_STATE_LOADED
# Unhandled edge case: cancel this if we discover bridge on new IP
self.hass.helpers.event.async_call_later(retry_delay, retry_setup)
hass.helpers.event.async_call_later(retry_delay, retry_setup)
return False
@ -78,11 +77,10 @@ class HueBridge(object):
host)
return False
self.hass.async_add_job(
self.hass.helpers.discovery.async_load_platform(
'light', DOMAIN, {'host': host}))
hass.async_add_job(hass.config_entries.async_forward_entry(
self.config_entry, 'light'))
self.hass.services.async_register(
hass.services.async_register(
DOMAIN, SERVICE_HUE_SCENE, self.hue_activate_scene,
schema=SCENE_SCHEMA)

View File

@ -334,7 +334,7 @@ class SetIntentHandler(intent.IntentHandler):
async def async_setup(hass, config):
"""Expose light control via state machine and services."""
component = EntityComponent(
component = hass.data[DOMAIN] = EntityComponent(
_LOGGER, DOMAIN, hass, SCAN_INTERVAL, GROUP_NAME_ALL_LIGHTS)
await component.async_setup(config)
@ -388,6 +388,11 @@ async def async_setup(hass, config):
return True
async def async_setup_entry(hass, entry):
"""Setup a config entry."""
return await hass.data[DOMAIN].async_setup_entry(entry)
class Profiles:
"""Representation of available color profiles."""

View File

@ -49,11 +49,17 @@ GROUP_MIN_API_VERSION = (1, 13, 0)
async def async_setup_platform(hass, config, async_add_devices,
discovery_info=None):
"""Set up the Hue lights."""
if discovery_info is None:
return
"""Old way of setting up Hue lights.
bridge = hass.data[hue.DOMAIN][discovery_info['host']]
Can only be called when a user accidentally mentions hue platform in their
config. But even in that case it would have been ignored.
"""
pass
async def async_setup_entry(hass, config_entry, async_add_devices):
"""Set up the Hue lights from a config entry."""
bridge = hass.data[hue.DOMAIN][config_entry.data['host']]
cur_lights = {}
cur_groups = {}

View File

@ -187,13 +187,17 @@ class ConfigEntry:
if not isinstance(result, bool):
_LOGGER.error('%s.async_config_entry did not return boolean',
self.domain)
component.DOMAIN)
result = False
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error setting up entry %s for %s',
self.title, self.domain)
self.title, component.DOMAIN)
result = False
# Only store setup result as state if it was not forwarded.
if self.domain != component.DOMAIN:
return
if result:
self.state = ENTRY_STATE_LOADED
else:
@ -322,6 +326,27 @@ class ConfigEntries:
entries = await self.hass.async_add_job(load_json, path)
self._entries = [ConfigEntry(**entry) for entry in entries]
async def async_forward_entry(self, entry, component):
"""Forward the setup of an entry to a different component.
By default an entry is setup with the component it belongs to. If that
component also has related platforms, the component will have to
forward the entry to be setup by that component.
You don't want to await this coroutine if it is called as part of the
setup of a component, because it can cause a deadlock.
"""
# Setup Component if not set up yet
if component not in self.hass.config.components:
result = await async_setup_component(
self.hass, component, self._hass_config)
if not result:
return False
await entry.async_setup(
self.hass, component=getattr(self.hass.components, component))
async def _async_add_entry(self, entry):
"""Add an entry."""
self._entries.append(entry)

View File

@ -93,6 +93,26 @@ class EntityComponent(object):
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered)
async def async_setup_entry(self, config_entry):
"""Setup a config entry."""
platform_type = config_entry.domain
platform = await async_prepare_setup_platform(
self.hass, self.config, self.domain, platform_type)
if platform is None:
return False
key = config_entry.entry_id
if key in self._platforms:
raise ValueError('Config entry has already been setup!')
self._platforms[key] = self._async_init_entity_platform(
platform_type, platform
)
return await self._platforms[key].async_setup_entry(config_entry)
@callback
def async_extract_from_service(self, service, expand_group=True):
"""Extract all known and available entities from a service call.

View File

@ -1,15 +1,13 @@
"""Class to manage the entities for a single platform."""
import asyncio
from datetime import timedelta
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import callback, valid_entity_id, split_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.util.async_ import (
run_callback_threadsafe, run_coroutine_threadsafe)
import homeassistant.util.dt as dt_util
from .event import async_track_time_interval, async_track_point_in_time
from .event import async_track_time_interval, async_call_later
from .entity_registry import async_get_registry
SLOW_SETUP_WARNING = 10
@ -42,6 +40,7 @@ class EntityPlatform(object):
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.async_entities_added_callback = async_entities_added_callback
self.config_entry = None
self.entities = {}
self._tasks = []
self._async_unsub_polling = None
@ -68,9 +67,47 @@ class EntityPlatform(object):
else:
self.parallel_updates = None
async def async_setup(self, platform_config, discovery_info=None, tries=0):
"""Setup the platform."""
async def async_setup(self, platform_config, discovery_info=None):
"""Setup the platform from a config file."""
platform = self.platform
hass = self.hass
@callback
def async_create_setup_task():
"""Get task to setup platform."""
if getattr(platform, 'async_setup_platform', None):
return platform.async_setup_platform(
hass, platform_config,
self._async_schedule_add_entities, discovery_info
)
# This should not be replaced with hass.async_add_job because
# we don't want to track this task in case it blocks startup.
return hass.loop.run_in_executor(
None, platform.setup_platform, hass, platform_config,
self._schedule_add_entities, discovery_info
)
await self._async_setup_platform(async_create_setup_task)
async def async_setup_entry(self, config_entry):
"""Setup the platform from a config entry."""
# Store it so that we can save config entry ID in entity registry
self.config_entry = config_entry
platform = self.platform
@callback
def async_create_setup_task():
"""Get task to setup platform."""
return platform.async_setup_entry(
self.hass, config_entry, self._async_schedule_add_entities)
return await self._async_setup_platform(async_create_setup_task)
async def _async_setup_platform(self, async_create_setup_task, tries=0):
"""Helper to setup a platform via config file or config entry.
async_create_setup_task creates a coroutine that sets up platform.
"""
logger = self.logger
hass = self.hass
full_name = '{}.{}'.format(self.domain, self.platform_name)
@ -82,18 +119,8 @@ class EntityPlatform(object):
self.platform_name, SLOW_SETUP_WARNING)
try:
if getattr(platform, 'async_setup_platform', None):
task = platform.async_setup_platform(
hass, platform_config,
self._async_schedule_add_entities, discovery_info
)
else:
# This should not be replaced with hass.async_add_job because
# we don't want to track this task in case it blocks startup.
task = hass.loop.run_in_executor(
None, platform.setup_platform, hass, platform_config,
self._schedule_add_entities, discovery_info
)
task = async_create_setup_task()
await asyncio.wait_for(
asyncio.shield(task, loop=hass.loop),
SLOW_SETUP_MAX_WAIT, loop=hass.loop)
@ -108,23 +135,31 @@ class EntityPlatform(object):
pending, loop=self.hass.loop)
hass.config.components.add(full_name)
return True
except PlatformNotReady:
tries += 1
wait_time = min(tries, 6) * 30
logger.warning(
'Platform %s not ready yet. Retrying in %d seconds.',
self.platform_name, wait_time)
async_track_point_in_time(
hass, self.async_setup(platform_config, discovery_info, tries),
dt_util.utcnow() + timedelta(seconds=wait_time))
async def setup_again(now):
"""Run setup again."""
await self._async_setup_platform(
async_create_setup_task, tries)
async_call_later(hass, wait_time, setup_again)
return False
except asyncio.TimeoutError:
logger.error(
"Setup of platform %s is taking longer than %s seconds."
" Startup will proceed without waiting any longer.",
self.platform_name, SLOW_SETUP_MAX_WAIT)
return False
except Exception: # pylint: disable=broad-except
logger.exception(
"Error while setting up platform %s", self.platform_name)
return False
finally:
warn_task.cancel()

View File

@ -344,7 +344,8 @@ class MockPlatform(object):
# pylint: disable=invalid-name
def __init__(self, setup_platform=None, dependencies=None,
platform_schema=None, async_setup_platform=None):
platform_schema=None, async_setup_platform=None,
async_setup_entry=None):
"""Initialize the platform."""
self.DEPENDENCIES = dependencies or []
@ -358,6 +359,9 @@ class MockPlatform(object):
if async_setup_platform is not None:
self.async_setup_platform = async_setup_platform
if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
if setup_platform is None and async_setup_platform is None:
self.async_setup_platform = mock_coro_func()
@ -376,6 +380,14 @@ class MockEntityPlatform(entity_platform.EntityPlatform):
async_entities_added_callback=lambda: None
):
"""Initialize a mock entity platform."""
if logger is None:
logger = logging.getLogger('homeassistant.helpers.entity_platform')
# Otherwise the constructor will blow up.
if (isinstance(platform, Mock) and
isinstance(platform.PARALLEL_UPDATES, Mock)):
platform.PARALLEL_UPDATES = 0
super().__init__(
hass=hass,
logger=logger,

View File

@ -18,10 +18,9 @@ async def test_bridge_setup():
assert await hue_bridge.async_setup() is True
assert hue_bridge.api is api
assert len(hass.helpers.discovery.async_load_platform.mock_calls) == 1
assert hass.helpers.discovery.async_load_platform.mock_calls[0][1][2] == {
'host': '1.2.3.4'
}
assert len(hass.config_entries.async_forward_entry.mock_calls) == 1
assert hass.config_entries.async_forward_entry.mock_calls[0][1] == \
(entry, 'light')
async def test_bridge_setup_invalid_username():

View File

@ -9,6 +9,7 @@ from aiohue.lights import Lights
from aiohue.groups import Groups
import pytest
from homeassistant import config_entries
from homeassistant.components import hue
import homeassistant.components.light.hue as hue_light
from homeassistant.util import color
@ -196,9 +197,11 @@ async def setup_bridge(hass, mock_bridge):
"""Load the Hue light platform with the provided bridge."""
hass.config.components.add(hue.DOMAIN)
hass.data[hue.DOMAIN] = {'mock-host': mock_bridge}
await hass.helpers.discovery.async_load_platform('light', 'hue', {
config_entry = config_entries.ConfigEntry(1, hue.DOMAIN, 'Mock Title', {
'host': 'mock-host'
})
}, 'test')
await hass.config_entries.async_forward_entry(config_entry, 'light')
# To flush out the service call to update the group
await hass.async_block_till_done()

View File

@ -7,6 +7,8 @@ import unittest
from unittest.mock import patch, Mock
from datetime import timedelta
import pytest
import homeassistant.core as ha
import homeassistant.loader as loader
from homeassistant.exceptions import PlatformNotReady
@ -19,7 +21,7 @@ import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, mock_coro,
async_fire_time_changed, MockEntity)
async_fire_time_changed, MockEntity, MockConfigEntry)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
@ -333,3 +335,44 @@ def test_setup_dependencies_platform(hass):
assert 'test_component' in hass.config.components
assert 'test_component2' in hass.config.components
assert 'test_domain.test_component' in hass.config.components
async def test_setup_entry(hass):
"""Test setup entry calls async_setup_entry on platform."""
mock_setup_entry = Mock(return_value=mock_coro(True))
loader.set_component(
'test_domain.entry_domain',
MockPlatform(async_setup_entry=mock_setup_entry))
component = EntityComponent(_LOGGER, DOMAIN, hass)
entry = MockConfigEntry(domain='entry_domain')
assert await component.async_setup_entry(entry)
assert len(mock_setup_entry.mock_calls) == 1
p_hass, p_entry, p_add_entities = mock_setup_entry.mock_calls[0][1]
assert p_hass is hass
assert p_entry is entry
async def test_setup_entry_platform_not_exist(hass):
"""Test setup entry fails if platform doesnt exist."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entry = MockConfigEntry(domain='non_existing')
assert (await component.async_setup_entry(entry)) is False
async def test_setup_entry_fails_duplicate(hass):
"""Test we don't allow setting up a config entry twice."""
mock_setup_entry = Mock(return_value=mock_coro(True))
loader.set_component(
'test_domain.entry_domain',
MockPlatform(async_setup_entry=mock_setup_entry))
component = EntityComponent(_LOGGER, DOMAIN, hass)
entry = MockConfigEntry(domain='entry_domain')
assert await component.async_setup_entry(entry)
with pytest.raises(ValueError):
await component.async_setup_entry(entry)

View File

@ -5,6 +5,7 @@ import unittest
from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta
from homeassistant.exceptions import PlatformNotReady
import homeassistant.loader as loader
from homeassistant.helpers.entity import generate_entity_id
from homeassistant.helpers.entity_component import (
@ -15,7 +16,7 @@ import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity, MockEntityPlatform)
MockEntity, MockEntityPlatform, MockConfigEntry, mock_coro)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
@ -511,3 +512,46 @@ async def test_entity_registry_updates(hass):
state = hass.states.get('test_domain.world')
assert state.name == 'after update'
async def test_setup_entry(hass):
"""Test we can setup an entry."""
async_setup_entry = Mock(return_value=mock_coro(True))
platform = MockPlatform(
async_setup_entry=async_setup_entry
)
config_entry = MockConfigEntry()
entity_platform = MockEntityPlatform(
hass,
platform_name=config_entry.domain,
platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
full_name = '{}.{}'.format(entity_platform.domain, config_entry.domain)
assert full_name in hass.config.components
assert len(async_setup_entry.mock_calls) == 1
async def test_setup_entry_platform_not_ready(hass, caplog):
"""Test when an entry is not ready yet."""
async_setup_entry = Mock(side_effect=PlatformNotReady)
platform = MockPlatform(
async_setup_entry=async_setup_entry
)
config_entry = MockConfigEntry()
ent_platform = MockEntityPlatform(
hass,
platform_name=config_entry.domain,
platform=platform
)
with patch.object(entity_platform, 'async_call_later') as mock_call_later:
assert not await ent_platform.async_setup_entry(config_entry)
full_name = '{}.{}'.format(ent_platform.domain, config_entry.domain)
assert full_name not in hass.config.components
assert len(async_setup_entry.mock_calls) == 1
assert 'Platform test not ready yet' in caplog.text
assert len(mock_call_later.mock_calls) == 1

View File

@ -389,3 +389,39 @@ def test_discovery_init_flow(manager):
assert entry.title == 'hello'
assert entry.data == data
assert entry.source == config_entries.SOURCE_DISCOVERY
async def test_forward_entry_sets_up_component(hass):
"""Test we setup the component entry is forwarded to."""
entry = MockConfigEntry(domain='original')
mock_original_setup_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(
'original',
MockModule('original', async_setup_entry=mock_original_setup_entry))
mock_forwarded_setup_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(
'forwarded',
MockModule('forwarded', async_setup_entry=mock_forwarded_setup_entry))
await hass.config_entries.async_forward_entry(entry, 'forwarded')
assert len(mock_original_setup_entry.mock_calls) == 0
assert len(mock_forwarded_setup_entry.mock_calls) == 1
async def test_forward_entry_does_not_setup_entry_if_setup_fails(hass):
"""Test we do not setup entry if component setup fails."""
entry = MockConfigEntry(domain='original')
mock_setup = MagicMock(return_value=mock_coro(False))
mock_setup_entry = MagicMock()
loader.set_component('forwarded', MockModule(
'forwarded',
async_setup=mock_setup,
async_setup_entry=mock_setup_entry,
))
await hass.config_entries.async_forward_entry(entry, 'forwarded')
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 0