From 444df4a7d2d20edfbe4ff8fe0ba10413d5b7205e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 12 Aug 2020 09:08:33 -0500 Subject: [PATCH] Use the shared zeroconf instance when attempting to create another Zeroconf instance (#38744) --- homeassistant/components/zeroconf/__init__.py | 4 ++ homeassistant/components/zeroconf/usage.py | 50 +++++++++++++++++ homeassistant/helpers/frame.py | 35 +++++++++--- tests/components/conftest.py | 7 +++ tests/components/zeroconf/conftest.py | 11 ++++ tests/components/zeroconf/test_init.py | 8 --- tests/components/zeroconf/test_usage.py | 56 +++++++++++++++++++ tests/helpers/test_frame.py | 33 +++++++++++ 8 files changed, 187 insertions(+), 17 deletions(-) create mode 100644 homeassistant/components/zeroconf/usage.py create mode 100644 tests/components/zeroconf/conftest.py create mode 100644 tests/components/zeroconf/test_usage.py diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index 71e2f67bad7..5bbc87f3da8 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -30,6 +30,8 @@ from homeassistant.helpers.network import NoURLAvailableError, get_url from homeassistant.helpers.singleton import singleton from homeassistant.loader import async_get_homekit, async_get_zeroconf +from .usage import install_multiple_zeroconf_catcher + _LOGGER = logging.getLogger(__name__) DOMAIN = "zeroconf" @@ -135,6 +137,8 @@ def setup(hass, config): ipv6=zc_config.get(CONF_IPV6, DEFAULT_IPV6), ) + install_multiple_zeroconf_catcher(zeroconf) + # Get instance UUID uuid = asyncio.run_coroutine_threadsafe( hass.helpers.instance_id.async_get(), hass.loop diff --git a/homeassistant/components/zeroconf/usage.py b/homeassistant/components/zeroconf/usage.py new file mode 100644 index 00000000000..1303412249c --- /dev/null +++ b/homeassistant/components/zeroconf/usage.py @@ -0,0 +1,50 @@ +"""Zeroconf usage utility to warn about multiple instances.""" + +import logging + +import zeroconf + +from homeassistant.helpers.frame import ( + MissingIntegrationFrame, + get_integration_frame, + report_integration, +) + +_LOGGER = logging.getLogger(__name__) + + +def install_multiple_zeroconf_catcher(hass_zc) -> None: + """Wrap the Zeroconf class to return the shared instance if multiple instances are detected.""" + + def new_zeroconf_new(self, *k, **kw): + _report( + "attempted to create another Zeroconf instance. Please use the shared Zeroconf via await homeassistant.components.zeroconf.async_get_instance(hass)", + ) + return hass_zc + + def new_zeroconf_init(self, *k, **kw): + return + + zeroconf.Zeroconf.__new__ = new_zeroconf_new + zeroconf.Zeroconf.__init__ = new_zeroconf_init + + +def _report(what: str) -> None: + """Report incorrect usage. + + Async friendly. + """ + integration_frame = None + + try: + integration_frame = get_integration_frame(exclude_integrations={"zeroconf"}) + except MissingIntegrationFrame: + pass + + if not integration_frame: + _LOGGER.warning( + "Detected code that %s. Please report this issue.", what, stack_info=True + ) + return + + report_integration(what, integration_frame) diff --git a/homeassistant/helpers/frame.py b/homeassistant/helpers/frame.py index 63d7cba4ec5..35f7b3fab9f 100644 --- a/homeassistant/helpers/frame.py +++ b/homeassistant/helpers/frame.py @@ -3,7 +3,7 @@ import asyncio import functools import logging from traceback import FrameSummary, extract_stack -from typing import Any, Callable, Tuple, TypeVar, cast +from typing import Any, Callable, Optional, Tuple, TypeVar, cast from homeassistant.exceptions import HomeAssistantError @@ -12,15 +12,24 @@ _LOGGER = logging.getLogger(__name__) CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name -def get_integration_frame() -> Tuple[FrameSummary, str, str]: +def get_integration_frame( + exclude_integrations: Optional[set] = None, +) -> Tuple[FrameSummary, str, str]: """Return the frame, integration and integration path of the current stack frame.""" found_frame = None + if not exclude_integrations: + exclude_integrations = set() for frame in reversed(extract_stack()): for path in ("custom_components/", "homeassistant/components/"): try: index = frame.filename.index(path) - found_frame = frame + start = index + len(path) + end = frame.filename.index("/", start) + integration = frame.filename[start:end] + if integration not in exclude_integrations: + found_frame = frame + break except ValueError: continue @@ -31,11 +40,6 @@ def get_integration_frame() -> Tuple[FrameSummary, str, str]: if found_frame is None: raise MissingIntegrationFrame - start = index + len(path) - end = found_frame.filename.index("/", start) - - integration = found_frame.filename[start:end] - return found_frame, integration, path @@ -49,11 +53,24 @@ def report(what: str) -> None: Async friendly. """ try: - found_frame, integration, path = get_integration_frame() + integration_frame = get_integration_frame() except MissingIntegrationFrame: # Did not source from an integration? Hard error. raise RuntimeError(f"Detected code that {what}. Please report this issue.") + report_integration(what, integration_frame) + + +def report_integration( + what: str, integration_frame: Tuple[FrameSummary, str, str] +) -> None: + """Report incorrect usage in an integration. + + Async friendly. + """ + + found_frame, integration, path = integration_frame + index = found_frame.filename.index(path) if path == "custom_components/": extra = " to the custom component author" diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 96ab3bca543..e78a67a16f3 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -1,8 +1,15 @@ """Fixtures for component testing.""" import pytest +from homeassistant.components import zeroconf + from tests.async_mock import patch +zeroconf.orig_install_multiple_zeroconf_catcher = ( + zeroconf.install_multiple_zeroconf_catcher +) +zeroconf.install_multiple_zeroconf_catcher = lambda zc: None + @pytest.fixture(autouse=True) def prevent_io(): diff --git a/tests/components/zeroconf/conftest.py b/tests/components/zeroconf/conftest.py new file mode 100644 index 00000000000..e7d7b030aaf --- /dev/null +++ b/tests/components/zeroconf/conftest.py @@ -0,0 +1,11 @@ +"""conftest for zeroconf.""" +import pytest + +from tests.async_mock import patch + + +@pytest.fixture +def mock_zeroconf(): + """Mock zeroconf.""" + with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc: + yield mock_zc.return_value diff --git a/tests/components/zeroconf/test_init.py b/tests/components/zeroconf/test_init.py index e8315b5dc75..2ced2fac8ea 100644 --- a/tests/components/zeroconf/test_init.py +++ b/tests/components/zeroconf/test_init.py @@ -1,5 +1,4 @@ """Test Zeroconf component setup process.""" -import pytest from zeroconf import InterfaceChoice, IPVersion, ServiceInfo, ServiceStateChange from homeassistant.components import zeroconf @@ -22,13 +21,6 @@ HOMEKIT_STATUS_UNPAIRED = b"1" HOMEKIT_STATUS_PAIRED = b"0" -@pytest.fixture -def mock_zeroconf(): - """Mock zeroconf.""" - with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc: - yield mock_zc.return_value - - def service_update_mock(zeroconf, services, handlers): """Call service update handler.""" for service in services: diff --git a/tests/components/zeroconf/test_usage.py b/tests/components/zeroconf/test_usage.py new file mode 100644 index 00000000000..0a2095daa6a --- /dev/null +++ b/tests/components/zeroconf/test_usage.py @@ -0,0 +1,56 @@ +"""Test Zeroconf multiple instance protection.""" +import zeroconf + +from homeassistant.components.zeroconf import async_get_instance +from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher + +from tests.async_mock import Mock, patch + + +async def test_multiple_zeroconf_instances(hass, mock_zeroconf, caplog): + """Test creating multiple zeroconf throws without an integration.""" + + zeroconf_instance = await async_get_instance(hass) + + install_multiple_zeroconf_catcher(zeroconf_instance) + + new_zeroconf_instance = zeroconf.Zeroconf() + assert new_zeroconf_instance == zeroconf_instance + + assert "Zeroconf" in caplog.text + + +async def test_multiple_zeroconf_instances_gives_shared(hass, mock_zeroconf, caplog): + """Test creating multiple zeroconf gives the shared instance to an integration.""" + + zeroconf_instance = await async_get_instance(hass) + + install_multiple_zeroconf_catcher(zeroconf_instance) + + correct_frame = Mock( + filename="/config/custom_components/burncpu/light.py", + lineno="23", + line="self.light.is_on", + ) + with patch( + "homeassistant.helpers.frame.extract_stack", + return_value=[ + Mock( + filename="/home/dev/homeassistant/core.py", + lineno="23", + line="do_something()", + ), + correct_frame, + Mock( + filename="/home/dev/homeassistant/components/zeroconf/usage.py", + lineno="23", + line="self.light.is_on", + ), + Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",), + ], + ): + assert zeroconf.Zeroconf() == zeroconf_instance + + assert "custom_components/burncpu/light.py" in caplog.text + assert "23" in caplog.text + assert "self.light.is_on" in caplog.text diff --git a/tests/helpers/test_frame.py b/tests/helpers/test_frame.py index 2e8c83c6517..9f68ecdefb2 100644 --- a/tests/helpers/test_frame.py +++ b/tests/helpers/test_frame.py @@ -36,6 +36,39 @@ async def test_extract_frame_integration(caplog): assert found_frame == correct_frame +async def test_extract_frame_integration_with_excluded_intergration(caplog): + """Test extracting the current frame from integration context.""" + correct_frame = Mock( + filename="/home/dev/homeassistant/components/mdns/light.py", + lineno="23", + line="self.light.is_on", + ) + with patch( + "homeassistant.helpers.frame.extract_stack", + return_value=[ + Mock( + filename="/home/dev/homeassistant/core.py", + lineno="23", + line="do_something()", + ), + correct_frame, + Mock( + filename="/home/dev/homeassistant/components/zeroconf/usage.py", + lineno="23", + line="self.light.is_on", + ), + Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",), + ], + ): + found_frame, integration, path = frame.get_integration_frame( + exclude_integrations={"zeroconf"} + ) + + assert integration == "mdns" + assert path == "homeassistant/components/" + assert found_frame == correct_frame + + async def test_extract_frame_no_integration(caplog): """Test extracting the current frame without integration context.""" with patch(