From 57746642349a3ca62959de4447a4eb5963a84ae1 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 24 Sep 2022 03:58:01 -0400 Subject: [PATCH] Clean up Speech-to-text integration and add tests (#79012) --- homeassistant/components/stt/__init__.py | 116 +++++++++++------ tests/components/stt/test_init.py | 153 ++++++++++++++++++++--- 2 files changed, 215 insertions(+), 54 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index 94acf155968..1d68b0a954b 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio +from dataclasses import asdict, dataclass import logging from typing import Any @@ -13,7 +14,6 @@ from aiohttp.web_exceptions import ( HTTPNotFound, HTTPUnsupportedMediaType, ) -import attr from homeassistant.components.http import HomeAssistantView from homeassistant.core import HomeAssistant, callback @@ -34,9 +34,18 @@ from .const import ( _LOGGER = logging.getLogger(__name__) +@callback +def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider: + """Return provider.""" + if domain is None: + domain = next(iter(hass.data[DOMAIN])) + + return hass.data[DOMAIN][domain] + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up STT.""" - providers = {} + providers = hass.data[DOMAIN] = {} async def async_setup_platform(p_type, p_config=None, discovery_info=None): """Set up a TTS platform.""" @@ -80,24 +89,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -@attr.s +@dataclass class SpeechMetadata: """Metadata of audio stream.""" - language: str = attr.ib() - format: AudioFormats = attr.ib() - codec: AudioCodecs = attr.ib() - bit_rate: AudioBitRates = attr.ib(converter=int) - sample_rate: AudioSampleRates = attr.ib(converter=int) - channel: AudioChannels = attr.ib(converter=int) + language: str + format: AudioFormats + codec: AudioCodecs + bit_rate: AudioBitRates + sample_rate: AudioSampleRates + channel: AudioChannels + + def __post_init__(self) -> None: + """Finish initializing the metadata.""" + self.bit_rate = AudioBitRates(int(self.bit_rate)) + self.sample_rate = AudioSampleRates(int(self.sample_rate)) + self.channel = AudioChannels(int(self.channel)) -@attr.s +@dataclass class SpeechResult: """Result of audio Speech.""" - text: str | None = attr.ib() - result: SpeechResultState = attr.ib() + text: str | None + result: SpeechResultState class Provider(ABC): @@ -171,30 +186,6 @@ class SpeechToTextView(HomeAssistantView): """Initialize a tts view.""" self.providers = providers - @staticmethod - def _metadata_from_header(request: web.Request) -> SpeechMetadata | None: - """Extract metadata from header. - - X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de - """ - try: - data = request.headers[istr("X-Speech-Content")].split(";") - except KeyError: - _LOGGER.warning("Missing X-Speech-Content") - return None - - # Convert Header data - args: dict[str, Any] = {} - for value in data: - value = value.strip() - args[value.partition("=")[0]] = value.partition("=")[2] - - try: - return SpeechMetadata(**args) - except TypeError as err: - _LOGGER.warning("Wrong format of X-Speech-Content: %s", err) - return None - async def post(self, request: web.Request, provider: str) -> web.Response: """Convert Speech (audio) to text.""" if provider not in self.providers: @@ -202,9 +193,10 @@ class SpeechToTextView(HomeAssistantView): stt_provider: Provider = self.providers[provider] # Get metadata - metadata = self._metadata_from_header(request) - if not metadata: - raise HTTPBadRequest() + try: + metadata = metadata_from_header(request) + except ValueError as err: + raise HTTPBadRequest(text=str(err)) from err # Check format if not stt_provider.check_metadata(metadata): @@ -216,7 +208,7 @@ class SpeechToTextView(HomeAssistantView): ) # Return result - return self.json(attr.asdict(result)) + return self.json(asdict(result)) async def get(self, request: web.Request, provider: str) -> web.Response: """Return provider specific audio information.""" @@ -234,3 +226,47 @@ class SpeechToTextView(HomeAssistantView): "channels": stt_provider.supported_channels, } ) + + +def metadata_from_header(request: web.Request) -> SpeechMetadata: + """Extract STT metadata from header. + + X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de + """ + try: + data = request.headers[istr("X-Speech-Content")].split(";") + except KeyError as err: + raise ValueError("Missing X-Speech-Content header") from err + + fields = ( + "language", + "format", + "codec", + "bit_rate", + "sample_rate", + "channel", + ) + + # Convert Header data + args: dict[str, Any] = {} + for entry in data: + key, _, value = entry.strip().partition("=") + if key not in fields: + raise ValueError(f"Invalid field {key}") + args[key] = value + + for field in fields: + if field not in args: + raise ValueError(f"Missing {field} in X-Speech-Content header") + + try: + return SpeechMetadata( + language=args["language"], + format=args["format"], + codec=args["codec"], + bit_rate=args["bit_rate"], + sample_rate=args["sample_rate"], + channel=args["channel"], + ) + except TypeError as err: + raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 3b207fae01a..33242180f77 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -1,30 +1,155 @@ """Test STT component setup.""" +from asyncio import StreamReader from http import HTTPStatus +from unittest.mock import AsyncMock, Mock -from homeassistant.components import stt +import pytest + +from homeassistant.components.stt import ( + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + Provider, + SpeechMetadata, + SpeechResult, + SpeechResultState, + async_get_provider, +) from homeassistant.setup import async_setup_component - -async def test_setup_comp(hass): - """Set up demo component.""" - assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}}) +from tests.common import mock_platform -async def test_demo_settings_not_exists(hass, hass_client): - """Test retrieve settings from demo provider.""" - assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}}) +class TestProvider(Provider): + """Test provider.""" + + fail_process_audio = False + + def __init__(self) -> None: + """Init test provider.""" + self.calls = [] + + @property + def supported_languages(self): + """Return a list of supported languages.""" + return ["en"] + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV, AudioFormats.OGG] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM, AudioCodecs.OPUS] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bitrates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported samplerates.""" + return [AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: StreamReader + ) -> SpeechResult: + """Process an audio stream.""" + self.calls.append((metadata, stream)) + if self.fail_process_audio: + return SpeechResult(None, SpeechResultState.ERROR) + + return SpeechResult("test", SpeechResultState.SUCCESS) + + +@pytest.fixture +def test_provider(): + """Test provider fixture.""" + return TestProvider() + + +@pytest.fixture(autouse=True) +async def mock_setup(hass, test_provider): + """Set up a test provider.""" + mock_platform( + hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider)) + ) + assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}}) + + +async def test_get_provider_info(hass, hass_client): + """Test engine that doesn't exist.""" client = await hass_client() + response = await client.get("/api/stt/test") + assert response.status == HTTPStatus.OK + assert await response.json() == { + "languages": ["en"], + "formats": ["wav", "ogg"], + "codecs": ["pcm", "opus"], + "sample_rates": [16000], + "bit_rates": [16], + "channels": [1], + } - response = await client.get("/api/stt/beer") +async def test_get_non_existing_provider_info(hass, hass_client): + """Test streaming to engine that doesn't exist.""" + client = await hass_client() + response = await client.get("/api/stt/not_exist") assert response.status == HTTPStatus.NOT_FOUND -async def test_demo_speech_not_exists(hass, hass_client): - """Test retrieve settings from demo provider.""" - assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}}) +async def test_stream_audio(hass, hass_client): + """Test streaming audio and getting response.""" client = await hass_client() + response = await client.post( + "/api/stt/test", + headers={ + "X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en" + }, + ) + assert response.status == HTTPStatus.OK + assert await response.json() == {"text": "test", "result": "success"} - response = await client.post("/api/stt/beer", data=b"test") - assert response.status == HTTPStatus.NOT_FOUND +@pytest.mark.parametrize( + "header,status,error", + ( + (None, 400, "Missing X-Speech-Content header"), + ( + "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en", + 400, + "100 is not a valid AudioChannels", + ), + ( + "format=wav; codec=pcm; sample_rate=16000", + 400, + "Missing language in X-Speech-Content header", + ), + ), +) +async def test_metadata_errors(hass, hass_client, header, status, error): + """Test metadata errors.""" + client = await hass_client() + headers = {} + if header: + headers["X-Speech-Content"] = header + + response = await client.post("/api/stt/test", headers=headers) + assert response.status == status + assert await response.text() == error + + +async def test_get_provider(hass, test_provider): + """Test we can get STT providers.""" + assert test_provider == async_get_provider(hass, "test")