Handle generic commands as area commands in the LLM Assist API (#118276)

* Handle generic commands as area commands in the LLM Assist API

* Add word area
This commit is contained in:
Paulus Schoutsen 2024-05-28 11:21:17 -04:00 committed by GitHub
parent dbcef2e3c3
commit f0d7f48930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 12 deletions

View File

@ -231,23 +231,34 @@ class AssistAPI(API):
prompt = [ prompt = [
( (
"Call the intent tools to control Home Assistant. " "Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name." "When controlling an area, prefer passing area name."
) )
] ]
area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None
if tool_input.device_id: if tool_input.device_id:
device_reg = dr.async_get(self.hass) device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id) device = device_reg.async_get(tool_input.device_id)
if device: if device:
area_reg = ar.async_get(self.hass) area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)): if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = fr.async_get(self.hass) floor_reg = fr.async_get(self.hass)
if area.floor_id and ( if area.floor_id:
floor := floor_reg.async_get_floor(area.floor_id) floor = floor_reg.async_get_floor(area.floor_id)
):
prompt.append(f"You are in {area.name} ({floor.name}).") extra = "and all generic commands like 'turn on the lights' should target this area."
if floor and area:
prompt.append(f"You are in area {area.name} (floor {floor.name}) {extra}")
elif area:
prompt.append(f"You are in area {area.name} {extra}")
else: else:
prompt.append(f"You are in {area.name}.") prompt.append(
"Reject all generic commands like 'turn on the lights' because we "
"don't know in what area this conversation is happening."
)
if tool_input.context and tool_input.context.user_id: if tool_input.context and tool_input.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id) user = await self.hass.auth.async_get_user(tool_input.context.user_id)
if user: if user:

View File

@ -371,32 +371,44 @@ async def test_assist_api_prompt(
) )
first_part_prompt = ( first_part_prompt = (
"Call the intent tools to control Home Assistant. " "Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name." "When controlling an area, prefer passing area name."
) )
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
"this conversation is happening."
)
assert prompt == ( assert prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
# Fake that request is made from a specific device ID # Fake that request is made from a specific device ID
tool_input.device_id = device.id tool_input.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"You are in area Test Area and all generic commands like 'turn on the lights' "
"should target this area."
)
assert prompt == ( assert prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
You are in Test Area. {area_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
# Add floor # Add floor
floor = floor_registry.async_create("second floor") floor = floor_registry.async_create("2")
area_registry.async_update(area.id, floor_id=floor.floor_id) area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
"should target this area."
)
assert prompt == ( assert prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
You are in Test Area (second floor). {area_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
@ -409,7 +421,7 @@ You are in Test Area (second floor).
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
assert prompt == ( assert prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
You are in Test Area (second floor). {area_prompt}
The user name is Test User. The user name is Test User.
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )