Stronger type annotations for conversation content (#140725)

stronger type annotations for conversation content
This commit is contained in:
Denis Shulyaka 2025-03-16 17:59:25 +03:00 committed by GitHub
parent 012b4645f3
commit 056616f9c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 24 deletions

View File

@ -51,8 +51,7 @@ def async_get_chat_log(
)
if user_input is not None and (
(content := chat_log.content[-1]).role != "user"
# MyPy doesn't understand that content is a UserContent here
or content.content != user_input.text # type: ignore[union-attr]
or content.content != user_input.text
):
chat_log.async_add_user_content(UserContent(content=user_input.text))
@ -128,7 +127,7 @@ class ConverseError(HomeAssistantError):
class SystemContent:
"""Base class for chat messages."""
role: str = field(init=False, default="system")
role: Literal["system"] = field(init=False, default="system")
content: str
@ -136,7 +135,7 @@ class SystemContent:
class UserContent:
"""Assistant content."""
role: str = field(init=False, default="user")
role: Literal["user"] = field(init=False, default="user")
content: str
@ -144,7 +143,7 @@ class UserContent:
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
role: Literal["assistant"] = field(init=False, default="assistant")
agent_id: str
content: str | None = None
tool_calls: list[llm.ToolInput] | None = None
@ -154,7 +153,7 @@ class AssistantContent:
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
role: Literal["tool_result"] = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
@ -193,8 +192,8 @@ class ChatLog:
return (
last_msg.role == "assistant"
and last_msg.content is not None # type: ignore[union-attr]
and last_msg.content.strip().endswith( # type: ignore[union-attr]
and last_msg.content is not None
and last_msg.content.strip().endswith(
(
"?",
";", # Greek question mark

View File

@ -188,7 +188,7 @@ def _convert_content(
| conversation.SystemContent,
) -> Content:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
if content.role != "assistant" or not content.tool_calls:
role = "model" if content.role == "assistant" else content.role
return Content(
role=role,
@ -321,24 +321,14 @@ class GoogleGenerativeAIConversationEntity(
for chat_content in chat_log.content[1:-1]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
tool_results.append(chat_content)
continue
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear()
messages.append(
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
messages.append(_convert_content(chat_content))
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))

View File

@ -82,13 +82,13 @@ def _convert_content_to_param(
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = content.role
if content.role != "assistant" or not content.tool_calls:
role: Literal["system", "user", "assistant", "developer"] = content.role
if role == "system":
role = "developer"
return cast(
ChatCompletionMessageParam,
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
{"role": content.role, "content": content.content},
)
# Handle the Assistant content including tool calls.