Add VAD sensitivity option to VoIP devices (#94688)

* Add VAD sensitivity option to VoIP devices

* Use select entitiy for VAD sensitivity

* Add sensitivity to tests

* Add to assist pipeline tests

* Update homeassistant/components/assist_pipeline/select.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Update tests/components/voip/test_voip.py

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-06-23 22:28:13 -05:00 committed by GitHub
parent c42d0feec1
commit 65454c945d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 15 deletions

View File

@ -11,6 +11,7 @@ from homeassistant.helpers import collection, entity_registry as er, restore_sta
from .const import DOMAIN from .const import DOMAIN
from .pipeline import PipelineData, PipelineStorageCollection from .pipeline import PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred" OPTION_PREFERRED = "preferred"
@ -38,6 +39,25 @@ def get_chosen_pipeline(
) )
@callback
def get_vad_sensitivity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> VadSensitivity:
"""Get the chosen vad sensitivity for a domain."""
ent_reg = er.async_get(hass)
sensitivity_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT, domain, f"{unique_id_prefix}-vad_sensitivity"
)
if sensitivity_entity_id is None:
return VadSensitivity.DEFAULT
state = hass.states.get(sensitivity_entity_id)
if state is None:
return VadSensitivity.DEFAULT
return VadSensitivity(state.state)
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent a pipeline selector.""" """Entity to represent a pipeline selector."""
@ -102,3 +122,34 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
if self._attr_current_option not in options: if self._attr_current_option not in options:
self._attr_current_option = OPTION_PREFERRED self._attr_current_option = OPTION_PREFERRED
class VadSensitivitySelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent VAD sensitivity."""
entity_description = SelectEntityDescription(
key="vad_sensitivity",
translation_key="vad_sensitivity",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_current_option = VadSensitivity.DEFAULT.value
_attr_options = [vs.value for vs in VadSensitivity]
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._attr_unique_id = f"{unique_id_prefix}-vad_sensitivity"
self.hass = hass
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
self.async_write_ha_state()

View File

@ -11,6 +11,14 @@
"state": { "state": {
"preferred": "Preferred" "preferred": "Preferred"
} }
},
"vad_sensitivity": {
"name": "Silence sensitivity",
"state": {
"default": "Default",
"aggressive": "Aggressive",
"relaxed": "Relaxed"
}
} }
} }
} }

View File

@ -1,11 +1,35 @@
"""Voice activity detection.""" """Voice activity detection."""
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
import webrtcvad import webrtcvad
from homeassistant.backports.enum import StrEnum
_SAMPLE_RATE = 16000 _SAMPLE_RATE = 16000
class VadSensitivity(StrEnum):
"""How quickly the end of a voice command is detected."""
DEFAULT = "default"
RELAXED = "relaxed"
AGGRESSIVE = "aggressive"
@staticmethod
def to_seconds(sensitivity: VadSensitivity | str) -> float:
"""Return seconds of silence for sensitivity level."""
sensitivity = VadSensitivity(sensitivity)
if sensitivity == VadSensitivity.RELAXED:
return 2.0
if sensitivity == VadSensitivity.AGGRESSIVE:
return 0.5
return 1.0
@dataclass @dataclass
class VoiceCommandSegmenter: class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad.""" """Segments an audio stream into voice commands using webrtcvad."""

View File

@ -4,7 +4,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect from homeassistant.components.assist_pipeline.select import (
AssistPipelineSelect,
VadSensitivitySelect,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -28,13 +31,18 @@ async def async_setup_entry(
@callback @callback
def async_add_device(device: VoIPDevice) -> None: def async_add_device(device: VoIPDevice) -> None:
"""Add device.""" """Add device."""
async_add_entities([VoipPipelineSelect(hass, device)]) async_add_entities(
[VoipPipelineSelect(hass, device), VoipVadSensitivitySelect(hass, device)]
)
domain_data.devices.async_add_new_device_listener(async_add_device) domain_data.devices.async_add_new_device_listener(async_add_device)
async_add_entities( entities: list[VoIPEntity] = []
[VoipPipelineSelect(hass, device) for device in domain_data.devices] for device in domain_data.devices:
) entities.append(VoipPipelineSelect(hass, device))
entities.append(VoipVadSensitivitySelect(hass, device))
async_add_entities(entities)
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect): class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
@ -44,3 +52,12 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
"""Initialize a pipeline selector.""" """Initialize a pipeline selector."""
VoIPEntity.__init__(self, device) VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id) AssistPipelineSelect.__init__(self, hass, device.voip_id)
class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):
"""VAD sensitivity selector for VoIP devices."""
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a VAD sensitivity selector."""
VoIPEntity.__init__(self, device)
VadSensitivitySelect.__init__(self, hass, device.voip_id)

View File

@ -26,6 +26,14 @@
"state": { "state": {
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
} }
},
"vad_sensitivity": {
"name": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::name%]",
"state": {
"default": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::default%]",
"aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]",
"relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]"
}
} }
} }
}, },

View File

@ -23,12 +23,21 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select, select as pipeline_select,
) )
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter from homeassistant.components.assist_pipeline.vad import (
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid from homeassistant.util.ulid import ulid
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH from .const import (
CHANNELS,
DOMAIN,
RATE,
RTP_AUDIO_SETTINGS,
WIDTH,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices from .devices import VoIPDevice, VoIPDevices
@ -63,6 +72,12 @@ def make_protocol(
opus_payload_type=call_info.opus_payload_type, opus_payload_type=call_info.opus_payload_type,
) )
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
# Pipeline is properly configured # Pipeline is properly configured
return PipelineRtpDatagramProtocol( return PipelineRtpDatagramProtocol(
hass, hass,
@ -70,6 +85,7 @@ def make_protocol(
voip_device, voip_device,
Context(user_id=devices.config_entry.data["user"]), Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type, opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
) )
@ -130,6 +146,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
error_tone_enabled: bool = True, error_tone_enabled: bool = True,
tone_delay: float = 0.2, tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0, tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
) -> None: ) -> None:
"""Set up pipeline RTP server.""" """Set up pipeline RTP server."""
super().__init__( super().__init__(
@ -151,6 +168,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.error_tone_enabled = error_tone_enabled self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context self._context = context
@ -199,7 +217,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try: try:
# Wait for speech before starting pipeline # Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter() segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
chunk_buffer: deque[bytes] = deque( chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech, maxlen=self.buffered_chunks_before_speech,
) )

View File

@ -9,7 +9,11 @@ from homeassistant.components.assist_pipeline.pipeline import (
PipelineData, PipelineData,
PipelineStorageCollection, PipelineStorageCollection,
) )
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect from homeassistant.components.assist_pipeline.select import (
AssistPipelineSelect,
VadSensitivitySelect,
)
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
@ -30,11 +34,15 @@ class SelectPlatform(MockPlatform):
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up fake select platform.""" """Set up fake select platform."""
entity = AssistPipelineSelect(hass, "test") pipeline_entity = AssistPipelineSelect(hass, "test")
entity._attr_device_info = DeviceInfo( pipeline_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")}, identifiers={("test", "test")},
) )
async_add_entities([entity]) sensitivity_entity = VadSensitivitySelect(hass, "test")
sensitivity_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
async_add_entities([pipeline_entity, sensitivity_entity])
@pytest.fixture @pytest.fixture
@ -95,6 +103,7 @@ async def test_select_entity_registering_device(
"""Test entity registering as an assist device.""" """Test entity registering as an assist device."""
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device({("test", "test")}) device = dev_reg.async_get_device({("test", "test")})
assert device is not None
# Test device is registered # Test device is registered
assert pipeline_data.pipeline_devices == {device.id} assert pipeline_data.pipeline_devices == {device.id}
@ -138,6 +147,7 @@ async def test_select_entity_changing_pipelines(
) )
state = hass.states.get("select.assist_pipeline_test_pipeline") state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == pipeline_2.name assert state.state == pipeline_2.name
# Reload config entry to test selected option persists # Reload config entry to test selected option persists
@ -145,15 +155,52 @@ async def test_select_entity_changing_pipelines(
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_pipeline") state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == pipeline_2.name assert state.state == pipeline_2.name
# Remove selected pipeline # Remove selected pipeline
await pipeline_storage.async_delete_item(pipeline_2.id) await pipeline_storage.async_delete_item(pipeline_2.id)
state = hass.states.get("select.assist_pipeline_test_pipeline") state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == "preferred" assert state.state == "preferred"
assert state.attributes["options"] == [ assert state.attributes["options"] == [
"preferred", "preferred",
"Home Assistant", "Home Assistant",
pipeline_1.name, pipeline_1.name,
] ]
async def test_select_entity_changing_vad_sensitivity(
hass: HomeAssistant,
init_select: ConfigEntry,
) -> None:
"""Test entity tracking pipeline changes."""
config_entry = init_select # nicer naming
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.DEFAULT.value
# Change select to new pipeline
await hass.services.async_call(
"select",
"select_option",
{
"entity_id": "select.assist_pipeline_test_vad_sensitivity",
"option": VadSensitivity.AGGRESSIVE.value,
},
blocking=True,
)
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.AGGRESSIVE.value
# Reload config entry to test selected option persists
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.AGGRESSIVE.value

View File

@ -17,3 +17,18 @@ async def test_pipeline_select(
state = hass.states.get("select.192_168_1_210_assist_pipeline") state = hass.states.get("select.192_168_1_210_assist_pipeline")
assert state is not None assert state is not None
assert state.state == "preferred" assert state.state == "preferred"
async def test_vad_sensitivity_select(
hass: HomeAssistant,
config_entry: ConfigEntry,
voip_device: VoIPDevice,
) -> None:
"""Test VAD sensitivity select.
Functionality is tested in assist_pipeline/test_select.py.
This test is only to ensure it is set up.
"""
state = hass.states.get("select.192_168_1_210_silence_sensitivity")
assert state is not None
assert state.state == "default"

View File

@ -95,6 +95,7 @@ async def test_pipeline(
listening_tone_enabled=False, listening_tone_enabled=False,
processing_tone_enabled=False, processing_tone_enabled=False,
error_tone_enabled=False, error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
) )
rtp_protocol.transport = Mock() rtp_protocol.transport = Mock()
@ -113,7 +114,7 @@ async def test_pipeline(
# "speech" # "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence # silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND)) rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream # Wait for mock pipeline to exhaust the audio stream
@ -288,6 +289,7 @@ async def test_tts_timeout(
listening_tone_enabled=True, listening_tone_enabled=True,
processing_tone_enabled=True, processing_tone_enabled=True,
error_tone_enabled=True, error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
) )
rtp_protocol._tone_bytes = tone_bytes rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes rtp_protocol._processing_bytes = tone_bytes
@ -313,8 +315,8 @@ async def test_tts_timeout(
# "speech" # "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence # silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND)) rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream # Wait for mock pipeline to exhaust the audio stream
async with async_timeout.timeout(1): async with async_timeout.timeout(1):