mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 18:27:51 +00:00
Add stt entity (#91230)
* Add stt entity * Update demo platform * Rename ProviderEntity to SpeechToTextEntity * Fix get method * Run all init tests for config entry setup * Fix and test metadata from header * Test config entry unload * Rename get provider entity * Test post for non existing provider * Test entity name before addition * Test restore state * Use register shutdown * Update deprecation comment
This commit is contained in:
parent
22a1a6846d
commit
473cbf7f9b
@ -36,6 +36,7 @@ COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM = [
|
|||||||
Platform.SELECT,
|
Platform.SELECT,
|
||||||
Platform.SENSOR,
|
Platform.SENSOR,
|
||||||
Platform.SIREN,
|
Platform.SIREN,
|
||||||
|
Platform.STT,
|
||||||
Platform.SWITCH,
|
Platform.SWITCH,
|
||||||
Platform.TEXT,
|
Platform.TEXT,
|
||||||
Platform.UPDATE,
|
Platform.UPDATE,
|
||||||
|
@ -13,8 +13,11 @@ from homeassistant.components.stt import (
|
|||||||
SpeechMetadata,
|
SpeechMetadata,
|
||||||
SpeechResult,
|
SpeechResult,
|
||||||
SpeechResultState,
|
SpeechResultState,
|
||||||
|
SpeechToTextEntity,
|
||||||
)
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
SUPPORT_LANGUAGES = ["en", "de"]
|
SUPPORT_LANGUAGES = ["en", "de"]
|
||||||
@ -29,6 +32,60 @@ async def async_get_engine(
|
|||||||
return DemoProvider()
|
return DemoProvider()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up Demo speech platform via config entry."""
|
||||||
|
async_add_entities([DemoProviderEntity()])
|
||||||
|
|
||||||
|
|
||||||
|
class DemoProviderEntity(SpeechToTextEntity):
|
||||||
|
"""Demo speech API provider entity."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return SUPPORT_LANGUAGES
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> list[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [AudioFormats.WAV]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> list[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [AudioCodecs.PCM]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||||
|
"""Return a list of supported bit rates."""
|
||||||
|
return [AudioBitRates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||||
|
"""Return a list of supported sample rates."""
|
||||||
|
return [AudioSampleRates.SAMPLERATE_16000, AudioSampleRates.SAMPLERATE_44100]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_channels(self) -> list[AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
return [AudioChannels.CHANNEL_STEREO]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service."""
|
||||||
|
|
||||||
|
# Read available data
|
||||||
|
async for _ in stream:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
class DemoProvider(Provider):
|
class DemoProvider(Provider):
|
||||||
"""Demo speech API provider."""
|
"""Demo speech API provider."""
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
"""Provide functionality to STT."""
|
"""Provide functionality to STT."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any
|
import logging
|
||||||
|
from typing import Any, final
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.hdrs import istr
|
from aiohttp.hdrs import istr
|
||||||
@ -14,10 +17,16 @@ from aiohttp.web_exceptions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
DATA_PROVIDERS,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
AudioBitRates,
|
AudioBitRates,
|
||||||
AudioChannels,
|
AudioChannels,
|
||||||
@ -36,6 +45,7 @@ from .legacy import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"async_get_provider",
|
"async_get_provider",
|
||||||
|
"async_get_speech_to_text_entity",
|
||||||
"AudioBitRates",
|
"AudioBitRates",
|
||||||
"AudioChannels",
|
"AudioChannels",
|
||||||
"AudioCodecs",
|
"AudioCodecs",
|
||||||
@ -43,26 +53,158 @@ __all__ = [
|
|||||||
"AudioSampleRates",
|
"AudioSampleRates",
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"Provider",
|
"Provider",
|
||||||
|
"SpeechToTextEntity",
|
||||||
"SpeechMetadata",
|
"SpeechMetadata",
|
||||||
"SpeechResult",
|
"SpeechResult",
|
||||||
"SpeechResultState",
|
"SpeechResultState",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_speech_to_text_entity(
|
||||||
|
hass: HomeAssistant, entity_id: str
|
||||||
|
) -> SpeechToTextEntity | None:
|
||||||
|
"""Return stt entity."""
|
||||||
|
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
return component.get_entity(entity_id)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up STT."""
|
"""Set up STT."""
|
||||||
|
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
|
||||||
|
_LOGGER, DOMAIN, hass
|
||||||
|
)
|
||||||
|
|
||||||
|
component.register_shutdown()
|
||||||
platform_setups = async_setup_legacy(hass, config)
|
platform_setups = async_setup_legacy(hass, config)
|
||||||
|
|
||||||
if platform_setups:
|
if platform_setups:
|
||||||
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
|
||||||
|
|
||||||
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
|
hass.http.register_view(SpeechToTextView(hass.data[DATA_PROVIDERS]))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Set up a config entry."""
|
||||||
|
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_setup_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
"""Unload a config entry."""
|
||||||
|
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||||
|
return await component.async_unload_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextEntity(RestoreEntity):
|
||||||
|
"""Represent a single STT provider."""
|
||||||
|
|
||||||
|
_attr_should_poll = False
|
||||||
|
__last_processed: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@final
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of the provider entity."""
|
||||||
|
# Only one entity is allowed per platform for now.
|
||||||
|
if self.platform is None:
|
||||||
|
raise RuntimeError("Entity is not added to hass yet.")
|
||||||
|
|
||||||
|
return self.platform.platform_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
@final
|
||||||
|
def state(self) -> str | None:
|
||||||
|
"""Return the state of the provider entity."""
|
||||||
|
if self.__last_processed is None:
|
||||||
|
return None
|
||||||
|
return self.__last_processed
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_formats(self) -> list[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_codecs(self) -> list[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
||||||
|
"""Return a list of supported bit rates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
||||||
|
"""Return a list of supported sample rates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_channels(self) -> list[AudioChannels]:
|
||||||
|
"""Return a list of supported channels."""
|
||||||
|
|
||||||
|
async def async_internal_added_to_hass(self) -> None:
|
||||||
|
"""Call when the provider entity is added to hass."""
|
||||||
|
await super().async_internal_added_to_hass()
|
||||||
|
state = await self.async_get_last_state()
|
||||||
|
if (
|
||||||
|
state is not None
|
||||||
|
and state.state is not None
|
||||||
|
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||||
|
):
|
||||||
|
self.__last_processed = state.state
|
||||||
|
|
||||||
|
@final
|
||||||
|
async def internal_async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service.
|
||||||
|
|
||||||
|
Only streaming content is allowed!
|
||||||
|
"""
|
||||||
|
self.__last_processed = dt_util.utcnow().isoformat()
|
||||||
|
self.async_write_ha_state()
|
||||||
|
return await self.async_process_audio_stream(metadata=metadata, stream=stream)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service.
|
||||||
|
|
||||||
|
Only streaming content is allowed!
|
||||||
|
"""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||||
|
"""Check if given metadata supported by this provider."""
|
||||||
|
if (
|
||||||
|
metadata.language not in self.supported_languages
|
||||||
|
or metadata.format not in self.supported_formats
|
||||||
|
or metadata.codec not in self.supported_codecs
|
||||||
|
or metadata.bit_rate not in self.supported_bit_rates
|
||||||
|
or metadata.sample_rate not in self.supported_sample_rates
|
||||||
|
or metadata.channel not in self.supported_channels
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextView(HomeAssistantView):
|
class SpeechToTextView(HomeAssistantView):
|
||||||
"""STT view to generate a text from audio stream."""
|
"""STT view to generate a text from audio stream."""
|
||||||
|
|
||||||
|
_legacy_provider_reported = False
|
||||||
requires_auth = True
|
requires_auth = True
|
||||||
url = "/api/stt/{provider}"
|
url = "/api/stt/{provider}"
|
||||||
name = "api:stt:provider"
|
name = "api:stt:provider"
|
||||||
@ -73,9 +215,17 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
|
|
||||||
async def post(self, request: web.Request, provider: str) -> web.Response:
|
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||||
"""Convert Speech (audio) to text."""
|
"""Convert Speech (audio) to text."""
|
||||||
if provider not in self.providers:
|
hass: HomeAssistant = request.app["hass"]
|
||||||
|
provider_entity: SpeechToTextEntity | None = None
|
||||||
|
if (
|
||||||
|
not (
|
||||||
|
provider_entity := async_get_speech_to_text_entity(
|
||||||
|
hass, f"{DOMAIN}.{provider}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
and provider not in self.providers
|
||||||
|
):
|
||||||
raise HTTPNotFound()
|
raise HTTPNotFound()
|
||||||
stt_provider: Provider = self.providers[provider]
|
|
||||||
|
|
||||||
# Get metadata
|
# Get metadata
|
||||||
try:
|
try:
|
||||||
@ -83,35 +233,105 @@ class SpeechToTextView(HomeAssistantView):
|
|||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise HTTPBadRequest(text=str(err)) from err
|
raise HTTPBadRequest(text=str(err)) from err
|
||||||
|
|
||||||
# Check format
|
if not provider_entity:
|
||||||
if not stt_provider.check_metadata(metadata):
|
stt_provider = self._get_provider(provider)
|
||||||
raise HTTPUnsupportedMediaType()
|
|
||||||
|
|
||||||
# Process audio stream
|
# Check format
|
||||||
result = await stt_provider.async_process_audio_stream(
|
if not stt_provider.check_metadata(metadata):
|
||||||
metadata, request.content
|
raise HTTPUnsupportedMediaType()
|
||||||
)
|
|
||||||
|
# Process audio stream
|
||||||
|
result = await stt_provider.async_process_audio_stream(
|
||||||
|
metadata, request.content
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Check format
|
||||||
|
if not provider_entity.check_metadata(metadata):
|
||||||
|
raise HTTPUnsupportedMediaType()
|
||||||
|
|
||||||
|
# Process audio stream
|
||||||
|
result = await provider_entity.internal_async_process_audio_stream(
|
||||||
|
metadata, request.content
|
||||||
|
)
|
||||||
|
|
||||||
# Return result
|
# Return result
|
||||||
return self.json(asdict(result))
|
return self.json(asdict(result))
|
||||||
|
|
||||||
async def get(self, request: web.Request, provider: str) -> web.Response:
|
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||||
"""Return provider specific audio information."""
|
"""Return provider specific audio information."""
|
||||||
if provider not in self.providers:
|
hass: HomeAssistant = request.app["hass"]
|
||||||
|
if (
|
||||||
|
not (
|
||||||
|
provider_entity := async_get_speech_to_text_entity(
|
||||||
|
hass, f"{DOMAIN}.{provider}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
and provider not in self.providers
|
||||||
|
):
|
||||||
raise HTTPNotFound()
|
raise HTTPNotFound()
|
||||||
stt_provider: Provider = self.providers[provider]
|
|
||||||
|
if not provider_entity:
|
||||||
|
stt_provider = self._get_provider(provider)
|
||||||
|
|
||||||
|
return self.json(
|
||||||
|
{
|
||||||
|
"languages": stt_provider.supported_languages,
|
||||||
|
"formats": stt_provider.supported_formats,
|
||||||
|
"codecs": stt_provider.supported_codecs,
|
||||||
|
"sample_rates": stt_provider.supported_sample_rates,
|
||||||
|
"bit_rates": stt_provider.supported_bit_rates,
|
||||||
|
"channels": stt_provider.supported_channels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return self.json(
|
return self.json(
|
||||||
{
|
{
|
||||||
"languages": stt_provider.supported_languages,
|
"languages": provider_entity.supported_languages,
|
||||||
"formats": stt_provider.supported_formats,
|
"formats": provider_entity.supported_formats,
|
||||||
"codecs": stt_provider.supported_codecs,
|
"codecs": provider_entity.supported_codecs,
|
||||||
"sample_rates": stt_provider.supported_sample_rates,
|
"sample_rates": provider_entity.supported_sample_rates,
|
||||||
"bit_rates": stt_provider.supported_bit_rates,
|
"bit_rates": provider_entity.supported_bit_rates,
|
||||||
"channels": stt_provider.supported_channels,
|
"channels": provider_entity.supported_channels,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_provider(self, provider: str) -> Provider:
|
||||||
|
"""Get provider.
|
||||||
|
|
||||||
|
Method for legacy providers.
|
||||||
|
This can be removed when we remove the legacy provider support.
|
||||||
|
"""
|
||||||
|
stt_provider = self.providers[provider]
|
||||||
|
|
||||||
|
if not self._legacy_provider_reported:
|
||||||
|
self._legacy_provider_reported = True
|
||||||
|
report_issue = self._suggest_report_issue(provider, stt_provider)
|
||||||
|
# This should raise in Home Assistant Core 2023.9
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Provider %s (%s) is using a legacy implementation, "
|
||||||
|
"and should be updated to use the SpeechToTextEntity. Please "
|
||||||
|
"%s",
|
||||||
|
provider,
|
||||||
|
type(stt_provider),
|
||||||
|
report_issue,
|
||||||
|
)
|
||||||
|
|
||||||
|
return stt_provider
|
||||||
|
|
||||||
|
def _suggest_report_issue(self, provider: str, provider_instance: object) -> str:
|
||||||
|
"""Suggest to report an issue."""
|
||||||
|
report_issue = ""
|
||||||
|
if "custom_components" in type(provider_instance).__module__:
|
||||||
|
report_issue = "report it to the custom integration author."
|
||||||
|
else:
|
||||||
|
report_issue = (
|
||||||
|
"create a bug report at "
|
||||||
|
"https://github.com/home-assistant/core/issues?q=is%3Aopen+is%3Aissue"
|
||||||
|
)
|
||||||
|
report_issue += f"+label%3A%22integration%3A+{provider}%22"
|
||||||
|
|
||||||
|
return report_issue
|
||||||
|
|
||||||
|
|
||||||
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||||
"""Extract STT metadata from header.
|
"""Extract STT metadata from header.
|
||||||
@ -138,7 +358,7 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
|||||||
for entry in data:
|
for entry in data:
|
||||||
key, _, value = entry.strip().partition("=")
|
key, _, value = entry.strip().partition("=")
|
||||||
if key not in fields:
|
if key not in fields:
|
||||||
raise ValueError(f"Invalid field {key}")
|
raise ValueError(f"Invalid field: {key}")
|
||||||
args[key] = value
|
args[key] = value
|
||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
@ -154,5 +374,5 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
|||||||
sample_rate=args["sample_rate"],
|
sample_rate=args["sample_rate"],
|
||||||
channel=args["channel"],
|
channel=args["channel"],
|
||||||
)
|
)
|
||||||
except TypeError as err:
|
except ValueError as err:
|
||||||
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
DOMAIN = "stt"
|
DOMAIN = "stt"
|
||||||
|
DATA_PROVIDERS = f"{DOMAIN}_providers"
|
||||||
|
|
||||||
|
|
||||||
class AudioCodecs(str, Enum):
|
class AudioCodecs(str, Enum):
|
||||||
|
@ -13,6 +13,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||||||
from homeassistant.setup import async_prepare_setup_platform
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
DATA_PROVIDERS,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
AudioBitRates,
|
AudioBitRates,
|
||||||
AudioChannels,
|
AudioChannels,
|
||||||
@ -31,15 +32,15 @@ def async_get_provider(
|
|||||||
) -> Provider | None:
|
) -> Provider | None:
|
||||||
"""Return provider."""
|
"""Return provider."""
|
||||||
if domain:
|
if domain:
|
||||||
return hass.data[DOMAIN].get(domain)
|
return hass.data[DATA_PROVIDERS].get(domain)
|
||||||
|
|
||||||
if not hass.data[DOMAIN]:
|
if not hass.data[DATA_PROVIDERS]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if "cloud" in hass.data[DOMAIN]:
|
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||||
return hass.data[DOMAIN]["cloud"]
|
return hass.data[DATA_PROVIDERS]["cloud"]
|
||||||
|
|
||||||
return next(iter(hass.data[DOMAIN].values()))
|
return next(iter(hass.data[DATA_PROVIDERS].values()))
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -47,7 +48,7 @@ def async_setup_legacy(
|
|||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> list[Coroutine[Any, Any, None]]:
|
) -> list[Coroutine[Any, Any, None]]:
|
||||||
"""Set up legacy speech to text providers."""
|
"""Set up legacy speech to text providers."""
|
||||||
providers = hass.data[DOMAIN] = {}
|
providers = hass.data[DATA_PROVIDERS] = {}
|
||||||
|
|
||||||
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
|
||||||
"""Set up a TTS platform."""
|
"""Set up a TTS platform."""
|
||||||
|
@ -4,18 +4,31 @@ from http import HTTPStatus
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture
|
||||||
async def setup_comp(hass):
|
async def setup_legacy_platform(hass: HomeAssistant) -> None:
|
||||||
"""Set up demo component."""
|
"""Set up legacy demo platform."""
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_config_entry(hass: HomeAssistant) -> None:
|
||||||
|
"""Set up demo component from config entry."""
|
||||||
|
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||||
async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
|
async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test retrieve settings from demo provider."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -34,6 +47,7 @@ async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||||
async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> None:
|
async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> None:
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test retrieve settings from demo provider."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -42,6 +56,7 @@ async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> N
|
|||||||
assert response.status == HTTPStatus.BAD_REQUEST
|
assert response.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||||
async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -> None:
|
async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -> None:
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test retrieve settings from demo provider."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -59,6 +74,7 @@ async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -
|
|||||||
assert response.status == HTTPStatus.UNSUPPORTED_MEDIA_TYPE
|
assert response.status == HTTPStatus.UNSUPPORTED_MEDIA_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_legacy_platform")
|
||||||
async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
|
async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
|
||||||
"""Test retrieve settings from demo provider."""
|
"""Test retrieve settings from demo provider."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -77,3 +93,26 @@ async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
|
|||||||
|
|
||||||
assert response.status == HTTPStatus.OK
|
assert response.status == HTTPStatus.OK
|
||||||
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_config_entry")
|
||||||
|
async def test_config_entry_demo_speech(
|
||||||
|
hass_client: ClientSessionGenerator, hass: HomeAssistant
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieve settings from demo provider from config entry."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/stt/demo",
|
||||||
|
headers={
|
||||||
|
"X-Speech-Content": (
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
|
||||||
|
" language=de"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
data=b"Test",
|
||||||
|
)
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
assert response.status == HTTPStatus.OK
|
||||||
|
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
||||||
|
@ -6,7 +6,9 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.stt import Provider
|
from homeassistant.components.stt import Provider
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from tests.common import MockPlatform, mock_platform
|
from tests.common import MockPlatform, mock_platform
|
||||||
@ -54,3 +56,19 @@ def mock_stt_platform(
|
|||||||
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||||
|
|
||||||
return loaded_platform
|
return loaded_platform
|
||||||
|
|
||||||
|
|
||||||
|
def mock_stt_entity_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
integration: str,
|
||||||
|
async_setup_entry: Callable[
|
||||||
|
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
) -> MockPlatform:
|
||||||
|
"""Specialize the mock platform for stt."""
|
||||||
|
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||||
|
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
||||||
|
return loaded_platform
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Test STT component setup."""
|
"""Test STT component setup."""
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable, Generator
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
@ -7,6 +7,7 @@ from unittest.mock import AsyncMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.stt import (
|
from homeassistant.components.stt import (
|
||||||
|
DOMAIN,
|
||||||
AudioBitRates,
|
AudioBitRates,
|
||||||
AudioChannels,
|
AudioChannels,
|
||||||
AudioCodecs,
|
AudioCodecs,
|
||||||
@ -16,17 +17,30 @@ from homeassistant.components.stt import (
|
|||||||
SpeechMetadata,
|
SpeechMetadata,
|
||||||
SpeechResult,
|
SpeechResult,
|
||||||
SpeechResultState,
|
SpeechResultState,
|
||||||
|
SpeechToTextEntity,
|
||||||
async_get_provider,
|
async_get_provider,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||||
|
from homeassistant.core import HomeAssistant, State
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .common import mock_stt_platform
|
from .common import mock_stt_entity_platform, mock_stt_platform
|
||||||
|
|
||||||
|
from tests.common import (
|
||||||
|
MockConfigEntry,
|
||||||
|
MockModule,
|
||||||
|
mock_config_flow,
|
||||||
|
mock_integration,
|
||||||
|
mock_platform,
|
||||||
|
mock_restore_cache,
|
||||||
|
)
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
TEST_DOMAIN = "test"
|
||||||
|
|
||||||
class MockProvider(Provider):
|
|
||||||
|
class BaseProvider:
|
||||||
"""Mock provider."""
|
"""Mock provider."""
|
||||||
|
|
||||||
fail_process_audio = False
|
fail_process_audio = False
|
||||||
@ -73,7 +87,15 @@ class MockProvider(Provider):
|
|||||||
if self.fail_process_audio:
|
if self.fail_process_audio:
|
||||||
return SpeechResult(None, SpeechResultState.ERROR)
|
return SpeechResult(None, SpeechResultState.ERROR)
|
||||||
|
|
||||||
return SpeechResult("test", SpeechResultState.SUCCESS)
|
return SpeechResult("test_result", SpeechResultState.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
|
class MockProvider(BaseProvider, Provider):
|
||||||
|
"""Mock provider."""
|
||||||
|
|
||||||
|
|
||||||
|
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
||||||
|
"""Mock provider entity."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -82,26 +104,113 @@ def mock_provider() -> MockProvider:
|
|||||||
return MockProvider()
|
return MockProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_provider_entity() -> MockProviderEntity:
|
||||||
|
"""Test provider entity fixture."""
|
||||||
|
return MockProviderEntity()
|
||||||
|
|
||||||
|
|
||||||
|
class STTFlow(ConfigFlow):
|
||||||
|
"""Test flow."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||||
|
"""Mock config flow."""
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||||
|
|
||||||
|
with mock_config_flow(TEST_DOMAIN, STTFlow):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="setup")
|
||||||
|
async def setup_fixture(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Set up the test environment."""
|
||||||
|
if request.param == "mock_setup":
|
||||||
|
await mock_setup(hass, tmp_path, MockProvider())
|
||||||
|
elif request.param == "mock_config_entry_setup":
|
||||||
|
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid setup fixture")
|
||||||
|
|
||||||
|
|
||||||
async def mock_setup(
|
async def mock_setup(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_provider: MockProvider,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up a test provider."""
|
"""Set up a test provider."""
|
||||||
mock_stt_platform(
|
mock_stt_platform(
|
||||||
hass,
|
hass,
|
||||||
tmp_path,
|
tmp_path,
|
||||||
"test",
|
TEST_DOMAIN,
|
||||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
)
|
)
|
||||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_config_entry_setup(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
"""Set up a test provider via config entry."""
|
||||||
|
|
||||||
|
async def async_setup_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Set up test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def async_unload_entry_init(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Unload up test config entry."""
|
||||||
|
await hass.config_entries.async_forward_entry_unload(config_entry, DOMAIN)
|
||||||
|
return True
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_setup_entry=async_setup_entry_init,
|
||||||
|
async_unload_entry=async_unload_entry_init,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_setup_entry_platform(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up test stt platform via config entry."""
|
||||||
|
async_add_entities([mock_provider_entity])
|
||||||
|
|
||||||
|
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
|
||||||
|
|
||||||
|
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
return config_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||||
|
)
|
||||||
async def test_get_provider_info(
|
async def test_get_provider_info(
|
||||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
setup: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test engine that doesn't exist."""
|
"""Test engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
response = await client.get("/api/stt/test")
|
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
|
||||||
assert response.status == HTTPStatus.OK
|
assert response.status == HTTPStatus.OK
|
||||||
assert await response.json() == {
|
assert await response.json() == {
|
||||||
"languages": ["en"],
|
"languages": ["en"],
|
||||||
@ -113,22 +222,44 @@ async def test_get_provider_info(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_get_non_existing_provider_info(
|
@pytest.mark.parametrize(
|
||||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||||
|
)
|
||||||
|
async def test_non_existing_provider(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
setup: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming to engine that doesn't exist."""
|
"""Test streaming to engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
response = await client.get("/api/stt/not_exist")
|
response = await client.get("/api/stt/not_exist")
|
||||||
assert response.status == HTTPStatus.NOT_FOUND
|
assert response.status == HTTPStatus.NOT_FOUND
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/stt/not_exist",
|
||||||
|
headers={
|
||||||
|
"X-Speech-Content": (
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||||
|
" language=en"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status == HTTPStatus.NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||||
|
)
|
||||||
async def test_stream_audio(
|
async def test_stream_audio(
|
||||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
setup: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and getting response."""
|
"""Test streaming audio and getting response."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/stt/test",
|
f"/api/stt/{TEST_DOMAIN}",
|
||||||
headers={
|
headers={
|
||||||
"X-Speech-Content": (
|
"X-Speech-Content": (
|
||||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||||
@ -137,20 +268,39 @@ async def test_stream_audio(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status == HTTPStatus.OK
|
assert response.status == HTTPStatus.OK
|
||||||
assert await response.json() == {"text": "test", "result": "success"}
|
assert await response.json() == {"text": "test_result", "result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("header", "status", "error"),
|
("header", "status", "error"),
|
||||||
(
|
(
|
||||||
(None, 400, "Missing X-Speech-Content header"),
|
(None, 400, "Missing X-Speech-Content header"),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
||||||
|
" language=en; unknown=1"
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
"Invalid field: unknown",
|
||||||
|
),
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
|
||||||
" language=en"
|
" language=en"
|
||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
"100 is not a valid AudioChannels",
|
"Wrong format of X-Speech-Content: 100 is not a valid AudioChannels",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=bad channel;"
|
||||||
|
" language=en"
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
"Wrong format of X-Speech-Content: invalid literal for int() with base 10: 'bad channel'",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"format=wav; codec=pcm; sample_rate=16000",
|
"format=wav; codec=pcm; sample_rate=16000",
|
||||||
@ -165,6 +315,7 @@ async def test_metadata_errors(
|
|||||||
header: str | None,
|
header: str | None,
|
||||||
status: int,
|
status: int,
|
||||||
error: str,
|
error: str,
|
||||||
|
setup: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test metadata errors."""
|
"""Test metadata errors."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@ -172,11 +323,55 @@ async def test_metadata_errors(
|
|||||||
if header:
|
if header:
|
||||||
headers["X-Speech-Content"] = header
|
headers["X-Speech-Content"] = header
|
||||||
|
|
||||||
response = await client.post("/api/stt/test", headers=headers)
|
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
|
||||||
assert response.status == status
|
assert response.status == status
|
||||||
assert await response.text() == error
|
assert await response.text() == error
|
||||||
|
|
||||||
|
|
||||||
async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None:
|
async def test_get_provider(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_provider: MockProvider,
|
||||||
|
) -> None:
|
||||||
"""Test we can get STT providers."""
|
"""Test we can get STT providers."""
|
||||||
assert mock_provider == async_get_provider(hass, "test")
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
|
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_config_entry_unload(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||||
|
) -> None:
|
||||||
|
"""Test we can unload config entry."""
|
||||||
|
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
|
assert config_entry.state == ConfigEntryState.LOADED
|
||||||
|
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||||
|
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_name_raises_before_addition(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_provider_entity: MockProviderEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test entity name raises before addition to Home Assistant."""
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
mock_provider_entity.name # pylint: disable=pointless-statement
|
||||||
|
|
||||||
|
|
||||||
|
async def test_restore_state(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_provider_entity: MockProviderEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test we restore state in the integration."""
|
||||||
|
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
|
||||||
|
timestamp = "2023-01-01T23:59:59+00:00"
|
||||||
|
mock_restore_cache(hass, (State(entity_id, timestamp),))
|
||||||
|
|
||||||
|
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert config_entry.state == ConfigEntryState.LOADED
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state == timestamp
|
||||||
|
Loading…
x
Reference in New Issue
Block a user