Code refactoring and clean up

This commit is contained in:
Robin Lintermann 2025-04-30 12:03:30 +00:00
parent 2a96ea0418
commit 2fa18c6838
5 changed files with 72 additions and 99 deletions

View File

@ -5,26 +5,18 @@ from pysmarlaapi import Connection, Federwiege
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryError from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import config_validation as cv
from .const import DOMAIN, HOST, PLATFORMS from .const import HOST, PLATFORMS
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type FederwiegeConfigEntry = ConfigEntry[Federwiege] type FederwiegeConfigEntry = ConfigEntry[Federwiege]
async def async_setup_entry(hass: HomeAssistant, entry: FederwiegeConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: FederwiegeConfigEntry) -> bool:
"""Set up this integration using UI.""" """Set up this integration using UI."""
if hass.data.get(DOMAIN) is None: connection = Connection(HOST, token_str=entry.data.get(CONF_ACCESS_TOKEN, None))
hass.data.setdefault(DOMAIN, {})
try:
connection = Connection(HOST, token_str=entry.data.get(CONF_ACCESS_TOKEN, None))
except ValueError as e:
raise ConfigEntryError("Invalid token") from e
# Check if token still has access
if not await connection.get_token(): if not await connection.get_token():
raise ConfigEntryAuthFailed("Invalid authentication") raise ConfigEntryAuthFailed("Invalid authentication")
@ -36,7 +28,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: FederwiegeConfigEntry) -
await hass.config_entries.async_forward_entry_setups( await hass.config_entries.async_forward_entry_setups(
entry, entry,
list(PLATFORMS), PLATFORMS,
) )
return True return True
@ -46,11 +38,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: FederwiegeConfigEntry)
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms( unload_ok = await hass.config_entries.async_unload_platforms(
entry, entry,
list(PLATFORMS), PLATFORMS,
) )
if unload_ok: if unload_ok:
federwiege: Federwiege = entry.runtime_data federwiege = entry.runtime_data
federwiege.disconnect() federwiege.disconnect()
return unload_ok return unload_ok

View File

@ -7,7 +7,7 @@ from typing import Any
from pysmarlaapi import Connection from pysmarlaapi import Connection
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, exceptions from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN
from .const import DOMAIN, HOST from .const import DOMAIN, HOST
@ -15,68 +15,50 @@ from .const import DOMAIN, HOST
STEP_USER_DATA_SCHEMA = vol.Schema({CONF_ACCESS_TOKEN: str}) STEP_USER_DATA_SCHEMA = vol.Schema({CONF_ACCESS_TOKEN: str})
class SmarlaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Swing2Sleep Smarla.""" """Handle a config flow for Swing2Sleep Smarla."""
VERSION = 1 VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH
async def _handle_token(self, token: str) -> tuple[dict[str, str], dict[str, str]]:
"""Handle the token input."""
errors: dict[str, str] = {}
info: dict[str, str] = {}
try:
conn = Connection(url=HOST, token_b64=token)
except ValueError:
errors["base"] = "invalid_token"
return (errors, info)
if await conn.get_token():
info["serial_number"] = conn.token.serialNumber
info["token"] = conn.token.get_string()
else:
errors["base"] = "invalid_auth"
return (errors, info)
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> config_entries.ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
if user_input is None:
return self.async_show_form(
step_id="user",
data_schema=STEP_USER_DATA_SCHEMA,
)
errors: dict[str, str] = {} errors: dict[str, str] = {}
try: if user_input is not None:
info = await self.validate_input(user_input) errors, info = await self._handle_token(token=user_input[CONF_ACCESS_TOKEN])
except InvalidAuth:
errors["base"] = "invalid_auth"
except InvalidToken:
errors["base"] = "invalid_token"
if not errors: if not errors:
return self.async_create_entry( await self.async_set_unique_id(info["serial_number"])
title=info["title"], self._abort_if_unique_id_configured()
data={CONF_ACCESS_TOKEN: info.get(CONF_ACCESS_TOKEN)},
) return self.async_create_entry(
title=info["serial_number"],
data={CONF_ACCESS_TOKEN: info["token"]},
)
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
data_schema=STEP_USER_DATA_SCHEMA, data_schema=STEP_USER_DATA_SCHEMA,
errors=errors, errors=errors,
) )
async def validate_input(self, data: dict[str, Any]):
"""Validate the user input allows us to connect.
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
try:
conn = Connection(url=HOST, token_b64=data[CONF_ACCESS_TOKEN])
except ValueError as e:
raise InvalidToken from e
await self.async_set_unique_id(conn.token.serialNumber)
self._abort_if_unique_id_configured()
if not await conn.get_token():
raise InvalidAuth
return {
"title": conn.token.serialNumber,
CONF_ACCESS_TOKEN: conn.token.get_string(),
}
class InvalidAuth(exceptions.HomeAssistantError):
"""Error to indicate there is invalid auth."""
class InvalidToken(exceptions.HomeAssistantError):
"""Error to indicate there is an invalid token."""

View File

@ -1,6 +1,9 @@
"""Common base for entities.""" """Common base for entities."""
from typing import Any
from pysmarlaapi import Federwiege from pysmarlaapi import Federwiege
from pysmarlaapi.federwiege.classes import Property
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
@ -11,15 +14,18 @@ from .const import DEVICE_MODEL_NAME, DOMAIN, MANUFACTURER_NAME
class SmarlaBaseEntity(Entity): class SmarlaBaseEntity(Entity):
"""Common Base Entity class for defining Smarla device.""" """Common Base Entity class for defining Smarla device."""
_property: Property
_attr_should_poll = False
_attr_has_entity_name = True _attr_has_entity_name = True
def __init__( async def on_change(self, value: Any):
self, """Notify ha when state changes."""
federwiege: Federwiege, self.async_write_ha_state()
) -> None:
"""Initialise the entity."""
super().__init__()
def __init__(self, federwiege: Federwiege, prop: Property) -> None:
"""Initialise the entity."""
self._property = prop
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, federwiege.serial_number)}, identifiers={(DOMAIN, federwiege.serial_number)},
name=DEVICE_MODEL_NAME, name=DEVICE_MODEL_NAME,
@ -27,3 +33,11 @@ class SmarlaBaseEntity(Entity):
manufacturer=MANUFACTURER_NAME, manufacturer=MANUFACTURER_NAME,
serial_number=federwiege.serial_number, serial_number=federwiege.serial_number,
) )
async def async_added_to_hass(self) -> None:
"""Run when this Entity has been added to HA."""
await self._property.add_listener(self.on_change)
async def async_will_remove_from_hass(self) -> None:
"""Entity being removed from hass."""
await self._property.remove_listener(self.on_change)

View File

@ -44,7 +44,7 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up the Smarla switches from config entry.""" """Set up the Smarla switches from config entry."""
federwiege: Federwiege = config_entry.runtime_data federwiege = config_entry.runtime_data
async_add_entities(SmarlaSwitch(federwiege, desc) for desc in SWITCHES) async_add_entities(SmarlaSwitch(federwiege, desc) for desc in SWITCHES)
@ -52,33 +52,19 @@ class SmarlaSwitch(SmarlaBaseEntity, SwitchEntity):
"""Representation of Smarla switch.""" """Representation of Smarla switch."""
entity_description: SmarlaSwitchEntityDescription entity_description: SmarlaSwitchEntityDescription
_property: Property
_attr_should_poll = False
async def on_change(self, value: Any): _property: Property[bool]
"""Notify ha when state changes."""
self.async_write_ha_state()
def __init__( def __init__(
self, self,
federwiege: Federwiege, federwiege: Federwiege,
description: SmarlaSwitchEntityDescription, desc: SmarlaSwitchEntityDescription,
) -> None: ) -> None:
"""Initialize a Smarla switch.""" """Initialize a Smarla switch."""
super().__init__(federwiege) prop = federwiege.get_service(desc.service).get_property(desc.property)
self._property = federwiege.get_service(description.service).get_property( super().__init__(federwiege, prop)
description.property self.entity_description = desc
) self._attr_unique_id = f"{federwiege.serial_number}-{desc.key}"
self.entity_description = description
self._attr_unique_id = f"{federwiege.serial_number}-{description.key}"
async def async_added_to_hass(self) -> None:
"""Run when this Entity has been added to HA."""
await self._property.add_listener(self.on_change)
async def async_will_remove_from_hass(self) -> None:
"""Entity being removed from hass."""
await self._property.remove_listener(self.on_change)
@property @property
def is_on(self) -> bool: def is_on(self) -> bool:

View File

@ -78,13 +78,12 @@ async def test_device_exists_abort(hass: HomeAssistant) -> None:
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert len(hass.config_entries.async_entries(DOMAIN)) == 1 with patch.object(Connection, "get_token", new=AsyncMock(return_value=True)):
result = await hass.config_entries.flow.async_init(
result = await hass.config_entries.flow.async_init( DOMAIN,
DOMAIN, context={"source": config_entries.SOURCE_USER},
context={"source": config_entries.SOURCE_USER}, data={CONF_ACCESS_TOKEN: MOCK_ACCESS_TOKEN},
data={CONF_ACCESS_TOKEN: MOCK_ACCESS_TOKEN}, )
)
assert result["type"] is FlowResultType.ABORT assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"