From 65454c945dc43bf64bcaf976774def8399a4bcc9 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 23 Jun 2023 22:28:13 -0500 Subject: [PATCH] Add VAD sensitivity option to VoIP devices (#94688) * Add VAD sensitivity option to VoIP devices * Use select entitiy for VAD sensitivity * Add sensitivity to tests * Add to assist pipeline tests * Update homeassistant/components/assist_pipeline/select.py Co-authored-by: Paulus Schoutsen * Update tests/components/voip/test_voip.py --------- Co-authored-by: Paulus Schoutsen --- .../components/assist_pipeline/select.py | 51 +++++++++++++++++ .../components/assist_pipeline/strings.json | 8 +++ .../components/assist_pipeline/vad.py | 24 ++++++++ homeassistant/components/voip/select.py | 27 +++++++-- homeassistant/components/voip/strings.json | 8 +++ homeassistant/components/voip/voip.py | 24 +++++++- .../components/assist_pipeline/test_select.py | 55 +++++++++++++++++-- tests/components/voip/test_select.py | 15 +++++ tests/components/voip/test_voip.py | 8 ++- 9 files changed, 205 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index 8e9f11252be..2ae46fcb9ac 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -11,6 +11,7 @@ from homeassistant.helpers import collection, entity_registry as er, restore_sta from .const import DOMAIN from .pipeline import PipelineData, PipelineStorageCollection +from .vad import VadSensitivity OPTION_PREFERRED = "preferred" @@ -38,6 +39,25 @@ def get_chosen_pipeline( ) +@callback +def get_vad_sensitivity( + hass: HomeAssistant, domain: str, unique_id_prefix: str +) -> VadSensitivity: + """Get the chosen vad sensitivity for a domain.""" + ent_reg = er.async_get(hass) + sensitivity_entity_id = ent_reg.async_get_entity_id( + Platform.SELECT, domain, f"{unique_id_prefix}-vad_sensitivity" + ) + if sensitivity_entity_id is None: + return VadSensitivity.DEFAULT + + state = hass.states.get(sensitivity_entity_id) + if state is None: + return VadSensitivity.DEFAULT + + return VadSensitivity(state.state) + + class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): """Entity to represent a pipeline selector.""" @@ -102,3 +122,34 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): if self._attr_current_option not in options: self._attr_current_option = OPTION_PREFERRED + + +class VadSensitivitySelect(SelectEntity, restore_state.RestoreEntity): + """Entity to represent VAD sensitivity.""" + + entity_description = SelectEntityDescription( + key="vad_sensitivity", + translation_key="vad_sensitivity", + entity_category=EntityCategory.CONFIG, + ) + _attr_should_poll = False + _attr_current_option = VadSensitivity.DEFAULT.value + _attr_options = [vs.value for vs in VadSensitivity] + + def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None: + """Initialize a pipeline selector.""" + self._attr_unique_id = f"{unique_id_prefix}-vad_sensitivity" + self.hass = hass + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + + state = await self.async_get_last_state() + if state is not None and state.state in self.options: + self._attr_current_option = state.state + + async def async_select_option(self, option: str) -> None: + """Select an option.""" + self._attr_current_option = option + self.async_write_ha_state() diff --git a/homeassistant/components/assist_pipeline/strings.json b/homeassistant/components/assist_pipeline/strings.json index d85eb1aaed9..edcdff752f6 100644 --- a/homeassistant/components/assist_pipeline/strings.json +++ b/homeassistant/components/assist_pipeline/strings.json @@ -11,6 +11,14 @@ "state": { "preferred": "Preferred" } + }, + "vad_sensitivity": { + "name": "Silence sensitivity", + "state": { + "default": "Default", + "aggressive": "Aggressive", + "relaxed": "Relaxed" + } } } } diff --git a/homeassistant/components/assist_pipeline/vad.py b/homeassistant/components/assist_pipeline/vad.py index c5f87f1336a..f76de39ccce 100644 --- a/homeassistant/components/assist_pipeline/vad.py +++ b/homeassistant/components/assist_pipeline/vad.py @@ -1,11 +1,35 @@ """Voice activity detection.""" +from __future__ import annotations + from dataclasses import dataclass, field import webrtcvad +from homeassistant.backports.enum import StrEnum + _SAMPLE_RATE = 16000 +class VadSensitivity(StrEnum): + """How quickly the end of a voice command is detected.""" + + DEFAULT = "default" + RELAXED = "relaxed" + AGGRESSIVE = "aggressive" + + @staticmethod + def to_seconds(sensitivity: VadSensitivity | str) -> float: + """Return seconds of silence for sensitivity level.""" + sensitivity = VadSensitivity(sensitivity) + if sensitivity == VadSensitivity.RELAXED: + return 2.0 + + if sensitivity == VadSensitivity.AGGRESSIVE: + return 0.5 + + return 1.0 + + @dataclass class VoiceCommandSegmenter: """Segments an audio stream into voice commands using webrtcvad.""" diff --git a/homeassistant/components/voip/select.py b/homeassistant/components/voip/select.py index 7383e1b886a..94a3aacc0fd 100644 --- a/homeassistant/components/voip/select.py +++ b/homeassistant/components/voip/select.py @@ -4,7 +4,10 @@ from __future__ import annotations from typing import TYPE_CHECKING -from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.components.assist_pipeline.select import ( + AssistPipelineSelect, + VadSensitivitySelect, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -28,13 +31,18 @@ async def async_setup_entry( @callback def async_add_device(device: VoIPDevice) -> None: """Add device.""" - async_add_entities([VoipPipelineSelect(hass, device)]) + async_add_entities( + [VoipPipelineSelect(hass, device), VoipVadSensitivitySelect(hass, device)] + ) domain_data.devices.async_add_new_device_listener(async_add_device) - async_add_entities( - [VoipPipelineSelect(hass, device) for device in domain_data.devices] - ) + entities: list[VoIPEntity] = [] + for device in domain_data.devices: + entities.append(VoipPipelineSelect(hass, device)) + entities.append(VoipVadSensitivitySelect(hass, device)) + + async_add_entities(entities) class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect): @@ -44,3 +52,12 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect): """Initialize a pipeline selector.""" VoIPEntity.__init__(self, device) AssistPipelineSelect.__init__(self, hass, device.voip_id) + + +class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect): + """VAD sensitivity selector for VoIP devices.""" + + def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None: + """Initialize a VAD sensitivity selector.""" + VoIPEntity.__init__(self, device) + VadSensitivitySelect.__init__(self, hass, device.voip_id) diff --git a/homeassistant/components/voip/strings.json b/homeassistant/components/voip/strings.json index 2bef9a18008..8bcbb06d4e2 100644 --- a/homeassistant/components/voip/strings.json +++ b/homeassistant/components/voip/strings.json @@ -26,6 +26,14 @@ "state": { "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" } + }, + "vad_sensitivity": { + "name": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::name%]", + "state": { + "default": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::default%]", + "aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]", + "relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]" + } } } }, diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index d7e261508fd..32cfbd70337 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -23,12 +23,21 @@ from homeassistant.components.assist_pipeline import ( async_pipeline_from_audio_stream, select as pipeline_select, ) -from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter +from homeassistant.components.assist_pipeline.vad import ( + VadSensitivity, + VoiceCommandSegmenter, +) from homeassistant.const import __version__ from homeassistant.core import Context, HomeAssistant from homeassistant.util.ulid import ulid -from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH +from .const import ( + CHANNELS, + DOMAIN, + RATE, + RTP_AUDIO_SETTINGS, + WIDTH, +) if TYPE_CHECKING: from .devices import VoIPDevice, VoIPDevices @@ -63,6 +72,12 @@ def make_protocol( opus_payload_type=call_info.opus_payload_type, ) + vad_sensitivity = pipeline_select.get_vad_sensitivity( + hass, + DOMAIN, + voip_device.voip_id, + ) + # Pipeline is properly configured return PipelineRtpDatagramProtocol( hass, @@ -70,6 +85,7 @@ def make_protocol( voip_device, Context(user_id=devices.config_entry.data["user"]), opus_payload_type=call_info.opus_payload_type, + silence_seconds=VadSensitivity.to_seconds(vad_sensitivity), ) @@ -130,6 +146,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): error_tone_enabled: bool = True, tone_delay: float = 0.2, tts_extra_timeout: float = 1.0, + silence_seconds: float = 1.0, ) -> None: """Set up pipeline RTP server.""" super().__init__( @@ -151,6 +168,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): self.error_tone_enabled = error_tone_enabled self.tone_delay = tone_delay self.tts_extra_timeout = tts_extra_timeout + self.silence_seconds = silence_seconds self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() self._context = context @@ -199,7 +217,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): try: # Wait for speech before starting pipeline - segmenter = VoiceCommandSegmenter() + segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds) chunk_buffer: deque[bytes] = deque( maxlen=self.buffered_chunks_before_speech, ) diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index 2bc580864d7..bb9c4d45a32 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -9,7 +9,11 @@ from homeassistant.components.assist_pipeline.pipeline import ( PipelineData, PipelineStorageCollection, ) -from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.components.assist_pipeline.select import ( + AssistPipelineSelect, + VadSensitivitySelect, +) +from homeassistant.components.assist_pipeline.vad import VadSensitivity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr @@ -30,11 +34,15 @@ class SelectPlatform(MockPlatform): async_add_entities: AddEntitiesCallback, ) -> None: """Set up fake select platform.""" - entity = AssistPipelineSelect(hass, "test") - entity._attr_device_info = DeviceInfo( + pipeline_entity = AssistPipelineSelect(hass, "test") + pipeline_entity._attr_device_info = DeviceInfo( identifiers={("test", "test")}, ) - async_add_entities([entity]) + sensitivity_entity = VadSensitivitySelect(hass, "test") + sensitivity_entity._attr_device_info = DeviceInfo( + identifiers={("test", "test")}, + ) + async_add_entities([pipeline_entity, sensitivity_entity]) @pytest.fixture @@ -95,6 +103,7 @@ async def test_select_entity_registering_device( """Test entity registering as an assist device.""" dev_reg = dr.async_get(hass) device = dev_reg.async_get_device({("test", "test")}) + assert device is not None # Test device is registered assert pipeline_data.pipeline_devices == {device.id} @@ -138,6 +147,7 @@ async def test_select_entity_changing_pipelines( ) state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state is not None assert state.state == pipeline_2.name # Reload config entry to test selected option persists @@ -145,15 +155,52 @@ async def test_select_entity_changing_pipelines( assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state is not None assert state.state == pipeline_2.name # Remove selected pipeline await pipeline_storage.async_delete_item(pipeline_2.id) state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state is not None assert state.state == "preferred" assert state.attributes["options"] == [ "preferred", "Home Assistant", pipeline_1.name, ] + + +async def test_select_entity_changing_vad_sensitivity( + hass: HomeAssistant, + init_select: ConfigEntry, +) -> None: + """Test entity tracking pipeline changes.""" + config_entry = init_select # nicer naming + + state = hass.states.get("select.assist_pipeline_test_vad_sensitivity") + assert state is not None + assert state.state == VadSensitivity.DEFAULT.value + + # Change select to new pipeline + await hass.services.async_call( + "select", + "select_option", + { + "entity_id": "select.assist_pipeline_test_vad_sensitivity", + "option": VadSensitivity.AGGRESSIVE.value, + }, + blocking=True, + ) + + state = hass.states.get("select.assist_pipeline_test_vad_sensitivity") + assert state is not None + assert state.state == VadSensitivity.AGGRESSIVE.value + + # Reload config entry to test selected option persists + assert await hass.config_entries.async_forward_entry_unload(config_entry, "select") + assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") + + state = hass.states.get("select.assist_pipeline_test_vad_sensitivity") + assert state is not None + assert state.state == VadSensitivity.AGGRESSIVE.value diff --git a/tests/components/voip/test_select.py b/tests/components/voip/test_select.py index 19c3202576a..9d45477a429 100644 --- a/tests/components/voip/test_select.py +++ b/tests/components/voip/test_select.py @@ -17,3 +17,18 @@ async def test_pipeline_select( state = hass.states.get("select.192_168_1_210_assist_pipeline") assert state is not None assert state.state == "preferred" + + +async def test_vad_sensitivity_select( + hass: HomeAssistant, + config_entry: ConfigEntry, + voip_device: VoIPDevice, +) -> None: + """Test VAD sensitivity select. + + Functionality is tested in assist_pipeline/test_select.py. + This test is only to ensure it is set up. + """ + state = hass.states.get("select.192_168_1_210_silence_sensitivity") + assert state is not None + assert state.state == "default" diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index 8fc98f31167..9b3f5d963dc 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -95,6 +95,7 @@ async def test_pipeline( listening_tone_enabled=False, processing_tone_enabled=False, error_tone_enabled=False, + silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"), ) rtp_protocol.transport = Mock() @@ -113,7 +114,7 @@ async def test_pipeline( # "speech" rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) - # silence + # silence (assumes aggressive VAD sensitivity) rtp_protocol.on_chunk(bytes(_ONE_SECOND)) # Wait for mock pipeline to exhaust the audio stream @@ -288,6 +289,7 @@ async def test_tts_timeout( listening_tone_enabled=True, processing_tone_enabled=True, error_tone_enabled=True, + silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"), ) rtp_protocol._tone_bytes = tone_bytes rtp_protocol._processing_bytes = tone_bytes @@ -313,8 +315,8 @@ async def test_tts_timeout( # "speech" rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) - # silence - rtp_protocol.on_chunk(bytes(_ONE_SECOND)) + # silence (assumes relaxed VAD sensitivity) + rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4)) # Wait for mock pipeline to exhaust the audio stream async with async_timeout.timeout(1):