Use dataclass for tradfri hass.data

This commit is contained in:
Martin Hjelmare 2022-06-21 22:25:29 +02:00
parent e04bb5932d
commit b66bdd444e
9 changed files with 58 additions and 60 deletions

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any
from pytradfri import Gateway, RequestError from pytradfri import Gateway, RequestError
from pytradfri.api.aiocoap_api import APIFactory from pytradfri.api.aiocoap_api import APIFactory
@ -20,18 +19,9 @@ from homeassistant.helpers.dispatcher import (
) )
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from .const import ( from .const import CONF_GATEWAY_ID, CONF_IDENTITY, CONF_KEY, DOMAIN, LOGGER
CONF_GATEWAY_ID,
CONF_IDENTITY,
CONF_KEY,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
FACTORY,
KEY_API,
LOGGER,
)
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .models import TradfriData
PLATFORMS = [ PLATFORMS = [
Platform.COVER, Platform.COVER,
@ -49,15 +39,11 @@ async def async_setup_entry(
entry: ConfigEntry, entry: ConfigEntry,
) -> bool: ) -> bool:
"""Create a gateway.""" """Create a gateway."""
tradfri_data: dict[str, Any] = {}
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
factory = await APIFactory.init( factory = await APIFactory.init(
entry.data[CONF_HOST], entry.data[CONF_HOST],
psk_id=entry.data[CONF_IDENTITY], psk_id=entry.data[CONF_IDENTITY],
psk=entry.data[CONF_KEY], psk=entry.data[CONF_KEY],
) )
tradfri_data[FACTORY] = factory # Used for async_unload_entry
async def on_hass_stop(event: Event) -> None: async def on_hass_stop(event: Event) -> None:
"""Close connection when hass stops.""" """Close connection when hass stops."""
@ -95,11 +81,7 @@ async def async_setup_entry(
remove_stale_devices(hass, entry, devices) remove_stale_devices(hass, entry, devices)
# Setup the device coordinators # Setup the device coordinators
coordinator_data = { coordinators: list[TradfriDeviceDataUpdateCoordinator] = []
CONF_GATEWAY_ID: gateway,
KEY_API: api,
COORDINATOR_LIST: [],
}
for device in devices: for device in devices:
coordinator = TradfriDeviceDataUpdateCoordinator( coordinator = TradfriDeviceDataUpdateCoordinator(
@ -110,9 +92,12 @@ async def async_setup_entry(
entry.async_on_unload( entry.async_on_unload(
async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available) async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available)
) )
coordinator_data[COORDINATOR_LIST].append(coordinator) coordinators.append(coordinator)
tradfri_data[COORDINATOR] = coordinator_data tradfri_data = TradfriData(
api=api, coordinators=coordinators, factory=factory, gateway=gateway
)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
async def async_keep_alive(now: datetime) -> None: async def async_keep_alive(now: datetime) -> None:
if hass.is_stopping: if hass.is_stopping:
@ -140,8 +125,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok: if unload_ok:
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id) tradfri_data: TradfriData = hass.data[DOMAIN].pop(entry.entry_id)
factory = tradfri_data[FACTORY] factory = tradfri_data.factory
await factory.shutdown() await factory.shutdown()
return unload_ok return unload_ok

View File

@ -7,8 +7,4 @@ LOGGER = logging.getLogger(__package__)
CONF_GATEWAY_ID = "gateway_id" CONF_GATEWAY_ID = "gateway_id"
CONF_IDENTITY = "identity" CONF_IDENTITY = "identity"
CONF_KEY = "key" CONF_KEY = "key"
COORDINATOR = "coordinator"
COORDINATOR_LIST = "coordinator_list"
DOMAIN = "tradfri" DOMAIN = "tradfri"
FACTORY = "tradfri_factory"
KEY_API = "tradfri_api"

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry( async def async_setup_entry(
@ -23,8 +24,8 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri covers based on a config entry.""" """Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api: APIRequestProtocol = coordinator_data[KEY_API] api = tradfri_data.api
async_add_entities( async_add_entities(
TradfriCover( TradfriCover(
@ -32,7 +33,7 @@ async def async_setup_entry(
api, api,
gateway_id, gateway_id,
) )
for device_coordinator in coordinator_data[COORDINATOR_LIST] for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_blind_control if device_coordinator.device.has_blind_control
) )

View File

@ -8,15 +8,15 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN from .const import CONF_GATEWAY_ID, DOMAIN
from .models import TradfriData
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics the Tradfri platform.""" """Return diagnostics the Tradfri platform."""
entry_data = hass.data[DOMAIN][entry.entry_id] tradfri_data: TradfriData = hass.data[DOMAIN][entry.entry_id]
coordinator_data = entry_data[COORDINATOR]
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
device = cast( device = cast(
@ -28,7 +28,7 @@ async def async_get_config_entry_diagnostics(
device_data: list = [ device_data: list = [
coordinator.device.device_info.model_number coordinator.device.device_info.model_number
for coordinator in coordinator_data[COORDINATOR_LIST] for coordinator in tradfri_data.coordinators
] ]
return { return {

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity from .entity import TradfriBaseEntity
from .models import TradfriData
ATTR_AUTO = "Auto" ATTR_AUTO = "Auto"
ATTR_MAX_FAN_STEPS = 49 ATTR_MAX_FAN_STEPS = 49
@ -36,8 +37,8 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api: APIRequestProtocol = coordinator_data[KEY_API] api = tradfri_data.api
async_add_entities( async_add_entities(
TradfriAirPurifierFan( TradfriAirPurifierFan(
@ -45,7 +46,7 @@ async def async_setup_entry(
api, api,
gateway_id, gateway_id,
) )
for device_coordinator in coordinator_data[COORDINATOR_LIST] for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_air_purifier_control if device_coordinator.device.has_air_purifier_control
) )

View File

@ -21,9 +21,10 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import color as color_util from homeassistant.util import color as color_util
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry( async def async_setup_entry(
@ -33,8 +34,8 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri lights based on a config entry.""" """Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api: APIRequestProtocol = coordinator_data[KEY_API] api = tradfri_data.api
async_add_entities( async_add_entities(
TradfriLight( TradfriLight(
@ -42,7 +43,7 @@ async def async_setup_entry(
api, api,
gateway_id, gateway_id,
) )
for device_coordinator in coordinator_data[COORDINATOR_LIST] for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_light_control if device_coordinator.device.has_light_control
) )

View File

@ -0,0 +1,19 @@
"""Provide a model for the Tradfri integration."""
from __future__ import annotations
from dataclasses import dataclass
from pytradfri import Gateway
from pytradfri.api.aiocoap_api import APIFactory, APIRequestProtocol
from .coordinator import TradfriDeviceDataUpdateCoordinator
@dataclass
class TradfriData:
"""Data for the Tradfri integration."""
api: APIRequestProtocol
coordinators: list[TradfriDeviceDataUpdateCoordinator]
factory: APIFactory
gateway: Gateway

View File

@ -26,16 +26,10 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import ( from .const import CONF_GATEWAY_ID, DOMAIN, LOGGER
CONF_GATEWAY_ID,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
KEY_API,
LOGGER,
)
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity from .entity import TradfriBaseEntity
from .models import TradfriData
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@ -132,12 +126,12 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up a Tradfri config entry.""" """Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api: APIRequestProtocol = coordinator_data[KEY_API] api = tradfri_data.api
entities: list[TradfriSensor] = [] entities: list[TradfriSensor] = []
for device_coordinator in coordinator_data[COORDINATOR_LIST]: for device_coordinator in tradfri_data.coordinators:
if ( if (
not device_coordinator.device.has_light_control not device_coordinator.device.has_light_control
and not device_coordinator.device.has_socket_control and not device_coordinator.device.has_socket_control

View File

@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator from .coordinator import TradfriDeviceDataUpdateCoordinator
from .entity import TradfriBaseEntity from .entity import TradfriBaseEntity
from .models import TradfriData
async def async_setup_entry( async def async_setup_entry(
@ -23,8 +24,8 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id]
api: APIRequestProtocol = coordinator_data[KEY_API] api = tradfri_data.api
async_add_entities( async_add_entities(
TradfriSwitch( TradfriSwitch(
@ -32,7 +33,7 @@ async def async_setup_entry(
api, api,
gateway_id, gateway_id,
) )
for device_coordinator in coordinator_data[COORDINATOR_LIST] for device_coordinator in tradfri_data.coordinators
if device_coordinator.device.has_socket_control if device_coordinator.device.has_socket_control
) )