Address late feedback Google LLM (#117873)

This commit is contained in:
Paulus Schoutsen 2024-05-21 14:11:18 -04:00 committed by GitHub
parent 2a9b31261c
commit f21226dd0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 16 deletions

View File

@ -19,7 +19,10 @@ from .singleton import singleton
LLM_API_ASSIST = "assist" LLM_API_ASSIST = "assist"
PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant." PROMPT_NO_API_CONFIGURED = (
"If the user wants to control a device, tell them to edit the AI configuration and "
"allow access to Home Assistant."
)
@singleton("llm") @singleton("llm")

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_default_prompt[False-None] # name: test_default_prompt[config_entry_options0-None]
list([ list([
tuple( tuple(
'', '',
@ -58,7 +58,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[False-conversation.google_generative_ai_conversation] # name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',
@ -117,7 +117,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[True-None] # name: test_default_prompt[config_entry_options1-None]
list([ list([
tuple( tuple(
'', '',
@ -176,7 +176,7 @@
), ),
]) ])
# --- # ---
# name: test_default_prompt[True-conversation.google_generative_ai_conversation] # name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation]
list([ list([
tuple( tuple(
'', '',

View File

@ -24,7 +24,13 @@ from tests.common import MockConfigEntry
@pytest.mark.parametrize( @pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"] "agent_id", [None, "conversation.google_generative_ai_conversation"]
) )
@pytest.mark.parametrize("allow_hass_access", [False, True]) @pytest.mark.parametrize(
"config_entry_options",
[
{},
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
],
)
async def test_default_prompt( async def test_default_prompt(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -33,7 +39,7 @@ async def test_default_prompt(
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str | None, agent_id: str | None,
allow_hass_access: bool, config_entry_options: {},
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
@ -44,14 +50,10 @@ async def test_default_prompt(
if agent_id is None: if agent_id is None:
agent_id = mock_config_entry.entry_id agent_id = mock_config_entry.entry_id
if allow_hass_access: hass.config_entries.async_update_entry(
hass.config_entries.async_update_entry( mock_config_entry,
mock_config_entry, options={**mock_config_entry.options, **config_entry_options},
options={ )
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
@ -145,7 +147,7 @@ async def test_default_prompt(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == allow_hass_access assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
@patch( @patch(