diff --git a/homeassistant/components/wyoming/conversation.py b/homeassistant/components/wyoming/conversation.py new file mode 100644 index 00000000000..9a17559c1f8 --- /dev/null +++ b/homeassistant/components/wyoming/conversation.py @@ -0,0 +1,194 @@ +"""Support for Wyoming intent recognition services.""" + +import logging + +from wyoming.asr import Transcript +from wyoming.client import AsyncTcpClient +from wyoming.handle import Handled, NotHandled +from wyoming.info import HandleProgram, IntentProgram +from wyoming.intent import Intent, NotRecognized + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers import intent +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import ulid + +from .const import DOMAIN +from .data import WyomingService +from .error import WyomingError +from .models import DomainDataItem + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming conversation.""" + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + async_add_entities( + [ + WyomingConversationEntity(config_entry, item.service), + ] + ) + + +class WyomingConversationEntity( + conversation.ConversationEntity, conversation.AbstractConversationAgent +): + """Wyoming conversation agent.""" + + _attr_has_entity_name = True + + def __init__( + self, + config_entry: ConfigEntry, + service: WyomingService, + ) -> None: + """Set up provider.""" + super().__init__() + + self.service = service + + self._intent_service: IntentProgram | None = None + self._handle_service: HandleProgram | None = None + + for maybe_intent in self.service.info.intent: + if maybe_intent.installed: + self._intent_service = maybe_intent + break + + for maybe_handle in self.service.info.handle: + if maybe_handle.installed: + self._handle_service = maybe_handle + break + + model_languages: set[str] = set() + + if self._intent_service is not None: + for intent_model in self._intent_service.models: + if intent_model.installed: + model_languages.update(intent_model.languages) + + self._attr_name = self._intent_service.name + self._attr_supported_features = ( + conversation.ConversationEntityFeature.CONTROL + ) + elif self._handle_service is not None: + for handle_model in self._handle_service.models: + if handle_model.installed: + model_languages.update(handle_model.languages) + + self._attr_name = self._handle_service.name + + self._supported_languages = list(model_languages) + self._attr_unique_id = f"{config_entry.entry_id}-conversation" + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return self._supported_languages + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + conversation_id = user_input.conversation_id or ulid.ulid_now() + intent_response = intent.IntentResponse(language=user_input.language) + + try: + async with AsyncTcpClient(self.service.host, self.service.port) as client: + await client.write_event( + Transcript( + user_input.text, context={"conversation_id": conversation_id} + ).event() + ) + + while True: + event = await client.read_event() + if event is None: + _LOGGER.debug("Connection lost") + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + "Connection to service was lost", + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id, + ) + + if Intent.is_type(event.type): + # Success + recognized_intent = Intent.from_event(event) + _LOGGER.debug("Recognized intent: %s", recognized_intent) + + intent_type = recognized_intent.name + intent_slots = { + e.name: {"value": e.value} + for e in recognized_intent.entities + } + intent_response = await intent.async_handle( + self.hass, + DOMAIN, + intent_type, + intent_slots, + text_input=user_input.text, + language=user_input.language, + ) + + if (not intent_response.speech) and recognized_intent.text: + intent_response.async_set_speech(recognized_intent.text) + + break + + if NotRecognized.is_type(event.type): + not_recognized = NotRecognized.from_event(event) + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_INTENT_MATCH, + not_recognized.text, + ) + break + + if Handled.is_type(event.type): + # Success + handled = Handled.from_event(event) + intent_response.async_set_speech(handled.text) + break + + if NotHandled.is_type(event.type): + not_handled = NotHandled.from_event(event) + intent_response.async_set_error( + intent.IntentResponseErrorCode.FAILED_TO_HANDLE, + not_handled.text, + ) + break + + except (OSError, WyomingError) as err: + _LOGGER.exception("Unexpected error while communicating with service") + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error communicating with service: {err}", + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id, + ) + except intent.IntentError as err: + _LOGGER.exception("Unexpected error while handling intent") + intent_response.async_set_error( + intent.IntentResponseErrorCode.FAILED_TO_HANDLE, + f"Error handling intent: {err}", + ) + return conversation.ConversationResult( + response=intent_response, + conversation_id=user_input.conversation_id, + ) + + # Success + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py index 1ee0f24f805..a16062ab058 100644 --- a/homeassistant/components/wyoming/data.py +++ b/homeassistant/components/wyoming/data.py @@ -37,6 +37,10 @@ class WyomingService: self.platforms.append(Platform.TTS) if any(wake.installed for wake in info.wake): self.platforms.append(Platform.WAKE_WORD) + if any(intent.installed for intent in info.intent) or any( + handle.installed for handle in info.handle + ): + self.platforms.append(Platform.CONVERSATION) def has_services(self) -> bool: """Return True if services are installed that Home Assistant can use.""" @@ -44,6 +48,8 @@ class WyomingService: any(asr for asr in self.info.asr if asr.installed) or any(tts for tts in self.info.tts if tts.installed) or any(wake for wake in self.info.wake if wake.installed) + or any(intent for intent in self.info.intent if intent.installed) + or any(handle for handle in self.info.handle if handle.installed) or ((self.info.satellite is not None) and self.info.satellite.installed) ) @@ -70,6 +76,16 @@ class WyomingService: if wake_installed: return wake_installed[0].name + # intent recognition (text -> intent) + intent_installed = [intent for intent in self.info.intent if intent.installed] + if intent_installed: + return intent_installed[0].name + + # intent handling (text -> text) + handle_installed = [handle for handle in self.info.handle if handle.installed] + if handle_installed: + return handle_installed[0].name + return None @classmethod diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 30703159994..4540cdaabfd 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -8,7 +8,11 @@ from wyoming.info import ( AsrModel, AsrProgram, Attribution, + HandleModel, + HandleProgram, Info, + IntentModel, + IntentProgram, Satellite, TtsProgram, TtsVoice, @@ -87,6 +91,48 @@ WAKE_WORD_INFO = Info( ) ] ) +INTENT_INFO = Info( + intent=[ + IntentProgram( + name="Test Intent", + description="Test Intent", + installed=True, + attribution=TEST_ATTR, + models=[ + IntentModel( + name="Test Model", + description="Test Model", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + version=None, + ) + ], + version=None, + ) + ] +) +HANDLE_INFO = Info( + handle=[ + HandleProgram( + name="Test Handle", + description="Test Handle", + installed=True, + attribution=TEST_ATTR, + models=[ + HandleModel( + name="Test Model", + description="Test Model", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + version=None, + ) + ], + version=None, + ) + ] +) SATELLITE_INFO = Info( satellite=Satellite( name="Test Satellite", diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index d504f98a5b0..018fff33821 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -13,7 +13,14 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO +from . import ( + HANDLE_INFO, + INTENT_INFO, + SATELLITE_INFO, + STT_INFO, + TTS_INFO, + WAKE_WORD_INFO, +) from tests.common import MockConfigEntry @@ -83,6 +90,36 @@ def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry: return entry +@pytest.fixture +def intent_config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Intent", + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +def handle_config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test Handle", + ) + entry.add_to_hass(hass) + return entry + + @pytest.fixture async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry): """Initialize Wyoming STT.""" @@ -115,6 +152,34 @@ async def init_wyoming_wake_word( await hass.config_entries.async_setup(wake_word_config_entry.entry_id) +@pytest.fixture +async def init_wyoming_intent( + hass: HomeAssistant, intent_config_entry: ConfigEntry +) -> ConfigEntry: + """Initialize Wyoming intent recognizer.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=INTENT_INFO, + ): + await hass.config_entries.async_setup(intent_config_entry.entry_id) + + return intent_config_entry + + +@pytest.fixture +async def init_wyoming_handle( + hass: HomeAssistant, handle_config_entry: ConfigEntry +) -> ConfigEntry: + """Initialize Wyoming intent handler.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=HANDLE_INFO, + ): + await hass.config_entries.async_setup(handle_config_entry.entry_id) + + return handle_config_entry + + @pytest.fixture def metadata(hass: HomeAssistant) -> stt.SpeechMetadata: """Get default STT metadata.""" diff --git a/tests/components/wyoming/snapshots/test_conversation.ambr b/tests/components/wyoming/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..24763cac441 --- /dev/null +++ b/tests/components/wyoming/snapshots/test_conversation.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_connection_lost + 'Connection to service was lost' +# --- +# name: test_oserror + 'Error communicating with service: Boom!' +# --- diff --git a/tests/components/wyoming/test_conversation.py b/tests/components/wyoming/test_conversation.py new file mode 100644 index 00000000000..02b04503962 --- /dev/null +++ b/tests/components/wyoming/test_conversation.py @@ -0,0 +1,224 @@ +"""Test conversation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from syrupy import SnapshotAssertion +from wyoming.asr import Transcript +from wyoming.handle import Handled, NotHandled +from wyoming.intent import Entity, Intent, NotRecognized + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import intent + +from . import MockAsyncTcpClient + + +async def test_intent(hass: HomeAssistant, init_wyoming_intent: ConfigEntry) -> None: + """Test when an intent is recognized.""" + agent_id = "conversation.test_intent" + + conversation_id = "conversation-1234" + test_intent = Intent( + name="TestIntent", + entities=[Entity(name="entity", value="value")], + text="success", + ) + + class TestIntentHandler(intent.IntentHandler): + """Test Intent Handler.""" + + intent_type = "TestIntent" + + async def async_handle(self, intent_obj: intent.Intent): + """Handle the intent.""" + assert intent_obj.slots.get("entity", {}).get("value") == "value" + return intent_obj.create_response() + + intent.async_register(hass, TestIntentHandler()) + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([test_intent.event()]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=conversation_id, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == "success" + assert result.conversation_id == conversation_id + + +async def test_intent_handle_error( + hass: HomeAssistant, init_wyoming_intent: ConfigEntry +) -> None: + """Test error during handling when an intent is recognized.""" + agent_id = "conversation.test_intent" + + test_intent = Intent(name="TestIntent", entities=[], text="success") + + class TestIntentHandler(intent.IntentHandler): + """Test Intent Handler.""" + + intent_type = "TestIntent" + + async def async_handle(self, intent_obj: intent.Intent): + """Handle the intent.""" + raise intent.IntentError + + intent.async_register(hass, TestIntentHandler()) + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([test_intent.event()]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=None, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE + + +async def test_not_recognized( + hass: HomeAssistant, init_wyoming_intent: ConfigEntry +) -> None: + """Test when an intent is not recognized.""" + agent_id = "conversation.test_intent" + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([NotRecognized(text="failure").event()]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=None, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == "failure" + + +async def test_handle(hass: HomeAssistant, init_wyoming_handle: ConfigEntry) -> None: + """Test when an intent is handled.""" + agent_id = "conversation.test_handle" + + conversation_id = "conversation-1234" + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([Handled(text="success").event()]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=conversation_id, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == "success" + assert result.conversation_id == conversation_id + + +async def test_not_handled( + hass: HomeAssistant, init_wyoming_handle: ConfigEntry +) -> None: + """Test when an intent is not handled.""" + agent_id = "conversation.test_handle" + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([NotHandled(text="failure").event()]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=None, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == "failure" + + +async def test_connection_lost( + hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion +) -> None: + """Test connection to client is lost.""" + agent_id = "conversation.test_handle" + + with patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", + MockAsyncTcpClient([None]), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=None, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == snapshot() + + +async def test_oserror( + hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion +) -> None: + """Test connection error.""" + agent_id = "conversation.test_handle" + + mock_client = MockAsyncTcpClient([Transcript("success").event()]) + + with ( + patch( + "homeassistant.components.wyoming.conversation.AsyncTcpClient", mock_client + ), + patch.object(mock_client, "read_event", side_effect=OSError("Boom!")), + ): + result = await conversation.async_converse( + hass=hass, + text="test text", + conversation_id=None, + context=Context(), + language=hass.config.language, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR + assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN + assert result.response.speech, "No speech" + assert result.response.speech.get("plain", {}).get("speech") == snapshot()