Small improvement of assist_pipeline test coverage (#92115)

This commit is contained in:
Erik Montnemery 2023-05-04 19:01:41 +02:00 committed by GitHub
parent 57af4672d5
commit 887e656570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 23 deletions

View File

@ -5,7 +5,7 @@ import asyncio
from collections.abc import AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field
import logging
from typing import Any
from typing import Any, cast
import voluptuous as vol
@ -332,12 +332,12 @@ class PipelineRun:
event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
intent_agent: str | None = None
tts_engine: str | None = None
tts_audio_output: str | None = None
id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None)
def __post_init__(self) -> None:
@ -388,8 +388,6 @@ class PipelineRun:
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech to text."""
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
# pipeline.stt_engine can't be None or this function is not called
stt_provider = stt.async_get_speech_to_text_engine(
self.hass,
@ -422,9 +420,6 @@ class PipelineRun:
stream: AsyncIterable[bytes],
) -> str:
"""Run speech to text portion of pipeline. Returns the spoken text."""
if self.stt_provider is None:
raise RuntimeError("Speech to text was not prepared")
if isinstance(self.stt_provider, stt.Provider):
engine = self.stt_provider.name
else:
@ -547,7 +542,8 @@ class PipelineRun:
async def prepare_text_to_speech(self) -> None:
"""Prepare text to speech."""
engine = self.pipeline.tts_engine
# pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine)
tts_options = {}
if self.pipeline.tts_voice is not None:
@ -557,34 +553,31 @@ class PipelineRun:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
try:
# pipeline.tts_engine can't be None or this function is not called
if not await tts.async_support_options(
options_supported = await tts.async_support_options(
self.hass,
engine, # type: ignore[arg-type]
engine,
self.pipeline.tts_language,
tts_options,
):
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
)
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text to speech engine '{engine}' not found",
) from err
if not options_supported:
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
self.tts_engine = engine
self.tts_options = tts_options
async def text_to_speech(self, tts_input: str) -> str:
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""
if self.tts_engine is None:
raise RuntimeError("Text to speech was not prepared")
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,

View File

@ -241,3 +241,42 @@ async def test_pipeline_from_audio_stream_no_stt(
)
assert not events
async def test_pipeline_from_audio_stream_unknown_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, the pipeline does not exist.
"""
events = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Try to use the created pipeline
with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
pipeline_id="blah",
)
assert not events

View File

@ -7,6 +7,7 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from tests.typing import WebSocketGenerator
@ -430,6 +431,34 @@ async def test_stt_provider_missing(
assert msg["error"]["code"] == "stt-provider-missing"
async def test_stt_provider_bad_metadata(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_stt_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with wrong metadata."""
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 12345,
},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "stt-provider-unsupported-metadata"
async def test_stt_stream_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
@ -559,6 +588,64 @@ async def test_tts_failed(
assert msg["result"] == {"events": events}
async def test_tts_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_tts_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
side_effect=HomeAssistantError,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
async def test_tts_provider_bad_options(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_tts_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
return_value=False,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
async def test_invalid_stage_order(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None: