mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Simplify get pipeline method (#93865)
This commit is contained in:
parent
4bade86dcc
commit
927b59fe5a
@ -34,6 +34,7 @@ __all__ = (
|
||||
"Pipeline",
|
||||
"PipelineEvent",
|
||||
"PipelineEventType",
|
||||
"PipelineNotFound",
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
@ -57,13 +58,10 @@ async def async_pipeline_from_audio_stream(
|
||||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
) -> None:
|
||||
"""Create an audio pipeline from an audio stream."""
|
||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
if pipeline is None:
|
||||
raise PipelineNotFound(
|
||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||
)
|
||||
"""Create an audio pipeline from an audio stream.
|
||||
|
||||
Raises PipelineNotFound if no pipeline is found.
|
||||
"""
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=conversation_id,
|
||||
stt_metadata=stt_metadata,
|
||||
@ -71,13 +69,12 @@ async def async_pipeline_from_audio_stream(
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=PipelineStage.STT,
|
||||
end_stage=PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
),
|
||||
)
|
||||
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
@ -36,6 +36,7 @@ from .const import DOMAIN
|
||||
from .error import (
|
||||
IntentRecognitionError,
|
||||
PipelineError,
|
||||
PipelineNotFound,
|
||||
SpeechToTextError,
|
||||
TextToSpeechError,
|
||||
)
|
||||
@ -208,9 +209,7 @@ async def async_create_default_pipeline(
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pipeline(
|
||||
hass: HomeAssistant, pipeline_id: str | None = None
|
||||
) -> Pipeline | None:
|
||||
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||
"""Get a pipeline by id or the preferred pipeline."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
||||
@ -218,7 +217,15 @@ def async_get_pipeline(
|
||||
# A pipeline was not specified, use the preferred one
|
||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||
|
||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||
|
||||
# If invalid pipeline ID was specified
|
||||
if pipeline is None:
|
||||
raise PipelineNotFound(
|
||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -17,6 +17,7 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
PipelineData,
|
||||
PipelineError,
|
||||
@ -85,8 +86,9 @@ async def websocket_run(
|
||||
) -> None:
|
||||
"""Run a pipeline."""
|
||||
pipeline_id = msg.get("pipeline")
|
||||
try:
|
||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
if pipeline is None:
|
||||
except PipelineNotFound:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
"pipeline-not-found",
|
||||
|
@ -15,6 +15,7 @@ from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
@ -337,6 +338,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": "pipeline not found",
|
||||
"message": "Selected pipeline timeout",
|
||||
},
|
||||
)
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except asyncio.TimeoutError:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
|
@ -18,6 +18,7 @@ from homeassistant.components.assist_pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_get_pipeline,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
@ -45,7 +46,11 @@ def make_protocol(
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
pipeline = async_get_pipeline(hass, pipeline_id)
|
||||
try:
|
||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||
except PipelineNotFound:
|
||||
pipeline = None
|
||||
|
||||
if (
|
||||
(pipeline is None)
|
||||
or (pipeline.stt_engine is None)
|
||||
@ -261,6 +266,8 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except asyncio.TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
|
Loading…
x
Reference in New Issue
Block a user