Refactor Shelly to use data class for ConfigEntry data (#79671)

* Refactor Shelly to use data class for ConfigEntry data

* Apply suggestions from code review

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/shelly/__init__.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Optimize usage of shelly_entry_data in _async_setup_block_entry

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Shay Levy 2022-10-06 10:10:58 +03:00 committed by GitHub
parent 9b4c7f5dc5
commit 93b2a6cc26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 164 additions and 194 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Final, cast from typing import Any, Final
from aiohttp import ClientResponseError from aiohttp import ClientResponseError
import aioshelly import aioshelly
@ -23,23 +23,20 @@ from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
AIOSHELLY_DEVICE_TIMEOUT_SEC, AIOSHELLY_DEVICE_TIMEOUT_SEC,
BLOCK,
CONF_COAP_PORT, CONF_COAP_PORT,
CONF_SLEEP_PERIOD, CONF_SLEEP_PERIOD,
DATA_CONFIG_ENTRY, DATA_CONFIG_ENTRY,
DEFAULT_COAP_PORT, DEFAULT_COAP_PORT,
DEVICE,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
REST,
RPC,
RPC_POLL,
) )
from .coordinator import ( from .coordinator import (
ShellyBlockCoordinator, ShellyBlockCoordinator,
ShellyEntryData,
ShellyRestCoordinator, ShellyRestCoordinator,
ShellyRpcCoordinator, ShellyRpcCoordinator,
ShellyRpcPollingCoordinator, ShellyRpcPollingCoordinator,
get_entry_data,
) )
from .utils import get_block_device_sleep_period, get_coap_context, get_device_entry_gen from .utils import get_block_device_sleep_period, get_coap_context, get_device_entry_gen
@ -101,16 +98,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
return False return False
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id] = {} get_entry_data(hass)[entry.entry_id] = ShellyEntryData()
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][DEVICE] = None
if get_device_entry_gen(entry) == 2: if get_device_entry_gen(entry) == 2:
return await async_setup_rpc_entry(hass, entry) return await _async_setup_rpc_entry(hass, entry)
return await async_setup_block_entry(hass, entry) return await _async_setup_block_entry(hass, entry)
async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Shelly block based device from a config entry.""" """Set up Shelly block based device from a config entry."""
temperature_unit = "C" if hass.config.units.is_metric else "F" temperature_unit = "C" if hass.config.units.is_metric else "F"
@ -146,11 +142,26 @@ async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bo
device_entry = None device_entry = None
sleep_period = entry.data.get(CONF_SLEEP_PERIOD) sleep_period = entry.data.get(CONF_SLEEP_PERIOD)
shelly_entry_data = get_entry_data(hass)[entry.entry_id]
@callback
def _async_block_device_setup() -> None:
"""Set up a block based device that is online."""
shelly_entry_data.block = ShellyBlockCoordinator(hass, entry, device)
shelly_entry_data.block.async_setup()
platforms = BLOCK_SLEEPING_PLATFORMS
if not entry.data.get(CONF_SLEEP_PERIOD):
shelly_entry_data.rest = ShellyRestCoordinator(hass, device, entry)
platforms = BLOCK_PLATFORMS
hass.config_entries.async_setup_platforms(entry, platforms)
@callback @callback
def _async_device_online(_: Any) -> None: def _async_device_online(_: Any) -> None:
LOGGER.debug("Device %s is online, resuming setup", entry.title) LOGGER.debug("Device %s is online, resuming setup", entry.title)
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][DEVICE] = None shelly_entry_data.device = None
if sleep_period is None: if sleep_period is None:
data = {**entry.data} data = {**entry.data}
@ -158,7 +169,7 @@ async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bo
data["model"] = device.settings["device"]["type"] data["model"] = device.settings["device"]["type"]
hass.config_entries.async_update_entry(entry, data=data) hass.config_entries.async_update_entry(entry, data=data)
async_block_device_setup(hass, entry, device) _async_block_device_setup()
if sleep_period == 0: if sleep_period == 0:
# Not a sleeping device, finish setup # Not a sleeping device, finish setup
@ -179,10 +190,10 @@ async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bo
if err.status == HTTPStatus.UNAUTHORIZED: if err.status == HTTPStatus.UNAUTHORIZED:
raise ConfigEntryAuthFailed from err raise ConfigEntryAuthFailed from err
async_block_device_setup(hass, entry, device) _async_block_device_setup()
elif sleep_period is None or device_entry is None: elif sleep_period is None or device_entry is None:
# Need to get sleep info or first time sleeping device setup, wait for device # Need to get sleep info or first time sleeping device setup, wait for device
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][DEVICE] = device shelly_entry_data.device = device
LOGGER.debug( LOGGER.debug(
"Setup for device %s will resume when device is online", entry.title "Setup for device %s will resume when device is online", entry.title
) )
@ -190,33 +201,12 @@ async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bo
else: else:
# Restore sensors for sleeping device # Restore sensors for sleeping device
LOGGER.debug("Setting up offline block device %s", entry.title) LOGGER.debug("Setting up offline block device %s", entry.title)
async_block_device_setup(hass, entry, device) _async_block_device_setup()
return True return True
@callback async def _async_setup_rpc_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
def async_block_device_setup(
hass: HomeAssistant, entry: ConfigEntry, device: BlockDevice
) -> None:
"""Set up a block based device that is online."""
block_coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][
BLOCK
] = ShellyBlockCoordinator(hass, entry, device)
block_coordinator.async_setup()
platforms = BLOCK_SLEEPING_PLATFORMS
if not entry.data.get(CONF_SLEEP_PERIOD):
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][
REST
] = ShellyRestCoordinator(hass, device, entry)
platforms = BLOCK_PLATFORMS
hass.config_entries.async_setup_platforms(entry, platforms)
async def async_setup_rpc_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Shelly RPC based device from a config entry.""" """Set up Shelly RPC based device from a config entry."""
options = aioshelly.common.ConnectionOptions( options = aioshelly.common.ConnectionOptions(
entry.data[CONF_HOST], entry.data[CONF_HOST],
@ -237,14 +227,11 @@ async def async_setup_rpc_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool
except (AuthRequired, InvalidAuthError) as err: except (AuthRequired, InvalidAuthError) as err:
raise ConfigEntryAuthFailed from err raise ConfigEntryAuthFailed from err
rpc_coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][ shelly_entry_data = get_entry_data(hass)[entry.entry_id]
RPC shelly_entry_data.rpc = ShellyRpcCoordinator(hass, entry, device)
] = ShellyRpcCoordinator(hass, entry, device) shelly_entry_data.rpc.async_setup()
rpc_coordinator.async_setup()
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][ shelly_entry_data.rpc_poll = ShellyRpcPollingCoordinator(hass, entry, device)
RPC_POLL
] = ShellyRpcPollingCoordinator(hass, entry, device)
hass.config_entries.async_setup_platforms(entry, RPC_PLATFORMS) hass.config_entries.async_setup_platforms(entry, RPC_PLATFORMS)
@ -253,73 +240,32 @@ async def async_setup_rpc_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
shelly_entry_data = get_entry_data(hass)[entry.entry_id]
if get_device_entry_gen(entry) == 2: if get_device_entry_gen(entry) == 2:
unload_ok = await hass.config_entries.async_unload_platforms( if unload_ok := await hass.config_entries.async_unload_platforms(
entry, RPC_PLATFORMS entry, RPC_PLATFORMS
) ):
if unload_ok: if shelly_entry_data.rpc:
await hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][RPC].shutdown() await shelly_entry_data.rpc.shutdown()
hass.data[DOMAIN][DATA_CONFIG_ENTRY].pop(entry.entry_id) get_entry_data(hass).pop(entry.entry_id)
return unload_ok return unload_ok
device = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id].get(DEVICE) if shelly_entry_data.device is not None:
if device is not None:
# If device is present, block coordinator is not setup yet # If device is present, block coordinator is not setup yet
device.shutdown() shelly_entry_data.device.shutdown()
return True return True
platforms = BLOCK_SLEEPING_PLATFORMS platforms = BLOCK_SLEEPING_PLATFORMS
if not entry.data.get(CONF_SLEEP_PERIOD): if not entry.data.get(CONF_SLEEP_PERIOD):
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][REST] = None shelly_entry_data.rest = None
platforms = BLOCK_PLATFORMS platforms = BLOCK_PLATFORMS
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms) if unload_ok := await hass.config_entries.async_unload_platforms(entry, platforms):
if unload_ok: if shelly_entry_data.block:
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][BLOCK].shutdown() shelly_entry_data.block.shutdown()
hass.data[DOMAIN][DATA_CONFIG_ENTRY].pop(entry.entry_id) get_entry_data(hass).pop(entry.entry_id)
return unload_ok return unload_ok
def get_block_device_coordinator(
hass: HomeAssistant, device_id: str
) -> ShellyBlockCoordinator | None:
"""Get a Shelly block device coordinator for the given device id."""
if not hass.data.get(DOMAIN):
return None
dev_reg = device_registry.async_get(hass)
if device := dev_reg.async_get(device_id):
for config_entry in device.config_entries:
if not hass.data[DOMAIN][DATA_CONFIG_ENTRY].get(config_entry):
continue
if coordinator := hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry].get(
BLOCK
):
return cast(ShellyBlockCoordinator, coordinator)
return None
def get_rpc_device_coordinator(
hass: HomeAssistant, device_id: str
) -> ShellyRpcCoordinator | None:
"""Get a Shelly RPC device coordinator for the given device id."""
if not hass.data.get(DOMAIN):
return None
dev_reg = device_registry.async_get(hass)
if device := dev_reg.async_get(device_id):
for config_entry in device.config_entries:
if not hass.data[DOMAIN][DATA_CONFIG_ENTRY].get(config_entry):
continue
if coordinator := hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry].get(
RPC
):
return cast(ShellyRpcCoordinator, coordinator)
return None

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Final, cast from typing import Final
from homeassistant.components.button import ( from homeassistant.components.button import (
ButtonDeviceClass, ButtonDeviceClass,
@ -18,8 +18,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from homeassistant.util import slugify from homeassistant.util import slugify
from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC, SHELLY_GAS_MODELS from .const import SHELLY_GAS_MODELS
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator, get_entry_data
from .utils import get_block_device_name, get_device_entry_gen, get_rpc_device_name from .utils import get_block_device_name, get_device_entry_gen, get_rpc_device_name
@ -80,15 +80,9 @@ async def async_setup_entry(
"""Set buttons for device.""" """Set buttons for device."""
coordinator: ShellyRpcCoordinator | ShellyBlockCoordinator | None = None coordinator: ShellyRpcCoordinator | ShellyBlockCoordinator | None = None
if get_device_entry_gen(config_entry) == 2: if get_device_entry_gen(config_entry) == 2:
if rpc_coordinator := hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
config_entry.entry_id
].get(RPC):
coordinator = cast(ShellyRpcCoordinator, rpc_coordinator)
else: else:
if block_coordinator := hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].block
config_entry.entry_id
].get(BLOCK):
coordinator = cast(ShellyBlockCoordinator, block_coordinator)
if coordinator is not None: if coordinator is not None:
entities = [] entities = []

View File

@ -26,15 +26,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import ( from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, LOGGER, SHTRV_01_TEMPERATURE_SETTINGS
AIOSHELLY_DEVICE_TIMEOUT_SEC, from .coordinator import ShellyBlockCoordinator, get_entry_data
BLOCK,
DATA_CONFIG_ENTRY,
DOMAIN,
LOGGER,
SHTRV_01_TEMPERATURE_SETTINGS,
)
from .coordinator import ShellyBlockCoordinator
from .utils import get_device_entry_gen from .utils import get_device_entry_gen
@ -48,10 +41,8 @@ async def async_setup_entry(
if get_device_entry_gen(config_entry) == 2: if get_device_entry_gen(config_entry) == 2:
return return
coordinator: ShellyBlockCoordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].block
config_entry.entry_id assert coordinator
][BLOCK]
if coordinator.device.initialized: if coordinator.device.initialized:
async_setup_climate_entities(async_add_entities, coordinator) async_setup_climate_entities(async_add_entities, coordinator)
else: else:

View File

@ -9,13 +9,7 @@ DOMAIN: Final = "shelly"
LOGGER: Logger = getLogger(__package__) LOGGER: Logger = getLogger(__package__)
BLOCK: Final = "block"
DATA_CONFIG_ENTRY: Final = "config_entry" DATA_CONFIG_ENTRY: Final = "config_entry"
DEVICE: Final = "device"
REST: Final = "rest"
RPC: Final = "rpc"
RPC_POLL: Final = "rpc_poll"
CONF_COAP_PORT: Final = "coap_port" CONF_COAP_PORT: Final = "coap_port"
DEFAULT_COAP_PORT: Final = 5683 DEFAULT_COAP_PORT: Final = 5683
FIRMWARE_PATTERN: Final = re.compile(r"^(\d{8})") FIRMWARE_PATTERN: Final = re.compile(r"^(\d{8})")

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Any, cast from typing import Any, cast
@ -27,6 +28,8 @@ from .const import (
ATTR_GENERATION, ATTR_GENERATION,
BATTERY_DEVICES_WITH_PERMANENT_CONNECTION, BATTERY_DEVICES_WITH_PERMANENT_CONNECTION,
CONF_SLEEP_PERIOD, CONF_SLEEP_PERIOD,
DATA_CONFIG_ENTRY,
DOMAIN,
DUAL_MODE_LIGHT_MODELS, DUAL_MODE_LIGHT_MODELS,
ENTRY_RELOAD_COOLDOWN, ENTRY_RELOAD_COOLDOWN,
EVENT_SHELLY_CLICK, EVENT_SHELLY_CLICK,
@ -45,6 +48,22 @@ from .const import (
from .utils import device_update_info, get_block_device_name, get_rpc_device_name from .utils import device_update_info, get_block_device_name, get_rpc_device_name
@dataclass
class ShellyEntryData:
"""Class for sharing data within a given config entry."""
block: ShellyBlockCoordinator | None = None
device: BlockDevice | None = None
rest: ShellyRestCoordinator | None = None
rpc: ShellyRpcCoordinator | None = None
rpc_poll: ShellyRpcPollingCoordinator | None = None
def get_entry_data(hass: HomeAssistant) -> dict[str, ShellyEntryData]:
"""Return Shelly entry data for a given config entry."""
return cast(dict[str, ShellyEntryData], hass.data[DOMAIN][DATA_CONFIG_ENTRY])
class ShellyBlockCoordinator(DataUpdateCoordinator): class ShellyBlockCoordinator(DataUpdateCoordinator):
"""Coordinator for a Shelly block based device.""" """Coordinator for a Shelly block based device."""
@ -532,3 +551,41 @@ class ShellyRpcPollingCoordinator(DataUpdateCoordinator):
def mac(self) -> str: def mac(self) -> str:
"""Mac address of the device.""" """Mac address of the device."""
return cast(str, self.entry.unique_id) return cast(str, self.entry.unique_id)
def get_block_coordinator_by_device_id(
hass: HomeAssistant, device_id: str
) -> ShellyBlockCoordinator | None:
"""Get a Shelly block device coordinator for the given device id."""
if not hass.data.get(DOMAIN):
return None
dev_reg = device_registry.async_get(hass)
if device := dev_reg.async_get(device_id):
for config_entry in device.config_entries:
if not (entry_data := get_entry_data(hass).get(config_entry)):
continue
if coordinator := entry_data.block:
return coordinator
return None
def get_rpc_coordinator_by_device_id(
hass: HomeAssistant, device_id: str
) -> ShellyRpcCoordinator | None:
"""Get a Shelly RPC device coordinator for the given device id."""
if not hass.data.get(DOMAIN):
return None
dev_reg = device_registry.async_get(hass)
if device := dev_reg.async_get(device_id):
for config_entry in device.config_entries:
if not (entry_data := get_entry_data(hass).get(config_entry)):
continue
if coordinator := entry_data.rpc:
return coordinator
return None

View File

@ -15,8 +15,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator, get_entry_data
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator
from .entity import ShellyBlockEntity, ShellyRpcEntity from .entity import ShellyBlockEntity, ShellyRpcEntity
from .utils import get_device_entry_gen, get_rpc_key_ids from .utils import get_device_entry_gen, get_rpc_key_ids
@ -40,7 +39,8 @@ def async_setup_block_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up cover for device.""" """Set up cover for device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][BLOCK] coordinator = get_entry_data(hass)[config_entry.entry_id].block
assert coordinator and coordinator.device.blocks
blocks = [block for block in coordinator.device.blocks if block.type == "roller"] blocks = [block for block in coordinator.device.blocks if block.type == "roller"]
if not blocks: if not blocks:
@ -56,8 +56,8 @@ def async_setup_rpc_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up entities for RPC device.""" """Set up entities for RPC device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][RPC] coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
assert coordinator
cover_key_ids = get_rpc_key_ids(coordinator.device.status, "cover") cover_key_ids = get_rpc_key_ids(coordinator.device.status, "cover")
if not cover_key_ids: if not cover_key_ids:

View File

@ -22,7 +22,6 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import get_block_device_coordinator, get_rpc_device_coordinator
from .const import ( from .const import (
ATTR_CHANNEL, ATTR_CHANNEL,
ATTR_CLICK_TYPE, ATTR_CLICK_TYPE,
@ -34,6 +33,10 @@ from .const import (
RPC_INPUTS_EVENTS_TYPES, RPC_INPUTS_EVENTS_TYPES,
SHBTN_MODELS, SHBTN_MODELS,
) )
from .coordinator import (
get_block_coordinator_by_device_id,
get_rpc_coordinator_by_device_id,
)
from .utils import ( from .utils import (
get_block_input_triggers, get_block_input_triggers,
get_rpc_input_triggers, get_rpc_input_triggers,
@ -78,7 +81,7 @@ async def async_validate_trigger_config(
trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) trigger = (config[CONF_TYPE], config[CONF_SUBTYPE])
if config[CONF_TYPE] in RPC_INPUTS_EVENTS_TYPES: if config[CONF_TYPE] in RPC_INPUTS_EVENTS_TYPES:
rpc_coordinator = get_rpc_device_coordinator(hass, config[CONF_DEVICE_ID]) rpc_coordinator = get_rpc_coordinator_by_device_id(hass, config[CONF_DEVICE_ID])
if not rpc_coordinator or not rpc_coordinator.device.initialized: if not rpc_coordinator or not rpc_coordinator.device.initialized:
return config return config
@ -87,7 +90,9 @@ async def async_validate_trigger_config(
return config return config
elif config[CONF_TYPE] in BLOCK_INPUTS_EVENTS_TYPES: elif config[CONF_TYPE] in BLOCK_INPUTS_EVENTS_TYPES:
block_coordinator = get_block_device_coordinator(hass, config[CONF_DEVICE_ID]) block_coordinator = get_block_coordinator_by_device_id(
hass, config[CONF_DEVICE_ID]
)
if not block_coordinator or not block_coordinator.device.initialized: if not block_coordinator or not block_coordinator.device.initialized:
return config return config
@ -109,12 +114,12 @@ async def async_get_triggers(
"""List device triggers for Shelly devices.""" """List device triggers for Shelly devices."""
triggers: list[dict[str, str]] = [] triggers: list[dict[str, str]] = []
if rpc_coordinator := get_rpc_device_coordinator(hass, device_id): if rpc_coordinator := get_rpc_coordinator_by_device_id(hass, device_id):
input_triggers = get_rpc_input_triggers(rpc_coordinator.device) input_triggers = get_rpc_input_triggers(rpc_coordinator.device)
append_input_triggers(triggers, input_triggers, device_id) append_input_triggers(triggers, input_triggers, device_id)
return triggers return triggers
if block_coordinator := get_block_device_coordinator(hass, device_id): if block_coordinator := get_block_coordinator_by_device_id(hass, device_id):
if block_coordinator.model in SHBTN_MODELS: if block_coordinator.model in SHBTN_MODELS:
input_triggers = get_shbtn_input_triggers() input_triggers = get_shbtn_input_triggers()
append_input_triggers(triggers, input_triggers, device_id) append_input_triggers(triggers, input_triggers, device_id)

View File

@ -6,8 +6,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC from .coordinator import get_entry_data
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator
TO_REDACT = {CONF_USERNAME, CONF_PASSWORD} TO_REDACT = {CONF_USERNAME, CONF_PASSWORD}
@ -16,12 +15,13 @@ async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ConfigEntry
) -> dict: ) -> dict:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
data: dict = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id] shelly_entry_data = get_entry_data(hass)[entry.entry_id]
device_settings: str | dict = "not initialized" device_settings: str | dict = "not initialized"
device_status: str | dict = "not initialized" device_status: str | dict = "not initialized"
if BLOCK in data: if shelly_entry_data.block:
block_coordinator: ShellyBlockCoordinator = data[BLOCK] block_coordinator = shelly_entry_data.block
assert block_coordinator
device_info = { device_info = {
"name": block_coordinator.name, "name": block_coordinator.name,
"model": block_coordinator.model, "model": block_coordinator.model,
@ -51,7 +51,8 @@ async def async_get_config_entry_diagnostics(
] ]
} }
else: else:
rpc_coordinator: ShellyRpcCoordinator = data[RPC] rpc_coordinator = shelly_entry_data.rpc
assert rpc_coordinator
device_info = { device_info = {
"name": rpc_coordinator.name, "name": rpc_coordinator.name,
"model": rpc_coordinator.model, "model": rpc_coordinator.model,

View File

@ -18,21 +18,12 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import ( from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, LOGGER
AIOSHELLY_DEVICE_TIMEOUT_SEC,
BLOCK,
DATA_CONFIG_ENTRY,
DOMAIN,
LOGGER,
REST,
RPC,
RPC_POLL,
)
from .coordinator import ( from .coordinator import (
ShellyBlockCoordinator, ShellyBlockCoordinator,
ShellyRestCoordinator,
ShellyRpcCoordinator, ShellyRpcCoordinator,
ShellyRpcPollingCoordinator, ShellyRpcPollingCoordinator,
get_entry_data,
) )
from .utils import ( from .utils import (
async_remove_shelly_entity, async_remove_shelly_entity,
@ -54,10 +45,8 @@ def async_setup_entry_attribute_entities(
], ],
) -> None: ) -> None:
"""Set up entities for attributes.""" """Set up entities for attributes."""
coordinator: ShellyBlockCoordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].block
config_entry.entry_id assert coordinator
][BLOCK]
if coordinator.device.initialized: if coordinator.device.initialized:
async_setup_block_attribute_entities( async_setup_block_attribute_entities(
hass, async_add_entities, coordinator, sensors, sensor_class hass, async_add_entities, coordinator, sensors, sensor_class
@ -166,13 +155,10 @@ def async_setup_entry_rpc(
sensor_class: Callable, sensor_class: Callable,
) -> None: ) -> None:
"""Set up entities for REST sensors.""" """Set up entities for REST sensors."""
coordinator: ShellyRpcCoordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
config_entry.entry_id assert coordinator
][RPC] polling_coordinator = get_entry_data(hass)[config_entry.entry_id].rpc_poll
assert polling_coordinator
polling_coordinator: ShellyRpcPollingCoordinator = hass.data[DOMAIN][
DATA_CONFIG_ENTRY
][config_entry.entry_id][RPC_POLL]
entities = [] entities = []
for sensor_id in sensors: for sensor_id in sensors:
@ -220,10 +206,8 @@ def async_setup_entry_rest(
sensor_class: Callable, sensor_class: Callable,
) -> None: ) -> None:
"""Set up entities for REST sensors.""" """Set up entities for REST sensors."""
coordinator: ShellyRestCoordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(hass)[config_entry.entry_id].rest
config_entry.entry_id assert coordinator
][REST]
entities = [] entities = []
for sensor_id in sensors: for sensor_id in sensors:
description = sensors.get(sensor_id) description = sensors.get(sensor_id)

View File

@ -26,9 +26,6 @@ from homeassistant.util.color import (
) )
from .const import ( from .const import (
BLOCK,
DATA_CONFIG_ENTRY,
DOMAIN,
DUAL_MODE_LIGHT_MODELS, DUAL_MODE_LIGHT_MODELS,
FIRMWARE_PATTERN, FIRMWARE_PATTERN,
KELVIN_MAX_VALUE, KELVIN_MAX_VALUE,
@ -39,11 +36,10 @@ from .const import (
MAX_TRANSITION_TIME, MAX_TRANSITION_TIME,
MODELS_SUPPORTING_LIGHT_TRANSITION, MODELS_SUPPORTING_LIGHT_TRANSITION,
RGBW_MODELS, RGBW_MODELS,
RPC,
SHBLB_1_RGB_EFFECTS, SHBLB_1_RGB_EFFECTS,
STANDARD_RGB_EFFECTS, STANDARD_RGB_EFFECTS,
) )
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator, get_entry_data
from .entity import ShellyBlockEntity, ShellyRpcEntity from .entity import ShellyBlockEntity, ShellyRpcEntity
from .utils import ( from .utils import (
async_remove_shelly_entity, async_remove_shelly_entity,
@ -77,8 +73,8 @@ def async_setup_block_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up entities for block device.""" """Set up entities for block device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][BLOCK] coordinator = get_entry_data(hass)[config_entry.entry_id].block
assert coordinator
blocks = [] blocks = []
assert coordinator.device.blocks assert coordinator.device.blocks
for block in coordinator.device.blocks: for block in coordinator.device.blocks:
@ -108,7 +104,8 @@ def async_setup_rpc_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up entities for RPC device.""" """Set up entities for RPC device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][RPC] coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
assert coordinator
switch_key_ids = get_rpc_key_ids(coordinator.device.status, "switch") switch_key_ids = get_rpc_key_ids(coordinator.device.status, "switch")
switch_ids = [] switch_ids = []

View File

@ -8,7 +8,6 @@ from homeassistant.const import ATTR_DEVICE_ID
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.typing import EventType from homeassistant.helpers.typing import EventType
from . import get_block_device_coordinator, get_rpc_device_coordinator
from .const import ( from .const import (
ATTR_CHANNEL, ATTR_CHANNEL,
ATTR_CLICK_TYPE, ATTR_CLICK_TYPE,
@ -18,6 +17,10 @@ from .const import (
EVENT_SHELLY_CLICK, EVENT_SHELLY_CLICK,
RPC_INPUTS_EVENTS_TYPES, RPC_INPUTS_EVENTS_TYPES,
) )
from .coordinator import (
get_block_coordinator_by_device_id,
get_rpc_coordinator_by_device_id,
)
from .utils import get_block_device_name, get_rpc_entity_name from .utils import get_block_device_name, get_rpc_entity_name
@ -37,13 +40,13 @@ def async_describe_events(
input_name = f"{event.data[ATTR_DEVICE]} channel {channel}" input_name = f"{event.data[ATTR_DEVICE]} channel {channel}"
if click_type in RPC_INPUTS_EVENTS_TYPES: if click_type in RPC_INPUTS_EVENTS_TYPES:
rpc_coordinator = get_rpc_device_coordinator(hass, device_id) rpc_coordinator = get_rpc_coordinator_by_device_id(hass, device_id)
if rpc_coordinator and rpc_coordinator.device.initialized: if rpc_coordinator and rpc_coordinator.device.initialized:
key = f"input:{channel-1}" key = f"input:{channel-1}"
input_name = get_rpc_entity_name(rpc_coordinator.device, key) input_name = get_rpc_entity_name(rpc_coordinator.device, key)
elif click_type in BLOCK_INPUTS_EVENTS_TYPES: elif click_type in BLOCK_INPUTS_EVENTS_TYPES:
block_coordinator = get_block_device_coordinator(hass, device_id) block_coordinator = get_block_coordinator_by_device_id(hass, device_id)
if block_coordinator and block_coordinator.device.initialized: if block_coordinator and block_coordinator.device.initialized:
device_name = get_block_device_name(block_coordinator.device) device_name = get_block_device_name(block_coordinator.device)
input_name = f"{device_name} channel {channel}" input_name = f"{device_name} channel {channel}"

View File

@ -10,8 +10,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator, get_entry_data
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator
from .entity import ShellyBlockEntity, ShellyRpcEntity from .entity import ShellyBlockEntity, ShellyRpcEntity
from .utils import ( from .utils import (
async_remove_shelly_entity, async_remove_shelly_entity,
@ -41,7 +40,8 @@ def async_setup_block_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up entities for block device.""" """Set up entities for block device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][BLOCK] coordinator = get_entry_data(hass)[config_entry.entry_id].block
assert coordinator
# In roller mode the relay blocks exist but do not contain required info # In roller mode the relay blocks exist but do not contain required info
if ( if (
@ -75,8 +75,8 @@ def async_setup_rpc_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up entities for RPC device.""" """Set up entities for RPC device."""
coordinator = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][RPC] coordinator = get_entry_data(hass)[config_entry.entry_id].rpc
assert coordinator
switch_key_ids = get_rpc_key_ids(coordinator.device.status, "switch") switch_key_ids = get_rpc_key_ids(coordinator.device.status, "switch")
switch_ids = [] switch_ids = []

View File

@ -17,8 +17,8 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import BLOCK, CONF_SLEEP_PERIOD, DATA_CONFIG_ENTRY, DOMAIN from .const import CONF_SLEEP_PERIOD
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator, get_entry_data
from .entity import ( from .entity import (
RestEntityDescription, RestEntityDescription,
RpcEntityDescription, RpcEntityDescription,
@ -178,11 +178,9 @@ class RestUpdateEntity(ShellyRestAttributeEntity, UpdateEntity):
) -> None: ) -> None:
"""Install the latest firmware version.""" """Install the latest firmware version."""
config_entry = self.block_coordinator.entry config_entry = self.block_coordinator.entry
block_coordinator = self.hass.data[DOMAIN][DATA_CONFIG_ENTRY][ coordinator = get_entry_data(self.hass)[config_entry.entry_id].block
config_entry.entry_id
].get(BLOCK)
self._in_progress_old_version = self.installed_version self._in_progress_old_version = self.installed_version
await self.entity_description.install(block_coordinator) await self.entity_description.install(coordinator)
class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity): class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity):