mirror of
https://github.com/home-assistant/core.git
synced 2025-07-12 15:57:06 +00:00
Voice Assistant: Require sample rate as input (#91182)
Require sample rate as input
This commit is contained in:
parent
0fee49a32e
commit
1aa8e94224
@ -10,6 +10,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import stt, websocket_api
|
from homeassistant.components import stt, websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
PipelineError,
|
PipelineError,
|
||||||
@ -30,23 +31,46 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
@callback
|
@callback
|
||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
"""Register the websocket API."""
|
"""Register the websocket API."""
|
||||||
websocket_api.async_register_command(hass, websocket_run)
|
websocket_api.async_register_command(
|
||||||
|
hass,
|
||||||
|
"voice_assistant/run",
|
||||||
|
websocket_run,
|
||||||
|
vol.All(
|
||||||
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "voice_assistant/run",
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||||
|
vol.Optional("input"): dict,
|
||||||
|
vol.Optional("language"): str,
|
||||||
|
vol.Optional("pipeline"): str,
|
||||||
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
cv.key_value_schemas(
|
||||||
|
"start_stage",
|
||||||
|
{
|
||||||
|
PipelineStage.STT: vol.Schema(
|
||||||
|
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
PipelineStage.INTENT: vol.Schema(
|
||||||
|
{vol.Required("input"): {"text": str}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
PipelineStage.TTS: vol.Schema(
|
||||||
|
{vol.Required("input"): {"text": str}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
|
||||||
{
|
|
||||||
vol.Required("type"): "voice_assistant/run",
|
|
||||||
# pylint: disable-next=unnecessary-lambda
|
|
||||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
|
||||||
# pylint: disable-next=unnecessary-lambda
|
|
||||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
|
||||||
vol.Optional("input"): {"text": str},
|
|
||||||
vol.Optional("language"): str,
|
|
||||||
vol.Optional("pipeline"): str,
|
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
|
||||||
vol.Optional("timeout"): vol.Any(float, int),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
async def websocket_run(
|
async def websocket_run(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -88,6 +112,7 @@ async def websocket_run(
|
|||||||
if start_stage == PipelineStage.STT:
|
if start_stage == PipelineStage.STT:
|
||||||
# Audio pipeline that will receive audio as binary websocket messages
|
# Audio pipeline that will receive audio as binary websocket messages
|
||||||
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||||
|
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||||
|
|
||||||
async def stt_stream():
|
async def stt_stream():
|
||||||
state = None
|
state = None
|
||||||
@ -95,7 +120,9 @@ async def websocket_run(
|
|||||||
|
|
||||||
# Yield until we receive an empty chunk
|
# Yield until we receive an empty chunk
|
||||||
while chunk := await audio_queue.get():
|
while chunk := await audio_queue.get():
|
||||||
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
|
chunk, state = audioop.ratecv(
|
||||||
|
chunk, 2, 1, incoming_sample_rate, 16000, state
|
||||||
|
)
|
||||||
if not segmenter.process(chunk):
|
if not segmenter.process(chunk):
|
||||||
# Voice command is finished
|
# Voice command is finished
|
||||||
break
|
break
|
||||||
|
@ -70,6 +70,9 @@ async def test_audio_pipeline(
|
|||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 44100,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -263,6 +266,9 @@ async def test_audio_pipeline_timeout(
|
|||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 44100,
|
||||||
|
},
|
||||||
"timeout": 0.0001,
|
"timeout": 0.0001,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -295,6 +301,9 @@ async def test_stt_provider_missing(
|
|||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 44100,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -322,6 +331,9 @@ async def test_stt_stream_failed(
|
|||||||
"type": "voice_assistant/run",
|
"type": "voice_assistant/run",
|
||||||
"start_stage": "stt",
|
"start_stage": "stt",
|
||||||
"end_stage": "tts",
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 44100,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user