Files
core/homeassistant/components/usage_prediction/common_control.py
J. Nick Koston eebaa42496 split out of exec
2025-09-24 21:45:59 -05:00

238 lines
7.4 KiB
Python

"""Code to generate common control usage patterns."""
from __future__ import annotations
from collections import Counter
from collections.abc import Callable, Sequence
from datetime import datetime, timedelta
from functools import cache
import logging
from typing import Any, Literal, cast
from sqlalchemy import select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes
from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
from homeassistant.components.recorder.util import session_scope
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.util import dt as dt_util
from homeassistant.util.json import json_loads_object
from .models import EntityUsagePredictions
_LOGGER = logging.getLogger(__name__)
# Time categories for usage patterns
TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
RESULTS_TO_INCLUDE = 8
# List of domains for which we want to track usage
ALLOWED_DOMAINS = {
# Entity platforms
Platform.AIR_QUALITY,
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.CAMERA,
Platform.CLIMATE,
Platform.COVER,
Platform.FAN,
Platform.HUMIDIFIER,
Platform.LAWN_MOWER,
Platform.LIGHT,
Platform.LOCK,
Platform.MEDIA_PLAYER,
Platform.NUMBER,
Platform.SCENE,
Platform.SELECT,
Platform.SENSOR,
Platform.SIREN,
Platform.SWITCH,
Platform.VACUUM,
Platform.VALVE,
Platform.WATER_HEATER,
# Helpers with own domain
"counter",
"group",
"input_boolean",
"input_button",
"input_datetime",
"input_number",
"input_select",
"input_text",
"schedule",
"timer",
}
@cache
def time_category(hour: int) -> Literal["morning", "afternoon", "evening", "night"]:
"""Determine the time category for a given hour."""
if 6 <= hour < 12:
return "morning"
if 12 <= hour < 18:
return "afternoon"
if 18 <= hour < 22:
return "evening"
return "night"
async def async_predict_common_control(
hass: HomeAssistant, user_id: str
) -> EntityUsagePredictions:
"""Generate a list of commonly used entities for a user.
Args:
hass: Home Assistant instance
user_id: User ID to filter events by.
Returns:
Dictionary with time categories as keys and lists of most common entity IDs as values
"""
# Get the recorder instance to ensure it's ready
recorder = get_instance(hass)
ent_reg = er.async_get(hass)
# Execute the database operation in the recorder's executor
data = await recorder.async_add_executor_job(
_fetch_with_session, hass, _fetch_and_process_data, ent_reg, user_id
)
# Prepare a dictionary to track results
results: dict[str, Counter[str]] = {
time_cat: Counter() for time_cat in TIME_CATEGORIES
}
# Keep track of contexts that we processed so that we will only process
# the first service call in a context, and not subsequent calls.
context_processed: set[bytes] = set()
# Execute the query
context_id: bytes
time_fired_ts: float
shared_data: str | None
local_time_zone = dt_util.get_default_time_zone()
for context_id, time_fired_ts, shared_data in data:
# Skip if we have already processed an event that was part of this context
if context_id in context_processed:
continue
# Mark this context as processed
context_processed.add(context_id)
# Parse the event data
if not shared_data:
continue
try:
event_data = json_loads_object(shared_data)
except (ValueError, TypeError) as err:
_LOGGER.debug("Failed to parse event data: %s", err)
continue
# Empty event data, skipping
if not event_data:
continue
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
# No service data found, skipping
if not service_data:
continue
entity_ids: str | list[str] | None
if (target := service_data.get("target")) and (
target_entity_ids := target.get("entity_id")
):
entity_ids = target_entity_ids
else:
entity_ids = service_data.get("entity_id")
# No entity IDs found, skip this event
if entity_ids is None:
continue
if not isinstance(entity_ids, list):
entity_ids = [entity_ids]
# Filter out entity IDs that are not in allowed domains
entity_ids = [
entity_id
for entity_id in entity_ids
if entity_id.split(".")[0] in ALLOWED_DOMAINS
and ((entry := ent_reg.async_get(entity_id)) is None or not entry.hidden)
]
if not entity_ids:
continue
# Convert timestamp to datetime and determine time category
if time_fired_ts:
# Convert to local time for time category determination
period = time_category(
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
)
# Count entity usage
for entity_id in entity_ids:
results[period][entity_id] += 1
return EntityUsagePredictions(
morning=[
ent_id for (ent_id, _) in results["morning"].most_common(RESULTS_TO_INCLUDE)
],
afternoon=[
ent_id
for (ent_id, _) in results["afternoon"].most_common(RESULTS_TO_INCLUDE)
],
evening=[
ent_id for (ent_id, _) in results["evening"].most_common(RESULTS_TO_INCLUDE)
],
night=[
ent_id for (ent_id, _) in results["night"].most_common(RESULTS_TO_INCLUDE)
],
)
def _fetch_and_process_data(
session: Session, ent_reg: er.EntityRegistry, user_id: str
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
"""Fetch and process service call events from the database."""
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
if not user_id_bytes:
raise ValueError("Invalid user_id format")
# Build the main query for events with their data
query = (
select(
Events.context_id_bin,
Events.time_fired_ts,
EventData.shared_data,
)
.select_from(Events)
.outerjoin(EventData, Events.data_id == EventData.data_id)
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
.where(Events.time_fired_ts >= thirty_days_ago_ts)
.where(Events.context_user_id_bin == user_id_bytes)
.where(EventTypes.event_type == "call_service")
.order_by(Events.time_fired_ts)
)
return session.connection().execute(query).all()
def _fetch_with_session(
hass: HomeAssistant,
fetch_func: Callable[
[Session], Sequence[Row[tuple[bytes | None, float | None, str | None]]]
],
*args: object,
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
"""Execute a fetch function with a database session."""
with session_scope(hass=hass, read_only=True) as session:
return fetch_func(session, *args)