mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
Add VAD sensitivity option to VoIP devices (#94688)
* Add VAD sensitivity option to VoIP devices * Use select entitiy for VAD sensitivity * Add sensitivity to tests * Add to assist pipeline tests * Update homeassistant/components/assist_pipeline/select.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Update tests/components/voip/test_voip.py --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
c42d0feec1
commit
65454c945d
@ -11,6 +11,7 @@ from homeassistant.helpers import collection, entity_registry as er, restore_sta
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
@ -38,6 +39,25 @@ def get_chosen_pipeline(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def get_vad_sensitivity(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> VadSensitivity:
|
||||
"""Get the chosen vad sensitivity for a domain."""
|
||||
ent_reg = er.async_get(hass)
|
||||
sensitivity_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.SELECT, domain, f"{unique_id_prefix}-vad_sensitivity"
|
||||
)
|
||||
if sensitivity_entity_id is None:
|
||||
return VadSensitivity.DEFAULT
|
||||
|
||||
state = hass.states.get(sensitivity_entity_id)
|
||||
if state is None:
|
||||
return VadSensitivity.DEFAULT
|
||||
|
||||
return VadSensitivity(state.state)
|
||||
|
||||
|
||||
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
"""Entity to represent a pipeline selector."""
|
||||
|
||||
@ -102,3 +122,34 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
|
||||
if self._attr_current_option not in options:
|
||||
self._attr_current_option = OPTION_PREFERRED
|
||||
|
||||
|
||||
class VadSensitivitySelect(SelectEntity, restore_state.RestoreEntity):
|
||||
"""Entity to represent VAD sensitivity."""
|
||||
|
||||
entity_description = SelectEntityDescription(
|
||||
key="vad_sensitivity",
|
||||
translation_key="vad_sensitivity",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_should_poll = False
|
||||
_attr_current_option = VadSensitivity.DEFAULT.value
|
||||
_attr_options = [vs.value for vs in VadSensitivity]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
self._attr_unique_id = f"{unique_id_prefix}-vad_sensitivity"
|
||||
self.hass = hass
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
state = await self.async_get_last_state()
|
||||
if state is not None and state.state in self.options:
|
||||
self._attr_current_option = state.state
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
self._attr_current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
@ -11,6 +11,14 @@
|
||||
"state": {
|
||||
"preferred": "Preferred"
|
||||
}
|
||||
},
|
||||
"vad_sensitivity": {
|
||||
"name": "Silence sensitivity",
|
||||
"state": {
|
||||
"default": "Default",
|
||||
"aggressive": "Aggressive",
|
||||
"relaxed": "Relaxed"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,35 @@
|
||||
"""Voice activity detection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import webrtcvad
|
||||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
|
||||
_SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
class VadSensitivity(StrEnum):
|
||||
"""How quickly the end of a voice command is detected."""
|
||||
|
||||
DEFAULT = "default"
|
||||
RELAXED = "relaxed"
|
||||
AGGRESSIVE = "aggressive"
|
||||
|
||||
@staticmethod
|
||||
def to_seconds(sensitivity: VadSensitivity | str) -> float:
|
||||
"""Return seconds of silence for sensitivity level."""
|
||||
sensitivity = VadSensitivity(sensitivity)
|
||||
if sensitivity == VadSensitivity.RELAXED:
|
||||
return 2.0
|
||||
|
||||
if sensitivity == VadSensitivity.AGGRESSIVE:
|
||||
return 0.5
|
||||
|
||||
return 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceCommandSegmenter:
|
||||
"""Segments an audio stream into voice commands using webrtcvad."""
|
||||
|
@ -4,7 +4,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.components.assist_pipeline.select import (
|
||||
AssistPipelineSelect,
|
||||
VadSensitivitySelect,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
@ -28,13 +31,18 @@ async def async_setup_entry(
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipPipelineSelect(hass, device)])
|
||||
async_add_entities(
|
||||
[VoipPipelineSelect(hass, device), VoipVadSensitivitySelect(hass, device)]
|
||||
)
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
async_add_entities(
|
||||
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
|
||||
)
|
||||
entities: list[VoIPEntity] = []
|
||||
for device in domain_data.devices:
|
||||
entities.append(VoipPipelineSelect(hass, device))
|
||||
entities.append(VoipVadSensitivitySelect(hass, device))
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
|
||||
@ -44,3 +52,12 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
|
||||
"""Initialize a pipeline selector."""
|
||||
VoIPEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, device.voip_id)
|
||||
|
||||
|
||||
class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):
|
||||
"""VAD sensitivity selector for VoIP devices."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
|
||||
"""Initialize a VAD sensitivity selector."""
|
||||
VoIPEntity.__init__(self, device)
|
||||
VadSensitivitySelect.__init__(self, hass, device.voip_id)
|
||||
|
@ -26,6 +26,14 @@
|
||||
"state": {
|
||||
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
|
||||
}
|
||||
},
|
||||
"vad_sensitivity": {
|
||||
"name": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::name%]",
|
||||
"state": {
|
||||
"default": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::default%]",
|
||||
"aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]",
|
||||
"relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -23,12 +23,21 @@ from homeassistant.components.assist_pipeline import (
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
VadSensitivity,
|
||||
VoiceCommandSegmenter,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.ulid import ulid
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .const import (
|
||||
CHANNELS,
|
||||
DOMAIN,
|
||||
RATE,
|
||||
RTP_AUDIO_SETTINGS,
|
||||
WIDTH,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
@ -63,6 +72,12 @@ def make_protocol(
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
)
|
||||
|
||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
|
||||
# Pipeline is properly configured
|
||||
return PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
@ -70,6 +85,7 @@ def make_protocol(
|
||||
voip_device,
|
||||
Context(user_id=devices.config_entry.data["user"]),
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
||||
)
|
||||
|
||||
|
||||
@ -130,6 +146,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
error_tone_enabled: bool = True,
|
||||
tone_delay: float = 0.2,
|
||||
tts_extra_timeout: float = 1.0,
|
||||
silence_seconds: float = 1.0,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
super().__init__(
|
||||
@ -151,6 +168,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
self.error_tone_enabled = error_tone_enabled
|
||||
self.tone_delay = tone_delay
|
||||
self.tts_extra_timeout = tts_extra_timeout
|
||||
self.silence_seconds = silence_seconds
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._context = context
|
||||
@ -199,7 +217,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
|
||||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
|
@ -9,7 +9,11 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.components.assist_pipeline.select import (
|
||||
AssistPipelineSelect,
|
||||
VadSensitivitySelect,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
@ -30,11 +34,15 @@ class SelectPlatform(MockPlatform):
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up fake select platform."""
|
||||
entity = AssistPipelineSelect(hass, "test")
|
||||
entity._attr_device_info = DeviceInfo(
|
||||
pipeline_entity = AssistPipelineSelect(hass, "test")
|
||||
pipeline_entity._attr_device_info = DeviceInfo(
|
||||
identifiers={("test", "test")},
|
||||
)
|
||||
async_add_entities([entity])
|
||||
sensitivity_entity = VadSensitivitySelect(hass, "test")
|
||||
sensitivity_entity._attr_device_info = DeviceInfo(
|
||||
identifiers={("test", "test")},
|
||||
)
|
||||
async_add_entities([pipeline_entity, sensitivity_entity])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -95,6 +103,7 @@ async def test_select_entity_registering_device(
|
||||
"""Test entity registering as an assist device."""
|
||||
dev_reg = dr.async_get(hass)
|
||||
device = dev_reg.async_get_device({("test", "test")})
|
||||
assert device is not None
|
||||
|
||||
# Test device is registered
|
||||
assert pipeline_data.pipeline_devices == {device.id}
|
||||
@ -138,6 +147,7 @@ async def test_select_entity_changing_pipelines(
|
||||
)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
# Reload config entry to test selected option persists
|
||||
@ -145,15 +155,52 @@ async def test_select_entity_changing_pipelines(
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
# Remove selected pipeline
|
||||
await pipeline_storage.async_delete_item(pipeline_2.id)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == [
|
||||
"preferred",
|
||||
"Home Assistant",
|
||||
pipeline_1.name,
|
||||
]
|
||||
|
||||
|
||||
async def test_select_entity_changing_vad_sensitivity(
|
||||
hass: HomeAssistant,
|
||||
init_select: ConfigEntry,
|
||||
) -> None:
|
||||
"""Test entity tracking pipeline changes."""
|
||||
config_entry = init_select # nicer naming
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
|
||||
assert state is not None
|
||||
assert state.state == VadSensitivity.DEFAULT.value
|
||||
|
||||
# Change select to new pipeline
|
||||
await hass.services.async_call(
|
||||
"select",
|
||||
"select_option",
|
||||
{
|
||||
"entity_id": "select.assist_pipeline_test_vad_sensitivity",
|
||||
"option": VadSensitivity.AGGRESSIVE.value,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
|
||||
assert state is not None
|
||||
assert state.state == VadSensitivity.AGGRESSIVE.value
|
||||
|
||||
# Reload config entry to test selected option persists
|
||||
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_vad_sensitivity")
|
||||
assert state is not None
|
||||
assert state.state == VadSensitivity.AGGRESSIVE.value
|
||||
|
@ -17,3 +17,18 @@ async def test_pipeline_select(
|
||||
state = hass.states.get("select.192_168_1_210_assist_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
|
||||
|
||||
async def test_vad_sensitivity_select(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test VAD sensitivity select.
|
||||
|
||||
Functionality is tested in assist_pipeline/test_select.py.
|
||||
This test is only to ensure it is set up.
|
||||
"""
|
||||
state = hass.states.get("select.192_168_1_210_silence_sensitivity")
|
||||
assert state is not None
|
||||
assert state.state == "default"
|
||||
|
@ -95,6 +95,7 @@ async def test_pipeline(
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
|
||||
@ -113,7 +114,7 @@ async def test_pipeline(
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence
|
||||
# silence (assumes aggressive VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
@ -288,6 +289,7 @@ async def test_tts_timeout(
|
||||
listening_tone_enabled=True,
|
||||
processing_tone_enabled=True,
|
||||
error_tone_enabled=True,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
||||
)
|
||||
rtp_protocol._tone_bytes = tone_bytes
|
||||
rtp_protocol._processing_bytes = tone_bytes
|
||||
@ -313,8 +315,8 @@ async def test_tts_timeout(
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with async_timeout.timeout(1):
|
||||
|
Loading…
x
Reference in New Issue
Block a user