mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Small improvement of assist_pipeline test coverage (#92115)
This commit is contained in:
parent
57af4672d5
commit
887e656570
@ -5,7 +5,7 @@ import asyncio
|
||||
from collections.abc import AsyncIterable, Callable, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -332,12 +332,12 @@ class PipelineRun:
|
||||
event_callback: PipelineEventCallback
|
||||
language: str = None # type: ignore[assignment]
|
||||
runner_data: Any | None = None
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
||||
intent_agent: str | None = None
|
||||
tts_engine: str | None = None
|
||||
tts_audio_output: str | None = None
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||
tts_engine: str = field(init=False)
|
||||
tts_options: dict | None = field(init=False, default=None)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -388,8 +388,6 @@ class PipelineRun:
|
||||
|
||||
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
|
||||
"""Prepare speech to text."""
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
||||
|
||||
# pipeline.stt_engine can't be None or this function is not called
|
||||
stt_provider = stt.async_get_speech_to_text_engine(
|
||||
self.hass,
|
||||
@ -422,9 +420,6 @@ class PipelineRun:
|
||||
stream: AsyncIterable[bytes],
|
||||
) -> str:
|
||||
"""Run speech to text portion of pipeline. Returns the spoken text."""
|
||||
if self.stt_provider is None:
|
||||
raise RuntimeError("Speech to text was not prepared")
|
||||
|
||||
if isinstance(self.stt_provider, stt.Provider):
|
||||
engine = self.stt_provider.name
|
||||
else:
|
||||
@ -547,7 +542,8 @@ class PipelineRun:
|
||||
|
||||
async def prepare_text_to_speech(self) -> None:
|
||||
"""Prepare text to speech."""
|
||||
engine = self.pipeline.tts_engine
|
||||
# pipeline.tts_engine can't be None or this function is not called
|
||||
engine = cast(str, self.pipeline.tts_engine)
|
||||
|
||||
tts_options = {}
|
||||
if self.pipeline.tts_voice is not None:
|
||||
@ -557,34 +553,31 @@ class PipelineRun:
|
||||
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
|
||||
|
||||
try:
|
||||
# pipeline.tts_engine can't be None or this function is not called
|
||||
if not await tts.async_support_options(
|
||||
options_supported = await tts.async_support_options(
|
||||
self.hass,
|
||||
engine, # type: ignore[arg-type]
|
||||
engine,
|
||||
self.pipeline.tts_language,
|
||||
tts_options,
|
||||
):
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||
),
|
||||
)
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=f"Text to speech engine '{engine}' not found",
|
||||
) from err
|
||||
if not options_supported:
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||
),
|
||||
)
|
||||
|
||||
self.tts_engine = engine
|
||||
self.tts_options = tts_options
|
||||
|
||||
async def text_to_speech(self, tts_input: str) -> str:
|
||||
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
|
||||
if self.tts_engine is None:
|
||||
raise RuntimeError("Text to speech was not prepared")
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_START,
|
||||
|
@ -241,3 +241,42 @@ async def test_pipeline_from_audio_stream_no_stt(
|
||||
)
|
||||
|
||||
assert not events
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_unknown_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream.
|
||||
|
||||
In this test, the pipeline does not exist.
|
||||
"""
|
||||
events = []
|
||||
|
||||
async def audio_data():
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
# Try to use the created pipeline
|
||||
with pytest.raises(assist_pipeline.PipelineNotFound):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
pipeline_id="blah",
|
||||
)
|
||||
|
||||
assert not events
|
||||
|
@ -7,6 +7,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
@ -430,6 +431,34 @@ async def test_stt_provider_missing(
|
||||
assert msg["error"]["code"] == "stt-provider-missing"
|
||||
|
||||
|
||||
async def test_stt_provider_bad_metadata(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_stt_provider,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with wrong metadata."""
|
||||
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "stt",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"sample_rate": 12345,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "stt-provider-unsupported-metadata"
|
||||
|
||||
|
||||
async def test_stt_stream_failed(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
@ -559,6 +588,64 @@ async def test_tts_failed(
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_tts_provider_missing(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_tts_provider,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test pipeline run with text to speech error."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_support_options",
|
||||
side_effect=HomeAssistantError,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {"text": "Lights are on."},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "tts-not-supported"
|
||||
|
||||
|
||||
async def test_tts_provider_bad_options(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_tts_provider,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test pipeline run with text to speech error."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_support_options",
|
||||
return_value=False,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {"text": "Lights are on."},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "tts-not-supported"
|
||||
|
||||
|
||||
async def test_invalid_stage_order(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user