mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Clean up Speech-to-text integration and add tests (#79012)
This commit is contained in:
parent
1b144c0e4d
commit
5774664234
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -13,7 +14,6 @@ from aiohttp.web_exceptions import (
|
|||||||
HTTPNotFound,
|
HTTPNotFound,
|
||||||
HTTPUnsupportedMediaType,
|
HTTPUnsupportedMediaType,
|
||||||
)
|
)
|
||||||
import attr
|
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
@ -34,9 +34,18 @@ from .const import (
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_provider(hass: HomeAssistant, domain: str | None = None) -> Provider:
|
||||||
|
"""Return provider."""
|
||||||
|
if domain is None:
|
||||||
|
domain = next(iter(hass.data[DOMAIN]))
|
||||||
|
|
||||||
|
return hass.data[DOMAIN][domain]
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up STT."""
|
"""Set up STT."""
|
||||||
providers = {}
|
providers = hass.data[DOMAIN] = {}
|
||||||
|
|
||||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||||
"""Set up a TTS platform."""
|
"""Set up a TTS platform."""
|
||||||
@ -80,24 +89,30 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@dataclass
|
||||||
class SpeechMetadata:
|
class SpeechMetadata:
|
||||||
"""Metadata of audio stream."""
|
"""Metadata of audio stream."""
|
||||||
|
|
||||||
language: str = attr.ib()
|
language: str
|
||||||
format: AudioFormats = attr.ib()
|
format: AudioFormats
|
||||||
codec: AudioCodecs = attr.ib()
|
codec: AudioCodecs
|
||||||
bit_rate: AudioBitRates = attr.ib(converter=int)
|
bit_rate: AudioBitRates
|
||||||
sample_rate: AudioSampleRates = attr.ib(converter=int)
|
sample_rate: AudioSampleRates
|
||||||
channel: AudioChannels = attr.ib(converter=int)
|
channel: AudioChannels
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Finish initializing the metadata."""
|
||||||
|
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||||
|
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||||
|
self.channel = AudioChannels(int(self.channel))
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@dataclass
|
||||||
class SpeechResult:
|
class SpeechResult:
|
||||||
"""Result of audio Speech."""
|
"""Result of audio Speech."""
|
||||||
|
|
||||||
text: str | None = attr.ib()
|
text: str | None
|
||||||
result: SpeechResultState = attr.ib()
|
result: SpeechResultState
|
||||||
|
|
||||||
|
|
||||||
class Provider(ABC):
|
class Provider(ABC):
|
||||||
@ -171,30 +186,6 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
"""Initialize a tts view."""
|
"""Initialize a tts view."""
|
||||||
self.providers = providers
|
self.providers = providers
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _metadata_from_header(request: web.Request) -> SpeechMetadata | None:
|
|
||||||
"""Extract metadata from header.
|
|
||||||
|
|
||||||
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = request.headers[istr("X-Speech-Content")].split(";")
|
|
||||||
except KeyError:
|
|
||||||
_LOGGER.warning("Missing X-Speech-Content")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Convert Header data
|
|
||||||
args: dict[str, Any] = {}
|
|
||||||
for value in data:
|
|
||||||
value = value.strip()
|
|
||||||
args[value.partition("=")[0]] = value.partition("=")[2]
|
|
||||||
|
|
||||||
try:
|
|
||||||
return SpeechMetadata(**args)
|
|
||||||
except TypeError as err:
|
|
||||||
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def post(self, request: web.Request, provider: str) -> web.Response:
|
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||||
"""Convert Speech (audio) to text."""
|
"""Convert Speech (audio) to text."""
|
||||||
if provider not in self.providers:
|
if provider not in self.providers:
|
||||||
@ -202,9 +193,10 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
stt_provider: Provider = self.providers[provider]
|
stt_provider: Provider = self.providers[provider]
|
||||||
|
|
||||||
# Get metadata
|
# Get metadata
|
||||||
metadata = self._metadata_from_header(request)
|
try:
|
||||||
if not metadata:
|
metadata = metadata_from_header(request)
|
||||||
raise HTTPBadRequest()
|
except ValueError as err:
|
||||||
|
raise HTTPBadRequest(text=str(err)) from err
|
||||||
|
|
||||||
# Check format
|
# Check format
|
||||||
if not stt_provider.check_metadata(metadata):
|
if not stt_provider.check_metadata(metadata):
|
||||||
@ -216,7 +208,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return result
|
# Return result
|
||||||
return self.json(attr.asdict(result))
|
return self.json(asdict(result))
|
||||||
|
|
||||||
async def get(self, request: web.Request, provider: str) -> web.Response:
|
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||||
"""Return provider specific audio information."""
|
"""Return provider specific audio information."""
|
||||||
@ -234,3 +226,47 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
"channels": stt_provider.supported_channels,
|
"channels": stt_provider.supported_channels,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||||
|
"""Extract STT metadata from header.
|
||||||
|
|
||||||
|
X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = request.headers[istr("X-Speech-Content")].split(";")
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError("Missing X-Speech-Content header") from err
|
||||||
|
|
||||||
|
fields = (
|
||||||
|
"language",
|
||||||
|
"format",
|
||||||
|
"codec",
|
||||||
|
"bit_rate",
|
||||||
|
"sample_rate",
|
||||||
|
"channel",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert Header data
|
||||||
|
args: dict[str, Any] = {}
|
||||||
|
for entry in data:
|
||||||
|
key, _, value = entry.strip().partition("=")
|
||||||
|
if key not in fields:
|
||||||
|
raise ValueError(f"Invalid field {key}")
|
||||||
|
args[key] = value
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if field not in args:
|
||||||
|
raise ValueError(f"Missing {field} in X-Speech-Content header")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return SpeechMetadata(
|
||||||
|
language=args["language"],
|
||||||
|
format=args["format"],
|
||||||
|
codec=args["codec"],
|
||||||
|
bit_rate=args["bit_rate"],
|
||||||
|
sample_rate=args["sample_rate"],
|
||||||
|
channel=args["channel"],
|
||||||
|
)
|
||||||
|
except TypeError as err:
|
||||||
|
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||||
|
@ -1,30 +1,155 @@
|
|||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
|
from asyncio import StreamReader
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from homeassistant.components import stt
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.stt import (
|
||||||
|
AudioBitRates,
|
||||||
|
AudioChannels,
|
||||||
|
AudioCodecs,
|
||||||
|
AudioFormats,
|
||||||
|
AudioSampleRates,
|
||||||
|
Provider,
|
||||||
|
SpeechMetadata,
|
||||||
|
SpeechResult,
|
||||||
|
SpeechResultState,
|
||||||
|
async_get_provider,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import mock_platform
|
||||||
async def test_setup_comp(hass):
|
|
||||||
"""Set up demo component."""
|
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_demo_settings_not_exists(hass, hass_client):
|
class TestProvider(Provider):
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test provider."""
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
|
||||||
|
fail_process_audio = False
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Init test provider."""
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self):
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return ["en"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
return [AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
return [AudioSampleRates.SAMPLERATE_16000]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
return [AudioChannels.CHANNEL_MONO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: StreamReader
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream."""
|
||||||
|
self.calls.append((metadata, stream))
|
||||||
|
if self.fail_process_audio:
|
||||||
|
return SpeechResult(None, SpeechResultState.ERROR)
|
||||||
|
|
||||||
|
return SpeechResult("test", SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_provider():
|
||||||
|
"""Test provider fixture."""
|
||||||
|
return TestProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def mock_setup(hass, test_provider):
|
||||||
|
"""Set up a test provider."""
|
||||||
|
mock_platform(
|
||||||
|
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=test_provider))
|
||||||
|
)
|
||||||
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_provider_info(hass, hass_client):
|
||||||
|
"""Test engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
response = await client.get("/api/stt/test")
|
||||||
|
assert response.status == HTTPStatus.OK
|
||||||
|
assert await response.json() == {
|
||||||
|
"languages": ["en"],
|
||||||
|
"formats": ["wav", "ogg"],
|
||||||
|
"codecs": ["pcm", "opus"],
|
||||||
|
"sample_rates": [16000],
|
||||||
|
"bit_rates": [16],
|
||||||
|
"channels": [1],
|
||||||
|
}
|
||||||
|
|
||||||
response = await client.get("/api/stt/beer")
|
|
||||||
|
|
||||||
|
async def test_get_non_existing_provider_info(hass, hass_client):
|
||||||
|
"""Test streaming to engine that doesn't exist."""
|
||||||
|
client = await hass_client()
|
||||||
|
response = await client.get("/api/stt/not_exist")
|
||||||
assert response.status == HTTPStatus.NOT_FOUND
|
assert response.status == HTTPStatus.NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
async def test_demo_speech_not_exists(hass, hass_client):
|
async def test_stream_audio(hass, hass_client):
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test streaming audio and getting response."""
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
response = await client.post(
|
||||||
|
"/api/stt/test",
|
||||||
|
headers={
|
||||||
|
"X-Speech-Content": "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=en"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status == HTTPStatus.OK
|
||||||
|
assert await response.json() == {"text": "test", "result": "success"}
|
||||||
|
|
||||||
response = await client.post("/api/stt/beer", data=b"test")
|
|
||||||
|
|
||||||
assert response.status == HTTPStatus.NOT_FOUND
|
@pytest.mark.parametrize(
|
||||||
|
"header,status,error",
|
||||||
|
(
|
||||||
|
(None, 400, "Missing X-Speech-Content header"),
|
||||||
|
(
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100; language=en",
|
||||||
|
400,
|
||||||
|
"100 is not a valid AudioChannels",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000",
|
||||||
|
400,
|
||||||
|
"Missing language in X-Speech-Content header",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_metadata_errors(hass, hass_client, header, status, error):
|
||||||
|
"""Test metadata errors."""
|
||||||
|
client = await hass_client()
|
||||||
|
headers = {}
|
||||||
|
if header:
|
||||||
|
headers["X-Speech-Content"] = header
|
||||||
|
|
||||||
|
response = await client.post("/api/stt/test", headers=headers)
|
||||||
|
assert response.status == status
|
||||||
|
assert await response.text() == error
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_provider(hass, test_provider):
|
||||||
|
"""Test we can get STT providers."""
|
||||||
|
assert test_provider == async_get_provider(hass, "test")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user