Use the shared zeroconf instance when attempting to create another Zeroconf instance (#38744)

This commit is contained in:
J. Nick Koston 2020-08-12 09:08:33 -05:00 committed by GitHub
parent 34cb12d3c9
commit 444df4a7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 187 additions and 17 deletions

View File

@ -30,6 +30,8 @@ from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
from homeassistant.loader import async_get_homekit, async_get_zeroconf from homeassistant.loader import async_get_homekit, async_get_zeroconf
from .usage import install_multiple_zeroconf_catcher
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = "zeroconf" DOMAIN = "zeroconf"
@ -135,6 +137,8 @@ def setup(hass, config):
ipv6=zc_config.get(CONF_IPV6, DEFAULT_IPV6), ipv6=zc_config.get(CONF_IPV6, DEFAULT_IPV6),
) )
install_multiple_zeroconf_catcher(zeroconf)
# Get instance UUID # Get instance UUID
uuid = asyncio.run_coroutine_threadsafe( uuid = asyncio.run_coroutine_threadsafe(
hass.helpers.instance_id.async_get(), hass.loop hass.helpers.instance_id.async_get(), hass.loop

View File

@ -0,0 +1,50 @@
"""Zeroconf usage utility to warn about multiple instances."""
import logging
import zeroconf
from homeassistant.helpers.frame import (
MissingIntegrationFrame,
get_integration_frame,
report_integration,
)
_LOGGER = logging.getLogger(__name__)
def install_multiple_zeroconf_catcher(hass_zc) -> None:
"""Wrap the Zeroconf class to return the shared instance if multiple instances are detected."""
def new_zeroconf_new(self, *k, **kw):
_report(
"attempted to create another Zeroconf instance. Please use the shared Zeroconf via await homeassistant.components.zeroconf.async_get_instance(hass)",
)
return hass_zc
def new_zeroconf_init(self, *k, **kw):
return
zeroconf.Zeroconf.__new__ = new_zeroconf_new
zeroconf.Zeroconf.__init__ = new_zeroconf_init
def _report(what: str) -> None:
"""Report incorrect usage.
Async friendly.
"""
integration_frame = None
try:
integration_frame = get_integration_frame(exclude_integrations={"zeroconf"})
except MissingIntegrationFrame:
pass
if not integration_frame:
_LOGGER.warning(
"Detected code that %s. Please report this issue.", what, stack_info=True
)
return
report_integration(what, integration_frame)

View File

@ -3,7 +3,7 @@ import asyncio
import functools import functools
import logging import logging
from traceback import FrameSummary, extract_stack from traceback import FrameSummary, extract_stack
from typing import Any, Callable, Tuple, TypeVar, cast from typing import Any, Callable, Optional, Tuple, TypeVar, cast
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -12,15 +12,24 @@ _LOGGER = logging.getLogger(__name__)
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
def get_integration_frame() -> Tuple[FrameSummary, str, str]: def get_integration_frame(
exclude_integrations: Optional[set] = None,
) -> Tuple[FrameSummary, str, str]:
"""Return the frame, integration and integration path of the current stack frame.""" """Return the frame, integration and integration path of the current stack frame."""
found_frame = None found_frame = None
if not exclude_integrations:
exclude_integrations = set()
for frame in reversed(extract_stack()): for frame in reversed(extract_stack()):
for path in ("custom_components/", "homeassistant/components/"): for path in ("custom_components/", "homeassistant/components/"):
try: try:
index = frame.filename.index(path) index = frame.filename.index(path)
start = index + len(path)
end = frame.filename.index("/", start)
integration = frame.filename[start:end]
if integration not in exclude_integrations:
found_frame = frame found_frame = frame
break break
except ValueError: except ValueError:
continue continue
@ -31,11 +40,6 @@ def get_integration_frame() -> Tuple[FrameSummary, str, str]:
if found_frame is None: if found_frame is None:
raise MissingIntegrationFrame raise MissingIntegrationFrame
start = index + len(path)
end = found_frame.filename.index("/", start)
integration = found_frame.filename[start:end]
return found_frame, integration, path return found_frame, integration, path
@ -49,11 +53,24 @@ def report(what: str) -> None:
Async friendly. Async friendly.
""" """
try: try:
found_frame, integration, path = get_integration_frame() integration_frame = get_integration_frame()
except MissingIntegrationFrame: except MissingIntegrationFrame:
# Did not source from an integration? Hard error. # Did not source from an integration? Hard error.
raise RuntimeError(f"Detected code that {what}. Please report this issue.") raise RuntimeError(f"Detected code that {what}. Please report this issue.")
report_integration(what, integration_frame)
def report_integration(
what: str, integration_frame: Tuple[FrameSummary, str, str]
) -> None:
"""Report incorrect usage in an integration.
Async friendly.
"""
found_frame, integration, path = integration_frame
index = found_frame.filename.index(path) index = found_frame.filename.index(path)
if path == "custom_components/": if path == "custom_components/":
extra = " to the custom component author" extra = " to the custom component author"

View File

@ -1,8 +1,15 @@
"""Fixtures for component testing.""" """Fixtures for component testing."""
import pytest import pytest
from homeassistant.components import zeroconf
from tests.async_mock import patch from tests.async_mock import patch
zeroconf.orig_install_multiple_zeroconf_catcher = (
zeroconf.install_multiple_zeroconf_catcher
)
zeroconf.install_multiple_zeroconf_catcher = lambda zc: None
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def prevent_io(): def prevent_io():

View File

@ -0,0 +1,11 @@
"""conftest for zeroconf."""
import pytest
from tests.async_mock import patch
@pytest.fixture
def mock_zeroconf():
"""Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc:
yield mock_zc.return_value

View File

@ -1,5 +1,4 @@
"""Test Zeroconf component setup process.""" """Test Zeroconf component setup process."""
import pytest
from zeroconf import InterfaceChoice, IPVersion, ServiceInfo, ServiceStateChange from zeroconf import InterfaceChoice, IPVersion, ServiceInfo, ServiceStateChange
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
@ -22,13 +21,6 @@ HOMEKIT_STATUS_UNPAIRED = b"1"
HOMEKIT_STATUS_PAIRED = b"0" HOMEKIT_STATUS_PAIRED = b"0"
@pytest.fixture
def mock_zeroconf():
"""Mock zeroconf."""
with patch("homeassistant.components.zeroconf.HaZeroconf") as mock_zc:
yield mock_zc.return_value
def service_update_mock(zeroconf, services, handlers): def service_update_mock(zeroconf, services, handlers):
"""Call service update handler.""" """Call service update handler."""
for service in services: for service in services:

View File

@ -0,0 +1,56 @@
"""Test Zeroconf multiple instance protection."""
import zeroconf
from homeassistant.components.zeroconf import async_get_instance
from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher
from tests.async_mock import Mock, patch
async def test_multiple_zeroconf_instances(hass, mock_zeroconf, caplog):
"""Test creating multiple zeroconf throws without an integration."""
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
new_zeroconf_instance = zeroconf.Zeroconf()
assert new_zeroconf_instance == zeroconf_instance
assert "Zeroconf" in caplog.text
async def test_multiple_zeroconf_instances_gives_shared(hass, mock_zeroconf, caplog):
"""Test creating multiple zeroconf gives the shared instance to an integration."""
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
correct_frame = Mock(
filename="/config/custom_components/burncpu/light.py",
lineno="23",
line="self.light.is_on",
)
with patch(
"homeassistant.helpers.frame.extract_stack",
return_value=[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",),
],
):
assert zeroconf.Zeroconf() == zeroconf_instance
assert "custom_components/burncpu/light.py" in caplog.text
assert "23" in caplog.text
assert "self.light.is_on" in caplog.text

View File

@ -36,6 +36,39 @@ async def test_extract_frame_integration(caplog):
assert found_frame == correct_frame assert found_frame == correct_frame
async def test_extract_frame_integration_with_excluded_intergration(caplog):
"""Test extracting the current frame from integration context."""
correct_frame = Mock(
filename="/home/dev/homeassistant/components/mdns/light.py",
lineno="23",
line="self.light.is_on",
)
with patch(
"homeassistant.helpers.frame.extract_stack",
return_value=[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(filename="/home/dev/mdns/lights.py", lineno="2", line="something()",),
],
):
found_frame, integration, path = frame.get_integration_frame(
exclude_integrations={"zeroconf"}
)
assert integration == "mdns"
assert path == "homeassistant/components/"
assert found_frame == correct_frame
async def test_extract_frame_no_integration(caplog): async def test_extract_frame_no_integration(caplog):
"""Test extracting the current frame without integration context.""" """Test extracting the current frame without integration context."""
with patch( with patch(