mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 17:57:55 +00:00
Stronger type annotations for conversation content (#140725)
stronger type annotations for conversation content
This commit is contained in:
parent
012b4645f3
commit
056616f9c5
@ -51,8 +51,7 @@ def async_get_chat_log(
|
|||||||
)
|
)
|
||||||
if user_input is not None and (
|
if user_input is not None and (
|
||||||
(content := chat_log.content[-1]).role != "user"
|
(content := chat_log.content[-1]).role != "user"
|
||||||
# MyPy doesn't understand that content is a UserContent here
|
or content.content != user_input.text
|
||||||
or content.content != user_input.text # type: ignore[union-attr]
|
|
||||||
):
|
):
|
||||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||||
|
|
||||||
@ -128,7 +127,7 @@ class ConverseError(HomeAssistantError):
|
|||||||
class SystemContent:
|
class SystemContent:
|
||||||
"""Base class for chat messages."""
|
"""Base class for chat messages."""
|
||||||
|
|
||||||
role: str = field(init=False, default="system")
|
role: Literal["system"] = field(init=False, default="system")
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@ -136,7 +135,7 @@ class SystemContent:
|
|||||||
class UserContent:
|
class UserContent:
|
||||||
"""Assistant content."""
|
"""Assistant content."""
|
||||||
|
|
||||||
role: str = field(init=False, default="user")
|
role: Literal["user"] = field(init=False, default="user")
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@ -144,7 +143,7 @@ class UserContent:
|
|||||||
class AssistantContent:
|
class AssistantContent:
|
||||||
"""Assistant content."""
|
"""Assistant content."""
|
||||||
|
|
||||||
role: str = field(init=False, default="assistant")
|
role: Literal["assistant"] = field(init=False, default="assistant")
|
||||||
agent_id: str
|
agent_id: str
|
||||||
content: str | None = None
|
content: str | None = None
|
||||||
tool_calls: list[llm.ToolInput] | None = None
|
tool_calls: list[llm.ToolInput] | None = None
|
||||||
@ -154,7 +153,7 @@ class AssistantContent:
|
|||||||
class ToolResultContent:
|
class ToolResultContent:
|
||||||
"""Tool result content."""
|
"""Tool result content."""
|
||||||
|
|
||||||
role: str = field(init=False, default="tool_result")
|
role: Literal["tool_result"] = field(init=False, default="tool_result")
|
||||||
agent_id: str
|
agent_id: str
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
@ -193,8 +192,8 @@ class ChatLog:
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
last_msg.role == "assistant"
|
last_msg.role == "assistant"
|
||||||
and last_msg.content is not None # type: ignore[union-attr]
|
and last_msg.content is not None
|
||||||
and last_msg.content.strip().endswith( # type: ignore[union-attr]
|
and last_msg.content.strip().endswith(
|
||||||
(
|
(
|
||||||
"?",
|
"?",
|
||||||
";", # Greek question mark
|
";", # Greek question mark
|
||||||
|
@ -188,7 +188,7 @@ def _convert_content(
|
|||||||
| conversation.SystemContent,
|
| conversation.SystemContent,
|
||||||
) -> Content:
|
) -> Content:
|
||||||
"""Convert HA content to Google content."""
|
"""Convert HA content to Google content."""
|
||||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
if content.role != "assistant" or not content.tool_calls:
|
||||||
role = "model" if content.role == "assistant" else content.role
|
role = "model" if content.role == "assistant" else content.role
|
||||||
return Content(
|
return Content(
|
||||||
role=role,
|
role=role,
|
||||||
@ -321,24 +321,14 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
|
|
||||||
for chat_content in chat_log.content[1:-1]:
|
for chat_content in chat_log.content[1:-1]:
|
||||||
if chat_content.role == "tool_result":
|
if chat_content.role == "tool_result":
|
||||||
# mypy doesn't like picking a type based on checking shared property 'role'
|
tool_results.append(chat_content)
|
||||||
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_results:
|
if tool_results:
|
||||||
messages.append(_create_google_tool_response_content(tool_results))
|
messages.append(_create_google_tool_response_content(tool_results))
|
||||||
tool_results.clear()
|
tool_results.clear()
|
||||||
|
|
||||||
messages.append(
|
messages.append(_convert_content(chat_content))
|
||||||
_convert_content(
|
|
||||||
cast(
|
|
||||||
conversation.UserContent
|
|
||||||
| conversation.SystemContent
|
|
||||||
| conversation.AssistantContent,
|
|
||||||
chat_content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_results:
|
if tool_results:
|
||||||
messages.append(_create_google_tool_response_content(tool_results))
|
messages.append(_create_google_tool_response_content(tool_results))
|
||||||
|
@ -82,13 +82,13 @@ def _convert_content_to_param(
|
|||||||
tool_call_id=content.tool_call_id,
|
tool_call_id=content.tool_call_id,
|
||||||
content=json.dumps(content.tool_result),
|
content=json.dumps(content.tool_result),
|
||||||
)
|
)
|
||||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
if content.role != "assistant" or not content.tool_calls:
|
||||||
role = content.role
|
role: Literal["system", "user", "assistant", "developer"] = content.role
|
||||||
if role == "system":
|
if role == "system":
|
||||||
role = "developer"
|
role = "developer"
|
||||||
return cast(
|
return cast(
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
|
{"role": content.role, "content": content.content},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle the Assistant content including tool calls.
|
# Handle the Assistant content including tool calls.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user