Add strict typing to tradfri __init__ and switch (#56002)

* Add strict typing to __init__ and switch.

* Review comments.

* Review comments.

* Corrected switch.
This commit is contained in:
jan iversen 2021-09-18 23:24:35 +02:00 committed by GitHub
parent f6526de7b6
commit 9b710cad5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 18 deletions

View File

@ -1,7 +1,7 @@
"""Support for IKEA Tradfri.""" """Support for IKEA Tradfri."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any from typing import Any
@ -15,7 +15,7 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import Event, async_track_time_interval
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
@ -97,7 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
psk=entry.data[CONF_KEY], psk=entry.data[CONF_KEY],
) )
async def on_hass_stop(event): async def on_hass_stop(event: Event) -> None:
"""Close connection when hass stops.""" """Close connection when hass stops."""
await factory.shutdown() await factory.shutdown()
@ -135,7 +135,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.config_entries.async_setup_platforms(entry, PLATFORMS) hass.config_entries.async_setup_platforms(entry, PLATFORMS)
async def async_keep_alive(now): async def async_keep_alive(now: datetime) -> None:
if hass.is_stopping: if hass.is_stopping:
return return
@ -151,7 +151,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok: if unload_ok:

View File

@ -1,7 +1,15 @@
"""Base class for IKEA TRADFRI.""" """Base class for IKEA TRADFRI."""
from __future__ import annotations
from functools import wraps from functools import wraps
import logging import logging
from typing import Any, Callable
from pytradfri.command import Command
from pytradfri.device.blind import Blind
from pytradfri.device.light import Light
from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl
from pytradfri.error import PytradfriError from pytradfri.error import PytradfriError
from homeassistant.core import callback from homeassistant.core import callback
@ -34,12 +42,14 @@ class TradfriBaseClass(Entity):
_attr_should_poll = False _attr_should_poll = False
def __init__(self, device, api, gateway_id): def __init__(
self, device: Command, api: Callable[[str], Any], gateway_id: str
) -> None:
"""Initialize a device.""" """Initialize a device."""
self._api = handle_error(api) self._api = handle_error(api)
self._device = None self._device: Command | None = None
self._device_control = None self._device_control: SocketControl | None = None
self._device_data = None self._device_data: Socket | Light | Blind | None = None
self._gateway_id = gateway_id self._gateway_id = gateway_id
self._refresh(device) self._refresh(device)
@ -71,7 +81,7 @@ class TradfriBaseClass(Entity):
self._refresh(device) self._refresh(device)
self.async_write_ha_state() self.async_write_ha_state()
def _refresh(self, device): def _refresh(self, device: Command) -> None:
"""Refresh the device data.""" """Refresh the device data."""
self._device = device self._device = device
self._attr_name = device.name self._attr_name = device.name
@ -97,7 +107,7 @@ class TradfriBaseDevice(TradfriBaseClass):
"via_device": (DOMAIN, self._gateway_id), "via_device": (DOMAIN, self._gateway_id),
} }
def _refresh(self, device): def _refresh(self, device: Command) -> None:
"""Refresh the device data.""" """Refresh the device data."""
super()._refresh(device) super()._refresh(device)
self._attr_available = device.reachable self._attr_available = device.reachable

View File

@ -1,11 +1,24 @@
"""Support for IKEA Tradfri switches.""" """Support for IKEA Tradfri switches."""
from __future__ import annotations
from typing import Any, Callable, cast
from pytradfri.command import Command
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id] tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
@ -22,12 +35,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
class TradfriSwitch(TradfriBaseDevice, SwitchEntity): class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
"""The platform class required by Home Assistant.""" """The platform class required by Home Assistant."""
def __init__(self, device, api, gateway_id): def __init__(
self, device: Command, api: Callable[[str], Any], gateway_id: str
) -> None:
"""Initialize a switch.""" """Initialize a switch."""
super().__init__(device, api, gateway_id) super().__init__(device, api, gateway_id)
self._attr_unique_id = f"{gateway_id}-{device.id}" self._attr_unique_id = f"{gateway_id}-{device.id}"
def _refresh(self, device): def _refresh(self, device: Command) -> None:
"""Refresh the switch data.""" """Refresh the switch data."""
super()._refresh(device) super()._refresh(device)
@ -36,14 +51,20 @@ class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
self._device_data = device.socket_control.sockets[0] self._device_data = device.socket_control.sockets[0]
@property @property
def is_on(self): def is_on(self) -> bool:
"""Return true if switch is on.""" """Return true if switch is on."""
return self._device_data.state if not self._device_data:
return False
return cast(bool, self._device_data.state)
async def async_turn_off(self, **kwargs): async def async_turn_off(self, **kwargs: Any) -> None:
"""Instruct the switch to turn off.""" """Instruct the switch to turn off."""
if not self._device_control:
return None
await self._api(self._device_control.set_state(False)) await self._api(self._device_control.set_state(False))
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs: Any) -> None:
"""Instruct the switch to turn on.""" """Instruct the switch to turn on."""
if not self._device_control:
return None
await self._api(self._device_control.set_state(True)) await self._api(self._device_control.set_state(True))