diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py index 3338f223bc9..f1f38f343f9 100644 --- a/homeassistant/components/assist_satellite/__init__.py +++ b/homeassistant/components/assist_satellite/__init__.py @@ -1,13 +1,23 @@ """Base class for assist satellite entities.""" +from dataclasses import asdict import logging from pathlib import Path +from typing import Any +from hassil.util import ( + PUNCTUATION_END, + PUNCTUATION_END_WORD, + PUNCTUATION_START, + PUNCTUATION_START_WORD, +) import voluptuous as vol from homeassistant.components.http import StaticPathConfig from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import ConfigType @@ -23,6 +33,7 @@ from .const import ( ) from .entity import ( AssistSatelliteAnnouncement, + AssistSatelliteAnswer, AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityDescription, @@ -34,6 +45,7 @@ from .websocket_api import async_register_websocket_api __all__ = [ "DOMAIN", "AssistSatelliteAnnouncement", + "AssistSatelliteAnswer", "AssistSatelliteConfiguration", "AssistSatelliteEntity", "AssistSatelliteEntityDescription", @@ -86,6 +98,62 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: "async_internal_start_conversation", [AssistSatelliteEntityFeature.START_CONVERSATION], ) + + async def handle_ask_question(call: ServiceCall) -> dict[str, Any]: + """Handle a Show View service call.""" + satellite_entity_id: str = call.data[ATTR_ENTITY_ID] + satellite_entity: AssistSatelliteEntity | None = component.get_entity( + satellite_entity_id + ) + if satellite_entity is None: + raise HomeAssistantError( + f"Invalid Assist satellite entity id: {satellite_entity_id}" + ) + + ask_question_args = { + "question": call.data.get("question"), + "question_media_id": call.data.get("question_media_id"), + "preannounce": call.data.get("preannounce", False), + "answers": call.data.get("answers"), + } + + if preannounce_media_id := call.data.get("preannounce_media_id"): + ask_question_args["preannounce_media_id"] = preannounce_media_id + + answer = await satellite_entity.async_internal_ask_question(**ask_question_args) + + if answer is None: + raise HomeAssistantError("No answer from satellite") + + return asdict(answer) + + hass.services.async_register( + domain=DOMAIN, + service="ask_question", + service_func=handle_ask_question, + schema=vol.All( + { + vol.Required(ATTR_ENTITY_ID): cv.entity_id, + vol.Optional("question"): str, + vol.Optional("question_media_id"): str, + vol.Optional("preannounce"): bool, + vol.Optional("preannounce_media_id"): str, + vol.Optional("answers"): [ + { + vol.Required("id"): str, + vol.Required("sentences"): vol.All( + cv.ensure_list, + [cv.string], + has_one_non_empty_item, + has_no_punctuation, + ), + } + ], + }, + cv.has_at_least_one_key("question", "question_media_id"), + ), + supports_response=SupportsResponse.ONLY, + ) hass.data[CONNECTION_TEST_DATA] = {} async_register_websocket_api(hass) hass.http.register_view(ConnectionTestView()) @@ -110,3 +178,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" return await hass.data[DATA_COMPONENT].async_unload_entry(entry) + + +def has_no_punctuation(value: list[str]) -> list[str]: + """Validate result does not contain punctuation.""" + for sentence in value: + if ( + PUNCTUATION_START.search(sentence) + or PUNCTUATION_END.search(sentence) + or PUNCTUATION_START_WORD.search(sentence) + or PUNCTUATION_END_WORD.search(sentence) + ): + raise vol.Invalid("sentence should not contain punctuation") + + return value + + +def has_one_non_empty_item(value: list[str]) -> list[str]: + """Validate result has at least one item.""" + if len(value) < 1: + raise vol.Invalid("at least one sentence is required") + + for sentence in value: + if not sentence: + raise vol.Invalid("sentences cannot be empty") + + return value diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index dc20c7650d7..d32bad2c824 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -4,12 +4,16 @@ from abc import abstractmethod import asyncio from collections.abc import AsyncIterable import contextlib -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import StrEnum import logging import time from typing import Any, Literal, final +from hassil import Intents, recognize +from hassil.expression import Expression, ListReference, Sequence +from hassil.intents import WildcardSlotList + from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components.assist_pipeline import ( OPTION_PREFERRED, @@ -105,6 +109,20 @@ class AssistSatelliteAnnouncement: """Media ID to be played before announcement.""" +@dataclass +class AssistSatelliteAnswer: + """Answer to a question.""" + + id: str | None + """Matched answer id or None if no answer was matched.""" + + sentence: str + """Raw sentence text from user response.""" + + slots: dict[str, Any] = field(default_factory=dict) + """Matched slots from answer.""" + + class AssistSatelliteEntity(entity.Entity): """Entity encapsulating the state and functionality of an Assist satellite.""" @@ -120,8 +138,10 @@ class AssistSatelliteEntity(entity.Entity): _is_announcing = False _extra_system_prompt: str | None = None _wake_word_intercept_future: asyncio.Future[str | None] | None = None + _stt_intercept_future: asyncio.Future[str | None] | None = None _attr_tts_options: dict[str, Any] | None = None _pipeline_task: asyncio.Task | None = None + _ask_question_future: asyncio.Future[str | None] | None = None __assist_satellite_state = AssistSatelliteState.IDLE @@ -309,6 +329,112 @@ class AssistSatelliteEntity(entity.Entity): """Start a conversation from the satellite.""" raise NotImplementedError + async def async_internal_ask_question( + self, + question: str | None = None, + question_media_id: str | None = None, + preannounce: bool = True, + preannounce_media_id: str = PREANNOUNCE_URL, + answers: list[dict[str, Any]] | None = None, + ) -> AssistSatelliteAnswer | None: + """Ask a question and get a user's response from the satellite. + + If question_media_id is not provided, question is synthesized to audio + with the selected pipeline. + + If question_media_id is provided, it is played directly. It is possible + to omit the message and the satellite will not show any text. + + If preannounce is True, a sound is played before the start message or media. + If preannounce_media_id is provided, it overrides the default sound. + + Calls async_start_conversation. + """ + await self._cancel_running_pipeline() + + if question is None: + question = "" + + announcement = await self._resolve_announcement_media_id( + question, + question_media_id, + preannounce_media_id=preannounce_media_id if preannounce else None, + ) + + if self._is_announcing: + raise SatelliteBusyError + + self._is_announcing = True + self._set_state(AssistSatelliteState.RESPONDING) + self._ask_question_future = asyncio.Future() + + try: + # Wait for announcement to finish + await self.async_start_conversation(announcement) + + # Wait for response text + response_text = await self._ask_question_future + if response_text is None: + raise HomeAssistantError("No answer from question") + + if not answers: + return AssistSatelliteAnswer(id=None, sentence=response_text) + + return self._question_response_to_answer(response_text, answers) + finally: + self._is_announcing = False + self._set_state(AssistSatelliteState.IDLE) + self._ask_question_future = None + + def _question_response_to_answer( + self, response_text: str, answers: list[dict[str, Any]] + ) -> AssistSatelliteAnswer: + """Match text to a pre-defined set of answers.""" + + # Build intents and match + intents = Intents.from_dict( + { + "language": self.hass.config.language, + "intents": { + "QuestionIntent": { + "data": [ + { + "sentences": answer["sentences"], + "metadata": {"answer_id": answer["id"]}, + } + for answer in answers + ] + } + }, + } + ) + + # Assume slot list references are wildcards + wildcard_names: set[str] = set() + for intent in intents.intents.values(): + for intent_data in intent.data: + for sentence in intent_data.sentences: + _collect_list_references(sentence, wildcard_names) + + for wildcard_name in wildcard_names: + intents.slot_lists[wildcard_name] = WildcardSlotList(wildcard_name) + + # Match response text + result = recognize(response_text, intents) + if result is None: + # No match + return AssistSatelliteAnswer(id=None, sentence=response_text) + + assert result.intent_metadata + return AssistSatelliteAnswer( + id=result.intent_metadata["answer_id"], + sentence=response_text, + slots={ + entity_name: entity.value + for entity_name, entity in result.entities.items() + }, + ) + async def async_accept_pipeline_from_satellite( self, audio_stream: AsyncIterable[bytes], @@ -351,6 +477,11 @@ class AssistSatelliteEntity(entity.Entity): self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END)) return + if (self._ask_question_future is not None) and ( + start_stage == PipelineStage.STT + ): + end_stage = PipelineStage.STT + device_id = self.registry_entry.device_id if self.registry_entry else None # Refresh context if necessary @@ -433,6 +564,16 @@ class AssistSatelliteEntity(entity.Entity): self._set_state(AssistSatelliteState.IDLE) elif event.type is PipelineEventType.STT_START: self._set_state(AssistSatelliteState.LISTENING) + elif event.type is PipelineEventType.STT_END: + # Intercepting text for ask question + if ( + (self._ask_question_future is not None) + and (not self._ask_question_future.done()) + and event.data + ): + self._ask_question_future.set_result( + event.data.get("stt_output", {}).get("text") + ) elif event.type is PipelineEventType.INTENT_START: self._set_state(AssistSatelliteState.PROCESSING) elif event.type is PipelineEventType.TTS_START: @@ -443,6 +584,12 @@ class AssistSatelliteEntity(entity.Entity): if not self._run_has_tts: self._set_state(AssistSatelliteState.IDLE) + if (self._ask_question_future is not None) and ( + not self._ask_question_future.done() + ): + # No text for ask question + self._ask_question_future.set_result(None) + self.on_pipeline_event(event) @callback @@ -577,3 +724,15 @@ class AssistSatelliteEntity(entity.Entity): media_id_source=media_id_source, preannounce_media_id=preannounce_media_id, ) + + +def _collect_list_references(expression: Expression, list_names: set[str]) -> None: + """Collect list reference names recursively.""" + if isinstance(expression, Sequence): + seq: Sequence = expression + for item in seq.items: + _collect_list_references(item, list_names) + elif isinstance(expression, ListReference): + # {list} + list_ref: ListReference = expression + list_names.add(list_ref.slot_name) diff --git a/homeassistant/components/assist_satellite/icons.json b/homeassistant/components/assist_satellite/icons.json index 1ed29541621..fc2589ea506 100644 --- a/homeassistant/components/assist_satellite/icons.json +++ b/homeassistant/components/assist_satellite/icons.json @@ -10,6 +10,9 @@ }, "start_conversation": { "service": "mdi:forum" + }, + "ask_question": { + "service": "mdi:microphone-question" } } } diff --git a/homeassistant/components/assist_satellite/manifest.json b/homeassistant/components/assist_satellite/manifest.json index 68a3ceafd4f..97362f157e4 100644 --- a/homeassistant/components/assist_satellite/manifest.json +++ b/homeassistant/components/assist_satellite/manifest.json @@ -5,5 +5,6 @@ "dependencies": ["assist_pipeline", "http", "stt", "tts"], "documentation": "https://www.home-assistant.io/integrations/assist_satellite", "integration_type": "entity", - "quality_scale": "internal" + "quality_scale": "internal", + "requirements": ["hassil==2.2.3"] } diff --git a/homeassistant/components/assist_satellite/services.yaml b/homeassistant/components/assist_satellite/services.yaml index d88710c4c4e..c5484e22dad 100644 --- a/homeassistant/components/assist_satellite/services.yaml +++ b/homeassistant/components/assist_satellite/services.yaml @@ -54,3 +54,35 @@ start_conversation: required: false selector: text: +ask_question: + fields: + entity_id: + required: true + selector: + entity: + domain: assist_satellite + supported_features: + - assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION + question: + required: false + example: "What kind of music would you like to play?" + default: "" + selector: + text: + question_media_id: + required: false + selector: + text: + preannounce: + required: false + default: true + selector: + boolean: + preannounce_media_id: + required: false + selector: + text: + answers: + required: false + selector: + object: diff --git a/homeassistant/components/assist_satellite/strings.json b/homeassistant/components/assist_satellite/strings.json index b69711c7106..e0bf2bcfb94 100644 --- a/homeassistant/components/assist_satellite/strings.json +++ b/homeassistant/components/assist_satellite/strings.json @@ -59,6 +59,36 @@ "description": "Custom media ID to play before the start message or media." } } + }, + "ask_question": { + "name": "Ask question", + "description": "Asks a question and gets the user's response.", + "fields": { + "entity_id": { + "name": "Entity", + "description": "Assist satellite entity to ask the question on." + }, + "question": { + "name": "Question", + "description": "The question to ask." + }, + "question_media_id": { + "name": "Question media ID", + "description": "The media ID of the question to use instead of text-to-speech." + }, + "preannounce": { + "name": "Preannounce", + "description": "Play a sound before the start message or media." + }, + "preannounce_media_id": { + "name": "Preannounce media ID", + "description": "Custom media ID to play before the start message or media." + }, + "answers": { + "name": "Answers", + "description": "Possible answers to the question." + } + } } } } diff --git a/requirements_all.txt b/requirements_all.txt index a3f0c833d2f..cf683a09e67 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1129,6 +1129,7 @@ hass-nabucasa==0.102.0 # homeassistant.components.splunk hass-splunk==0.1.1 +# homeassistant.components.assist_satellite # homeassistant.components.conversation hassil==2.2.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a27b9f5d199..3f513185014 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -984,6 +984,7 @@ habluetooth==3.49.0 # homeassistant.components.cloud hass-nabucasa==0.102.0 +# homeassistant.components.assist_satellite # homeassistant.components.conversation hassil==2.2.3 diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index 8050b23f5ff..3473b23bedd 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -2,6 +2,7 @@ import asyncio from collections.abc import Generator +from dataclasses import asdict from unittest.mock import Mock, patch import pytest @@ -20,6 +21,7 @@ from homeassistant.components.assist_pipeline import ( ) from homeassistant.components.assist_satellite import ( AssistSatelliteAnnouncement, + AssistSatelliteAnswer, SatelliteBusyError, ) from homeassistant.components.assist_satellite.const import PREANNOUNCE_URL @@ -708,6 +710,127 @@ async def test_start_conversation_default_preannounce( ) +@pytest.mark.parametrize( + ("service_data", "response_text", "expected_answer"), + [ + ( + {"preannounce": False}, + "jazz", + AssistSatelliteAnswer(id=None, sentence="jazz"), + ), + ( + { + "answers": [ + {"id": "jazz", "sentences": ["[some] jazz [please]"]}, + {"id": "rock", "sentences": ["[some] rock [please]"]}, + ], + "preannounce": False, + }, + "Some Rock, please.", + AssistSatelliteAnswer(id="rock", sentence="Some Rock, please."), + ), + ( + { + "answers": [ + { + "id": "genre", + "sentences": ["genre {genre} [please]"], + }, + { + "id": "artist", + "sentences": ["artist {artist} [please]"], + }, + ], + "preannounce": False, + }, + "artist Pink Floyd", + AssistSatelliteAnswer( + id="artist", + sentence="artist Pink Floyd", + slots={"artist": "Pink Floyd"}, + ), + ), + ], +) +async def test_ask_question( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + service_data: dict, + response_text: str, + expected_answer: AssistSatelliteAnswer, +) -> None: + """Test asking a question on a device and matching an answer.""" + entity_id = "assist_satellite.test_entity" + question_text = "What kind of music would you like to listen to?" + + await async_update_pipeline( + hass, async_get_pipeline(hass), stt_engine="test-stt-engine", stt_language="en" + ) + + async def speech_to_text(self, *args, **kwargs): + self.process_event( + PipelineEvent( + PipelineEventType.STT_END, {"stt_output": {"text": response_text}} + ) + ) + + return response_text + + original_start_conversation = entity.async_start_conversation + + async def async_start_conversation(start_announcement): + # Verify state change + assert entity.state == AssistSatelliteState.RESPONDING + await original_start_conversation(start_announcement) + + audio_stream = object() + with ( + patch( + "homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text" + ), + patch( + "homeassistant.components.assist_pipeline.pipeline.PipelineRun.speech_to_text", + speech_to_text, + ), + ): + await entity.async_accept_pipeline_from_satellite( + audio_stream, start_stage=PipelineStage.STT + ) + + with ( + patch( + "homeassistant.components.tts.generate_media_source_id", + return_value="media-source://generated", + ), + patch( + "homeassistant.components.tts.async_resolve_engine", + return_value="tts.cloud", + ), + patch( + "homeassistant.components.tts.async_create_stream", + return_value=MockResultStream(hass, "wav", b""), + ), + patch( + "homeassistant.components.media_source.async_resolve_media", + return_value=PlayMedia( + url="https://www.home-assistant.io/resolved.mp3", + mime_type="audio/mp3", + ), + ), + patch.object(entity, "async_start_conversation", new=async_start_conversation), + ): + response = await hass.services.async_call( + "assist_satellite", + "ask_question", + {"entity_id": entity_id, "question": question_text, **service_data}, + blocking=True, + return_response=True, + ) + assert entity.state == AssistSatelliteState.IDLE + assert response == asdict(expected_answer) + + async def test_wake_word_start_keeps_responding( hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite ) -> None: