diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index 2c113b63727..4f5997b2fa1 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -1,7 +1,7 @@ """Support for IKEA Tradfri.""" from __future__ import annotations -from datetime import timedelta +from datetime import datetime, timedelta import logging from typing import Any @@ -15,7 +15,7 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.event import Event, async_track_time_interval from homeassistant.helpers.typing import ConfigType from .const import ( @@ -97,7 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: psk=entry.data[CONF_KEY], ) - async def on_hass_stop(event): + async def on_hass_stop(event: Event) -> None: """Close connection when hass stops.""" await factory.shutdown() @@ -135,7 +135,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.config_entries.async_setup_platforms(entry, PLATFORMS) - async def async_keep_alive(now): + async def async_keep_alive(now: datetime) -> None: if hass.is_stopping: return @@ -151,7 +151,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: diff --git a/homeassistant/components/tradfri/base_class.py b/homeassistant/components/tradfri/base_class.py index ed95e47abd5..eb1884cfc1b 100644 --- a/homeassistant/components/tradfri/base_class.py +++ b/homeassistant/components/tradfri/base_class.py @@ -1,7 +1,15 @@ """Base class for IKEA TRADFRI.""" +from __future__ import annotations + from functools import wraps import logging +from typing import Any, Callable +from pytradfri.command import Command +from pytradfri.device.blind import Blind +from pytradfri.device.light import Light +from pytradfri.device.socket import Socket +from pytradfri.device.socket_control import SocketControl from pytradfri.error import PytradfriError from homeassistant.core import callback @@ -34,12 +42,14 @@ class TradfriBaseClass(Entity): _attr_should_poll = False - def __init__(self, device, api, gateway_id): + def __init__( + self, device: Command, api: Callable[[str], Any], gateway_id: str + ) -> None: """Initialize a device.""" self._api = handle_error(api) - self._device = None - self._device_control = None - self._device_data = None + self._device: Command | None = None + self._device_control: SocketControl | None = None + self._device_data: Socket | Light | Blind | None = None self._gateway_id = gateway_id self._refresh(device) @@ -71,7 +81,7 @@ class TradfriBaseClass(Entity): self._refresh(device) self.async_write_ha_state() - def _refresh(self, device): + def _refresh(self, device: Command) -> None: """Refresh the device data.""" self._device = device self._attr_name = device.name @@ -97,7 +107,7 @@ class TradfriBaseDevice(TradfriBaseClass): "via_device": (DOMAIN, self._gateway_id), } - def _refresh(self, device): + def _refresh(self, device: Command) -> None: """Refresh the device data.""" super()._refresh(device) self._attr_available = device.reachable diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index 00e15f1b875..6dc934814f0 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -1,11 +1,24 @@ """Support for IKEA Tradfri switches.""" +from __future__ import annotations + +from typing import Any, Callable, cast + +from pytradfri.command import Command + from homeassistant.components.switch import SwitchEntity +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback from .base_class import TradfriBaseDevice from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API -async def async_setup_entry(hass, config_entry, async_add_entities): +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] tradfri_data = hass.data[DOMAIN][config_entry.entry_id] @@ -22,12 +35,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): class TradfriSwitch(TradfriBaseDevice, SwitchEntity): """The platform class required by Home Assistant.""" - def __init__(self, device, api, gateway_id): + def __init__( + self, device: Command, api: Callable[[str], Any], gateway_id: str + ) -> None: """Initialize a switch.""" super().__init__(device, api, gateway_id) self._attr_unique_id = f"{gateway_id}-{device.id}" - def _refresh(self, device): + def _refresh(self, device: Command) -> None: """Refresh the switch data.""" super()._refresh(device) @@ -36,14 +51,20 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity): self._device_data = device.socket_control.sockets[0] @property - def is_on(self): + def is_on(self) -> bool: """Return true if switch is on.""" - return self._device_data.state + if not self._device_data: + return False + return cast(bool, self._device_data.state) - async def async_turn_off(self, **kwargs): + async def async_turn_off(self, **kwargs: Any) -> None: """Instruct the switch to turn off.""" + if not self._device_control: + return None await self._api(self._device_control.set_state(False)) - async def async_turn_on(self, **kwargs): + async def async_turn_on(self, **kwargs: Any) -> None: """Instruct the switch to turn on.""" + if not self._device_control: + return None await self._api(self._device_control.set_state(True))