Refactor homekit to use a dataclass for entry data (#101738)

This commit is contained in:
J. Nick Koston 2023-10-10 09:14:37 -10:00 committed by GitHub
parent 6c65db2036
commit 265f6653c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 38 deletions

View File

@ -112,18 +112,16 @@ from .const import (
DEFAULT_HOMEKIT_MODE, DEFAULT_HOMEKIT_MODE,
DEFAULT_PORT, DEFAULT_PORT,
DOMAIN, DOMAIN,
HOMEKIT,
HOMEKIT_MODE_ACCESSORY, HOMEKIT_MODE_ACCESSORY,
HOMEKIT_MODES, HOMEKIT_MODES,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
MANUFACTURER, MANUFACTURER,
PERSIST_LOCK, PERSIST_LOCK_DATA,
SERVICE_HOMEKIT_RESET_ACCESSORY, SERVICE_HOMEKIT_RESET_ACCESSORY,
SERVICE_HOMEKIT_UNPAIR, SERVICE_HOMEKIT_UNPAIR,
SHUTDOWN_TIMEOUT, SHUTDOWN_TIMEOUT,
) )
from .iidmanager import AccessoryIIDStorage from .iidmanager import AccessoryIIDStorage
from .models import HomeKitEntryData
from .type_triggers import DeviceTriggerAccessory from .type_triggers import DeviceTriggerAccessory
from .util import ( from .util import (
accessory_friendly_name, accessory_friendly_name,
@ -205,11 +203,8 @@ UNPAIR_SERVICE_SCHEMA = vol.All(
def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]: def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]:
"""All active HomeKit instances.""" """All active HomeKit instances."""
return [ domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN]
data[HOMEKIT] return [data.homekit for data in domain_data.values()]
for data in hass.data[DOMAIN].values()
if isinstance(data, dict) and HOMEKIT in data
]
def _async_get_imported_entries_indices( def _async_get_imported_entries_indices(
@ -231,7 +226,8 @@ def _async_get_imported_entries_indices(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HomeKit from yaml.""" """Set up the HomeKit from yaml."""
hass.data.setdefault(DOMAIN, {})[PERSIST_LOCK] = asyncio.Lock() hass.data[DOMAIN] = {}
hass.data[PERSIST_LOCK_DATA] = asyncio.Lock()
# Initialize the loader before loading entries to ensure # Initialize the loader before loading entries to ensure
# there is no race where multiple entries try to load it # there is no race where multiple entries try to load it
@ -352,7 +348,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, homekit.async_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, homekit.async_stop)
) )
hass.data[DOMAIN][entry.entry_id] = {HOMEKIT: homekit} entry_data = HomeKitEntryData(
homekit=homekit, pairing_qr=None, pairing_qr_secret=None
)
hass.data[DOMAIN][entry.entry_id] = entry_data
if hass.state == CoreState.running: if hass.state == CoreState.running:
await homekit.async_start() await homekit.async_start()
@ -372,7 +371,8 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> Non
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
async_dismiss_setup_message(hass, entry.entry_id) async_dismiss_setup_message(hass, entry.entry_id)
homekit = hass.data[DOMAIN][entry.entry_id][HOMEKIT] entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
homekit = entry_data.homekit
if homekit.status == STATUS_RUNNING: if homekit.status == STATUS_RUNNING:
await homekit.async_stop() await homekit.async_stop()
@ -849,7 +849,7 @@ class HomeKit:
self._async_register_bridge() self._async_register_bridge()
_LOGGER.debug("Driver start for %s", self._name) _LOGGER.debug("Driver start for %s", self._name)
await self.driver.async_start() await self.driver.async_start()
async with self.hass.data[DOMAIN][PERSIST_LOCK]: async with self.hass.data[PERSIST_LOCK_DATA]:
await self.hass.async_add_executor_job(self.driver.persist) await self.hass.async_add_executor_job(self.driver.persist)
self.status = STATUS_RUNNING self.status = STATUS_RUNNING
@ -1162,14 +1162,16 @@ class HomeKitPairingQRView(HomeAssistantView):
if not request.query_string: if not request.query_string:
raise Unauthorized() raise Unauthorized()
entry_id, secret = request.query_string.split("-") entry_id, secret = request.query_string.split("-")
hass: HomeAssistant = request.app["hass"]
domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN]
if ( if (
entry_id not in request.app["hass"].data[DOMAIN] not (entry_data := domain_data.get(entry_id))
or secret or not secret
!= request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET] or not entry_data.pairing_qr_secret
or secret != entry_data.pairing_qr_secret
): ):
raise Unauthorized() raise Unauthorized()
return web.Response( return web.Response(
body=request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR], body=entry_data.pairing_qr,
content_type="image/svg+xml", content_type="image/svg+xml",
) )

View File

@ -6,13 +6,10 @@ from homeassistant.const import CONF_DEVICES
DEBOUNCE_TIMEOUT = 0.5 DEBOUNCE_TIMEOUT = 0.5
DEVICE_PRECISION_LEEWAY = 6 DEVICE_PRECISION_LEEWAY = 6
DOMAIN = "homekit" DOMAIN = "homekit"
PERSIST_LOCK_DATA = f"{DOMAIN}_persist_lock"
HOMEKIT_FILE = ".homekit.state" HOMEKIT_FILE = ".homekit.state"
HOMEKIT_PAIRING_QR = "homekit-pairing-qr"
HOMEKIT_PAIRING_QR_SECRET = "homekit-pairing-qr-secret"
HOMEKIT = "homekit"
SHUTDOWN_TIMEOUT = 30 SHUTDOWN_TIMEOUT = 30
CONF_ENTRY_INDEX = "index" CONF_ENTRY_INDEX = "index"
PERSIST_LOCK = "persist_lock"
# ### Codecs #### # ### Codecs ####
VIDEO_CODEC_COPY = "copy" VIDEO_CODEC_COPY = "copy"

View File

@ -10,9 +10,9 @@ from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import HomeKit
from .accessories import HomeAccessory, HomeBridge from .accessories import HomeAccessory, HomeBridge
from .const import DOMAIN, HOMEKIT from .const import DOMAIN
from .models import HomeKitEntryData
TO_REDACT = {"access_token", "entity_picture"} TO_REDACT = {"access_token", "entity_picture"}
@ -21,7 +21,8 @@ async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
homekit: HomeKit = hass.data[DOMAIN][entry.entry_id][HOMEKIT] entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
homekit = entry_data.homekit
data: dict[str, Any] = { data: dict[str, Any] = {
"status": homekit.status, "status": homekit.status,
"config-entry": { "config-entry": {

View File

@ -0,0 +1,15 @@
"""Models for the HomeKit component."""
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from . import HomeKit
@dataclass
class HomeKitEntryData:
"""Class to hold HomeKit data."""
homekit: "HomeKit"
pairing_qr: bytes | None = None
pairing_qr_secret: str | None = None

View File

@ -86,8 +86,6 @@ from .const import (
FEATURE_PLAY_PAUSE, FEATURE_PLAY_PAUSE,
FEATURE_PLAY_STOP, FEATURE_PLAY_STOP,
FEATURE_TOGGLE_MUTE, FEATURE_TOGGLE_MUTE,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
MAX_NAME_LENGTH, MAX_NAME_LENGTH,
TYPE_FAUCET, TYPE_FAUCET,
TYPE_OUTLET, TYPE_OUTLET,
@ -100,6 +98,7 @@ from .const import (
VIDEO_CODEC_H264_V4L2M2M, VIDEO_CODEC_H264_V4L2M2M,
VIDEO_CODEC_LIBX264, VIDEO_CODEC_LIBX264,
) )
from .models import HomeKitEntryData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -352,8 +351,10 @@ def async_show_setup_message(
url.svg(buffer, scale=5, module_color="#000", background="#FFF") url.svg(buffer, scale=5, module_color="#000", background="#FFF")
pairing_secret = secrets.token_hex(32) pairing_secret = secrets.token_hex(32)
hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR] = buffer.getvalue() entry_data: HomeKitEntryData = hass.data[DOMAIN][entry_id]
hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET] = pairing_secret
entry_data.pairing_qr = buffer.getvalue()
entry_data.pairing_qr_secret = pairing_secret
message = ( message = (
f"To set up {bridge_name} in the Home App, " f"To set up {bridge_name} in the Home App, "

View File

@ -28,12 +28,12 @@ from homeassistant.components.homekit.const import (
CONF_ADVERTISE_IP, CONF_ADVERTISE_IP,
DEFAULT_PORT, DEFAULT_PORT,
DOMAIN, DOMAIN,
HOMEKIT,
HOMEKIT_MODE_ACCESSORY, HOMEKIT_MODE_ACCESSORY,
HOMEKIT_MODE_BRIDGE, HOMEKIT_MODE_BRIDGE,
SERVICE_HOMEKIT_RESET_ACCESSORY, SERVICE_HOMEKIT_RESET_ACCESSORY,
SERVICE_HOMEKIT_UNPAIR, SERVICE_HOMEKIT_UNPAIR,
) )
from homeassistant.components.homekit.models import HomeKitEntryData
from homeassistant.components.homekit.type_triggers import DeviceTriggerAccessory from homeassistant.components.homekit.type_triggers import DeviceTriggerAccessory
from homeassistant.components.homekit.util import get_persist_fullpath_for_entry_id from homeassistant.components.homekit.util import get_persist_fullpath_for_entry_id
from homeassistant.components.light import ( from homeassistant.components.light import (
@ -1799,10 +1799,8 @@ async def test_homekit_uses_system_zeroconf(
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert ( entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
hass.data[DOMAIN][entry.entry_id][HOMEKIT].driver.advertiser assert entry_data.homekit.driver.advertiser == system_async_zc
== system_async_zc
)
assert await hass.config_entries.async_unload(entry.entry_id) assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -14,8 +14,6 @@ from homeassistant.components.homekit.const import (
DOMAIN, DOMAIN,
FEATURE_ON_OFF, FEATURE_ON_OFF,
FEATURE_PLAY_PAUSE, FEATURE_PLAY_PAUSE,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
TYPE_FAUCET, TYPE_FAUCET,
TYPE_OUTLET, TYPE_OUTLET,
TYPE_SHOWER, TYPE_SHOWER,
@ -23,6 +21,7 @@ from homeassistant.components.homekit.const import (
TYPE_SWITCH, TYPE_SWITCH,
TYPE_VALVE, TYPE_VALVE,
) )
from homeassistant.components.homekit.models import HomeKitEntryData
from homeassistant.components.homekit.util import ( from homeassistant.components.homekit.util import (
accessory_friendly_name, accessory_friendly_name,
async_dismiss_setup_message, async_dismiss_setup_message,
@ -251,8 +250,9 @@ async def test_async_show_setup_msg(
hass, entry.entry_id, "bridge_name", pincode, "X-HM://0" hass, entry.entry_id, "bridge_name", pincode, "X-HM://0"
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR_SECRET] entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR] assert entry_data.pairing_qr_secret
assert entry_data.pairing_qr
assert len(mock_create.mock_calls) == 1 assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][1][3] == entry.entry_id assert mock_create.mock_calls[0][1][3] == entry.entry_id