Add VAD sensitivity option to VoIP devices (#94688)

* Add VAD sensitivity option to VoIP devices

* Use select entitiy for VAD sensitivity

* Add sensitivity to tests

* Add to assist pipeline tests

* Update homeassistant/components/assist_pipeline/select.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Update tests/components/voip/test_voip.py

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-06-23 22:28:13 -05:00 committed by GitHub
parent c42d0feec1
commit 65454c945d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 15 deletions

View File

@ -11,6 +11,7 @@ from homeassistant.helpers import collection, entity_registry as er, restore_sta
from .const import DOMAIN
from .pipeline import PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@ -38,6 +39,25 @@ def get_chosen_pipeline(
)
@callback
def get_vad_sensitivity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> VadSensitivity:
"""Get the chosen vad sensitivity for a domain."""
ent_reg = er.async_get(hass)
sensitivity_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT, domain, f"{unique_id_prefix}-vad_sensitivity"
)
if sensitivity_entity_id is None:
return VadSensitivity.DEFAULT
state = hass.states.get(sensitivity_entity_id)
if state is None:
return VadSensitivity.DEFAULT
return VadSensitivity(state.state)
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent a pipeline selector."""
@ -102,3 +122,34 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
if self._attr_current_option not in options:
self._attr_current_option = OPTION_PREFERRED
class VadSensitivitySelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent VAD sensitivity."""
entity_description = SelectEntityDescription(
key="vad_sensitivity",
translation_key="vad_sensitivity",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_current_option = VadSensitivity.DEFAULT.value
_attr_options = [vs.value for vs in VadSensitivity]
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._attr_unique_id = f"{unique_id_prefix}-vad_sensitivity"
self.hass = hass
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
self.async_write_ha_state()

View File

@ -11,6 +11,14 @@
"state": {
"preferred": "Preferred"
}
},
"vad_sensitivity": {
"name": "Silence sensitivity",
"state": {
"default": "Default",
"aggressive": "Aggressive",
"relaxed": "Relaxed"
}
}
}
}

View File

@ -1,11 +1,35 @@
"""Voice activity detection."""
from __future__ import annotations
from dataclasses import dataclass, field
import webrtcvad
from homeassistant.backports.enum import StrEnum
_SAMPLE_RATE = 16000
class VadSensitivity(StrEnum):
"""How quickly the end of a voice command is detected."""
DEFAULT = "default"
RELAXED = "relaxed"
AGGRESSIVE = "aggressive"
@staticmethod
def to_seconds(sensitivity: VadSensitivity | str) -> float:
"""Return seconds of silence for sensitivity level."""
sensitivity = VadSensitivity(sensitivity)
if sensitivity == VadSensitivity.RELAXED:
return 2.0
if sensitivity == VadSensitivity.AGGRESSIVE:
return 0.5
return 1.0
@dataclass
class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad."""

View File

@ -4,7 +4,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.components.assist_pipeline.select import (
AssistPipelineSelect,
VadSensitivitySelect,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -28,13 +31,18 @@ async def async_setup_entry(
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipPipelineSelect(hass, device)])
async_add_entities(
[VoipPipelineSelect(hass, device), VoipVadSensitivitySelect(hass, device)]
)
domain_data.devices.async_add_new_device_listener(async_add_device)
async_add_entities(
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
)
entities: list[VoIPEntity] = []
for device in domain_data.devices:
entities.append(VoipPipelineSelect(hass, device))
entities.append(VoipVadSensitivitySelect(hass, device))
async_add_entities(entities)
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
@ -44,3 +52,12 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
"""Initialize a pipeline selector."""
VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id)
class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):
"""VAD sensitivity selector for VoIP devices."""
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a VAD sensitivity selector."""
VoIPEntity.__init__(self, device)
VadSensitivitySelect.__init__(self, hass, device.voip_id)

View File

@ -26,6 +26,14 @@
"state": {
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
}
},
"vad_sensitivity": {
"name": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::name%]",
"state": {
"default": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::default%]",
"aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]",
"relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]"
}
}
}
},

View File

@ -23,12 +23,21 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.components.assist_pipeline.vad import (
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .const import (
CHANNELS,
DOMAIN,
RATE,
RTP_AUDIO_SETTINGS,
WIDTH,
)
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
@ -63,6 +72,12 @@ def make_protocol(
opus_payload_type=call_info.opus_payload_type,
)
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
# Pipeline is properly configured
return PipelineRtpDatagramProtocol(
hass,
@ -70,6 +85,7 @@ def make_protocol(
voip_device,
Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
)
@ -130,6 +146,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
error_tone_enabled: bool = True,
tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
) -> None:
"""Set up pipeline RTP server."""
super().__init__(
@ -151,6 +168,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context
@ -199,7 +217,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter()
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)

View File

@ -9,7 +9,11 @@ from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.components.assist_pipeline.select import (
AssistPipelineSelect,
VadSensitivitySelect,
)
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
@ -30,11 +34,15 @@ class SelectPlatform(MockPlatform):
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up fake select platform."""
entity = AssistPipelineSelect(hass, "test")
entity._attr_device_info = DeviceInfo(
pipeline_entity = AssistPipelineSelect(hass, "test")
pipeline_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
async_add_entities([entity])
sensitivity_entity = VadSensitivitySelect(hass, "test")
sensitivity_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
async_add_entities([pipeline_entity, sensitivity_entity])
@pytest.fixture
@ -95,6 +103,7 @@ async def test_select_entity_registering_device(
"""Test entity registering as an assist device."""
dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device({("test", "test")})
assert device is not None
# Test device is registered
assert pipeline_data.pipeline_devices == {device.id}
@ -138,6 +147,7 @@ async def test_select_entity_changing_pipelines(
)
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == pipeline_2.name
# Reload config entry to test selected option persists
@ -145,15 +155,52 @@ async def test_select_entity_changing_pipelines(
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == pipeline_2.name
# Remove selected pipeline
await pipeline_storage.async_delete_item(pipeline_2.id)
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [
"preferred",
"Home Assistant",
pipeline_1.name,
]
async def test_select_entity_changing_vad_sensitivity(
hass: HomeAssistant,
init_select: ConfigEntry,
) -> None:
"""Test entity tracking pipeline changes."""
config_entry = init_select # nicer naming
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.DEFAULT.value
# Change select to new pipeline
await hass.services.async_call(
"select",
"select_option",
{
"entity_id": "select.assist_pipeline_test_vad_sensitivity",
"option": VadSensitivity.AGGRESSIVE.value,
},
blocking=True,
)
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.AGGRESSIVE.value
# Reload config entry to test selected option persists
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
assert state is not None
assert state.state == VadSensitivity.AGGRESSIVE.value

View File

@ -17,3 +17,18 @@ async def test_pipeline_select(
state = hass.states.get("select.192_168_1_210_assist_pipeline")
assert state is not None
assert state.state == "preferred"
async def test_vad_sensitivity_select(
hass: HomeAssistant,
config_entry: ConfigEntry,
voip_device: VoIPDevice,
) -> None:
"""Test VAD sensitivity select.
Functionality is tested in assist_pipeline/test_select.py.
This test is only to ensure it is set up.
"""
state = hass.states.get("select.192_168_1_210_silence_sensitivity")
assert state is not None
assert state.state == "default"

View File

@ -95,6 +95,7 @@ async def test_pipeline(
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
)
rtp_protocol.transport = Mock()
@ -113,7 +114,7 @@ async def test_pipeline(
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence
# silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
@ -288,6 +289,7 @@ async def test_tts_timeout(
listening_tone_enabled=True,
processing_tone_enabled=True,
error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
)
rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes
@ -313,8 +315,8 @@ async def test_tts_timeout(
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with async_timeout.timeout(1):