Use dataclass for tradfri hass.data

This commit is contained in:
Martin Hjelmare 2022-06-21 22:25:29 +02:00
parent e04bb5932d
commit b66bdd444e
9 changed files with 58 additions and 60 deletions

View File

@ -3,7 +3,6 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from pytradfri import Gateway, RequestError
from pytradfri.api.aiocoap_api import APIFactory
@ -20,18 +19,9 @@ from homeassistant.helpers.dispatcher import (
)
from homeassistant.helpers.event import async_track_time_interval
from .const import (
CONF_GATEWAY_ID,
CONF_IDENTITY,
CONF_KEY,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
FACTORY,
KEY_API,
LOGGER,
)
from .const import CONF_GATEWAY_ID, CONF_IDENTITY, CONF_KEY, DOMAIN, LOGGER
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .models import TradfriData
PLATFORMS = [
Platform.COVER,
@ -49,15 +39,11 @@ async def async_setup_entry(
entry: ConfigEntry,
) -> bool:
"""Create a gateway."""
tradfri_data: dict[str, Any] = {}
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
factory = await APIFactory.init(
entry.data[CONF_HOST],
psk_id=entry.data[CONF_IDENTITY],
psk=entry.data[CONF_KEY],
)
tradfri_data[FACTORY] = factory # Used for async_unload_entry
async def on_hass_stop(event: Event) -> None:
"""Close connection when hass stops."""
@ -95,11 +81,7 @@ async def async_setup_entry(
remove_stale_devices(hass, entry, devices)
# Setup the device coordinators
coordinator_data = {
CONF_GATEWAY_ID: gateway,
KEY_API: api,
COORDINATOR_LIST: [],
}
coordinators: list[TradfriDeviceDataUpdateCoordinator] = []
for device in devices:
coordinator = TradfriDeviceDataUpdateCoordinator(
@ -110,9 +92,12 @@ async def async_setup_entry(
entry.async_on_unload(
async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available)
)
coordinator_data[COORDINATOR_LIST].append(coordinator)
coordinators.append(coordinator)
tradfri_data[COORDINATOR] = coordinator_data
tradfri_data = TradfriData(
api=api, coordinators=coordinators, factory=factory, gateway=gateway
)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
async def async_keep_alive(now: datetime) -> None:
if hass.is_stopping:
@ -140,8 +125,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
factory = tradfri_data[FACTORY]
tradfri_data: TradfriData = hass.data[DOMAIN].pop(entry.entry_id)
factory = tradfri_data.factory
await factory.shutdown()
return unload_ok

View File

@ -7,8 +7,4 @@ LOGGER = logging.getLogger(__package__)
CONF_GATEWAY_ID = "gateway_id"
CONF_IDENTITY = "identity"
CONF_KEY = "key"
COORDINATOR = "coordinator"
COORDINATOR_LIST = "coordinator_list"
DOMAIN = "tradfri"
FACTORY = "tradfri_factory"
KEY_API = "tradfri_api"

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry(
@ -23,8 +24,8 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api: APIRequestProtocol = coordinator_data[KEY_API]
tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data.api
async_add_entities(
TradfriCover(
@ -32,7 +33,7 @@ async def async_setup_entry(
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_blind_control
)

View File

@ -8,15 +8,15 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN
from .const import CONF_GATEWAY_ID, DOMAIN
from .models import TradfriData
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any]:
"""Return diagnostics the Tradfri platform."""
entry_data = hass.data[DOMAIN][entry.entry_id]
coordinator_data = entry_data[COORDINATOR]
tradfri_data: TradfriData = hass.data[DOMAIN][entry.entry_id]
device_registry = dr.async_get(hass)
device = cast(
@ -28,7 +28,7 @@ async def async_get_config_entry_diagnostics(
device_data: list = [
coordinator.device.device_info.model_number
for coordinator in coordinator_data[COORDINATOR_LIST]
for coordinator in tradfri_data.coordinators
]
return {

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity
from .models import TradfriData
ATTR_AUTO = "Auto"
ATTR_MAX_FAN_STEPS = 49
@ -36,8 +37,8 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api: APIRequestProtocol = coordinator_data[KEY_API]
tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data.api
async_add_entities(
TradfriAirPurifierFan(
@ -45,7 +46,7 @@ async def async_setup_entry(
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_air_purifier_control
)

View File

@ -21,9 +21,10 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import color as color_util
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry(
@ -33,8 +34,8 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api: APIRequestProtocol = coordinator_data[KEY_API]
tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data.api
async_add_entities(
TradfriLight(
@ -42,7 +43,7 @@ async def async_setup_entry(
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_light_control
)

View File

@ -0,0 +1,19 @@
"""Provide a model for the Tradfri integration."""
from __future__ import annotations
from dataclasses import dataclass
from pytradfri import Gateway
from pytradfri.api.aiocoap_api import APIFactory, APIRequestProtocol
from .coordinator import TradfriDeviceDataUpdateCoordinator
@dataclass
class TradfriData:
"""Data for the Tradfri integration."""
api: APIRequestProtocol
coordinators: list[TradfriDeviceDataUpdateCoordinator]
factory: APIFactory
gateway: Gateway

View File

@ -26,16 +26,10 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import (
CONF_GATEWAY_ID,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
KEY_API,
LOGGER,
)
from .const import CONF_GATEWAY_ID, DOMAIN, LOGGER
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity
from .models import TradfriData
@dataclass(frozen=True, kw_only=True)
@ -132,12 +126,12 @@ async def async_setup_entry(
) -> None:
"""Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api: APIRequestProtocol = coordinator_data[KEY_API]
tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data.api
entities: list[TradfriSensor] = []
for device_coordinator in coordinator_data[COORDINATOR_LIST]:
for device_coordinator in tradfri_data.coordinators:
if (
not device_coordinator.device.has_light_control
and not device_coordinator.device.has_socket_control

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry(
@ -23,8 +24,8 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api: APIRequestProtocol = coordinator_data[KEY_API]
tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data.api
async_add_entities(
TradfriSwitch(
@ -32,7 +33,7 @@ async def async_setup_entry(
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_socket_control
)