mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Add conversation agent to Wyoming (#124373)
* Add conversation agent to Wyoming * Remove error * Remove conversation platform from satellite list * Clean up * Update homeassistant/components/wyoming/conversation.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Remove unnecessary attribute --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
bcac851677
commit
11ac8f8006
194
homeassistant/components/wyoming/conversation.py
Normal file
194
homeassistant/components/wyoming/conversation.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""Support for Wyoming intent recognition services."""
|
||||
|
||||
import logging
|
||||
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.handle import Handled, NotHandled
|
||||
from wyoming.info import HandleProgram, IntentProgram
|
||||
from wyoming.intent import Intent, NotRecognized
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming conversation."""
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingConversationEntity(config_entry, item.service),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
"""Wyoming conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
super().__init__()
|
||||
|
||||
self.service = service
|
||||
|
||||
self._intent_service: IntentProgram | None = None
|
||||
self._handle_service: HandleProgram | None = None
|
||||
|
||||
for maybe_intent in self.service.info.intent:
|
||||
if maybe_intent.installed:
|
||||
self._intent_service = maybe_intent
|
||||
break
|
||||
|
||||
for maybe_handle in self.service.info.handle:
|
||||
if maybe_handle.installed:
|
||||
self._handle_service = maybe_handle
|
||||
break
|
||||
|
||||
model_languages: set[str] = set()
|
||||
|
||||
if self._intent_service is not None:
|
||||
for intent_model in self._intent_service.models:
|
||||
if intent_model.installed:
|
||||
model_languages.update(intent_model.languages)
|
||||
|
||||
self._attr_name = self._intent_service.name
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
elif self._handle_service is not None:
|
||||
for handle_model in self._handle_service.models:
|
||||
if handle_model.installed:
|
||||
model_languages.update(handle_model.languages)
|
||||
|
||||
self._attr_name = self._handle_service.name
|
||||
|
||||
self._supported_languages = list(model_languages)
|
||||
self._attr_unique_id = f"{config_entry.entry_id}-conversation"
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return self._supported_languages
|
||||
|
||||
async def async_process(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
await client.write_event(
|
||||
Transcript(
|
||||
user_input.text, context={"conversation_id": conversation_id}
|
||||
).event()
|
||||
)
|
||||
|
||||
while True:
|
||||
event = await client.read_event()
|
||||
if event is None:
|
||||
_LOGGER.debug("Connection lost")
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Connection to service was lost",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=user_input.conversation_id,
|
||||
)
|
||||
|
||||
if Intent.is_type(event.type):
|
||||
# Success
|
||||
recognized_intent = Intent.from_event(event)
|
||||
_LOGGER.debug("Recognized intent: %s", recognized_intent)
|
||||
|
||||
intent_type = recognized_intent.name
|
||||
intent_slots = {
|
||||
e.name: {"value": e.value}
|
||||
for e in recognized_intent.entities
|
||||
}
|
||||
intent_response = await intent.async_handle(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
intent_type,
|
||||
intent_slots,
|
||||
text_input=user_input.text,
|
||||
language=user_input.language,
|
||||
)
|
||||
|
||||
if (not intent_response.speech) and recognized_intent.text:
|
||||
intent_response.async_set_speech(recognized_intent.text)
|
||||
|
||||
break
|
||||
|
||||
if NotRecognized.is_type(event.type):
|
||||
not_recognized = NotRecognized.from_event(event)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||
not_recognized.text,
|
||||
)
|
||||
break
|
||||
|
||||
if Handled.is_type(event.type):
|
||||
# Success
|
||||
handled = Handled.from_event(event)
|
||||
intent_response.async_set_speech(handled.text)
|
||||
break
|
||||
|
||||
if NotHandled.is_type(event.type):
|
||||
not_handled = NotHandled.from_event(event)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
not_handled.text,
|
||||
)
|
||||
break
|
||||
|
||||
except (OSError, WyomingError) as err:
|
||||
_LOGGER.exception("Unexpected error while communicating with service")
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error communicating with service: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=user_input.conversation_id,
|
||||
)
|
||||
except intent.IntentError as err:
|
||||
_LOGGER.exception("Unexpected error while handling intent")
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
|
||||
f"Error handling intent: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=user_input.conversation_id,
|
||||
)
|
||||
|
||||
# Success
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
@ -37,6 +37,10 @@ class WyomingService:
|
||||
self.platforms.append(Platform.TTS)
|
||||
if any(wake.installed for wake in info.wake):
|
||||
self.platforms.append(Platform.WAKE_WORD)
|
||||
if any(intent.installed for intent in info.intent) or any(
|
||||
handle.installed for handle in info.handle
|
||||
):
|
||||
self.platforms.append(Platform.CONVERSATION)
|
||||
|
||||
def has_services(self) -> bool:
|
||||
"""Return True if services are installed that Home Assistant can use."""
|
||||
@ -44,6 +48,8 @@ class WyomingService:
|
||||
any(asr for asr in self.info.asr if asr.installed)
|
||||
or any(tts for tts in self.info.tts if tts.installed)
|
||||
or any(wake for wake in self.info.wake if wake.installed)
|
||||
or any(intent for intent in self.info.intent if intent.installed)
|
||||
or any(handle for handle in self.info.handle if handle.installed)
|
||||
or ((self.info.satellite is not None) and self.info.satellite.installed)
|
||||
)
|
||||
|
||||
@ -70,6 +76,16 @@ class WyomingService:
|
||||
if wake_installed:
|
||||
return wake_installed[0].name
|
||||
|
||||
# intent recognition (text -> intent)
|
||||
intent_installed = [intent for intent in self.info.intent if intent.installed]
|
||||
if intent_installed:
|
||||
return intent_installed[0].name
|
||||
|
||||
# intent handling (text -> text)
|
||||
handle_installed = [handle for handle in self.info.handle if handle.installed]
|
||||
if handle_installed:
|
||||
return handle_installed[0].name
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -8,7 +8,11 @@ from wyoming.info import (
|
||||
AsrModel,
|
||||
AsrProgram,
|
||||
Attribution,
|
||||
HandleModel,
|
||||
HandleProgram,
|
||||
Info,
|
||||
IntentModel,
|
||||
IntentProgram,
|
||||
Satellite,
|
||||
TtsProgram,
|
||||
TtsVoice,
|
||||
@ -87,6 +91,48 @@ WAKE_WORD_INFO = Info(
|
||||
)
|
||||
]
|
||||
)
|
||||
INTENT_INFO = Info(
|
||||
intent=[
|
||||
IntentProgram(
|
||||
name="Test Intent",
|
||||
description="Test Intent",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
models=[
|
||||
IntentModel(
|
||||
name="Test Model",
|
||||
description="Test Model",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
version=None,
|
||||
)
|
||||
],
|
||||
version=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
HANDLE_INFO = Info(
|
||||
handle=[
|
||||
HandleProgram(
|
||||
name="Test Handle",
|
||||
description="Test Handle",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
models=[
|
||||
HandleModel(
|
||||
name="Test Model",
|
||||
description="Test Model",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
version=None,
|
||||
)
|
||||
],
|
||||
version=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
SATELLITE_INFO = Info(
|
||||
satellite=Satellite(
|
||||
name="Test Satellite",
|
||||
|
@ -13,7 +13,14 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
||||
from . import (
|
||||
HANDLE_INFO,
|
||||
INTENT_INFO,
|
||||
SATELLITE_INFO,
|
||||
STT_INFO,
|
||||
TTS_INFO,
|
||||
WAKE_WORD_INFO,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -83,6 +90,36 @@ def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def intent_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test Intent",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handle_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test Handle",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming STT."""
|
||||
@ -115,6 +152,34 @@ async def init_wyoming_wake_word(
|
||||
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_intent(
|
||||
hass: HomeAssistant, intent_config_entry: ConfigEntry
|
||||
) -> ConfigEntry:
|
||||
"""Initialize Wyoming intent recognizer."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=INTENT_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(intent_config_entry.entry_id)
|
||||
|
||||
return intent_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_handle(
|
||||
hass: HomeAssistant, handle_config_entry: ConfigEntry
|
||||
) -> ConfigEntry:
|
||||
"""Initialize Wyoming intent handler."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=HANDLE_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(handle_config_entry.entry_id)
|
||||
|
||||
return handle_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
|
||||
"""Get default STT metadata."""
|
||||
|
@ -0,0 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_connection_lost
|
||||
'Connection to service was lost'
|
||||
# ---
|
||||
# name: test_oserror
|
||||
'Error communicating with service: Boom!'
|
||||
# ---
|
224
tests/components/wyoming/test_conversation.py
Normal file
224
tests/components/wyoming/test_conversation.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""Test conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from syrupy import SnapshotAssertion
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.handle import Handled, NotHandled
|
||||
from wyoming.intent import Entity, Intent, NotRecognized
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
|
||||
from . import MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_intent(hass: HomeAssistant, init_wyoming_intent: ConfigEntry) -> None:
|
||||
"""Test when an intent is recognized."""
|
||||
agent_id = "conversation.test_intent"
|
||||
|
||||
conversation_id = "conversation-1234"
|
||||
test_intent = Intent(
|
||||
name="TestIntent",
|
||||
entities=[Entity(name="entity", value="value")],
|
||||
text="success",
|
||||
)
|
||||
|
||||
class TestIntentHandler(intent.IntentHandler):
|
||||
"""Test Intent Handler."""
|
||||
|
||||
intent_type = "TestIntent"
|
||||
|
||||
async def async_handle(self, intent_obj: intent.Intent):
|
||||
"""Handle the intent."""
|
||||
assert intent_obj.slots.get("entity", {}).get("value") == "value"
|
||||
return intent_obj.create_response()
|
||||
|
||||
intent.async_register(hass, TestIntentHandler())
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([test_intent.event()]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=conversation_id,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == "success"
|
||||
assert result.conversation_id == conversation_id
|
||||
|
||||
|
||||
async def test_intent_handle_error(
|
||||
hass: HomeAssistant, init_wyoming_intent: ConfigEntry
|
||||
) -> None:
|
||||
"""Test error during handling when an intent is recognized."""
|
||||
agent_id = "conversation.test_intent"
|
||||
|
||||
test_intent = Intent(name="TestIntent", entities=[], text="success")
|
||||
|
||||
class TestIntentHandler(intent.IntentHandler):
|
||||
"""Test Intent Handler."""
|
||||
|
||||
intent_type = "TestIntent"
|
||||
|
||||
async def async_handle(self, intent_obj: intent.Intent):
|
||||
"""Handle the intent."""
|
||||
raise intent.IntentError
|
||||
|
||||
intent.async_register(hass, TestIntentHandler())
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([test_intent.event()]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE
|
||||
|
||||
|
||||
async def test_not_recognized(
|
||||
hass: HomeAssistant, init_wyoming_intent: ConfigEntry
|
||||
) -> None:
|
||||
"""Test when an intent is not recognized."""
|
||||
agent_id = "conversation.test_intent"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([NotRecognized(text="failure").event()]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == "failure"
|
||||
|
||||
|
||||
async def test_handle(hass: HomeAssistant, init_wyoming_handle: ConfigEntry) -> None:
|
||||
"""Test when an intent is handled."""
|
||||
agent_id = "conversation.test_handle"
|
||||
|
||||
conversation_id = "conversation-1234"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([Handled(text="success").event()]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=conversation_id,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == "success"
|
||||
assert result.conversation_id == conversation_id
|
||||
|
||||
|
||||
async def test_not_handled(
|
||||
hass: HomeAssistant, init_wyoming_handle: ConfigEntry
|
||||
) -> None:
|
||||
"""Test when an intent is not handled."""
|
||||
agent_id = "conversation.test_handle"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([NotHandled(text="failure").event()]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.FAILED_TO_HANDLE
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == "failure"
|
||||
|
||||
|
||||
async def test_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test connection to client is lost."""
|
||||
agent_id = "conversation.test_handle"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == snapshot()
|
||||
|
||||
|
||||
async def test_oserror(
|
||||
hass: HomeAssistant, init_wyoming_handle: ConfigEntry, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test connection error."""
|
||||
agent_id = "conversation.test_handle"
|
||||
|
||||
mock_client = MockAsyncTcpClient([Transcript("success").event()])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.conversation.AsyncTcpClient", mock_client
|
||||
),
|
||||
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass=hass,
|
||||
text="test text",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
language=hass.config.language,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
assert result.response.error_code == intent.IntentResponseErrorCode.UNKNOWN
|
||||
assert result.response.speech, "No speech"
|
||||
assert result.response.speech.get("plain", {}).get("speech") == snapshot()
|
Loading…
x
Reference in New Issue
Block a user