mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Simplify get pipeline method (#93865)
This commit is contained in:
parent
4bade86dcc
commit
927b59fe5a
@ -34,6 +34,7 @@ __all__ = (
|
|||||||
"Pipeline",
|
"Pipeline",
|
||||||
"PipelineEvent",
|
"PipelineEvent",
|
||||||
"PipelineEventType",
|
"PipelineEventType",
|
||||||
|
"PipelineNotFound",
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||||
@ -57,13 +58,10 @@ async def async_pipeline_from_audio_stream(
|
|||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
tts_audio_output: str | None = None,
|
tts_audio_output: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream."""
|
"""Create an audio pipeline from an audio stream.
|
||||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
|
||||||
if pipeline is None:
|
|
||||||
raise PipelineNotFound(
|
|
||||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
Raises PipelineNotFound if no pipeline is found.
|
||||||
|
"""
|
||||||
pipeline_input = PipelineInput(
|
pipeline_input = PipelineInput(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
@ -71,13 +69,12 @@ async def async_pipeline_from_audio_stream(
|
|||||||
run=PipelineRun(
|
run=PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=context,
|
context=context,
|
||||||
pipeline=pipeline,
|
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||||
start_stage=PipelineStage.STT,
|
start_stage=PipelineStage.STT,
|
||||||
end_stage=PipelineStage.TTS,
|
end_stage=PipelineStage.TTS,
|
||||||
event_callback=event_callback,
|
event_callback=event_callback,
|
||||||
tts_audio_output=tts_audio_output,
|
tts_audio_output=tts_audio_output,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
await pipeline_input.execute()
|
await pipeline_input.execute()
|
||||||
|
@ -36,6 +36,7 @@ from .const import DOMAIN
|
|||||||
from .error import (
|
from .error import (
|
||||||
IntentRecognitionError,
|
IntentRecognitionError,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
|
PipelineNotFound,
|
||||||
SpeechToTextError,
|
SpeechToTextError,
|
||||||
TextToSpeechError,
|
TextToSpeechError,
|
||||||
)
|
)
|
||||||
@ -208,9 +209,7 @@ async def async_create_default_pipeline(
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_pipeline(
|
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||||
hass: HomeAssistant, pipeline_id: str | None = None
|
|
||||||
) -> Pipeline | None:
|
|
||||||
"""Get a pipeline by id or the preferred pipeline."""
|
"""Get a pipeline by id or the preferred pipeline."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
|
||||||
@ -218,7 +217,15 @@ def async_get_pipeline(
|
|||||||
# A pipeline was not specified, use the preferred one
|
# A pipeline was not specified, use the preferred one
|
||||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||||
|
|
||||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
|
# If invalid pipeline ID was specified
|
||||||
|
if pipeline is None:
|
||||||
|
raise PipelineNotFound(
|
||||||
|
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -17,6 +17,7 @@ from homeassistant.helpers import config_validation as cv
|
|||||||
from homeassistant.util import language as language_util
|
from homeassistant.util import language as language_util
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineError,
|
PipelineError,
|
||||||
@ -85,8 +86,9 @@ async def websocket_run(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run a pipeline."""
|
"""Run a pipeline."""
|
||||||
pipeline_id = msg.get("pipeline")
|
pipeline_id = msg.get("pipeline")
|
||||||
|
try:
|
||||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||||
if pipeline is None:
|
except PipelineNotFound:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
"pipeline-not-found",
|
"pipeline-not-found",
|
||||||
|
@ -15,6 +15,7 @@ from homeassistant.components import stt, tts
|
|||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
|
PipelineNotFound,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
@ -337,6 +338,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
await self._tts_done.wait()
|
await self._tts_done.wait()
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
except PipelineNotFound:
|
||||||
|
self.handle_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
|
{
|
||||||
|
"code": "pipeline not found",
|
||||||
|
"message": "Selected pipeline timeout",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_LOGGER.warning("Pipeline not found")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.handle_event(
|
self.handle_event(
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
|
@ -18,6 +18,7 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
|
PipelineNotFound,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
@ -45,7 +46,11 @@ def make_protocol(
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
voip_device.voip_id,
|
voip_device.voip_id,
|
||||||
)
|
)
|
||||||
pipeline = async_get_pipeline(hass, pipeline_id)
|
try:
|
||||||
|
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||||
|
except PipelineNotFound:
|
||||||
|
pipeline = None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(pipeline is None)
|
(pipeline is None)
|
||||||
or (pipeline.stt_engine is None)
|
or (pipeline.stt_engine is None)
|
||||||
@ -261,6 +266,8 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
await self._tts_done.wait()
|
await self._tts_done.wait()
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
except PipelineNotFound:
|
||||||
|
_LOGGER.warning("Pipeline not found")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Expected after caller hangs up
|
# Expected after caller hangs up
|
||||||
_LOGGER.debug("Pipeline timeout")
|
_LOGGER.debug("Pipeline timeout")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user