Stronger type annotations for conversation content (#140725)

stronger type annotations for conversation content
This commit is contained in:
Denis Shulyaka 2025-03-16 17:59:25 +03:00 committed by GitHub
parent 012b4645f3
commit 056616f9c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 24 deletions

View File

@ -51,8 +51,7 @@ def async_get_chat_log(
) )
if user_input is not None and ( if user_input is not None and (
(content := chat_log.content[-1]).role != "user" (content := chat_log.content[-1]).role != "user"
# MyPy doesn't understand that content is a UserContent here or content.content != user_input.text
or content.content != user_input.text # type: ignore[union-attr]
): ):
chat_log.async_add_user_content(UserContent(content=user_input.text)) chat_log.async_add_user_content(UserContent(content=user_input.text))
@ -128,7 +127,7 @@ class ConverseError(HomeAssistantError):
class SystemContent: class SystemContent:
"""Base class for chat messages.""" """Base class for chat messages."""
role: str = field(init=False, default="system") role: Literal["system"] = field(init=False, default="system")
content: str content: str
@ -136,7 +135,7 @@ class SystemContent:
class UserContent: class UserContent:
"""Assistant content.""" """Assistant content."""
role: str = field(init=False, default="user") role: Literal["user"] = field(init=False, default="user")
content: str content: str
@ -144,7 +143,7 @@ class UserContent:
class AssistantContent: class AssistantContent:
"""Assistant content.""" """Assistant content."""
role: str = field(init=False, default="assistant") role: Literal["assistant"] = field(init=False, default="assistant")
agent_id: str agent_id: str
content: str | None = None content: str | None = None
tool_calls: list[llm.ToolInput] | None = None tool_calls: list[llm.ToolInput] | None = None
@ -154,7 +153,7 @@ class AssistantContent:
class ToolResultContent: class ToolResultContent:
"""Tool result content.""" """Tool result content."""
role: str = field(init=False, default="tool_result") role: Literal["tool_result"] = field(init=False, default="tool_result")
agent_id: str agent_id: str
tool_call_id: str tool_call_id: str
tool_name: str tool_name: str
@ -193,8 +192,8 @@ class ChatLog:
return ( return (
last_msg.role == "assistant" last_msg.role == "assistant"
and last_msg.content is not None # type: ignore[union-attr] and last_msg.content is not None
and last_msg.content.strip().endswith( # type: ignore[union-attr] and last_msg.content.strip().endswith(
( (
"?", "?",
";", # Greek question mark ";", # Greek question mark

View File

@ -188,7 +188,7 @@ def _convert_content(
| conversation.SystemContent, | conversation.SystemContent,
) -> Content: ) -> Content:
"""Convert HA content to Google content.""" """Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] if content.role != "assistant" or not content.tool_calls:
role = "model" if content.role == "assistant" else content.role role = "model" if content.role == "assistant" else content.role
return Content( return Content(
role=role, role=role,
@ -321,24 +321,14 @@ class GoogleGenerativeAIConversationEntity(
for chat_content in chat_log.content[1:-1]: for chat_content in chat_log.content[1:-1]:
if chat_content.role == "tool_result": if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role' tool_results.append(chat_content)
tool_results.append(cast(conversation.ToolResultContent, chat_content))
continue continue
if tool_results: if tool_results:
messages.append(_create_google_tool_response_content(tool_results)) messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear() tool_results.clear()
messages.append( messages.append(_convert_content(chat_content))
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
if tool_results: if tool_results:
messages.append(_create_google_tool_response_content(tool_results)) messages.append(_create_google_tool_response_content(tool_results))

View File

@ -82,13 +82,13 @@ def _convert_content_to_param(
tool_call_id=content.tool_call_id, tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result), content=json.dumps(content.tool_result),
) )
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] if content.role != "assistant" or not content.tool_calls:
role = content.role role: Literal["system", "user", "assistant", "developer"] = content.role
if role == "system": if role == "system":
role = "developer" role = "developer"
return cast( return cast(
ChatCompletionMessageParam, ChatCompletionMessageParam,
{"role": content.role, "content": content.content}, # type: ignore[union-attr] {"role": content.role, "content": content.content},
) )
# Handle the Assistant content including tool calls. # Handle the Assistant content including tool calls.