diff --git a/homeassistant/components/tplink/siren.py b/homeassistant/components/tplink/siren.py index d1ce03c1469..027fa2dd58f 100644 --- a/homeassistant/components/tplink/siren.py +++ b/homeassistant/components/tplink/siren.py @@ -4,20 +4,27 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass -from typing import Any +import math +from typing import TYPE_CHECKING, Any, cast from kasa import Device, Module from homeassistant.components.siren import ( + ATTR_DURATION, + ATTR_TONE, + ATTR_VOLUME_LEVEL, DOMAIN as SIREN_DOMAIN, SirenEntity, SirenEntityDescription, SirenEntityFeature, + SirenTurnOnServiceParameters, ) from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers.entity_platform import AddEntitiesCallback from . import TPLinkConfigEntry, legacy_device_id +from .const import DOMAIN from .coordinator import TPLinkDataUpdateCoordinator from .entity import ( CoordinatedTPLinkModuleEntity, @@ -86,7 +93,13 @@ class TPLinkSirenEntity(CoordinatedTPLinkModuleEntity, SirenEntity): """Representation of a tplink siren entity.""" _attr_name = None - _attr_supported_features = SirenEntityFeature.TURN_OFF | SirenEntityFeature.TURN_ON + _attr_supported_features = ( + SirenEntityFeature.TURN_OFF + | SirenEntityFeature.TURN_ON + | SirenEntityFeature.TONES + | SirenEntityFeature.DURATION + | SirenEntityFeature.VOLUME_SET + ) entity_description: TPLinkSirenEntityDescription @@ -102,10 +115,38 @@ class TPLinkSirenEntity(CoordinatedTPLinkModuleEntity, SirenEntity): super().__init__(device, coordinator, description, parent=parent) self._alarm_module = device.modules[Module.Alarm] + alarm_vol_feat = self._alarm_module.get_feature("alarm_volume") + alarm_duration_feat = self._alarm_module.get_feature("alarm_duration") + if TYPE_CHECKING: + assert alarm_vol_feat + assert alarm_duration_feat + self._alarm_volume_max = alarm_vol_feat.maximum_value + self._alarm_duration_max = alarm_duration_feat.maximum_value + @async_refresh_after async def async_turn_on(self, **kwargs: Any) -> None: """Turn the siren on.""" - await self._alarm_module.play() + turn_on_params = cast(SirenTurnOnServiceParameters, kwargs) + if (volume := kwargs.get(ATTR_VOLUME_LEVEL)) is not None: + # service parameter is a % so we round up to the nearest int + volume = math.ceil(volume * self._alarm_volume_max) + + if (duration := kwargs.get(ATTR_DURATION)) is not None: + if duration < 1 or duration > self._alarm_duration_max: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="invalid_alarm_duration", + translation_placeholders={ + "duration": str(duration), + "duration_max": str(self._alarm_duration_max), + }, + ) + + await self._alarm_module.play( + duration=turn_on_params.get(ATTR_DURATION), + volume=volume, + sound=kwargs.get(ATTR_TONE), + ) @async_refresh_after async def async_turn_off(self, **kwargs: Any) -> None: @@ -116,4 +157,8 @@ class TPLinkSirenEntity(CoordinatedTPLinkModuleEntity, SirenEntity): def _async_update_attrs(self) -> bool: """Update the entity's attributes.""" self._attr_is_on = self._alarm_module.active + # alarm_sounds returns list[str], so we need to widen the type + self._attr_available_tones = cast( + list[str | int], self._alarm_module.alarm_sounds + ) return True diff --git a/homeassistant/components/tplink/strings.json b/homeassistant/components/tplink/strings.json index 9c32dd5bbf4..fa284a3cc83 100644 --- a/homeassistant/components/tplink/strings.json +++ b/homeassistant/components/tplink/strings.json @@ -367,6 +367,9 @@ }, "unsupported_mode": { "message": "Tried to set unsupported mode: {mode}" + }, + "invalid_alarm_duration": { + "message": "Invalid duration {duration} available: 1-{duration_max}s" } }, "issues": { diff --git a/tests/components/tplink/__init__.py b/tests/components/tplink/__init__.py index a056555f4c0..008d25a3dcb 100644 --- a/tests/components/tplink/__init__.py +++ b/tests/components/tplink/__init__.py @@ -178,12 +178,6 @@ def _mocked_device( device_config.host = ip_address device.host = ip_address - if modules: - device.modules = { - module_name: MODULE_TO_MOCK_GEN[module_name](device) - for module_name in modules - } - device_features = {} if features: device_features = { @@ -201,6 +195,13 @@ def _mocked_device( ) device.features = device_features + # Add modules after features so modules can add required features + if modules: + device.modules = { + module_name: MODULE_TO_MOCK_GEN[module_name](device) + for module_name in modules + } + for mod in device.modules.values(): mod.get_feature.side_effect = device_features.get mod.has_feature.side_effect = lambda id: id in device_features @@ -251,7 +252,10 @@ def _mocked_feature( feature.id = id feature.name = name or id.upper() feature.set_value = AsyncMock() - if not (fixture := FEATURES_FIXTURE.get(id)): + if fixture := FEATURES_FIXTURE.get(id): + # copy the fixture so tests do not interfere with each other + fixture = dict(fixture) + else: assert require_fixture is False, ( f"No fixture defined for feature {id} and require_fixture is True" ) @@ -259,7 +263,8 @@ def _mocked_feature( f"Value must be provided if feature {id} not defined in features.json" ) fixture = {"value": value, "category": "Primary", "type": "Sensor"} - elif value is not UNDEFINED: + + if value is not UNDEFINED: fixture["value"] = value feature.value = fixture["value"] @@ -352,9 +357,23 @@ def _mocked_fan_module(effect) -> Fan: def _mocked_alarm_module(device): alarm = MagicMock(auto_spec=Alarm, name="Mocked alarm") alarm.active = False + alarm.alarm_sounds = "Foo", "Bar" alarm.play = AsyncMock() alarm.stop = AsyncMock() + device.features["alarm_volume"] = _mocked_feature( + "alarm_volume", + minimum_value=0, + maximum_value=3, + value=None, + ) + device.features["alarm_duration"] = _mocked_feature( + "alarm_duration", + minimum_value=0, + maximum_value=300, + value=None, + ) + return alarm diff --git a/tests/components/tplink/snapshots/test_siren.ambr b/tests/components/tplink/snapshots/test_siren.ambr index b144288bd1c..7141ccfa084 100644 --- a/tests/components/tplink/snapshots/test_siren.ambr +++ b/tests/components/tplink/snapshots/test_siren.ambr @@ -40,7 +40,12 @@ 'aliases': set({ }), 'area_id': None, - 'capabilities': None, + 'capabilities': dict({ + 'available_tones': tuple( + 'Foo', + 'Bar', + ), + }), 'config_entry_id': , 'device_class': None, 'device_id': , @@ -62,7 +67,7 @@ 'original_name': None, 'platform': 'tplink', 'previous_unique_id': None, - 'supported_features': , + 'supported_features': , 'translation_key': None, 'unique_id': '123456789ABCDEFGH', 'unit_of_measurement': None, @@ -71,8 +76,12 @@ # name: test_states[siren.hub-state] StateSnapshot({ 'attributes': ReadOnlyDict({ + 'available_tones': tuple( + 'Foo', + 'Bar', + ), 'friendly_name': 'hub', - 'supported_features': , + 'supported_features': , }), 'context': , 'entity_id': 'siren.hub', diff --git a/tests/components/tplink/test_siren.py b/tests/components/tplink/test_siren.py index 8c3328558b0..1d820bca1d1 100644 --- a/tests/components/tplink/test_siren.py +++ b/tests/components/tplink/test_siren.py @@ -7,12 +7,16 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components.siren import ( + ATTR_DURATION, + ATTR_TONE, + ATTR_VOLUME_LEVEL, DOMAIN as SIREN_DOMAIN, SERVICE_TURN_OFF, SERVICE_TURN_ON, ) from homeassistant.const import ATTR_ENTITY_ID, Platform from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers import device_registry as dr, entity_registry as er from . import _mocked_device, setup_platform_for_device, snapshot_platform @@ -74,3 +78,91 @@ async def test_turn_on_and_off( ) alarm_module.play.assert_called() + + +@pytest.mark.parametrize( + ("max_volume", "volume_level", "expected_volume"), + [ + pytest.param(3, 0.1, 1, id="smart-10%"), + pytest.param(3, 0.3, 1, id="smart-30%"), + pytest.param(3, 0.99, 3, id="smart-99%"), + pytest.param(3, 1, 3, id="smart-100%"), + pytest.param(10, 0.1, 1, id="smartcam-10%"), + pytest.param(10, 0.3, 3, id="smartcam-30%"), + pytest.param(10, 0.99, 10, id="smartcam-99%"), + pytest.param(10, 1, 10, id="smartcam-100%"), + ], +) +async def test_turn_on_with_volume( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mocked_hub: Device, + max_volume: int, + volume_level: float, + expected_volume: int, +) -> None: + """Test that turn_on volume parameters work as expected.""" + + alarm_module = mocked_hub.modules[Module.Alarm] + alarm_volume_feat = alarm_module.get_feature("alarm_volume") + assert alarm_volume_feat + alarm_volume_feat.maximum_value = max_volume + + await setup_platform_for_device(hass, mock_config_entry, Platform.SIREN, mocked_hub) + + await hass.services.async_call( + SIREN_DOMAIN, + SERVICE_TURN_ON, + {ATTR_ENTITY_ID: [ENTITY_ID], ATTR_VOLUME_LEVEL: volume_level}, + blocking=True, + ) + + alarm_module.play.assert_called_with( + volume=expected_volume, duration=None, sound=None + ) + + +async def test_turn_on_with_duration_and_sound( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mocked_hub: Device, +) -> None: + """Test that turn_on tone and duration parameters work as expected.""" + + alarm_module = mocked_hub.modules[Module.Alarm] + + await setup_platform_for_device(hass, mock_config_entry, Platform.SIREN, mocked_hub) + + await hass.services.async_call( + SIREN_DOMAIN, + SERVICE_TURN_ON, + {ATTR_ENTITY_ID: [ENTITY_ID], ATTR_DURATION: 5, ATTR_TONE: "Foo"}, + blocking=True, + ) + + alarm_module.play.assert_called_with(volume=None, duration=5, sound="Foo") + + +@pytest.mark.parametrize(("duration"), [0, 301]) +async def test_turn_on_with_invalid_duration( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mocked_hub: Device, + duration: int, +) -> None: + """Test that turn_on with invalid_duration raises an error.""" + + await setup_platform_for_device(hass, mock_config_entry, Platform.SIREN, mocked_hub) + + msg = f"Invalid duration {duration} available: 1-300s" + + with pytest.raises(ServiceValidationError, match=msg): + await hass.services.async_call( + SIREN_DOMAIN, + SERVICE_TURN_ON, + { + ATTR_ENTITY_ID: [ENTITY_ID], + ATTR_DURATION: duration, + }, + blocking=True, + )