From 1aa8e942240629f1629da183c0ce43b1068f356b Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 10 Apr 2023 19:28:03 -0400 Subject: [PATCH] Voice Assistant: Require sample rate as input (#91182) Require sample rate as input --- .../voice_assistant/websocket_api.py | 59 ++++++++++++++----- .../voice_assistant/test_websocket.py | 12 ++++ 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/voice_assistant/websocket_api.py b/homeassistant/components/voice_assistant/websocket_api.py index 42c22bfbed5..de115cf7596 100644 --- a/homeassistant/components/voice_assistant/websocket_api.py +++ b/homeassistant/components/voice_assistant/websocket_api.py @@ -10,6 +10,7 @@ import voluptuous as vol from homeassistant.components import stt, websocket_api from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_validation as cv from .pipeline import ( PipelineError, @@ -30,23 +31,46 @@ _LOGGER = logging.getLogger(__name__) @callback def async_register_websocket_api(hass: HomeAssistant) -> None: """Register the websocket API.""" - websocket_api.async_register_command(hass, websocket_run) + websocket_api.async_register_command( + hass, + "voice_assistant/run", + websocket_run, + vol.All( + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + vol.Required("type"): "voice_assistant/run", + # pylint: disable-next=unnecessary-lambda + vol.Required("start_stage"): lambda val: PipelineStage(val), + # pylint: disable-next=unnecessary-lambda + vol.Required("end_stage"): lambda val: PipelineStage(val), + vol.Optional("input"): dict, + vol.Optional("language"): str, + vol.Optional("pipeline"): str, + vol.Optional("conversation_id"): vol.Any(str, None), + vol.Optional("timeout"): vol.Any(float, int), + }, + ), + cv.key_value_schemas( + "start_stage", + { + PipelineStage.STT: vol.Schema( + {vol.Required("input"): {vol.Required("sample_rate"): int}}, + extra=vol.ALLOW_EXTRA, + ), + PipelineStage.INTENT: vol.Schema( + {vol.Required("input"): {"text": str}}, + extra=vol.ALLOW_EXTRA, + ), + PipelineStage.TTS: vol.Schema( + {vol.Required("input"): {"text": str}}, + extra=vol.ALLOW_EXTRA, + ), + }, + ), + ), + ) -@websocket_api.websocket_command( - { - vol.Required("type"): "voice_assistant/run", - # pylint: disable-next=unnecessary-lambda - vol.Required("start_stage"): lambda val: PipelineStage(val), - # pylint: disable-next=unnecessary-lambda - vol.Required("end_stage"): lambda val: PipelineStage(val), - vol.Optional("input"): {"text": str}, - vol.Optional("language"): str, - vol.Optional("pipeline"): str, - vol.Optional("conversation_id"): vol.Any(str, None), - vol.Optional("timeout"): vol.Any(float, int), - } -) @websocket_api.async_response async def websocket_run( hass: HomeAssistant, @@ -88,6 +112,7 @@ async def websocket_run( if start_stage == PipelineStage.STT: # Audio pipeline that will receive audio as binary websocket messages audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue() + incoming_sample_rate = msg["input"]["sample_rate"] async def stt_stream(): state = None @@ -95,7 +120,9 @@ async def websocket_run( # Yield until we receive an empty chunk while chunk := await audio_queue.get(): - chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state) + chunk, state = audioop.ratecv( + chunk, 2, 1, incoming_sample_rate, 16000, state + ) if not segmenter.process(chunk): # Voice command is finished break diff --git a/tests/components/voice_assistant/test_websocket.py b/tests/components/voice_assistant/test_websocket.py index 938184b607a..d5bd4d8ea1a 100644 --- a/tests/components/voice_assistant/test_websocket.py +++ b/tests/components/voice_assistant/test_websocket.py @@ -70,6 +70,9 @@ async def test_audio_pipeline( "type": "voice_assistant/run", "start_stage": "stt", "end_stage": "tts", + "input": { + "sample_rate": 44100, + }, } ) @@ -263,6 +266,9 @@ async def test_audio_pipeline_timeout( "type": "voice_assistant/run", "start_stage": "stt", "end_stage": "tts", + "input": { + "sample_rate": 44100, + }, "timeout": 0.0001, } ) @@ -295,6 +301,9 @@ async def test_stt_provider_missing( "type": "voice_assistant/run", "start_stage": "stt", "end_stage": "tts", + "input": { + "sample_rate": 44100, + }, } ) @@ -322,6 +331,9 @@ async def test_stt_stream_failed( "type": "voice_assistant/run", "start_stage": "stt", "end_stage": "tts", + "input": { + "sample_rate": 44100, + }, } )