mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 06:07:17 +00:00
Move legacy stt (#90776)
* Move legacy stt to separate module * Remove case for None as provider * Add error log for unknown platform * Add some tests
This commit is contained in:
parent
584066b809
commit
535fb34207
@ -1,11 +1,8 @@
|
|||||||
"""Provide functionality to STT."""
|
"""Provide functionality to STT."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from dataclasses import asdict
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -17,10 +14,8 @@ from aiohttp.web_exceptions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_per_platform, discovery
|
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.setup import async_prepare_setup_platform
|
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
@ -31,159 +26,40 @@ from .const import (
|
|||||||
AudioSampleRates,
|
AudioSampleRates,
|
||||||
SpeechResultState,
|
SpeechResultState,
|
||||||
)
|
)
|
||||||
|
from .legacy import (
|
||||||
|
Provider,
|
||||||
|
SpeechMetadata,
|
||||||
|
SpeechResult,
|
||||||
|
async_get_provider,
|
||||||
|
async_setup_legacy,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
__all__ = [
|
||||||
|
"async_get_provider",
|
||||||
|
"AudioBitRates",
|
||||||
@callback
|
"AudioChannels",
|
||||||
def async_get_provider(
|
"AudioCodecs",
|
||||||
hass: HomeAssistant, domain: str | None = None
|
"AudioFormats",
|
||||||
) -> Provider | None:
|
"AudioSampleRates",
|
||||||
"""Return provider."""
|
"DOMAIN",
|
||||||
if domain:
|
"Provider",
|
||||||
return hass.data[DOMAIN].get(domain)
|
"SpeechMetadata",
|
||||||
|
"SpeechResult",
|
||||||
if not hass.data[DOMAIN]:
|
"SpeechResultState",
|
||||||
return None
|
]
|
||||||
|
|
||||||
if "cloud" in hass.data[DOMAIN]:
|
|
||||||
return hass.data[DOMAIN]["cloud"]
|
|
||||||
|
|
||||||
return next(iter(hass.data[DOMAIN].values()))
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up STT."""
|
"""Set up STT."""
|
||||||
providers = hass.data[DOMAIN] = {}
|
platform_setups = async_setup_legacy(hass, config)
|
||||||
|
|
||||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
if platform_setups:
|
||||||
"""Set up a TTS platform."""
|
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||||
if p_config is None:
|
|
||||||
p_config = {}
|
|
||||||
|
|
||||||
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
|
||||||
if platform is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider = await platform.async_get_engine(hass, p_config, discovery_info)
|
|
||||||
if provider is None:
|
|
||||||
_LOGGER.error("Error setting up platform %s", p_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
provider.name = p_type
|
|
||||||
provider.hass = hass
|
|
||||||
|
|
||||||
providers[provider.name] = provider
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
setup_tasks = [
|
|
||||||
asyncio.create_task(async_setup_platform(p_type, p_config))
|
|
||||||
for p_type, p_config in config_per_platform(config, DOMAIN)
|
|
||||||
]
|
|
||||||
|
|
||||||
if setup_tasks:
|
|
||||||
await asyncio.wait(setup_tasks)
|
|
||||||
|
|
||||||
# Add discovery support
|
|
||||||
async def async_platform_discovered(platform, info):
|
|
||||||
"""Handle for discovered platform."""
|
|
||||||
await async_setup_platform(platform, discovery_info=info)
|
|
||||||
|
|
||||||
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
|
||||||
|
|
||||||
hass.http.register_view(SpeechToTextView(providers))
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpeechMetadata:
|
|
||||||
"""Metadata of audio stream."""
|
|
||||||
|
|
||||||
language: str
|
|
||||||
format: AudioFormats
|
|
||||||
codec: AudioCodecs
|
|
||||||
bit_rate: AudioBitRates
|
|
||||||
sample_rate: AudioSampleRates
|
|
||||||
channel: AudioChannels
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
"""Finish initializing the metadata."""
|
|
||||||
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
|
||||||
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
|
||||||
self.channel = AudioChannels(int(self.channel))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpeechResult:
|
|
||||||
"""Result of audio Speech."""
|
|
||||||
|
|
||||||
text: str | None
|
|
||||||
result: SpeechResultState
|
|
||||||
|
|
||||||
|
|
||||||
class Provider(ABC):
|
|
||||||
"""Represent a single STT provider."""
|
|
||||||
|
|
||||||
hass: HomeAssistant | None = None
|
|
||||||
name: str | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_languages(self) -> list[str]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_formats(self) -> list[AudioFormats]:
|
|
||||||
"""Return a list of supported formats."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_codecs(self) -> list[AudioCodecs]:
|
|
||||||
"""Return a list of supported codecs."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
||||||
"""Return a list of supported bit rates."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
||||||
"""Return a list of supported sample rates."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def supported_channels(self) -> list[AudioChannels]:
|
|
||||||
"""Return a list of supported channels."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def async_process_audio_stream(
|
|
||||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
||||||
) -> SpeechResult:
|
|
||||||
"""Process an audio stream to STT service.
|
|
||||||
|
|
||||||
Only streaming of content are allow!
|
|
||||||
"""
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
|
||||||
"""Check if given metadata supported by this provider."""
|
|
||||||
if (
|
|
||||||
metadata.language not in self.supported_languages
|
|
||||||
or metadata.format not in self.supported_formats
|
|
||||||
or metadata.codec not in self.supported_codecs
|
|
||||||
or metadata.bit_rate not in self.supported_bit_rates
|
|
||||||
or metadata.sample_rate not in self.supported_sample_rates
|
|
||||||
or metadata.channel not in self.supported_channels
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextView(HomeAssistantView):
|
class SpeechToTextView(HomeAssistantView):
|
||||||
"""STT view to generate a text from audio stream."""
|
"""STT view to generate a text from audio stream."""
|
||||||
|
|
||||||
@ -203,7 +79,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
|
|
||||||
# Get metadata
|
# Get metadata
|
||||||
try:
|
try:
|
||||||
metadata = metadata_from_header(request)
|
metadata = _metadata_from_header(request)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise HTTPBadRequest(text=str(err)) from err
|
raise HTTPBadRequest(text=str(err)) from err
|
||||||
|
|
||||||
@ -237,7 +113,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def metadata_from_header(request: web.Request) -> SpeechMetadata:
|
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||||
"""Extract STT metadata from header.
|
"""Extract STT metadata from header.
|
||||||
|
|
||||||
X-Speech-Content:
|
X-Speech-Content:
|
||||||
|
169
homeassistant/components/stt/legacy.py
Normal file
169
homeassistant/components/stt/legacy.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
"""Handle legacy speech to text platforms."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterable, Coroutine
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import config_per_platform, discovery
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
DOMAIN,
|
||||||
|
AudioBitRates,
|
||||||
|
AudioChannels,
|
||||||
|
AudioCodecs,
|
||||||
|
AudioFormats,
|
||||||
|
AudioSampleRates,
|
||||||
|
SpeechResultState,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_provider(
|
||||||
|
hass: HomeAssistant, domain: str | None = None
|
||||||
|
) -> Provider | None:
|
||||||
|
"""Return provider."""
|
||||||
|
if domain:
|
||||||
|
return hass.data[DOMAIN].get(domain)
|
||||||
|
|
||||||
|
if not hass.data[DOMAIN]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "cloud" in hass.data[DOMAIN]:
|
||||||
|
return hass.data[DOMAIN]["cloud"]
|
||||||
|
|
||||||
|
return next(iter(hass.data[DOMAIN].values()))
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_setup_legacy(
|
||||||
|
hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> list[Coroutine[Any, Any, None]]:
|
||||||
|
"""Set up legacy speech to text providers."""
|
||||||
|
providers = hass.data[DOMAIN] = {}
|
||||||
|
|
||||||
|
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||||
|
"""Set up a TTS platform."""
|
||||||
|
if p_config is None:
|
||||||
|
p_config = {}
|
||||||
|
|
||||||
|
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||||
|
if platform is None:
|
||||||
|
_LOGGER.error("Unknown speech to text platform specified")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await platform.async_get_engine(hass, p_config, discovery_info)
|
||||||
|
|
||||||
|
provider.name = p_type
|
||||||
|
provider.hass = hass
|
||||||
|
|
||||||
|
providers[provider.name] = provider
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add discovery support
|
||||||
|
async def async_platform_discovered(platform, info):
|
||||||
|
"""Handle for discovered platform."""
|
||||||
|
await async_setup_platform(platform, discovery_info=info)
|
||||||
|
|
||||||
|
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
|
||||||
|
|
||||||
|
return [
|
||||||
|
async_setup_platform(p_type, p_config)
|
||||||
|
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeechMetadata:
|
||||||
|
"""Metadata of audio stream."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
format: AudioFormats
|
||||||
|
codec: AudioCodecs
|
||||||
|
bit_rate: AudioBitRates
|
||||||
|
sample_rate: AudioSampleRates
|
||||||
|
channel: AudioChannels
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Finish initializing the metadata."""
|
||||||
|
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||||
|
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||||
|
self.channel = AudioChannels(int(self.channel))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeechResult:
|
||||||
|
"""Result of audio Speech."""
|
||||||
|
|
||||||
|
text: str | None
|
||||||
|
result: SpeechResultState
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(ABC):
|
||||||
|
"""Represent a single STT provider."""
|
||||||
|
|
||||||
|
hass: HomeAssistant | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_formats(self) -> list[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_codecs(self) -> list[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||||
|
"""Return a list of supported bit rates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||||
|
"""Return a list of supported sample rates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_channels(self) -> list[AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service.
|
||||||
|
|
||||||
|
Only streaming of content are allow!
|
||||||
|
"""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||||
|
"""Check if given metadata supported by this provider."""
|
||||||
|
if (
|
||||||
|
metadata.language not in self.supported_languages
|
||||||
|
or metadata.format not in self.supported_formats
|
||||||
|
or metadata.codec not in self.supported_codecs
|
||||||
|
or metadata.bit_rate not in self.supported_bit_rates
|
||||||
|
or metadata.sample_rate not in self.supported_sample_rates
|
||||||
|
or metadata.channel not in self.supported_channels
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
56
tests/components/stt/common.py
Normal file
56
tests/components/stt/common.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Provide common test tools for STT."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.stt import Provider
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
|
from tests.common import MockPlatform, mock_platform
|
||||||
|
|
||||||
|
|
||||||
|
class MockSTTPlatform(MockPlatform):
|
||||||
|
"""Help to set up test stt service."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
async_get_engine: Callable[
|
||||||
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
||||||
|
Coroutine[Any, Any, Provider | None],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
get_engine: Callable[
|
||||||
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Return the stt service."""
|
||||||
|
super().__init__()
|
||||||
|
if get_engine:
|
||||||
|
self.get_engine = get_engine
|
||||||
|
if async_get_engine:
|
||||||
|
self.async_get_engine = async_get_engine
|
||||||
|
|
||||||
|
|
||||||
|
def mock_stt_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
integration: str = "stt",
|
||||||
|
async_get_engine: Callable[
|
||||||
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
||||||
|
Coroutine[Any, Any, Provider | None],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
get_engine: Callable[
|
||||||
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
):
|
||||||
|
"""Specialize the mock platform for stt."""
|
||||||
|
loaded_platform = MockSTTPlatform(async_get_engine, get_engine)
|
||||||
|
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||||
|
|
||||||
|
return loaded_platform
|
@ -1,7 +1,8 @@
|
|||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from unittest.mock import AsyncMock, Mock
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -20,7 +21,8 @@ from homeassistant.components.stt import (
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import mock_platform
|
from .common import mock_stt_platform
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +33,7 @@ class MockProvider(Provider):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Init test provider."""
|
"""Init test provider."""
|
||||||
self.calls = []
|
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
@ -81,10 +83,15 @@ def mock_provider() -> MockProvider:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def mock_setup(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
async def mock_setup(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||||
|
) -> None:
|
||||||
"""Set up a test provider."""
|
"""Set up a test provider."""
|
||||||
mock_platform(
|
mock_stt_platform(
|
||||||
hass, "test.stt", Mock(async_get_engine=AsyncMock(return_value=mock_provider))
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
"test",
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
)
|
)
|
||||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
||||||
|
|
||||||
|
56
tests/components/stt/test_legacy.py
Normal file
56
tests/components/stt/test_legacy.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Test the legacy stt setup."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.stt import Provider
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.discovery import async_load_platform
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
|
from .common import mock_stt_platform
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_platform(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test platform setup with an invalid platform."""
|
||||||
|
await async_load_platform(
|
||||||
|
hass,
|
||||||
|
"stt",
|
||||||
|
"bad_stt",
|
||||||
|
{"stt": [{"platform": "bad_stt"}]},
|
||||||
|
hass_config={"stt": [{"platform": "bad_stt"}]},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Unknown speech to text platform specified" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platform_setup_with_error(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Test platform setup with an error during setup."""
|
||||||
|
|
||||||
|
async def async_get_engine(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
) -> Provider:
|
||||||
|
"""Raise exception during platform setup."""
|
||||||
|
raise Exception("Setup error") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
|
mock_stt_platform(hass, tmp_path, "bad_stt", async_get_engine=async_get_engine)
|
||||||
|
|
||||||
|
await async_load_platform(
|
||||||
|
hass,
|
||||||
|
"stt",
|
||||||
|
"bad_stt",
|
||||||
|
{},
|
||||||
|
hass_config={"stt": [{"platform": "bad_stt"}]},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Error setting up platform: bad_stt" in caplog.text
|
Loading…
x
Reference in New Issue
Block a user