mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Filter out certain intents from being matched in local fallback (#137763)
* Filter out certain intents from being matched in local fallback * Only filter if LLM agent can control HA
This commit is contained in:
parent
b2e2ef3119
commit
0b6f49fec2
@ -13,7 +13,7 @@ from pathlib import Path
|
|||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import time
|
import time
|
||||||
from typing import Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import hass_nabucasa
|
import hass_nabucasa
|
||||||
@ -30,7 +30,7 @@ from homeassistant.components import (
|
|||||||
from homeassistant.components.tts import (
|
from homeassistant.components.tts import (
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session, intent
|
from homeassistant.helpers import chat_session, intent
|
||||||
@ -81,6 +81,9 @@ from .error import (
|
|||||||
)
|
)
|
||||||
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
|
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from hassil.recognize import RecognizeResult
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
||||||
@ -123,6 +126,12 @@ STORED_PIPELINE_RUNS = 10
|
|||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_local_fallback_intent_filter(result: RecognizeResult) -> bool:
|
||||||
|
"""Filter out intents that are not local fallback."""
|
||||||
|
return result.intent.name in (intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_resolve_default_pipeline_settings(
|
def _async_resolve_default_pipeline_settings(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -1084,10 +1093,22 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
intent_response.async_set_speech(trigger_response_text)
|
intent_response.async_set_speech(trigger_response_text)
|
||||||
|
|
||||||
|
intent_filter: Callable[[RecognizeResult], bool] | None = None
|
||||||
|
# If the LLM has API access, we filter out some sentences that are
|
||||||
|
# interfering with LLM operation.
|
||||||
|
if (
|
||||||
|
intent_agent_state := self.hass.states.get(self.intent_agent)
|
||||||
|
) and intent_agent_state.attributes.get(
|
||||||
|
ATTR_SUPPORTED_FEATURES, 0
|
||||||
|
) & conversation.ConversationEntityFeature.CONTROL:
|
||||||
|
intent_filter = _async_local_fallback_intent_filter
|
||||||
|
|
||||||
# Try local intents first, if preferred.
|
# Try local intents first, if preferred.
|
||||||
elif self.pipeline.prefer_local_intents and (
|
elif self.pipeline.prefer_local_intents and (
|
||||||
intent_response := await conversation.async_handle_intents(
|
intent_response := await conversation.async_handle_intents(
|
||||||
self.hass, user_input
|
self.hass,
|
||||||
|
user_input,
|
||||||
|
intent_filter=intent_filter,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
# Local intent matched
|
# Local intent matched
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from hassil.recognize import RecognizeResult
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@ -241,7 +243,10 @@ async def async_handle_sentence_triggers(
|
|||||||
|
|
||||||
|
|
||||||
async def async_handle_intents(
|
async def async_handle_intents(
|
||||||
hass: HomeAssistant, user_input: ConversationInput
|
hass: HomeAssistant,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
*,
|
||||||
|
intent_filter: Callable[[RecognizeResult], bool] | None = None,
|
||||||
) -> intent.IntentResponse | None:
|
) -> intent.IntentResponse | None:
|
||||||
"""Try to match input against registered intents and return response.
|
"""Try to match input against registered intents and return response.
|
||||||
|
|
||||||
@ -250,7 +255,9 @@ async def async_handle_intents(
|
|||||||
default_agent = async_get_agent(hass)
|
default_agent = async_get_agent(hass)
|
||||||
assert isinstance(default_agent, DefaultAgent)
|
assert isinstance(default_agent, DefaultAgent)
|
||||||
|
|
||||||
return await default_agent.async_handle_intents(user_input)
|
return await default_agent.async_handle_intents(
|
||||||
|
user_input, intent_filter=intent_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
@ -1324,6 +1324,8 @@ class DefaultAgent(ConversationEntity):
|
|||||||
async def async_handle_intents(
|
async def async_handle_intents(
|
||||||
self,
|
self,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
|
*,
|
||||||
|
intent_filter: Callable[[RecognizeResult], bool] | None = None,
|
||||||
) -> intent.IntentResponse | None:
|
) -> intent.IntentResponse | None:
|
||||||
"""Try to match sentence against registered intents and return response.
|
"""Try to match sentence against registered intents and return response.
|
||||||
|
|
||||||
@ -1331,7 +1333,9 @@ class DefaultAgent(ConversationEntity):
|
|||||||
Returns None if no match or a matching error occurred.
|
Returns None if no match or a matching error occurred.
|
||||||
"""
|
"""
|
||||||
result = await self.async_recognize_intent(user_input, strict_intents_only=True)
|
result = await self.async_recognize_intent(user_input, strict_intents_only=True)
|
||||||
if not isinstance(result, RecognizeResult):
|
if not isinstance(result, RecognizeResult) or (
|
||||||
|
intent_filter is not None and intent_filter(result)
|
||||||
|
):
|
||||||
# No error message on failed match
|
# No error message on failed match
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
|
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineStorageCollection,
|
PipelineStorageCollection,
|
||||||
PipelineStore,
|
PipelineStore,
|
||||||
|
_async_local_fallback_intent_filter,
|
||||||
async_create_default_pipeline,
|
async_create_default_pipeline,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
async_get_pipelines,
|
async_get_pipelines,
|
||||||
@ -23,6 +25,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||||||
async_update_pipeline,
|
async_update_pipeline,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import MANY_LANGUAGES
|
from . import MANY_LANGUAGES
|
||||||
@ -657,3 +660,40 @@ async def test_migrate_after_load(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
assert pipeline_updated.stt_engine == "stt.test"
|
assert pipeline_updated.stt_engine == "stt.test"
|
||||||
assert pipeline_updated.tts_engine == "tts.test"
|
assert pipeline_updated.tts_engine == "tts.test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_intent_filter() -> None:
|
||||||
|
"""Test that we filter the right things."""
|
||||||
|
assert (
|
||||||
|
_async_local_fallback_intent_filter(
|
||||||
|
RecognizeResult(
|
||||||
|
intent=Intent(intent.INTENT_GET_STATE),
|
||||||
|
intent_data=IntentData([]),
|
||||||
|
entities={},
|
||||||
|
entities_list=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
_async_local_fallback_intent_filter(
|
||||||
|
RecognizeResult(
|
||||||
|
intent=Intent(intent.INTENT_NEVERMIND),
|
||||||
|
intent_data=IntentData([]),
|
||||||
|
entities={},
|
||||||
|
entities_list=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
_async_local_fallback_intent_filter(
|
||||||
|
RecognizeResult(
|
||||||
|
intent=Intent(intent.INTENT_TURN_ON),
|
||||||
|
intent_data=IntentData([]),
|
||||||
|
entities={},
|
||||||
|
entities_list=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
@ -3154,6 +3154,79 @@ async def test_handle_intents_with_response_errors(
|
|||||||
assert response is None
|
assert response is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("init_components")
|
||||||
|
async def test_handle_intents_filters_results(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test that handle_intents can filter responses."""
|
||||||
|
assert await async_setup_component(hass, "climate", {})
|
||||||
|
area_registry.async_create("living room")
|
||||||
|
|
||||||
|
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
|
||||||
|
user_input = ConversationInput(
|
||||||
|
text="What is the temperature in the living room?",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
|
language=hass.config.language,
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = RecognizeResult(
|
||||||
|
intent=Intent("HassTurnOn"),
|
||||||
|
intent_data=IntentData([]),
|
||||||
|
entities={},
|
||||||
|
entities_list=[],
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def _filter_intents(result):
|
||||||
|
results.append(result)
|
||||||
|
# We filter first, not 2nd.
|
||||||
|
return len(results) == 1
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent",
|
||||||
|
return_value=mock_result,
|
||||||
|
) as mock_recognize,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.conversation.default_agent.DefaultAgent._async_process_intent_result",
|
||||||
|
) as mock_process,
|
||||||
|
):
|
||||||
|
response = await agent.async_handle_intents(
|
||||||
|
user_input, intent_filter=_filter_intents
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_recognize.mock_calls) == 1
|
||||||
|
assert len(mock_process.mock_calls) == 0
|
||||||
|
|
||||||
|
# It was ignored
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
# Check we filtered things
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0] is mock_result
|
||||||
|
|
||||||
|
# Second time it is not filtered
|
||||||
|
response = await agent.async_handle_intents(
|
||||||
|
user_input, intent_filter=_filter_intents
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_recognize.mock_calls) == 2
|
||||||
|
assert len(mock_process.mock_calls) == 2
|
||||||
|
|
||||||
|
# Check we filtered things
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[1] is mock_result
|
||||||
|
|
||||||
|
# It was ignored
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("init_components")
|
@pytest.mark.usefixtures("init_components")
|
||||||
async def test_state_names_are_not_translated(
|
async def test_state_names_are_not_translated(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user