From 0b6f49fec24856249a7c47a08645e54e8c2667b2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 19 Feb 2025 15:27:42 -0500 Subject: [PATCH] Filter out certain intents from being matched in local fallback (#137763) * Filter out certain intents from being matched in local fallback * Only filter if LLM agent can control HA --- .../components/assist_pipeline/pipeline.py | 27 ++++++- .../components/conversation/__init__.py | 11 ++- .../components/conversation/default_agent.py | 6 +- .../assist_pipeline/test_pipeline.py | 40 ++++++++++ .../conversation/test_default_agent.py | 73 +++++++++++++++++++ 5 files changed, 151 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index cf9fb4c7212..788a207b83a 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -13,7 +13,7 @@ from pathlib import Path from queue import Empty, Queue from threading import Thread import time -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import wave import hass_nabucasa @@ -30,7 +30,7 @@ from homeassistant.components import ( from homeassistant.components.tts import ( generate_media_source_id as tts_generate_media_source_id, ) -from homeassistant.const import MATCH_ALL +from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import chat_session, intent @@ -81,6 +81,9 @@ from .error import ( ) from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples +if TYPE_CHECKING: + from hassil.recognize import RecognizeResult + _LOGGER = logging.getLogger(__name__) STORAGE_KEY = f"{DOMAIN}.pipelines" @@ -123,6 +126,12 @@ STORED_PIPELINE_RUNS = 10 SAVE_DELAY = 10 +@callback +def _async_local_fallback_intent_filter(result: RecognizeResult) -> bool: + """Filter out intents that are not local fallback.""" + return result.intent.name in (intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND) + + @callback def _async_resolve_default_pipeline_settings( hass: HomeAssistant, @@ -1084,10 +1093,22 @@ class PipelineRun: ) intent_response.async_set_speech(trigger_response_text) + intent_filter: Callable[[RecognizeResult], bool] | None = None + # If the LLM has API access, we filter out some sentences that are + # interfering with LLM operation. + if ( + intent_agent_state := self.hass.states.get(self.intent_agent) + ) and intent_agent_state.attributes.get( + ATTR_SUPPORTED_FEATURES, 0 + ) & conversation.ConversationEntityFeature.CONTROL: + intent_filter = _async_local_fallback_intent_filter + # Try local intents first, if preferred. elif self.pipeline.prefer_local_intents and ( intent_response := await conversation.async_handle_intents( - self.hass, user_input + self.hass, + user_input, + intent_filter=intent_filter, ) ): # Local intent matched diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 11de75801ba..14c5244c18b 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -2,10 +2,12 @@ from __future__ import annotations +from collections.abc import Callable import logging import re from typing import Literal +from hassil.recognize import RecognizeResult import voluptuous as vol from homeassistant.config_entries import ConfigEntry @@ -241,7 +243,10 @@ async def async_handle_sentence_triggers( async def async_handle_intents( - hass: HomeAssistant, user_input: ConversationInput + hass: HomeAssistant, + user_input: ConversationInput, + *, + intent_filter: Callable[[RecognizeResult], bool] | None = None, ) -> intent.IntentResponse | None: """Try to match input against registered intents and return response. @@ -250,7 +255,9 @@ async def async_handle_intents( default_agent = async_get_agent(hass) assert isinstance(default_agent, DefaultAgent) - return await default_agent.async_handle_intents(user_input) + return await default_agent.async_handle_intents( + user_input, intent_filter=intent_filter + ) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index e8bd38f5adf..86c46584faf 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -1324,6 +1324,8 @@ class DefaultAgent(ConversationEntity): async def async_handle_intents( self, user_input: ConversationInput, + *, + intent_filter: Callable[[RecognizeResult], bool] | None = None, ) -> intent.IntentResponse | None: """Try to match sentence against registered intents and return response. @@ -1331,7 +1333,9 @@ class DefaultAgent(ConversationEntity): Returns None if no match or a matching error occurred. """ result = await self.async_recognize_intent(user_input, strict_intents_only=True) - if not isinstance(result, RecognizeResult): + if not isinstance(result, RecognizeResult) or ( + intent_filter is not None and intent_filter(result) + ): # No error message on failed match return None diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index d52e2a762ee..a7f6fbf7553 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator from typing import Any from unittest.mock import ANY, patch +from hassil.recognize import Intent, IntentData, RecognizeResult import pytest from homeassistant.components import conversation @@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( PipelineData, PipelineStorageCollection, PipelineStore, + _async_local_fallback_intent_filter, async_create_default_pipeline, async_get_pipeline, async_get_pipelines, @@ -23,6 +25,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( async_update_pipeline, ) from homeassistant.core import HomeAssistant +from homeassistant.helpers import intent from homeassistant.setup import async_setup_component from . import MANY_LANGUAGES @@ -657,3 +660,40 @@ async def test_migrate_after_load(hass: HomeAssistant) -> None: assert pipeline_updated.stt_engine == "stt.test" assert pipeline_updated.tts_engine == "tts.test" + + +def test_fallback_intent_filter() -> None: + """Test that we filter the right things.""" + assert ( + _async_local_fallback_intent_filter( + RecognizeResult( + intent=Intent(intent.INTENT_GET_STATE), + intent_data=IntentData([]), + entities={}, + entities_list=[], + ) + ) + is True + ) + assert ( + _async_local_fallback_intent_filter( + RecognizeResult( + intent=Intent(intent.INTENT_NEVERMIND), + intent_data=IntentData([]), + entities={}, + entities_list=[], + ) + ) + is True + ) + assert ( + _async_local_fallback_intent_filter( + RecognizeResult( + intent=Intent(intent.INTENT_TURN_ON), + intent_data=IntentData([]), + entities={}, + entities_list=[], + ) + ) + is False + ) diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index d9f9917b9e0..dca4653b480 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -3154,6 +3154,79 @@ async def test_handle_intents_with_response_errors( assert response is None +@pytest.mark.usefixtures("init_components") +async def test_handle_intents_filters_results( + hass: HomeAssistant, + init_components: None, + area_registry: ar.AreaRegistry, +) -> None: + """Test that handle_intents can filter responses.""" + assert await async_setup_component(hass, "climate", {}) + area_registry.async_create("living room") + + agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY] + + user_input = ConversationInput( + text="What is the temperature in the living room?", + context=Context(), + conversation_id=None, + device_id=None, + language=hass.config.language, + agent_id=None, + ) + + mock_result = RecognizeResult( + intent=Intent("HassTurnOn"), + intent_data=IntentData([]), + entities={}, + entities_list=[], + ) + results = [] + + def _filter_intents(result): + results.append(result) + # We filter first, not 2nd. + return len(results) == 1 + + with ( + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent", + return_value=mock_result, + ) as mock_recognize, + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent._async_process_intent_result", + ) as mock_process, + ): + response = await agent.async_handle_intents( + user_input, intent_filter=_filter_intents + ) + + assert len(mock_recognize.mock_calls) == 1 + assert len(mock_process.mock_calls) == 0 + + # It was ignored + assert response is None + + # Check we filtered things + assert len(results) == 1 + assert results[0] is mock_result + + # Second time it is not filtered + response = await agent.async_handle_intents( + user_input, intent_filter=_filter_intents + ) + + assert len(mock_recognize.mock_calls) == 2 + assert len(mock_process.mock_calls) == 2 + + # Check we filtered things + assert len(results) == 2 + assert results[1] is mock_result + + # It was ignored + assert response is not None + + @pytest.mark.usefixtures("init_components") async def test_state_names_are_not_translated( hass: HomeAssistant,