diff --git a/homeassistant/components/voice_assistant/manifest.json b/homeassistant/components/voice_assistant/manifest.json index 644c49e9459..f4a17bf52e7 100644 --- a/homeassistant/components/voice_assistant/manifest.json +++ b/homeassistant/components/voice_assistant/manifest.json @@ -5,5 +5,6 @@ "dependencies": ["conversation", "stt", "tts"], "documentation": "https://www.home-assistant.io/integrations/voice_assistant", "iot_class": "local_push", - "quality_scale": "internal" + "quality_scale": "internal", + "requirements": ["webrtcvad==2.0.10"] } diff --git a/homeassistant/components/voice_assistant/vad.py b/homeassistant/components/voice_assistant/vad.py new file mode 100644 index 00000000000..e86579b9750 --- /dev/null +++ b/homeassistant/components/voice_assistant/vad.py @@ -0,0 +1,128 @@ +"""Voice activity detection.""" +from dataclasses import dataclass, field + +import webrtcvad + +_SAMPLE_RATE = 16000 + + +@dataclass +class VoiceCommandSegmenter: + """Segments an audio stream into voice commands using webrtcvad.""" + + vad_mode: int = 3 + """Aggressiveness in filtering out non-speech. 3 is the most aggressive.""" + + vad_frames: int = 480 # 30 ms + """Must be 10, 20, or 30 ms at 16Khz.""" + + speech_seconds: float = 0.3 + """Seconds of speech before voice command has started.""" + + silence_seconds: float = 0.5 + """Seconds of silence after voice command has ended.""" + + timeout_seconds: float = 15.0 + """Maximum number of seconds before stopping with timeout=True.""" + + reset_seconds: float = 1.0 + """Seconds before reset start/stop time counters.""" + + _in_command: bool = False + """True if inside voice command.""" + + _speech_seconds_left: float = 0.0 + """Seconds left before considering voice command as started.""" + + _silence_seconds_left: float = 0.0 + """Seconds left before considering voice command as stopped.""" + + _timeout_seconds_left: float = 0.0 + """Seconds left before considering voice command timed out.""" + + _reset_seconds_left: float = 0.0 + """Seconds left before resetting start/stop time counters.""" + + _vad: webrtcvad.Vad = None + _audio_buffer: bytes = field(default_factory=bytes) + _bytes_per_chunk: int = 480 * 2 # 16-bit samples + _seconds_per_chunk: float = 0.03 # 30 ms + + def __post_init__(self): + """Initialize VAD.""" + self._vad = webrtcvad.Vad(self.vad_mode) + self._bytes_per_chunk = self.vad_frames * 2 + self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE + self.reset() + + def reset(self): + """Reset all counters and state.""" + self._audio_buffer = b"" + self._speech_seconds_left = self.speech_seconds + self._silence_seconds_left = self.silence_seconds + self._timeout_seconds_left = self.timeout_seconds + self._reset_seconds_left = self.reset_seconds + self._in_command = False + + def process(self, samples: bytes) -> bool: + """Process a 16-bit 16Khz mono audio samples. + + Returns False when command is done. + """ + self._audio_buffer += samples + + # Process in 10, 20, or 30 ms chunks. + num_chunks = len(self._audio_buffer) // self._bytes_per_chunk + for chunk_idx in range(num_chunks): + chunk_offset = chunk_idx * self._bytes_per_chunk + chunk = self._audio_buffer[ + chunk_offset : chunk_offset + self._bytes_per_chunk + ] + if not self._process_chunk(chunk): + self.reset() + return False + + if num_chunks > 0: + # Remove from buffer + self._audio_buffer = self._audio_buffer[ + num_chunks * self._bytes_per_chunk : + ] + + return True + + def _process_chunk(self, chunk: bytes) -> bool: + """Process a single chunk of 16-bit 16Khz mono audio. + + Returns False when command is done. + """ + is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE) + + self._timeout_seconds_left -= self._seconds_per_chunk + if self._timeout_seconds_left <= 0: + return False + + if not self._in_command: + if is_speech: + self._reset_seconds_left = self.reset_seconds + self._speech_seconds_left -= self._seconds_per_chunk + if self._speech_seconds_left <= 0: + # Inside voice command + self._in_command = True + else: + # Reset if enough silence + self._reset_seconds_left -= self._seconds_per_chunk + if self._reset_seconds_left <= 0: + self._speech_seconds_left = self.speech_seconds + else: + if not is_speech: + self._reset_seconds_left = self.reset_seconds + self._silence_seconds_left -= self._seconds_per_chunk + if self._silence_seconds_left <= 0: + return False + else: + # Reset if enough speech + self._reset_seconds_left -= self._seconds_per_chunk + if self._reset_seconds_left <= 0: + self._silence_seconds_left = self.silence_seconds + + return True diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index aa295ad5c62..718989f6613 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -20,15 +20,12 @@ from .pipeline import ( PipelineStage, async_get_pipeline, ) +from .vad import VoiceCommandSegmenter DEFAULT_TIMEOUT = 30 _LOGGER = logging.getLogger(__name__) -_VAD_ENERGY_THRESHOLD = 1000 -_VAD_SPEECH_FRAMES = 25 -_VAD_SILENCE_FRAMES = 25 - @callback def async_register_websocket_api(hass: HomeAssistant) -> None: @@ -36,17 +33,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_run) -def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float: - """Compute RMS of debiased audio.""" - energy = -audioop.rms(audio_data, width) - energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF]) - debiased_energy = audioop.rms( - audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width - ) - - return debiased_energy - - @websocket_api.websocket_command( { vol.Required("type"): "voice_assistant/run", @@ -105,30 +91,14 @@ async def websocket_run( async def stt_stream(): state = None - speech_count = 0 - in_voice_command = False + segmenter = VoiceCommandSegmenter() # Yield until we receive an empty chunk while chunk := await audio_queue.get(): chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state) - is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD - - if in_voice_command: - if is_speech: - speech_count += 1 - else: - speech_count -= 1 - - if speech_count <= -_VAD_SILENCE_FRAMES: - _LOGGER.info("Voice command stopped") - break - else: - if is_speech: - speech_count += 1 - - if speech_count >= _VAD_SPEECH_FRAMES: - in_voice_command = True - _LOGGER.info("Voice command started") + if not segmenter.process(chunk): + # Voice command is finished + break yield chunk diff --git a/requirements_all.txt b/requirements_all.txt index 935d7bdb69f..e01e1c23e89 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2619,6 +2619,9 @@ waterfurnace==1.1.0 # homeassistant.components.cisco_webex_teams webexteamssdk==1.1.1 +# homeassistant.components.voice_assistant +webrtcvad==2.0.10 + # homeassistant.components.whirlpool whirlpool-sixth-sense==0.18.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 27d71f3d569..89b84246ffd 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1877,6 +1877,9 @@ wallbox==0.4.12 # homeassistant.components.folder_watcher watchdog==2.3.1 +# homeassistant.components.voice_assistant +webrtcvad==2.0.10 + # homeassistant.components.whirlpool whirlpool-sixth-sense==0.18.2 diff --git a/tests/components/voice_assistant/test_vad.py b/tests/components/voice_assistant/test_vad.py new file mode 100644 index 00000000000..4285f78d51b --- /dev/null +++ b/tests/components/voice_assistant/test_vad.py @@ -0,0 +1,38 @@ +"""Tests for webrtcvad voice command segmenter.""" +from unittest.mock import patch + +from homeassistant.components.voice_assistant.vad import VoiceCommandSegmenter + +_ONE_SECOND = 16000 * 2 # 16Khz 16-bit + + +def test_silence() -> None: + """Test that 3 seconds of silence does not trigger a voice command.""" + segmenter = VoiceCommandSegmenter() + + # True return value indicates voice command has not finished + assert segmenter.process(bytes(_ONE_SECOND * 3)) + + +def test_speech() -> None: + """Test that silence + speech + silence triggers a voice command.""" + + def is_speech(self, chunk, sample_rate): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ): + segmenter = VoiceCommandSegmenter() + + # silence + assert segmenter.process(bytes(_ONE_SECOND)) + + # "speech" + assert segmenter.process(bytes([255] * _ONE_SECOND)) + + # silence + # False return value indicates voice command is finished + assert not segmenter.process(bytes(_ONE_SECOND)) diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index ce876550327..54fe51a7a22 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -75,7 +75,7 @@ class MockSTT: hass: HomeAssistant, config: ConfigType, discovery_info: DiscoveryInfoType | None = None, - ) -> tts.Provider: + ) -> stt.Provider: """Set up a mock speech component.""" return MockSttProvider(hass, _TRANSCRIPT)