Small improvement of assist_pipeline test coverage (#92115)

This commit is contained in:
Erik Montnemery 2023-05-04 19:01:41 +02:00 committed by GitHub
parent 57af4672d5
commit 887e656570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 23 deletions

View File

@ -5,7 +5,7 @@ import asyncio
from collections.abc import AsyncIterable, Callable, Iterable from collections.abc import AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
import logging import logging
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -332,12 +332,12 @@ class PipelineRun:
event_callback: PipelineEventCallback event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment] language: str = None # type: ignore[assignment]
runner_data: Any | None = None runner_data: Any | None = None
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
intent_agent: str | None = None intent_agent: str | None = None
tts_engine: str | None = None
tts_audio_output: str | None = None tts_audio_output: str | None = None
id: str = field(default_factory=ulid_util.ulid) id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None) tts_options: dict | None = field(init=False, default=None)
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -388,8 +388,6 @@ class PipelineRun:
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None: async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech to text.""" """Prepare speech to text."""
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
# pipeline.stt_engine can't be None or this function is not called # pipeline.stt_engine can't be None or this function is not called
stt_provider = stt.async_get_speech_to_text_engine( stt_provider = stt.async_get_speech_to_text_engine(
self.hass, self.hass,
@ -422,9 +420,6 @@ class PipelineRun:
stream: AsyncIterable[bytes], stream: AsyncIterable[bytes],
) -> str: ) -> str:
"""Run speech to text portion of pipeline. Returns the spoken text.""" """Run speech to text portion of pipeline. Returns the spoken text."""
if self.stt_provider is None:
raise RuntimeError("Speech to text was not prepared")
if isinstance(self.stt_provider, stt.Provider): if isinstance(self.stt_provider, stt.Provider):
engine = self.stt_provider.name engine = self.stt_provider.name
else: else:
@ -547,7 +542,8 @@ class PipelineRun:
async def prepare_text_to_speech(self) -> None: async def prepare_text_to_speech(self) -> None:
"""Prepare text to speech.""" """Prepare text to speech."""
engine = self.pipeline.tts_engine # pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine)
tts_options = {} tts_options = {}
if self.pipeline.tts_voice is not None: if self.pipeline.tts_voice is not None:
@ -557,13 +553,18 @@ class PipelineRun:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
try: try:
# pipeline.tts_engine can't be None or this function is not called options_supported = await tts.async_support_options(
if not await tts.async_support_options(
self.hass, self.hass,
engine, # type: ignore[arg-type] engine,
self.pipeline.tts_language, self.pipeline.tts_language,
tts_options, tts_options,
): )
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text to speech engine '{engine}' not found",
) from err
if not options_supported:
raise TextToSpeechError( raise TextToSpeechError(
code="tts-not-supported", code="tts-not-supported",
message=( message=(
@ -571,20 +572,12 @@ class PipelineRun:
f"does not support language {self.pipeline.tts_language} or options {tts_options}" f"does not support language {self.pipeline.tts_language} or options {tts_options}"
), ),
) )
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text to speech engine '{engine}' not found",
) from err
self.tts_engine = engine self.tts_engine = engine
self.tts_options = tts_options self.tts_options = tts_options
async def text_to_speech(self, tts_input: str) -> str: async def text_to_speech(self, tts_input: str) -> str:
"""Run text to speech portion of pipeline. Returns URL of TTS audio.""" """Run text to speech portion of pipeline. Returns URL of TTS audio."""
if self.tts_engine is None:
raise RuntimeError("Text to speech was not prepared")
self.process_event( self.process_event(
PipelineEvent( PipelineEvent(
PipelineEventType.TTS_START, PipelineEventType.TTS_START,

View File

@ -241,3 +241,42 @@ async def test_pipeline_from_audio_stream_no_stt(
) )
assert not events assert not events
async def test_pipeline_from_audio_stream_unknown_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, the pipeline does not exist.
"""
events = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Try to use the created pipeline
with pytest.raises(assist_pipeline.PipelineNotFound):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
pipeline_id="blah",
)
assert not events

View File

@ -7,6 +7,7 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData from homeassistant.components.assist_pipeline.pipeline import Pipeline, PipelineData
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -430,6 +431,34 @@ async def test_stt_provider_missing(
assert msg["error"]["code"] == "stt-provider-missing" assert msg["error"]["code"] == "stt-provider-missing"
async def test_stt_provider_bad_metadata(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_stt_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with wrong metadata."""
with patch.object(mock_stt_provider, "check_metadata", return_value=False):
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "tts",
"input": {
"sample_rate": 12345,
},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "stt-provider-unsupported-metadata"
async def test_stt_stream_failed( async def test_stt_stream_failed(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
@ -559,6 +588,64 @@ async def test_tts_failed(
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_tts_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_tts_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
side_effect=HomeAssistantError,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
async def test_tts_provider_bad_options(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_tts_provider,
snapshot: SnapshotAssertion,
) -> None:
"""Test pipeline run with text to speech error."""
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.tts.async_support_options",
return_value=False,
):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {"text": "Lights are on."},
}
)
# result
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "tts-not-supported"
async def test_invalid_stage_order( async def test_invalid_stage_order(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None: ) -> None: