mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Voice Assistant: Require sample rate as input (#91182)
Require sample rate as input
This commit is contained in:
parent
0fee49a32e
commit
1aa8e94224
@ -10,6 +10,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt, websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .pipeline import (
|
||||
PipelineError,
|
||||
@ -30,23 +31,46 @@ _LOGGER = logging.getLogger(__name__)
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_run)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
"voice_assistant/run",
|
||||
websocket_run,
|
||||
vol.All(
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): "voice_assistant/run",
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||
vol.Optional("input"): dict,
|
||||
vol.Optional("language"): str,
|
||||
vol.Optional("pipeline"): str,
|
||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
},
|
||||
),
|
||||
cv.key_value_schemas(
|
||||
"start_stage",
|
||||
{
|
||||
PipelineStage.STT: vol.Schema(
|
||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.INTENT: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.TTS: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "voice_assistant/run",
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||
vol.Optional("input"): {"text": str},
|
||||
vol.Optional("language"): str,
|
||||
vol.Optional("pipeline"): str,
|
||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_run(
|
||||
hass: HomeAssistant,
|
||||
@ -88,6 +112,7 @@ async def websocket_run(
|
||||
if start_stage == PipelineStage.STT:
|
||||
# Audio pipeline that will receive audio as binary websocket messages
|
||||
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||
|
||||
async def stt_stream():
|
||||
state = None
|
||||
@ -95,7 +120,9 @@ async def websocket_run(
|
||||
|
||||
# Yield until we receive an empty chunk
|
||||
while chunk := await audio_queue.get():
|
||||
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
||||
chunk, state = audioop.ratecv(
|
||||
chunk, 2, 1, incoming_sample_rate, 16000, state
|
||||
)
|
||||
if not segmenter.process(chunk):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
@ -70,6 +70,9 @@ async def test_audio_pipeline(
|
||||
"type": "voice_assistant/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 44100,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@ -263,6 +266,9 @@ async def test_audio_pipeline_timeout(
|
||||
"type": "voice_assistant/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 44100,
|
||||
},
|
||||
"timeout": 0.0001,
|
||||
}
|
||||
)
|
||||
@ -295,6 +301,9 @@ async def test_stt_provider_missing(
|
||||
"type": "voice_assistant/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 44100,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@ -322,6 +331,9 @@ async def test_stt_stream_failed(
|
||||
"type": "voice_assistant/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 44100,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user