Add stt entity (#91230)

* Add stt entity

* Update demo platform

* Rename ProviderEntity to SpeechToTextEntity

* Fix get method

* Run all init tests for config entry setup

* Fix and test metadata from header

* Test config entry unload

* Rename get provider entity

* Test post for non existing provider

* Test entity name before addition

* Test restore state

* Use register shutdown

* Update deprecation comment
This commit is contained in:
Martin Hjelmare 2023-04-13 19:58:35 +02:00 committed by GitHub
parent 22a1a6846d
commit 473cbf7f9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 582 additions and 50 deletions

View File

@ -36,6 +36,7 @@ COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM = [
Platform.SELECT,
Platform.SENSOR,
Platform.SIREN,
Platform.STT,
Platform.SWITCH,
Platform.TEXT,
Platform.UPDATE,

View File

@ -13,8 +13,11 @@ from homeassistant.components.stt import (
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
SUPPORT_LANGUAGES = ["en", "de"]
@ -29,6 +32,60 @@ async def async_get_engine(
return DemoProvider()
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Demo speech platform via config entry."""
async_add_entities([DemoProviderEntity()])
class DemoProviderEntity(SpeechToTextEntity):
"""Demo speech API provider entity."""
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
return [AudioSampleRates.SAMPLERATE_16000, AudioSampleRates.SAMPLERATE_44100]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_STEREO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
# Read available data
async for _ in stream:
pass
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
class DemoProvider(Provider):
"""Demo speech API provider."""

View File

@ -1,9 +1,12 @@
"""Provide functionality to STT."""
from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
from dataclasses import asdict
from typing import Any
import logging
from typing import Any, final
from aiohttp import web
from aiohttp.hdrs import istr
@ -14,10 +17,16 @@ from aiohttp.web_exceptions import (
)
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
from .const import (
DATA_PROVIDERS,
DOMAIN,
AudioBitRates,
AudioChannels,
@ -36,6 +45,7 @@ from .legacy import (
__all__ = [
"async_get_provider",
"async_get_speech_to_text_entity",
"AudioBitRates",
"AudioChannels",
"AudioCodecs",
@ -43,26 +53,158 @@ __all__ = [
"AudioSampleRates",
"DOMAIN",
"Provider",
"SpeechToTextEntity",
"SpeechMetadata",
"SpeechResult",
"SpeechResultState",
]
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_speech_to_text_entity(
hass: HomeAssistant, entity_id: str
) -> SpeechToTextEntity | None:
"""Return stt entity."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return component.get_entity(entity_id)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
_LOGGER, DOMAIN, hass
)
component.register_shutdown()
platform_setups = async_setup_legacy(hass, config)
if platform_setups:
await asyncio.wait([asyncio.create_task(setup) for setup in platform_setups])
hass.http.register_view(SpeechToTextView(hass.data[DOMAIN]))
hass.http.register_view(SpeechToTextView(hass.data[DATA_PROVIDERS]))
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
class SpeechToTextEntity(RestoreEntity):
"""Represent a single STT provider."""
_attr_should_poll = False
__last_processed: str | None = None
@property
@final
def name(self) -> str:
"""Return the name of the provider entity."""
# Only one entity is allowed per platform for now.
if self.platform is None:
raise RuntimeError("Entity is not added to hass yet.")
return self.platform.platform_name
@property
@final
def state(self) -> str | None:
"""Return the state of the provider entity."""
if self.__last_processed is None:
return None
return self.__last_processed
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
async def async_internal_added_to_hass(self) -> None:
"""Call when the provider entity is added to hass."""
await super().async_internal_added_to_hass()
state = await self.async_get_last_state()
if (
state is not None
and state.state is not None
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
):
self.__last_processed = state.state
@final
async def internal_async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming content is allowed!
"""
self.__last_processed = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_process_audio_stream(metadata=metadata, stream=stream)
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming content is allowed!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True
class SpeechToTextView(HomeAssistantView):
"""STT view to generate a text from audio stream."""
_legacy_provider_reported = False
requires_auth = True
url = "/api/stt/{provider}"
name = "api:stt:provider"
@ -73,9 +215,17 @@ class SpeechToTextView(HomeAssistantView):
async def post(self, request: web.Request, provider: str) -> web.Response:
"""Convert Speech (audio) to text."""
if provider not in self.providers:
hass: HomeAssistant = request.app["hass"]
provider_entity: SpeechToTextEntity | None = None
if (
not (
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
and provider not in self.providers
):
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
# Get metadata
try:
@ -83,6 +233,9 @@ class SpeechToTextView(HomeAssistantView):
except ValueError as err:
raise HTTPBadRequest(text=str(err)) from err
if not provider_entity:
stt_provider = self._get_provider(provider)
# Check format
if not stt_provider.check_metadata(metadata):
raise HTTPUnsupportedMediaType()
@ -91,15 +244,34 @@ class SpeechToTextView(HomeAssistantView):
result = await stt_provider.async_process_audio_stream(
metadata, request.content
)
else:
# Check format
if not provider_entity.check_metadata(metadata):
raise HTTPUnsupportedMediaType()
# Process audio stream
result = await provider_entity.internal_async_process_audio_stream(
metadata, request.content
)
# Return result
return self.json(asdict(result))
async def get(self, request: web.Request, provider: str) -> web.Response:
"""Return provider specific audio information."""
if provider not in self.providers:
hass: HomeAssistant = request.app["hass"]
if (
not (
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
and provider not in self.providers
):
raise HTTPNotFound()
stt_provider: Provider = self.providers[provider]
if not provider_entity:
stt_provider = self._get_provider(provider)
return self.json(
{
@ -112,6 +284,54 @@ class SpeechToTextView(HomeAssistantView):
}
)
return self.json(
{
"languages": provider_entity.supported_languages,
"formats": provider_entity.supported_formats,
"codecs": provider_entity.supported_codecs,
"sample_rates": provider_entity.supported_sample_rates,
"bit_rates": provider_entity.supported_bit_rates,
"channels": provider_entity.supported_channels,
}
)
def _get_provider(self, provider: str) -> Provider:
"""Get provider.
Method for legacy providers.
This can be removed when we remove the legacy provider support.
"""
stt_provider = self.providers[provider]
if not self._legacy_provider_reported:
self._legacy_provider_reported = True
report_issue = self._suggest_report_issue(provider, stt_provider)
# This should raise in Home Assistant Core 2023.9
_LOGGER.warning(
"Provider %s (%s) is using a legacy implementation, "
"and should be updated to use the SpeechToTextEntity. Please "
"%s",
provider,
type(stt_provider),
report_issue,
)
return stt_provider
def _suggest_report_issue(self, provider: str, provider_instance: object) -> str:
"""Suggest to report an issue."""
report_issue = ""
if "custom_components" in type(provider_instance).__module__:
report_issue = "report it to the custom integration author."
else:
report_issue = (
"create a bug report at "
"https://github.com/home-assistant/core/issues?q=is%3Aopen+is%3Aissue"
)
report_issue += f"+label%3A%22integration%3A+{provider}%22"
return report_issue
def _metadata_from_header(request: web.Request) -> SpeechMetadata:
"""Extract STT metadata from header.
@ -138,7 +358,7 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
for entry in data:
key, _, value = entry.strip().partition("=")
if key not in fields:
raise ValueError(f"Invalid field {key}")
raise ValueError(f"Invalid field: {key}")
args[key] = value
for field in fields:
@ -154,5 +374,5 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
sample_rate=args["sample_rate"],
channel=args["channel"],
)
except TypeError as err:
except ValueError as err:
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err

View File

@ -2,6 +2,7 @@
from enum import Enum
DOMAIN = "stt"
DATA_PROVIDERS = f"{DOMAIN}_providers"
class AudioCodecs(str, Enum):

View File

@ -13,6 +13,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DATA_PROVIDERS,
DOMAIN,
AudioBitRates,
AudioChannels,
@ -31,15 +32,15 @@ def async_get_provider(
) -> Provider | None:
"""Return provider."""
if domain:
return hass.data[DOMAIN].get(domain)
return hass.data[DATA_PROVIDERS].get(domain)
if not hass.data[DOMAIN]:
if not hass.data[DATA_PROVIDERS]:
return None
if "cloud" in hass.data[DOMAIN]:
return hass.data[DOMAIN]["cloud"]
if "cloud" in hass.data[DATA_PROVIDERS]:
return hass.data[DATA_PROVIDERS]["cloud"]
return next(iter(hass.data[DOMAIN].values()))
return next(iter(hass.data[DATA_PROVIDERS].values()))
@callback
@ -47,7 +48,7 @@ def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy speech to text providers."""
providers = hass.data[DOMAIN] = {}
providers = hass.data[DATA_PROVIDERS] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up a TTS platform."""

View File

@ -4,18 +4,31 @@ from http import HTTPStatus
import pytest
from homeassistant.components import stt
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
async def setup_comp(hass):
"""Set up demo component."""
@pytest.fixture
async def setup_legacy_platform(hass: HomeAssistant) -> None:
"""Set up legacy demo platform."""
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
await hass.async_block_till_done()
@pytest.fixture
async def setup_config_entry(hass: HomeAssistant) -> None:
"""Set up demo component from config entry."""
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.mark.usefixtures("setup_legacy_platform")
async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
"""Test retrieve settings from demo provider."""
client = await hass_client()
@ -34,6 +47,7 @@ async def test_demo_settings(hass_client: ClientSessionGenerator) -> None:
}
@pytest.mark.usefixtures("setup_legacy_platform")
async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> None:
"""Test retrieve settings from demo provider."""
client = await hass_client()
@ -42,6 +56,7 @@ async def test_demo_speech_no_metadata(hass_client: ClientSessionGenerator) -> N
assert response.status == HTTPStatus.BAD_REQUEST
@pytest.mark.usefixtures("setup_legacy_platform")
async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -> None:
"""Test retrieve settings from demo provider."""
client = await hass_client()
@ -59,6 +74,7 @@ async def test_demo_speech_wrong_metadata(hass_client: ClientSessionGenerator) -
assert response.status == HTTPStatus.UNSUPPORTED_MEDIA_TYPE
@pytest.mark.usefixtures("setup_legacy_platform")
async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
"""Test retrieve settings from demo provider."""
client = await hass_client()
@ -77,3 +93,26 @@ async def test_demo_speech(hass_client: ClientSessionGenerator) -> None:
assert response.status == HTTPStatus.OK
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
@pytest.mark.usefixtures("setup_config_entry")
async def test_config_entry_demo_speech(
hass_client: ClientSessionGenerator, hass: HomeAssistant
) -> None:
"""Test retrieve settings from demo provider from config entry."""
client = await hass_client()
response = await client.post(
"/api/stt/demo",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
" language=de"
)
},
data=b"Test",
)
response_data = await response.json()
assert response.status == HTTPStatus.OK
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}

View File

@ -6,7 +6,9 @@ from pathlib import Path
from typing import Any
from homeassistant.components.stt import Provider
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform, mock_platform
@ -54,3 +56,19 @@ def mock_stt_platform(
mock_platform(hass, f"{integration}.stt", loaded_platform)
return loaded_platform
def mock_stt_entity_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str,
async_setup_entry: Callable[
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
Coroutine[Any, Any, None],
]
| None = None,
) -> MockPlatform:
"""Specialize the mock platform for stt."""
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
mock_platform(hass, f"{integration}.stt", loaded_platform)
return loaded_platform

View File

@ -1,5 +1,5 @@
"""Test STT component setup."""
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Generator
from http import HTTPStatus
from pathlib import Path
from unittest.mock import AsyncMock
@ -7,6 +7,7 @@ from unittest.mock import AsyncMock
import pytest
from homeassistant.components.stt import (
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
@ -16,17 +17,30 @@ from homeassistant.components.stt import (
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
async_get_provider,
)
from homeassistant.core import HomeAssistant
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from .common import mock_stt_platform
from .common import mock_stt_entity_platform, mock_stt_platform
from tests.common import (
MockConfigEntry,
MockModule,
mock_config_flow,
mock_integration,
mock_platform,
mock_restore_cache,
)
from tests.typing import ClientSessionGenerator
TEST_DOMAIN = "test"
class MockProvider(Provider):
class BaseProvider:
"""Mock provider."""
fail_process_audio = False
@ -73,7 +87,15 @@ class MockProvider(Provider):
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult("test", SpeechResultState.SUCCESS)
return SpeechResult("test_result", SpeechResultState.SUCCESS)
class MockProvider(BaseProvider, Provider):
"""Mock provider."""
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
@pytest.fixture
@ -82,26 +104,113 @@ def mock_provider() -> MockProvider:
return MockProvider()
@pytest.fixture
def mock_provider_entity() -> MockProviderEntity:
"""Test provider entity fixture."""
return MockProviderEntity()
class STTFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, STTFlow):
yield
@pytest.fixture(name="setup")
async def setup_fixture(
hass: HomeAssistant,
tmp_path: Path,
request: pytest.FixtureRequest,
) -> None:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, tmp_path, MockProvider())
elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
else:
raise RuntimeError("Invalid setup fixture")
async def mock_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
) -> None:
"""Set up a test provider."""
mock_stt_platform(
hass,
tmp_path,
"test",
TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
await hass.async_block_till_done()
async def mock_config_entry_setup(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> MockConfigEntry:
"""Set up a test provider via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_provider_entity])
mock_stt_entity_platform(hass, tmp_path, TEST_DOMAIN, async_setup_entry_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry
@pytest.mark.parametrize(
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
)
async def test_get_provider_info(
hass: HomeAssistant, hass_client: ClientSessionGenerator
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
) -> None:
"""Test engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/test")
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["en"],
@ -113,22 +222,44 @@ async def test_get_provider_info(
}
async def test_get_non_existing_provider_info(
hass: HomeAssistant, hass_client: ClientSessionGenerator
@pytest.mark.parametrize(
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
)
async def test_non_existing_provider(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
) -> None:
"""Test streaming to engine that doesn't exist."""
client = await hass_client()
response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND
response = await client.post(
"/api/stt/not_exist",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
" language=en"
)
},
)
assert response.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
)
async def test_stream_audio(
hass: HomeAssistant, hass_client: ClientSessionGenerator
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
response = await client.post(
"/api/stt/test",
f"/api/stt/{TEST_DOMAIN}",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
@ -137,20 +268,39 @@ async def test_stream_audio(
},
)
assert response.status == HTTPStatus.OK
assert await response.json() == {"text": "test", "result": "success"}
assert await response.json() == {"text": "test_result", "result": "success"}
@pytest.mark.parametrize(
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
)
@pytest.mark.parametrize(
("header", "status", "error"),
(
(None, 400, "Missing X-Speech-Content header"),
(
(
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
" language=en; unknown=1"
),
400,
"Invalid field: unknown",
),
(
(
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=100;"
" language=en"
),
400,
"100 is not a valid AudioChannels",
"Wrong format of X-Speech-Content: 100 is not a valid AudioChannels",
),
(
(
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=bad channel;"
" language=en"
),
400,
"Wrong format of X-Speech-Content: invalid literal for int() with base 10: 'bad channel'",
),
(
"format=wav; codec=pcm; sample_rate=16000",
@ -165,6 +315,7 @@ async def test_metadata_errors(
header: str | None,
status: int,
error: str,
setup: str,
) -> None:
"""Test metadata errors."""
client = await hass_client()
@ -172,11 +323,55 @@ async def test_metadata_errors(
if header:
headers["X-Speech-Content"] = header
response = await client.post("/api/stt/test", headers=headers)
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
assert response.status == status
assert await response.text() == error
async def test_get_provider(hass: HomeAssistant, mock_provider: MockProvider) -> None:
async def test_get_provider(
hass: HomeAssistant,
tmp_path: Path,
mock_provider: MockProvider,
) -> None:
"""Test we can get STT providers."""
assert mock_provider == async_get_provider(hass, "test")
await mock_setup(hass, tmp_path, mock_provider)
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test we can unload config entry."""
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert config_entry.state == ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
assert config_entry.state == ConfigEntryState.NOT_LOADED
def test_entity_name_raises_before_addition(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test entity name raises before addition to Home Assistant."""
with pytest.raises(RuntimeError):
mock_provider_entity.name # pylint: disable=pointless-statement
async def test_restore_state(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test we restore state in the integration."""
entity_id = f"{DOMAIN}.{TEST_DOMAIN}"
timestamp = "2023-01-01T23:59:59+00:00"
mock_restore_cache(hass, (State(entity_id, timestamp),))
config_entry = await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
await hass.async_block_till_done()
assert config_entry.state == ConfigEntryState.LOADED
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp