Add conversation agent to Wyoming (#124373)

* Add conversation agent to Wyoming

* Remove error

* Remove conversation platform from satellite list

* Clean up

* Update homeassistant/components/wyoming/conversation.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Remove unnecessary attribute

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-10-16 09:07:56 -05:00 committed by GitHub
parent bcac851677
commit 11ac8f8006
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 553 additions and 1 deletions

View File

@ -0,0 +1,194 @@
"""Support for Wyoming intent recognition services."""
import logging
from wyoming.asr import Transcript
from wyoming.client import AsyncTcpClient
from wyoming.handle import Handled, NotHandled
from wyoming.info import HandleProgram, IntentProgram
from wyoming.intent import Intent, NotRecognized
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import DOMAIN
from .data import WyomingService
from .error import WyomingError
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming conversation."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingConversationEntity(config_entry, item.service),
]
)
class WyomingConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""Wyoming conversation agent."""
_attr_has_entity_name = True
def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
super().__init__()
self.service = service
self._intent_service: IntentProgram | None = None
self._handle_service: HandleProgram | None = None
for maybe_intent in self.service.info.intent:
if maybe_intent.installed:
self._intent_service = maybe_intent
break
for maybe_handle in self.service.info.handle:
if maybe_handle.installed:
self._handle_service = maybe_handle
break
model_languages: set[str] = set()
if self._intent_service is not None:
for intent_model in self._intent_service.models:
if intent_model.installed:
model_languages.update(intent_model.languages)
self._attr_name = self._intent_service.name
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
elif self._handle_service is not None:
for handle_model in self._handle_service.models:
if handle_model.installed:
model_languages.update(handle_model.languages)
self._attr_name = self._handle_service.name
self._supported_languages = list(model_languages)
self._attr_unique_id = f"{config_entry.entry_id}-conversation"
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
conversation_id = user_input.conversation_id or ulid.ulid_now()
intent_response = intent.IntentResponse(language=user_input.language)
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(
Transcript(
user_input.text, context={"conversation_id": conversation_id}
).event()
)
while True:
event = await client.read_event()
if event is None:
_LOGGER.debug("Connection lost")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Connection to service was lost",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
if Intent.is_type(event.type):
# Success
recognized_intent = Intent.from_event(event)
_LOGGER.debug("Recognized intent: %s", recognized_intent)
intent_type = recognized_intent.name
intent_slots = {
e.name: {"value": e.value}
for e in recognized_intent.entities
}
intent_response = await intent.async_handle(
self.hass,
DOMAIN,
intent_type,
intent_slots,
text_input=user_input.text,
language=user_input.language,
)
if (not intent_response.speech) and recognized_intent.text:
intent_response.async_set_speech(recognized_intent.text)
break
if NotRecognized.is_type(event.type):
not_recognized = NotRecognized.from_event(event)
intent_response.async_set_error(
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
not_recognized.text,
)
break
if Handled.is_type(event.type):
# Success
handled = Handled.from_event(event)
intent_response.async_set_speech(handled.text)
break
if NotHandled.is_type(event.type):
not_handled = NotHandled.from_event(event)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
not_handled.text,
)
break
except (OSError, WyomingError) as err:
_LOGGER.exception("Unexpected error while communicating with service")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error communicating with service: {err}",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
except intent.IntentError as err:
_LOGGER.exception("Unexpected error while handling intent")
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Error handling intent: {err}",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
# Success
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

View File

@ -37,6 +37,10 @@ class WyomingService:
self.platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
self.platforms.append(Platform.WAKE_WORD)
if any(intent.installed for intent in info.intent) or any(
handle.installed for handle in info.handle
):
self.platforms.append(Platform.CONVERSATION)
def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use."""
@ -44,6 +48,8 @@ class WyomingService:
any(asr for asr in self.info.asr if asr.installed)
or any(tts for tts in self.info.tts if tts.installed)
or any(wake for wake in self.info.wake if wake.installed)
or any(intent for intent in self.info.intent if intent.installed)
or any(handle for handle in self.info.handle if handle.installed)
or ((self.info.satellite is not None) and self.info.satellite.installed)
)
@ -70,6 +76,16 @@ class WyomingService:
if wake_installed:
return wake_installed[0].name
# intent recognition (text -> intent)
intent_installed = [intent for intent in self.info.intent if intent.installed]
if intent_installed:
return intent_installed[0].name
# intent handling (text -> text)
handle_installed = [handle for handle in self.info.handle if handle.installed]
if handle_installed:
return handle_installed[0].name
return None
@classmethod

View File

@ -8,7 +8,11 @@ from wyoming.info import (
AsrModel,
AsrProgram,
Attribution,
HandleModel,
HandleProgram,
Info,
IntentModel,
IntentProgram,
Satellite,
TtsProgram,
TtsVoice,
@ -87,6 +91,48 @@ WAKE_WORD_INFO = Info(
)
]
)
INTENT_INFO = Info(
intent=[
IntentProgram(
name="Test Intent",
description="Test Intent",
installed=True,
attribution=TEST_ATTR,
models=[
IntentModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
version=None,
)
],
version=None,
)
]
)
HANDLE_INFO = Info(
handle=[
HandleProgram(
name="Test Handle",
description="Test Handle",
installed=True,
attribution=TEST_ATTR,
models=[
HandleModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
version=None,
)
],
version=None,
)
]
)
SATELLITE_INFO = Info(
satellite=Satellite(
name="Test Satellite",

View File

@ -13,7 +13,14 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
from . import (
HANDLE_INFO,
INTENT_INFO,
SATELLITE_INFO,
STT_INFO,
TTS_INFO,
WAKE_WORD_INFO,
)
from tests.common import MockConfigEntry
@ -83,6 +90,36 @@ def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
return entry
@pytest.fixture
def intent_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Intent",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
def handle_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Handle",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
"""Initialize Wyoming STT."""
@ -115,6 +152,34 @@ async def init_wyoming_wake_word(
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)
@pytest.fixture
async def init_wyoming_intent(
hass: HomeAssistant, intent_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming intent recognizer."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=INTENT_INFO,
):
await hass.config_entries.async_setup(intent_config_entry.entry_id)
return intent_config_entry
@pytest.fixture
async def init_wyoming_handle(
hass: HomeAssistant, handle_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming intent handler."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=HANDLE_INFO,
):
await hass.config_entries.async_setup(handle_config_entry.entry_id)
return handle_config_entry
@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""

View File

@ -0,0 +1,7 @@
# serializer version: 1
# name: test_connection_lost
'Connection to service was lost'
# ---
# name: test_oserror
'Error communicating with service: Boom!'
# ---

View File

@ -0,0 +1,224 @@
"""Test conversation."""
from __future__ import annotations
from unittest.mock import patch
from syrupy import SnapshotAssertion
from wyoming.asr import Transcript
from wyoming.handle import Handled, NotHandled
from wyoming.intent import Entity, Intent, NotRecognized
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from . import MockAsyncTcpClient
async def test_intent(hass: HomeAssistant, init_wyoming_intent: ConfigEntry) -> None:
"""Test when an intent is recognized."""
agent_id = "conversation.test_intent"
conversation_id = "conversation-1234"
test_intent = Intent(
name="TestIntent",
entities=[Entity(name="entity", value="value")],
text="success",
)
class TestIntentHandler(intent.IntentHandler):
"""Test Intent Handler."""
intent_type = "TestIntent"
async def async_handle(self, intent_obj: intent.Intent):
"""Handle the intent."""
assert intent_obj.slots.get("entity", {}).get("value") == "value"
return intent_obj.create_response()
intent.async_register(hass, TestIntentHandler())
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([test_intent.event()]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=conversation_id,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == "success"
assert result.conversation_id == conversation_id
async def test_intent_handle_error(
hass: HomeAssistant, init_wyoming_intent: ConfigEntry
) -> None:
"""Test error during handling when an intent is recognized."""
agent_id = "conversation.test_intent"
test_intent = Intent(name="TestIntent", entities=[], text="success")
class TestIntentHandler(intent.IntentHandler):
"""Test Intent Handler."""
intent_type = "TestIntent"
async def async_handle(self, intent_obj: intent.Intent):
"""Handle the intent."""
raise intent.IntentError
intent.async_register(hass, TestIntentHandler())
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([test_intent.event()]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=None,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE
async def test_not_recognized(
hass: HomeAssistant, init_wyoming_intent: ConfigEntry
) -> None:
"""Test when an intent is not recognized."""
agent_id = "conversation.test_intent"
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([NotRecognized(text="failure").event()]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=None,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == "failure"
async def test_handle(hass: HomeAssistant, init_wyoming_handle: ConfigEntry) -> None:
"""Test when an intent is handled."""
agent_id = "conversation.test_handle"
conversation_id = "conversation-1234"
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([Handled(text="success").event()]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=conversation_id,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == "success"
assert result.conversation_id == conversation_id
async def test_not_handled(
hass: HomeAssistant, init_wyoming_handle: ConfigEntry
) -> None:
"""Test when an intent is not handled."""
agent_id = "conversation.test_handle"
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([NotHandled(text="failure").event()]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=None,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == "failure"
async def test_connection_lost(
hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion
) -> None:
"""Test connection to client is lost."""
agent_id = "conversation.test_handle"
with patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=None,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == snapshot()
async def test_oserror(
hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion
) -> None:
"""Test connection error."""
agent_id = "conversation.test_handle"
mock_client = MockAsyncTcpClient([Transcript("success").event()])
with (
patch(
"homeassistant.components.wyoming.conversation.AsyncTcpClient", mock_client
),
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
):
result = await conversation.async_converse(
hass=hass,
text="test text",
conversation_id=None,
context=Context(),
language=hass.config.language,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR
assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN
assert result.response.speech, "No speech"
assert result.response.speech.get("plain", {}).get("speech") == snapshot()