Simplify get pipeline method (#93865)

This commit is contained in:
Paulus Schoutsen 2023-05-31 11:06:03 -04:00 committed by GitHub
parent 4bade86dcc
commit 927b59fe5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 15 deletions

View File

@ -34,6 +34,7 @@ __all__ = (
"Pipeline",
"PipelineEvent",
"PipelineEventType",
"PipelineNotFound",
)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@ -57,13 +58,10 @@ async def async_pipeline_from_audio_stream(
conversation_id: str | None = None,
tts_audio_output: str | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
"""Create an audio pipeline from an audio stream.
Raises PipelineNotFound if no pipeline is found.
"""
pipeline_input = PipelineInput(
conversation_id=conversation_id,
stt_metadata=stt_metadata,
@ -71,13 +69,12 @@ async def async_pipeline_from_audio_stream(
run=PipelineRun(
hass,
context=context,
pipeline=pipeline,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
),
)
await pipeline_input.validate()
await pipeline_input.execute()

View File

@ -36,6 +36,7 @@ from .const import DOMAIN
from .error import (
IntentRecognitionError,
PipelineError,
PipelineNotFound,
SpeechToTextError,
TextToSpeechError,
)
@ -208,9 +209,7 @@ async def async_create_default_pipeline(
@callback
def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None:
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
@ -218,7 +217,15 @@ def async_get_pipeline(
# A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
return pipeline_data.pipeline_store.data.get(pipeline_id)
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
# If invalid pipeline ID was specified
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
return pipeline
@callback

View File

@ -17,6 +17,7 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util
from .const import DOMAIN
from .error import PipelineNotFound
from .pipeline import (
PipelineData,
PipelineError,
@ -85,8 +86,9 @@ async def websocket_run(
) -> None:
"""Run a pipeline."""
pipeline_id = msg.get("pipeline")
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
try:
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
except PipelineNotFound:
connection.send_error(
msg["id"],
"pipeline-not-found",

View File

@ -15,6 +15,7 @@ from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
@ -337,6 +338,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline not found",
"message": "Selected pipeline timeout",
},
)
_LOGGER.warning("Pipeline not found")
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,

View File

@ -18,6 +18,7 @@ from homeassistant.components.assist_pipeline import (
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_get_pipeline,
async_pipeline_from_audio_stream,
select as pipeline_select,
@ -45,7 +46,11 @@ def make_protocol(
DOMAIN,
voip_device.voip_id,
)
pipeline = async_get_pipeline(hass, pipeline_id)
try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
pipeline = None
if (
(pipeline is None)
or (pipeline.stt_engine is None)
@ -261,6 +266,8 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except asyncio.TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")