Mock llm prompts in test_default_prompt for Google Generative AI (#118286)

This commit is contained in:
tronikos 2024-05-27 21:40:26 -07:00 committed by GitHub
parent 69a177e864
commit 4f7a91828e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 211 deletions

View File

@ -150,7 +150,7 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
<no_api_prompt>
''',
'role': 'user',
}),
@ -206,7 +206,7 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
<no_api_prompt>
''',
'role': 'user',
}),
@ -262,49 +262,7 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
<api_prompt>
''',
'role': 'user',
}),
@ -360,49 +318,7 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
<api_prompt>
''',
'role': 'user',
}),

View File

@ -14,13 +14,7 @@ from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@ -47,9 +41,6 @@ async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion,
agent_id: str | None,
config_entry_options: {},
@ -58,8 +49,6 @@ async def test_default_prompt(
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None:
agent_id = mock_config_entry.entry_id
@ -68,115 +57,6 @@ async def test_default_prompt(
mock_config_entry,
options={**mock_config_entry.options, **config_entry_options},
)
entities = []
def create_entity(device: dr.DeviceEntry) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name),
suggested_object_id=str(device.name),
)
entity.write_unavailable_state(hass)
entities.append(entity.entity_id)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
)
for i in range(3):
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
)
# Set options for registered entities
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "homeassistant/expose_entity",
"assistants": ["conversation"],
"entity_ids": entities,
"should_expose": True,
}
)
response = await ws_client.receive_json()
assert response["success"]
with (
patch("google.generativeai.GenerativeModel") as mock_model,
@ -184,6 +64,14 @@ async def test_default_prompt(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
return_value=[],
) as mock_get_tools,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt",
return_value="<api_prompt>",
),
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.async_render_no_api_prompt",
return_value="<no_api_prompt>",
),
):
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
@ -268,7 +156,7 @@ async def test_function_call(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
"""Test function calling."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
@ -366,7 +254,7 @@ async def test_function_exception(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that the default prompt works."""
"""Test exception in function calling."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()