diff --git a/homeassistant/components/notify/legacy.py b/homeassistant/components/notify/legacy.py index af29a9fba99..50b02324827 100644 --- a/homeassistant/components/notify/legacy.py +++ b/homeassistant/components/notify/legacy.py @@ -29,11 +29,13 @@ from .const import ( CONF_FIELDS = "fields" NOTIFY_SERVICES = "notify_services" +NOTIFY_DISCOVERY_DISPATCHER = "notify_discovery_dispatcher" async def async_setup_legacy(hass: HomeAssistant, config: ConfigType) -> None: """Set up legacy notify services.""" hass.data.setdefault(NOTIFY_SERVICES, {}) + hass.data.setdefault(NOTIFY_DISCOVERY_DISPATCHER, None) async def async_setup_platform( integration_name: str, @@ -114,7 +116,9 @@ async def async_setup_legacy(hass: HomeAssistant, config: ConfigType) -> None: """Handle for discovered platform.""" await async_setup_platform(platform, discovery_info=info) - discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) + hass.data[NOTIFY_DISCOVERY_DISPATCHER] = discovery.async_listen_platform( + hass, DOMAIN, async_platform_discovered + ) @callback @@ -147,6 +151,9 @@ async def async_reload(hass: HomeAssistant, integration_name: str) -> None: @bind_hass async def async_reset_platform(hass: HomeAssistant, integration_name: str) -> None: """Unregister notify services for an integration.""" + if NOTIFY_DISCOVERY_DISPATCHER in hass.data: + hass.data[NOTIFY_DISCOVERY_DISPATCHER]() + hass.data[NOTIFY_DISCOVERY_DISPATCHER] = None if not _async_integration_has_notify_services(hass, integration_name): return diff --git a/homeassistant/helpers/discovery.py b/homeassistant/helpers/discovery.py index 20819ac7504..93efd7e69c1 100644 --- a/homeassistant/helpers/discovery.py +++ b/homeassistant/helpers/discovery.py @@ -95,7 +95,7 @@ def async_listen_platform( hass: core.HomeAssistant, component: str, callback: Callable[[str, dict[str, Any] | None], Any], -) -> None: +) -> Callable[[], None]: """Register a platform loader listener. This method must be run in the event loop. @@ -112,7 +112,7 @@ def async_listen_platform( if task: await task - async_dispatcher_connect( + return async_dispatcher_connect( hass, SIGNAL_PLATFORM_DISCOVERED.format(service), discovery_platform_listener ) diff --git a/homeassistant/helpers/reload.py b/homeassistant/helpers/reload.py index 42fcf1032bb..83698557eb6 100644 --- a/homeassistant/helpers/reload.py +++ b/homeassistant/helpers/reload.py @@ -21,6 +21,8 @@ from .typing import ConfigType _LOGGER = logging.getLogger(__name__) +PLATFORM_RESET_LOCK = "lock_async_reset_platform_{}" + async def async_reload_integration_platforms( hass: HomeAssistant, integration_name: str, integration_platforms: Iterable[str] @@ -79,8 +81,11 @@ async def _resetup_platform( if hasattr(component, "async_reset_platform"): # If the integration has its own way to reset # use this method. - await component.async_reset_platform(hass, integration_name) - await component.async_setup(hass, root_config) + async with hass.data.setdefault( + PLATFORM_RESET_LOCK.format(integration_platform), asyncio.Lock() + ): + await component.async_reset_platform(hass, integration_name) + await component.async_setup(hass, root_config) return # If it's an entity platform, we use the entity_platform diff --git a/tests/components/notify/test_init.py b/tests/components/notify/test_init.py index 9dbbcf9b9b9..ae32884add7 100644 --- a/tests/components/notify/test_init.py +++ b/tests/components/notify/test_init.py @@ -1,8 +1,41 @@ """The tests for notify services that change targets.""" + +from unittest.mock import Mock, patch + +import yaml + +from homeassistant import config as hass_config from homeassistant.components import notify +from homeassistant.const import SERVICE_RELOAD from homeassistant.core import HomeAssistant +from homeassistant.helpers.discovery import async_load_platform +from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.setup import async_setup_component +from tests.common import MockPlatform, mock_platform + + +class MockNotifyPlatform(MockPlatform): + """Help to set up test notify service.""" + + def __init__(self, async_get_service=None, get_service=None): + """Return the notify service.""" + super().__init__() + if get_service: + self.get_service = get_service + if async_get_service: + self.async_get_service = async_get_service + + +def mock_notify_platform( + hass, tmp_path, integration="notify", async_get_service=None, get_service=None +): + """Specialize the mock platform for notify.""" + loaded_platform = MockNotifyPlatform(async_get_service, get_service) + mock_platform(hass, f"{integration}.notify", loaded_platform) + + return loaded_platform + async def test_same_targets(hass: HomeAssistant): """Test not changing the targets in a notify service.""" @@ -73,10 +106,16 @@ async def test_remove_targets(hass: HomeAssistant): class NotificationService(notify.BaseNotificationService): """A test class for notification services.""" - def __init__(self, hass): + def __init__(self, hass, target_list={"a": 1, "b": 2}, name="notify"): """Initialize the service.""" + + async def _async_make_reloadable(hass): + """Initialize the reload service.""" + await async_setup_reload_service(hass, name, [notify.DOMAIN]) + self.hass = hass - self.target_list = {"a": 1, "b": 2} + self.target_list = target_list + hass.async_create_task(_async_make_reloadable(hass)) @property def targets(self): @@ -97,3 +136,197 @@ async def test_warn_template(hass, caplog): # We should only log it once assert caplog.text.count("Passing templates to notify service is deprecated") == 1 assert hass.states.get("persistent_notification.notification") is not None + + +async def test_invalid_platform(hass, caplog, tmp_path): + """Test service setup with an invalid platform.""" + mock_notify_platform(hass, tmp_path, "testnotify1") + # Setup the platform + await async_setup_component( + hass, "notify", {"notify": [{"platform": "testnotify1"}]} + ) + await hass.async_block_till_done() + assert "Invalid notify platform" in caplog.text + caplog.clear() + # Setup the second testnotify2 platform dynamically + mock_notify_platform(hass, tmp_path, "testnotify2") + await async_load_platform( + hass, + "notify", + "testnotify2", + {}, + hass_config={"notify": [{"platform": "testnotify2"}]}, + ) + await hass.async_block_till_done() + assert "Invalid notify platform" in caplog.text + + +async def test_invalid_service(hass, caplog, tmp_path): + """Test service setup with an invalid service object or platform.""" + + def get_service(hass, config, discovery_info=None): + """Return None for an invalid notify service.""" + return None + + mock_notify_platform(hass, tmp_path, "testnotify", get_service=get_service) + # Setup the second testnotify2 platform dynamically + await async_load_platform( + hass, + "notify", + "testnotify", + {}, + hass_config={"notify": [{"platform": "testnotify"}]}, + ) + await hass.async_block_till_done() + assert "Failed to initialize notification service testnotify" in caplog.text + caplog.clear() + + await async_load_platform( + hass, + "notify", + "testnotifyinvalid", + {"notify": [{"platform": "testnotifyinvalid"}]}, + hass_config={"notify": [{"platform": "testnotifyinvalid"}]}, + ) + await hass.async_block_till_done() + assert "Unknown notification service specified" in caplog.text + + +async def test_platform_setup_with_error(hass, caplog, tmp_path): + """Test service setup with an invalid setup.""" + + async def async_get_service(hass, config, discovery_info=None): + """Return None for an invalid notify service.""" + raise Exception("Setup error") + + mock_notify_platform( + hass, tmp_path, "testnotify", async_get_service=async_get_service + ) + # Setup the second testnotify2 platform dynamically + await async_load_platform( + hass, + "notify", + "testnotify", + {}, + hass_config={"notify": [{"platform": "testnotify"}]}, + ) + await hass.async_block_till_done() + assert "Error setting up platform testnotify" in caplog.text + + +async def test_reload_with_notify_builtin_platform_reload(hass, caplog, tmp_path): + """Test reload using the notify platform reload method.""" + + async def async_get_service(hass, config, discovery_info=None): + """Get notify service for mocked platform.""" + targetlist = {"a": 1, "b": 2} + return NotificationService(hass, targetlist, "testnotify") + + # platform with service + mock_notify_platform( + hass, tmp_path, "testnotify", async_get_service=async_get_service + ) + + # Perform a reload using the notify module for testnotify (without services) + await notify.async_reload(hass, "testnotify") + + # Setup the platform + await async_setup_component( + hass, "notify", {"notify": [{"platform": "testnotify"}]} + ) + await hass.async_block_till_done() + assert hass.services.has_service(notify.DOMAIN, "testnotify_a") + assert hass.services.has_service(notify.DOMAIN, "testnotify_b") + + # Perform a reload using the notify module for testnotify (with services) + await notify.async_reload(hass, "testnotify") + assert hass.services.has_service(notify.DOMAIN, "testnotify_a") + assert hass.services.has_service(notify.DOMAIN, "testnotify_b") + + +async def test_setup_platform_and_reload(hass, caplog, tmp_path): + """Test service setup and reload.""" + get_service_called = Mock() + + async def async_get_service(hass, config, discovery_info=None): + """Get notify service for mocked platform.""" + get_service_called(config, discovery_info) + targetlist = {"a": 1, "b": 2} + return NotificationService(hass, targetlist, "testnotify") + + async def async_get_service2(hass, config, discovery_info=None): + """Get notify service for mocked platform.""" + get_service_called(config, discovery_info) + targetlist = {"c": 3, "d": 4} + return NotificationService(hass, targetlist, "testnotify2") + + # Mock first platform + mock_notify_platform( + hass, tmp_path, "testnotify", async_get_service=async_get_service + ) + + # Initialize a second platform testnotify2 + mock_notify_platform( + hass, tmp_path, "testnotify2", async_get_service=async_get_service2 + ) + + # Setup the testnotify platform + await async_setup_component( + hass, "notify", {"notify": [{"platform": "testnotify"}]} + ) + await hass.async_block_till_done() + assert hass.services.has_service("testnotify", SERVICE_RELOAD) + assert hass.services.has_service(notify.DOMAIN, "testnotify_a") + assert hass.services.has_service(notify.DOMAIN, "testnotify_b") + assert get_service_called.call_count == 1 + assert get_service_called.call_args[0][0] == {"platform": "testnotify"} + assert get_service_called.call_args[0][1] is None + get_service_called.reset_mock() + + # Setup the second testnotify2 platform dynamically + await async_load_platform( + hass, + "notify", + "testnotify2", + {}, + hass_config={"notify": [{"platform": "testnotify"}]}, + ) + await hass.async_block_till_done() + assert hass.services.has_service("testnotify2", SERVICE_RELOAD) + assert hass.services.has_service(notify.DOMAIN, "testnotify2_c") + assert hass.services.has_service(notify.DOMAIN, "testnotify2_d") + assert get_service_called.call_count == 1 + assert get_service_called.call_args[0][0] == {} + assert get_service_called.call_args[0][1] == {} + get_service_called.reset_mock() + + # Perform a reload + new_yaml_config_file = tmp_path / "configuration.yaml" + new_yaml_config = yaml.dump({"notify": [{"platform": "testnotify"}]}) + new_yaml_config_file.write_text(new_yaml_config) + + with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file): + await hass.services.async_call( + "testnotify", + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.services.async_call( + "testnotify2", + SERVICE_RELOAD, + {}, + blocking=True, + ) + await hass.async_block_till_done() + + # Check if the notify services from setup still exist + assert hass.services.has_service(notify.DOMAIN, "testnotify_a") + assert hass.services.has_service(notify.DOMAIN, "testnotify_b") + assert get_service_called.call_count == 1 + assert get_service_called.call_args[0][0] == {"platform": "testnotify"} + assert get_service_called.call_args[0][1] is None + + # Check if the dynamically notify services from setup were removed + assert not hass.services.has_service(notify.DOMAIN, "testnotify2_c") + assert not hass.services.has_service(notify.DOMAIN, "testnotify2_d")