Mock llm prompts in test_default_prompt for Google Generative AI (#118286)

This commit is contained in:
tronikos 2024-05-27 21:40:26 -07:00 committed by GitHub
parent 69a177e864
commit 4f7a91828e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 211 deletions

View File

@ -150,7 +150,7 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. <no_api_prompt>
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -206,7 +206,7 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. <no_api_prompt>
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -262,49 +262,7 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name. <api_prompt>
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -360,49 +318,7 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name. <api_prompt>
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''', ''',
'role': 'user', 'role': 'user',
}), }),

View File

@ -14,13 +14,7 @@ from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import intent, llm
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -47,9 +41,6 @@ async def test_default_prompt(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_init_component, mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str | None, agent_id: str | None,
config_entry_options: {}, config_entry_options: {},
@ -58,8 +49,6 @@ async def test_default_prompt(
"""Test that the default prompt works.""" """Test that the default prompt works."""
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None: if agent_id is None:
agent_id = mock_config_entry.entry_id agent_id = mock_config_entry.entry_id
@ -68,115 +57,6 @@ async def test_default_prompt(
mock_config_entry, mock_config_entry,
options={**mock_config_entry.options, **config_entry_options}, options={**mock_config_entry.options, **config_entry_options},
) )
entities = []
def create_entity(device: dr.DeviceEntry) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name),
suggested_object_id=str(device.name),
)
entity.write_unavailable_state(hass)
entities.append(entity.entity_id)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
)
for i in range(3):
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
)
# Set options for registered entities
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "homeassistant/expose_entity",
"assistants": ["conversation"],
"entity_ids": entities,
"should_expose": True,
}
)
response = await ws_client.receive_json()
assert response["success"]
with ( with (
patch("google.generativeai.GenerativeModel") as mock_model, patch("google.generativeai.GenerativeModel") as mock_model,
@ -184,6 +64,14 @@ async def test_default_prompt(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools", "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
return_value=[], return_value=[],
) as mock_get_tools, ) as mock_get_tools,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt",
return_value="<api_prompt>",
),
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.async_render_no_api_prompt",
return_value="<no_api_prompt>",
),
): ):
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
@ -268,7 +156,7 @@ async def test_function_call(
mock_config_entry_with_assist: MockConfigEntry, mock_config_entry_with_assist: MockConfigEntry,
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test function calling."""
agent_id = mock_config_entry_with_assist.entry_id agent_id = mock_config_entry_with_assist.entry_id
context = Context() context = Context()
@ -366,7 +254,7 @@ async def test_function_exception(
mock_config_entry_with_assist: MockConfigEntry, mock_config_entry_with_assist: MockConfigEntry,
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test exception in function calling."""
agent_id = mock_config_entry_with_assist.entry_id agent_id = mock_config_entry_with_assist.entry_id
context = Context() context = Context()