Handle generic commands as area commands in the LLM Assist API (#118276)

* Handle generic commands as area commands in the LLM Assist API

* Add word area
This commit is contained in:
Paulus Schoutsen 2024-05-28 11:21:17 -04:00 committed by GitHub
parent dbcef2e3c3
commit f0d7f48930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 12 deletions

View File

@ -231,23 +231,34 @@ class AssistAPI(API):
prompt = [
(
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
]
area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None
if tool_input.device_id:
device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id)
if device:
area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = fr.async_get(self.hass)
if area.floor_id and (
floor := floor_reg.async_get_floor(area.floor_id)
):
prompt.append(f"You are in {area.name} ({floor.name}).")
else:
prompt.append(f"You are in {area.name}.")
if area.floor_id:
floor = floor_reg.async_get_floor(area.floor_id)
extra = "and all generic commands like 'turn on the lights' should target this area."
if floor and area:
prompt.append(f"You are in area {area.name} (floor {floor.name}) {extra}")
elif area:
prompt.append(f"You are in area {area.name} {extra}")
else:
prompt.append(
"Reject all generic commands like 'turn on the lights' because we "
"don't know in what area this conversation is happening."
)
if tool_input.context and tool_input.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
if user:

View File

@ -371,32 +371,44 @@ async def test_assist_api_prompt(
)
first_part_prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
"this conversation is happening."
)
assert prompt == (
f"""{first_part_prompt}
{area_prompt}
{exposed_entities_prompt}"""
)
# Fake that request is made from a specific device ID
tool_input.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"You are in area Test Area and all generic commands like 'turn on the lights' "
"should target this area."
)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area.
{area_prompt}
{exposed_entities_prompt}"""
)
# Add floor
floor = floor_registry.async_create("second floor")
floor = floor_registry.async_create("2")
area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = (
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
"should target this area."
)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area (second floor).
{area_prompt}
{exposed_entities_prompt}"""
)
@ -409,7 +421,7 @@ You are in Test Area (second floor).
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area (second floor).
{area_prompt}
The user name is Test User.
{exposed_entities_prompt}"""
)