Improve tplink_lte typing (#108393)

This commit is contained in:
Marc Mueller 2024-01-19 19:06:52 +01:00 committed by GitHub
parent f0077ac27e
commit 7e60979abe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 21 deletions

View File

@ -1,6 +1,9 @@
"""Support for TP-Link LTE modems.""" """Support for TP-Link LTE modems."""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Any
import aiohttp import aiohttp
import attr import attr
@ -15,7 +18,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.helpers.aiohttp_client import async_create_clientsession
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -59,20 +62,20 @@ CONFIG_SCHEMA = vol.Schema(
class ModemData: class ModemData:
"""Class for modem state.""" """Class for modem state."""
host = attr.ib() host: str = attr.ib()
modem = attr.ib() modem: tp_connected.Modem = attr.ib()
connected = attr.ib(init=False, default=True) connected: bool = attr.ib(init=False, default=True)
@attr.s @attr.s
class LTEData: class LTEData:
"""Shared state.""" """Shared state."""
websession = attr.ib() websession: aiohttp.ClientSession = attr.ib()
modem_data: dict[str, ModemData] = attr.ib(init=False, factory=dict) modem_data: dict[str, ModemData] = attr.ib(init=False, factory=dict)
def get_modem_data(self, config): def get_modem_data(self, config: dict[str, Any]) -> ModemData | None:
"""Get the requested or the only modem_data value.""" """Get the requested or the only modem_data value."""
if CONF_HOST in config: if CONF_HOST in config:
return self.modem_data.get(config[CONF_HOST]) return self.modem_data.get(config[CONF_HOST])
@ -107,14 +110,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def _setup_lte(hass, lte_config, delay=0): async def _setup_lte(
hass: HomeAssistant, lte_config: dict[str, Any], delay: int = 0
) -> None:
"""Set up a TP-Link LTE modem.""" """Set up a TP-Link LTE modem."""
host = lte_config[CONF_HOST] host: str = lte_config[CONF_HOST]
password = lte_config[CONF_PASSWORD] password: str = lte_config[CONF_PASSWORD]
websession = hass.data[DATA_KEY].websession lte_data: LTEData = hass.data[DATA_KEY]
modem = tp_connected.Modem(hostname=host, websession=websession) modem = tp_connected.Modem(hostname=host, websession=lte_data.websession)
modem_data = ModemData(host, modem) modem_data = ModemData(host, modem)
@ -124,7 +129,7 @@ async def _setup_lte(hass, lte_config, delay=0):
retry_task = hass.loop.create_task(_retry_login(hass, modem_data, password)) retry_task = hass.loop.create_task(_retry_login(hass, modem_data, password))
@callback @callback
def cleanup_retry(event): def cleanup_retry(event: Event) -> None:
"""Clean up retry task resources.""" """Clean up retry task resources."""
if not retry_task.done(): if not retry_task.done():
retry_task.cancel() retry_task.cancel()
@ -132,20 +137,23 @@ async def _setup_lte(hass, lte_config, delay=0):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, cleanup_retry) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, cleanup_retry)
async def _login(hass, modem_data, password): async def _login(hass: HomeAssistant, modem_data: ModemData, password: str) -> None:
"""Log in and complete setup.""" """Log in and complete setup."""
await modem_data.modem.login(password=password) await modem_data.modem.login(password=password)
modem_data.connected = True modem_data.connected = True
hass.data[DATA_KEY].modem_data[modem_data.host] = modem_data lte_data: LTEData = hass.data[DATA_KEY]
lte_data.modem_data[modem_data.host] = modem_data
async def cleanup(event): async def cleanup(event: Event) -> None:
"""Clean up resources.""" """Clean up resources."""
await modem_data.modem.logout() await modem_data.modem.logout()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, cleanup) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, cleanup)
async def _retry_login(hass, modem_data, password): async def _retry_login(
hass: HomeAssistant, modem_data: ModemData, password: str
) -> None:
"""Sleep and retry setup.""" """Sleep and retry setup."""
_LOGGER.warning("Could not connect to %s. Will keep trying", modem_data.host) _LOGGER.warning("Could not connect to %s. Will keep trying", modem_data.host)

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
import attr import attr
import tp_connected import tp_connected
@ -11,7 +12,7 @@ from homeassistant.const import CONF_RECIPIENT
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DATA_KEY from . import DATA_KEY, LTEData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -31,13 +32,14 @@ async def async_get_service(
class TplinkNotifyService(BaseNotificationService): class TplinkNotifyService(BaseNotificationService):
"""Implementation of a notification service.""" """Implementation of a notification service."""
hass = attr.ib() hass: HomeAssistant = attr.ib()
config = attr.ib() config: dict[str, Any] = attr.ib()
async def async_send_message(self, message="", **kwargs): async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Send a message to a user.""" """Send a message to a user."""
modem_data = self.hass.data[DATA_KEY].get_modem_data(self.config) lte_data: LTEData = self.hass.data[DATA_KEY]
modem_data = lte_data.get_modem_data(self.config)
if not modem_data: if not modem_data:
_LOGGER.error("No modem available") _LOGGER.error("No modem available")
return return