mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Small improvement of assist_pipeline test coverage (#92115)
This commit is contained in:
parent
57af4672d5
commit
887e656570
@ -5,7 +5,7 @@ import asyncio
|
|||||||
from collections.abc import AsyncIterable, Callable, Iterable
|
from collections.abc import AsyncIterable, Callable, Iterable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -332,12 +332,12 @@ class PipelineRun:
|
|||||||
event_callback: PipelineEventCallback
|
event_callback: PipelineEventCallback
|
||||||
language: str = None # type: ignore[assignment]
|
language: str = None # type: ignore[assignment]
|
||||||
runner_data: Any | None = None
|
runner_data: Any | None = None
|
||||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
|
||||||
intent_agent: str | None = None
|
intent_agent: str | None = None
|
||||||
tts_engine: str | None = None
|
|
||||||
tts_audio_output: str | None = None
|
tts_audio_output: str | None = None
|
||||||
|
|
||||||
id: str = field(default_factory=ulid_util.ulid)
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
|
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||||
|
tts_engine: str = field(init=False)
|
||||||
tts_options: dict | None = field(init=False, default=None)
|
tts_options: dict | None = field(init=False, default=None)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -388,8 +388,6 @@ class PipelineRun:
|
|||||||
|
|
||||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||||
"""Prepare speech to text."""
|
"""Prepare speech to text."""
|
||||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
|
||||||
|
|
||||||
# pipeline.stt_engine can't be None or this function is not called
|
# pipeline.stt_engine can't be None or this function is not called
|
||||||
stt_provider = stt.async_get_speech_to_text_engine(
|
stt_provider = stt.async_get_speech_to_text_engine(
|
||||||
self.hass,
|
self.hass,
|
||||||
@ -422,9 +420,6 @@ class PipelineRun:
|
|||||||
stream: AsyncIterable[bytes],
|
stream: AsyncIterable[bytes],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
||||||
if self.stt_provider is None:
|
|
||||||
raise RuntimeError("Speech to text was not prepared")
|
|
||||||
|
|
||||||
if isinstance(self.stt_provider, stt.Provider):
|
if isinstance(self.stt_provider, stt.Provider):
|
||||||
engine = self.stt_provider.name
|
engine = self.stt_provider.name
|
||||||
else:
|
else:
|
||||||
@ -547,7 +542,8 @@ class PipelineRun:
|
|||||||
|
|
||||||
async def prepare_text_to_speech(self) -> None:
|
async def prepare_text_to_speech(self) -> None:
|
||||||
"""Prepare text to speech."""
|
"""Prepare text to speech."""
|
||||||
engine = self.pipeline.tts_engine
|
# pipeline.tts_engine can't be None or this function is not called
|
||||||
|
engine = cast(str, self.pipeline.tts_engine)
|
||||||
|
|
||||||
tts_options = {}
|
tts_options = {}
|
||||||
if self.pipeline.tts_voice is not None:
|
if self.pipeline.tts_voice is not None:
|
||||||
@ -557,13 +553,18 @@ class PipelineRun:
|
|||||||
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# pipeline.tts_engine can't be None or this function is not called
|
options_supported = await tts.async_support_options(
|
||||||
if not await tts.async_support_options(
|
|
||||||
self.hass,
|
self.hass,
|
||||||
engine, # type: ignore[arg-type]
|
engine,
|
||||||
self.pipeline.tts_language,
|
self.pipeline.tts_language,
|
||||||
tts_options,
|
tts_options,
|
||||||
):
|
)
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
raise TextToSpeechError(
|
||||||
|
code="tts-not-supported",
|
||||||
|
message=f"Text to speech engine '{engine}' not found",
|
||||||
|
) from err
|
||||||
|
if not options_supported:
|
||||||
raise TextToSpeechError(
|
raise TextToSpeechError(
|
||||||
code="tts-not-supported",
|
code="tts-not-supported",
|
||||||
message=(
|
message=(
|
||||||
@ -571,20 +572,12 @@ class PipelineRun:
|
|||||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
|
||||||
raise TextToSpeechError(
|
|
||||||
code="tts-not-supported",
|
|
||||||
message=f"Text to speech engine '{engine}' not found",
|
|
||||||
) from err
|
|
||||||
|
|
||||||
self.tts_engine = engine
|
self.tts_engine = engine
|
||||||
self.tts_options = tts_options
|
self.tts_options = tts_options
|
||||||
|
|
||||||
async def text_to_speech(self, tts_input: str) -> str:
|
async def text_to_speech(self, tts_input: str) -> str:
|
||||||
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||||
if self.tts_engine is None:
|
|
||||||
raise RuntimeError("Text to speech was not prepared")
|
|
||||||
|
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_START,
|
PipelineEventType.TTS_START,
|
||||||
|
@ -241,3 +241,42 @@ async def test_pipeline_from_audio_stream_no_stt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert not events
|
assert not events
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_from_audio_stream_unknown_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a pipeline from an audio stream.
|
||||||
|
|
||||||
|
In this test, the pipeline does not exist.
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def audio_data():
|
||||||
|
yield b"part1"
|
||||||
|
yield b"part2"
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
# Try to use the created pipeline
|
||||||
|
with pytest.raises(assist_pipeline.PipelineNotFound):
|
||||||
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
|
hass,
|
||||||
|
Context(),
|
||||||
|
events.append,
|
||||||
|
stt.SpeechMetadata(
|
||||||
|
language="en-UK",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
audio_data(),
|
||||||
|
pipeline_id="blah",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not events
|
||||||
|
@ -7,6 +7,7 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
@ -430,6 +431,34 @@ async def test_stt_provider_missing(
|
|||||||
assert msg["error"]["code"] == "stt-provider-missing"
|
assert msg["error"]["code"] == "stt-provider-missing"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_provider_bad_metadata(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_stt_provider,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with wrong metadata."""
|
||||||
|
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "stt",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"sample_rate": 12345,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "stt-provider-unsupported-metadata"
|
||||||
|
|
||||||
|
|
||||||
async def test_stt_stream_failed(
|
async def test_stt_stream_failed(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
@ -559,6 +588,64 @@ async def test_tts_failed(
|
|||||||
assert msg["result"] == {"events": events}
|
assert msg["result"] == {"events": events}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_provider_missing(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_tts_provider,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline run with text to speech error."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.async_support_options",
|
||||||
|
side_effect=HomeAssistantError,
|
||||||
|
):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "tts",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {"text": "Lights are on."},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "tts-not-supported"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_provider_bad_options(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_tts_provider,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline run with text to speech error."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.async_support_options",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "tts",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {"text": "Lights are on."},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "tts-not-supported"
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_stage_order(
|
async def test_invalid_stage_order(
|
||||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user