Voice Assistant: Require sample rate as input (#91182)

Require sample rate as input
This commit is contained in:
Paulus Schoutsen 2023-04-10 19:28:03 -04:00 committed by GitHub
parent 0fee49a32e
commit 1aa8e94224
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 16 deletions

View File

@ -10,6 +10,7 @@ import voluptuous as vol
from homeassistant.components import stt, websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from .pipeline import (
PipelineError,
@ -30,23 +31,46 @@ _LOGGER = logging.getLogger(__name__)
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_run)
websocket_api.async_register_command(
hass,
"voice_assistant/run",
websocket_run,
vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): "voice_assistant/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("language"): str,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
},
),
cv.key_value_schemas(
"start_stage",
{
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.TTS: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
},
),
),
)
@websocket_api.websocket_command(
{
vol.Required("type"): "voice_assistant/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): {"text": str},
vol.Optional("language"): str,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
}
)
@websocket_api.async_response
async def websocket_run(
hass: HomeAssistant,
@ -88,6 +112,7 @@ async def websocket_run(
if start_stage == PipelineStage.STT:
# Audio pipeline that will receive audio as binary websocket messages
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]
async def stt_stream():
state = None
@ -95,7 +120,9 @@ async def websocket_run(
# Yield until we receive an empty chunk
while chunk := await audio_queue.get():
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state
)
if not segmenter.process(chunk):
# Voice command is finished
break

View File

@ -70,6 +70,9 @@ async def test_audio_pipeline(
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 44100,
},
}
)
@ -263,6 +266,9 @@ async def test_audio_pipeline_timeout(
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 44100,
},
"timeout": 0.0001,
}
)
@ -295,6 +301,9 @@ async def test_stt_provider_missing(
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 44100,
},
}
)
@ -322,6 +331,9 @@ async def test_stt_stream_failed(
"type": "voice_assistant/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 44100,
},
}
)