From 4a06e20318ba50fddb4be1f3ea20a988075622f6 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 2 Aug 2024 12:31:31 +0200 Subject: [PATCH] Ollama implement CONTROL supported feature (#123049) --- .../components/ollama/conversation.py | 18 +++++++++++++ tests/components/ollama/test_conversation.py | 25 ++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index f59e268394b..9f66083f506 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -106,6 +106,10 @@ class OllamaConversationEntity( self._history: dict[str, MessageHistory] = {} self._attr_name = entry.title self._attr_unique_id = entry.entry_id + if self.entry.options.get(CONF_LLM_HASS_API): + self._attr_supported_features = ( + conversation.ConversationEntityFeature.CONTROL + ) async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" @@ -114,6 +118,9 @@ class OllamaConversationEntity( self.hass, "conversation", self.entry.entry_id, self.entity_id ) conversation.async_set_agent(self.hass, self.entry, self) + self.entry.async_on_unload( + self.entry.add_update_listener(self._async_entry_update_listener) + ) async def async_will_remove_from_hass(self) -> None: """When entity will be removed from Home Assistant.""" @@ -334,3 +341,14 @@ class OllamaConversationEntity( message_history.messages = [ message_history.messages[0] ] + message_history.messages[drop_index:] + + async def _async_entry_update_listener( + self, hass: HomeAssistant, entry: ConfigEntry + ) -> None: + """Handle options update.""" + if entry.options.get(CONF_LLM_HASS_API): + self._attr_supported_features = ( + conversation.ConversationEntityFeature.CONTROL + ) + else: + self._attr_supported_features = conversation.ConversationEntityFeature(0) diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index b5a94cc6f57..c83dce3b565 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -10,7 +10,7 @@ import voluptuous as vol from homeassistant.components import conversation, ollama from homeassistant.components.conversation import trace -from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL +from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import intent, llm @@ -554,3 +554,26 @@ async def test_conversation_agent( mock_config_entry.entry_id ) assert agent.supported_languages == MATCH_ALL + + state = hass.states.get("conversation.mock_title") + assert state + assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 + + +async def test_conversation_agent_with_assist( + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test OllamaConversationEntity.""" + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry_with_assist.entry_id + ) + assert agent.supported_languages == MATCH_ALL + + state = hass.states.get("conversation.mock_title") + assert state + assert ( + state.attributes[ATTR_SUPPORTED_FEATURES] + == conversation.ConversationEntityFeature.CONTROL + )