mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Use config entry to setup platforms (#13752)
* Use config entry to setup platforms * Rename to async_forward_entry * Add tests * Catch if platform not exists for entry
This commit is contained in:
parent
cb51553c2d
commit
73de749411
@ -39,18 +39,17 @@ class HueBridge(object):
|
||||
async def async_setup(self, tries=0):
|
||||
"""Set up a phue bridge based on host parameter."""
|
||||
host = self.host
|
||||
hass = self.hass
|
||||
|
||||
try:
|
||||
self.api = await get_bridge(
|
||||
self.hass, host,
|
||||
self.config_entry.data['username']
|
||||
)
|
||||
hass, host, self.config_entry.data['username'])
|
||||
except AuthenticationRequired:
|
||||
# usernames can become invalid if hub is reset or user removed.
|
||||
# We are going to fail the config entry setup and initiate a new
|
||||
# linking procedure. When linking succeeds, it will remove the
|
||||
# old config entry.
|
||||
self.hass.async_add_job(self.hass.config_entries.flow.async_init(
|
||||
hass.async_add_job(hass.config_entries.flow.async_init(
|
||||
DOMAIN, source='import', data={
|
||||
'host': host,
|
||||
}
|
||||
@ -69,7 +68,7 @@ class HueBridge(object):
|
||||
self.config_entry.state = config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
# Unhandled edge case: cancel this if we discover bridge on new IP
|
||||
self.hass.helpers.event.async_call_later(retry_delay, retry_setup)
|
||||
hass.helpers.event.async_call_later(retry_delay, retry_setup)
|
||||
|
||||
return False
|
||||
|
||||
@ -78,11 +77,10 @@ class HueBridge(object):
|
||||
host)
|
||||
return False
|
||||
|
||||
self.hass.async_add_job(
|
||||
self.hass.helpers.discovery.async_load_platform(
|
||||
'light', DOMAIN, {'host': host}))
|
||||
hass.async_add_job(hass.config_entries.async_forward_entry(
|
||||
self.config_entry, 'light'))
|
||||
|
||||
self.hass.services.async_register(
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_HUE_SCENE, self.hue_activate_scene,
|
||||
schema=SCENE_SCHEMA)
|
||||
|
||||
|
@ -334,7 +334,7 @@ class SetIntentHandler(intent.IntentHandler):
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Expose light control via state machine and services."""
|
||||
component = EntityComponent(
|
||||
component = hass.data[DOMAIN] = EntityComponent(
|
||||
_LOGGER, DOMAIN, hass, SCAN_INTERVAL, GROUP_NAME_ALL_LIGHTS)
|
||||
await component.async_setup(config)
|
||||
|
||||
@ -388,6 +388,11 @@ async def async_setup(hass, config):
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry):
|
||||
"""Setup a config entry."""
|
||||
return await hass.data[DOMAIN].async_setup_entry(entry)
|
||||
|
||||
|
||||
class Profiles:
|
||||
"""Representation of available color profiles."""
|
||||
|
||||
|
@ -49,11 +49,17 @@ GROUP_MIN_API_VERSION = (1, 13, 0)
|
||||
|
||||
async def async_setup_platform(hass, config, async_add_devices,
|
||||
discovery_info=None):
|
||||
"""Set up the Hue lights."""
|
||||
if discovery_info is None:
|
||||
return
|
||||
"""Old way of setting up Hue lights.
|
||||
|
||||
bridge = hass.data[hue.DOMAIN][discovery_info['host']]
|
||||
Can only be called when a user accidentally mentions hue platform in their
|
||||
config. But even in that case it would have been ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_devices):
|
||||
"""Set up the Hue lights from a config entry."""
|
||||
bridge = hass.data[hue.DOMAIN][config_entry.data['host']]
|
||||
cur_lights = {}
|
||||
cur_groups = {}
|
||||
|
||||
|
@ -187,13 +187,17 @@ class ConfigEntry:
|
||||
|
||||
if not isinstance(result, bool):
|
||||
_LOGGER.error('%s.async_config_entry did not return boolean',
|
||||
self.domain)
|
||||
component.DOMAIN)
|
||||
result = False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error setting up entry %s for %s',
|
||||
self.title, self.domain)
|
||||
self.title, component.DOMAIN)
|
||||
result = False
|
||||
|
||||
# Only store setup result as state if it was not forwarded.
|
||||
if self.domain != component.DOMAIN:
|
||||
return
|
||||
|
||||
if result:
|
||||
self.state = ENTRY_STATE_LOADED
|
||||
else:
|
||||
@ -322,6 +326,27 @@ class ConfigEntries:
|
||||
entries = await self.hass.async_add_job(load_json, path)
|
||||
self._entries = [ConfigEntry(**entry) for entry in entries]
|
||||
|
||||
async def async_forward_entry(self, entry, component):
|
||||
"""Forward the setup of an entry to a different component.
|
||||
|
||||
By default an entry is setup with the component it belongs to. If that
|
||||
component also has related platforms, the component will have to
|
||||
forward the entry to be setup by that component.
|
||||
|
||||
You don't want to await this coroutine if it is called as part of the
|
||||
setup of a component, because it can cause a deadlock.
|
||||
"""
|
||||
# Setup Component if not set up yet
|
||||
if component not in self.hass.config.components:
|
||||
result = await async_setup_component(
|
||||
self.hass, component, self._hass_config)
|
||||
|
||||
if not result:
|
||||
return False
|
||||
|
||||
await entry.async_setup(
|
||||
self.hass, component=getattr(self.hass.components, component))
|
||||
|
||||
async def _async_add_entry(self, entry):
|
||||
"""Add an entry."""
|
||||
self._entries.append(entry)
|
||||
|
@ -93,6 +93,26 @@ class EntityComponent(object):
|
||||
discovery.async_listen_platform(
|
||||
self.hass, self.domain, component_platform_discovered)
|
||||
|
||||
async def async_setup_entry(self, config_entry):
|
||||
"""Setup a config entry."""
|
||||
platform_type = config_entry.domain
|
||||
platform = await async_prepare_setup_platform(
|
||||
self.hass, self.config, self.domain, platform_type)
|
||||
|
||||
if platform is None:
|
||||
return False
|
||||
|
||||
key = config_entry.entry_id
|
||||
|
||||
if key in self._platforms:
|
||||
raise ValueError('Config entry has already been setup!')
|
||||
|
||||
self._platforms[key] = self._async_init_entity_platform(
|
||||
platform_type, platform
|
||||
)
|
||||
|
||||
return await self._platforms[key].async_setup_entry(config_entry)
|
||||
|
||||
@callback
|
||||
def async_extract_from_service(self, service, expand_group=True):
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
@ -1,15 +1,13 @@
|
||||
"""Class to manage the entities for a single platform."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant.const import DEVICE_DEFAULT_NAME
|
||||
from homeassistant.core import callback, valid_entity_id, split_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.util.async_ import (
|
||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .event import async_track_time_interval, async_track_point_in_time
|
||||
from .event import async_track_time_interval, async_call_later
|
||||
from .entity_registry import async_get_registry
|
||||
|
||||
SLOW_SETUP_WARNING = 10
|
||||
@ -42,6 +40,7 @@ class EntityPlatform(object):
|
||||
self.scan_interval = scan_interval
|
||||
self.entity_namespace = entity_namespace
|
||||
self.async_entities_added_callback = async_entities_added_callback
|
||||
self.config_entry = None
|
||||
self.entities = {}
|
||||
self._tasks = []
|
||||
self._async_unsub_polling = None
|
||||
@ -68,9 +67,47 @@ class EntityPlatform(object):
|
||||
else:
|
||||
self.parallel_updates = None
|
||||
|
||||
async def async_setup(self, platform_config, discovery_info=None, tries=0):
|
||||
"""Setup the platform."""
|
||||
async def async_setup(self, platform_config, discovery_info=None):
|
||||
"""Setup the platform from a config file."""
|
||||
platform = self.platform
|
||||
hass = self.hass
|
||||
|
||||
@callback
|
||||
def async_create_setup_task():
|
||||
"""Get task to setup platform."""
|
||||
if getattr(platform, 'async_setup_platform', None):
|
||||
return platform.async_setup_platform(
|
||||
hass, platform_config,
|
||||
self._async_schedule_add_entities, discovery_info
|
||||
)
|
||||
|
||||
# This should not be replaced with hass.async_add_job because
|
||||
# we don't want to track this task in case it blocks startup.
|
||||
return hass.loop.run_in_executor(
|
||||
None, platform.setup_platform, hass, platform_config,
|
||||
self._schedule_add_entities, discovery_info
|
||||
)
|
||||
await self._async_setup_platform(async_create_setup_task)
|
||||
|
||||
async def async_setup_entry(self, config_entry):
|
||||
"""Setup the platform from a config entry."""
|
||||
# Store it so that we can save config entry ID in entity registry
|
||||
self.config_entry = config_entry
|
||||
platform = self.platform
|
||||
|
||||
@callback
|
||||
def async_create_setup_task():
|
||||
"""Get task to setup platform."""
|
||||
return platform.async_setup_entry(
|
||||
self.hass, config_entry, self._async_schedule_add_entities)
|
||||
|
||||
return await self._async_setup_platform(async_create_setup_task)
|
||||
|
||||
async def _async_setup_platform(self, async_create_setup_task, tries=0):
|
||||
"""Helper to setup a platform via config file or config entry.
|
||||
|
||||
async_create_setup_task creates a coroutine that sets up platform.
|
||||
"""
|
||||
logger = self.logger
|
||||
hass = self.hass
|
||||
full_name = '{}.{}'.format(self.domain, self.platform_name)
|
||||
@ -82,18 +119,8 @@ class EntityPlatform(object):
|
||||
self.platform_name, SLOW_SETUP_WARNING)
|
||||
|
||||
try:
|
||||
if getattr(platform, 'async_setup_platform', None):
|
||||
task = platform.async_setup_platform(
|
||||
hass, platform_config,
|
||||
self._async_schedule_add_entities, discovery_info
|
||||
)
|
||||
else:
|
||||
# This should not be replaced with hass.async_add_job because
|
||||
# we don't want to track this task in case it blocks startup.
|
||||
task = hass.loop.run_in_executor(
|
||||
None, platform.setup_platform, hass, platform_config,
|
||||
self._schedule_add_entities, discovery_info
|
||||
)
|
||||
task = async_create_setup_task()
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(task, loop=hass.loop),
|
||||
SLOW_SETUP_MAX_WAIT, loop=hass.loop)
|
||||
@ -108,23 +135,31 @@ class EntityPlatform(object):
|
||||
pending, loop=self.hass.loop)
|
||||
|
||||
hass.config.components.add(full_name)
|
||||
return True
|
||||
except PlatformNotReady:
|
||||
tries += 1
|
||||
wait_time = min(tries, 6) * 30
|
||||
logger.warning(
|
||||
'Platform %s not ready yet. Retrying in %d seconds.',
|
||||
self.platform_name, wait_time)
|
||||
async_track_point_in_time(
|
||||
hass, self.async_setup(platform_config, discovery_info, tries),
|
||||
dt_util.utcnow() + timedelta(seconds=wait_time))
|
||||
|
||||
async def setup_again(now):
|
||||
"""Run setup again."""
|
||||
await self._async_setup_platform(
|
||||
async_create_setup_task, tries)
|
||||
|
||||
async_call_later(hass, wait_time, setup_again)
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Setup of platform %s is taking longer than %s seconds."
|
||||
" Startup will proceed without waiting any longer.",
|
||||
self.platform_name, SLOW_SETUP_MAX_WAIT)
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.exception(
|
||||
"Error while setting up platform %s", self.platform_name)
|
||||
return False
|
||||
finally:
|
||||
warn_task.cancel()
|
||||
|
||||
|
@ -344,7 +344,8 @@ class MockPlatform(object):
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def __init__(self, setup_platform=None, dependencies=None,
|
||||
platform_schema=None, async_setup_platform=None):
|
||||
platform_schema=None, async_setup_platform=None,
|
||||
async_setup_entry=None):
|
||||
"""Initialize the platform."""
|
||||
self.DEPENDENCIES = dependencies or []
|
||||
|
||||
@ -358,6 +359,9 @@ class MockPlatform(object):
|
||||
if async_setup_platform is not None:
|
||||
self.async_setup_platform = async_setup_platform
|
||||
|
||||
if async_setup_entry is not None:
|
||||
self.async_setup_entry = async_setup_entry
|
||||
|
||||
if setup_platform is None and async_setup_platform is None:
|
||||
self.async_setup_platform = mock_coro_func()
|
||||
|
||||
@ -376,6 +380,14 @@ class MockEntityPlatform(entity_platform.EntityPlatform):
|
||||
async_entities_added_callback=lambda: None
|
||||
):
|
||||
"""Initialize a mock entity platform."""
|
||||
if logger is None:
|
||||
logger = logging.getLogger('homeassistant.helpers.entity_platform')
|
||||
|
||||
# Otherwise the constructor will blow up.
|
||||
if (isinstance(platform, Mock) and
|
||||
isinstance(platform.PARALLEL_UPDATES, Mock)):
|
||||
platform.PARALLEL_UPDATES = 0
|
||||
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
logger=logger,
|
||||
|
@ -18,10 +18,9 @@ async def test_bridge_setup():
|
||||
assert await hue_bridge.async_setup() is True
|
||||
|
||||
assert hue_bridge.api is api
|
||||
assert len(hass.helpers.discovery.async_load_platform.mock_calls) == 1
|
||||
assert hass.helpers.discovery.async_load_platform.mock_calls[0][1][2] == {
|
||||
'host': '1.2.3.4'
|
||||
}
|
||||
assert len(hass.config_entries.async_forward_entry.mock_calls) == 1
|
||||
assert hass.config_entries.async_forward_entry.mock_calls[0][1] == \
|
||||
(entry, 'light')
|
||||
|
||||
|
||||
async def test_bridge_setup_invalid_username():
|
||||
|
@ -9,6 +9,7 @@ from aiohue.lights import Lights
|
||||
from aiohue.groups import Groups
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import hue
|
||||
import homeassistant.components.light.hue as hue_light
|
||||
from homeassistant.util import color
|
||||
@ -196,9 +197,11 @@ async def setup_bridge(hass, mock_bridge):
|
||||
"""Load the Hue light platform with the provided bridge."""
|
||||
hass.config.components.add(hue.DOMAIN)
|
||||
hass.data[hue.DOMAIN] = {'mock-host': mock_bridge}
|
||||
await hass.helpers.discovery.async_load_platform('light', 'hue', {
|
||||
config_entry = config_entries.ConfigEntry(1, hue.DOMAIN, 'Mock Title', {
|
||||
'host': 'mock-host'
|
||||
})
|
||||
}, 'test')
|
||||
await hass.config_entries.async_forward_entry(config_entry, 'light')
|
||||
# To flush out the service call to update the group
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
|
@ -7,6 +7,8 @@ import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
import homeassistant.core as ha
|
||||
import homeassistant.loader as loader
|
||||
from homeassistant.exceptions import PlatformNotReady
|
||||
@ -19,7 +21,7 @@ import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, MockPlatform, MockModule, mock_coro,
|
||||
async_fire_time_changed, MockEntity)
|
||||
async_fire_time_changed, MockEntity, MockConfigEntry)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DOMAIN = "test_domain"
|
||||
@ -333,3 +335,44 @@ def test_setup_dependencies_platform(hass):
|
||||
assert 'test_component' in hass.config.components
|
||||
assert 'test_component2' in hass.config.components
|
||||
assert 'test_domain.test_component' in hass.config.components
|
||||
|
||||
|
||||
async def test_setup_entry(hass):
|
||||
"""Test setup entry calls async_setup_entry on platform."""
|
||||
mock_setup_entry = Mock(return_value=mock_coro(True))
|
||||
loader.set_component(
|
||||
'test_domain.entry_domain',
|
||||
MockPlatform(async_setup_entry=mock_setup_entry))
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entry = MockConfigEntry(domain='entry_domain')
|
||||
|
||||
assert await component.async_setup_entry(entry)
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
p_hass, p_entry, p_add_entities = mock_setup_entry.mock_calls[0][1]
|
||||
assert p_hass is hass
|
||||
assert p_entry is entry
|
||||
|
||||
|
||||
async def test_setup_entry_platform_not_exist(hass):
|
||||
"""Test setup entry fails if platform doesnt exist."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entry = MockConfigEntry(domain='non_existing')
|
||||
|
||||
assert (await component.async_setup_entry(entry)) is False
|
||||
|
||||
|
||||
async def test_setup_entry_fails_duplicate(hass):
|
||||
"""Test we don't allow setting up a config entry twice."""
|
||||
mock_setup_entry = Mock(return_value=mock_coro(True))
|
||||
loader.set_component(
|
||||
'test_domain.entry_domain',
|
||||
MockPlatform(async_setup_entry=mock_setup_entry))
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
entry = MockConfigEntry(domain='entry_domain')
|
||||
|
||||
assert await component.async_setup_entry(entry)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await component.async_setup_entry(entry)
|
||||
|
@ -5,6 +5,7 @@ import unittest
|
||||
from unittest.mock import patch, Mock, MagicMock
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant.exceptions import PlatformNotReady
|
||||
import homeassistant.loader as loader
|
||||
from homeassistant.helpers.entity import generate_entity_id
|
||||
from homeassistant.helpers.entity_component import (
|
||||
@ -15,7 +16,7 @@ import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
|
||||
MockEntity, MockEntityPlatform)
|
||||
MockEntity, MockEntityPlatform, MockConfigEntry, mock_coro)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DOMAIN = "test_domain"
|
||||
@ -511,3 +512,46 @@ async def test_entity_registry_updates(hass):
|
||||
|
||||
state = hass.states.get('test_domain.world')
|
||||
assert state.name == 'after update'
|
||||
|
||||
|
||||
async def test_setup_entry(hass):
|
||||
"""Test we can setup an entry."""
|
||||
async_setup_entry = Mock(return_value=mock_coro(True))
|
||||
platform = MockPlatform(
|
||||
async_setup_entry=async_setup_entry
|
||||
)
|
||||
config_entry = MockConfigEntry()
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass,
|
||||
platform_name=config_entry.domain,
|
||||
platform=platform
|
||||
)
|
||||
|
||||
assert await entity_platform.async_setup_entry(config_entry)
|
||||
|
||||
full_name = '{}.{}'.format(entity_platform.domain, config_entry.domain)
|
||||
assert full_name in hass.config.components
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_setup_entry_platform_not_ready(hass, caplog):
|
||||
"""Test when an entry is not ready yet."""
|
||||
async_setup_entry = Mock(side_effect=PlatformNotReady)
|
||||
platform = MockPlatform(
|
||||
async_setup_entry=async_setup_entry
|
||||
)
|
||||
config_entry = MockConfigEntry()
|
||||
ent_platform = MockEntityPlatform(
|
||||
hass,
|
||||
platform_name=config_entry.domain,
|
||||
platform=platform
|
||||
)
|
||||
|
||||
with patch.object(entity_platform, 'async_call_later') as mock_call_later:
|
||||
assert not await ent_platform.async_setup_entry(config_entry)
|
||||
|
||||
full_name = '{}.{}'.format(ent_platform.domain, config_entry.domain)
|
||||
assert full_name not in hass.config.components
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
assert 'Platform test not ready yet' in caplog.text
|
||||
assert len(mock_call_later.mock_calls) == 1
|
||||
|
@ -389,3 +389,39 @@ def test_discovery_init_flow(manager):
|
||||
assert entry.title == 'hello'
|
||||
assert entry.data == data
|
||||
assert entry.source == config_entries.SOURCE_DISCOVERY
|
||||
|
||||
|
||||
async def test_forward_entry_sets_up_component(hass):
|
||||
"""Test we setup the component entry is forwarded to."""
|
||||
entry = MockConfigEntry(domain='original')
|
||||
|
||||
mock_original_setup_entry = MagicMock(return_value=mock_coro(True))
|
||||
loader.set_component(
|
||||
'original',
|
||||
MockModule('original', async_setup_entry=mock_original_setup_entry))
|
||||
|
||||
mock_forwarded_setup_entry = MagicMock(return_value=mock_coro(True))
|
||||
loader.set_component(
|
||||
'forwarded',
|
||||
MockModule('forwarded', async_setup_entry=mock_forwarded_setup_entry))
|
||||
|
||||
await hass.config_entries.async_forward_entry(entry, 'forwarded')
|
||||
assert len(mock_original_setup_entry.mock_calls) == 0
|
||||
assert len(mock_forwarded_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_forward_entry_does_not_setup_entry_if_setup_fails(hass):
|
||||
"""Test we do not setup entry if component setup fails."""
|
||||
entry = MockConfigEntry(domain='original')
|
||||
|
||||
mock_setup = MagicMock(return_value=mock_coro(False))
|
||||
mock_setup_entry = MagicMock()
|
||||
loader.set_component('forwarded', MockModule(
|
||||
'forwarded',
|
||||
async_setup=mock_setup,
|
||||
async_setup_entry=mock_setup_entry,
|
||||
))
|
||||
|
||||
await hass.config_entries.async_forward_entry(entry, 'forwarded')
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user