mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Add ask_question action to Assist satellite (#145233)
* Add get_response to Assist satellite and ESPHome * Rename get_response to ask_question * Add possible answers to questions * Add wildcard support and entity test * Add ESPHome test * Refactor to remove async_ask_question * Use single entity_id instead of target * Fix error message * Remove ESPHome test * Clean up * Revert fix
This commit is contained in:
parent
2c13c70e12
commit
341d9f15f0
@ -1,13 +1,23 @@
|
|||||||
"""Base class for assist satellite entities."""
|
"""Base class for assist satellite entities."""
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from hassil.util import (
|
||||||
|
PUNCTUATION_END,
|
||||||
|
PUNCTUATION_END_WORD,
|
||||||
|
PUNCTUATION_START,
|
||||||
|
PUNCTUATION_START_WORD,
|
||||||
|
)
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.http import StaticPathConfig
|
from homeassistant.components.http import StaticPathConfig
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.const import ATTR_ENTITY_ID
|
||||||
|
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
@ -23,6 +33,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import (
|
from .entity import (
|
||||||
AssistSatelliteAnnouncement,
|
AssistSatelliteAnnouncement,
|
||||||
|
AssistSatelliteAnswer,
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityDescription,
|
AssistSatelliteEntityDescription,
|
||||||
@ -34,6 +45,7 @@ from .websocket_api import async_register_websocket_api
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"AssistSatelliteAnnouncement",
|
"AssistSatelliteAnnouncement",
|
||||||
|
"AssistSatelliteAnswer",
|
||||||
"AssistSatelliteConfiguration",
|
"AssistSatelliteConfiguration",
|
||||||
"AssistSatelliteEntity",
|
"AssistSatelliteEntity",
|
||||||
"AssistSatelliteEntityDescription",
|
"AssistSatelliteEntityDescription",
|
||||||
@ -86,6 +98,62 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
"async_internal_start_conversation",
|
"async_internal_start_conversation",
|
||||||
[AssistSatelliteEntityFeature.START_CONVERSATION],
|
[AssistSatelliteEntityFeature.START_CONVERSATION],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def handle_ask_question(call: ServiceCall) -> dict[str, Any]:
|
||||||
|
"""Handle a Show View service call."""
|
||||||
|
satellite_entity_id: str = call.data[ATTR_ENTITY_ID]
|
||||||
|
satellite_entity: AssistSatelliteEntity | None = component.get_entity(
|
||||||
|
satellite_entity_id
|
||||||
|
)
|
||||||
|
if satellite_entity is None:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Invalid Assist satellite entity id: {satellite_entity_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ask_question_args = {
|
||||||
|
"question": call.data.get("question"),
|
||||||
|
"question_media_id": call.data.get("question_media_id"),
|
||||||
|
"preannounce": call.data.get("preannounce", False),
|
||||||
|
"answers": call.data.get("answers"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if preannounce_media_id := call.data.get("preannounce_media_id"):
|
||||||
|
ask_question_args["preannounce_media_id"] = preannounce_media_id
|
||||||
|
|
||||||
|
answer = await satellite_entity.async_internal_ask_question(**ask_question_args)
|
||||||
|
|
||||||
|
if answer is None:
|
||||||
|
raise HomeAssistantError("No answer from satellite")
|
||||||
|
|
||||||
|
return asdict(answer)
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
domain=DOMAIN,
|
||||||
|
service="ask_question",
|
||||||
|
service_func=handle_ask_question,
|
||||||
|
schema=vol.All(
|
||||||
|
{
|
||||||
|
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
|
||||||
|
vol.Optional("question"): str,
|
||||||
|
vol.Optional("question_media_id"): str,
|
||||||
|
vol.Optional("preannounce"): bool,
|
||||||
|
vol.Optional("preannounce_media_id"): str,
|
||||||
|
vol.Optional("answers"): [
|
||||||
|
{
|
||||||
|
vol.Required("id"): str,
|
||||||
|
vol.Required("sentences"): vol.All(
|
||||||
|
cv.ensure_list,
|
||||||
|
[cv.string],
|
||||||
|
has_one_non_empty_item,
|
||||||
|
has_no_punctuation,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
cv.has_at_least_one_key("question", "question_media_id"),
|
||||||
|
),
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
hass.data[CONNECTION_TEST_DATA] = {}
|
hass.data[CONNECTION_TEST_DATA] = {}
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
hass.http.register_view(ConnectionTestView())
|
hass.http.register_view(ConnectionTestView())
|
||||||
@ -110,3 +178,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
|
def has_no_punctuation(value: list[str]) -> list[str]:
|
||||||
|
"""Validate result does not contain punctuation."""
|
||||||
|
for sentence in value:
|
||||||
|
if (
|
||||||
|
PUNCTUATION_START.search(sentence)
|
||||||
|
or PUNCTUATION_END.search(sentence)
|
||||||
|
or PUNCTUATION_START_WORD.search(sentence)
|
||||||
|
or PUNCTUATION_END_WORD.search(sentence)
|
||||||
|
):
|
||||||
|
raise vol.Invalid("sentence should not contain punctuation")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def has_one_non_empty_item(value: list[str]) -> list[str]:
|
||||||
|
"""Validate result has at least one item."""
|
||||||
|
if len(value) < 1:
|
||||||
|
raise vol.Invalid("at least one sentence is required")
|
||||||
|
|
||||||
|
for sentence in value:
|
||||||
|
if not sentence:
|
||||||
|
raise vol.Invalid("sentences cannot be empty")
|
||||||
|
|
||||||
|
return value
|
||||||
|
@ -4,12 +4,16 @@ from abc import abstractmethod
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
import contextlib
|
import contextlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Literal, final
|
from typing import Any, Literal, final
|
||||||
|
|
||||||
|
from hassil import Intents, recognize
|
||||||
|
from hassil.expression import Expression, ListReference, Sequence
|
||||||
|
from hassil.intents import WildcardSlotList
|
||||||
|
|
||||||
from homeassistant.components import conversation, media_source, stt, tts
|
from homeassistant.components import conversation, media_source, stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
OPTION_PREFERRED,
|
OPTION_PREFERRED,
|
||||||
@ -105,6 +109,20 @@ class AssistSatelliteAnnouncement:
|
|||||||
"""Media ID to be played before announcement."""
|
"""Media ID to be played before announcement."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistSatelliteAnswer:
|
||||||
|
"""Answer to a question."""
|
||||||
|
|
||||||
|
id: str | None
|
||||||
|
"""Matched answer id or None if no answer was matched."""
|
||||||
|
|
||||||
|
sentence: str
|
||||||
|
"""Raw sentence text from user response."""
|
||||||
|
|
||||||
|
slots: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Matched slots from answer."""
|
||||||
|
|
||||||
|
|
||||||
class AssistSatelliteEntity(entity.Entity):
|
class AssistSatelliteEntity(entity.Entity):
|
||||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||||
|
|
||||||
@ -120,8 +138,10 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
_is_announcing = False
|
_is_announcing = False
|
||||||
_extra_system_prompt: str | None = None
|
_extra_system_prompt: str | None = None
|
||||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
|
_stt_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
_attr_tts_options: dict[str, Any] | None = None
|
_attr_tts_options: dict[str, Any] | None = None
|
||||||
_pipeline_task: asyncio.Task | None = None
|
_pipeline_task: asyncio.Task | None = None
|
||||||
|
_ask_question_future: asyncio.Future[str | None] | None = None
|
||||||
|
|
||||||
__assist_satellite_state = AssistSatelliteState.IDLE
|
__assist_satellite_state = AssistSatelliteState.IDLE
|
||||||
|
|
||||||
@ -309,6 +329,112 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
"""Start a conversation from the satellite."""
|
"""Start a conversation from the satellite."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_internal_ask_question(
|
||||||
|
self,
|
||||||
|
question: str | None = None,
|
||||||
|
question_media_id: str | None = None,
|
||||||
|
preannounce: bool = True,
|
||||||
|
preannounce_media_id: str = PREANNOUNCE_URL,
|
||||||
|
answers: list[dict[str, Any]] | None = None,
|
||||||
|
) -> AssistSatelliteAnswer | None:
|
||||||
|
"""Ask a question and get a user's response from the satellite.
|
||||||
|
|
||||||
|
If question_media_id is not provided, question is synthesized to audio
|
||||||
|
with the selected pipeline.
|
||||||
|
|
||||||
|
If question_media_id is provided, it is played directly. It is possible
|
||||||
|
to omit the message and the satellite will not show any text.
|
||||||
|
|
||||||
|
If preannounce is True, a sound is played before the start message or media.
|
||||||
|
If preannounce_media_id is provided, it overrides the default sound.
|
||||||
|
|
||||||
|
Calls async_start_conversation.
|
||||||
|
"""
|
||||||
|
await self._cancel_running_pipeline()
|
||||||
|
|
||||||
|
if question is None:
|
||||||
|
question = ""
|
||||||
|
|
||||||
|
announcement = await self._resolve_announcement_media_id(
|
||||||
|
question,
|
||||||
|
question_media_id,
|
||||||
|
preannounce_media_id=preannounce_media_id if preannounce else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._is_announcing:
|
||||||
|
raise SatelliteBusyError
|
||||||
|
|
||||||
|
self._is_announcing = True
|
||||||
|
self._set_state(AssistSatelliteState.RESPONDING)
|
||||||
|
self._ask_question_future = asyncio.Future()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for announcement to finish
|
||||||
|
await self.async_start_conversation(announcement)
|
||||||
|
|
||||||
|
# Wait for response text
|
||||||
|
response_text = await self._ask_question_future
|
||||||
|
if response_text is None:
|
||||||
|
raise HomeAssistantError("No answer from question")
|
||||||
|
|
||||||
|
if not answers:
|
||||||
|
return AssistSatelliteAnswer(id=None, sentence=response_text)
|
||||||
|
|
||||||
|
return self._question_response_to_answer(response_text, answers)
|
||||||
|
finally:
|
||||||
|
self._is_announcing = False
|
||||||
|
self._set_state(AssistSatelliteState.IDLE)
|
||||||
|
self._ask_question_future = None
|
||||||
|
|
||||||
|
def _question_response_to_answer(
|
||||||
|
self, response_text: str, answers: list[dict[str, Any]]
|
||||||
|
) -> AssistSatelliteAnswer:
|
||||||
|
"""Match text to a pre-defined set of answers."""
|
||||||
|
|
||||||
|
# Build intents and match
|
||||||
|
intents = Intents.from_dict(
|
||||||
|
{
|
||||||
|
"language": self.hass.config.language,
|
||||||
|
"intents": {
|
||||||
|
"QuestionIntent": {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"sentences": answer["sentences"],
|
||||||
|
"metadata": {"answer_id": answer["id"]},
|
||||||
|
}
|
||||||
|
for answer in answers
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assume slot list references are wildcards
|
||||||
|
wildcard_names: set[str] = set()
|
||||||
|
for intent in intents.intents.values():
|
||||||
|
for intent_data in intent.data:
|
||||||
|
for sentence in intent_data.sentences:
|
||||||
|
_collect_list_references(sentence, wildcard_names)
|
||||||
|
|
||||||
|
for wildcard_name in wildcard_names:
|
||||||
|
intents.slot_lists[wildcard_name] = WildcardSlotList(wildcard_name)
|
||||||
|
|
||||||
|
# Match response text
|
||||||
|
result = recognize(response_text, intents)
|
||||||
|
if result is None:
|
||||||
|
# No match
|
||||||
|
return AssistSatelliteAnswer(id=None, sentence=response_text)
|
||||||
|
|
||||||
|
assert result.intent_metadata
|
||||||
|
return AssistSatelliteAnswer(
|
||||||
|
id=result.intent_metadata["answer_id"],
|
||||||
|
sentence=response_text,
|
||||||
|
slots={
|
||||||
|
entity_name: entity.value
|
||||||
|
for entity_name, entity in result.entities.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def async_accept_pipeline_from_satellite(
|
async def async_accept_pipeline_from_satellite(
|
||||||
self,
|
self,
|
||||||
audio_stream: AsyncIterable[bytes],
|
audio_stream: AsyncIterable[bytes],
|
||||||
@ -351,6 +477,11 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (self._ask_question_future is not None) and (
|
||||||
|
start_stage == PipelineStage.STT
|
||||||
|
):
|
||||||
|
end_stage = PipelineStage.STT
|
||||||
|
|
||||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||||
|
|
||||||
# Refresh context if necessary
|
# Refresh context if necessary
|
||||||
@ -433,6 +564,16 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
self._set_state(AssistSatelliteState.IDLE)
|
self._set_state(AssistSatelliteState.IDLE)
|
||||||
elif event.type is PipelineEventType.STT_START:
|
elif event.type is PipelineEventType.STT_START:
|
||||||
self._set_state(AssistSatelliteState.LISTENING)
|
self._set_state(AssistSatelliteState.LISTENING)
|
||||||
|
elif event.type is PipelineEventType.STT_END:
|
||||||
|
# Intercepting text for ask question
|
||||||
|
if (
|
||||||
|
(self._ask_question_future is not None)
|
||||||
|
and (not self._ask_question_future.done())
|
||||||
|
and event.data
|
||||||
|
):
|
||||||
|
self._ask_question_future.set_result(
|
||||||
|
event.data.get("stt_output", {}).get("text")
|
||||||
|
)
|
||||||
elif event.type is PipelineEventType.INTENT_START:
|
elif event.type is PipelineEventType.INTENT_START:
|
||||||
self._set_state(AssistSatelliteState.PROCESSING)
|
self._set_state(AssistSatelliteState.PROCESSING)
|
||||||
elif event.type is PipelineEventType.TTS_START:
|
elif event.type is PipelineEventType.TTS_START:
|
||||||
@ -443,6 +584,12 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
if not self._run_has_tts:
|
if not self._run_has_tts:
|
||||||
self._set_state(AssistSatelliteState.IDLE)
|
self._set_state(AssistSatelliteState.IDLE)
|
||||||
|
|
||||||
|
if (self._ask_question_future is not None) and (
|
||||||
|
not self._ask_question_future.done()
|
||||||
|
):
|
||||||
|
# No text for ask question
|
||||||
|
self._ask_question_future.set_result(None)
|
||||||
|
|
||||||
self.on_pipeline_event(event)
|
self.on_pipeline_event(event)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -577,3 +724,15 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
media_id_source=media_id_source,
|
media_id_source=media_id_source,
|
||||||
preannounce_media_id=preannounce_media_id,
|
preannounce_media_id=preannounce_media_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
||||||
|
"""Collect list reference names recursively."""
|
||||||
|
if isinstance(expression, Sequence):
|
||||||
|
seq: Sequence = expression
|
||||||
|
for item in seq.items:
|
||||||
|
_collect_list_references(item, list_names)
|
||||||
|
elif isinstance(expression, ListReference):
|
||||||
|
# {list}
|
||||||
|
list_ref: ListReference = expression
|
||||||
|
list_names.add(list_ref.slot_name)
|
||||||
|
@ -10,6 +10,9 @@
|
|||||||
},
|
},
|
||||||
"start_conversation": {
|
"start_conversation": {
|
||||||
"service": "mdi:forum"
|
"service": "mdi:forum"
|
||||||
|
},
|
||||||
|
"ask_question": {
|
||||||
|
"service": "mdi:microphone-question"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,5 +5,6 @@
|
|||||||
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
|
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||||
"integration_type": "entity",
|
"integration_type": "entity",
|
||||||
"quality_scale": "internal"
|
"quality_scale": "internal",
|
||||||
|
"requirements": ["hassil==2.2.3"]
|
||||||
}
|
}
|
||||||
|
@ -54,3 +54,35 @@ start_conversation:
|
|||||||
required: false
|
required: false
|
||||||
selector:
|
selector:
|
||||||
text:
|
text:
|
||||||
|
ask_question:
|
||||||
|
fields:
|
||||||
|
entity_id:
|
||||||
|
required: true
|
||||||
|
selector:
|
||||||
|
entity:
|
||||||
|
domain: assist_satellite
|
||||||
|
supported_features:
|
||||||
|
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
|
question:
|
||||||
|
required: false
|
||||||
|
example: "What kind of music would you like to play?"
|
||||||
|
default: ""
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
question_media_id:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
preannounce:
|
||||||
|
required: false
|
||||||
|
default: true
|
||||||
|
selector:
|
||||||
|
boolean:
|
||||||
|
preannounce_media_id:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
answers:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
object:
|
||||||
|
@ -59,6 +59,36 @@
|
|||||||
"description": "Custom media ID to play before the start message or media."
|
"description": "Custom media ID to play before the start message or media."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"ask_question": {
|
||||||
|
"name": "Ask question",
|
||||||
|
"description": "Asks a question and gets the user's response.",
|
||||||
|
"fields": {
|
||||||
|
"entity_id": {
|
||||||
|
"name": "Entity",
|
||||||
|
"description": "Assist satellite entity to ask the question on."
|
||||||
|
},
|
||||||
|
"question": {
|
||||||
|
"name": "Question",
|
||||||
|
"description": "The question to ask."
|
||||||
|
},
|
||||||
|
"question_media_id": {
|
||||||
|
"name": "Question media ID",
|
||||||
|
"description": "The media ID of the question to use instead of text-to-speech."
|
||||||
|
},
|
||||||
|
"preannounce": {
|
||||||
|
"name": "Preannounce",
|
||||||
|
"description": "Play a sound before the start message or media."
|
||||||
|
},
|
||||||
|
"preannounce_media_id": {
|
||||||
|
"name": "Preannounce media ID",
|
||||||
|
"description": "Custom media ID to play before the start message or media."
|
||||||
|
},
|
||||||
|
"answers": {
|
||||||
|
"name": "Answers",
|
||||||
|
"description": "Possible answers to the question."
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
1
requirements_all.txt
generated
1
requirements_all.txt
generated
@ -1129,6 +1129,7 @@ hass-nabucasa==0.102.0
|
|||||||
# homeassistant.components.splunk
|
# homeassistant.components.splunk
|
||||||
hass-splunk==0.1.1
|
hass-splunk==0.1.1
|
||||||
|
|
||||||
|
# homeassistant.components.assist_satellite
|
||||||
# homeassistant.components.conversation
|
# homeassistant.components.conversation
|
||||||
hassil==2.2.3
|
hassil==2.2.3
|
||||||
|
|
||||||
|
1
requirements_test_all.txt
generated
1
requirements_test_all.txt
generated
@ -984,6 +984,7 @@ habluetooth==3.49.0
|
|||||||
# homeassistant.components.cloud
|
# homeassistant.components.cloud
|
||||||
hass-nabucasa==0.102.0
|
hass-nabucasa==0.102.0
|
||||||
|
|
||||||
|
# homeassistant.components.assist_satellite
|
||||||
# homeassistant.components.conversation
|
# homeassistant.components.conversation
|
||||||
hassil==2.2.3
|
hassil==2.2.3
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from dataclasses import asdict
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -20,6 +21,7 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
)
|
)
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
AssistSatelliteAnnouncement,
|
AssistSatelliteAnnouncement,
|
||||||
|
AssistSatelliteAnswer,
|
||||||
SatelliteBusyError,
|
SatelliteBusyError,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_satellite.const import PREANNOUNCE_URL
|
from homeassistant.components.assist_satellite.const import PREANNOUNCE_URL
|
||||||
@ -708,6 +710,127 @@ async def test_start_conversation_default_preannounce(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("service_data", "response_text", "expected_answer"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"preannounce": False},
|
||||||
|
"jazz",
|
||||||
|
AssistSatelliteAnswer(id=None, sentence="jazz"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"answers": [
|
||||||
|
{"id": "jazz", "sentences": ["[some] jazz [please]"]},
|
||||||
|
{"id": "rock", "sentences": ["[some] rock [please]"]},
|
||||||
|
],
|
||||||
|
"preannounce": False,
|
||||||
|
},
|
||||||
|
"Some Rock, please.",
|
||||||
|
AssistSatelliteAnswer(id="rock", sentence="Some Rock, please."),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"answers": [
|
||||||
|
{
|
||||||
|
"id": "genre",
|
||||||
|
"sentences": ["genre {genre} [please]"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "artist",
|
||||||
|
"sentences": ["artist {artist} [please]"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"preannounce": False,
|
||||||
|
},
|
||||||
|
"artist Pink Floyd",
|
||||||
|
AssistSatelliteAnswer(
|
||||||
|
id="artist",
|
||||||
|
sentence="artist Pink Floyd",
|
||||||
|
slots={"artist": "Pink Floyd"},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_ask_question(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
service_data: dict,
|
||||||
|
response_text: str,
|
||||||
|
expected_answer: AssistSatelliteAnswer,
|
||||||
|
) -> None:
|
||||||
|
"""Test asking a question on a device and matching an answer."""
|
||||||
|
entity_id = "assist_satellite.test_entity"
|
||||||
|
question_text = "What kind of music would you like to listen to?"
|
||||||
|
|
||||||
|
await async_update_pipeline(
|
||||||
|
hass, async_get_pipeline(hass), stt_engine="test-stt-engine", stt_language="en"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def speech_to_text(self, *args, **kwargs):
|
||||||
|
self.process_event(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.STT_END, {"stt_output": {"text": response_text}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
original_start_conversation = entity.async_start_conversation
|
||||||
|
|
||||||
|
async def async_start_conversation(start_announcement):
|
||||||
|
# Verify state change
|
||||||
|
assert entity.state == AssistSatelliteState.RESPONDING
|
||||||
|
await original_start_conversation(start_announcement)
|
||||||
|
|
||||||
|
audio_stream = object()
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text"
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.speech_to_text",
|
||||||
|
speech_to_text,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
audio_stream, start_stage=PipelineStage.STT
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
|
return_value="media-source://generated",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
|
return_value="tts.cloud",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_create_stream",
|
||||||
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
return_value=PlayMedia(
|
||||||
|
url="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
mime_type="audio/mp3",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch.object(entity, "async_start_conversation", new=async_start_conversation),
|
||||||
|
):
|
||||||
|
response = await hass.services.async_call(
|
||||||
|
"assist_satellite",
|
||||||
|
"ask_question",
|
||||||
|
{"entity_id": entity_id, "question": question_text, **service_data},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
assert entity.state == AssistSatelliteState.IDLE
|
||||||
|
assert response == asdict(expected_answer)
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_start_keeps_responding(
|
async def test_wake_word_start_keeps_responding(
|
||||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user