mirror of
https://github.com/home-assistant/core.git
synced 2025-11-06 01:19:29 +00:00
238 lines
7.4 KiB
Python
238 lines
7.4 KiB
Python
"""Code to generate common control usage patterns."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import Counter
|
|
from collections.abc import Callable, Sequence
|
|
from datetime import datetime, timedelta
|
|
from functools import cache
|
|
import logging
|
|
from typing import Any, Literal, cast
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.engine.row import Row
|
|
from sqlalchemy.orm import Session
|
|
|
|
from homeassistant.components.recorder import get_instance
|
|
from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes
|
|
from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
|
|
from homeassistant.components.recorder.util import session_scope
|
|
from homeassistant.const import Platform
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers import entity_registry as er
|
|
from homeassistant.util import dt as dt_util
|
|
from homeassistant.util.json import json_loads_object
|
|
|
|
from .models import EntityUsagePredictions
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Time categories for usage patterns
|
|
TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
|
|
|
|
RESULTS_TO_INCLUDE = 8
|
|
|
|
# List of domains for which we want to track usage
|
|
ALLOWED_DOMAINS = {
|
|
# Entity platforms
|
|
Platform.AIR_QUALITY,
|
|
Platform.ALARM_CONTROL_PANEL,
|
|
Platform.BINARY_SENSOR,
|
|
Platform.BUTTON,
|
|
Platform.CAMERA,
|
|
Platform.CLIMATE,
|
|
Platform.COVER,
|
|
Platform.FAN,
|
|
Platform.HUMIDIFIER,
|
|
Platform.LAWN_MOWER,
|
|
Platform.LIGHT,
|
|
Platform.LOCK,
|
|
Platform.MEDIA_PLAYER,
|
|
Platform.NUMBER,
|
|
Platform.SCENE,
|
|
Platform.SELECT,
|
|
Platform.SENSOR,
|
|
Platform.SIREN,
|
|
Platform.SWITCH,
|
|
Platform.VACUUM,
|
|
Platform.VALVE,
|
|
Platform.WATER_HEATER,
|
|
# Helpers with own domain
|
|
"counter",
|
|
"group",
|
|
"input_boolean",
|
|
"input_button",
|
|
"input_datetime",
|
|
"input_number",
|
|
"input_select",
|
|
"input_text",
|
|
"schedule",
|
|
"timer",
|
|
}
|
|
|
|
|
|
@cache
|
|
def time_category(hour: int) -> Literal["morning", "afternoon", "evening", "night"]:
|
|
"""Determine the time category for a given hour."""
|
|
if 6 <= hour < 12:
|
|
return "morning"
|
|
if 12 <= hour < 18:
|
|
return "afternoon"
|
|
if 18 <= hour < 22:
|
|
return "evening"
|
|
return "night"
|
|
|
|
|
|
async def async_predict_common_control(
|
|
hass: HomeAssistant, user_id: str
|
|
) -> EntityUsagePredictions:
|
|
"""Generate a list of commonly used entities for a user.
|
|
|
|
Args:
|
|
hass: Home Assistant instance
|
|
user_id: User ID to filter events by.
|
|
|
|
Returns:
|
|
Dictionary with time categories as keys and lists of most common entity IDs as values
|
|
"""
|
|
# Get the recorder instance to ensure it's ready
|
|
recorder = get_instance(hass)
|
|
ent_reg = er.async_get(hass)
|
|
|
|
# Execute the database operation in the recorder's executor
|
|
data = await recorder.async_add_executor_job(
|
|
_fetch_with_session, hass, _fetch_and_process_data, ent_reg, user_id
|
|
)
|
|
# Prepare a dictionary to track results
|
|
results: dict[str, Counter[str]] = {
|
|
time_cat: Counter() for time_cat in TIME_CATEGORIES
|
|
}
|
|
|
|
# Keep track of contexts that we processed so that we will only process
|
|
# the first service call in a context, and not subsequent calls.
|
|
context_processed: set[bytes] = set()
|
|
# Execute the query
|
|
context_id: bytes
|
|
time_fired_ts: float
|
|
shared_data: str | None
|
|
local_time_zone = dt_util.get_default_time_zone()
|
|
for context_id, time_fired_ts, shared_data in data:
|
|
# Skip if we have already processed an event that was part of this context
|
|
if context_id in context_processed:
|
|
continue
|
|
|
|
# Mark this context as processed
|
|
context_processed.add(context_id)
|
|
|
|
# Parse the event data
|
|
if not shared_data:
|
|
continue
|
|
|
|
try:
|
|
event_data = json_loads_object(shared_data)
|
|
except (ValueError, TypeError) as err:
|
|
_LOGGER.debug("Failed to parse event data: %s", err)
|
|
continue
|
|
|
|
# Empty event data, skipping
|
|
if not event_data:
|
|
continue
|
|
|
|
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
|
|
|
|
# No service data found, skipping
|
|
if not service_data:
|
|
continue
|
|
|
|
entity_ids: str | list[str] | None
|
|
if (target := service_data.get("target")) and (
|
|
target_entity_ids := target.get("entity_id")
|
|
):
|
|
entity_ids = target_entity_ids
|
|
else:
|
|
entity_ids = service_data.get("entity_id")
|
|
|
|
# No entity IDs found, skip this event
|
|
if entity_ids is None:
|
|
continue
|
|
|
|
if not isinstance(entity_ids, list):
|
|
entity_ids = [entity_ids]
|
|
|
|
# Filter out entity IDs that are not in allowed domains
|
|
entity_ids = [
|
|
entity_id
|
|
for entity_id in entity_ids
|
|
if entity_id.split(".")[0] in ALLOWED_DOMAINS
|
|
and ((entry := ent_reg.async_get(entity_id)) is None or not entry.hidden)
|
|
]
|
|
|
|
if not entity_ids:
|
|
continue
|
|
|
|
# Convert timestamp to datetime and determine time category
|
|
if time_fired_ts:
|
|
# Convert to local time for time category determination
|
|
period = time_category(
|
|
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
|
|
)
|
|
|
|
# Count entity usage
|
|
for entity_id in entity_ids:
|
|
results[period][entity_id] += 1
|
|
|
|
return EntityUsagePredictions(
|
|
morning=[
|
|
ent_id for (ent_id, _) in results["morning"].most_common(RESULTS_TO_INCLUDE)
|
|
],
|
|
afternoon=[
|
|
ent_id
|
|
for (ent_id, _) in results["afternoon"].most_common(RESULTS_TO_INCLUDE)
|
|
],
|
|
evening=[
|
|
ent_id for (ent_id, _) in results["evening"].most_common(RESULTS_TO_INCLUDE)
|
|
],
|
|
night=[
|
|
ent_id for (ent_id, _) in results["night"].most_common(RESULTS_TO_INCLUDE)
|
|
],
|
|
)
|
|
|
|
|
|
def _fetch_and_process_data(
|
|
session: Session, ent_reg: er.EntityRegistry, user_id: str
|
|
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
|
|
"""Fetch and process service call events from the database."""
|
|
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
|
|
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
|
|
if not user_id_bytes:
|
|
raise ValueError("Invalid user_id format")
|
|
|
|
# Build the main query for events with their data
|
|
query = (
|
|
select(
|
|
Events.context_id_bin,
|
|
Events.time_fired_ts,
|
|
EventData.shared_data,
|
|
)
|
|
.select_from(Events)
|
|
.outerjoin(EventData, Events.data_id == EventData.data_id)
|
|
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
|
|
.where(Events.time_fired_ts >= thirty_days_ago_ts)
|
|
.where(Events.context_user_id_bin == user_id_bytes)
|
|
.where(EventTypes.event_type == "call_service")
|
|
.order_by(Events.time_fired_ts)
|
|
)
|
|
return session.connection().execute(query).all()
|
|
|
|
|
|
def _fetch_with_session(
|
|
hass: HomeAssistant,
|
|
fetch_func: Callable[
|
|
[Session], Sequence[Row[tuple[bytes | None, float | None, str | None]]]
|
|
],
|
|
*args: object,
|
|
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
|
|
"""Execute a fetch function with a database session."""
|
|
with session_scope(hass=hass, read_only=True) as session:
|
|
return fetch_func(session, *args)
|