Filter out certain intents from being matched in local fallback (#137763)

* Filter out certain intents from being matched in local fallback

* Only filter if LLM agent can control HA
This commit is contained in:
Paulus Schoutsen 2025-02-19 15:27:42 -05:00 committed by GitHub
parent b2e2ef3119
commit 0b6f49fec2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from threading import Thread from threading import Thread
import time import time
from typing import Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
import wave import wave
import hass_nabucasa import hass_nabucasa
@ -30,7 +30,7 @@ from homeassistant.components import (
from homeassistant.components.tts import ( from homeassistant.components.tts import (
generate_media_source_id as tts_generate_media_source_id, generate_media_source_id as tts_generate_media_source_id,
) )
from homeassistant.const import MATCH_ALL from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent from homeassistant.helpers import chat_session, intent
@ -81,6 +81,9 @@ from .error import (
) )
from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples from .vad import AudioBuffer, VoiceActivityTimeout, VoiceCommandSegmenter, chunk_samples
if TYPE_CHECKING:
from hassil.recognize import RecognizeResult
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STORAGE_KEY = f"{DOMAIN}.pipelines" STORAGE_KEY = f"{DOMAIN}.pipelines"
@ -123,6 +126,12 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10 SAVE_DELAY = 10
@callback
def _async_local_fallback_intent_filter(result: RecognizeResult) -> bool:
"""Filter out intents that are not local fallback."""
return result.intent.name in (intent.INTENT_GET_STATE, intent.INTENT_NEVERMIND)
@callback @callback
def _async_resolve_default_pipeline_settings( def _async_resolve_default_pipeline_settings(
hass: HomeAssistant, hass: HomeAssistant,
@ -1084,10 +1093,22 @@ class PipelineRun:
) )
intent_response.async_set_speech(trigger_response_text) intent_response.async_set_speech(trigger_response_text)
intent_filter: Callable[[RecognizeResult], bool] | None = None
# If the LLM has API access, we filter out some sentences that are
# interfering with LLM operation.
if (
intent_agent_state := self.hass.states.get(self.intent_agent)
) and intent_agent_state.attributes.get(
ATTR_SUPPORTED_FEATURES, 0
) & conversation.ConversationEntityFeature.CONTROL:
intent_filter = _async_local_fallback_intent_filter
# Try local intents first, if preferred. # Try local intents first, if preferred.
elif self.pipeline.prefer_local_intents and ( elif self.pipeline.prefer_local_intents and (
intent_response := await conversation.async_handle_intents( intent_response := await conversation.async_handle_intents(
self.hass, user_input self.hass,
user_input,
intent_filter=intent_filter,
) )
): ):
# Local intent matched # Local intent matched

View File

@ -2,10 +2,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import logging import logging
import re import re
from typing import Literal from typing import Literal
from hassil.recognize import RecognizeResult
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -241,7 +243,10 @@ async def async_handle_sentence_triggers(
async def async_handle_intents( async def async_handle_intents(
hass: HomeAssistant, user_input: ConversationInput hass: HomeAssistant,
user_input: ConversationInput,
*,
intent_filter: Callable[[RecognizeResult], bool] | None = None,
) -> intent.IntentResponse | None: ) -> intent.IntentResponse | None:
"""Try to match input against registered intents and return response. """Try to match input against registered intents and return response.
@ -250,7 +255,9 @@ async def async_handle_intents(
default_agent = async_get_agent(hass) default_agent = async_get_agent(hass)
assert isinstance(default_agent, DefaultAgent) assert isinstance(default_agent, DefaultAgent)
return await default_agent.async_handle_intents(user_input) return await default_agent.async_handle_intents(
user_input, intent_filter=intent_filter
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -1324,6 +1324,8 @@ class DefaultAgent(ConversationEntity):
async def async_handle_intents( async def async_handle_intents(
self, self,
user_input: ConversationInput, user_input: ConversationInput,
*,
intent_filter: Callable[[RecognizeResult], bool] | None = None,
) -> intent.IntentResponse | None: ) -> intent.IntentResponse | None:
"""Try to match sentence against registered intents and return response. """Try to match sentence against registered intents and return response.
@ -1331,7 +1333,9 @@ class DefaultAgent(ConversationEntity):
Returns None if no match or a matching error occurred. Returns None if no match or a matching error occurred.
""" """
result = await self.async_recognize_intent(user_input, strict_intents_only=True) result = await self.async_recognize_intent(user_input, strict_intents_only=True)
if not isinstance(result, RecognizeResult): if not isinstance(result, RecognizeResult) or (
intent_filter is not None and intent_filter(result)
):
# No error message on failed match # No error message on failed match
return None return None

View File

@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
@ -16,6 +17,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
PipelineData, PipelineData,
PipelineStorageCollection, PipelineStorageCollection,
PipelineStore, PipelineStore,
_async_local_fallback_intent_filter,
async_create_default_pipeline, async_create_default_pipeline,
async_get_pipeline, async_get_pipeline,
async_get_pipelines, async_get_pipelines,
@ -23,6 +25,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_update_pipeline, async_update_pipeline,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES from . import MANY_LANGUAGES
@ -657,3 +660,40 @@ async def test_migrate_after_load(hass: HomeAssistant) -> None:
assert pipeline_updated.stt_engine == "stt.test" assert pipeline_updated.stt_engine == "stt.test"
assert pipeline_updated.tts_engine == "tts.test" assert pipeline_updated.tts_engine == "tts.test"
def test_fallback_intent_filter() -> None:
"""Test that we filter the right things."""
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_GET_STATE),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is True
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_NEVERMIND),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is True
)
assert (
_async_local_fallback_intent_filter(
RecognizeResult(
intent=Intent(intent.INTENT_TURN_ON),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
)
is False
)

View File

@ -3154,6 +3154,79 @@ async def test_handle_intents_with_response_errors(
assert response is None assert response is None
@pytest.mark.usefixtures("init_components")
async def test_handle_intents_filters_results(
hass: HomeAssistant,
init_components: None,
area_registry: ar.AreaRegistry,
) -> None:
"""Test that handle_intents can filter responses."""
assert await async_setup_component(hass, "climate", {})
area_registry.async_create("living room")
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
user_input = ConversationInput(
text="What is the temperature in the living room?",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
agent_id=None,
)
mock_result = RecognizeResult(
intent=Intent("HassTurnOn"),
intent_data=IntentData([]),
entities={},
entities_list=[],
)
results = []
def _filter_intents(result):
results.append(result)
# We filter first, not 2nd.
return len(results) == 1
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_recognize_intent",
return_value=mock_result,
) as mock_recognize,
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent._async_process_intent_result",
) as mock_process,
):
response = await agent.async_handle_intents(
user_input, intent_filter=_filter_intents
)
assert len(mock_recognize.mock_calls) == 1
assert len(mock_process.mock_calls) == 0
# It was ignored
assert response is None
# Check we filtered things
assert len(results) == 1
assert results[0] is mock_result
# Second time it is not filtered
response = await agent.async_handle_intents(
user_input, intent_filter=_filter_intents
)
assert len(mock_recognize.mock_calls) == 2
assert len(mock_process.mock_calls) == 2
# Check we filtered things
assert len(results) == 2
assert results[1] is mock_result
# It was ignored
assert response is not None
@pytest.mark.usefixtures("init_components") @pytest.mark.usefixtures("init_components")
async def test_state_names_are_not_translated( async def test_state_names_are_not_translated(
hass: HomeAssistant, hass: HomeAssistant,