Migrate Airgradient select entities to be config source dependent (#120462)

Co-authored-by: Robert Resch <robert@resch.dev>
This commit is contained in:
Joost Lekkerkerker 2024-06-25 22:27:52 +02:00 committed by GitHub
parent 4290a1fcb5
commit 1f0e47b251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 69 deletions

View File

@ -6,10 +6,14 @@ from dataclasses import dataclass
from airgradient import AirGradientClient, Config
from airgradient.models import ConfigurationControl, LedBarMode, TemperatureUnit
from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SelectEntity,
SelectEntityDescription,
)
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import AirGradientConfigEntry
@ -24,8 +28,6 @@ class AirGradientSelectEntityDescription(SelectEntityDescription):
value_fn: Callable[[Config], str | None]
set_value_fn: Callable[[AirGradientClient, str], Awaitable[None]]
requires_display: bool = False
requires_led_bar: bool = False
CONFIG_CONTROL_ENTITY = AirGradientSelectEntityDescription(
@ -43,7 +45,7 @@ CONFIG_CONTROL_ENTITY = AirGradientSelectEntityDescription(
),
)
PROTECTED_SELECT_TYPES: tuple[AirGradientSelectEntityDescription, ...] = (
DISPLAY_SELECT_TYPES: tuple[AirGradientSelectEntityDescription, ...] = (
AirGradientSelectEntityDescription(
key="display_temperature_unit",
translation_key="display_temperature_unit",
@ -53,7 +55,6 @@ PROTECTED_SELECT_TYPES: tuple[AirGradientSelectEntityDescription, ...] = (
set_value_fn=lambda client, value: client.set_temperature_unit(
TemperatureUnit(value)
),
requires_display=True,
),
AirGradientSelectEntityDescription(
key="display_pm_standard",
@ -64,8 +65,10 @@ PROTECTED_SELECT_TYPES: tuple[AirGradientSelectEntityDescription, ...] = (
set_value_fn=lambda client, value: client.set_pm_standard(
PM_STANDARD_REVERSE[value]
),
requires_display=True,
),
)
LED_BAR_ENTITIES: tuple[AirGradientSelectEntityDescription, ...] = (
AirGradientSelectEntityDescription(
key="led_bar_mode",
translation_key="led_bar_mode",
@ -73,7 +76,6 @@ PROTECTED_SELECT_TYPES: tuple[AirGradientSelectEntityDescription, ...] = (
entity_category=EntityCategory.CONFIG,
value_fn=lambda config: config.led_bar_mode,
set_value_fn=lambda client, value: client.set_led_bar_mode(LedBarMode(value)),
requires_led_bar=True,
),
)
@ -85,22 +87,52 @@ async def async_setup_entry(
) -> None:
"""Set up AirGradient select entities based on a config entry."""
config_coordinator = entry.runtime_data.config
coordinator = entry.runtime_data.config
measurement_coordinator = entry.runtime_data.measurement
entities = [AirGradientSelect(config_coordinator, CONFIG_CONTROL_ENTITY)]
async_add_entities([AirGradientSelect(coordinator, CONFIG_CONTROL_ENTITY)])
model = measurement_coordinator.data.model
added_entities = False
@callback
def _async_check_entities() -> None:
nonlocal added_entities
entities.extend(
AirGradientProtectedSelect(config_coordinator, description)
for description in PROTECTED_SELECT_TYPES
if (
description.requires_display
and measurement_coordinator.data.model.startswith("I")
)
or (description.requires_led_bar and "L" in measurement_coordinator.data.model)
)
coordinator.data.configuration_control is ConfigurationControl.LOCAL
and not added_entities
):
entities: list[AirGradientSelect] = []
if "I" in model:
entities.extend(
AirGradientSelect(coordinator, description)
for description in DISPLAY_SELECT_TYPES
)
if "L" in model:
entities.extend(
AirGradientSelect(coordinator, description)
for description in LED_BAR_ENTITIES
)
async_add_entities(entities)
async_add_entities(entities)
added_entities = True
elif (
coordinator.data.configuration_control is not ConfigurationControl.LOCAL
and added_entities
):
entity_registry = er.async_get(hass)
for entity_description in DISPLAY_SELECT_TYPES + LED_BAR_ENTITIES:
unique_id = f"{coordinator.serial_number}-{entity_description.key}"
if entity_id := entity_registry.async_get_entity_id(
SELECT_DOMAIN, DOMAIN, unique_id
):
entity_registry.async_remove(entity_id)
added_entities = False
coordinator.async_add_listener(_async_check_entities)
_async_check_entities()
class AirGradientSelect(AirGradientEntity, SelectEntity):
@ -128,19 +160,3 @@ class AirGradientSelect(AirGradientEntity, SelectEntity):
"""Change the selected option."""
await self.entity_description.set_value_fn(self.coordinator.client, option)
await self.coordinator.async_request_refresh()
class AirGradientProtectedSelect(AirGradientSelect):
"""Defines a protected AirGradient select entity."""
async def async_select_option(self, option: str) -> None:
"""Change the selected option."""
if (
self.coordinator.data.configuration_control
is not ConfigurationControl.LOCAL
):
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="no_local_configuration",
)
await super().async_select_option(option)

View File

@ -125,10 +125,5 @@
"name": "[%key:component::airgradient::entity::number::display_brightness::name%]"
}
}
},
"exceptions": {
"no_local_configuration": {
"message": "Device should be configured with local configuration to be able to change settings."
}
}
}

View File

@ -1,23 +1,30 @@
"""Tests for the AirGradient select platform."""
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from airgradient import ConfigurationControl
from airgradient import Config
from freezegun.api import FrozenDateTimeFactory
import pytest
from syrupy import SnapshotAssertion
from homeassistant.components.airgradient import DOMAIN
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import ATTR_ENTITY_ID, ATTR_OPTION, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import entity_registry as er
from . import setup_integration
from tests.common import MockConfigEntry, snapshot_platform
from tests.common import (
MockConfigEntry,
async_fire_time_changed,
load_fixture,
snapshot_platform,
)
@pytest.mark.usefixtures("entity_registry_enabled_by_default")
@ -56,37 +63,34 @@ async def test_setting_value(
assert mock_airgradient_client.get_config.call_count == 2
async def test_setting_protected_value(
async def test_cloud_creates_no_number(
hass: HomeAssistant,
mock_cloud_airgradient_client: AsyncMock,
mock_config_entry: MockConfigEntry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test setting protected value."""
await setup_integration(hass, mock_config_entry)
"""Test cloud configuration control."""
with patch("homeassistant.components.airgradient.PLATFORMS", [Platform.SELECT]):
await setup_integration(hass, mock_config_entry)
with pytest.raises(ServiceValidationError):
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{
ATTR_ENTITY_ID: "select.airgradient_display_temperature_unit",
ATTR_OPTION: "c",
},
blocking=True,
)
mock_cloud_airgradient_client.set_temperature_unit.assert_not_called()
assert len(hass.states.async_all()) == 1
mock_cloud_airgradient_client.get_config.return_value.configuration_control = (
ConfigurationControl.LOCAL
mock_cloud_airgradient_client.get_config.return_value = Config.from_json(
load_fixture("get_config_local.json", DOMAIN)
)
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{
ATTR_ENTITY_ID: "select.airgradient_display_temperature_unit",
ATTR_OPTION: "c",
},
blocking=True,
freezer.tick(timedelta(minutes=5))
async_fire_time_changed(hass)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 4
mock_cloud_airgradient_client.get_config.return_value = Config.from_json(
load_fixture("get_config_cloud.json", DOMAIN)
)
mock_cloud_airgradient_client.set_temperature_unit.assert_called_once_with("c")
freezer.tick(timedelta(minutes=5))
async_fire_time_changed(hass)
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 1