From 265f6653c3c941750913353e3c9be047c74c7a75 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 10 Oct 2023 09:14:37 -1000 Subject: [PATCH] Refactor homekit to use a dataclass for entry data (#101738) --- homeassistant/components/homekit/__init__.py | 38 ++++++++++--------- homeassistant/components/homekit/const.py | 5 +-- .../components/homekit/diagnostics.py | 7 ++-- homeassistant/components/homekit/models.py | 15 ++++++++ homeassistant/components/homekit/util.py | 9 +++-- tests/components/homekit/test_homekit.py | 8 ++-- tests/components/homekit/test_util.py | 8 ++-- 7 files changed, 52 insertions(+), 38 deletions(-) create mode 100644 homeassistant/components/homekit/models.py diff --git a/homeassistant/components/homekit/__init__.py b/homeassistant/components/homekit/__init__.py index c3b7bf5d2e6..0920530524d 100644 --- a/homeassistant/components/homekit/__init__.py +++ b/homeassistant/components/homekit/__init__.py @@ -112,18 +112,16 @@ from .const import ( DEFAULT_HOMEKIT_MODE, DEFAULT_PORT, DOMAIN, - HOMEKIT, HOMEKIT_MODE_ACCESSORY, HOMEKIT_MODES, - HOMEKIT_PAIRING_QR, - HOMEKIT_PAIRING_QR_SECRET, MANUFACTURER, - PERSIST_LOCK, + PERSIST_LOCK_DATA, SERVICE_HOMEKIT_RESET_ACCESSORY, SERVICE_HOMEKIT_UNPAIR, SHUTDOWN_TIMEOUT, ) from .iidmanager import AccessoryIIDStorage +from .models import HomeKitEntryData from .type_triggers import DeviceTriggerAccessory from .util import ( accessory_friendly_name, @@ -205,11 +203,8 @@ UNPAIR_SERVICE_SCHEMA = vol.All( def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]: """All active HomeKit instances.""" - return [ - data[HOMEKIT] - for data in hass.data[DOMAIN].values() - if isinstance(data, dict) and HOMEKIT in data - ] + domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN] + return [data.homekit for data in domain_data.values()] def _async_get_imported_entries_indices( @@ -231,7 +226,8 @@ def _async_get_imported_entries_indices( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the HomeKit from yaml.""" - hass.data.setdefault(DOMAIN, {})[PERSIST_LOCK] = asyncio.Lock() + hass.data[DOMAIN] = {} + hass.data[PERSIST_LOCK_DATA] = asyncio.Lock() # Initialize the loader before loading entries to ensure # there is no race where multiple entries try to load it @@ -352,7 +348,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, homekit.async_stop) ) - hass.data[DOMAIN][entry.entry_id] = {HOMEKIT: homekit} + entry_data = HomeKitEntryData( + homekit=homekit, pairing_qr=None, pairing_qr_secret=None + ) + hass.data[DOMAIN][entry.entry_id] = entry_data if hass.state == CoreState.running: await homekit.async_start() @@ -372,7 +371,8 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> Non async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" async_dismiss_setup_message(hass, entry.entry_id) - homekit = hass.data[DOMAIN][entry.entry_id][HOMEKIT] + entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] + homekit = entry_data.homekit if homekit.status == STATUS_RUNNING: await homekit.async_stop() @@ -849,7 +849,7 @@ class HomeKit: self._async_register_bridge() _LOGGER.debug("Driver start for %s", self._name) await self.driver.async_start() - async with self.hass.data[DOMAIN][PERSIST_LOCK]: + async with self.hass.data[PERSIST_LOCK_DATA]: await self.hass.async_add_executor_job(self.driver.persist) self.status = STATUS_RUNNING @@ -1162,14 +1162,16 @@ class HomeKitPairingQRView(HomeAssistantView): if not request.query_string: raise Unauthorized() entry_id, secret = request.query_string.split("-") - + hass: HomeAssistant = request.app["hass"] + domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN] if ( - entry_id not in request.app["hass"].data[DOMAIN] - or secret - != request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET] + not (entry_data := domain_data.get(entry_id)) + or not secret + or not entry_data.pairing_qr_secret + or secret != entry_data.pairing_qr_secret ): raise Unauthorized() return web.Response( - body=request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR], + body=entry_data.pairing_qr, content_type="image/svg+xml", ) diff --git a/homeassistant/components/homekit/const.py b/homeassistant/components/homekit/const.py index bb5ae1ffd1c..5a7ee1d9576 100644 --- a/homeassistant/components/homekit/const.py +++ b/homeassistant/components/homekit/const.py @@ -6,13 +6,10 @@ from homeassistant.const import CONF_DEVICES DEBOUNCE_TIMEOUT = 0.5 DEVICE_PRECISION_LEEWAY = 6 DOMAIN = "homekit" +PERSIST_LOCK_DATA = f"{DOMAIN}_persist_lock" HOMEKIT_FILE = ".homekit.state" -HOMEKIT_PAIRING_QR = "homekit-pairing-qr" -HOMEKIT_PAIRING_QR_SECRET = "homekit-pairing-qr-secret" -HOMEKIT = "homekit" SHUTDOWN_TIMEOUT = 30 CONF_ENTRY_INDEX = "index" -PERSIST_LOCK = "persist_lock" # ### Codecs #### VIDEO_CODEC_COPY = "copy" diff --git a/homeassistant/components/homekit/diagnostics.py b/homeassistant/components/homekit/diagnostics.py index f27171e6eae..347a3df0dd4 100644 --- a/homeassistant/components/homekit/diagnostics.py +++ b/homeassistant/components/homekit/diagnostics.py @@ -10,9 +10,9 @@ from homeassistant.components.diagnostics import async_redact_data from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from . import HomeKit from .accessories import HomeAccessory, HomeBridge -from .const import DOMAIN, HOMEKIT +from .const import DOMAIN +from .models import HomeKitEntryData TO_REDACT = {"access_token", "entity_picture"} @@ -21,7 +21,8 @@ async def async_get_config_entry_diagnostics( hass: HomeAssistant, entry: ConfigEntry ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - homekit: HomeKit = hass.data[DOMAIN][entry.entry_id][HOMEKIT] + entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] + homekit = entry_data.homekit data: dict[str, Any] = { "status": homekit.status, "config-entry": { diff --git a/homeassistant/components/homekit/models.py b/homeassistant/components/homekit/models.py new file mode 100644 index 00000000000..e96af00fead --- /dev/null +++ b/homeassistant/components/homekit/models.py @@ -0,0 +1,15 @@ +"""Models for the HomeKit component.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from . import HomeKit + + +@dataclass +class HomeKitEntryData: + """Class to hold HomeKit data.""" + + homekit: "HomeKit" + pairing_qr: bytes | None = None + pairing_qr_secret: str | None = None diff --git a/homeassistant/components/homekit/util.py b/homeassistant/components/homekit/util.py index 151b97f2cda..8a51f35564e 100644 --- a/homeassistant/components/homekit/util.py +++ b/homeassistant/components/homekit/util.py @@ -86,8 +86,6 @@ from .const import ( FEATURE_PLAY_PAUSE, FEATURE_PLAY_STOP, FEATURE_TOGGLE_MUTE, - HOMEKIT_PAIRING_QR, - HOMEKIT_PAIRING_QR_SECRET, MAX_NAME_LENGTH, TYPE_FAUCET, TYPE_OUTLET, @@ -100,6 +98,7 @@ from .const import ( VIDEO_CODEC_H264_V4L2M2M, VIDEO_CODEC_LIBX264, ) +from .models import HomeKitEntryData _LOGGER = logging.getLogger(__name__) @@ -352,8 +351,10 @@ def async_show_setup_message( url.svg(buffer, scale=5, module_color="#000", background="#FFF") pairing_secret = secrets.token_hex(32) - hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR] = buffer.getvalue() - hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET] = pairing_secret + entry_data: HomeKitEntryData = hass.data[DOMAIN][entry_id] + + entry_data.pairing_qr = buffer.getvalue() + entry_data.pairing_qr_secret = pairing_secret message = ( f"To set up {bridge_name} in the Home App, " diff --git a/tests/components/homekit/test_homekit.py b/tests/components/homekit/test_homekit.py index ebb710561d9..5c517ac9cb9 100644 --- a/tests/components/homekit/test_homekit.py +++ b/tests/components/homekit/test_homekit.py @@ -28,12 +28,12 @@ from homeassistant.components.homekit.const import ( CONF_ADVERTISE_IP, DEFAULT_PORT, DOMAIN, - HOMEKIT, HOMEKIT_MODE_ACCESSORY, HOMEKIT_MODE_BRIDGE, SERVICE_HOMEKIT_RESET_ACCESSORY, SERVICE_HOMEKIT_UNPAIR, ) +from homeassistant.components.homekit.models import HomeKitEntryData from homeassistant.components.homekit.type_triggers import DeviceTriggerAccessory from homeassistant.components.homekit.util import get_persist_fullpath_for_entry_id from homeassistant.components.light import ( @@ -1799,10 +1799,8 @@ async def test_homekit_uses_system_zeroconf( entry.add_to_hass(hass) assert await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() - assert ( - hass.data[DOMAIN][entry.entry_id][HOMEKIT].driver.advertiser - == system_async_zc - ) + entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] + assert entry_data.homekit.driver.advertiser == system_async_zc assert await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() diff --git a/tests/components/homekit/test_util.py b/tests/components/homekit/test_util.py index 0046f90b284..60ee2a4d8e8 100644 --- a/tests/components/homekit/test_util.py +++ b/tests/components/homekit/test_util.py @@ -14,8 +14,6 @@ from homeassistant.components.homekit.const import ( DOMAIN, FEATURE_ON_OFF, FEATURE_PLAY_PAUSE, - HOMEKIT_PAIRING_QR, - HOMEKIT_PAIRING_QR_SECRET, TYPE_FAUCET, TYPE_OUTLET, TYPE_SHOWER, @@ -23,6 +21,7 @@ from homeassistant.components.homekit.const import ( TYPE_SWITCH, TYPE_VALVE, ) +from homeassistant.components.homekit.models import HomeKitEntryData from homeassistant.components.homekit.util import ( accessory_friendly_name, async_dismiss_setup_message, @@ -251,8 +250,9 @@ async def test_async_show_setup_msg( hass, entry.entry_id, "bridge_name", pincode, "X-HM://0" ) await hass.async_block_till_done() - assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR_SECRET] - assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR] + entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id] + assert entry_data.pairing_qr_secret + assert entry_data.pairing_qr assert len(mock_create.mock_calls) == 1 assert mock_create.mock_calls[0][1][3] == entry.entry_id