From a2031491334b18b1ab689206ed45d25f3c9c7af3 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 23 Apr 2023 23:06:34 -0400 Subject: [PATCH] Allow entity names for STT entities (#91932) * Allow entity names for STT entities * Fix tests --- .../components/assist_pipeline/pipeline.py | 5 ++- homeassistant/components/demo/stt.py | 2 + homeassistant/components/stt/__init__.py | 22 +--------- tests/components/assist_pipeline/conftest.py | 2 + .../assist_pipeline/snapshots/test_init.ambr | 2 +- tests/components/demo/test_stt.py | 10 ++++- tests/components/stt/test_init.py | 41 +++++++++---------- tests/components/wyoming/test_stt.py | 10 ++--- 8 files changed, 44 insertions(+), 50 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 57c1ed92f85..6e69fe0dc2e 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -304,7 +304,10 @@ class PipelineRun: if self.stt_provider is None: raise RuntimeError("Speech to text was not prepared") - engine = self.stt_provider.name + if isinstance(self.stt_provider, stt.Provider): + engine = self.stt_provider.name + else: + engine = self.stt_provider.entity_id self.process_event( PipelineEvent( diff --git a/homeassistant/components/demo/stt.py b/homeassistant/components/demo/stt.py index e1f59fa76ee..07a844c048c 100644 --- a/homeassistant/components/demo/stt.py +++ b/homeassistant/components/demo/stt.py @@ -44,6 +44,8 @@ async def async_setup_entry( class DemoProviderEntity(SpeechToTextEntity): """Demo speech API provider entity.""" + _attr_name = "Demo STT" + @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index ef217aed2a2..0a7711d5619 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -128,16 +128,6 @@ class SpeechToTextEntity(RestoreEntity): _attr_should_poll = False __last_processed: str | None = None - @property - @final - def name(self) -> str: - """Return the name of the provider entity.""" - # Only one entity is allowed per platform for now. - if self.platform is None: - raise RuntimeError("Entity is not added to hass yet.") - - return self.platform.platform_name - @property @final def state(self) -> str | None: @@ -249,11 +239,7 @@ class SpeechToTextView(HomeAssistantView): hass: HomeAssistant = request.app["hass"] provider_entity: SpeechToTextEntity | None = None if ( - not ( - provider_entity := async_get_speech_to_text_entity( - hass, f"{DOMAIN}.{provider}" - ) - ) + not (provider_entity := async_get_speech_to_text_entity(hass, provider)) and provider not in self.providers ): raise HTTPNotFound() @@ -292,11 +278,7 @@ class SpeechToTextView(HomeAssistantView): """Return provider specific audio information.""" hass: HomeAssistant = request.app["hass"] if ( - not ( - provider_entity := async_get_speech_to_text_entity( - hass, f"{DOMAIN}.{provider}" - ) - ) + not (provider_entity := async_get_speech_to_text_entity(hass, provider)) and provider not in self.providers ): raise HTTPNotFound() diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index d0a5f74281b..744254f9954 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -89,6 +89,8 @@ class MockSttProvider(BaseProvider, stt.Provider): class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity): """Mock provider entity.""" + _attr_name = "Mock STT" + class MockTTSProvider(tts.Provider): """Mock TTS provider.""" diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index f7d2768f3ae..b5c636b4bd6 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -97,7 +97,7 @@ }), dict({ 'data': dict({ - 'engine': 'test', + 'engine': 'stt.mock_stt', 'metadata': dict({ 'bit_rate': , 'channel': , diff --git a/tests/components/demo/test_stt.py b/tests/components/demo/test_stt.py index 7a8582df29b..5d4242844ee 100644 --- a/tests/components/demo/test_stt.py +++ b/tests/components/demo/test_stt.py @@ -1,10 +1,12 @@ """The tests for the demo stt component.""" from http import HTTPStatus +from unittest.mock import patch import pytest from homeassistant.components import stt from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN +from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -24,7 +26,11 @@ async def setup_config_entry(hass: HomeAssistant) -> None: """Set up demo component from config entry.""" config_entry = MockConfigEntry(domain=DEMO_DOMAIN) config_entry.add_to_hass(hass) - assert await hass.config_entries.async_setup(config_entry.entry_id) + with patch( + "homeassistant.components.demo.COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM", + [Platform.STT], + ): + assert await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() @@ -103,7 +109,7 @@ async def test_config_entry_demo_speech( client = await hass_client() response = await client.post( - "/api/stt/demo", + "/api/stt/stt.demo_stt", headers={ "X-Speech-Content": ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;" diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 7275dcabe29..e7e0decde72 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -93,10 +93,15 @@ class BaseProvider: class MockProvider(BaseProvider, Provider): """Mock provider.""" + url_path = TEST_DOMAIN + class MockProviderEntity(BaseProvider, SpeechToTextEntity): """Mock provider entity.""" + url_path = "stt.test" + _attr_name = "test" + @pytest.fixture def mock_provider() -> MockProvider: @@ -128,15 +133,19 @@ async def setup_fixture( hass: HomeAssistant, tmp_path: Path, request: pytest.FixtureRequest, -) -> None: +) -> MockProvider | MockProviderEntity: """Set up the test environment.""" if request.param == "mock_setup": - await mock_setup(hass, tmp_path, MockProvider()) + provider = MockProvider() + await mock_setup(hass, tmp_path, provider) elif request.param == "mock_config_entry_setup": - await mock_config_entry_setup(hass, tmp_path, MockProviderEntity()) + provider = MockProviderEntity() + await mock_config_entry_setup(hass, tmp_path, provider) else: raise RuntimeError("Invalid setup fixture") + return provider + async def mock_setup( hass: HomeAssistant, @@ -206,11 +215,11 @@ async def mock_config_entry_setup( async def test_get_provider_info( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: str, + setup: MockProvider | MockProviderEntity, ) -> None: """Test engine that doesn't exist.""" client = await hass_client() - response = await client.get(f"/api/stt/{TEST_DOMAIN}") + response = await client.get(f"/api/stt/{setup.url_path}") assert response.status == HTTPStatus.OK assert await response.json() == { "languages": ["de", "de-CH", "en-US"], @@ -228,7 +237,7 @@ async def test_get_provider_info( async def test_non_existing_provider( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: str, + setup: MockProvider | MockProviderEntity, ) -> None: """Test streaming to engine that doesn't exist.""" client = await hass_client() @@ -255,14 +264,14 @@ async def test_non_existing_provider( async def test_stream_audio( hass: HomeAssistant, hass_client: ClientSessionGenerator, - setup: str, + setup: MockProvider | MockProviderEntity, ) -> None: """Test streaming audio and getting response.""" client = await hass_client() # Language en is matched with en-US response = await client.post( - f"/api/stt/{TEST_DOMAIN}", + f"/api/stt/{setup.url_path}", headers={ "X-Speech-Content": ( "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" @@ -318,7 +327,7 @@ async def test_metadata_errors( header: str | None, status: int, error: str, - setup: str, + setup: MockProvider | MockProviderEntity, ) -> None: """Test metadata errors.""" client = await hass_client() @@ -326,7 +335,7 @@ async def test_metadata_errors( if header: headers["X-Speech-Content"] = header - response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers) + response = await client.post(f"/api/stt/{setup.url_path}", headers=headers) assert response.status == status assert await response.text() == error @@ -351,16 +360,6 @@ async def test_config_entry_unload( assert config_entry.state == ConfigEntryState.NOT_LOADED -def test_entity_name_raises_before_addition( - hass: HomeAssistant, - tmp_path: Path, - mock_provider_entity: MockProviderEntity, -) -> None: - """Test entity name raises before addition to Home Assistant.""" - with pytest.raises(RuntimeError): - mock_provider_entity.name # pylint: disable=pointless-statement - - async def test_restore_state( hass: HomeAssistant, tmp_path: Path, @@ -388,7 +387,7 @@ async def test_restore_state( async def test_ws_list_engines( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, - setup: str, + setup: MockProvider | MockProviderEntity, engine_id: str, ) -> None: """Test listing speech to text engines.""" diff --git a/tests/components/wyoming/test_stt.py b/tests/components/wyoming/test_stt.py index 6c9e75ffa18..021419f3a5e 100644 --- a/tests/components/wyoming/test_stt.py +++ b/tests/components/wyoming/test_stt.py @@ -13,10 +13,10 @@ from . import MockAsyncTcpClient async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: """Test supported properties.""" - state = hass.states.get("stt.wyoming") + state = hass.states.get("stt.test_asr") assert state is not None - entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None assert entity.supported_languages == ["en-US"] @@ -29,7 +29,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: """Test streaming audio.""" - entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None async def audio_stream(): @@ -51,7 +51,7 @@ async def test_streaming_audio_connection_lost( hass: HomeAssistant, init_wyoming_stt ) -> None: """Test streaming audio and losing connection.""" - entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None async def audio_stream(): @@ -69,7 +69,7 @@ async def test_streaming_audio_connection_lost( async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: """Test streaming audio and error raising.""" - entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None async def audio_stream():