Refactor homekit to use a dataclass for entry data (#101738)

This commit is contained in:
J. Nick Koston 2023-10-10 09:14:37 -10:00 committed by GitHub
parent 6c65db2036
commit 265f6653c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 38 deletions

View File

@ -112,18 +112,16 @@ from .const import (
DEFAULT_HOMEKIT_MODE,
DEFAULT_PORT,
DOMAIN,
HOMEKIT,
HOMEKIT_MODE_ACCESSORY,
HOMEKIT_MODES,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
MANUFACTURER,
PERSIST_LOCK,
PERSIST_LOCK_DATA,
SERVICE_HOMEKIT_RESET_ACCESSORY,
SERVICE_HOMEKIT_UNPAIR,
SHUTDOWN_TIMEOUT,
)
from .iidmanager import AccessoryIIDStorage
from .models import HomeKitEntryData
from .type_triggers import DeviceTriggerAccessory
from .util import (
accessory_friendly_name,
@ -205,11 +203,8 @@ UNPAIR_SERVICE_SCHEMA = vol.All(
def _async_all_homekit_instances(hass: HomeAssistant) -> list[HomeKit]:
"""All active HomeKit instances."""
return [
data[HOMEKIT]
for data in hass.data[DOMAIN].values()
if isinstance(data, dict) and HOMEKIT in data
]
domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN]
return [data.homekit for data in domain_data.values()]
def _async_get_imported_entries_indices(
@ -231,7 +226,8 @@ def _async_get_imported_entries_indices(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HomeKit from yaml."""
hass.data.setdefault(DOMAIN, {})[PERSIST_LOCK] = asyncio.Lock()
hass.data[DOMAIN] = {}
hass.data[PERSIST_LOCK_DATA] = asyncio.Lock()
# Initialize the loader before loading entries to ensure
# there is no race where multiple entries try to load it
@ -352,7 +348,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, homekit.async_stop)
)
hass.data[DOMAIN][entry.entry_id] = {HOMEKIT: homekit}
entry_data = HomeKitEntryData(
homekit=homekit, pairing_qr=None, pairing_qr_secret=None
)
hass.data[DOMAIN][entry.entry_id] = entry_data
if hass.state == CoreState.running:
await homekit.async_start()
@ -372,7 +371,8 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> Non
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
async_dismiss_setup_message(hass, entry.entry_id)
homekit = hass.data[DOMAIN][entry.entry_id][HOMEKIT]
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
homekit = entry_data.homekit
if homekit.status == STATUS_RUNNING:
await homekit.async_stop()
@ -849,7 +849,7 @@ class HomeKit:
self._async_register_bridge()
_LOGGER.debug("Driver start for %s", self._name)
await self.driver.async_start()
async with self.hass.data[DOMAIN][PERSIST_LOCK]:
async with self.hass.data[PERSIST_LOCK_DATA]:
await self.hass.async_add_executor_job(self.driver.persist)
self.status = STATUS_RUNNING
@ -1162,14 +1162,16 @@ class HomeKitPairingQRView(HomeAssistantView):
if not request.query_string:
raise Unauthorized()
entry_id, secret = request.query_string.split("-")
hass: HomeAssistant = request.app["hass"]
domain_data: dict[str, HomeKitEntryData] = hass.data[DOMAIN]
if (
entry_id not in request.app["hass"].data[DOMAIN]
or secret
!= request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET]
not (entry_data := domain_data.get(entry_id))
or not secret
or not entry_data.pairing_qr_secret
or secret != entry_data.pairing_qr_secret
):
raise Unauthorized()
return web.Response(
body=request.app["hass"].data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR],
body=entry_data.pairing_qr,
content_type="image/svg+xml",
)

View File

@ -6,13 +6,10 @@ from homeassistant.const import CONF_DEVICES
DEBOUNCE_TIMEOUT = 0.5
DEVICE_PRECISION_LEEWAY = 6
DOMAIN = "homekit"
PERSIST_LOCK_DATA = f"{DOMAIN}_persist_lock"
HOMEKIT_FILE = ".homekit.state"
HOMEKIT_PAIRING_QR = "homekit-pairing-qr"
HOMEKIT_PAIRING_QR_SECRET = "homekit-pairing-qr-secret"
HOMEKIT = "homekit"
SHUTDOWN_TIMEOUT = 30
CONF_ENTRY_INDEX = "index"
PERSIST_LOCK = "persist_lock"
# ### Codecs ####
VIDEO_CODEC_COPY = "copy"

View File

@ -10,9 +10,9 @@ from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import HomeKit
from .accessories import HomeAccessory, HomeBridge
from .const import DOMAIN, HOMEKIT
from .const import DOMAIN
from .models import HomeKitEntryData
TO_REDACT = {"access_token", "entity_picture"}
@ -21,7 +21,8 @@ async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
homekit: HomeKit = hass.data[DOMAIN][entry.entry_id][HOMEKIT]
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
homekit = entry_data.homekit
data: dict[str, Any] = {
"status": homekit.status,
"config-entry": {

View File

@ -0,0 +1,15 @@
"""Models for the HomeKit component."""
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from . import HomeKit
@dataclass
class HomeKitEntryData:
"""Class to hold HomeKit data."""
homekit: "HomeKit"
pairing_qr: bytes | None = None
pairing_qr_secret: str | None = None

View File

@ -86,8 +86,6 @@ from .const import (
FEATURE_PLAY_PAUSE,
FEATURE_PLAY_STOP,
FEATURE_TOGGLE_MUTE,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
MAX_NAME_LENGTH,
TYPE_FAUCET,
TYPE_OUTLET,
@ -100,6 +98,7 @@ from .const import (
VIDEO_CODEC_H264_V4L2M2M,
VIDEO_CODEC_LIBX264,
)
from .models import HomeKitEntryData
_LOGGER = logging.getLogger(__name__)
@ -352,8 +351,10 @@ def async_show_setup_message(
url.svg(buffer, scale=5, module_color="#000", background="#FFF")
pairing_secret = secrets.token_hex(32)
hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR] = buffer.getvalue()
hass.data[DOMAIN][entry_id][HOMEKIT_PAIRING_QR_SECRET] = pairing_secret
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry_id]
entry_data.pairing_qr = buffer.getvalue()
entry_data.pairing_qr_secret = pairing_secret
message = (
f"To set up {bridge_name} in the Home App, "

View File

@ -28,12 +28,12 @@ from homeassistant.components.homekit.const import (
CONF_ADVERTISE_IP,
DEFAULT_PORT,
DOMAIN,
HOMEKIT,
HOMEKIT_MODE_ACCESSORY,
HOMEKIT_MODE_BRIDGE,
SERVICE_HOMEKIT_RESET_ACCESSORY,
SERVICE_HOMEKIT_UNPAIR,
)
from homeassistant.components.homekit.models import HomeKitEntryData
from homeassistant.components.homekit.type_triggers import DeviceTriggerAccessory
from homeassistant.components.homekit.util import get_persist_fullpath_for_entry_id
from homeassistant.components.light import (
@ -1799,10 +1799,8 @@ async def test_homekit_uses_system_zeroconf(
entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert (
hass.data[DOMAIN][entry.entry_id][HOMEKIT].driver.advertiser
== system_async_zc
)
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
assert entry_data.homekit.driver.advertiser == system_async_zc
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()

View File

@ -14,8 +14,6 @@ from homeassistant.components.homekit.const import (
DOMAIN,
FEATURE_ON_OFF,
FEATURE_PLAY_PAUSE,
HOMEKIT_PAIRING_QR,
HOMEKIT_PAIRING_QR_SECRET,
TYPE_FAUCET,
TYPE_OUTLET,
TYPE_SHOWER,
@ -23,6 +21,7 @@ from homeassistant.components.homekit.const import (
TYPE_SWITCH,
TYPE_VALVE,
)
from homeassistant.components.homekit.models import HomeKitEntryData
from homeassistant.components.homekit.util import (
accessory_friendly_name,
async_dismiss_setup_message,
@ -251,8 +250,9 @@ async def test_async_show_setup_msg(
hass, entry.entry_id, "bridge_name", pincode, "X-HM://0"
)
await hass.async_block_till_done()
assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR_SECRET]
assert hass.data[DOMAIN][entry.entry_id][HOMEKIT_PAIRING_QR]
entry_data: HomeKitEntryData = hass.data[DOMAIN][entry.entry_id]
assert entry_data.pairing_qr_secret
assert entry_data.pairing_qr
assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][1][3] == entry.entry_id