mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 14:27:07 +00:00
Voice Assistant: improve error handling (#90541)
Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
84eb9c5f97
commit
01a05340c6
@ -36,12 +36,20 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
|
||||
def async_get_provider(
|
||||
hass: HomeAssistant, domain: str | None = None
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
if domain is None:
|
||||
domain = next(iter(hass.data[DOMAIN]))
|
||||
if domain:
|
||||
return hass.data[DOMAIN].get(domain)
|
||||
|
||||
return hass.data[DOMAIN][domain]
|
||||
if not hass.data[DOMAIN]:
|
||||
return None
|
||||
|
||||
if "cloud" in hass.data[DOMAIN]:
|
||||
return hass.data[DOMAIN]["cloud"]
|
||||
|
||||
return next(iter(hass.data[DOMAIN].values()))
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
@ -8,7 +8,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
from homeassistant.components import conversation, media_source, stt
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
@ -17,8 +17,6 @@ from homeassistant.util.dt import utcnow
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
DEFAULT_TIMEOUT = 30 # seconds
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -151,6 +149,9 @@ class PipelineRun:
|
||||
event_callback: Callable[[PipelineEvent], None]
|
||||
language: str = None # type: ignore[assignment]
|
||||
runner_data: Any | None = None
|
||||
stt_provider: stt.Provider | None = None
|
||||
intent_agent: str | None = None
|
||||
tts_engine: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set language for pipeline."""
|
||||
@ -181,13 +182,39 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||
"""Prepare speech to text."""
|
||||
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)
|
||||
|
||||
if stt_provider is None:
|
||||
engine = self.pipeline.stt_engine or "default"
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-missing",
|
||||
message=f"No speech to text provider for: {engine}",
|
||||
)
|
||||
|
||||
if not stt_provider.check_metadata(metadata):
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-unsupported-metadata",
|
||||
message=(
|
||||
f"Provider {engine} does not support input speech "
|
||||
"to text metadata"
|
||||
),
|
||||
)
|
||||
|
||||
self.stt_provider = stt_provider
|
||||
|
||||
async def speech_to_text(
|
||||
self,
|
||||
metadata: stt.SpeechMetadata,
|
||||
stream: AsyncIterable[bytes],
|
||||
) -> str:
|
||||
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
||||
engine = self.pipeline.stt_engine or "default"
|
||||
if self.stt_provider is None:
|
||||
raise RuntimeError("Speech to text was not prepared")
|
||||
|
||||
engine = self.stt_provider.name
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_START,
|
||||
@ -198,28 +225,11 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Load provider
|
||||
stt_provider: stt.Provider = stt.async_get_provider(
|
||||
self.hass, self.pipeline.stt_engine
|
||||
)
|
||||
assert stt_provider is not None
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("No speech to text provider for %s", engine)
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-missing",
|
||||
message=f"No speech to text provider for: {engine}",
|
||||
) from src_error
|
||||
|
||||
if not stt_provider.check_metadata(metadata):
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-unsupported-metadata",
|
||||
message=f"Provider {engine} does not support input speech to text metadata",
|
||||
)
|
||||
|
||||
try:
|
||||
# Transcribe audio stream
|
||||
result = await stt_provider.async_process_audio_stream(metadata, stream)
|
||||
result = await self.stt_provider.async_process_audio_stream(
|
||||
metadata, stream
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech to text")
|
||||
raise SpeechToTextError(
|
||||
@ -253,15 +263,33 @@ class PipelineRun:
|
||||
|
||||
return result.text
|
||||
|
||||
async def prepare_recognize_intent(self) -> None:
|
||||
"""Prepare recognizing an intent."""
|
||||
agent_info = conversation.async_get_agent_info(
|
||||
self.hass, self.pipeline.conversation_engine
|
||||
)
|
||||
|
||||
if agent_info is None:
|
||||
engine = self.pipeline.conversation_engine or "default"
|
||||
raise IntentRecognitionError(
|
||||
code="intent-not-supported",
|
||||
message=f"Intent recognition engine {engine} is not found",
|
||||
)
|
||||
|
||||
self.intent_agent = agent_info["id"]
|
||||
|
||||
async def recognize_intent(
|
||||
self, intent_input: str, conversation_id: str | None
|
||||
) -> str:
|
||||
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||
if self.intent_agent is None:
|
||||
raise RuntimeError("Recognize intent was not prepared")
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_START,
|
||||
{
|
||||
"engine": self.pipeline.conversation_engine or "default",
|
||||
"engine": self.intent_agent,
|
||||
"intent_input": intent_input,
|
||||
},
|
||||
)
|
||||
@ -274,7 +302,7 @@ class PipelineRun:
|
||||
conversation_id=conversation_id,
|
||||
context=self.context,
|
||||
language=self.language,
|
||||
agent_id=self.pipeline.conversation_engine,
|
||||
agent_id=self.intent_agent,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during intent recognition")
|
||||
@ -296,13 +324,38 @@ class PipelineRun:
|
||||
|
||||
return speech
|
||||
|
||||
async def prepare_text_to_speech(self) -> None:
|
||||
"""Prepare text to speech."""
|
||||
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
|
||||
|
||||
if engine is None:
|
||||
engine = self.pipeline.tts_engine or "default"
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=f"Text to speech engine '{engine}' not found",
|
||||
)
|
||||
|
||||
if not await tts.async_support_options(self.hass, engine, self.language):
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.language}"
|
||||
),
|
||||
)
|
||||
|
||||
self.tts_engine = engine
|
||||
|
||||
async def text_to_speech(self, tts_input: str) -> str:
|
||||
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||
if self.tts_engine is None:
|
||||
raise RuntimeError("Text to speech was not prepared")
|
||||
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_START,
|
||||
{
|
||||
"engine": self.pipeline.tts_engine or "default",
|
||||
"engine": self.tts_engine,
|
||||
"tts_input": tts_input,
|
||||
},
|
||||
)
|
||||
@ -315,7 +368,8 @@ class PipelineRun:
|
||||
tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.pipeline.tts_engine,
|
||||
engine=self.tts_engine,
|
||||
language=self.language,
|
||||
),
|
||||
)
|
||||
except Exception as src_error:
|
||||
@ -341,6 +395,8 @@ class PipelineRun:
|
||||
class PipelineInput:
|
||||
"""Input to a pipeline run."""
|
||||
|
||||
run: PipelineRun
|
||||
|
||||
stt_metadata: stt.SpeechMetadata | None = None
|
||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||
|
||||
@ -355,21 +411,10 @@ class PipelineInput:
|
||||
|
||||
conversation_id: str | None = None
|
||||
|
||||
async def execute(
|
||||
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
|
||||
):
|
||||
"""Run pipeline with optional timeout."""
|
||||
await asyncio.wait_for(
|
||||
self._execute(run),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def _execute(self, run: PipelineRun):
|
||||
self._validate(run.start_stage)
|
||||
|
||||
# stt -> intent -> tts
|
||||
run.start()
|
||||
current_stage = run.start_stage
|
||||
async def execute(self):
|
||||
"""Run pipeline."""
|
||||
self.run.start()
|
||||
current_stage = self.run.start_stage
|
||||
|
||||
try:
|
||||
# Speech to text
|
||||
@ -377,29 +422,29 @@ class PipelineInput:
|
||||
if current_stage == PipelineStage.STT:
|
||||
assert self.stt_metadata is not None
|
||||
assert self.stt_stream is not None
|
||||
intent_input = await run.speech_to_text(
|
||||
intent_input = await self.run.speech_to_text(
|
||||
self.stt_metadata,
|
||||
self.stt_stream,
|
||||
)
|
||||
current_stage = PipelineStage.INTENT
|
||||
|
||||
if run.end_stage != PipelineStage.STT:
|
||||
if self.run.end_stage != PipelineStage.STT:
|
||||
tts_input = self.tts_input
|
||||
|
||||
if current_stage == PipelineStage.INTENT:
|
||||
assert intent_input is not None
|
||||
tts_input = await run.recognize_intent(
|
||||
tts_input = await self.run.recognize_intent(
|
||||
intent_input, self.conversation_id
|
||||
)
|
||||
current_stage = PipelineStage.TTS
|
||||
|
||||
if run.end_stage != PipelineStage.INTENT:
|
||||
if self.run.end_stage != PipelineStage.INTENT:
|
||||
if current_stage == PipelineStage.TTS:
|
||||
assert tts_input is not None
|
||||
await run.text_to_speech(tts_input)
|
||||
await self.run.text_to_speech(tts_input)
|
||||
|
||||
except PipelineError as err:
|
||||
run.event_callback(
|
||||
self.run.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": err.code, "message": err.message},
|
||||
@ -407,11 +452,11 @@ class PipelineInput:
|
||||
)
|
||||
return
|
||||
|
||||
run.end()
|
||||
self.run.end()
|
||||
|
||||
def _validate(self, stage: PipelineStage):
|
||||
async def validate(self):
|
||||
"""Validate pipeline input against start stage."""
|
||||
if stage == PipelineStage.STT:
|
||||
if self.run.start_stage == PipelineStage.STT:
|
||||
if self.stt_metadata is None:
|
||||
raise PipelineRunValidationError(
|
||||
"stt_metadata is required for speech to text"
|
||||
@ -421,13 +466,29 @@ class PipelineInput:
|
||||
raise PipelineRunValidationError(
|
||||
"stt_stream is required for speech to text"
|
||||
)
|
||||
elif stage == PipelineStage.INTENT:
|
||||
elif self.run.start_stage == PipelineStage.INTENT:
|
||||
if self.intent_input is None:
|
||||
raise PipelineRunValidationError(
|
||||
"intent_input is required for intent recognition"
|
||||
)
|
||||
elif stage == PipelineStage.TTS:
|
||||
elif self.run.start_stage == PipelineStage.TTS:
|
||||
if self.tts_input is None:
|
||||
raise PipelineRunValidationError(
|
||||
"tts_input is required for text to speech"
|
||||
)
|
||||
|
||||
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
||||
|
||||
prepare_tasks = []
|
||||
|
||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
||||
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
|
||||
|
||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
||||
prepare_tasks.append(self.run.prepare_recognize_intent())
|
||||
|
||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS):
|
||||
prepare_tasks.append(self.run.prepare_text_to_speech())
|
||||
|
||||
if prepare_tasks:
|
||||
await asyncio.gather(*prepare_tasks)
|
||||
|
@ -5,13 +5,13 @@ from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt, websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .pipeline import (
|
||||
DEFAULT_TIMEOUT,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
@ -21,6 +21,8 @@ from .pipeline import (
|
||||
async_get_pipeline,
|
||||
)
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VAD_ENERGY_THRESHOLD = 1000
|
||||
@ -155,37 +157,40 @@ async def websocket_run(
|
||||
# Input to text to speech system
|
||||
input_args["tts_input"] = msg["input"]["text"]
|
||||
|
||||
run_task = hass.async_create_task(
|
||||
PipelineInput(**input_args).execute(
|
||||
PipelineRun(
|
||||
hass,
|
||||
context=connection.context(msg),
|
||||
pipeline=pipeline,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=lambda event: connection.send_event(
|
||||
msg["id"], event.as_dict()
|
||||
),
|
||||
runner_data={
|
||||
"stt_binary_handler_id": handler_id,
|
||||
},
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
input_args["run"] = PipelineRun(
|
||||
hass,
|
||||
context=connection.context(msg),
|
||||
pipeline=pipeline,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=lambda event: connection.send_event(msg["id"], event.as_dict()),
|
||||
runner_data={
|
||||
"stt_binary_handler_id": handler_id,
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
return
|
||||
|
||||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
run_task = hass.async_create_task(pipeline_input.execute())
|
||||
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
|
||||
try:
|
||||
# Task contains a timeout
|
||||
await run_task
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
async with async_timeout.timeout(timeout):
|
||||
await run_task
|
||||
except asyncio.TimeoutError:
|
||||
connection.send_event(
|
||||
msg["id"],
|
||||
|
@ -5,12 +5,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
@ -30,7 +31,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline.3
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
})
|
||||
# ---
|
||||
@ -58,7 +59,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline.5
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'test',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
})
|
||||
# ---
|
||||
@ -66,7 +67,7 @@
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_test.mp3',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
@ -76,12 +77,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
@ -91,12 +93,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 0.1,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
@ -112,6 +115,7 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
@ -134,12 +138,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': 16,
|
||||
'channel': 1,
|
||||
@ -156,12 +161,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
})
|
||||
# ---
|
||||
@ -199,12 +205,13 @@
|
||||
'pipeline': 'en-US',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed.1
|
||||
dict({
|
||||
'engine': 'default',
|
||||
'engine': 'test',
|
||||
'tts_input': 'Lights are on.',
|
||||
})
|
||||
# ---
|
||||
|
@ -93,7 +93,7 @@ class MockTTSProvider(tts.Provider):
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return ["en"]
|
||||
return ["en-US"]
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
@ -264,7 +264,7 @@ async def test_intent_timeout(
|
||||
"start_stage": "intent",
|
||||
"end_stage": "intent",
|
||||
"input": {"text": "Are the lights on?"},
|
||||
"timeout": 0.00001,
|
||||
"timeout": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
@ -301,7 +301,7 @@ async def test_text_pipeline_timeout(
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
|
||||
new=sleepy_run,
|
||||
):
|
||||
await client.send_json(
|
||||
@ -381,7 +381,7 @@ async def test_audio_pipeline_timeout(
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
|
||||
new=sleepy_run,
|
||||
):
|
||||
await client.send_json(
|
||||
@ -427,25 +427,8 @@ async def test_stt_provider_missing(
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# stt
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "stt-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
|
||||
# End of audio stream (handler id + empty payload)
|
||||
await client.send_bytes(b"1")
|
||||
|
||||
# stt error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "stt-provider-missing"
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "stt-provider-missing"
|
||||
|
||||
|
||||
async def test_stt_stream_failed(
|
||||
|
Loading…
x
Reference in New Issue
Block a user