Add CONTROL supported feature to Google conversation when API access (#123046)

* Add CONTROL supported feature to Google conversation when API access

* Better function name

* Handle entry update inline

* Reload instead of update
This commit is contained in:
Paulus Schoutsen 2024-08-03 08:16:30 +02:00 committed by GitHub
parent f6ad018f8f
commit aa6f0cd55a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 8 deletions

View File

@ -172,6 +172,10 @@ class GoogleGenerativeAIConversationEntity(
model="Generative AI", model="Generative AI",
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
) )
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@property @property
def supported_languages(self) -> list[str] | Literal["*"]: def supported_languages(self) -> list[str] | Literal["*"]:
@ -185,6 +189,9 @@ class GoogleGenerativeAIConversationEntity(
self.hass, "conversation", self.entry.entry_id, self.entity_id self.hass, "conversation", self.entry.entry_id, self.entity_id
) )
conversation.async_set_agent(self.hass, self.entry, self) conversation.async_set_agent(self.hass, self.entry, self)
self.entry.async_on_unload(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant.""" """When entity will be removed from Home Assistant."""
@ -405,3 +412,10 @@ class GoogleGenerativeAIConversationEntity(
parts.append(llm_api.api_prompt) parts.append(llm_api.api_prompt)
return "\n".join(parts) return "\n".join(parts)
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Handle options update."""
# Reload as we update device info + entity name + supported features
await hass.config_entries.async_reload(entry.entry_id)

View File

@ -215,7 +215,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[config_entry_options0-None] # name: test_default_prompt[config_entry_options0-0-None]
list([ list([
tuple( tuple(
'', '',
@ -263,7 +263,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation] # name: test_default_prompt[config_entry_options0-0-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',
@ -311,7 +311,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[config_entry_options1-None] # name: test_default_prompt[config_entry_options1-1-None]
list([ list([
tuple( tuple(
'', '',
@ -360,7 +360,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation] # name: test_default_prompt[config_entry_options1-1-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',

View File

@ -19,7 +19,7 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
_escape_decode, _escape_decode,
_format_schema, _format_schema,
) )
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm from homeassistant.helpers import intent, llm
@ -39,10 +39,13 @@ def freeze_the_time():
"agent_id", [None, "conversation.google_generative_ai_conversation"] "agent_id", [None, "conversation.google_generative_ai_conversation"]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config_entry_options", ("config_entry_options", "expected_features"),
[ [
{}, ({}, 0),
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, (
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
conversation.ConversationEntityFeature.CONTROL,
),
], ],
) )
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
@ -52,6 +55,7 @@ async def test_default_prompt(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str | None, agent_id: str | None,
config_entry_options: {}, config_entry_options: {},
expected_features: conversation.ConversationEntityFeature,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
@ -98,6 +102,9 @@ async def test_default_prompt(
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
state = hass.states.get("conversation.google_generative_ai_conversation")
assert state.attributes[ATTR_SUPPORTED_FEATURES] == expected_features
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_name", "supports_system_instruction"), ("model_name", "supports_system_instruction"),