diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index c804fb6b642..3ab7b03b5d4 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -23,6 +23,7 @@ from homeassistant.const import ( from homeassistant.generated.zeroconf import HOMEKIT, ZEROCONF import homeassistant.helpers.config_validation as cv from homeassistant.helpers.network import NoURLAvailableError, get_url +from homeassistant.helpers.singleton import singleton _LOGGER = logging.getLogger(__name__) @@ -54,12 +55,40 @@ CONFIG_SCHEMA = vol.Schema( ) +@singleton(DOMAIN) +async def async_get_instance(hass): + """Zeroconf instance to be shared with other integrations that use it.""" + return await hass.async_add_executor_job(_get_instance, hass) + + +def _get_instance(hass, default_interface=False): + """Create an instance.""" + args = [InterfaceChoice.Default] if default_interface else [] + zeroconf = HaZeroconf(*args) + + def stop_zeroconf(_): + """Stop Zeroconf.""" + zeroconf.ha_close() + + hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_zeroconf) + + return zeroconf + + +class HaZeroconf(Zeroconf): + """Zeroconf that cannot be closed.""" + + def close(self): + """Fake method to avoid integrations closing it.""" + + ha_close = Zeroconf.close + + def setup(hass, config): """Set up Zeroconf and make Home Assistant discoverable.""" - if DOMAIN in config and config[DOMAIN].get(CONF_DEFAULT_INTERFACE): - zeroconf = Zeroconf(interfaces=InterfaceChoice.Default) - else: - zeroconf = Zeroconf() + zeroconf = hass.data[DOMAIN] = _get_instance( + hass, config.get(DOMAIN, {}).get(CONF_DEFAULT_INTERFACE) + ) zeroconf_name = f"{hass.config.location_name}.{ZEROCONF_TYPE}" params = { @@ -142,12 +171,6 @@ def setup(hass, config): if HOMEKIT_TYPE not in ZEROCONF: ServiceBrowser(zeroconf, HOMEKIT_TYPE, handlers=[service_update]) - def stop_zeroconf(_): - """Stop Zeroconf.""" - zeroconf.close() - - hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_zeroconf) - return True diff --git a/tests/components/zeroconf/test_init.py b/tests/components/zeroconf/test_init.py index bd74b3c2b8d..1110f9d145f 100644 --- a/tests/components/zeroconf/test_init.py +++ b/tests/components/zeroconf/test_init.py @@ -4,6 +4,7 @@ from zeroconf import InterfaceChoice, ServiceInfo, ServiceStateChange from homeassistant.components import zeroconf from homeassistant.components.zeroconf import CONF_DEFAULT_INTERFACE +from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.generated import zeroconf as zc_gen from homeassistant.setup import async_setup_component @@ -21,7 +22,7 @@ PROPERTIES = { @pytest.fixture def mock_zeroconf(): """Mock zeroconf.""" - with patch("homeassistant.components.zeroconf.Zeroconf") as mock_zc: + with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc: yield mock_zc.return_value @@ -78,6 +79,10 @@ async def test_setup(hass, mock_zeroconf): expected_flow_calls += len(matching_components) assert len(mock_config_flow.mock_calls) == expected_flow_calls + # Test instance is set. + assert "zeroconf" in hass.data + assert await hass.components.zeroconf.async_get_instance() is mock_zeroconf + async def test_setup_with_default_interface(hass, mock_zeroconf): """Test default interface config.""" @@ -171,3 +176,11 @@ async def test_info_from_service_non_utf8(hass): assert len(info["properties"]) <= len(raw_info) assert "non-utf8-value" not in info["properties"] assert raw_info["non-utf8-value"] is NON_UTF8_VALUE + + +async def test_get_instance(hass, mock_zeroconf): + """Test we get an instance.""" + assert await hass.components.zeroconf.async_get_instance() is mock_zeroconf + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + assert len(mock_zeroconf.ha_close.mock_calls) == 1