mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 14:57:09 +00:00
Voice Assistant: improve error handling (#90541)
Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
84eb9c5f97
commit
01a05340c6
@ -36,12 +36,20 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
|
def async_get_provider(
|
||||||
|
hass: HomeAssistant, domain: str | None = None
|
||||||
|
) -> Provider | None:
|
||||||
"""Return provider."""
|
"""Return provider."""
|
||||||
if domain is None:
|
if domain:
|
||||||
domain = next(iter(hass.data[DOMAIN]))
|
return hass.data[DOMAIN].get(domain)
|
||||||
|
|
||||||
return hass.data[DOMAIN][domain]
|
if not hass.data[DOMAIN]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "cloud" in hass.data[DOMAIN]:
|
||||||
|
return hass.data[DOMAIN]["cloud"]
|
||||||
|
|
||||||
|
return next(iter(hass.data[DOMAIN].values()))
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
@ -8,7 +8,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.backports.enum import StrEnum
|
from homeassistant.backports.enum import StrEnum
|
||||||
from homeassistant.components import conversation, media_source, stt
|
from homeassistant.components import conversation, media_source, stt, tts
|
||||||
from homeassistant.components.tts.media_source import (
|
from homeassistant.components.tts.media_source import (
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
@ -17,8 +17,6 @@ from homeassistant.util.dt import utcnow
|
|||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 30 # seconds
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -151,6 +149,9 @@ class PipelineRun:
|
|||||||
event_callback: Callable[[PipelineEvent], None]
|
event_callback: Callable[[PipelineEvent], None]
|
||||||
language: str = None # type: ignore[assignment]
|
language: str = None # type: ignore[assignment]
|
||||||
runner_data: Any | None = None
|
runner_data: Any | None = None
|
||||||
|
stt_provider: stt.Provider | None = None
|
||||||
|
intent_agent: str | None = None
|
||||||
|
tts_engine: str | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
@ -181,13 +182,39 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||||
|
"""Prepare speech to text."""
|
||||||
|
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)
|
||||||
|
|
||||||
|
if stt_provider is None:
|
||||||
|
engine = self.pipeline.stt_engine or "default"
|
||||||
|
raise SpeechToTextError(
|
||||||
|
code="stt-provider-missing",
|
||||||
|
message=f"No speech to text provider for: {engine}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stt_provider.check_metadata(metadata):
|
||||||
|
raise SpeechToTextError(
|
||||||
|
code="stt-provider-unsupported-metadata",
|
||||||
|
message=(
|
||||||
|
f"Provider {engine} does not support input speech "
|
||||||
|
"to text metadata"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stt_provider = stt_provider
|
||||||
|
|
||||||
async def speech_to_text(
|
async def speech_to_text(
|
||||||
self,
|
self,
|
||||||
metadata: stt.SpeechMetadata,
|
metadata: stt.SpeechMetadata,
|
||||||
stream: AsyncIterable[bytes],
|
stream: AsyncIterable[bytes],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
||||||
engine = self.pipeline.stt_engine or "default"
|
if self.stt_provider is None:
|
||||||
|
raise RuntimeError("Speech to text was not prepared")
|
||||||
|
|
||||||
|
engine = self.stt_provider.name
|
||||||
|
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.STT_START,
|
PipelineEventType.STT_START,
|
||||||
@ -198,28 +225,11 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
# Load provider
|
|
||||||
stt_provider: stt.Provider = stt.async_get_provider(
|
|
||||||
self.hass, self.pipeline.stt_engine
|
|
||||||
)
|
|
||||||
assert stt_provider is not None
|
|
||||||
except Exception as src_error:
|
|
||||||
_LOGGER.exception("No speech to text provider for %s", engine)
|
|
||||||
raise SpeechToTextError(
|
|
||||||
code="stt-provider-missing",
|
|
||||||
message=f"No speech to text provider for: {engine}",
|
|
||||||
) from src_error
|
|
||||||
|
|
||||||
if not stt_provider.check_metadata(metadata):
|
|
||||||
raise SpeechToTextError(
|
|
||||||
code="stt-provider-unsupported-metadata",
|
|
||||||
message=f"Provider {engine} does not support input speech to text metadata",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Transcribe audio stream
|
# Transcribe audio stream
|
||||||
result = await stt_provider.async_process_audio_stream(metadata, stream)
|
result = await self.stt_provider.async_process_audio_stream(
|
||||||
|
metadata, stream
|
||||||
|
)
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during speech to text")
|
_LOGGER.exception("Unexpected error during speech to text")
|
||||||
raise SpeechToTextError(
|
raise SpeechToTextError(
|
||||||
@ -253,15 +263,33 @@ class PipelineRun:
|
|||||||
|
|
||||||
return result.text
|
return result.text
|
||||||
|
|
||||||
|
async def prepare_recognize_intent(self) -> None:
|
||||||
|
"""Prepare recognizing an intent."""
|
||||||
|
agent_info = conversation.async_get_agent_info(
|
||||||
|
self.hass, self.pipeline.conversation_engine
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_info is None:
|
||||||
|
engine = self.pipeline.conversation_engine or "default"
|
||||||
|
raise IntentRecognitionError(
|
||||||
|
code="intent-not-supported",
|
||||||
|
message=f"Intent recognition engine {engine} is not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intent_agent = agent_info["id"]
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self, intent_input: str, conversation_id: str | None
|
self, intent_input: str, conversation_id: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||||
|
if self.intent_agent is None:
|
||||||
|
raise RuntimeError("Recognize intent was not prepared")
|
||||||
|
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.INTENT_START,
|
PipelineEventType.INTENT_START,
|
||||||
{
|
{
|
||||||
"engine": self.pipeline.conversation_engine or "default",
|
"engine": self.intent_agent,
|
||||||
"intent_input": intent_input,
|
"intent_input": intent_input,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -274,7 +302,7 @@ class PipelineRun:
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
context=self.context,
|
context=self.context,
|
||||||
language=self.language,
|
language=self.language,
|
||||||
agent_id=self.pipeline.conversation_engine,
|
agent_id=self.intent_agent,
|
||||||
)
|
)
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during intent recognition")
|
_LOGGER.exception("Unexpected error during intent recognition")
|
||||||
@ -296,13 +324,38 @@ class PipelineRun:
|
|||||||
|
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
|
async def prepare_text_to_speech(self) -> None:
|
||||||
|
"""Prepare text to speech."""
|
||||||
|
engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)
|
||||||
|
|
||||||
|
if engine is None:
|
||||||
|
engine = self.pipeline.tts_engine or "default"
|
||||||
|
raise TextToSpeechError(
|
||||||
|
code="tts-not-supported",
|
||||||
|
message=f"Text to speech engine '{engine}' not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await tts.async_support_options(self.hass, engine, self.language):
|
||||||
|
raise TextToSpeechError(
|
||||||
|
code="tts-not-supported",
|
||||||
|
message=(
|
||||||
|
f"Text to speech engine {engine} "
|
||||||
|
f"does not support language {self.language}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tts_engine = engine
|
||||||
|
|
||||||
async def text_to_speech(self, tts_input: str) -> str:
|
async def text_to_speech(self, tts_input: str) -> str:
|
||||||
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||||
|
if self.tts_engine is None:
|
||||||
|
raise RuntimeError("Text to speech was not prepared")
|
||||||
|
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_START,
|
PipelineEventType.TTS_START,
|
||||||
{
|
{
|
||||||
"engine": self.pipeline.tts_engine or "default",
|
"engine": self.tts_engine,
|
||||||
"tts_input": tts_input,
|
"tts_input": tts_input,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -315,7 +368,8 @@ class PipelineRun:
|
|||||||
tts_generate_media_source_id(
|
tts_generate_media_source_id(
|
||||||
self.hass,
|
self.hass,
|
||||||
tts_input,
|
tts_input,
|
||||||
engine=self.pipeline.tts_engine,
|
engine=self.tts_engine,
|
||||||
|
language=self.language,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
@ -341,6 +395,8 @@ class PipelineRun:
|
|||||||
class PipelineInput:
|
class PipelineInput:
|
||||||
"""Input to a pipeline run."""
|
"""Input to a pipeline run."""
|
||||||
|
|
||||||
|
run: PipelineRun
|
||||||
|
|
||||||
stt_metadata: stt.SpeechMetadata | None = None
|
stt_metadata: stt.SpeechMetadata | None = None
|
||||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||||
|
|
||||||
@ -355,21 +411,10 @@ class PipelineInput:
|
|||||||
|
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
|
|
||||||
async def execute(
|
async def execute(self):
|
||||||
self, run: PipelineRun, timeout: int | float | None = DEFAULT_TIMEOUT
|
"""Run pipeline."""
|
||||||
):
|
self.run.start()
|
||||||
"""Run pipeline with optional timeout."""
|
current_stage = self.run.start_stage
|
||||||
await asyncio.wait_for(
|
|
||||||
self._execute(run),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute(self, run: PipelineRun):
|
|
||||||
self._validate(run.start_stage)
|
|
||||||
|
|
||||||
# stt -> intent -> tts
|
|
||||||
run.start()
|
|
||||||
current_stage = run.start_stage
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Speech to text
|
# Speech to text
|
||||||
@ -377,29 +422,29 @@ class PipelineInput:
|
|||||||
if current_stage == PipelineStage.STT:
|
if current_stage == PipelineStage.STT:
|
||||||
assert self.stt_metadata is not None
|
assert self.stt_metadata is not None
|
||||||
assert self.stt_stream is not None
|
assert self.stt_stream is not None
|
||||||
intent_input = await run.speech_to_text(
|
intent_input = await self.run.speech_to_text(
|
||||||
self.stt_metadata,
|
self.stt_metadata,
|
||||||
self.stt_stream,
|
self.stt_stream,
|
||||||
)
|
)
|
||||||
current_stage = PipelineStage.INTENT
|
current_stage = PipelineStage.INTENT
|
||||||
|
|
||||||
if run.end_stage != PipelineStage.STT:
|
if self.run.end_stage != PipelineStage.STT:
|
||||||
tts_input = self.tts_input
|
tts_input = self.tts_input
|
||||||
|
|
||||||
if current_stage == PipelineStage.INTENT:
|
if current_stage == PipelineStage.INTENT:
|
||||||
assert intent_input is not None
|
assert intent_input is not None
|
||||||
tts_input = await run.recognize_intent(
|
tts_input = await self.run.recognize_intent(
|
||||||
intent_input, self.conversation_id
|
intent_input, self.conversation_id
|
||||||
)
|
)
|
||||||
current_stage = PipelineStage.TTS
|
current_stage = PipelineStage.TTS
|
||||||
|
|
||||||
if run.end_stage != PipelineStage.INTENT:
|
if self.run.end_stage != PipelineStage.INTENT:
|
||||||
if current_stage == PipelineStage.TTS:
|
if current_stage == PipelineStage.TTS:
|
||||||
assert tts_input is not None
|
assert tts_input is not None
|
||||||
await run.text_to_speech(tts_input)
|
await self.run.text_to_speech(tts_input)
|
||||||
|
|
||||||
except PipelineError as err:
|
except PipelineError as err:
|
||||||
run.event_callback(
|
self.run.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.ERROR,
|
PipelineEventType.ERROR,
|
||||||
{"code": err.code, "message": err.message},
|
{"code": err.code, "message": err.message},
|
||||||
@ -407,11 +452,11 @@ class PipelineInput:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
run.end()
|
self.run.end()
|
||||||
|
|
||||||
def _validate(self, stage: PipelineStage):
|
async def validate(self):
|
||||||
"""Validate pipeline input against start stage."""
|
"""Validate pipeline input against start stage."""
|
||||||
if stage == PipelineStage.STT:
|
if self.run.start_stage == PipelineStage.STT:
|
||||||
if self.stt_metadata is None:
|
if self.stt_metadata is None:
|
||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"stt_metadata is required for speech to text"
|
"stt_metadata is required for speech to text"
|
||||||
@ -421,13 +466,29 @@ class PipelineInput:
|
|||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"stt_stream is required for speech to text"
|
"stt_stream is required for speech to text"
|
||||||
)
|
)
|
||||||
elif stage == PipelineStage.INTENT:
|
elif self.run.start_stage == PipelineStage.INTENT:
|
||||||
if self.intent_input is None:
|
if self.intent_input is None:
|
||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"intent_input is required for intent recognition"
|
"intent_input is required for intent recognition"
|
||||||
)
|
)
|
||||||
elif stage == PipelineStage.TTS:
|
elif self.run.start_stage == PipelineStage.TTS:
|
||||||
if self.tts_input is None:
|
if self.tts_input is None:
|
||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"tts_input is required for text to speech"
|
"tts_input is required for text to speech"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
|
||||||
|
|
||||||
|
prepare_tasks = []
|
||||||
|
|
||||||
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
||||||
|
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
|
||||||
|
|
||||||
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
||||||
|
prepare_tasks.append(self.run.prepare_recognize_intent())
|
||||||
|
|
||||||
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS):
|
||||||
|
prepare_tasks.append(self.run.prepare_text_to_speech())
|
||||||
|
|
||||||
|
if prepare_tasks:
|
||||||
|
await asyncio.gather(*prepare_tasks)
|
||||||
|
@ -5,13 +5,13 @@ from collections.abc import Callable
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import async_timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import stt, websocket_api
|
from homeassistant.components import stt, websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
DEFAULT_TIMEOUT,
|
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
@ -21,6 +21,8 @@ from .pipeline import (
|
|||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 30
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_VAD_ENERGY_THRESHOLD = 1000
|
_VAD_ENERGY_THRESHOLD = 1000
|
||||||
@ -155,37 +157,40 @@ async def websocket_run(
|
|||||||
# Input to text to speech system
|
# Input to text to speech system
|
||||||
input_args["tts_input"] = msg["input"]["text"]
|
input_args["tts_input"] = msg["input"]["text"]
|
||||||
|
|
||||||
run_task = hass.async_create_task(
|
input_args["run"] = PipelineRun(
|
||||||
PipelineInput(**input_args).execute(
|
|
||||||
PipelineRun(
|
|
||||||
hass,
|
hass,
|
||||||
context=connection.context(msg),
|
context=connection.context(msg),
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
start_stage=start_stage,
|
start_stage=start_stage,
|
||||||
end_stage=end_stage,
|
end_stage=end_stage,
|
||||||
event_callback=lambda event: connection.send_event(
|
event_callback=lambda event: connection.send_event(msg["id"], event.as_dict()),
|
||||||
msg["id"], event.as_dict()
|
|
||||||
),
|
|
||||||
runner_data={
|
runner_data={
|
||||||
"stt_binary_handler_id": handler_id,
|
"stt_binary_handler_id": handler_id,
|
||||||
|
"timeout": timeout,
|
||||||
},
|
},
|
||||||
),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cancel pipeline if user unsubscribes
|
pipeline_input = PipelineInput(**input_args)
|
||||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
|
||||||
|
try:
|
||||||
|
await pipeline_input.validate()
|
||||||
|
except PipelineError as error:
|
||||||
|
# Report more specific error when possible
|
||||||
|
connection.send_error(msg["id"], error.code, error.message)
|
||||||
|
return
|
||||||
|
|
||||||
# Confirm subscription
|
# Confirm subscription
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
run_task = hass.async_create_task(pipeline_input.execute())
|
||||||
|
|
||||||
|
# Cancel pipeline if user unsubscribes
|
||||||
|
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Task contains a timeout
|
# Task contains a timeout
|
||||||
|
async with async_timeout.timeout(timeout):
|
||||||
await run_task
|
await run_task
|
||||||
except PipelineError as error:
|
|
||||||
# Report more specific error when possible
|
|
||||||
connection.send_error(msg["id"], error.code, error.message)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
connection.send_event(
|
connection.send_event(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
|
@ -5,12 +5,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.1
|
# name: test_audio_pipeline.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'test',
|
||||||
'metadata': dict({
|
'metadata': dict({
|
||||||
'bit_rate': 16,
|
'bit_rate': 16,
|
||||||
'channel': 1,
|
'channel': 1,
|
||||||
@ -30,7 +31,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.3
|
# name: test_audio_pipeline.3
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -58,7 +59,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.5
|
# name: test_audio_pipeline.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'test',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -66,7 +67,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -76,12 +77,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_failed.1
|
# name: test_intent_failed.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -91,12 +93,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 0.1,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout.1
|
# name: test_intent_timeout.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -112,6 +115,7 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -134,12 +138,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_stt_stream_failed.1
|
# name: test_stt_stream_failed.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'test',
|
||||||
'metadata': dict({
|
'metadata': dict({
|
||||||
'bit_rate': 16,
|
'bit_rate': 16,
|
||||||
'channel': 1,
|
'channel': 1,
|
||||||
@ -156,12 +161,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline.1
|
# name: test_text_only_pipeline.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
@ -199,12 +205,13 @@
|
|||||||
'pipeline': 'en-US',
|
'pipeline': 'en-US',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 30,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_tts_failed.1
|
# name: test_tts_failed.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'default',
|
'engine': 'test',
|
||||||
'tts_input': 'Lights are on.',
|
'tts_input': 'Lights are on.',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -93,7 +93,7 @@ class MockTTSProvider(tts.Provider):
|
|||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return list of supported languages."""
|
"""Return list of supported languages."""
|
||||||
return ["en"]
|
return ["en-US"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_options(self) -> list[str]:
|
def supported_options(self) -> list[str]:
|
||||||
@ -264,7 +264,7 @@ async def test_intent_timeout(
|
|||||||
"start_stage": "intent",
|
"start_stage": "intent",
|
||||||
"end_stage": "intent",
|
"end_stage": "intent",
|
||||||
"input": {"text": "Are the lights on?"},
|
"input": {"text": "Are the lights on?"},
|
||||||
"timeout": 0.00001,
|
"timeout": 0.1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,7 +301,7 @@ async def test_text_pipeline_timeout(
|
|||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
|
||||||
new=sleepy_run,
|
new=sleepy_run,
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
@ -381,7 +381,7 @@ async def test_audio_pipeline_timeout(
|
|||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.voice_assistant.pipeline.PipelineInput._execute",
|
"homeassistant.components.voice_assistant.pipeline.PipelineInput.execute",
|
||||||
new=sleepy_run,
|
new=sleepy_run,
|
||||||
):
|
):
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
@ -427,25 +427,8 @@ async def test_stt_provider_missing(
|
|||||||
|
|
||||||
# result
|
# result
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "stt-provider-missing"
|
||||||
# run start
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["event"]["type"] == "run-start"
|
|
||||||
assert msg["event"]["data"] == snapshot
|
|
||||||
|
|
||||||
# stt
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["event"]["type"] == "stt-start"
|
|
||||||
assert msg["event"]["data"] == snapshot
|
|
||||||
|
|
||||||
# End of audio stream (handler id + empty payload)
|
|
||||||
await client.send_bytes(b"1")
|
|
||||||
|
|
||||||
# stt error
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["event"]["type"] == "error"
|
|
||||||
assert msg["event"]["data"]["code"] == "stt-provider-missing"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_stt_stream_failed(
|
async def test_stt_stream_failed(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user