diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index cda49591d5c..87eb2cd615b 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -58,7 +58,7 @@ ATTR_ALL = "all" SERVICE_SET = "set" SERVICE_REMOVE = "remove" -PLATFORMS = ["light", "cover"] +PLATFORMS = ["light", "cover", "notify"] _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/group/services.yaml b/homeassistant/components/group/services.yaml index cec4f187ca6..57e11d672dc 100644 --- a/homeassistant/components/group/services.yaml +++ b/homeassistant/components/group/services.yaml @@ -1,6 +1,6 @@ # Describes the format for available group services reload: - description: Reload group configuration. + description: Reload group configuration, entities, and notify services. set: description: Create/Update a user group. diff --git a/homeassistant/components/notify/__init__.py b/homeassistant/components/notify/__init__.py index 77ec48c5435..9a4ec681aab 100644 --- a/homeassistant/components/notify/__init__.py +++ b/homeassistant/components/notify/__init__.py @@ -2,11 +2,12 @@ import asyncio from functools import partial import logging -from typing import Optional +from typing import Any, Dict, Optional import voluptuous as vol from homeassistant.const import CONF_NAME, CONF_PLATFORM +from homeassistant.core import ServiceCall from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, discovery import homeassistant.helpers.config_validation as cv @@ -37,10 +38,6 @@ DOMAIN = "notify" SERVICE_NOTIFY = "notify" NOTIFY_SERVICES = "notify_services" -SERVICE = "service" -TARGETS = "targets" -FRIENDLY_NAME = "friendly_name" -TARGET_FRIENDLY_NAME = "target_friendly_name" PLATFORM_SCHEMA = vol.Schema( {vol.Required(CONF_PLATFORM): cv.string, vol.Optional(CONF_NAME): cv.string}, @@ -58,88 +55,160 @@ NOTIFY_SERVICE_SCHEMA = vol.Schema( @bind_hass -async def async_reload(hass, integration_name): +async def async_reload(hass: HomeAssistantType, integration_name: str) -> None: """Register notify services for an integration.""" + if not _async_integration_has_notify_services(hass, integration_name): + return + + tasks = [ + notify_service.async_register_services() + for notify_service in hass.data[NOTIFY_SERVICES][integration_name] + ] + + await asyncio.gather(*tasks) + + +@bind_hass +async def async_reset_platform(hass: HomeAssistantType, integration_name: str) -> None: + """Unregister notify services for an integration.""" + if not _async_integration_has_notify_services(hass, integration_name): + return + + tasks = [ + notify_service.async_unregister_services() + for notify_service in hass.data[NOTIFY_SERVICES][integration_name] + ] + + await asyncio.gather(*tasks) + + del hass.data[NOTIFY_SERVICES][integration_name] + + +def _async_integration_has_notify_services( + hass: HomeAssistantType, integration_name: str +) -> bool: + """Determine if an integration has notify services registered.""" if ( NOTIFY_SERVICES not in hass.data or integration_name not in hass.data[NOTIFY_SERVICES] ): - return + return False - tasks = [ - _async_setup_notify_services(hass, data) - for data in hass.data[NOTIFY_SERVICES][integration_name] - ] - await asyncio.gather(*tasks) + return True -async def _async_setup_notify_services(hass, data): - """Create or remove the notify services.""" - notify_service = data[SERVICE] - friendly_name = data[FRIENDLY_NAME] - targets = data[TARGETS] +class BaseNotificationService: + """An abstract class for notification services.""" - async def _async_notify_message(service): + hass: Optional[HomeAssistantType] = None + + def send_message(self, message, **kwargs): + """Send a message. + + kwargs can contain ATTR_TITLE to specify a title. + """ + raise NotImplementedError() + + async def async_send_message(self, message: Any, **kwargs: Any) -> None: + """Send a message. + + kwargs can contain ATTR_TITLE to specify a title. + """ + await self.hass.async_add_job(partial(self.send_message, message, **kwargs)) # type: ignore + + async def _async_notify_message_service(self, service: ServiceCall) -> None: """Handle sending notification message service calls.""" - await _async_notify_message_service(hass, service, notify_service, targets) + kwargs = {} + message = service.data[ATTR_MESSAGE] + title = service.data.get(ATTR_TITLE) - if hasattr(notify_service, "targets"): - target_friendly_name = data[TARGET_FRIENDLY_NAME] - stale_targets = set(targets) + if title: + title.hass = self.hass + kwargs[ATTR_TITLE] = title.async_render() - for name, target in notify_service.targets.items(): - target_name = slugify(f"{target_friendly_name}_{name}") - if target_name in stale_targets: - stale_targets.remove(target_name) - if target_name in targets: - continue - targets[target_name] = target - hass.services.async_register( - DOMAIN, - target_name, - _async_notify_message, - schema=NOTIFY_SERVICE_SCHEMA, - ) + if self._registered_targets.get(service.service) is not None: + kwargs[ATTR_TARGET] = [self._registered_targets[service.service]] + elif service.data.get(ATTR_TARGET) is not None: + kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET) - for stale_target_name in stale_targets: - del targets[stale_target_name] - hass.services.async_remove( - DOMAIN, - stale_target_name, - ) + message.hass = self.hass + kwargs[ATTR_MESSAGE] = message.async_render() + kwargs[ATTR_DATA] = service.data.get(ATTR_DATA) - friendly_name_slug = slugify(friendly_name) - if hass.services.has_service(DOMAIN, friendly_name_slug): - return + await self.async_send_message(**kwargs) - hass.services.async_register( - DOMAIN, - friendly_name_slug, - _async_notify_message, - schema=NOTIFY_SERVICE_SCHEMA, - ) + async def async_setup( + self, + hass: HomeAssistantType, + service_name: str, + target_service_name_prefix: str, + ) -> None: + """Store the data for the notify service.""" + # pylint: disable=attribute-defined-outside-init + self.hass = hass + self._service_name = service_name + self._target_service_name_prefix = target_service_name_prefix + self._registered_targets: Dict = {} + async def async_register_services(self) -> None: + """Create or update the notify services.""" + assert self.hass -async def _async_notify_message_service(hass, service, notify_service, targets): - """Handle sending notification message service calls.""" - kwargs = {} - message = service.data[ATTR_MESSAGE] - title = service.data.get(ATTR_TITLE) + if hasattr(self, "targets"): + stale_targets = set(self._registered_targets) - if title: - title.hass = hass - kwargs[ATTR_TITLE] = title.async_render() + # pylint: disable=no-member + for name, target in self.targets.items(): # type: ignore + target_name = slugify(f"{self._target_service_name_prefix}_{name}") + if target_name in stale_targets: + stale_targets.remove(target_name) + if target_name in self._registered_targets: + continue + self._registered_targets[target_name] = target + self.hass.services.async_register( + DOMAIN, + target_name, + self._async_notify_message_service, + schema=NOTIFY_SERVICE_SCHEMA, + ) - if targets.get(service.service) is not None: - kwargs[ATTR_TARGET] = [targets[service.service]] - elif service.data.get(ATTR_TARGET) is not None: - kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET) + for stale_target_name in stale_targets: + del self._registered_targets[stale_target_name] + self.hass.services.async_remove( + DOMAIN, + stale_target_name, + ) - message.hass = hass - kwargs[ATTR_MESSAGE] = message.async_render() - kwargs[ATTR_DATA] = service.data.get(ATTR_DATA) + if self.hass.services.has_service(DOMAIN, self._service_name): + return - await notify_service.async_send_message(**kwargs) + self.hass.services.async_register( + DOMAIN, + self._service_name, + self._async_notify_message_service, + schema=NOTIFY_SERVICE_SCHEMA, + ) + + async def async_unregister_services(self) -> None: + """Unregister the notify services.""" + assert self.hass + + if self._registered_targets: + remove_targets = set(self._registered_targets) + for remove_target_name in remove_targets: + del self._registered_targets[remove_target_name] + self.hass.services.async_remove( + DOMAIN, + remove_target_name, + ) + + if not self.hass.services.has_service(DOMAIN, self._service_name): + return + + self.hass.services.async_remove( + DOMAIN, + self._service_name, + ) async def async_setup(hass, config): @@ -188,31 +257,19 @@ async def async_setup(hass, config): _LOGGER.exception("Error setting up platform %s", integration_name) return - notify_service.hass = hass - if discovery_info is None: discovery_info = {} - target_friendly_name = ( - p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or integration_name + conf_name = p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) + target_service_name_prefix = conf_name or integration_name + service_name = slugify(conf_name or SERVICE_NOTIFY) + + await notify_service.async_setup(hass, service_name, target_service_name_prefix) + await notify_service.async_register_services() + + hass.data[NOTIFY_SERVICES].setdefault(integration_name, []).append( + notify_service ) - friendly_name = ( - p_config.get(CONF_NAME) or discovery_info.get(CONF_NAME) or SERVICE_NOTIFY - ) - - data = { - FRIENDLY_NAME: friendly_name, - # The targets use a slightly different friendly name - # selection pattern than the base service - TARGET_FRIENDLY_NAME: target_friendly_name, - SERVICE: notify_service, - TARGETS: {}, - } - hass.data[NOTIFY_SERVICES].setdefault(integration_name, []) - hass.data[NOTIFY_SERVICES][integration_name].append(data) - - await _async_setup_notify_services(hass, data) - hass.config.components.add(f"{DOMAIN}.{integration_name}") return True @@ -232,23 +289,3 @@ async def async_setup(hass, config): discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) return True - - -class BaseNotificationService: - """An abstract class for notification services.""" - - hass: Optional[HomeAssistantType] = None - - def send_message(self, message, **kwargs): - """Send a message. - - kwargs can contain ATTR_TITLE to specify a title. - """ - raise NotImplementedError() - - async def async_send_message(self, message, **kwargs): - """Send a message. - - kwargs can contain ATTR_TITLE to specify a title. - """ - await self.hass.async_add_job(partial(self.send_message, message, **kwargs)) diff --git a/homeassistant/helpers/reload.py b/homeassistant/helpers/reload.py index 1ffba25ce15..19a1e46a6d4 100644 --- a/homeassistant/helpers/reload.py +++ b/homeassistant/helpers/reload.py @@ -2,7 +2,7 @@ import asyncio import logging -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional from homeassistant import config as conf_util from homeassistant.const import SERVICE_RELOAD @@ -12,6 +12,7 @@ from homeassistant.helpers import config_per_platform from homeassistant.helpers.entity_platform import DATA_ENTITY_PLATFORM, EntityPlatform from homeassistant.helpers.typing import HomeAssistantType from homeassistant.loader import async_get_integration +from homeassistant.setup import async_setup_component _LOGGER = logging.getLogger(__name__) @@ -34,29 +35,94 @@ async def async_reload_integration_platforms( _LOGGER.error(err) return - for integration_platform in integration_platforms: - platform = async_get_platform(hass, integration_name, integration_platform) - - if not platform: - continue - - integration = await async_get_integration(hass, integration_platform) - - conf = await conf_util.async_process_component_config( - hass, unprocessed_conf, integration + tasks = [ + _resetup_platform( + hass, integration_name, integration_platform, unprocessed_conf ) + for integration_platform in integration_platforms + ] - if not conf: + await asyncio.gather(*tasks) + + +async def _resetup_platform( + hass: HomeAssistantType, + integration_name: str, + integration_platform: str, + unprocessed_conf: Dict, +) -> None: + """Resetup a platform.""" + integration = await async_get_integration(hass, integration_platform) + + conf = await conf_util.async_process_component_config( + hass, unprocessed_conf, integration + ) + + if not conf: + return + + root_config: Dict = {integration_platform: []} + # Extract only the config for template, ignore the rest. + for p_type, p_config in config_per_platform(conf, integration_platform): + if p_type != integration_name: continue - await platform.async_reset() + root_config[integration_platform].append(p_config) - # Extract only the config for template, ignore the rest. - for p_type, p_config in config_per_platform(conf, integration_platform): - if p_type != integration_name: - continue + component = integration.get_component() - await platform.async_setup(p_config) # type: ignore + if hasattr(component, "async_reset_platform"): + # If the integration has its own way to reset + # use this method. + await component.async_reset_platform(hass, integration_name) # type: ignore + await component.async_setup(hass, root_config) # type: ignore + return + + # If its an entity platform, we use the entity_platform + # async_reset method + platform = async_get_platform(hass, integration_name, integration_platform) + if platform: + await _async_reconfig_platform(platform, root_config[integration_platform]) + return + + if not root_config[integration_platform]: + # No config for this platform + # and its not loaded. Nothing to do + return + + await _async_setup_platform( + hass, integration_name, integration_platform, root_config[integration_platform] + ) + + +async def _async_setup_platform( + hass: HomeAssistantType, + integration_name: str, + integration_platform: str, + platform_configs: List[Dict], +) -> None: + """Platform for the first time when new configuration is added.""" + if integration_platform not in hass.data: + await async_setup_component( + hass, integration_platform, {integration_platform: platform_configs} + ) + return + + entity_component = hass.data[integration_platform] + tasks = [ + entity_component.async_setup_platform(integration_name, p_config) + for p_config in platform_configs + ] + await asyncio.gather(*tasks) + + +async def _async_reconfig_platform( + platform: EntityPlatform, platform_configs: List[Dict] +) -> None: + """Reconfigure an already loaded platform.""" + await platform.async_reset() + tasks = [platform.async_setup(p_config) for p_config in platform_configs] # type: ignore + await asyncio.gather(*tasks) async def async_integration_yaml_config( diff --git a/tests/components/group/test_light.py b/tests/components/group/test_light.py index 8855ca97626..a22c56b4bfc 100644 --- a/tests/components/group/test_light.py +++ b/tests/components/group/test_light.py @@ -683,5 +683,79 @@ async def test_reload(hass): assert hass.states.get("light.outside_patio_lights_g") is not None +async def test_reload_with_platform_not_setup(hass): + """Test the ability to reload lights.""" + hass.states.async_set("light.bowl", STATE_ON) + await async_setup_component( + hass, + LIGHT_DOMAIN, + { + LIGHT_DOMAIN: [ + {"platform": "demo"}, + ] + }, + ) + assert await async_setup_component( + hass, + "group", + { + "group": { + "group_zero": {"entities": "light.Bowl", "icon": "mdi:work"}, + } + }, + ) + await hass.async_block_till_done() + + yaml_path = path.join( + _get_fixtures_base_path(), + "fixtures", + "group/configuration.yaml", + ) + with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path): + await hass.services.async_call( + DOMAIN, + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert hass.states.get("light.light_group") is None + assert hass.states.get("light.master_hall_lights_g") is not None + assert hass.states.get("light.outside_patio_lights_g") is not None + + +async def test_reload_with_base_integration_platform_not_setup(hass): + """Test the ability to reload lights.""" + assert await async_setup_component( + hass, + "group", + { + "group": { + "group_zero": {"entities": "light.Bowl", "icon": "mdi:work"}, + } + }, + ) + await hass.async_block_till_done() + + yaml_path = path.join( + _get_fixtures_base_path(), + "fixtures", + "group/configuration.yaml", + ) + with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path): + await hass.services.async_call( + DOMAIN, + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert hass.states.get("light.light_group") is None + assert hass.states.get("light.master_hall_lights_g") is not None + assert hass.states.get("light.outside_patio_lights_g") is not None + + def _get_fixtures_base_path(): return path.dirname(path.dirname(path.dirname(__file__))) diff --git a/tests/components/group/test_notify.py b/tests/components/group/test_notify.py index b120cf2cea4..62b395dcf4d 100644 --- a/tests/components/group/test_notify.py +++ b/tests/components/group/test_notify.py @@ -1,11 +1,14 @@ """The tests for the notify.group platform.""" import asyncio +from os import path import unittest +from homeassistant import config as hass_config import homeassistant.components.demo.notify as demo +from homeassistant.components.group import SERVICE_RELOAD import homeassistant.components.group.notify as group import homeassistant.components.notify as notify -from homeassistant.setup import setup_component +from homeassistant.setup import async_setup_component, setup_component from tests.async_mock import MagicMock, patch from tests.common import assert_setup_component, get_test_home_assistant @@ -90,3 +93,58 @@ class TestNotifyGroup(unittest.TestCase): "title": "Test notification", "data": {"hello": "world", "test": "message"}, } + + +async def test_reload_notify(hass): + """Verify we can reload the notify service.""" + + assert await async_setup_component( + hass, + "group", + {}, + ) + await hass.async_block_till_done() + + assert await async_setup_component( + hass, + notify.DOMAIN, + { + notify.DOMAIN: [ + {"name": "demo1", "platform": "demo"}, + {"name": "demo2", "platform": "demo"}, + { + "name": "group_notify", + "platform": "group", + "services": [{"service": "demo1"}], + }, + ] + }, + ) + await hass.async_block_till_done() + + assert hass.services.has_service(notify.DOMAIN, "demo1") + assert hass.services.has_service(notify.DOMAIN, "demo2") + assert hass.services.has_service(notify.DOMAIN, "group_notify") + + yaml_path = path.join( + _get_fixtures_base_path(), + "fixtures", + "group/configuration.yaml", + ) + with patch.object(hass_config, "YAML_CONFIG_FILE", yaml_path): + await hass.services.async_call( + "group", + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert hass.services.has_service(notify.DOMAIN, "demo1") + assert hass.services.has_service(notify.DOMAIN, "demo2") + assert not hass.services.has_service(notify.DOMAIN, "group_notify") + assert hass.services.has_service(notify.DOMAIN, "new_group_notify") + + +def _get_fixtures_base_path(): + return path.dirname(path.dirname(path.dirname(__file__))) diff --git a/tests/fixtures/group/configuration.yaml b/tests/fixtures/group/configuration.yaml index 9047024e3de..0a5c9e18bd1 100644 --- a/tests/fixtures/group/configuration.yaml +++ b/tests/fixtures/group/configuration.yaml @@ -9,3 +9,10 @@ light: entities: - light.outside_patio_lights - light.outside_patio_lights_2 + +notify: + - platform: group + name: new_group_notify + services: + - service: demo1 + - service: demo2 diff --git a/tests/helpers/test_reload.py b/tests/helpers/test_reload.py index dafcbebdb6e..25844151533 100644 --- a/tests/helpers/test_reload.py +++ b/tests/helpers/test_reload.py @@ -13,8 +13,9 @@ from homeassistant.helpers.reload import ( async_reload_integration_platforms, async_setup_reload_service, ) +from homeassistant.loader import async_get_integration -from tests.async_mock import Mock, patch +from tests.async_mock import AsyncMock, Mock, patch from tests.common import ( MockModule, MockPlatform, @@ -109,6 +110,104 @@ async def test_setup_reload_service(hass): assert len(setup_called) == 2 +async def test_setup_reload_service_when_async_process_component_config_fails(hass): + """Test setting up a reload service with the config processing failing.""" + component_setup = Mock(return_value=True) + + setup_called = [] + + async def setup_platform(*args): + setup_called.append(args) + + mock_integration(hass, MockModule(DOMAIN, setup=component_setup)) + mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN])) + + mock_platform = MockPlatform(async_setup_platform=setup_platform) + mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}}) + await hass.async_block_till_done() + assert component_setup.called + + assert f"{DOMAIN}.{PLATFORM}" in hass.config.components + assert len(setup_called) == 1 + + await async_setup_reload_service(hass, PLATFORM, [DOMAIN]) + + yaml_path = path.join( + _get_fixtures_base_path(), + "fixtures", + "helpers/reload_configuration.yaml", + ) + with patch.object(config, "YAML_CONFIG_FILE", yaml_path), patch.object( + config, "async_process_component_config", return_value=None + ): + await hass.services.async_call( + PLATFORM, + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert len(setup_called) == 1 + + +async def test_setup_reload_service_with_platform_that_provides_async_reset_platform( + hass, +): + """Test setting up a reload service using a platform that has its own async_reset_platform.""" + component_setup = AsyncMock(return_value=True) + + setup_called = [] + async_reset_platform_called = [] + + async def setup_platform(*args): + setup_called.append(args) + + async def async_reset_platform(*args): + async_reset_platform_called.append(args) + + mock_integration(hass, MockModule(DOMAIN, async_setup=component_setup)) + integration = await async_get_integration(hass, DOMAIN) + integration.get_component().async_reset_platform = async_reset_platform + + mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN])) + + mock_platform = MockPlatform(async_setup_platform=setup_platform) + mock_entity_platform(hass, f"{DOMAIN}.{PLATFORM}", mock_platform) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + await component.async_setup({DOMAIN: {"platform": PLATFORM, "name": "xyz"}}) + await hass.async_block_till_done() + assert component_setup.called + + assert f"{DOMAIN}.{PLATFORM}" in hass.config.components + assert len(setup_called) == 1 + + await async_setup_reload_service(hass, PLATFORM, [DOMAIN]) + + yaml_path = path.join( + _get_fixtures_base_path(), + "fixtures", + "helpers/reload_configuration.yaml", + ) + with patch.object(config, "YAML_CONFIG_FILE", yaml_path): + await hass.services.async_call( + PLATFORM, + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + assert len(setup_called) == 1 + assert len(async_reset_platform_called) == 1 + + async def test_async_integration_yaml_config(hass): """Test loading yaml config for an integration.""" mock_integration(hass, MockModule(DOMAIN))