diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index c3bf6dc43fd..cd6adbc5fef 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Any from pytradfri import Gateway, RequestError from pytradfri.api.aiocoap_api import APIFactory @@ -20,18 +19,9 @@ from homeassistant.helpers.dispatcher import ( ) from homeassistant.helpers.event import async_track_time_interval -from .const import ( - CONF_GATEWAY_ID, - CONF_IDENTITY, - CONF_KEY, - COORDINATOR, - COORDINATOR_LIST, - DOMAIN, - FACTORY, - KEY_API, - LOGGER, -) +from .const import CONF_GATEWAY_ID, CONF_IDENTITY, CONF_KEY, DOMAIN, LOGGER from .coordinator import TradfriDeviceDataUpdateCoordinator +from .models import TradfriData PLATFORMS = [ Platform.COVER, @@ -49,15 +39,11 @@ async def async_setup_entry( entry: ConfigEntry, ) -> bool: """Create a gateway.""" - tradfri_data: dict[str, Any] = {} - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data - factory = await APIFactory.init( entry.data[CONF_HOST], psk_id=entry.data[CONF_IDENTITY], psk=entry.data[CONF_KEY], ) - tradfri_data[FACTORY] = factory # Used for async_unload_entry async def on_hass_stop(event: Event) -> None: """Close connection when hass stops.""" @@ -95,11 +81,7 @@ async def async_setup_entry( remove_stale_devices(hass, entry, devices) # Setup the device coordinators - coordinator_data = { - CONF_GATEWAY_ID: gateway, - KEY_API: api, - COORDINATOR_LIST: [], - } + coordinators: list[TradfriDeviceDataUpdateCoordinator] = [] for device in devices: coordinator = TradfriDeviceDataUpdateCoordinator( @@ -110,9 +92,12 @@ async def async_setup_entry( entry.async_on_unload( async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available) ) - coordinator_data[COORDINATOR_LIST].append(coordinator) + coordinators.append(coordinator) - tradfri_data[COORDINATOR] = coordinator_data + tradfri_data = TradfriData( + api=api, coordinators=coordinators, factory=factory, gateway=gateway + ) + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data async def async_keep_alive(now: datetime) -> None: if hass.is_stopping: @@ -140,8 +125,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: - tradfri_data = hass.data[DOMAIN].pop(entry.entry_id) - factory = tradfri_data[FACTORY] + tradfri_data: TradfriData = hass.data[DOMAIN].pop(entry.entry_id) + factory = tradfri_data.factory await factory.shutdown() return unload_ok diff --git a/homeassistant/components/tradfri/const.py b/homeassistant/components/tradfri/const.py index e42bb6f5f4d..9a9da766baf 100644 --- a/homeassistant/components/tradfri/const.py +++ b/homeassistant/components/tradfri/const.py @@ -7,8 +7,4 @@ LOGGER = logging.getLogger(__package__) CONF_GATEWAY_ID = "gateway_id" CONF_IDENTITY = "identity" CONF_KEY = "key" -COORDINATOR = "coordinator" -COORDINATOR_LIST = "coordinator_list" DOMAIN = "tradfri" -FACTORY = "tradfri_factory" -KEY_API = "tradfri_api" diff --git a/homeassistant/components/tradfri/cover.py b/homeassistant/components/tradfri/cover.py index 978a806595c..fac72dc4cb5 100644 --- a/homeassistant/components/tradfri/cover.py +++ b/homeassistant/components/tradfri/cover.py @@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .const import CONF_GATEWAY_ID, DOMAIN from .coordinator import TradfriDeviceDataUpdateCoordinator from .entity import TradfriBaseEntity +from .models import TradfriData async def async_setup_entry( @@ -23,8 +24,8 @@ async def async_setup_entry( ) -> None: """Load Tradfri covers based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] - api: APIRequestProtocol = coordinator_data[KEY_API] + tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data.api async_add_entities( TradfriCover( @@ -32,7 +33,7 @@ async def async_setup_entry( api, gateway_id, ) - for device_coordinator in coordinator_data[COORDINATOR_LIST] + for device_coordinator in tradfri_data.coordinators if device_coordinator.device.has_blind_control ) diff --git a/homeassistant/components/tradfri/diagnostics.py b/homeassistant/components/tradfri/diagnostics.py index 4d89fd0081f..bbd5d9bbb12 100644 --- a/homeassistant/components/tradfri/diagnostics.py +++ b/homeassistant/components/tradfri/diagnostics.py @@ -8,15 +8,15 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr -from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN +from .const import CONF_GATEWAY_ID, DOMAIN +from .models import TradfriData async def async_get_config_entry_diagnostics( hass: HomeAssistant, entry: ConfigEntry ) -> dict[str, Any]: """Return diagnostics the Tradfri platform.""" - entry_data = hass.data[DOMAIN][entry.entry_id] - coordinator_data = entry_data[COORDINATOR] + tradfri_data: TradfriData = hass.data[DOMAIN][entry.entry_id] device_registry = dr.async_get(hass) device = cast( @@ -28,7 +28,7 @@ async def async_get_config_entry_diagnostics( device_data: list = [ coordinator.device.device_info.model_number - for coordinator in coordinator_data[COORDINATOR_LIST] + for coordinator in tradfri_data.coordinators ] return { diff --git a/homeassistant/components/tradfri/fan.py b/homeassistant/components/tradfri/fan.py index c61cbc97dca..cc893657aba 100644 --- a/homeassistant/components/tradfri/fan.py +++ b/homeassistant/components/tradfri/fan.py @@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .const import CONF_GATEWAY_ID, DOMAIN from .coordinator import TradfriDeviceDataUpdateCoordinator from .entity import TradfriBaseEntity +from .models import TradfriData ATTR_AUTO = "Auto" ATTR_MAX_FAN_STEPS = 49 @@ -36,8 +37,8 @@ async def async_setup_entry( ) -> None: """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] - api: APIRequestProtocol = coordinator_data[KEY_API] + tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data.api async_add_entities( TradfriAirPurifierFan( @@ -45,7 +46,7 @@ async def async_setup_entry( api, gateway_id, ) - for device_coordinator in coordinator_data[COORDINATOR_LIST] + for device_coordinator in tradfri_data.coordinators if device_coordinator.device.has_air_purifier_control ) diff --git a/homeassistant/components/tradfri/light.py b/homeassistant/components/tradfri/light.py index b99e9c97082..a23d950b4d0 100644 --- a/homeassistant/components/tradfri/light.py +++ b/homeassistant/components/tradfri/light.py @@ -21,9 +21,10 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import color as color_util -from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .const import CONF_GATEWAY_ID, DOMAIN from .coordinator import TradfriDeviceDataUpdateCoordinator from .entity import TradfriBaseEntity +from .models import TradfriData async def async_setup_entry( @@ -33,8 +34,8 @@ async def async_setup_entry( ) -> None: """Load Tradfri lights based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] - api: APIRequestProtocol = coordinator_data[KEY_API] + tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data.api async_add_entities( TradfriLight( @@ -42,7 +43,7 @@ async def async_setup_entry( api, gateway_id, ) - for device_coordinator in coordinator_data[COORDINATOR_LIST] + for device_coordinator in tradfri_data.coordinators if device_coordinator.device.has_light_control ) diff --git a/homeassistant/components/tradfri/models.py b/homeassistant/components/tradfri/models.py new file mode 100644 index 00000000000..aaeaf915481 --- /dev/null +++ b/homeassistant/components/tradfri/models.py @@ -0,0 +1,19 @@ +"""Provide a model for the Tradfri integration.""" +from __future__ import annotations + +from dataclasses import dataclass + +from pytradfri import Gateway +from pytradfri.api.aiocoap_api import APIFactory, APIRequestProtocol + +from .coordinator import TradfriDeviceDataUpdateCoordinator + + +@dataclass +class TradfriData: + """Data for the Tradfri integration.""" + + api: APIRequestProtocol + coordinators: list[TradfriDeviceDataUpdateCoordinator] + factory: APIFactory + gateway: Gateway diff --git a/homeassistant/components/tradfri/sensor.py b/homeassistant/components/tradfri/sensor.py index def5fa8c64d..20993d9f409 100644 --- a/homeassistant/components/tradfri/sensor.py +++ b/homeassistant/components/tradfri/sensor.py @@ -26,16 +26,10 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import ( - CONF_GATEWAY_ID, - COORDINATOR, - COORDINATOR_LIST, - DOMAIN, - KEY_API, - LOGGER, -) +from .const import CONF_GATEWAY_ID, DOMAIN, LOGGER from .coordinator import TradfriDeviceDataUpdateCoordinator from .entity import TradfriBaseEntity +from .models import TradfriData @dataclass(frozen=True, kw_only=True) @@ -132,12 +126,12 @@ async def async_setup_entry( ) -> None: """Set up a Tradfri config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] - api: APIRequestProtocol = coordinator_data[KEY_API] + tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data.api entities: list[TradfriSensor] = [] - for device_coordinator in coordinator_data[COORDINATOR_LIST]: + for device_coordinator in tradfri_data.coordinators: if ( not device_coordinator.device.has_light_control and not device_coordinator.device.has_socket_control diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index a7828deb15d..c7b43e38f1e 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -11,9 +11,10 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API +from .const import CONF_GATEWAY_ID, DOMAIN from .coordinator import TradfriDeviceDataUpdateCoordinator from .entity import TradfriBaseEntity +from .models import TradfriData async def async_setup_entry( @@ -23,8 +24,8 @@ async def async_setup_entry( ) -> None: """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR] - api: APIRequestProtocol = coordinator_data[KEY_API] + tradfri_data: TradfriData = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data.api async_add_entities( TradfriSwitch( @@ -32,7 +33,7 @@ async def async_setup_entry( api, gateway_id, ) - for device_coordinator in coordinator_data[COORDINATOR_LIST] + for device_coordinator in tradfri_data.coordinators if device_coordinator.device.has_socket_control )