mirror of
https://github.com/home-assistant/core.git
synced 2025-08-01 01:28:24 +00:00
Merge branch 'dev' into dev
This commit is contained in:
commit
3ceb25adae
@ -7,6 +7,6 @@
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["python_homeassistant_analytics"],
|
||||
"requirements": ["python-homeassistant-analytics==0.8.1"],
|
||||
"requirements": ["python-homeassistant-analytics==0.9.0"],
|
||||
"single_config_entry": true
|
||||
}
|
||||
|
@ -272,6 +272,7 @@ class AnthropicConversationEntity(
|
||||
continue
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool_args=cast(dict[str, Any], tool_call.input),
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
@ -114,24 +115,25 @@ async def async_pipeline_from_audio_stream(
|
||||
|
||||
Raises PipelineNotFound if no pipeline is found.
|
||||
"""
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_extra_system_prompt=conversation_extra_system_prompt,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=session.conversation_id,
|
||||
device_id=device_id,
|
||||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_extra_system_prompt=conversation_extra_system_prompt,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
@ -624,7 +624,7 @@ class PipelineRun:
|
||||
return
|
||||
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
|
||||
|
||||
def start(self, device_id: str | None) -> None:
|
||||
def start(self, conversation_id: str, device_id: str | None) -> None:
|
||||
"""Emit run start event."""
|
||||
self._device_id = device_id
|
||||
self._start_debug_recording_thread()
|
||||
@ -632,6 +632,7 @@ class PipelineRun:
|
||||
data = {
|
||||
"pipeline": self.pipeline.id,
|
||||
"language": self.language,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
if self.runner_data is not None:
|
||||
data["runner_data"] = self.runner_data
|
||||
@ -1015,7 +1016,7 @@ class PipelineRun:
|
||||
async def recognize_intent(
|
||||
self,
|
||||
intent_input: str,
|
||||
conversation_id: str | None,
|
||||
conversation_id: str,
|
||||
device_id: str | None,
|
||||
conversation_extra_system_prompt: str | None,
|
||||
) -> str:
|
||||
@ -1063,11 +1064,11 @@ class PipelineRun:
|
||||
agent_id=self.intent_agent,
|
||||
extra_system_prompt=conversation_extra_system_prompt,
|
||||
)
|
||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
agent_id = user_input.agent_id
|
||||
agent_id = self.intent_agent
|
||||
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||
intent_response: intent.IntentResponse | None = None
|
||||
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
||||
if not processed_locally:
|
||||
# Sentence triggers override conversation agent
|
||||
if (
|
||||
trigger_response_text
|
||||
@ -1105,13 +1106,13 @@ class PipelineRun:
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
async for _ in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=speech,
|
||||
)
|
||||
)
|
||||
):
|
||||
pass
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=session.conversation_id,
|
||||
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
|
||||
wav_writer.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class PipelineInput:
|
||||
"""Input to a pipeline run."""
|
||||
|
||||
run: PipelineRun
|
||||
|
||||
conversation_id: str
|
||||
"""Identifier for the conversation."""
|
||||
|
||||
stt_metadata: stt.SpeechMetadata | None = None
|
||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||
|
||||
@ -1430,9 +1434,6 @@ class PipelineInput:
|
||||
tts_input: str | None = None
|
||||
"""Input for text-to-speech. Required when start_stage = tts."""
|
||||
|
||||
conversation_id: str | None = None
|
||||
"""Identifier for the conversation."""
|
||||
|
||||
conversation_extra_system_prompt: str | None = None
|
||||
"""Extra prompt information for the conversation agent."""
|
||||
|
||||
@ -1441,7 +1442,7 @@ class PipelineInput:
|
||||
|
||||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start(device_id=self.device_id)
|
||||
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
@ -14,7 +14,11 @@ import voluptuous as vol
|
||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, entity_registry as er
|
||||
from homeassistant.helpers import (
|
||||
chat_session,
|
||||
config_validation as cv,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import (
|
||||
@ -145,7 +149,6 @@ async def websocket_run(
|
||||
|
||||
# Arguments to PipelineInput
|
||||
input_args: dict[str, Any] = {
|
||||
"conversation_id": msg.get("conversation_id"),
|
||||
"device_id": msg.get("device_id"),
|
||||
}
|
||||
|
||||
@ -233,38 +236,42 @@ async def websocket_run(
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
)
|
||||
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
with chat_session.async_get_chat_session(
|
||||
hass, msg.get("conversation_id")
|
||||
) as session:
|
||||
input_args["conversation_id"] = session.conversation_id
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
return
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
return
|
||||
|
||||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
run_task = hass.async_create_task(pipeline_input.execute())
|
||||
run_task = hass.async_create_task(pipeline_input.execute())
|
||||
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
|
||||
try:
|
||||
# Task contains a timeout
|
||||
async with asyncio.timeout(timeout):
|
||||
await run_task
|
||||
except TimeoutError:
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
try:
|
||||
# Task contains a timeout
|
||||
async with asyncio.timeout(timeout):
|
||||
await run_task
|
||||
except TimeoutError:
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final, Literal, final
|
||||
from typing import Any, Literal, final
|
||||
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
|
||||
)
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers import chat_session, entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
|
||||
from .const import AssistSatelliteEntityFeature
|
||||
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
_attr_vad_sensitivity_entity_id: str | None = None
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
_is_announcing = False
|
||||
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
else:
|
||||
self._extra_system_prompt = start_message or None
|
||||
|
||||
with (
|
||||
# Not passing in a conversation ID will force a new one to be created
|
||||
chat_session.async_get_chat_session(self.hass) as session,
|
||||
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||
):
|
||||
self._conversation_id = session.conversation_id
|
||||
|
||||
if start_message:
|
||||
async for _tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=self.entity_id, content=start_message
|
||||
)
|
||||
):
|
||||
pass # no tool responses.
|
||||
|
||||
try:
|
||||
await self.async_start_conversation(announcement)
|
||||
finally:
|
||||
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if self._conversation_id_time and (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
self._conversation_id_time = None
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
assert self.platform.config_entry is not None
|
||||
self._pipeline_task = self.platform.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output=self.tts_options,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
conversation_extra_system_prompt=extra_system_prompt,
|
||||
),
|
||||
f"{self.entity_id}_pipeline",
|
||||
)
|
||||
|
||||
try:
|
||||
await self._pipeline_task
|
||||
finally:
|
||||
self._pipeline_task = None
|
||||
with chat_session.async_get_chat_session(
|
||||
self.hass, self._conversation_id
|
||||
) as session:
|
||||
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
|
||||
self._conversation_id = session.conversation_id
|
||||
self._pipeline_task = (
|
||||
self.platform.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=session.conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output=self.tts_options,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
conversation_extra_system_prompt=extra_system_prompt,
|
||||
),
|
||||
f"{self.entity_id}_pipeline",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await self._pipeline_task
|
||||
finally:
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _cancel_running_pipeline(self) -> None:
|
||||
"""Cancel the current pipeline if it's running."""
|
||||
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
self._set_state(AssistSatelliteState.LISTENING)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.INTENT_END:
|
||||
assert event.data is not None
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
|
@ -19,6 +19,8 @@ from .const import (
|
||||
)
|
||||
from .entity import BangOlufsenEntity
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
|
@ -20,7 +20,7 @@
|
||||
"bluetooth-adapters==0.21.1",
|
||||
"bluetooth-auto-recovery==1.4.2",
|
||||
"bluetooth-data-tools==1.23.3",
|
||||
"dbus-fast==2.31.0",
|
||||
"dbus-fast==2.32.0",
|
||||
"habluetooth==3.21.0"
|
||||
]
|
||||
}
|
||||
|
@ -30,6 +30,16 @@ from .agent_manager import (
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .chat_log import (
|
||||
AssistantContent,
|
||||
ChatLog,
|
||||
Content,
|
||||
ConverseError,
|
||||
SystemContent,
|
||||
ToolResultContent,
|
||||
UserContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from .const import (
|
||||
ATTR_AGENT_ID,
|
||||
ATTR_CONVERSATION_ID,
|
||||
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"AssistantContent",
|
||||
"ChatLog",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
@ -63,7 +73,9 @@ __all__ = [
|
||||
"ConversationResult",
|
||||
"ConversationTraceEventType",
|
||||
"ConverseError",
|
||||
"NativeContent",
|
||||
"SystemContent",
|
||||
"ToolResultContent",
|
||||
"UserContent",
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
|
@ -2,19 +2,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import chat_session, intent, llm, template
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
|
||||
def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
session: chat_session.ChatSession,
|
||||
user_input: ConversationInput,
|
||||
user_input: ConversationInput | None = None,
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
@ -42,9 +39,9 @@ def async_get_chat_log(
|
||||
history = all_history.get(session.conversation_id)
|
||||
|
||||
if history:
|
||||
history = replace(history, messages=history.messages.copy())
|
||||
history = replace(history, content=history.content.copy())
|
||||
else:
|
||||
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
|
||||
history = ChatLog(hass, session.conversation_id)
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
@ -53,22 +50,19 @@ def async_get_chat_log(
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
message: Content = Content(
|
||||
role="user",
|
||||
agent_id=user_input.agent_id,
|
||||
content=user_input.text,
|
||||
)
|
||||
history.async_add_message(message)
|
||||
if user_input is not None:
|
||||
history.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
last_message = history.content[-1]
|
||||
|
||||
yield history
|
||||
|
||||
if history.messages[-1] is message:
|
||||
if history.content[-1] is last_message:
|
||||
LOGGER.debug(
|
||||
"History opened but no assistant message was added, ignoring update"
|
||||
)
|
||||
return
|
||||
|
||||
history.last_updated = dt_util.utcnow()
|
||||
all_history[session.conversation_id] = history
|
||||
|
||||
|
||||
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Content:
|
||||
@dataclass(frozen=True)
|
||||
class SystemContent:
|
||||
"""Base class for chat messages."""
|
||||
|
||||
role: Literal["system", "assistant", "user"]
|
||||
agent_id: str | None
|
||||
role: str = field(init=False, default="system")
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeContent[_NativeT]:
|
||||
"""Native content."""
|
||||
class UserContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="native")
|
||||
role: str = field(init=False, default="user")
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssistantContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="assistant")
|
||||
agent_id: str
|
||||
content: _NativeT
|
||||
content: str
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultContent:
|
||||
"""Tool result content."""
|
||||
|
||||
role: str = field(init=False, default="tool_result")
|
||||
agent_id: str
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_result: JsonObjectType
|
||||
|
||||
|
||||
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatLog[_NativeT]:
|
||||
class ChatLog:
|
||||
"""Class holding the chat history of a specific conversation."""
|
||||
|
||||
hass: HomeAssistant
|
||||
conversation_id: str
|
||||
agent_id: str | None
|
||||
user_name: str | None = None
|
||||
messages: list[Content | NativeContent[_NativeT]] = field(
|
||||
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
|
||||
)
|
||||
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||
|
||||
@callback
|
||||
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
|
||||
"""Process intent."""
|
||||
if message.role == "system":
|
||||
raise ValueError("Cannot add system messages to history")
|
||||
if message.role != "native" and self.messages[-1].role == message.role:
|
||||
raise ValueError("Cannot add two assistant or user messages in a row")
|
||||
def async_add_user_content(self, content: UserContent) -> None:
|
||||
"""Add user content to the log."""
|
||||
self.content.append(content)
|
||||
|
||||
self.messages.append(message)
|
||||
async def async_add_assistant_content(
|
||||
self, content: AssistantContent
|
||||
) -> AsyncGenerator[ToolResultContent]:
|
||||
"""Add assistant content."""
|
||||
self.content.append(content)
|
||||
|
||||
@callback
|
||||
def async_get_messages(
|
||||
self, agent_id: str | None = None
|
||||
) -> list[Content | NativeContent[_NativeT]]:
|
||||
"""Get messages for a specific agent ID.
|
||||
if content.tool_calls is None:
|
||||
return
|
||||
|
||||
This will filter out any native message tied to other agent IDs.
|
||||
It can still include assistant/user messages generated by other agents.
|
||||
"""
|
||||
return [
|
||||
message
|
||||
for message in self.messages
|
||||
if message.role != "native" or message.agent_id == agent_id
|
||||
]
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
|
||||
for tool_input in content.tool_calls:
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_result = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_result["error_text"] = str(e)
|
||||
LOGGER.debug("Tool response: %s", tool_result)
|
||||
|
||||
response_content = ToolResultContent(
|
||||
agent_id=content.agent_id,
|
||||
tool_call_id=tool_input.id,
|
||||
tool_name=tool_input.tool_name,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
self.content.append(response_content)
|
||||
yield response_content
|
||||
|
||||
async def async_update_llm_data(
|
||||
self,
|
||||
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
self.llm_api = llm_api
|
||||
self.user_name = user_name
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.messages[0] = Content(
|
||||
role="system",
|
||||
agent_id=user_input.agent_id,
|
||||
content=prompt,
|
||||
)
|
||||
self.content[0] = SystemContent(content=prompt)
|
||||
|
||||
LOGGER.debug("Prompt: %s", self.messages)
|
||||
LOGGER.debug("Prompt: %s", self.content)
|
||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{
|
||||
"messages": self.messages,
|
||||
"messages": self.content,
|
||||
"tools": self.llm_api.tools if self.llm_api else None,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
|
||||
"""Invoke LLM tool for the configured LLM API."""
|
||||
if not self.llm_api:
|
||||
raise ValueError("No LLM API configured")
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||
|
||||
try:
|
||||
tool_response = await self.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
return tool_response
|
@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_state_added_domain
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .chat_log import AssistantContent, async_get_chat_log
|
||||
from .const import (
|
||||
DATA_DEFAULT_ENTITY,
|
||||
DEFAULT_EXPOSED_ATTRIBUTES,
|
||||
@ -63,7 +64,6 @@ from .const import (
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .session import Content, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
||||
content=speech,
|
||||
)
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
return ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
|
@ -22,5 +22,5 @@
|
||||
"integration_type": "device",
|
||||
"iot_class": "local_polling",
|
||||
"loggers": ["eq3btsmart"],
|
||||
"requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.6.0"]
|
||||
"requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.7.0"]
|
||||
}
|
||||
|
@ -18,7 +18,7 @@
|
||||
"requirements": [
|
||||
"aioesphomeapi==29.0.0",
|
||||
"esphome-dashboard-api==1.2.3",
|
||||
"bleak-esphome==2.6.0"
|
||||
"bleak-esphome==2.7.0"
|
||||
],
|
||||
"zeroconf": ["_esphomelib._tcp.local."]
|
||||
}
|
||||
|
@ -28,14 +28,14 @@
|
||||
"user": {
|
||||
"description": "Enter the settings to connect to the camera.",
|
||||
"data": {
|
||||
"still_image_url": "Still Image URL (e.g. http://...)",
|
||||
"stream_source": "Stream Source URL (e.g. rtsp://...)",
|
||||
"still_image_url": "Still image URL (e.g. http://...)",
|
||||
"stream_source": "Stream source URL (e.g. rtsp://...)",
|
||||
"rtsp_transport": "RTSP transport protocol",
|
||||
"authentication": "Authentication",
|
||||
"limit_refetch_to_url_change": "Limit refetch to url change",
|
||||
"limit_refetch_to_url_change": "Limit refetch to URL change",
|
||||
"password": "[%key:common::config_flow::data::password%]",
|
||||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"framerate": "Frame Rate (Hz)",
|
||||
"framerate": "Frame rate (Hz)",
|
||||
"verify_ssl": "[%key:common::config_flow::data::verify_ssl%]"
|
||||
}
|
||||
},
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from google.api_core.exceptions import GoogleAPIError
|
||||
import google.generativeai as genai
|
||||
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
|
||||
) -> genai_types.ContentDict:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.role == "native":
|
||||
return message.content
|
||||
def _create_google_tool_response_content(
|
||||
content: list[conversation.ToolResultContent],
|
||||
) -> protos.Content:
|
||||
"""Create a Google tool response content."""
|
||||
return protos.Content(
|
||||
parts=[
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
name=tool_result.tool_name, response=tool_result.tool_result
|
||||
)
|
||||
)
|
||||
for tool_result in content
|
||||
]
|
||||
)
|
||||
|
||||
role = "model" if message.role == "assistant" else message.role
|
||||
return {"role": role, "parts": message.content}
|
||||
|
||||
def _convert_content(
|
||||
content: conversation.UserContent
|
||||
| conversation.AssistantContent
|
||||
| conversation.SystemContent,
|
||||
) -> genai_types.ContentDict:
|
||||
"""Convert HA content to Google content."""
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = "model" if content.role == "assistant" else content.role
|
||||
return {"role": role, "parts": content.content}
|
||||
|
||||
# Handle the Assistant content with tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
parts = []
|
||||
|
||||
if content.content:
|
||||
parts.append(protos.Part(text=content.content))
|
||||
|
||||
if content.tool_calls:
|
||||
parts.extend(
|
||||
[
|
||||
protos.Part(
|
||||
function_call=protos.FunctionCall(
|
||||
name=tool_call.tool_name,
|
||||
args=_escape_decode(tool_call.tool_args),
|
||||
)
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
return protos.Content({"role": "model", "parts": parts})
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatLog[genai_types.ContentDict],
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
|
||||
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await session.async_update_llm_data(
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if session.llm_api:
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
|
||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||
)
|
||||
|
||||
prompt, *messages = [
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
prompt = chat_log.content[0].content # type: ignore[union-attr]
|
||||
messages: list[genai_types.ContentDict] = []
|
||||
|
||||
# Google groups tool results, we do not. Group them before sending.
|
||||
tool_results: list[conversation.ToolResultContent] = []
|
||||
|
||||
for chat_content in chat_log.content[1:]:
|
||||
if chat_content.role == "tool_result":
|
||||
# mypy doesn't like picking a type based on checking shared property 'role'
|
||||
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
||||
continue
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
tool_results.clear()
|
||||
|
||||
messages.append(
|
||||
_convert_content(
|
||||
cast(
|
||||
conversation.UserContent
|
||||
| conversation.SystemContent
|
||||
| conversation.AssistantContent,
|
||||
chat_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config={
|
||||
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
|
||||
),
|
||||
},
|
||||
tools=tools or None,
|
||||
system_instruction=prompt["parts"] if supports_system_instruction else None,
|
||||
system_instruction=prompt if supports_system_instruction else None,
|
||||
)
|
||||
|
||||
if not supports_system_instruction:
|
||||
messages = [
|
||||
{"role": "user", "parts": prompt["parts"]},
|
||||
{"role": "user", "parts": prompt},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
*messages,
|
||||
]
|
||||
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
|
||||
content = " ".join(
|
||||
[part.text.strip() for part in chat_response.parts if part.text]
|
||||
)
|
||||
if content:
|
||||
session.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
function_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
|
||||
if not function_calls or not session.llm_api:
|
||||
break
|
||||
|
||||
tool_responses = []
|
||||
for function_call in function_calls:
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_calls = []
|
||||
for part in chat_response.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = _escape_decode(tool_call["args"])
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
function_response = await session.async_call_tool(tool_input)
|
||||
tool_responses.append(
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
name=tool_name, response=function_response
|
||||
tool_calls.append(
|
||||
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
)
|
||||
|
||||
chat_request = _create_google_tool_response_content(
|
||||
[
|
||||
tool_response
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
)
|
||||
)
|
||||
chat_request = protos.Content(parts=tool_responses)
|
||||
session.async_add_message(
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=chat_request,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(
|
||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
response=response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
|
@ -517,17 +517,22 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
|
||||
raise HomeAssistantError(message) from err
|
||||
|
||||
restore_complete = asyncio.Event()
|
||||
restore_errors: list[dict[str, str]] = []
|
||||
|
||||
@callback
|
||||
def on_job_progress(data: Mapping[str, Any]) -> None:
|
||||
"""Handle backup restore progress."""
|
||||
if data.get("done") is True:
|
||||
restore_complete.set()
|
||||
restore_errors.extend(data.get("errors", []))
|
||||
|
||||
unsub = self._async_listen_job_events(job.job_id, on_job_progress)
|
||||
try:
|
||||
await self._get_job_state(job.job_id, on_job_progress)
|
||||
await restore_complete.wait()
|
||||
if restore_errors:
|
||||
# We should add more specific error handling here in the future
|
||||
raise BackupReaderWriterError(f"Restore failed: {restore_errors}")
|
||||
finally:
|
||||
unsub()
|
||||
|
||||
@ -554,11 +559,23 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
|
||||
)
|
||||
return
|
||||
|
||||
on_progress(
|
||||
RestoreBackupEvent(
|
||||
reason="", stage=None, state=RestoreBackupState.COMPLETED
|
||||
restore_errors = data.get("errors", [])
|
||||
if restore_errors:
|
||||
_LOGGER.warning("Restore backup failed: %s", restore_errors)
|
||||
# We should add more specific error handling here in the future
|
||||
on_progress(
|
||||
RestoreBackupEvent(
|
||||
reason="unknown_error",
|
||||
stage=None,
|
||||
state=RestoreBackupState.FAILED,
|
||||
)
|
||||
)
|
||||
else:
|
||||
on_progress(
|
||||
RestoreBackupEvent(
|
||||
reason="", stage=None, state=RestoreBackupState.COMPLETED
|
||||
)
|
||||
)
|
||||
)
|
||||
on_progress(IdleEvent())
|
||||
unsub()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Constants for the homee integration."""
|
||||
|
||||
from homeassistant.const import (
|
||||
DEGREE,
|
||||
LIGHT_LUX,
|
||||
PERCENTAGE,
|
||||
REVOLUTIONS_PER_MINUTE,
|
||||
@ -32,6 +33,7 @@ HOMEE_UNIT_TO_HA_UNIT = {
|
||||
"W": UnitOfPower.WATT,
|
||||
"m/s": UnitOfSpeed.METERS_PER_SECOND,
|
||||
"km/h": UnitOfSpeed.KILOMETERS_PER_HOUR,
|
||||
"°": DEGREE,
|
||||
"°F": UnitOfTemperature.FAHRENHEIT,
|
||||
"°C": UnitOfTemperature.CELSIUS,
|
||||
"K": UnitOfTemperature.KELVIN,
|
||||
@ -51,7 +53,7 @@ OPEN_CLOSE_MAP_REVERSED = {
|
||||
0.0: "closed",
|
||||
1.0: "open",
|
||||
2.0: "partial",
|
||||
3.0: "cosing",
|
||||
3.0: "closing",
|
||||
4.0: "opening",
|
||||
}
|
||||
WINDOW_MAP = {
|
||||
|
@ -8,5 +8,5 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/lcn",
|
||||
"iot_class": "local_push",
|
||||
"loggers": ["pypck"],
|
||||
"requirements": ["pypck==0.8.3", "lcn-frontend==0.2.3"]
|
||||
"requirements": ["pypck==0.8.5", "lcn-frontend==0.2.3"]
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ class MotionMountEntity(Entity):
|
||||
self.config_entry = config_entry
|
||||
|
||||
# We store the pin, as we might need it during reconnect
|
||||
self.pin = config_entry.data[CONF_PIN]
|
||||
self.pin = config_entry.data.get(CONF_PIN)
|
||||
|
||||
mac = format_mac(mm.mac.hex())
|
||||
|
||||
|
@ -5,34 +5,33 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
|
||||
from msgraph import GraphRequestAdapter, GraphServiceClient
|
||||
from msgraph.generated.drives.item.items.items_request_builder import (
|
||||
ItemsRequestBuilder,
|
||||
from onedrive_personal_sdk import OneDriveClient
|
||||
from onedrive_personal_sdk.exceptions import (
|
||||
AuthenticationError,
|
||||
HttpRequestException,
|
||||
OneDriveException,
|
||||
)
|
||||
from msgraph.generated.models.drive_item import DriveItem
|
||||
from msgraph.generated.models.folder import Folder
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import (
|
||||
OAuth2Session,
|
||||
async_get_config_entry_implementation,
|
||||
)
|
||||
from homeassistant.helpers.httpx_client import create_async_httpx_client
|
||||
from homeassistant.helpers.instance_id import async_get as async_get_instance_id
|
||||
|
||||
from .api import OneDriveConfigEntryAccessTokenProvider
|
||||
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, OAUTH_SCOPES
|
||||
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneDriveRuntimeData:
|
||||
"""Runtime data for the OneDrive integration."""
|
||||
|
||||
items: ItemsRequestBuilder
|
||||
client: OneDriveClient
|
||||
token_provider: OneDriveConfigEntryAccessTokenProvider
|
||||
backup_folder_id: str
|
||||
|
||||
|
||||
@ -47,29 +46,18 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
|
||||
|
||||
session = OAuth2Session(hass, entry, implementation)
|
||||
|
||||
auth_provider = BaseBearerTokenAuthenticationProvider(
|
||||
access_token_provider=OneDriveConfigEntryAccessTokenProvider(session)
|
||||
)
|
||||
adapter = GraphRequestAdapter(
|
||||
auth_provider=auth_provider,
|
||||
client=create_async_httpx_client(hass, follow_redirects=True),
|
||||
)
|
||||
token_provider = OneDriveConfigEntryAccessTokenProvider(session)
|
||||
|
||||
graph_client = GraphServiceClient(
|
||||
request_adapter=adapter,
|
||||
scopes=OAUTH_SCOPES,
|
||||
)
|
||||
assert entry.unique_id
|
||||
drive_item = graph_client.drives.by_drive_id(entry.unique_id)
|
||||
client = OneDriveClient(token_provider, async_get_clientsession(hass))
|
||||
|
||||
# get approot, will be created automatically if it does not exist
|
||||
try:
|
||||
approot = await drive_item.special.by_drive_item_id("approot").get()
|
||||
except APIError as err:
|
||||
if err.response_status_code == 403:
|
||||
raise ConfigEntryAuthFailed(
|
||||
translation_domain=DOMAIN, translation_key="authentication_failed"
|
||||
) from err
|
||||
approot = await client.get_approot()
|
||||
except AuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed(
|
||||
translation_domain=DOMAIN, translation_key="authentication_failed"
|
||||
) from err
|
||||
except (HttpRequestException, OneDriveException, TimeoutError) as err:
|
||||
_LOGGER.debug("Failed to get approot", exc_info=True)
|
||||
raise ConfigEntryNotReady(
|
||||
translation_domain=DOMAIN,
|
||||
@ -77,24 +65,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
|
||||
translation_placeholders={"folder": "approot"},
|
||||
) from err
|
||||
|
||||
if approot is None or not approot.id:
|
||||
_LOGGER.debug("Failed to get approot, was None")
|
||||
instance_id = await async_get_instance_id(hass)
|
||||
backup_folder_name = f"backups_{instance_id[:8]}"
|
||||
try:
|
||||
backup_folder = await client.create_folder(
|
||||
parent_id=approot.id, name=backup_folder_name
|
||||
)
|
||||
except (HttpRequestException, OneDriveException, TimeoutError) as err:
|
||||
_LOGGER.debug("Failed to create backup folder", exc_info=True)
|
||||
raise ConfigEntryNotReady(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_get_folder",
|
||||
translation_placeholders={"folder": "approot"},
|
||||
)
|
||||
|
||||
instance_id = await async_get_instance_id(hass)
|
||||
backup_folder_id = await _async_create_folder_if_not_exists(
|
||||
items=drive_item.items,
|
||||
base_folder_id=approot.id,
|
||||
folder=f"backups_{instance_id[:8]}",
|
||||
)
|
||||
translation_placeholders={"folder": backup_folder_name},
|
||||
) from err
|
||||
|
||||
entry.runtime_data = OneDriveRuntimeData(
|
||||
items=drive_item.items,
|
||||
backup_folder_id=backup_folder_id,
|
||||
client=client,
|
||||
token_provider=token_provider,
|
||||
backup_folder_id=backup_folder.id,
|
||||
)
|
||||
|
||||
_async_notify_backup_listeners_soon(hass)
|
||||
@ -116,54 +104,3 @@ def _async_notify_backup_listeners(hass: HomeAssistant) -> None:
|
||||
@callback
|
||||
def _async_notify_backup_listeners_soon(hass: HomeAssistant) -> None:
|
||||
hass.loop.call_soon(_async_notify_backup_listeners, hass)
|
||||
|
||||
|
||||
async def _async_create_folder_if_not_exists(
|
||||
items: ItemsRequestBuilder,
|
||||
base_folder_id: str,
|
||||
folder: str,
|
||||
) -> str:
|
||||
"""Check if a folder exists and create it if it does not exist."""
|
||||
folder_item: DriveItem | None = None
|
||||
|
||||
try:
|
||||
folder_item = await items.by_drive_item_id(f"{base_folder_id}:/{folder}:").get()
|
||||
except APIError as err:
|
||||
if err.response_status_code != 404:
|
||||
_LOGGER.debug("Failed to get folder %s", folder, exc_info=True)
|
||||
raise ConfigEntryNotReady(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_get_folder",
|
||||
translation_placeholders={"folder": folder},
|
||||
) from err
|
||||
# is 404 not found, create folder
|
||||
_LOGGER.debug("Creating folder %s", folder)
|
||||
request_body = DriveItem(
|
||||
name=folder,
|
||||
folder=Folder(),
|
||||
additional_data={
|
||||
"@microsoft_graph_conflict_behavior": "fail",
|
||||
},
|
||||
)
|
||||
try:
|
||||
folder_item = await items.by_drive_item_id(base_folder_id).children.post(
|
||||
request_body
|
||||
)
|
||||
except APIError as create_err:
|
||||
_LOGGER.debug("Failed to create folder %s", folder, exc_info=True)
|
||||
raise ConfigEntryNotReady(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_create_folder",
|
||||
translation_placeholders={"folder": folder},
|
||||
) from create_err
|
||||
_LOGGER.debug("Created folder %s", folder)
|
||||
else:
|
||||
_LOGGER.debug("Found folder %s", folder)
|
||||
if folder_item is None or not folder_item.id:
|
||||
_LOGGER.debug("Failed to get folder %s, was None", folder)
|
||||
raise ConfigEntryNotReady(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_get_folder",
|
||||
translation_placeholders={"folder": folder},
|
||||
)
|
||||
return folder_item.id
|
||||
|
@ -1,28 +1,14 @@
|
||||
"""API for OneDrive bound to Home Assistant OAuth."""
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from kiota_abstractions.authentication import AccessTokenProvider, AllowedHostsValidator
|
||||
from onedrive_personal_sdk import TokenProvider
|
||||
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
|
||||
class OneDriveAccessTokenProvider(AccessTokenProvider):
|
||||
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize OneDrive auth."""
|
||||
super().__init__()
|
||||
# currently allowing all hosts
|
||||
self._allowed_hosts_validator = AllowedHostsValidator(allowed_hosts=[])
|
||||
|
||||
def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
|
||||
"""Retrieve the allowed hosts validator."""
|
||||
return self._allowed_hosts_validator
|
||||
|
||||
|
||||
class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
|
||||
class OneDriveConfigFlowAccessTokenProvider(TokenProvider):
|
||||
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
|
||||
|
||||
def __init__(self, token: str) -> None:
|
||||
@ -30,14 +16,12 @@ class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
|
||||
super().__init__()
|
||||
self._token = token
|
||||
|
||||
async def get_authorization_token( # pylint: disable=dangerous-default-value
|
||||
self, uri: str, additional_authentication_context: dict[str, Any] = {}
|
||||
) -> str:
|
||||
"""Return a valid authorization token."""
|
||||
def async_get_access_token(self) -> str:
|
||||
"""Return a valid access token."""
|
||||
return self._token
|
||||
|
||||
|
||||
class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider):
|
||||
class OneDriveConfigEntryAccessTokenProvider(TokenProvider):
|
||||
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
|
||||
|
||||
def __init__(self, oauth_session: config_entry_oauth2_flow.OAuth2Session) -> None:
|
||||
@ -45,9 +29,6 @@ class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider):
|
||||
super().__init__()
|
||||
self._oauth_session = oauth_session
|
||||
|
||||
async def get_authorization_token( # pylint: disable=dangerous-default-value
|
||||
self, uri: str, additional_authentication_context: dict[str, Any] = {}
|
||||
) -> str:
|
||||
"""Return a valid authorization token."""
|
||||
await self._oauth_session.async_ensure_token_valid()
|
||||
def async_get_access_token(self) -> str:
|
||||
"""Return a valid access token."""
|
||||
return cast(str, self._oauth_session.token[CONF_ACCESS_TOKEN])
|
||||
|
@ -2,37 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Coroutine
|
||||
from functools import wraps
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Concatenate, cast
|
||||
from typing import Any, Concatenate
|
||||
|
||||
from httpx import Response, TimeoutException
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from kiota_abstractions.authentication import AnonymousAuthenticationProvider
|
||||
from kiota_abstractions.headers_collection import HeadersCollection
|
||||
from kiota_abstractions.method import Method
|
||||
from kiota_abstractions.native_response_handler import NativeResponseHandler
|
||||
from kiota_abstractions.request_information import RequestInformation
|
||||
from kiota_http.middleware.options import ResponseHandlerOption
|
||||
from msgraph import GraphRequestAdapter
|
||||
from msgraph.generated.drives.item.items.item.content.content_request_builder import (
|
||||
ContentRequestBuilder,
|
||||
from aiohttp import ClientTimeout
|
||||
from onedrive_personal_sdk.clients.large_file_upload import LargeFileUploadClient
|
||||
from onedrive_personal_sdk.exceptions import (
|
||||
AuthenticationError,
|
||||
HashMismatchError,
|
||||
OneDriveException,
|
||||
)
|
||||
from msgraph.generated.drives.item.items.item.create_upload_session.create_upload_session_post_request_body import (
|
||||
CreateUploadSessionPostRequestBody,
|
||||
)
|
||||
from msgraph.generated.drives.item.items.item.drive_item_item_request_builder import (
|
||||
DriveItemItemRequestBuilder,
|
||||
)
|
||||
from msgraph.generated.models.drive_item import DriveItem
|
||||
from msgraph.generated.models.drive_item_uploadable_properties import (
|
||||
DriveItemUploadableProperties,
|
||||
)
|
||||
from msgraph_core.models import LargeFileUploadSession
|
||||
from onedrive_personal_sdk.models.items import File, Folder, ItemUpdate
|
||||
from onedrive_personal_sdk.models.upload import FileInfo
|
||||
|
||||
from homeassistant.components.backup import (
|
||||
AgentBackup,
|
||||
@ -41,14 +26,14 @@ from homeassistant.components.backup import (
|
||||
suggested_filename,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from . import OneDriveConfigEntry
|
||||
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
UPLOAD_CHUNK_SIZE = 16 * 320 * 1024 # 5.2MB
|
||||
MAX_RETRIES = 5
|
||||
TIMEOUT = ClientTimeout(connect=10, total=43200) # 12 hours
|
||||
|
||||
|
||||
async def async_get_backup_agents(
|
||||
@ -92,18 +77,18 @@ def handle_backup_errors[_R, **P](
|
||||
) -> _R:
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except APIError as err:
|
||||
if err.response_status_code == 403:
|
||||
self._entry.async_start_reauth(self._hass)
|
||||
except AuthenticationError as err:
|
||||
self._entry.async_start_reauth(self._hass)
|
||||
raise BackupAgentError("Authentication error") from err
|
||||
except OneDriveException as err:
|
||||
_LOGGER.error(
|
||||
"Error during backup in %s: Status %s, message %s",
|
||||
"Error during backup in %s:, message %s",
|
||||
func.__name__,
|
||||
err.response_status_code,
|
||||
err.message,
|
||||
err,
|
||||
)
|
||||
_LOGGER.debug("Full error: %s", err, exc_info=True)
|
||||
raise BackupAgentError("Backup operation failed") from err
|
||||
except TimeoutException as err:
|
||||
except TimeoutError as err:
|
||||
_LOGGER.error(
|
||||
"Error during backup in %s: Timeout",
|
||||
func.__name__,
|
||||
@ -123,7 +108,8 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
super().__init__()
|
||||
self._hass = hass
|
||||
self._entry = entry
|
||||
self._items = entry.runtime_data.items
|
||||
self._client = entry.runtime_data.client
|
||||
self._token_provider = entry.runtime_data.token_provider
|
||||
self._folder_id = entry.runtime_data.backup_folder_id
|
||||
self.name = entry.title
|
||||
assert entry.unique_id
|
||||
@ -134,24 +120,12 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
self, backup_id: str, **kwargs: Any
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Download a backup file."""
|
||||
# this forces the query to return a raw httpx response, but breaks typing
|
||||
backup = await self._find_item_by_backup_id(backup_id)
|
||||
if backup is None or backup.id is None:
|
||||
item = await self._find_item_by_backup_id(backup_id)
|
||||
if item is None:
|
||||
raise BackupAgentError("Backup not found")
|
||||
|
||||
request_config = (
|
||||
ContentRequestBuilder.ContentRequestBuilderGetRequestConfiguration(
|
||||
options=[ResponseHandlerOption(NativeResponseHandler())],
|
||||
)
|
||||
)
|
||||
response = cast(
|
||||
Response,
|
||||
await self._items.by_drive_item_id(backup.id).content.get(
|
||||
request_configuration=request_config
|
||||
),
|
||||
)
|
||||
|
||||
return response.aiter_bytes(chunk_size=1024)
|
||||
stream = await self._client.download_drive_item(item.id, timeout=TIMEOUT)
|
||||
return stream.iter_chunked(1024)
|
||||
|
||||
@handle_backup_errors
|
||||
async def async_upload_backup(
|
||||
@ -163,27 +137,20 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
) -> None:
|
||||
"""Upload a backup."""
|
||||
|
||||
# upload file in chunks to support large files
|
||||
upload_session_request_body = CreateUploadSessionPostRequestBody(
|
||||
item=DriveItemUploadableProperties(
|
||||
additional_data={
|
||||
"@microsoft.graph.conflictBehavior": "fail",
|
||||
},
|
||||
file = FileInfo(
|
||||
suggested_filename(backup),
|
||||
backup.size,
|
||||
self._folder_id,
|
||||
await open_stream(),
|
||||
)
|
||||
try:
|
||||
item = await LargeFileUploadClient.upload(
|
||||
self._token_provider, file, session=async_get_clientsession(self._hass)
|
||||
)
|
||||
)
|
||||
file_item = self._get_backup_file_item(suggested_filename(backup))
|
||||
upload_session = await file_item.create_upload_session.post(
|
||||
upload_session_request_body
|
||||
)
|
||||
|
||||
if upload_session is None or upload_session.upload_url is None:
|
||||
except HashMismatchError as err:
|
||||
raise BackupAgentError(
|
||||
translation_domain=DOMAIN, translation_key="backup_no_upload_session"
|
||||
)
|
||||
|
||||
await self._upload_file(
|
||||
upload_session.upload_url, await open_stream(), backup.size
|
||||
)
|
||||
"Hash validation failed, backup file might be corrupt"
|
||||
) from err
|
||||
|
||||
# store metadata in description
|
||||
backup_dict = backup.as_dict()
|
||||
@ -191,7 +158,10 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
description = json.dumps(backup_dict)
|
||||
_LOGGER.debug("Creating metadata: %s", description)
|
||||
|
||||
await file_item.patch(DriveItem(description=description))
|
||||
await self._client.update_drive_item(
|
||||
path_or_id=item.id,
|
||||
data=ItemUpdate(description=description),
|
||||
)
|
||||
|
||||
@handle_backup_errors
|
||||
async def async_delete_backup(
|
||||
@ -200,35 +170,31 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Delete a backup file."""
|
||||
backup = await self._find_item_by_backup_id(backup_id)
|
||||
if backup is None or backup.id is None:
|
||||
item = await self._find_item_by_backup_id(backup_id)
|
||||
if item is None:
|
||||
return
|
||||
await self._items.by_drive_item_id(backup.id).delete()
|
||||
await self._client.delete_drive_item(item.id)
|
||||
|
||||
@handle_backup_errors
|
||||
async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]:
|
||||
"""List backups."""
|
||||
backups: list[AgentBackup] = []
|
||||
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get()
|
||||
if items and (values := items.value):
|
||||
for item in values:
|
||||
if (description := item.description) is None:
|
||||
continue
|
||||
if "homeassistant_version" in description:
|
||||
backups.append(self._backup_from_description(description))
|
||||
return backups
|
||||
return [
|
||||
self._backup_from_description(item.description)
|
||||
for item in await self._client.list_drive_items(self._folder_id)
|
||||
if item.description and "homeassistant_version" in item.description
|
||||
]
|
||||
|
||||
@handle_backup_errors
|
||||
async def async_get_backup(
|
||||
self, backup_id: str, **kwargs: Any
|
||||
) -> AgentBackup | None:
|
||||
"""Return a backup."""
|
||||
backup = await self._find_item_by_backup_id(backup_id)
|
||||
if backup is None:
|
||||
return None
|
||||
|
||||
assert backup.description # already checked in _find_item_by_backup_id
|
||||
return self._backup_from_description(backup.description)
|
||||
item = await self._find_item_by_backup_id(backup_id)
|
||||
return (
|
||||
self._backup_from_description(item.description)
|
||||
if item and item.description
|
||||
else None
|
||||
)
|
||||
|
||||
def _backup_from_description(self, description: str) -> AgentBackup:
|
||||
"""Create a backup object from a description."""
|
||||
@ -237,91 +203,13 @@ class OneDriveBackupAgent(BackupAgent):
|
||||
) # OneDrive encodes the description on save automatically
|
||||
return AgentBackup.from_dict(json.loads(description))
|
||||
|
||||
async def _find_item_by_backup_id(self, backup_id: str) -> DriveItem | None:
|
||||
"""Find a backup item by its backup ID."""
|
||||
|
||||
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get()
|
||||
if items and (values := items.value):
|
||||
for item in values:
|
||||
if (description := item.description) is None:
|
||||
continue
|
||||
if backup_id in description:
|
||||
return item
|
||||
return None
|
||||
|
||||
def _get_backup_file_item(self, backup_id: str) -> DriveItemItemRequestBuilder:
|
||||
return self._items.by_drive_item_id(f"{self._folder_id}:/{backup_id}:")
|
||||
|
||||
async def _upload_file(
|
||||
self, upload_url: str, stream: AsyncIterator[bytes], total_size: int
|
||||
) -> None:
|
||||
"""Use custom large file upload; SDK does not support stream."""
|
||||
|
||||
adapter = GraphRequestAdapter(
|
||||
auth_provider=AnonymousAuthenticationProvider(),
|
||||
client=get_async_client(self._hass),
|
||||
async def _find_item_by_backup_id(self, backup_id: str) -> File | Folder | None:
|
||||
"""Find an item by backup ID."""
|
||||
return next(
|
||||
(
|
||||
item
|
||||
for item in await self._client.list_drive_items(self._folder_id)
|
||||
if item.description and backup_id in item.description
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
async def async_upload(
|
||||
start: int, end: int, chunk_data: bytes
|
||||
) -> LargeFileUploadSession:
|
||||
info = RequestInformation()
|
||||
info.url = upload_url
|
||||
info.http_method = Method.PUT
|
||||
info.headers = HeadersCollection()
|
||||
info.headers.try_add("Content-Range", f"bytes {start}-{end}/{total_size}")
|
||||
info.headers.try_add("Content-Length", str(len(chunk_data)))
|
||||
info.headers.try_add("Content-Type", "application/octet-stream")
|
||||
_LOGGER.debug(info.headers.get_all())
|
||||
info.set_stream_content(chunk_data)
|
||||
result = await adapter.send_async(info, LargeFileUploadSession, {})
|
||||
_LOGGER.debug("Next expected range: %s", result.next_expected_ranges)
|
||||
return result
|
||||
|
||||
start = 0
|
||||
buffer: list[bytes] = []
|
||||
buffer_size = 0
|
||||
retries = 0
|
||||
|
||||
async for chunk in stream:
|
||||
buffer.append(chunk)
|
||||
buffer_size += len(chunk)
|
||||
if buffer_size >= UPLOAD_CHUNK_SIZE:
|
||||
chunk_data = b"".join(buffer)
|
||||
uploaded_chunks = 0
|
||||
while (
|
||||
buffer_size > UPLOAD_CHUNK_SIZE
|
||||
): # Loop in case the buffer is >= UPLOAD_CHUNK_SIZE * 2
|
||||
slice_start = uploaded_chunks * UPLOAD_CHUNK_SIZE
|
||||
try:
|
||||
await async_upload(
|
||||
start,
|
||||
start + UPLOAD_CHUNK_SIZE - 1,
|
||||
chunk_data[slice_start : slice_start + UPLOAD_CHUNK_SIZE],
|
||||
)
|
||||
except APIError as err:
|
||||
if (
|
||||
err.response_status_code and err.response_status_code < 500
|
||||
): # no retry on 4xx errors
|
||||
raise
|
||||
if retries < MAX_RETRIES:
|
||||
await asyncio.sleep(2**retries)
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
except TimeoutException:
|
||||
if retries < MAX_RETRIES:
|
||||
retries += 1
|
||||
continue
|
||||
raise
|
||||
retries = 0
|
||||
start += UPLOAD_CHUNK_SIZE
|
||||
uploaded_chunks += 1
|
||||
buffer_size -= UPLOAD_CHUNK_SIZE
|
||||
buffer = [chunk_data[UPLOAD_CHUNK_SIZE * uploaded_chunks :]]
|
||||
|
||||
# upload the remaining bytes
|
||||
if buffer:
|
||||
_LOGGER.debug("Last chunk")
|
||||
chunk_data = b"".join(buffer)
|
||||
await async_upload(start, start + len(chunk_data) - 1, chunk_data)
|
||||
|
@ -4,16 +4,13 @@ from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
|
||||
from kiota_abstractions.method import Method
|
||||
from kiota_abstractions.request_information import RequestInformation
|
||||
from msgraph import GraphRequestAdapter, GraphServiceClient
|
||||
from onedrive_personal_sdk.clients.client import OneDriveClient
|
||||
from onedrive_personal_sdk.exceptions import OneDriveException
|
||||
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
|
||||
from .api import OneDriveConfigFlowAccessTokenProvider
|
||||
from .const import DOMAIN, OAUTH_SCOPES
|
||||
@ -39,48 +36,24 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
|
||||
data: dict[str, Any],
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
auth_provider = BaseBearerTokenAuthenticationProvider(
|
||||
access_token_provider=OneDriveConfigFlowAccessTokenProvider(
|
||||
cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
|
||||
)
|
||||
)
|
||||
adapter = GraphRequestAdapter(
|
||||
auth_provider=auth_provider,
|
||||
client=get_async_client(self.hass),
|
||||
token_provider = OneDriveConfigFlowAccessTokenProvider(
|
||||
cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
|
||||
)
|
||||
|
||||
graph_client = GraphServiceClient(
|
||||
request_adapter=adapter,
|
||||
scopes=OAUTH_SCOPES,
|
||||
graph_client = OneDriveClient(
|
||||
token_provider, async_get_clientsession(self.hass)
|
||||
)
|
||||
|
||||
# need to get adapter from client, as client changes it
|
||||
request_adapter = cast(GraphRequestAdapter, graph_client.request_adapter)
|
||||
|
||||
request_info = RequestInformation(
|
||||
method=Method.GET,
|
||||
url_template="{+baseurl}/me/drive/special/approot",
|
||||
path_parameters={},
|
||||
)
|
||||
parent_span = request_adapter.start_tracing_span(request_info, "get_approot")
|
||||
|
||||
# get the OneDrive id
|
||||
# use low level methods, to avoid files.read permissions
|
||||
# which would be required by drives.me.get()
|
||||
try:
|
||||
response = await request_adapter.get_http_response_message(
|
||||
request_info=request_info, parent_span=parent_span
|
||||
)
|
||||
except APIError:
|
||||
approot = await graph_client.get_approot()
|
||||
except OneDriveException:
|
||||
self.logger.exception("Failed to connect to OneDrive")
|
||||
return self.async_abort(reason="connection_error")
|
||||
except Exception:
|
||||
self.logger.exception("Unknown error")
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
drive: dict = response.json()
|
||||
|
||||
await self.async_set_unique_id(drive["parentReference"]["driveId"])
|
||||
await self.async_set_unique_id(approot.parent_reference.drive_id)
|
||||
|
||||
if self.source == SOURCE_REAUTH:
|
||||
reauth_entry = self._get_reauth_entry()
|
||||
@ -94,10 +67,11 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
|
||||
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
user = drive.get("createdBy", {}).get("user", {}).get("displayName")
|
||||
|
||||
title = f"{user}'s OneDrive" if user else "OneDrive"
|
||||
|
||||
title = (
|
||||
f"{approot.created_by.user.display_name}'s OneDrive"
|
||||
if approot.created_by.user and approot.created_by.user.display_name
|
||||
else "OneDrive"
|
||||
)
|
||||
return self.async_create_entry(title=title, data=data)
|
||||
|
||||
async def async_step_reauth(
|
||||
|
@ -7,7 +7,7 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/onedrive",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["msgraph", "msgraph-core", "kiota"],
|
||||
"loggers": ["onedrive_personal_sdk"],
|
||||
"quality_scale": "bronze",
|
||||
"requirements": ["msgraph-sdk==1.16.0"]
|
||||
"requirements": ["onedrive-personal-sdk==0.0.1"]
|
||||
}
|
||||
|
@ -23,31 +23,18 @@
|
||||
"connection_error": "Failed to connect to OneDrive.",
|
||||
"wrong_drive": "New account does not contain previously configured OneDrive.",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
|
||||
"failed_to_create_folder": "Failed to create backup folder"
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||
},
|
||||
"create_entry": {
|
||||
"default": "[%key:common::config_flow::create_entry::authenticated%]"
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"backup_not_found": {
|
||||
"message": "Backup not found"
|
||||
},
|
||||
"backup_no_content": {
|
||||
"message": "Backup has no content"
|
||||
},
|
||||
"backup_no_upload_session": {
|
||||
"message": "Failed to start backup upload"
|
||||
},
|
||||
"authentication_failed": {
|
||||
"message": "Authentication failed"
|
||||
},
|
||||
"failed_to_get_folder": {
|
||||
"message": "Failed to get {folder} folder"
|
||||
},
|
||||
"failed_to_create_folder": {
|
||||
"message": "Failed to create {folder} folder"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ from homeassistant.helpers.selector import (
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
SelectSelectorMode,
|
||||
TemplateSelector,
|
||||
)
|
||||
from homeassistant.helpers.typing import VolDictType
|
||||
@ -32,14 +33,17 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
UNSUPPORTED_MODELS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -124,26 +128,32 @@ class OpenAIOptionsFlow(OptionsFlow):
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is not None:
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
# Re-render the options again, now with the recommended options shown/hidden
|
||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
||||
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||
else:
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
else:
|
||||
# Re-render the options again, now with the recommended options shown/hidden
|
||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||
|
||||
options = {
|
||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||
}
|
||||
options = {
|
||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||
}
|
||||
|
||||
schema = openai_config_option_schema(self.hass, options)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@ -210,6 +220,17 @@ def openai_config_option_schema(
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=RECOMMENDED_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_REASONING_EFFORT,
|
||||
description={"suggested_value": options.get(CONF_REASONING_EFFORT)},
|
||||
default=RECOMMENDED_REASONING_EFFORT,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=["low", "medium", "high"],
|
||||
translation_key="reasoning_effort",
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@ -15,3 +15,17 @@ CONF_TOP_P = "top_p"
|
||||
RECOMMENDED_TOP_P = 1.0
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
CONF_REASONING_EFFORT = "reasoning_effort"
|
||||
RECOMMENDED_REASONING_EFFORT = "low"
|
||||
|
||||
UNSUPPORTED_MODELS = [
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"gpt-4o-realtime-preview",
|
||||
"gpt-4o-realtime-preview-2024-12-17",
|
||||
"gpt-4o-realtime-preview-2024-10-01",
|
||||
"gpt-4o-mini-realtime-preview",
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17",
|
||||
]
|
||||
|
@ -31,12 +31,14 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
@ -68,7 +70,9 @@ def _format_tool(
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
def _convert_message_to_param(
|
||||
message: ChatCompletionMessage,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert from class to TypedDict."""
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
@ -92,17 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
||||
return param
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.Content
|
||||
| conversation.NativeContent[ChatCompletionMessageParam],
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.role == "native":
|
||||
# mypy doesn't understand that checking role ensures content type
|
||||
return message.content # type: ignore[return-value]
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": message.role, "content": message.content},
|
||||
if content.role == "tool_result":
|
||||
assert type(content) is conversation.ToolResultContent
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
# Handle the Assistant content including tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -166,14 +195,14 @@ class OpenAIConversationEntity(
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatLog[ChatCompletionMessageParam],
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
assert user_input.agent_id
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await session.async_update_llm_data(
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
@ -183,73 +212,77 @@ class OpenAIConversationEntity(
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if session.llm_api:
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
messages = [
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
messages = [_convert_content_to_param(content) for content in chat_log.content]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
user=session.conversation_id,
|
||||
model_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools or NOT_GIVEN,
|
||||
"max_completion_tokens": options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": chat_log.conversation_id,
|
||||
}
|
||||
|
||||
if model.startswith("o"):
|
||||
model_args["reasoning_effort"] = options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
)
|
||||
|
||||
try:
|
||||
result = await client.chat.completions.create(**model_args)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
LOGGER.debug("Response %s", result)
|
||||
response = result.choices[0].message
|
||||
messages.append(_message_convert(response))
|
||||
messages.append(_convert_message_to_param(response))
|
||||
|
||||
session.async_add_message(
|
||||
conversation.Content(
|
||||
role=response.role,
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
),
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
if response.tool_calls:
|
||||
tool_calls = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in response.tool_calls
|
||||
]
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
_convert_content_to_param(tool_response)
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not response.tool_calls or not session.llm_api:
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
tool_response = await session.async_call_tool(tool_input)
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=tool_call.id,
|
||||
content=json.dumps(tool_response),
|
||||
)
|
||||
)
|
||||
session.async_add_message(
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=messages[-1],
|
||||
)
|
||||
)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=session.conversation_id
|
||||
response=intent_response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
|
@ -23,12 +23,26 @@
|
||||
"temperature": "Temperature",
|
||||
"top_p": "Top P",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"recommended": "Recommended model settings"
|
||||
"recommended": "Recommended model settings",
|
||||
"reasoning_effort": "Reasoning effort"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"model_not_supported": "This model is not supported, please select a different model"
|
||||
}
|
||||
},
|
||||
"selector": {
|
||||
"reasoning_effort": {
|
||||
"options": {
|
||||
"low": "Low",
|
||||
"medium": "Medium",
|
||||
"high": "High"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
@ -18,7 +18,13 @@ from homeassistant.const import (
|
||||
STATE_OFF,
|
||||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.core import (
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.collection import (
|
||||
CollectionEntity,
|
||||
@ -44,6 +50,7 @@ from .const import (
|
||||
CONF_TO,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
SERVICE_GET,
|
||||
WEEKDAY_TO_CONF,
|
||||
)
|
||||
|
||||
@ -205,6 +212,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
reload_service_handler,
|
||||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
SERVICE_GET,
|
||||
{},
|
||||
async_get_schedule_service,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
await component.async_setup(config)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -296,6 +311,10 @@ class Schedule(CollectionEntity):
|
||||
self.async_on_remove(self._clean_up_listener)
|
||||
self._update()
|
||||
|
||||
def get_schedule(self) -> ConfigType:
|
||||
"""Return the schedule."""
|
||||
return {d: self._config[d] for d in WEEKDAY_TO_CONF.values()}
|
||||
|
||||
@callback
|
||||
def _update(self, _: datetime | None = None) -> None:
|
||||
"""Update the states of the schedule."""
|
||||
@ -390,3 +409,10 @@ class Schedule(CollectionEntity):
|
||||
data_keys.update(time_range_custom_data.keys())
|
||||
|
||||
return frozenset(data_keys)
|
||||
|
||||
|
||||
async def async_get_schedule_service(
|
||||
schedule: Schedule, service_call: ServiceCall
|
||||
) -> ServiceResponse:
|
||||
"""Return the schedule configuration."""
|
||||
return schedule.get_schedule()
|
||||
|
@ -37,3 +37,5 @@ WEEKDAY_TO_CONF: Final = {
|
||||
5: CONF_SATURDAY,
|
||||
6: CONF_SUNDAY,
|
||||
}
|
||||
|
||||
SERVICE_GET: Final = "get_schedule"
|
||||
|
@ -2,6 +2,9 @@
|
||||
"services": {
|
||||
"reload": {
|
||||
"service": "mdi:reload"
|
||||
},
|
||||
"get_schedule": {
|
||||
"service": "mdi:calendar-export"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1 +1,5 @@
|
||||
reload:
|
||||
get_schedule:
|
||||
target:
|
||||
entity:
|
||||
domain: schedule
|
||||
|
@ -25,6 +25,10 @@
|
||||
"reload": {
|
||||
"name": "[%key:common::action::reload%]",
|
||||
"description": "Reloads schedules from the YAML-configuration."
|
||||
},
|
||||
"get_schedule": {
|
||||
"name": "Get schedule",
|
||||
"description": "Retrieve one or multiple schedules."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,16 @@
|
||||
{
|
||||
"config": {
|
||||
"flow_title": "Add Shark IQ Account",
|
||||
"flow_title": "Add Shark IQ account",
|
||||
"step": {
|
||||
"user": {
|
||||
"description": "Sign into your Shark Clean account to control your devices.",
|
||||
"description": "Sign into your SharkClean account to control your devices.",
|
||||
"data": {
|
||||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"password": "[%key:common::config_flow::data::password%]",
|
||||
"region": "Region"
|
||||
},
|
||||
"data_description": {
|
||||
"region": "Shark IQ uses different services in the EU. Select your region to connect to the correct service for your account."
|
||||
"region": "Shark IQ uses different services in the EU. Select your region to connect to the correct service for your account."
|
||||
}
|
||||
},
|
||||
"reauth_confirm": {
|
||||
@ -37,18 +37,18 @@
|
||||
"region": {
|
||||
"options": {
|
||||
"europe": "Europe",
|
||||
"elsewhere": "Everywhere Else"
|
||||
"elsewhere": "Everywhere else"
|
||||
}
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"invalid_room": {
|
||||
"message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the Shark App, including capitalization."
|
||||
"message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the SharkClean app, including capitalization."
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"clean_room": {
|
||||
"name": "Clean Room",
|
||||
"name": "Clean room",
|
||||
"description": "Cleans a specific user-defined room or set of rooms.",
|
||||
"fields": {
|
||||
"rooms": {
|
||||
|
@ -272,6 +272,18 @@ RPC_SENSORS: Final = {
|
||||
entity_category=EntityCategory.DIAGNOSTIC,
|
||||
entity_class=RpcBluTrvBinarySensor,
|
||||
),
|
||||
"flood": RpcBinarySensorDescription(
|
||||
key="flood",
|
||||
sub_key="alarm",
|
||||
name="Flood",
|
||||
device_class=BinarySensorDeviceClass.MOISTURE,
|
||||
),
|
||||
"mute": RpcBinarySensorDescription(
|
||||
key="flood",
|
||||
sub_key="mute",
|
||||
name="Mute",
|
||||
entity_category=EntityCategory.DIAGNOSTIC,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ from homeassistant.config_entries import SOURCE_USER, ConfigFlow, ConfigFlowResu
|
||||
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_USERNAME
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.device_registry import format_mac
|
||||
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
|
||||
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
|
||||
|
||||
from .const import DOMAIN
|
||||
@ -35,7 +36,8 @@ STEP_AUTH_DATA_SCHEMA = vol.Schema(
|
||||
class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for SMLIGHT Zigbee."""
|
||||
|
||||
host: str
|
||||
_host: str
|
||||
_device_name: str
|
||||
client: Api2
|
||||
|
||||
async def async_step_user(
|
||||
@ -45,11 +47,13 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is not None:
|
||||
self.host = user_input[CONF_HOST]
|
||||
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
|
||||
self._host = user_input[CONF_HOST]
|
||||
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
|
||||
|
||||
try:
|
||||
info = await self.client.get_info()
|
||||
self._host = str(info.device_ip)
|
||||
self._device_name = str(info.hostname)
|
||||
|
||||
if info.model not in Devices:
|
||||
return self.async_abort(reason="unsupported_device")
|
||||
@ -93,15 +97,14 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
self, discovery_info: ZeroconfServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a discovered Lan coordinator."""
|
||||
local_name = discovery_info.hostname[:-1]
|
||||
node_name = local_name.removesuffix(".local")
|
||||
mac: str | None = discovery_info.properties.get("mac")
|
||||
self._device_name = discovery_info.hostname.removesuffix(".local.")
|
||||
self._host = discovery_info.host
|
||||
|
||||
self.host = local_name
|
||||
self.context["title_placeholders"] = {CONF_NAME: node_name}
|
||||
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
|
||||
self.context["title_placeholders"] = {CONF_NAME: self._device_name}
|
||||
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
|
||||
|
||||
mac = discovery_info.properties.get("mac")
|
||||
# fallback for legacy firmware
|
||||
# fallback for legacy firmware older than v2.3.x
|
||||
if mac is None:
|
||||
try:
|
||||
info = await self.client.get_info()
|
||||
@ -111,7 +114,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
mac = info.MAC
|
||||
|
||||
await self.async_set_unique_id(format_mac(mac))
|
||||
self._abort_if_unique_id_configured()
|
||||
self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
|
||||
|
||||
return await self.async_step_confirm_discovery()
|
||||
|
||||
@ -122,7 +125,6 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is not None:
|
||||
user_input[CONF_HOST] = self.host
|
||||
try:
|
||||
info = await self.client.get_info()
|
||||
|
||||
@ -142,7 +144,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="confirm_discovery",
|
||||
description_placeholders={"host": self.host},
|
||||
description_placeholders={"host": self._device_name},
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
@ -151,8 +153,8 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle reauth when API Authentication failed."""
|
||||
|
||||
self.host = entry_data[CONF_HOST]
|
||||
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
|
||||
self._host = entry_data[CONF_HOST]
|
||||
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
|
||||
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
@ -182,6 +184,16 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_dhcp(
|
||||
self, discovery_info: DhcpServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle DHCP discovery."""
|
||||
await self.async_set_unique_id(format_mac(discovery_info.macaddress))
|
||||
self._abort_if_unique_id_configured(updates={CONF_HOST: discovery_info.ip})
|
||||
# This should never happen since we only listen to DHCP requests
|
||||
# for configured devices.
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
async def _async_check_auth_required(self, user_input: dict[str, Any]) -> bool:
|
||||
"""Check if auth required and attempt to authenticate."""
|
||||
if await self.client.check_auth_needed():
|
||||
@ -200,11 +212,10 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
await self.async_set_unique_id(
|
||||
format_mac(info.MAC), raise_on_progress=self.source != SOURCE_USER
|
||||
)
|
||||
self._abort_if_unique_id_configured()
|
||||
self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
|
||||
|
||||
if user_input.get(CONF_HOST) is None:
|
||||
user_input[CONF_HOST] = self.host
|
||||
user_input[CONF_HOST] = self._host
|
||||
|
||||
assert info.model is not None
|
||||
title = self.context.get("title_placeholders", {}).get(CONF_NAME) or info.model
|
||||
title = self._device_name or info.model
|
||||
return self.async_create_entry(title=title, data=user_input)
|
||||
|
@ -3,10 +3,15 @@
|
||||
"name": "SMLIGHT SLZB",
|
||||
"codeowners": ["@tl-sl"],
|
||||
"config_flow": true,
|
||||
"dhcp": [
|
||||
{
|
||||
"registered_devices": true
|
||||
}
|
||||
],
|
||||
"documentation": "https://www.home-assistant.io/integrations/smlight",
|
||||
"integration_type": "device",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["pysmlight==0.2.1"],
|
||||
"requirements": ["pysmlight==0.2.2"],
|
||||
"zeroconf": [
|
||||
{
|
||||
"type": "_slzb-06._tcp.local."
|
||||
|
@ -65,6 +65,7 @@ BINARY_SENSORS = [
|
||||
key="currently_obstructed",
|
||||
translation_key="currently_obstructed",
|
||||
device_class=BinarySensorDeviceClass.PROBLEM,
|
||||
entity_category=EntityCategory.DIAGNOSTIC,
|
||||
value_fn=lambda data: data.status["currently_obstructed"],
|
||||
),
|
||||
StarlinkBinarySensorEntityDescription(
|
||||
@ -114,4 +115,9 @@ BINARY_SENSORS = [
|
||||
entity_category=EntityCategory.DIAGNOSTIC,
|
||||
value_fn=lambda data: data.alert["alert_unexpected_location"],
|
||||
),
|
||||
StarlinkBinarySensorEntityDescription(
|
||||
key="connection",
|
||||
device_class=BinarySensorDeviceClass.CONNECTIVITY,
|
||||
value_fn=lambda data: data.status["state"] == "CONNECTED",
|
||||
),
|
||||
]
|
||||
|
@ -14,7 +14,7 @@
|
||||
},
|
||||
"reconfigure": {
|
||||
"title": "Reconfigure your Tado",
|
||||
"description": "Reconfigure the entry, for your account: `{username}`.",
|
||||
"description": "Reconfigure the entry for your account: `{username}`.",
|
||||
"data": {
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
},
|
||||
@ -25,7 +25,7 @@
|
||||
},
|
||||
"error": {
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]",
|
||||
"no_homes": "There are no homes linked to this tado account.",
|
||||
"no_homes": "There are no homes linked to this Tado account.",
|
||||
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||
}
|
||||
@ -33,7 +33,7 @@
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Dont change until you cancel; TADO_DEFAULT:= Change based on your setting in Tado App).",
|
||||
"description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Don't change until you cancel; TADO_DEFAULT:= Change based on your setting in the Tado app).",
|
||||
"data": {
|
||||
"fallback": "Choose fallback mode."
|
||||
},
|
||||
@ -102,11 +102,11 @@
|
||||
},
|
||||
"time_period": {
|
||||
"name": "Time period",
|
||||
"description": "Choose this or Overlay. Set the time period for the change if you want to be specific. Alternatively use Overlay."
|
||||
"description": "Choose this or 'Overlay'. Set the time period for the change if you want to be specific."
|
||||
},
|
||||
"requested_overlay": {
|
||||
"name": "Overlay",
|
||||
"description": "Choose this or Time Period. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on tado app setting."
|
||||
"description": "Choose this or 'Time period'. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on Tado app setting."
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -151,8 +151,8 @@
|
||||
},
|
||||
"issues": {
|
||||
"water_heater_fallback": {
|
||||
"title": "Tado Water Heater entities now support fallback options",
|
||||
"description": "Due to added support for water heaters entities, these entities may use different overlay. Please configure integration entity and tado app water heater zone overlay options. Otherwise, please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)."
|
||||
"title": "Tado water heater entities now support fallback options",
|
||||
"description": "Due to added support for water heaters entities, these entities may use a different overlay. Please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ from pyvesync import VeSync
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
|
||||
from .common import async_generate_device_list
|
||||
@ -91,3 +92,37 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
hass.data.pop(DOMAIN)
|
||||
|
||||
return unload_ok
|
||||
|
||||
|
||||
async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Migrate old entry."""
|
||||
_LOGGER.debug(
|
||||
"Migrating VeSync config entry: %s minor version: %s",
|
||||
config_entry.version,
|
||||
config_entry.minor_version,
|
||||
)
|
||||
if config_entry.minor_version == 1:
|
||||
# Migrate switch/outlets entity to a new unique ID
|
||||
_LOGGER.debug("Migrating VeSync config entry from version 1 to version 2")
|
||||
entity_registry = er.async_get(hass)
|
||||
registry_entries = er.async_entries_for_config_entry(
|
||||
entity_registry, config_entry.entry_id
|
||||
)
|
||||
for reg_entry in registry_entries:
|
||||
if "-" not in reg_entry.unique_id and reg_entry.entity_id.startswith(
|
||||
Platform.SWITCH
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Migrating switch/outlet entity from unique_id: %s to unique_id: %s",
|
||||
reg_entry.unique_id,
|
||||
reg_entry.unique_id + "-device_status",
|
||||
)
|
||||
entity_registry.async_update_entity(
|
||||
reg_entry.entity_id,
|
||||
new_unique_id=reg_entry.unique_id + "-device_status",
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("Skipping entity with unique_id: %s", reg_entry.unique_id)
|
||||
hass.config_entries.async_update_entry(config_entry, minor_version=2)
|
||||
|
||||
return True
|
||||
|
@ -24,6 +24,7 @@ class VeSyncFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow."""
|
||||
|
||||
VERSION = 1
|
||||
MINOR_VERSION = 2
|
||||
|
||||
@callback
|
||||
def _show_form(self, errors: dict[str, str] | None = None) -> ConfigFlowResult:
|
||||
|
@ -12,5 +12,5 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/vesync",
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["pyvesync"],
|
||||
"requirements": ["pyvesync==2.1.16"]
|
||||
"requirements": ["pyvesync==2.1.17"]
|
||||
}
|
||||
|
@ -83,6 +83,7 @@ class VeSyncSwitchHA(VeSyncBaseSwitch, SwitchEntity):
|
||||
) -> None:
|
||||
"""Initialize the VeSync switch device."""
|
||||
super().__init__(plug, coordinator)
|
||||
self._attr_unique_id = f"{super().unique_id}-device_status"
|
||||
self.smartplug = plug
|
||||
|
||||
|
||||
@ -94,4 +95,5 @@ class VeSyncLightSwitch(VeSyncBaseSwitch, SwitchEntity):
|
||||
) -> None:
|
||||
"""Initialize Light Switch device class."""
|
||||
super().__init__(switch, coordinator)
|
||||
self._attr_unique_id = f"{super().unique_id}-device_status"
|
||||
self.switch = switch
|
||||
|
4
homeassistant/generated/dhcp.py
generated
4
homeassistant/generated/dhcp.py
generated
@ -616,6 +616,10 @@ DHCP: Final[list[dict[str, str | bool]]] = [
|
||||
"hostname": "hub*",
|
||||
"macaddress": "286D97*",
|
||||
},
|
||||
{
|
||||
"domain": "smlight",
|
||||
"registered_devices": True,
|
||||
},
|
||||
{
|
||||
"domain": "solaredge",
|
||||
"hostname": "target",
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field as dc_field
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
|
||||
from . import (
|
||||
area_registry as ar,
|
||||
@ -139,6 +140,8 @@ class ToolInput:
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
# Using lambda for default to allow patching in tests
|
||||
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
|
||||
|
||||
|
||||
class Tool:
|
||||
|
@ -29,7 +29,7 @@ certifi>=2021.5.30
|
||||
ciso8601==2.3.2
|
||||
cronsim==2.6
|
||||
cryptography==44.0.0
|
||||
dbus-fast==2.31.0
|
||||
dbus-fast==2.32.0
|
||||
fnv-hash-fast==1.2.2
|
||||
go2rtc-client==0.1.2
|
||||
ha-ffmpeg==3.2.2
|
||||
@ -53,7 +53,7 @@ psutil-home-assistant==0.0.1
|
||||
PyJWT==2.10.1
|
||||
pymicro-vad==1.0.1
|
||||
PyNaCl==1.5.0
|
||||
pyOpenSSL==24.3.0
|
||||
pyOpenSSL==25.0.0
|
||||
pyserial==3.5
|
||||
pyspeex-noise==1.0.2
|
||||
python-slugify==8.0.4
|
||||
|
@ -59,7 +59,7 @@ dependencies = [
|
||||
"cryptography==44.0.0",
|
||||
"Pillow==11.1.0",
|
||||
"propcache==0.2.1",
|
||||
"pyOpenSSL==24.3.0",
|
||||
"pyOpenSSL==25.0.0",
|
||||
"orjson==3.10.12",
|
||||
"packaging>=23.1",
|
||||
"psutil-home-assistant==0.0.1",
|
||||
|
2
requirements.txt
generated
2
requirements.txt
generated
@ -31,7 +31,7 @@ PyJWT==2.10.1
|
||||
cryptography==44.0.0
|
||||
Pillow==11.1.0
|
||||
propcache==0.2.1
|
||||
pyOpenSSL==24.3.0
|
||||
pyOpenSSL==25.0.0
|
||||
orjson==3.10.12
|
||||
packaging>=23.1
|
||||
psutil-home-assistant==0.0.1
|
||||
|
18
requirements_all.txt
generated
18
requirements_all.txt
generated
@ -600,7 +600,7 @@ bizkaibus==0.1.1
|
||||
|
||||
# homeassistant.components.eq3btsmart
|
||||
# homeassistant.components.esphome
|
||||
bleak-esphome==2.6.0
|
||||
bleak-esphome==2.7.0
|
||||
|
||||
# homeassistant.components.bluetooth
|
||||
bleak-retry-connector==3.8.0
|
||||
@ -741,7 +741,7 @@ datadog==0.15.0
|
||||
datapoint==0.9.9
|
||||
|
||||
# homeassistant.components.bluetooth
|
||||
dbus-fast==2.31.0
|
||||
dbus-fast==2.32.0
|
||||
|
||||
# homeassistant.components.debugpy
|
||||
debugpy==1.8.11
|
||||
@ -1437,9 +1437,6 @@ motioneye-client==0.3.14
|
||||
# homeassistant.components.bang_olufsen
|
||||
mozart-api==4.1.1.116.4
|
||||
|
||||
# homeassistant.components.onedrive
|
||||
msgraph-sdk==1.16.0
|
||||
|
||||
# homeassistant.components.mullvad
|
||||
mullvad-api==1.0.0
|
||||
|
||||
@ -1561,6 +1558,9 @@ omnilogic==0.4.5
|
||||
# homeassistant.components.ondilo_ico
|
||||
ondilo==0.5.0
|
||||
|
||||
# homeassistant.components.onedrive
|
||||
onedrive-personal-sdk==0.0.1
|
||||
|
||||
# homeassistant.components.onvif
|
||||
onvif-zeep-async==3.2.5
|
||||
|
||||
@ -2205,7 +2205,7 @@ pypalazzetti==0.1.19
|
||||
pypca==0.0.7
|
||||
|
||||
# homeassistant.components.lcn
|
||||
pypck==0.8.3
|
||||
pypck==0.8.5
|
||||
|
||||
# homeassistant.components.pjlink
|
||||
pypjlink2==1.2.1
|
||||
@ -2313,7 +2313,7 @@ pysmarty2==0.10.1
|
||||
pysml==0.0.12
|
||||
|
||||
# homeassistant.components.smlight
|
||||
pysmlight==0.2.1
|
||||
pysmlight==0.2.2
|
||||
|
||||
# homeassistant.components.snmp
|
||||
pysnmp==6.2.6
|
||||
@ -2391,7 +2391,7 @@ python-gitlab==1.6.0
|
||||
python-google-drive-api==0.0.2
|
||||
|
||||
# homeassistant.components.analytics_insights
|
||||
python-homeassistant-analytics==0.8.1
|
||||
python-homeassistant-analytics==0.9.0
|
||||
|
||||
# homeassistant.components.homewizard
|
||||
python-homewizard-energy==v8.3.2
|
||||
@ -2516,7 +2516,7 @@ pyvera==0.3.15
|
||||
pyversasense==0.0.6
|
||||
|
||||
# homeassistant.components.vesync
|
||||
pyvesync==2.1.16
|
||||
pyvesync==2.1.17
|
||||
|
||||
# homeassistant.components.vizio
|
||||
pyvizio==0.1.61
|
||||
|
@ -8,31 +8,31 @@
|
||||
-c homeassistant/package_constraints.txt
|
||||
-r requirements_test_pre_commit.txt
|
||||
astroid==3.3.8
|
||||
coverage==7.6.8
|
||||
coverage==7.6.10
|
||||
freezegun==1.5.1
|
||||
license-expression==30.4.0
|
||||
license-expression==30.4.1
|
||||
mock-open==1.4.0
|
||||
mypy-dev==1.16.0a1
|
||||
pre-commit==4.0.0
|
||||
pydantic==2.10.6
|
||||
pylint==3.3.3
|
||||
pylint-per-file-ignores==1.3.2
|
||||
pipdeptree==2.23.4
|
||||
pytest-asyncio==0.24.0
|
||||
pylint==3.3.4
|
||||
pylint-per-file-ignores==1.4.0
|
||||
pipdeptree==2.25.0
|
||||
pytest-asyncio==0.25.3
|
||||
pytest-aiohttp==1.0.5
|
||||
pytest-cov==6.0.0
|
||||
pytest-freezer==0.4.8
|
||||
pytest-github-actions-annotate-failures==0.2.0
|
||||
pytest-freezer==0.4.9
|
||||
pytest-github-actions-annotate-failures==0.3.0
|
||||
pytest-socket==0.7.0
|
||||
pytest-sugar==1.0.0
|
||||
pytest-timeout==2.3.1
|
||||
pytest-unordered==0.6.1
|
||||
pytest-picked==0.5.0
|
||||
pytest-picked==0.5.1
|
||||
pytest-xdist==3.6.1
|
||||
pytest==8.3.4
|
||||
requests-mock==1.12.1
|
||||
respx==0.22.0
|
||||
syrupy==4.8.0
|
||||
syrupy==4.8.1
|
||||
tqdm==4.66.5
|
||||
types-aiofiles==24.1.0.20241221
|
||||
types-atomicwrites==1.4.5.1
|
||||
|
18
requirements_test_all.txt
generated
18
requirements_test_all.txt
generated
@ -528,7 +528,7 @@ bimmer-connected[china]==0.17.2
|
||||
|
||||
# homeassistant.components.eq3btsmart
|
||||
# homeassistant.components.esphome
|
||||
bleak-esphome==2.6.0
|
||||
bleak-esphome==2.7.0
|
||||
|
||||
# homeassistant.components.bluetooth
|
||||
bleak-retry-connector==3.8.0
|
||||
@ -634,7 +634,7 @@ datadog==0.15.0
|
||||
datapoint==0.9.9
|
||||
|
||||
# homeassistant.components.bluetooth
|
||||
dbus-fast==2.31.0
|
||||
dbus-fast==2.32.0
|
||||
|
||||
# homeassistant.components.debugpy
|
||||
debugpy==1.8.11
|
||||
@ -1206,9 +1206,6 @@ motioneye-client==0.3.14
|
||||
# homeassistant.components.bang_olufsen
|
||||
mozart-api==4.1.1.116.4
|
||||
|
||||
# homeassistant.components.onedrive
|
||||
msgraph-sdk==1.16.0
|
||||
|
||||
# homeassistant.components.mullvad
|
||||
mullvad-api==1.0.0
|
||||
|
||||
@ -1306,6 +1303,9 @@ omnilogic==0.4.5
|
||||
# homeassistant.components.ondilo_ico
|
||||
ondilo==0.5.0
|
||||
|
||||
# homeassistant.components.onedrive
|
||||
onedrive-personal-sdk==0.0.1
|
||||
|
||||
# homeassistant.components.onvif
|
||||
onvif-zeep-async==3.2.5
|
||||
|
||||
@ -1795,7 +1795,7 @@ pyownet==0.10.0.post1
|
||||
pypalazzetti==0.1.19
|
||||
|
||||
# homeassistant.components.lcn
|
||||
pypck==0.8.3
|
||||
pypck==0.8.5
|
||||
|
||||
# homeassistant.components.pjlink
|
||||
pypjlink2==1.2.1
|
||||
@ -1882,7 +1882,7 @@ pysmarty2==0.10.1
|
||||
pysml==0.0.12
|
||||
|
||||
# homeassistant.components.smlight
|
||||
pysmlight==0.2.1
|
||||
pysmlight==0.2.2
|
||||
|
||||
# homeassistant.components.snmp
|
||||
pysnmp==6.2.6
|
||||
@ -1933,7 +1933,7 @@ python-fullykiosk==0.0.14
|
||||
python-google-drive-api==0.0.2
|
||||
|
||||
# homeassistant.components.analytics_insights
|
||||
python-homeassistant-analytics==0.8.1
|
||||
python-homeassistant-analytics==0.9.0
|
||||
|
||||
# homeassistant.components.homewizard
|
||||
python-homewizard-energy==v8.3.2
|
||||
@ -2031,7 +2031,7 @@ pyuptimerobot==22.2.0
|
||||
pyvera==0.3.15
|
||||
|
||||
# homeassistant.components.vesync
|
||||
pyvesync==2.1.16
|
||||
pyvesync==2.1.17
|
||||
|
||||
# homeassistant.components.vizio
|
||||
pyvizio==0.1.61
|
||||
|
2
script/hassfest/docker/Dockerfile
generated
2
script/hassfest/docker/Dockerfile
generated
@ -24,7 +24,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.21,source=/uv,target=/bin/uv \
|
||||
--no-cache \
|
||||
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
|
||||
-r /usr/src/homeassistant/requirements.txt \
|
||||
stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.9.1 \
|
||||
stdlib-list==0.10.0 pipdeptree==2.25.0 tqdm==4.66.5 ruff==0.9.1 \
|
||||
PyTurboJPEG==1.7.5 go2rtc-client==0.1.2 ha-ffmpeg==3.2.2 hassil==2.2.0 home-assistant-intents==2025.1.28 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
|
||||
|
||||
LABEL "name"="hassfest"
|
||||
|
@ -236,6 +236,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
@ -373,6 +374,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
@ -3,6 +3,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -32,7 +33,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -94,6 +95,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -123,7 +125,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -185,6 +187,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -214,7 +217,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -276,6 +279,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -329,7 +333,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -391,6 +395,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -427,6 +432,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -434,7 +440,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -478,6 +484,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -485,7 +492,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -529,6 +536,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -536,7 +544,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -580,6 +588,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
|
@ -1,6 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_audio_pipeline
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -31,7 +32,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -84,6 +85,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -114,7 +116,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -179,6 +181,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -209,7 +212,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -262,6 +265,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -314,7 +318,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.5
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -367,6 +371,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -399,6 +404,7 @@
|
||||
# ---
|
||||
# name: test_device_capture
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -425,6 +431,7 @@
|
||||
# ---
|
||||
# name: test_device_capture_override
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -473,6 +480,7 @@
|
||||
# ---
|
||||
# name: test_device_capture_queue_full
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -512,6 +520,7 @@
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -522,7 +531,7 @@
|
||||
# ---
|
||||
# name: test_intent_failed.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
@ -535,6 +544,7 @@
|
||||
# ---
|
||||
# name: test_intent_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -545,7 +555,7 @@
|
||||
# ---
|
||||
# name: test_intent_timeout.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
@ -564,6 +574,7 @@
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -574,7 +585,7 @@
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'never mind',
|
||||
@ -611,6 +622,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -621,6 +633,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -631,6 +644,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -641,6 +655,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -651,6 +666,7 @@
|
||||
# ---
|
||||
# name: test_stt_stream_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -677,6 +693,7 @@
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg0]
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -723,6 +740,7 @@
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1]
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -775,6 +793,7 @@
|
||||
# ---
|
||||
# name: test_tts_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -796,6 +815,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -806,6 +826,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -857,6 +878,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -867,6 +889,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -921,6 +944,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -931,6 +955,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Test Voice Assistant init."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
import wave
|
||||
|
||||
import hass_nabucasa
|
||||
@ -41,6 +42,14 @@ from .conftest import (
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
processed = []
|
||||
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
@ -771,7 +780,7 @@ async def test_tts_audio_output(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test trigger sentence",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="I'd like to order a stout please",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
|
@ -2,8 +2,9 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extra_msg",
|
||||
[
|
||||
|
@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
||||
self, start_announcement: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite."""
|
||||
self.start_conversations.append((self._extra_system_prompt, start_announcement))
|
||||
self.start_conversations.append(
|
||||
(self._conversation_id, self._extra_system_prompt, start_announcement)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Test the Assist Satellite entity."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -31,6 +32,14 @@ from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-conversation-id"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
||||
"""Set up a pipeline with a TTS engine."""
|
||||
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
"extra_system_prompt": "Better system prompt",
|
||||
},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
"Better system prompt",
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
"start_media_id": "media-source://given",
|
||||
},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
"Hello",
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
(
|
||||
{"start_media_id": "http://example.com/given.mp3"},
|
||||
(
|
||||
"mock-conversation-id",
|
||||
None,
|
||||
AssistSatelliteAnnouncement(
|
||||
message="",
|
||||
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
|
||||
async def test_start_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
|
@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import (
|
||||
Content,
|
||||
AssistantContent,
|
||||
ConversationInput,
|
||||
ConverseError,
|
||||
NativeContent,
|
||||
ToolResultContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
|
||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
@ -56,13 +56,13 @@ async def test_cleanup(
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
# Add message so it persists
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
@ -79,7 +79,7 @@ async def test_cleanup(
|
||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
|
||||
async def test_add_message(
|
||||
async def test_default_content(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
@ -87,95 +87,11 @@ async def test_add_message(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
assert len(chat_log.messages) == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(role="system", agent_id=None, content="")
|
||||
)
|
||||
|
||||
# No 2 user messages in a row
|
||||
assert chat_log.messages[1].role == "user"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
|
||||
|
||||
# No 2 assistant messages in a row
|
||||
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
|
||||
assert len(chat_log.messages) == 3
|
||||
assert chat_log.messages[-1].role == "assistant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
|
||||
|
||||
async def test_message_filtering(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
messages = chat_log.async_get_messages(agent_id=None)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == Content(
|
||||
role="system",
|
||||
agent_id=None,
|
||||
content="",
|
||||
)
|
||||
assert messages[1] == Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content=mock_conversation_input.text,
|
||||
)
|
||||
# Cannot add a second user message in a row
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
# Different agent, native messages will be filtered out.
|
||||
chat_log.async_add_message(
|
||||
NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||
)
|
||||
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
|
||||
# A non-native message from another agent is not filtered out.
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="another-mock-agent-id",
|
||||
content="Hi!",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chat_log.messages) == 6
|
||||
|
||||
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 5
|
||||
|
||||
assert messages[2] == Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
|
||||
assert messages[4] == Content(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||
)
|
||||
assert len(chat_log.content) == 2
|
||||
assert chat_log.content[0].role == "system"
|
||||
assert chat_log.content[0].content == ""
|
||||
assert chat_log.content[1].role == "user"
|
||||
assert chat_log.content[1].content == mock_conversation_input.text
|
||||
|
||||
|
||||
async def test_llm_api(
|
||||
@ -268,12 +184,10 @@ async def test_template_variables(
|
||||
),
|
||||
)
|
||||
|
||||
assert chat_log.user_name == "Test User"
|
||||
|
||||
assert "The instance name is test home." in chat_log.messages[0].content
|
||||
assert "The user name is Test User." in chat_log.messages[0].content
|
||||
assert "The user id is 12345." in chat_log.messages[0].content
|
||||
assert "The calling platform is test." in chat_log.messages[0].content
|
||||
assert "The instance name is test home." in chat_log.content[0].content
|
||||
assert "The user name is Test User." in chat_log.content[0].content
|
||||
assert "The user id is 12345." in chat_log.content[0].content
|
||||
assert "The calling platform is test." in chat_log.content[0].content
|
||||
|
||||
|
||||
async def test_extra_systen_prompt(
|
||||
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
conversation_id = chat_log.conversation_id
|
||||
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that we take new system prompts
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_log.messages[0].content
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_log.content[0].content
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
|
||||
|
||||
async def test_tool_call(
|
||||
@ -383,8 +297,7 @@ async def test_tool_call(
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
@ -398,14 +311,29 @@ async def test_tool_call(
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
||||
assert result == "Test response"
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result="Test response",
|
||||
tool_name="test_tool",
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_call_exception(
|
||||
@ -423,8 +351,7 @@ async def test_tool_call_exception(
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
@ -438,11 +365,26 @@ async def test_tool_call_exception(
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
||||
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
||||
tool_name="test_tool",
|
||||
)
|
@ -36,6 +36,13 @@ def freeze_the_time():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid_tools():
|
||||
"""Mock generated ULIDs for tool calls."""
|
||||
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||
)
|
||||
@ -177,6 +184,7 @@ async def test_chat_history(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
@ -256,6 +264,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={
|
||||
"param1": ["test_value", "param1's value"],
|
||||
@ -287,9 +296,7 @@ async def test_function_call(
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert [
|
||||
p.function_response.name
|
||||
for p in detail_event["data"]["messages"][2]["content"].parts
|
||||
if p.function_response
|
||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||
] == ["test_tool"]
|
||||
|
||||
|
||||
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
),
|
||||
@ -451,6 +459,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": 1},
|
||||
),
|
||||
@ -605,6 +614,7 @@ async def test_template_variables(
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.text = "Model response"
|
||||
mock_part.function_call = None
|
||||
chat_response.parts = [mock_part]
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||
|
@ -324,6 +324,24 @@ TEST_JOB_DONE = supervisor_jobs.Job(
|
||||
errors=[],
|
||||
child_jobs=[],
|
||||
)
|
||||
TEST_RESTORE_JOB_DONE_WITH_ERROR = supervisor_jobs.Job(
|
||||
name="backup_manager_partial_restore",
|
||||
reference="1ef41507",
|
||||
uuid=UUID(TEST_JOB_ID),
|
||||
progress=0.0,
|
||||
stage="copy_additional_locations",
|
||||
done=True,
|
||||
errors=[
|
||||
supervisor_jobs.JobError(
|
||||
type="BackupInvalidError",
|
||||
message=(
|
||||
"Backup was made on supervisor version 2025.02.2.dev3105, "
|
||||
"can't restore on 2025.01.2.dev3105"
|
||||
),
|
||||
)
|
||||
],
|
||||
child_jobs=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -1946,6 +1964,97 @@ async def test_reader_writer_restore_error(
|
||||
assert response["error"]["code"] == expected_error_code
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("hassio_client", "setup_integration")
|
||||
async def test_reader_writer_restore_late_error(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
supervisor_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test restoring a backup with error."""
|
||||
client = await hass_ws_client(hass)
|
||||
supervisor_client.backups.partial_restore.return_value.job_id = TEST_JOB_ID
|
||||
supervisor_client.backups.list.return_value = [TEST_BACKUP]
|
||||
supervisor_client.backups.backup_info.return_value = TEST_BACKUP_DETAILS
|
||||
supervisor_client.jobs.get_job.return_value = TEST_JOB_NOT_DONE
|
||||
|
||||
await client.send_json_auto_id({"type": "backup/subscribe_events"})
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {"manager_state": "idle"}
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "backup/restore", "agent_id": "hassio.local", "backup_id": "abc123"}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {
|
||||
"manager_state": "restore_backup",
|
||||
"reason": None,
|
||||
"stage": None,
|
||||
"state": "in_progress",
|
||||
}
|
||||
|
||||
supervisor_client.backups.partial_restore.assert_called_once_with(
|
||||
"abc123",
|
||||
supervisor_backups.PartialRestoreOptions(
|
||||
addons=None,
|
||||
background=True,
|
||||
folders=None,
|
||||
homeassistant=True,
|
||||
location=None,
|
||||
password=None,
|
||||
),
|
||||
)
|
||||
|
||||
event = {
|
||||
"event": "job",
|
||||
"data": {
|
||||
"name": "backup_manager_partial_restore",
|
||||
"reference": "7c54aeed",
|
||||
"uuid": TEST_JOB_ID,
|
||||
"progress": 0,
|
||||
"stage": None,
|
||||
"done": True,
|
||||
"parent_id": None,
|
||||
"errors": [
|
||||
{
|
||||
"type": "BackupInvalidError",
|
||||
"message": (
|
||||
"Backup was made on supervisor version 2025.02.2.dev3105, can't"
|
||||
" restore on 2025.01.2.dev3105. Must update supervisor first."
|
||||
),
|
||||
}
|
||||
],
|
||||
"created": "2025-02-03T08:27:49.297997+00:00",
|
||||
},
|
||||
}
|
||||
await client.send_json_auto_id({"type": "supervisor/event", "data": event})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {
|
||||
"manager_state": "restore_backup",
|
||||
"reason": "backup_reader_writer_error",
|
||||
"stage": None,
|
||||
"state": "failed",
|
||||
}
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {"manager_state": "idle"}
|
||||
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": (
|
||||
"Restore failed: [{'type': 'BackupInvalidError', 'message': \"Backup "
|
||||
"was made on supervisor version 2025.02.2.dev3105, can't restore on "
|
||||
'2025.01.2.dev3105. Must update supervisor first."}]'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("backup", "backup_details", "parameters", "expected_error"),
|
||||
[
|
||||
@ -1999,15 +2108,40 @@ async def test_reader_writer_restore_wrong_parameters(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("get_job_result", "last_non_idle_event"),
|
||||
[
|
||||
(
|
||||
TEST_JOB_DONE,
|
||||
{
|
||||
"manager_state": "restore_backup",
|
||||
"reason": "",
|
||||
"stage": None,
|
||||
"state": "completed",
|
||||
},
|
||||
),
|
||||
(
|
||||
TEST_RESTORE_JOB_DONE_WITH_ERROR,
|
||||
{
|
||||
"manager_state": "restore_backup",
|
||||
"reason": "unknown_error",
|
||||
"stage": None,
|
||||
"state": "failed",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("hassio_client")
|
||||
async def test_restore_progress_after_restart(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
supervisor_client: AsyncMock,
|
||||
get_job_result: supervisor_jobs.Job,
|
||||
last_non_idle_event: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test restore backup progress after restart."""
|
||||
|
||||
supervisor_client.jobs.get_job.return_value = TEST_JOB_DONE
|
||||
supervisor_client.jobs.get_job.return_value = get_job_result
|
||||
|
||||
with patch.dict(os.environ, MOCK_ENVIRON | {RESTORE_JOB_ID_ENV: TEST_JOB_ID}):
|
||||
assert await async_setup_component(hass, BACKUP_DOMAIN, {BACKUP_DOMAIN: {}})
|
||||
@ -2018,12 +2152,7 @@ async def test_restore_progress_after_restart(
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"]["last_non_idle_event"] == {
|
||||
"manager_state": "restore_backup",
|
||||
"reason": "",
|
||||
"stage": None,
|
||||
"state": "completed",
|
||||
}
|
||||
assert response["result"]["last_non_idle_event"] == last_non_idle_event
|
||||
assert response["result"]["state"] == "idle"
|
||||
|
||||
|
||||
|
@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid_tools():
|
||||
"""Mock generated ULIDs for tool calls."""
|
||||
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
@ -205,6 +212,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args=expected_tool_args,
|
||||
),
|
||||
@ -285,6 +293,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
@ -1,18 +1,9 @@
|
||||
"""Fixtures for OneDrive tests."""
|
||||
|
||||
from collections.abc import AsyncIterator, Generator
|
||||
from html import escape
|
||||
from json import dumps
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from httpx import Response
|
||||
from msgraph.generated.models.drive_item import DriveItem
|
||||
from msgraph.generated.models.drive_item_collection_response import (
|
||||
DriveItemCollectionResponse,
|
||||
)
|
||||
from msgraph.generated.models.upload_session import UploadSession
|
||||
from msgraph_core.models import LargeFileUploadSession
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.application_credentials import (
|
||||
@ -23,7 +14,13 @@ from homeassistant.components.onedrive.const import DOMAIN, OAUTH_SCOPES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .const import BACKUP_METADATA, CLIENT_ID, CLIENT_SECRET
|
||||
from .const import (
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
MOCK_APPROOT,
|
||||
MOCK_BACKUP_FILE,
|
||||
MOCK_BACKUP_FOLDER,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -70,96 +67,41 @@ def mock_config_entry(expires_at: int, scopes: list[str]) -> MockConfigEntry:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter() -> Generator[MagicMock]:
|
||||
"""Return a mocked GraphAdapter."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.onedrive.config_flow.GraphRequestAdapter",
|
||||
autospec=True,
|
||||
) as mock_adapter,
|
||||
patch(
|
||||
"homeassistant.components.onedrive.backup.GraphRequestAdapter",
|
||||
new=mock_adapter,
|
||||
),
|
||||
):
|
||||
adapter = mock_adapter.return_value
|
||||
adapter.get_http_response_message.return_value = Response(
|
||||
status_code=200,
|
||||
json={
|
||||
"parentReference": {"driveId": "mock_drive_id"},
|
||||
"createdBy": {"user": {"displayName": "John Doe"}},
|
||||
},
|
||||
)
|
||||
yield adapter
|
||||
adapter.send_async.return_value = LargeFileUploadSession(
|
||||
next_expected_ranges=["2-"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_graph_client(mock_adapter: MagicMock) -> Generator[MagicMock]:
|
||||
def mock_onedrive_client() -> Generator[MagicMock]:
|
||||
"""Return a mocked GraphServiceClient."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.onedrive.config_flow.GraphServiceClient",
|
||||
"homeassistant.components.onedrive.config_flow.OneDriveClient",
|
||||
autospec=True,
|
||||
) as graph_client,
|
||||
) as onedrive_client,
|
||||
patch(
|
||||
"homeassistant.components.onedrive.GraphServiceClient",
|
||||
new=graph_client,
|
||||
"homeassistant.components.onedrive.OneDriveClient",
|
||||
new=onedrive_client,
|
||||
),
|
||||
):
|
||||
client = graph_client.return_value
|
||||
client = onedrive_client.return_value
|
||||
client.get_approot.return_value = MOCK_APPROOT
|
||||
client.create_folder.return_value = MOCK_BACKUP_FOLDER
|
||||
client.list_drive_items.return_value = [MOCK_BACKUP_FILE]
|
||||
client.get_drive_item.return_value = MOCK_BACKUP_FILE
|
||||
|
||||
client.request_adapter = mock_adapter
|
||||
class MockStreamReader:
|
||||
async def iter_chunked(self, chunk_size: int) -> AsyncIterator[bytes]:
|
||||
yield b"backup data"
|
||||
|
||||
drives = client.drives.by_drive_id.return_value
|
||||
drives.special.by_drive_item_id.return_value.get = AsyncMock(
|
||||
return_value=DriveItem(id="approot")
|
||||
)
|
||||
|
||||
drive_items = drives.items.by_drive_item_id.return_value
|
||||
drive_items.get = AsyncMock(return_value=DriveItem(id="folder_id"))
|
||||
drive_items.children.post = AsyncMock(return_value=DriveItem(id="folder_id"))
|
||||
drive_items.children.get = AsyncMock(
|
||||
return_value=DriveItemCollectionResponse(
|
||||
value=[
|
||||
DriveItem(
|
||||
id=BACKUP_METADATA["backup_id"],
|
||||
description=escape(dumps(BACKUP_METADATA)),
|
||||
),
|
||||
DriveItem(),
|
||||
]
|
||||
)
|
||||
)
|
||||
drive_items.delete = AsyncMock(return_value=None)
|
||||
drive_items.create_upload_session.post = AsyncMock(
|
||||
return_value=UploadSession(upload_url="https://test.tld")
|
||||
)
|
||||
drive_items.patch = AsyncMock(return_value=None)
|
||||
|
||||
async def generate_bytes() -> AsyncIterator[bytes]:
|
||||
"""Asynchronous generator that yields bytes."""
|
||||
yield b"backup data"
|
||||
|
||||
drive_items.content.get = AsyncMock(
|
||||
return_value=Response(status_code=200, content=generate_bytes())
|
||||
)
|
||||
client.download_drive_item.return_value = MockStreamReader()
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_drive_items(mock_graph_client: MagicMock) -> MagicMock:
|
||||
"""Return a mocked DriveItems."""
|
||||
return mock_graph_client.drives.by_drive_id.return_value.items.by_drive_item_id.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_special_folder(mock_graph_client: MagicMock) -> MagicMock:
|
||||
"""Mock the get special folder method."""
|
||||
return mock_graph_client.drives.by_drive_id.return_value.special.by_drive_item_id.return_value.get
|
||||
def mock_large_file_upload_client() -> Generator[AsyncMock]:
|
||||
"""Return a mocked LargeFileUploadClient upload."""
|
||||
with patch(
|
||||
"homeassistant.components.onedrive.backup.LargeFileUploadClient.upload"
|
||||
) as mock_upload:
|
||||
yield mock_upload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -179,10 +121,3 @@ def mock_instance_id() -> Generator[AsyncMock]:
|
||||
return_value="9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0",
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_asyncio_sleep() -> Generator[AsyncMock]:
|
||||
"""Mock asyncio.sleep."""
|
||||
with patch("homeassistant.components.onedrive.backup.asyncio.sleep", AsyncMock()):
|
||||
yield
|
||||
|
@ -1,5 +1,18 @@
|
||||
"""Consts for OneDrive tests."""
|
||||
|
||||
from html import escape
|
||||
from json import dumps
|
||||
|
||||
from onedrive_personal_sdk.models.items import (
|
||||
AppRoot,
|
||||
Contributor,
|
||||
File,
|
||||
Folder,
|
||||
Hashes,
|
||||
ItemParentReference,
|
||||
User,
|
||||
)
|
||||
|
||||
CLIENT_ID = "1234"
|
||||
CLIENT_SECRET = "5678"
|
||||
|
||||
@ -17,3 +30,48 @@ BACKUP_METADATA = {
|
||||
"protected": False,
|
||||
"size": 34519040,
|
||||
}
|
||||
|
||||
CONTRIBUTOR = Contributor(
|
||||
user=User(
|
||||
display_name="John Doe",
|
||||
id="id",
|
||||
email="john@doe.com",
|
||||
)
|
||||
)
|
||||
|
||||
MOCK_APPROOT = AppRoot(
|
||||
id="id",
|
||||
child_count=0,
|
||||
size=0,
|
||||
name="name",
|
||||
parent_reference=ItemParentReference(
|
||||
drive_id="mock_drive_id", id="id", path="path"
|
||||
),
|
||||
created_by=CONTRIBUTOR,
|
||||
)
|
||||
|
||||
MOCK_BACKUP_FOLDER = Folder(
|
||||
id="id",
|
||||
name="name",
|
||||
size=0,
|
||||
child_count=0,
|
||||
parent_reference=ItemParentReference(
|
||||
drive_id="mock_drive_id", id="id", path="path"
|
||||
),
|
||||
created_by=CONTRIBUTOR,
|
||||
)
|
||||
|
||||
MOCK_BACKUP_FILE = File(
|
||||
id="id",
|
||||
name="23e64aec.tar",
|
||||
size=34519040,
|
||||
parent_reference=ItemParentReference(
|
||||
drive_id="mock_drive_id", id="id", path="path"
|
||||
),
|
||||
hashes=Hashes(
|
||||
quick_xor_hash="hash",
|
||||
),
|
||||
mime_type="application/x-tar",
|
||||
description=escape(dumps(BACKUP_METADATA)),
|
||||
created_by=CONTRIBUTOR,
|
||||
)
|
||||
|
@ -3,15 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from html import escape
|
||||
from io import StringIO
|
||||
from json import dumps
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from httpx import TimeoutException
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from msgraph.generated.models.drive_item import DriveItem
|
||||
from msgraph_core.models import LargeFileUploadSession
|
||||
from onedrive_personal_sdk.exceptions import (
|
||||
AuthenticationError,
|
||||
HashMismatchError,
|
||||
OneDriveException,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
|
||||
@ -102,14 +101,10 @@ async def test_agents_list_backups(
|
||||
async def test_agents_get_backup(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test agent get backup."""
|
||||
|
||||
mock_drive_items.get = AsyncMock(
|
||||
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
|
||||
)
|
||||
backup_id = BACKUP_METADATA["backup_id"]
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
|
||||
@ -140,7 +135,7 @@ async def test_agents_get_backup(
|
||||
async def test_agents_delete(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test agent delete backup."""
|
||||
client = await hass_ws_client(hass)
|
||||
@ -155,37 +150,15 @@ async def test_agents_delete(
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"agent_errors": {}}
|
||||
mock_drive_items.delete.assert_called_once()
|
||||
|
||||
|
||||
async def test_agents_delete_not_found_does_not_throw(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
) -> None:
|
||||
"""Test agent delete backup."""
|
||||
mock_drive_items.children.get = AsyncMock(return_value=[])
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "backup/delete",
|
||||
"backup_id": BACKUP_METADATA["backup_id"],
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"agent_errors": {}}
|
||||
assert mock_drive_items.delete.call_count == 0
|
||||
mock_onedrive_client.delete_drive_item.assert_called_once()
|
||||
|
||||
|
||||
async def test_agents_upload(
|
||||
hass_client: ClientSessionGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
mock_large_file_upload_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_adapter: MagicMock,
|
||||
) -> None:
|
||||
"""Test agent upload backup."""
|
||||
client = await hass_client()
|
||||
@ -200,7 +173,6 @@ async def test_agents_upload(
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
|
||||
):
|
||||
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
|
||||
fetch_backup.return_value = test_backup
|
||||
@ -211,31 +183,22 @@ async def test_agents_upload(
|
||||
|
||||
assert resp.status == 201
|
||||
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
|
||||
mock_drive_items.create_upload_session.post.assert_called_once()
|
||||
mock_drive_items.patch.assert_called_once()
|
||||
assert mock_adapter.send_async.call_count == 2
|
||||
assert mock_adapter.method_calls[0].args[0].content == b"tes"
|
||||
assert mock_adapter.method_calls[0].args[0].headers.get("Content-Range") == {
|
||||
"bytes 0-2/34519040"
|
||||
}
|
||||
assert mock_adapter.method_calls[1].args[0].content == b"t"
|
||||
assert mock_adapter.method_calls[1].args[0].headers.get("Content-Range") == {
|
||||
"bytes 3-3/34519040"
|
||||
}
|
||||
mock_large_file_upload_client.assert_called_once()
|
||||
mock_onedrive_client.update_drive_item.assert_called_once()
|
||||
|
||||
|
||||
async def test_broken_upload_session(
|
||||
async def test_agents_upload_corrupt_upload(
|
||||
hass_client: ClientSessionGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
mock_large_file_upload_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test broken upload session."""
|
||||
"""Test hash validation fails."""
|
||||
mock_large_file_upload_client.side_effect = HashMismatchError("test")
|
||||
client = await hass_client()
|
||||
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
|
||||
|
||||
mock_drive_items.create_upload_session.post = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
@ -254,152 +217,18 @@ async def test_broken_upload_session(
|
||||
)
|
||||
|
||||
assert resp.status == 201
|
||||
assert "Failed to start backup upload" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect",
|
||||
[
|
||||
APIError(response_status_code=500),
|
||||
TimeoutException("Timeout"),
|
||||
],
|
||||
)
|
||||
async def test_agents_upload_errors_retried(
|
||||
hass_client: ClientSessionGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_adapter: MagicMock,
|
||||
side_effect: Exception,
|
||||
) -> None:
|
||||
"""Test agent upload backup."""
|
||||
client = await hass_client()
|
||||
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
|
||||
|
||||
mock_adapter.send_async.side_effect = [
|
||||
side_effect,
|
||||
LargeFileUploadSession(next_expected_ranges=["2-"]),
|
||||
LargeFileUploadSession(next_expected_ranges=["2-"]),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
) as fetch_backup,
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
|
||||
):
|
||||
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
|
||||
fetch_backup.return_value = test_backup
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
assert resp.status == 201
|
||||
assert mock_adapter.send_async.call_count == 3
|
||||
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
|
||||
mock_drive_items.patch.assert_called_once()
|
||||
|
||||
|
||||
async def test_agents_upload_4xx_errors_not_retried(
|
||||
hass_client: ClientSessionGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_adapter: MagicMock,
|
||||
) -> None:
|
||||
"""Test agent upload backup."""
|
||||
client = await hass_client()
|
||||
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
|
||||
|
||||
mock_adapter.send_async.side_effect = APIError(response_status_code=404)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
) as fetch_backup,
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
|
||||
):
|
||||
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
|
||||
fetch_backup.return_value = test_backup
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
assert resp.status == 201
|
||||
assert mock_adapter.send_async.call_count == 1
|
||||
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
|
||||
assert mock_drive_items.patch.call_count == 0
|
||||
assert "Backup operation failed" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(APIError(response_status_code=500), "Backup operation failed"),
|
||||
(TimeoutException("Timeout"), "Backup operation timed out"),
|
||||
],
|
||||
)
|
||||
async def test_agents_upload_fails_after_max_retries(
|
||||
hass_client: ClientSessionGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_adapter: MagicMock,
|
||||
side_effect: Exception,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Test agent upload backup."""
|
||||
client = await hass_client()
|
||||
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
|
||||
|
||||
mock_adapter.send_async.side_effect = side_effect
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
) as fetch_backup,
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
|
||||
):
|
||||
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
|
||||
fetch_backup.return_value = test_backup
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
assert resp.status == 201
|
||||
assert mock_adapter.send_async.call_count == 6
|
||||
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
|
||||
assert mock_drive_items.patch.call_count == 0
|
||||
assert error in caplog.text
|
||||
mock_large_file_upload_client.assert_called_once()
|
||||
assert mock_onedrive_client.update_drive_item.call_count == 0
|
||||
assert "Hash validation failed, backup file might be corrupt" in caplog.text
|
||||
|
||||
|
||||
async def test_agents_download(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test agent download backup."""
|
||||
mock_drive_items.get = AsyncMock(
|
||||
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
|
||||
)
|
||||
client = await hass_client()
|
||||
backup_id = BACKUP_METADATA["backup_id"]
|
||||
|
||||
@ -408,29 +237,30 @@ async def test_agents_download(
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert await resp.content.read() == b"backup data"
|
||||
mock_drive_items.content.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(
|
||||
APIError(response_status_code=500),
|
||||
OneDriveException(),
|
||||
"Backup operation failed",
|
||||
),
|
||||
(TimeoutException("Timeout"), "Backup operation timed out"),
|
||||
(TimeoutError(), "Backup operation timed out"),
|
||||
],
|
||||
)
|
||||
async def test_delete_error(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
side_effect: Exception,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Test error during delete."""
|
||||
mock_drive_items.delete = AsyncMock(side_effect=side_effect)
|
||||
mock_onedrive_client.delete_drive_item.side_effect = AsyncMock(
|
||||
side_effect=side_effect
|
||||
)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
@ -448,14 +278,35 @@ async def test_delete_error(
|
||||
}
|
||||
|
||||
|
||||
async def test_agents_delete_not_found_does_not_throw(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_onedrive_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test agent delete backup."""
|
||||
mock_onedrive_client.list_drive_items.return_value = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "backup/delete",
|
||||
"backup_id": BACKUP_METADATA["backup_id"],
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"agent_errors": {}}
|
||||
|
||||
|
||||
async def test_agents_backup_not_found(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test backup not found."""
|
||||
|
||||
mock_drive_items.children.get = AsyncMock(return_value=[])
|
||||
mock_onedrive_client.list_drive_items.return_value = []
|
||||
backup_id = BACKUP_METADATA["backup_id"]
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
|
||||
@ -468,13 +319,13 @@ async def test_agents_backup_not_found(
|
||||
async def test_reauth_on_403(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_drive_items: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test we re-authenticate on 403."""
|
||||
|
||||
mock_drive_items.children.get = AsyncMock(
|
||||
side_effect=APIError(response_status_code=403)
|
||||
mock_onedrive_client.list_drive_items.side_effect = AuthenticationError(
|
||||
403, "Auth failed"
|
||||
)
|
||||
backup_id = BACKUP_METADATA["backup_id"]
|
||||
client = await hass_ws_client(hass)
|
||||
@ -483,7 +334,7 @@ async def test_reauth_on_403(
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"]["agent_errors"] == {
|
||||
f"{DOMAIN}.{mock_config_entry.unique_id}": "Backup operation failed"
|
||||
f"{DOMAIN}.{mock_config_entry.unique_id}": "Authentication error"
|
||||
}
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
@ -3,8 +3,7 @@
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from httpx import Response
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from onedrive_personal_sdk.exceptions import OneDriveException
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
@ -20,7 +19,7 @@ from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from . import setup_integration
|
||||
from .const import CLIENT_ID
|
||||
from .const import CLIENT_ID, MOCK_APPROOT
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
@ -89,25 +88,52 @@ async def test_full_flow(
|
||||
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_full_flow_with_owner_not_found(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
) -> None:
|
||||
"""Ensure we get a default title if the drive's owner can't be read."""
|
||||
|
||||
mock_onedrive_client.get_approot.return_value.created_by.user = None
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
await _do_get_token(hass, result, hass_client_no_auth, aioclient_mock)
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert result["title"] == "OneDrive"
|
||||
assert result["result"].unique_id == "mock_drive_id"
|
||||
assert result["data"][CONF_TOKEN][CONF_ACCESS_TOKEN] == "mock-access-token"
|
||||
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "error"),
|
||||
[
|
||||
(Exception, "unknown"),
|
||||
(APIError, "connection_error"),
|
||||
(OneDriveException, "connection_error"),
|
||||
],
|
||||
)
|
||||
async def test_flow_errors(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
mock_adapter: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
exception: Exception,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Test errors during flow."""
|
||||
|
||||
mock_adapter.get_http_response_message.side_effect = exception
|
||||
mock_onedrive_client.get_approot.side_effect = exception
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
@ -172,15 +198,12 @@ async def test_reauth_flow_id_changed(
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_adapter: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that the reauth flow fails on a different drive id."""
|
||||
mock_adapter.get_http_response_message.return_value = Response(
|
||||
status_code=200,
|
||||
json={
|
||||
"parentReference": {"driveId": "other_drive_id"},
|
||||
},
|
||||
)
|
||||
app_root = MOCK_APPROOT
|
||||
app_root.parent_reference.drive_id = "other_drive_id"
|
||||
mock_onedrive_client.get_approot.return_value = app_root
|
||||
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from kiota_abstractions.api_error import APIError
|
||||
from onedrive_personal_sdk.exceptions import AuthenticationError, OneDriveException
|
||||
import pytest
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
@ -31,82 +31,31 @@ async def test_load_unload_config_entry(
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "state"),
|
||||
[
|
||||
(APIError(response_status_code=403), ConfigEntryState.SETUP_ERROR),
|
||||
(APIError(response_status_code=500), ConfigEntryState.SETUP_RETRY),
|
||||
(AuthenticationError(403, "Auth failed"), ConfigEntryState.SETUP_ERROR),
|
||||
(OneDriveException(), ConfigEntryState.SETUP_RETRY),
|
||||
],
|
||||
)
|
||||
async def test_approot_errors(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_get_special_folder: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
side_effect: Exception,
|
||||
state: ConfigEntryState,
|
||||
) -> None:
|
||||
"""Test errors during approot retrieval."""
|
||||
mock_get_special_folder.side_effect = side_effect
|
||||
mock_onedrive_client.get_approot.side_effect = side_effect
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is state
|
||||
|
||||
|
||||
async def test_faulty_approot(
|
||||
async def test_get_integration_folder_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_get_special_folder: MagicMock,
|
||||
mock_onedrive_client: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test faulty approot retrieval."""
|
||||
mock_get_special_folder.return_value = None
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
assert "Failed to get approot folder" in caplog.text
|
||||
|
||||
|
||||
async def test_faulty_integration_folder(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_drive_items: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test faulty approot retrieval."""
|
||||
mock_drive_items.get.return_value = None
|
||||
mock_onedrive_client.create_folder.side_effect = OneDriveException()
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
assert "Failed to get backups_9f86d081 folder" in caplog.text
|
||||
|
||||
|
||||
async def test_500_error_during_backup_folder_get(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_drive_items: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test error during backup folder creation."""
|
||||
mock_drive_items.get.side_effect = APIError(response_status_code=500)
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
assert "Failed to get backups_9f86d081 folder" in caplog.text
|
||||
|
||||
|
||||
async def test_error_during_backup_folder_creation(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_drive_items: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test error during backup folder creation."""
|
||||
mock_drive_items.get.side_effect = APIError(response_status_code=404)
|
||||
mock_drive_items.children.post.side_effect = APIError()
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
assert "Failed to create backups_9f86d081 folder" in caplog.text
|
||||
|
||||
|
||||
async def test_successful_backup_folder_creation(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_drive_items: MagicMock,
|
||||
) -> None:
|
||||
"""Test successful backup folder creation."""
|
||||
mock_drive_items.get.side_effect = APIError(response_status_code=404)
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
assert mock_config_entry.state is ConfigEntryState.LOADED
|
||||
|
@ -12,12 +12,14 @@ from homeassistant.components.openai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
@ -88,6 +90,27 @@ async def test_options(
|
||||
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||
|
||||
|
||||
async def test_options_unsupported_model(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test the options form giving error about models not supported."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_CHAT_MODEL: "o1-mini",
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {"chat_model": "model_not_supported"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
@ -148,6 +171,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -158,6 +182,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
|
@ -195,6 +195,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
@ -359,6 +360,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
59
tests/components/schedule/snapshots/test_init.ambr
Normal file
59
tests/components/schedule/snapshots/test_init.ambr
Normal file
@ -0,0 +1,59 @@
|
||||
# serializer version: 1
|
||||
# name: test_service_get[schedule.from_storage-get-after-update]
|
||||
dict({
|
||||
'friday': list([
|
||||
]),
|
||||
'monday': list([
|
||||
]),
|
||||
'saturday': list([
|
||||
]),
|
||||
'sunday': list([
|
||||
]),
|
||||
'thursday': list([
|
||||
]),
|
||||
'tuesday': list([
|
||||
]),
|
||||
'wednesday': list([
|
||||
dict({
|
||||
'from': datetime.time(17, 0),
|
||||
'to': datetime.time(19, 0),
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_service_get[schedule.from_storage-get]
|
||||
dict({
|
||||
'friday': list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'party_level': 'epic',
|
||||
}),
|
||||
'from': datetime.time(17, 0),
|
||||
'to': datetime.time(23, 59, 59),
|
||||
}),
|
||||
]),
|
||||
'monday': list([
|
||||
]),
|
||||
'saturday': list([
|
||||
dict({
|
||||
'from': datetime.time(0, 0),
|
||||
'to': datetime.time(23, 59, 59),
|
||||
}),
|
||||
]),
|
||||
'sunday': list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'entry': 'VIPs only',
|
||||
}),
|
||||
'from': datetime.time(0, 0),
|
||||
'to': datetime.time(23, 59, 59, 999999),
|
||||
}),
|
||||
]),
|
||||
'thursday': list([
|
||||
]),
|
||||
'tuesday': list([
|
||||
]),
|
||||
'wednesday': list([
|
||||
]),
|
||||
})
|
||||
# ---
|
@ -8,10 +8,12 @@ from unittest.mock import patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.schedule import STORAGE_VERSION, STORAGE_VERSION_MINOR
|
||||
from homeassistant.components.schedule.const import (
|
||||
ATTR_NEXT_EVENT,
|
||||
CONF_ALL_DAYS,
|
||||
CONF_DATA,
|
||||
CONF_FRIDAY,
|
||||
CONF_FROM,
|
||||
@ -23,12 +25,14 @@ from homeassistant.components.schedule.const import (
|
||||
CONF_TUESDAY,
|
||||
CONF_WEDNESDAY,
|
||||
DOMAIN,
|
||||
SERVICE_GET,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ATTR_EDITABLE,
|
||||
ATTR_FRIENDLY_NAME,
|
||||
ATTR_ICON,
|
||||
ATTR_NAME,
|
||||
CONF_ENTITY_ID,
|
||||
CONF_ICON,
|
||||
CONF_ID,
|
||||
CONF_NAME,
|
||||
@ -754,3 +758,66 @@ async def test_ws_create(
|
||||
assert result["party_mode"][CONF_MONDAY] == [
|
||||
{CONF_FROM: "12:00:00", CONF_TO: saved_to}
|
||||
]
|
||||
|
||||
|
||||
async def test_service_get(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
schedule_setup: Callable[..., Coroutine[Any, Any, bool]],
|
||||
) -> None:
|
||||
"""Test getting a single schedule via service."""
|
||||
assert await schedule_setup()
|
||||
|
||||
entity_id = "schedule.from_storage"
|
||||
|
||||
# Test retrieving a single schedule via service call
|
||||
service_result = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_GET,
|
||||
{
|
||||
CONF_ENTITY_ID: entity_id,
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
result = service_result.get(entity_id)
|
||||
|
||||
assert set(result) == CONF_ALL_DAYS
|
||||
assert result == snapshot(name=f"{entity_id}-get")
|
||||
|
||||
# Now we update the schedule via WS
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 1,
|
||||
"type": f"{DOMAIN}/update",
|
||||
f"{DOMAIN}_id": entity_id.rsplit(".", maxsplit=1)[-1],
|
||||
CONF_NAME: "Party pooper",
|
||||
CONF_ICON: "mdi:party-pooper",
|
||||
CONF_MONDAY: [],
|
||||
CONF_TUESDAY: [],
|
||||
CONF_WEDNESDAY: [{CONF_FROM: "17:00:00", CONF_TO: "19:00:00"}],
|
||||
CONF_THURSDAY: [],
|
||||
CONF_FRIDAY: [],
|
||||
CONF_SATURDAY: [],
|
||||
CONF_SUNDAY: [],
|
||||
}
|
||||
)
|
||||
resp = await client.receive_json()
|
||||
assert resp["success"]
|
||||
|
||||
# Test retrieving the schedule via service call after WS update
|
||||
service_result = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_GET,
|
||||
{
|
||||
CONF_ENTITY_ID: entity_id,
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
result = service_result.get(entity_id)
|
||||
|
||||
assert set(result) == CONF_ALL_DAYS
|
||||
assert result == snapshot(name=f"{entity_id}-get-after-update")
|
||||
|
@ -180,6 +180,7 @@ MOCK_CONFIG = {
|
||||
"xcounts": {"expr": None, "unit": None},
|
||||
"xfreq": {"expr": None, "unit": None},
|
||||
},
|
||||
"flood:0": {"id": 0, "name": "Test name"},
|
||||
"light:0": {"name": "test light_0"},
|
||||
"light:1": {"name": "test light_1"},
|
||||
"light:2": {"name": "test light_2"},
|
||||
@ -326,6 +327,7 @@ MOCK_STATUS_RPC = {
|
||||
"em1:1": {"act_power": 123.3},
|
||||
"em1data:0": {"total_act_energy": 123456.4},
|
||||
"em1data:1": {"total_act_energy": 987654.3},
|
||||
"flood:0": {"id": 0, "alarm": False, "mute": False},
|
||||
"thermostat:0": {
|
||||
"id": 0,
|
||||
"enable": True,
|
||||
|
@ -46,3 +46,96 @@
|
||||
'state': 'off',
|
||||
})
|
||||
# ---
|
||||
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-entry]
|
||||
EntityRegistryEntrySnapshot({
|
||||
'aliases': set({
|
||||
}),
|
||||
'area_id': None,
|
||||
'capabilities': None,
|
||||
'config_entry_id': <ANY>,
|
||||
'device_class': None,
|
||||
'device_id': <ANY>,
|
||||
'disabled_by': None,
|
||||
'domain': 'binary_sensor',
|
||||
'entity_category': None,
|
||||
'entity_id': 'binary_sensor.test_name_flood',
|
||||
'has_entity_name': False,
|
||||
'hidden_by': None,
|
||||
'icon': None,
|
||||
'id': <ANY>,
|
||||
'labels': set({
|
||||
}),
|
||||
'name': None,
|
||||
'options': dict({
|
||||
}),
|
||||
'original_device_class': <BinarySensorDeviceClass.MOISTURE: 'moisture'>,
|
||||
'original_icon': None,
|
||||
'original_name': 'Test name flood',
|
||||
'platform': 'shelly',
|
||||
'previous_unique_id': None,
|
||||
'supported_features': 0,
|
||||
'translation_key': None,
|
||||
'unique_id': '123456789ABC-flood:0-flood',
|
||||
'unit_of_measurement': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-state]
|
||||
StateSnapshot({
|
||||
'attributes': ReadOnlyDict({
|
||||
'device_class': 'moisture',
|
||||
'friendly_name': 'Test name flood',
|
||||
}),
|
||||
'context': <ANY>,
|
||||
'entity_id': 'binary_sensor.test_name_flood',
|
||||
'last_changed': <ANY>,
|
||||
'last_reported': <ANY>,
|
||||
'last_updated': <ANY>,
|
||||
'state': 'off',
|
||||
})
|
||||
# ---
|
||||
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-entry]
|
||||
EntityRegistryEntrySnapshot({
|
||||
'aliases': set({
|
||||
}),
|
||||
'area_id': None,
|
||||
'capabilities': None,
|
||||
'config_entry_id': <ANY>,
|
||||
'device_class': None,
|
||||
'device_id': <ANY>,
|
||||
'disabled_by': None,
|
||||
'domain': 'binary_sensor',
|
||||
'entity_category': <EntityCategory.DIAGNOSTIC: 'diagnostic'>,
|
||||
'entity_id': 'binary_sensor.test_name_mute',
|
||||
'has_entity_name': False,
|
||||
'hidden_by': None,
|
||||
'icon': None,
|
||||
'id': <ANY>,
|
||||
'labels': set({
|
||||
}),
|
||||
'name': None,
|
||||
'options': dict({
|
||||
}),
|
||||
'original_device_class': None,
|
||||
'original_icon': None,
|
||||
'original_name': 'Test name mute',
|
||||
'platform': 'shelly',
|
||||
'previous_unique_id': None,
|
||||
'supported_features': 0,
|
||||
'translation_key': None,
|
||||
'unique_id': '123456789ABC-flood:0-mute',
|
||||
'unit_of_measurement': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-state]
|
||||
StateSnapshot({
|
||||
'attributes': ReadOnlyDict({
|
||||
'friendly_name': 'Test name mute',
|
||||
}),
|
||||
'context': <ANY>,
|
||||
'entity_id': 'binary_sensor.test_name_mute',
|
||||
'last_changed': <ANY>,
|
||||
'last_reported': <ANY>,
|
||||
'last_updated': <ANY>,
|
||||
'state': 'off',
|
||||
})
|
||||
# ---
|
||||
|
@ -496,3 +496,22 @@ async def test_blu_trv_binary_sensor_entity(
|
||||
|
||||
entry = entity_registry.async_get(entity_id)
|
||||
assert entry == snapshot(name=f"{entity_id}-entry")
|
||||
|
||||
|
||||
async def test_rpc_flood_entities(
|
||||
hass: HomeAssistant,
|
||||
mock_rpc_device: Mock,
|
||||
entity_registry: EntityRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test RPC flood sensor entities."""
|
||||
await init_integration(hass, 4)
|
||||
|
||||
for entity in ("flood", "mute"):
|
||||
entity_id = f"{BINARY_SENSOR_DOMAIN}.test_name_{entity}"
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state == snapshot(name=f"{entity_id}-state")
|
||||
|
||||
entry = entity_registry.async_get(entity_id)
|
||||
assert entry == snapshot(name=f"{entity_id}-entry")
|
||||
|
@ -18,7 +18,8 @@ from tests.common import (
|
||||
load_json_object_fixture,
|
||||
)
|
||||
|
||||
MOCK_HOST = "slzb-06.local"
|
||||
MOCK_DEVICE_NAME = "slzb-06"
|
||||
MOCK_HOST = "192.168.1.161"
|
||||
MOCK_USERNAME = "test-user"
|
||||
MOCK_PASSWORD = "test-pass"
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
DeviceRegistryEntrySnapshot({
|
||||
'area_id': None,
|
||||
'config_entries': <ANY>,
|
||||
'configuration_url': 'http://slzb-06.local',
|
||||
'configuration_url': 'http://192.168.1.161',
|
||||
'connections': set({
|
||||
tuple(
|
||||
'mac',
|
||||
|
@ -8,19 +8,20 @@ from pysmlight.exceptions import SmlightAuthError, SmlightConnectionError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.smlight.const import DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF
|
||||
from homeassistant.config_entries import SOURCE_DHCP, SOURCE_USER, SOURCE_ZEROCONF
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
|
||||
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
|
||||
|
||||
from .conftest import MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME
|
||||
from .conftest import MOCK_DEVICE_NAME, MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
DISCOVERY_INFO = ZeroconfServiceInfo(
|
||||
ip_address=ip_address("127.0.0.1"),
|
||||
ip_addresses=[ip_address("127.0.0.1")],
|
||||
ip_address=ip_address("192.168.1.161"),
|
||||
ip_addresses=[ip_address("192.168.1.161")],
|
||||
hostname="slzb-06.local.",
|
||||
name="mock_name",
|
||||
port=6638,
|
||||
@ -29,8 +30,8 @@ DISCOVERY_INFO = ZeroconfServiceInfo(
|
||||
)
|
||||
|
||||
DISCOVERY_INFO_LEGACY = ZeroconfServiceInfo(
|
||||
ip_address=ip_address("127.0.0.1"),
|
||||
ip_addresses=[ip_address("127.0.0.1")],
|
||||
ip_address=ip_address("192.168.1.161"),
|
||||
ip_addresses=[ip_address("192.168.1.161")],
|
||||
hostname="slzb-06.local.",
|
||||
name="mock_name",
|
||||
port=6638,
|
||||
@ -52,7 +53,7 @@ async def test_user_flow(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> No
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_HOST: MOCK_HOST,
|
||||
CONF_HOST: "slzb-06p7.local",
|
||||
},
|
||||
)
|
||||
|
||||
@ -76,7 +77,7 @@ async def test_zeroconf_flow(
|
||||
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
|
||||
)
|
||||
|
||||
assert result["description_placeholders"] == {"host": MOCK_HOST}
|
||||
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "confirm_discovery"
|
||||
|
||||
@ -113,7 +114,7 @@ async def test_zeroconf_flow_auth(
|
||||
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
|
||||
)
|
||||
|
||||
assert result["description_placeholders"] == {"host": MOCK_HOST}
|
||||
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "confirm_discovery"
|
||||
|
||||
@ -167,7 +168,7 @@ async def test_zeroconf_unsupported_abort(
|
||||
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
|
||||
)
|
||||
|
||||
assert result["description_placeholders"] == {"host": MOCK_HOST}
|
||||
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "confirm_discovery"
|
||||
|
||||
@ -489,7 +490,7 @@ async def test_zeroconf_legacy_mac(
|
||||
data=DISCOVERY_INFO_LEGACY,
|
||||
)
|
||||
|
||||
assert result["description_placeholders"] == {"host": MOCK_HOST}
|
||||
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input={}
|
||||
@ -507,6 +508,76 @@ async def test_zeroconf_legacy_mac(
|
||||
assert len(mock_smlight_client.get_info.mock_calls) == 3
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_smlight_client")
|
||||
async def test_zeroconf_updates_host(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test zeroconf discovery updates host ip."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
service_info = DISCOVERY_INFO
|
||||
service_info.ip_address = ip_address("192.168.1.164")
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_smlight_client")
|
||||
async def test_dhcp_discovery_updates_host(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test dhcp discovery updates host ip."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
service_info = DhcpServiceInfo(
|
||||
ip="192.168.1.164",
|
||||
hostname="slzb-06",
|
||||
macaddress="aabbccddeeff",
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_smlight_client")
|
||||
async def test_dhcp_discovery_aborts(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test dhcp discovery updates host ip."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
service_info = DhcpServiceInfo(
|
||||
ip="192.168.1.161",
|
||||
hostname="slzb-06",
|
||||
macaddress="000000000000",
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
assert mock_config_entry.data[CONF_HOST] == "192.168.1.161"
|
||||
|
||||
|
||||
async def test_reauth_flow(
|
||||
hass: HomeAssistant,
|
||||
mock_smlight_client: MagicMock,
|
||||
|
@ -153,3 +153,25 @@ async def humidifier_config_entry(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture(name="switch_old_id_config_entry")
|
||||
async def switch_old_id_config_entry(
|
||||
hass: HomeAssistant, requests_mock: requests_mock.Mocker, config
|
||||
) -> MockConfigEntry:
|
||||
"""Create a mock VeSync config entry for `switch` with the old unique ID approach."""
|
||||
entry = MockConfigEntry(
|
||||
title="VeSync",
|
||||
domain=DOMAIN,
|
||||
data=config[DOMAIN],
|
||||
version=1,
|
||||
minor_version=1,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
wall_switch = "Wall Switch"
|
||||
humidifer = "Humidifier 200s"
|
||||
|
||||
mock_multiple_device_responses(requests_mock, [wall_switch, humidifer])
|
||||
|
||||
return entry
|
||||
|
@ -128,7 +128,7 @@
|
||||
'sleep',
|
||||
'manual',
|
||||
]),
|
||||
'mode': None,
|
||||
'mode': 'humidity',
|
||||
'night_light': True,
|
||||
'pid': None,
|
||||
'speed': None,
|
||||
@ -160,6 +160,30 @@
|
||||
# ---
|
||||
# name: test_async_get_device_diagnostics__single_fan
|
||||
dict({
|
||||
'_config_dict': dict({
|
||||
'features': list([
|
||||
'air_quality',
|
||||
]),
|
||||
'levels': list([
|
||||
1,
|
||||
2,
|
||||
]),
|
||||
'models': list([
|
||||
'LV-PUR131S',
|
||||
'LV-RH131S',
|
||||
]),
|
||||
'modes': list([
|
||||
'manual',
|
||||
'auto',
|
||||
'sleep',
|
||||
'off',
|
||||
]),
|
||||
'module': 'VeSyncAir131',
|
||||
}),
|
||||
'_features': list([
|
||||
'air_quality',
|
||||
]),
|
||||
'air_quality_feature': True,
|
||||
'cid': 'abcdefghabcdefghabcdefghabcdefgh',
|
||||
'config': dict({
|
||||
}),
|
||||
@ -180,6 +204,7 @@
|
||||
'device_region': 'US',
|
||||
'device_status': 'unknown',
|
||||
'device_type': 'LV-PUR131S',
|
||||
'enabled': True,
|
||||
'extension': None,
|
||||
'home_assistant': dict({
|
||||
'disabled': False,
|
||||
@ -271,6 +296,12 @@
|
||||
'mac_id': '**REDACTED**',
|
||||
'manager': '**REDACTED**',
|
||||
'mode': None,
|
||||
'modes': list([
|
||||
'manual',
|
||||
'auto',
|
||||
'sleep',
|
||||
'off',
|
||||
]),
|
||||
'pid': None,
|
||||
'speed': None,
|
||||
'sub_device_no': None,
|
||||
|
@ -367,7 +367,7 @@
|
||||
'previous_unique_id': None,
|
||||
'supported_features': 0,
|
||||
'translation_key': None,
|
||||
'unique_id': 'outlet',
|
||||
'unique_id': 'outlet-device_status',
|
||||
'unit_of_measurement': None,
|
||||
}),
|
||||
])
|
||||
@ -525,7 +525,7 @@
|
||||
'previous_unique_id': None,
|
||||
'supported_features': 0,
|
||||
'translation_key': None,
|
||||
'unique_id': 'switch',
|
||||
'unique_id': 'switch-device_status',
|
||||
'unit_of_measurement': None,
|
||||
}),
|
||||
])
|
||||
|
@ -10,6 +10,9 @@ from homeassistant.components.vesync.const import DOMAIN, VS_DEVICES, VS_MANAGER
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_async_setup_entry__not_login(
|
||||
@ -125,3 +128,55 @@ async def test_async_new_device_discovery(
|
||||
assert manager.login.call_count == 1
|
||||
assert hass.data[DOMAIN][VS_MANAGER] == manager
|
||||
assert hass.data[DOMAIN][VS_DEVICES] == [fan, humidifier]
|
||||
|
||||
|
||||
async def test_migrate_config_entry(
|
||||
hass: HomeAssistant,
|
||||
switch_old_id_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration of config entry. Only migrates switches to a new unique_id."""
|
||||
switch: er.RegistryEntry = entity_registry.async_get_or_create(
|
||||
domain="switch",
|
||||
platform="vesync",
|
||||
unique_id="switch",
|
||||
config_entry=switch_old_id_config_entry,
|
||||
suggested_object_id="switch",
|
||||
)
|
||||
|
||||
humidifer: er.RegistryEntry = entity_registry.async_get_or_create(
|
||||
domain="humidifer",
|
||||
platform="vesync",
|
||||
unique_id="humidifer",
|
||||
config_entry=switch_old_id_config_entry,
|
||||
suggested_object_id="humidifer",
|
||||
)
|
||||
|
||||
assert switch.unique_id == "switch"
|
||||
assert switch_old_id_config_entry.minor_version == 1
|
||||
assert humidifer.unique_id == "humidifer"
|
||||
|
||||
await hass.config_entries.async_setup(switch_old_id_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert switch_old_id_config_entry.minor_version == 2
|
||||
|
||||
migrated_switch = entity_registry.async_get(switch.entity_id)
|
||||
assert migrated_switch is not None
|
||||
assert migrated_switch.entity_id.startswith("switch")
|
||||
assert migrated_switch.unique_id == "switch-device_status"
|
||||
# Confirm humidifer was not impacted
|
||||
migrated_humidifer = entity_registry.async_get(humidifer.entity_id)
|
||||
assert migrated_humidifer is not None
|
||||
assert migrated_humidifer.unique_id == "humidifer"
|
||||
|
||||
# Assert that only one entity exists in the switch domain
|
||||
switch_entities = [
|
||||
e for e in entity_registry.entities.values() if e.domain == "switch"
|
||||
]
|
||||
assert len(switch_entities) == 1
|
||||
|
||||
humidifer_entities = [
|
||||
e for e in entity_registry.entities.values() if e.domain == "humidifer"
|
||||
]
|
||||
assert len(humidifer_entities) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user