Simplify get pipeline method (#93865)

This commit is contained in:
Paulus Schoutsen 2023-05-31 11:06:03 -04:00 committed by GitHub
parent 4bade86dcc
commit 927b59fe5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 15 deletions

View File

@ -34,6 +34,7 @@ __all__ = (
"Pipeline", "Pipeline",
"PipelineEvent", "PipelineEvent",
"PipelineEventType", "PipelineEventType",
"PipelineNotFound",
) )
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@ -57,13 +58,10 @@ async def async_pipeline_from_audio_stream(
conversation_id: str | None = None, conversation_id: str | None = None,
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream.""" """Create an audio pipeline from an audio stream.
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
Raises PipelineNotFound if no pipeline is found.
"""
pipeline_input = PipelineInput( pipeline_input = PipelineInput(
conversation_id=conversation_id, conversation_id=conversation_id,
stt_metadata=stt_metadata, stt_metadata=stt_metadata,
@ -71,13 +69,12 @@ async def async_pipeline_from_audio_stream(
run=PipelineRun( run=PipelineRun(
hass, hass,
context=context, context=context,
pipeline=pipeline, pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=PipelineStage.STT, start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS,
event_callback=event_callback, event_callback=event_callback,
tts_audio_output=tts_audio_output, tts_audio_output=tts_audio_output,
), ),
) )
await pipeline_input.validate() await pipeline_input.validate()
await pipeline_input.execute() await pipeline_input.execute()

View File

@ -36,6 +36,7 @@ from .const import DOMAIN
from .error import ( from .error import (
IntentRecognitionError, IntentRecognitionError,
PipelineError, PipelineError,
PipelineNotFound,
SpeechToTextError, SpeechToTextError,
TextToSpeechError, TextToSpeechError,
) )
@ -208,9 +209,7 @@ async def async_create_default_pipeline(
@callback @callback
def async_get_pipeline( def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or the preferred pipeline.""" """Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
@ -218,7 +217,15 @@ def async_get_pipeline(
# A pipeline was not specified, use the preferred one # A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
return pipeline_data.pipeline_store.data.get(pipeline_id) pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
# If invalid pipeline ID was specified
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
return pipeline
@callback @callback

View File

@ -17,6 +17,7 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.util import language as language_util from homeassistant.util import language as language_util
from .const import DOMAIN from .const import DOMAIN
from .error import PipelineNotFound
from .pipeline import ( from .pipeline import (
PipelineData, PipelineData,
PipelineError, PipelineError,
@ -85,8 +86,9 @@ async def websocket_run(
) -> None: ) -> None:
"""Run a pipeline.""" """Run a pipeline."""
pipeline_id = msg.get("pipeline") pipeline_id = msg.get("pipeline")
try:
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id) pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None: except PipelineNotFound:
connection.send_error( connection.send_error(
msg["id"], msg["id"],
"pipeline-not-found", "pipeline-not-found",

View File

@ -15,6 +15,7 @@ from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
PipelineEvent, PipelineEvent,
PipelineEventType, PipelineEventType,
PipelineNotFound,
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select, select as pipeline_select,
) )
@ -337,6 +338,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
await self._tts_done.wait() await self._tts_done.wait()
_LOGGER.debug("Pipeline finished") _LOGGER.debug("Pipeline finished")
except PipelineNotFound:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline not found",
"message": "Selected pipeline timeout",
},
)
_LOGGER.warning("Pipeline not found")
except asyncio.TimeoutError: except asyncio.TimeoutError:
self.handle_event( self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,

View File

@ -18,6 +18,7 @@ from homeassistant.components.assist_pipeline import (
Pipeline, Pipeline,
PipelineEvent, PipelineEvent,
PipelineEventType, PipelineEventType,
PipelineNotFound,
async_get_pipeline, async_get_pipeline,
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select, select as pipeline_select,
@ -45,7 +46,11 @@ def make_protocol(
DOMAIN, DOMAIN,
voip_device.voip_id, voip_device.voip_id,
) )
pipeline = async_get_pipeline(hass, pipeline_id) try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
pipeline = None
if ( if (
(pipeline is None) (pipeline is None)
or (pipeline.stt_engine is None) or (pipeline.stt_engine is None)
@ -261,6 +266,8 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
await self._tts_done.wait() await self._tts_done.wait()
_LOGGER.debug("Pipeline finished") _LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Expected after caller hangs up # Expected after caller hangs up
_LOGGER.debug("Pipeline timeout") _LOGGER.debug("Pipeline timeout")