mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Add stt entity (#91230)
* Add stt entity * Update demo platform * Rename ProviderEntity to SpeechToTextEntity * Fix get method * Run all init tests for config entry setup * Fix and test metadata from header * Test config entry unload * Rename get provider entity * Test post for non existing provider * Test entity name before addition * Test restore state * Use register shutdown * Update deprecation comment
This commit is contained in:
parent
22a1a6846d
commit
473cbf7f9b
@ -36,6 +36,7 @@ COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM = [
|
||||
Platform.SELECT,
|
||||
Platform.SENSOR,
|
||||
Platform.SIREN,
|
||||
Platform.STT,
|
||||
Platform.SWITCH,
|
||||
Platform.TEXT,
|
||||
Platform.UPDATE,
|
||||
|
@ -13,8 +13,11 @@ from homeassistant.components.stt import (
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
SpeechResultState,
|
||||
SpeechToTextEntity,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
SUPPORT_LANGUAGES = ["en", "de"]
|
||||
@ -29,6 +32,60 @@ async def async_get_engine(
|
||||
return DemoProvider()
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Demo speech platform via config entry."""
|
||||
async_add_entities([DemoProviderEntity()])
|
||||
|
||||
|
||||
class DemoProviderEntity(SpeechToTextEntity):
|
||||
"""Demo speech API provider entity."""
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
return [AudioFormats.WAV]
|
||||
|
||||
@property
|
||||
def supported_codecs(self) -> list[AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
return [AudioCodecs.PCM]
|
||||
|
||||
@property
|
||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||
"""Return a list of supported bit rates."""
|
||||
return [AudioBitRates.BITRATE_16]
|
||||
|
||||
@property
|
||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||
"""Return a list of supported sample rates."""
|
||||
return [AudioSampleRates.SAMPLERATE_16000, AudioSampleRates.SAMPLERATE_44100]
|
||||
|
||||
@property
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
return [AudioChannels.CHANNEL_STEREO]
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream to STT service."""
|
||||
|
||||
# Read available data
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
|
||||
|
||||
|
||||
class DemoProvider(Provider):
|
||||
"""Demo speech API provider."""
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
"""Provide functionality to STT."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
import logging
|
||||
from typing import Any, final
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import istr
|
||||
@ -14,10 +17,16 @@ from aiohttp.web_exceptions import (
|
||||
)
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import (
|
||||
DATA_PROVIDERS,
|
||||
DOMAIN,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
@ -36,6 +45,7 @@ from .legacy import (
|
||||
|
||||
__all__ = [
|
||||
"async_get_provider",
|
||||
"async_get_speech_to_text_entity",
|
||||
"AudioBitRates",
|
||||
"AudioChannels",
|
||||
"AudioCodecs",
|
||||
@ -43,26 +53,158 @@ __all__ = [
|
||||
"AudioSampleRates",
|
||||
"DOMAIN",
|
||||
"Provider",
|
||||
"SpeechToTextEntity",
|
||||
"SpeechMetadata",
|
||||
"SpeechResult",
|
||||
"SpeechResultState",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_speech_to_text_entity(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> SpeechToTextEntity | None:
|
||||
"""Return stt entity."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
|
||||
return component.get_entity(entity_id)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
|
||||
component.register_shutdown()
|
||||
platform_setups = async_setup_legacy(hass, config)
|
||||
|
||||
if platform_setups:
|
||||
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||
|
||||
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
|
||||
hass.http.register_view(SpeechToTextView(hass.data[DATA_PROVIDERS]))
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
|
||||
|
||||
class SpeechToTextEntity(RestoreEntity):
|
||||
"""Represent a single STT provider."""
|
||||
|
||||
_attr_should_poll = False
|
||||
__last_processed: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def name(self) -> str:
|
||||
"""Return the name of the provider entity."""
|
||||
# Only one entity is allowed per platform for now.
|
||||
if self.platform is None:
|
||||
raise RuntimeError("Entity is not added to hass yet.")
|
||||
|
||||
return self.platform.platform_name
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the provider entity."""
|
||||
if self.__last_processed is None:
|
||||
return None
|
||||
return self.__last_processed
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_formats(self) -> list[AudioFormats]:
|
||||
"""Return a list of supported formats."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_codecs(self) -> list[AudioCodecs]:
|
||||
"""Return a list of supported codecs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||
"""Return a list of supported bit rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||
"""Return a list of supported sample rates."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the provider entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if (
|
||||
state is not None
|
||||
and state.state is not None
|
||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
):
|
||||
self.__last_processed = state.state
|
||||
|
||||
@final
|
||||
async def internal_async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream to STT service.
|
||||
|
||||
Only streaming content is allowed!
|
||||
"""
|
||||
self.__last_processed = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_process_audio_stream(metadata=metadata, stream=stream)
|
||||
|
||||
@abstractmethod
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
"""Process an audio stream to STT service.
|
||||
|
||||
Only streaming content is allowed!
|
||||
"""
|
||||
|
||||
@callback
|
||||
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||
"""Check if given metadata supported by this provider."""
|
||||
if (
|
||||
metadata.language not in self.supported_languages
|
||||
or metadata.format not in self.supported_formats
|
||||
or metadata.codec not in self.supported_codecs
|
||||
or metadata.bit_rate not in self.supported_bit_rates
|
||||
or metadata.sample_rate not in self.supported_sample_rates
|
||||
or metadata.channel not in self.supported_channels
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SpeechToTextView(HomeAssistantView):
|
||||
"""STT view to generate a text from audio stream."""
|
||||
|
||||
_legacy_provider_reported = False
|
||||
requires_auth = True
|
||||
url = "/api/stt/{provider}"
|
||||
name = "api:stt:provider"
|
||||
@ -73,9 +215,17 @@ class SpeechToTextView(HomeAssistantView):
|
||||
|
||||
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Convert Speech (audio) to text."""
|
||||
if provider not in self.providers:
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
provider_entity: SpeechToTextEntity | None = None
|
||||
if (
|
||||
not (
|
||||
provider_entity := async_get_speech_to_text_entity(
|
||||
hass, f"{DOMAIN}.{provider}"
|
||||
)
|
||||
)
|
||||
and provider not in self.providers
|
||||
):
|
||||
raise HTTPNotFound()
|
||||
stt_provider: Provider = self.providers[provider]
|
||||
|
||||
# Get metadata
|
||||
try:
|
||||
@ -83,6 +233,9 @@ class SpeechToTextView(HomeAssistantView):
|
||||
except ValueError as err:
|
||||
raise HTTPBadRequest(text=str(err)) from err
|
||||
|
||||
if not provider_entity:
|
||||
stt_provider = self._get_provider(provider)
|
||||
|
||||
# Check format
|
||||
if not stt_provider.check_metadata(metadata):
|
||||
raise HTTPUnsupportedMediaType()
|
||||
@ -91,15 +244,34 @@ class SpeechToTextView(HomeAssistantView):
|
||||
result = await stt_provider.async_process_audio_stream(
|
||||
metadata, request.content
|
||||
)
|
||||
else:
|
||||
# Check format
|
||||
if not provider_entity.check_metadata(metadata):
|
||||
raise HTTPUnsupportedMediaType()
|
||||
|
||||
# Process audio stream
|
||||
result = await provider_entity.internal_async_process_audio_stream(
|
||||
metadata, request.content
|
||||
)
|
||||
|
||||
# Return result
|
||||
return self.json(asdict(result))
|
||||
|
||||
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||
"""Return provider specific audio information."""
|
||||
if provider not in self.providers:
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
if (
|
||||
not (
|
||||
provider_entity := async_get_speech_to_text_entity(
|
||||
hass, f"{DOMAIN}.{provider}"
|
||||
)
|
||||
)
|
||||
and provider not in self.providers
|
||||
):
|
||||
raise HTTPNotFound()
|
||||
stt_provider: Provider = self.providers[provider]
|
||||
|
||||
if not provider_entity:
|
||||
stt_provider = self._get_provider(provider)
|
||||
|
||||
return self.json(
|
||||
{
|
||||
@ -112,6 +284,54 @@ class SpeechToTextView(HomeAssistantView):
|
||||
}
|
||||
)
|
||||
|
||||
return self.json(
|
||||
{
|
||||
"languages": provider_entity.supported_languages,
|
||||
"formats": provider_entity.supported_formats,
|
||||
"codecs": provider_entity.supported_codecs,
|
||||
"sample_rates": provider_entity.supported_sample_rates,
|
||||
"bit_rates": provider_entity.supported_bit_rates,
|
||||
"channels": provider_entity.supported_channels,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_provider(self, provider: str) -> Provider:
|
||||
"""Get provider.
|
||||
|
||||
Method for legacy providers.
|
||||
This can be removed when we remove the legacy provider support.
|
||||
"""
|
||||
stt_provider = self.providers[provider]
|
||||
|
||||
if not self._legacy_provider_reported:
|
||||
self._legacy_provider_reported = True
|
||||
report_issue = self._suggest_report_issue(provider, stt_provider)
|
||||
# This should raise in Home Assistant Core 2023.9
|
||||
_LOGGER.warning(
|
||||
"Provider %s (%s) is using a legacy implementation, "
|
||||
"and should be updated to use the SpeechToTextEntity. Please "
|
||||
"%s",
|
||||
provider,
|
||||
type(stt_provider),
|
||||
report_issue,
|
||||
)
|
||||
|
||||
return stt_provider
|
||||
|
||||
def _suggest_report_issue(self, provider: str, provider_instance: object) -> str:
|
||||
"""Suggest to report an issue."""
|
||||
report_issue = ""
|
||||
if "custom_components" in type(provider_instance).__module__:
|
||||
report_issue = "report it to the custom integration author."
|
||||
else:
|
||||
report_issue = (
|
||||
"create a bug report at "
|
||||
"https://github.com/home-assistant/core/issues?q=is%3Aopen+is%3Aissue"
|
||||
)
|
||||
report_issue += f"+label%3A%22integration%3A+{provider}%22"
|
||||
|
||||
return report_issue
|
||||
|
||||
|
||||
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
"""Extract STT metadata from header.
|
||||
@ -138,7 +358,7 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
for entry in data:
|
||||
key, _, value = entry.strip().partition("=")
|
||||
if key not in fields:
|
||||
raise ValueError(f"Invalid field {key}")
|
||||
raise ValueError(f"Invalid field: {key}")
|
||||
args[key] = value
|
||||
|
||||
for field in fields:
|
||||
@ -154,5 +374,5 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
sample_rate=args["sample_rate"],
|
||||
channel=args["channel"],
|
||||
)
|
||||
except TypeError as err:
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||
|
@ -2,6 +2,7 @@
|
||||
from enum import Enum
|
||||
|
||||
DOMAIN = "stt"
|
||||
DATA_PROVIDERS = f"{DOMAIN}_providers"
|
||||
|
||||
|
||||
class AudioCodecs(str, Enum):
|
||||
|
@ -13,6 +13,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
from .const import (
|
||||
DATA_PROVIDERS,
|
||||
DOMAIN,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
@ -31,15 +32,15 @@ def async_get_provider(
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
if domain:
|
||||
return hass.data[DOMAIN].get(domain)
|
||||
return hass.data[DATA_PROVIDERS].get(domain)
|
||||
|
||||
if not hass.data[DOMAIN]:
|
||||
if not hass.data[DATA_PROVIDERS]:
|
||||
return None
|
||||
|
||||
if "cloud" in hass.data[DOMAIN]:
|
||||
return hass.data[DOMAIN]["cloud"]
|
||||
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||
return hass.data[DATA_PROVIDERS]["cloud"]
|
||||
|
||||
return next(iter(hass.data[DOMAIN].values()))
|
||||
return next(iter(hass.data[DATA_PROVIDERS].values()))
|
||||
|
||||
|
||||
@callback
|
||||
@ -47,7 +48,7 @@ def async_setup_legacy(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> list[Coroutine[Any, Any, None]]:
|
||||
"""Set up legacy speech to text providers."""
|
||||
providers = hass.data[DOMAIN] = {}
|
||||
providers = hass.data[DATA_PROVIDERS] = {}
|
||||
|
||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
|
@ -4,18 +4,31 @@ from http import HTTPStatus
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_comp(hass):
|
||||
"""Set up demo component."""
|
||||
@pytest.fixture
|
||||
async def setup_legacy_platform(hass: HomeAssistant) -> None:
|
||||
"""Set up legacy demo platform."""
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_config_entry(hass: HomeAssistant) -> None:
|
||||
"""Set up demo component from config entry."""
|
||||
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||
async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
|
||||
"""Test retrieve settings from demo provider."""
|
||||
client = await hass_client()
|
||||
@ -34,6 +47,7 @@ async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||
async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> None:
|
||||
"""Test retrieve settings from demo provider."""
|
||||
client = await hass_client()
|
||||
@ -42,6 +56,7 @@ async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> N
|
||||
assert response.status == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||
async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -> None:
|
||||
"""Test retrieve settings from demo provider."""
|
||||
client = await hass_client()
|
||||
@ -59,6 +74,7 @@ async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -
|
||||
assert response.status == HTTPStatus.UNSUPPORTED_MEDIA_TYPE
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||
async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
|
||||
"""Test retrieve settings from demo provider."""
|
||||
client = await hass_client()
|
||||
@ -77,3 +93,26 @@ async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
|
||||
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_config_entry")
|
||||
async def test_config_entry_demo_speech(
|
||||
hass_client: ClientSessionGenerator, hass: HomeAssistant
|
||||
) -> None:
|
||||
"""Test retrieve settings from demo provider from config entry."""
|
||||
client = await hass_client()
|
||||
|
||||
response = await client.post(
|
||||
"/api/stt/demo",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
|
||||
" language=de"
|
||||
)
|
||||
},
|
||||
data=b"Test",
|
||||
)
|
||||
response_data = await response.json()
|
||||
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
||||
|
@ -6,7 +6,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.stt import Provider
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from tests.common import MockPlatform, mock_platform
|
||||
@ -54,3 +56,19 @@ def mock_stt_platform(
|
||||
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||
|
||||
return loaded_platform
|
||||
|
||||
|
||||
def mock_stt_entity_platform(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
integration: str,
|
||||
async_setup_entry: Callable[
|
||||
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None = None,
|
||||
) -> MockPlatform:
|
||||
"""Specialize the mock platform for stt."""
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||
return loaded_platform
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Test STT component setup."""
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
@ -7,6 +7,7 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.stt import (
|
||||
DOMAIN,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
@ -16,17 +17,30 @@ from homeassistant.components.stt import (
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
SpeechResultState,
|
||||
SpeechToTextEntity,
|
||||
async_get_provider,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import mock_stt_platform
|
||||
from .common import mock_stt_entity_platform, mock_stt_platform
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
mock_restore_cache,
|
||||
)
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
TEST_DOMAIN = "test"
|
||||
|
||||
class MockProvider(Provider):
|
||||
|
||||
class BaseProvider:
|
||||
"""Mock provider."""
|
||||
|
||||
fail_process_audio = False
|
||||
@ -73,7 +87,15 @@ class MockProvider(Provider):
|
||||
if self.fail_process_audio:
|
||||
return SpeechResult(None, SpeechResultState.ERROR)
|
||||
|
||||
return SpeechResult("test", SpeechResultState.SUCCESS)
|
||||
return SpeechResult("test_result", SpeechResultState.SUCCESS)
|
||||
|
||||
|
||||
class MockProvider(BaseProvider, Provider):
|
||||
"""Mock provider."""
|
||||
|
||||
|
||||
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -82,26 +104,113 @@ def mock_provider() -> MockProvider:
|
||||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity() -> MockProviderEntity:
|
||||
"""Test provider entity fixture."""
|
||||
return MockProviderEntity()
|
||||
|
||||
|
||||
class STTFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, STTFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="setup")
|
||||
async def setup_fixture(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
request: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_setup":
|
||||
await mock_setup(hass, tmp_path, MockProvider())
|
||||
elif request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
|
||||
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Set up a test provider."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
"test",
|
||||
TEST_DOMAIN,
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test provider via config entry."""
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_provider_entity])
|
||||
|
||||
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
async def test_get_provider_info(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
) -> None:
|
||||
"""Test engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
response = await client.get("/api/stt/test")
|
||||
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {
|
||||
"languages": ["en"],
|
||||
@ -113,22 +222,44 @@ async def test_get_provider_info(
|
||||
}
|
||||
|
||||
|
||||
async def test_get_non_existing_provider_info(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
async def test_non_existing_provider(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
) -> None:
|
||||
"""Test streaming to engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
|
||||
response = await client.get("/api/stt/not_exist")
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
response = await client.post(
|
||||
"/api/stt/not_exist",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
" language=en"
|
||||
)
|
||||
},
|
||||
)
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
async def test_stream_audio(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
) -> None:
|
||||
"""Test streaming audio and getting response."""
|
||||
client = await hass_client()
|
||||
response = await client.post(
|
||||
"/api/stt/test",
|
||||
f"/api/stt/{TEST_DOMAIN}",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
@ -137,20 +268,39 @@ async def test_stream_audio(
|
||||
},
|
||||
)
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {"text": "test", "result": "success"}
|
||||
assert await response.json() == {"text": "test_result", "result": "success"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("header", "status", "error"),
|
||||
(
|
||||
(None, 400, "Missing X-Speech-Content header"),
|
||||
(
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
||||
" language=en; unknown=1"
|
||||
),
|
||||
400,
|
||||
"Invalid field: unknown",
|
||||
),
|
||||
(
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
||||
" language=en"
|
||||
),
|
||||
400,
|
||||
"100 is not a valid AudioChannels",
|
||||
"Wrong format of X-Speech-Content: 100 is not a valid AudioChannels",
|
||||
),
|
||||
(
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=bad channel;"
|
||||
" language=en"
|
||||
),
|
||||
400,
|
||||
"Wrong format of X-Speech-Content: invalid literal for int() with base 10: 'bad channel'",
|
||||
),
|
||||
(
|
||||
"format=wav; codec=pcm; sample_rate=16000",
|
||||
@ -165,6 +315,7 @@ async def test_metadata_errors(
|
||||
header: str | None,
|
||||
status: int,
|
||||
error: str,
|
||||
setup: str,
|
||||
) -> None:
|
||||
"""Test metadata errors."""
|
||||
client = await hass_client()
|
||||
@ -172,11 +323,55 @@ async def test_metadata_errors(
|
||||
if header:
|
||||
headers["X-Speech-Content"] = header
|
||||
|
||||
response = await client.post("/api/stt/test", headers=headers)
|
||||
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
|
||||
assert response.status == status
|
||||
assert await response.text() == error
|
||||
|
||||
|
||||
async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
||||
async def test_get_provider(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider: MockProvider,
|
||||
) -> None:
|
||||
"""Test we can get STT providers."""
|
||||
assert mock_provider == async_get_provider(hass, "test")
|
||||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
||||
|
||||
|
||||
async def test_config_entry_unload(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test we can unload config entry."""
|
||||
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
assert config_entry.state == ConfigEntryState.LOADED
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
def test_entity_name_raises_before_addition(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test entity name raises before addition to Home Assistant."""
|
||||
with pytest.raises(RuntimeError):
|
||||
mock_provider_entity.name # pylint: disable=pointless-statement
|
||||
|
||||
|
||||
async def test_restore_state(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test we restore state in the integration."""
|
||||
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
|
||||
timestamp = "2023-01-01T23:59:59+00:00"
|
||||
mock_restore_cache(hass, (State(entity_id, timestamp),))
|
||||
|
||||
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert config_entry.state == ConfigEntryState.LOADED
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == timestamp
|
||||
|
Loading…
x
Reference in New Issue
Block a user