Compare commits

...

1 Commits

Author SHA1 Message Date
Paulus Schoutsen
c3906e1a3a Handle validation error when starting stream from audio 2026-03-22 07:49:32 -04:00
3 changed files with 98 additions and 2 deletions

View File

@@ -24,7 +24,7 @@ from .const import (
SAMPLE_WIDTH,
SAMPLES_PER_CHUNK,
)
from .error import PipelineNotFound
from .error import PipelineError, PipelineNotFound
from .pipeline import (
AudioSettings,
Pipeline,
@@ -137,5 +137,21 @@ async def async_pipeline_from_audio_stream(
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()
try:
await pipeline_input.validate()
except PipelineError as err:
pipeline_input.run.start(
conversation_id=session.conversation_id,
device_id=device_id,
satellite_id=satellite_id,
)
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": err.code, "message": err.message},
)
)
await pipeline_input.run.end()
return
await pipeline_input.execute()

View File

@@ -330,6 +330,49 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
assert not events
async def test_pipeline_from_audio_stream_validation_pipeline_error(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,
init_components,
) -> None:
"""Test validation pipeline errors are emitted as terminal events."""
events: list[assist_pipeline.PipelineEvent] = []
await assist_pipeline.async_update_pipeline(
hass,
assist_pipeline.async_get_pipeline(hass),
conversation_engine="conversation.non_existing",
)
async def audio_data():
yield b"audio"
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
end_stage=assist_pipeline.PipelineStage.INTENT,
)
assert len(events) == 3
assert events[0].type == assist_pipeline.PipelineEventType.RUN_START
assert events[1].type == assist_pipeline.PipelineEventType.ERROR
assert events[1].data == {
"code": "intent-not-supported",
"message": "Intent recognition engine conversation.non_existing is not found",
}
assert events[2].type == assist_pipeline.PipelineEventType.RUN_END
async def test_pipeline_from_audio_stream_wake_word(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,

View File

@@ -184,6 +184,43 @@ async def test_new_pipeline_cancels_pipeline(
await pipeline2_finished.wait()
async def test_pipeline_validation_error_ends_pipeline(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test validation pipeline errors end the satellite pipeline cleanly."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
stt_engine="test-stt-engine",
stt_language="en",
conversation_engine="conversation.non_existing",
)
with patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text"
):
await entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
end_stage=PipelineStage.INTENT,
)
assert [event.type for event in entity.events[-3:]] == [
PipelineEventType.RUN_START,
PipelineEventType.ERROR,
PipelineEventType.RUN_END,
]
assert entity.events[-2].data == {
"code": "intent-not-supported",
"message": "Intent recognition engine conversation.non_existing is not found",
}
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == AssistSatelliteState.IDLE
@pytest.mark.parametrize(
("service_data", "expected_params"),
[