mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 06:47:09 +00:00
Add CONTROL supported feature to Google conversation when API access (#123046)
* Add CONTROL supported feature to Google conversation when API access * Better function name * Handle entry update inline * Reload instead of update
This commit is contained in:
parent
f6ad018f8f
commit
aa6f0cd55a
@ -172,6 +172,10 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
model="Generative AI",
|
model="Generative AI",
|
||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
)
|
)
|
||||||
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||||
|
self._attr_supported_features = (
|
||||||
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
@ -185,6 +189,9 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||||
)
|
)
|
||||||
conversation.async_set_agent(self.hass, self.entry, self)
|
conversation.async_set_agent(self.hass, self.entry, self)
|
||||||
|
self.entry.async_on_unload(
|
||||||
|
self.entry.add_update_listener(self._async_entry_update_listener)
|
||||||
|
)
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
"""When entity will be removed from Home Assistant."""
|
"""When entity will be removed from Home Assistant."""
|
||||||
@ -405,3 +412,10 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
parts.append(llm_api.api_prompt)
|
parts.append(llm_api.api_prompt)
|
||||||
|
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
async def _async_entry_update_listener(
|
||||||
|
self, hass: HomeAssistant, entry: ConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Handle options update."""
|
||||||
|
# Reload as we update device info + entity name + supported features
|
||||||
|
await hass.config_entries.async_reload(entry.entry_id)
|
||||||
|
@ -215,7 +215,7 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options0-None]
|
# name: test_default_prompt[config_entry_options0-0-None]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
'',
|
'',
|
||||||
@ -263,7 +263,7 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation]
|
# name: test_default_prompt[config_entry_options0-0-conversation.google_generative_ai_conversation]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
'',
|
'',
|
||||||
@ -311,7 +311,7 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options1-None]
|
# name: test_default_prompt[config_entry_options1-1-None]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
'',
|
'',
|
||||||
@ -360,7 +360,7 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation]
|
# name: test_default_prompt[config_entry_options1-1-conversation.google_generative_ai_conversation]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
'',
|
'',
|
||||||
|
@ -19,7 +19,7 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
|
|||||||
_escape_decode,
|
_escape_decode,
|
||||||
_format_schema,
|
_format_schema,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent, llm
|
from homeassistant.helpers import intent, llm
|
||||||
@ -39,10 +39,13 @@ def freeze_the_time():
|
|||||||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_entry_options",
|
("config_entry_options", "expected_features"),
|
||||||
[
|
[
|
||||||
{},
|
({}, 0),
|
||||||
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
(
|
||||||
|
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||||
|
conversation.ConversationEntityFeature.CONTROL,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
@ -52,6 +55,7 @@ async def test_default_prompt(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
agent_id: str | None,
|
agent_id: str | None,
|
||||||
config_entry_options: {},
|
config_entry_options: {},
|
||||||
|
expected_features: conversation.ConversationEntityFeature,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
@ -98,6 +102,9 @@ async def test_default_prompt(
|
|||||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||||
|
|
||||||
|
state = hass.states.get("conversation.google_generative_ai_conversation")
|
||||||
|
assert state.attributes[ATTR_SUPPORTED_FEATURES] == expected_features
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "supports_system_instruction"),
|
("model_name", "supports_system_instruction"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user