mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Improve conversation typing (#136084)
This commit is contained in:
parent
a7d5e52ffe
commit
29b7d5c2e4
@ -1,11 +1,13 @@
|
|||||||
"""Conversation history."""
|
"""Conversation history."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
from typing import Generic, Literal, TypeVar
|
from typing import Literal
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
@ -25,16 +27,15 @@ from homeassistant.util.hass_dict import HassKey
|
|||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey(
|
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
|
||||||
"conversation_chat_session"
|
"conversation_chat_session"
|
||||||
)
|
)
|
||||||
DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey(
|
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
|
||||||
"conversation_chat_session_cleanup"
|
"conversation_chat_session_cleanup"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||||
_NativeT = TypeVar("_NativeT")
|
|
||||||
|
|
||||||
|
|
||||||
class SessionCleanup:
|
class SessionCleanup:
|
||||||
@ -89,7 +90,7 @@ class SessionCleanup:
|
|||||||
async def async_get_chat_session(
|
async def async_get_chat_session(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
) -> AsyncGenerator["ChatSession"]:
|
) -> AsyncGenerator[ChatSession]:
|
||||||
"""Return chat session."""
|
"""Return chat session."""
|
||||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||||
if all_history is None:
|
if all_history is None:
|
||||||
@ -164,7 +165,7 @@ class ConverseError(HomeAssistantError):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage(Generic[_NativeT]):
|
class ChatMessage[_NativeT]:
|
||||||
"""Base class for chat messages.
|
"""Base class for chat messages.
|
||||||
|
|
||||||
When role is native, the content is to be ignored and message
|
When role is native, the content is to be ignored and message
|
||||||
@ -184,7 +185,7 @@ class ChatMessage(Generic[_NativeT]):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatSession(Generic[_NativeT]):
|
class ChatSession[_NativeT]:
|
||||||
"""Class holding all information for a specific conversation."""
|
"""Class holding all information for a specific conversation."""
|
||||||
|
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
|
Loading…
x
Reference in New Issue
Block a user