From b9aba30a6e8212c28261ec24489001bd14c345de Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 3 Jan 2020 21:37:11 +0100 Subject: [PATCH] Extract Collection helper from Person integration (#30313) * Add CRUD foundation * Use collection helper in person integration * Lint/pytest * Add tests * Lint * Create notification --- .../components/device_automation/__init__.py | 12 +- homeassistant/components/person/__init__.py | 429 +++++++----------- .../components/websocket_api/__init__.py | 19 +- .../components/websocket_api/commands.py | 4 +- .../components/websocket_api/connection.py | 6 +- .../components/websocket_api/const.py | 11 + .../components/websocket_api/decorators.py | 16 +- homeassistant/helpers/collection.py | 401 ++++++++++++++++ homeassistant/helpers/entity.py | 3 +- tests/components/conftest.py | 42 -- tests/components/onboarding/test_views.py | 2 +- tests/components/person/test_init.py | 128 +++--- tests/conftest.py | 41 ++ tests/helpers/test_collection.py | 356 +++++++++++++++ 14 files changed, 1074 insertions(+), 396 deletions(-) create mode 100644 homeassistant/helpers/collection.py create mode 100644 tests/helpers/test_collection.py diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 872a4af6cd6..56e087f0e5f 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -173,13 +173,13 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom return capabilities -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/action/list", vol.Required("device_id"): str, } ) +@websocket_api.async_response async def websocket_device_automation_list_actions(hass, connection, msg): """Handle request for device actions.""" device_id = msg["device_id"] @@ -187,13 +187,13 @@ async def websocket_device_automation_list_actions(hass, connection, msg): connection.send_result(msg["id"], actions) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/condition/list", vol.Required("device_id"): str, } ) +@websocket_api.async_response async def websocket_device_automation_list_conditions(hass, connection, msg): """Handle request for device conditions.""" device_id = msg["device_id"] @@ -201,13 +201,13 @@ async def websocket_device_automation_list_conditions(hass, connection, msg): connection.send_result(msg["id"], conditions) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/trigger/list", vol.Required("device_id"): str, } ) +@websocket_api.async_response async def websocket_device_automation_list_triggers(hass, connection, msg): """Handle request for device triggers.""" device_id = msg["device_id"] @@ -215,13 +215,13 @@ async def websocket_device_automation_list_triggers(hass, connection, msg): connection.send_result(msg["id"], triggers) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/action/capabilities", vol.Required("action"): dict, } ) +@websocket_api.async_response async def websocket_device_automation_get_action_capabilities(hass, connection, msg): """Handle request for device action capabilities.""" action = msg["action"] @@ -231,13 +231,13 @@ async def websocket_device_automation_get_action_capabilities(hass, connection, connection.send_result(msg["id"], capabilities) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/condition/capabilities", vol.Required("condition"): dict, } ) +@websocket_api.async_response async def websocket_device_automation_get_condition_capabilities(hass, connection, msg): """Handle request for device condition capabilities.""" condition = msg["condition"] @@ -247,13 +247,13 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio connection.send_result(msg["id"], capabilities) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "device_automation/trigger/capabilities", vol.Required("trigger"): dict, } ) +@websocket_api.async_response async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg): """Handle request for device trigger capabilities.""" trigger = msg["trigger"] diff --git a/homeassistant/components/person/__init__.py b/homeassistant/components/person/__init__.py index 2e347cf4d49..4d211ed39de 100644 --- a/homeassistant/components/person/__init__.py +++ b/homeassistant/components/person/__init__.py @@ -1,9 +1,6 @@ """Support for tracking people.""" -from collections import OrderedDict -from itertools import chain import logging -from typing import Optional -import uuid +from typing import List, Optional import voluptuous as vol @@ -28,6 +25,7 @@ from homeassistant.const import ( STATE_UNKNOWN, ) from homeassistant.core import Event, State, callback +from homeassistant.helpers import collection, entity_registry import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_track_state_change @@ -48,8 +46,7 @@ CONF_USER_ID = "user_id" DOMAIN = "person" STORAGE_KEY = DOMAIN -STORAGE_VERSION = 1 -SAVE_DELAY = 10 +STORAGE_VERSION = 2 # Device tracker states to ignore IGNORE_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE) @@ -75,217 +72,184 @@ _UNDEF = object() @bind_hass async def async_create_person(hass, name, *, user_id=None, device_trackers=None): """Create a new person.""" - await hass.data[DOMAIN].async_create_person( - name=name, user_id=user_id, device_trackers=device_trackers + await hass.data[DOMAIN][1].async_create_item( + {"name": name, "user_id": user_id, "device_trackers": device_trackers} ) -class PersonManager: - """Manage person data.""" +CREATE_FIELDS = { + vol.Required("name"): vol.All(str, vol.Length(min=1)), + vol.Optional("user_id"): vol.Any(str, None), + vol.Optional("device_trackers", default=list): vol.All( + cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN) + ), +} + + +UPDATE_FIELDS = { + vol.Optional("name"): vol.All(str, vol.Length(min=1)), + vol.Optional("user_id"): vol.Any(str, None), + vol.Optional("device_trackers", default=list): vol.All( + cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN) + ), +} + + +class PersonStore(Store): + """Person storage.""" + + async def _async_migrate_func(self, old_version, old_data): + """Migrate to the new version. + + Migrate storage to use format of collection helper. + """ + return {"items": old_data["persons"]} + + +class PersonStorageCollection(collection.StorageCollection): + """Person collection stored in storage.""" + + CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) + UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) def __init__( - self, hass: HomeAssistantType, component: EntityComponent, config_persons + self, + store: Store, + logger: logging.Logger, + id_manager: collection.IDManager, + yaml_collection: collection.YamlCollection, ): - """Initialize person storage.""" - self.hass = hass - self.component = component - self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY) - self.storage_data = None + """Initialize a person storage collection.""" + super().__init__(store, logger, id_manager) + self.async_add_listener(self._collection_changed) + self.yaml_collection = yaml_collection - config_data = self.config_data = OrderedDict() - for conf in config_persons: - person_id = conf[CONF_ID] + async def _process_create_data(self, data: dict) -> dict: + """Validate the config is valid.""" + data = self.CREATE_SCHEMA(data) - if person_id in config_data: - _LOGGER.error("Found config user with duplicate ID: %s", person_id) - continue - - config_data[person_id] = conf - - @property - def storage_persons(self): - """Iterate over persons stored in storage.""" - return list(self.storage_data.values()) - - @property - def config_persons(self): - """Iterate over persons stored in config.""" - return list(self.config_data.values()) - - async def async_initialize(self): - """Get the person data.""" - raw_storage = await self.store.async_load() - - if raw_storage is None: - raw_storage = {"persons": []} - - storage_data = self.storage_data = OrderedDict() - - for person in raw_storage["persons"]: - storage_data[person[CONF_ID]] = person - - entities = [] - seen_users = set() - - for person_conf in self.config_data.values(): - person_id = person_conf[CONF_ID] - user_id = person_conf.get(CONF_USER_ID) - - if user_id is not None: - if await self.hass.auth.async_get_user(user_id) is None: - _LOGGER.error("Invalid user_id detected for person %s", person_id) - continue - - if user_id in seen_users: - _LOGGER.error( - "Duplicate user_id %s detected for person %s", - user_id, - person_id, - ) - continue - - seen_users.add(user_id) - - entities.append(Person(person_conf, False)) - - # To make sure IDs don't overlap between config/storage - seen_persons = set(self.config_data) - - for person_conf in storage_data.values(): - person_id = person_conf[CONF_ID] - user_id = person_conf[CONF_USER_ID] - - if person_id in seen_persons: - _LOGGER.error( - "Skipping adding person from storage with same ID as" - " configuration.yaml entry: %s", - person_id, - ) - continue - - if user_id is not None and user_id in seen_users: - _LOGGER.error( - "Duplicate user_id %s detected for person %s", user_id, person_id - ) - continue - - # To make sure all users have just 1 person linked. - seen_users.add(user_id) - - entities.append(Person(person_conf, True)) - - if entities: - await self.component.async_add_entities(entities) - - self.hass.bus.async_listen(EVENT_USER_REMOVED, self._user_removed) - - async def async_create_person(self, *, name, device_trackers=None, user_id=None): - """Create a new person.""" - if not name: - raise ValueError("Name is required") + user_id = data.get("user_id") if user_id is not None: await self._validate_user_id(user_id) - person = { - CONF_ID: uuid.uuid4().hex, - CONF_NAME: name, - CONF_USER_ID: user_id, - CONF_DEVICE_TRACKERS: device_trackers or [], - } - self.storage_data[person[CONF_ID]] = person - self._async_schedule_save() - await self.component.async_add_entities([Person(person, True)]) - return person + return self.CREATE_SCHEMA(data) - async def async_update_person( - self, person_id, *, name=_UNDEF, device_trackers=_UNDEF, user_id=_UNDEF - ): - """Update person.""" - current = self.storage_data.get(person_id) + @callback + def _get_suggested_id(self, info: dict) -> str: + """Suggest an ID based on the config.""" + return info["name"] - if current is None: - raise ValueError("Invalid person specified.") + async def _update_data(self, data: dict, update_data: dict) -> dict: + """Return a new updated data object.""" + update_data = self.UPDATE_SCHEMA(update_data) - changes = { - key: value - for key, value in ( - (CONF_NAME, name), - (CONF_DEVICE_TRACKERS, device_trackers), - (CONF_USER_ID, user_id), - ) - if value is not _UNDEF and current[key] != value - } + user_id = update_data.get("user_id") - if CONF_USER_ID in changes and user_id is not None: + if user_id is not None: await self._validate_user_id(user_id) - self.storage_data[person_id].update(changes) - self._async_schedule_save() - - for entity in self.component.entities: - if entity.unique_id == person_id: - entity.person_updated() - break - - return self.storage_data[person_id] - - async def async_delete_person(self, person_id): - """Delete person.""" - if person_id not in self.storage_data: - raise ValueError("Invalid person specified.") - - self.storage_data.pop(person_id) - self._async_schedule_save() - ent_reg = await self.hass.helpers.entity_registry.async_get_registry() - - for entity in self.component.entities: - if entity.unique_id == person_id: - await entity.async_remove() - ent_reg.async_remove(entity.entity_id) - break - - @callback - def _async_schedule_save(self) -> None: - """Schedule saving the area registry.""" - self.store.async_delay_save(self._data_to_save, SAVE_DELAY) - - @callback - def _data_to_save(self) -> dict: - """Return data of area registry to store in a file.""" - return {"persons": list(self.storage_data.values())} + return {**data, **update_data} async def _validate_user_id(self, user_id): """Validate the used user_id.""" if await self.hass.auth.async_get_user(user_id) is None: raise ValueError("User does not exist") - if any( - person - for person in chain(self.storage_data.values(), self.config_data.values()) - if person.get(CONF_USER_ID) == user_id - ): - raise ValueError("User already taken") + for persons in (self.data.values(), self.yaml_collection.async_items()): + if any(person for person in persons if person.get(CONF_USER_ID) == user_id): + raise ValueError("User already taken") - async def _user_removed(self, event: Event): - """Handle event that a person is removed.""" - user_id = event.data["user_id"] - for person in self.storage_data.values(): - if person[CONF_USER_ID] == user_id: - await self.async_update_person(person_id=person[CONF_ID], user_id=None) + async def _collection_changed( + self, change_type: str, item_id: str, config: Optional[dict] + ) -> None: + """Handle a collection change.""" + if change_type != collection.CHANGE_REMOVED: + return + + ent_reg = await entity_registry.async_get_registry(self.hass) + ent_reg.async_remove(ent_reg.async_get_entity_id(DOMAIN, DOMAIN, item_id)) + + +async def filter_yaml_data(hass: HomeAssistantType, persons: List[dict]) -> List[dict]: + """Validate YAML data that we can't validate via schema.""" + filtered = [] + person_invalid_user = [] + + for person_conf in persons: + user_id = person_conf.get(CONF_USER_ID) + + if user_id is not None: + if await hass.auth.async_get_user(user_id) is None: + _LOGGER.error( + "Invalid user_id detected for person %s", + person_conf[collection.CONF_ID], + ) + person_invalid_user.append( + f"- Person {person_conf[CONF_NAME]} (id: {person_conf[collection.CONF_ID]}) points at invalid user {user_id}" + ) + continue + + filtered.append(person_conf) + + if person_invalid_user: + hass.components.persistent_notification.async_create( + f""" +The following persons point at invalid users: + +{"- ".join(person_invalid_user)} + """, + "Invalid Person Configuration", + DOMAIN, + ) + + return filtered async def async_setup(hass: HomeAssistantType, config: ConfigType): """Set up the person component.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) - conf_persons = config.get(DOMAIN, []) - manager = hass.data[DOMAIN] = PersonManager(hass, component, conf_persons) - await manager.async_initialize() + entity_component = EntityComponent(_LOGGER, DOMAIN, hass) + id_manager = collection.IDManager() + yaml_collection = collection.YamlCollection( + logging.getLogger(f"{__name__}.yaml_collection"), id_manager + ) + storage_collection = PersonStorageCollection( + PersonStore(hass, STORAGE_VERSION, STORAGE_KEY), + logging.getLogger(f"{__name__}.storage_collection"), + id_manager, + yaml_collection, + ) + + collection.attach_entity_component_collection( + entity_component, yaml_collection, lambda conf: Person(conf, False) + ) + collection.attach_entity_component_collection( + entity_component, storage_collection, lambda conf: Person(conf, True) + ) + + await yaml_collection.async_load( + await filter_yaml_data(hass, config.get(DOMAIN, [])) + ) + await storage_collection.async_load() + + hass.data[DOMAIN] = (yaml_collection, storage_collection) + + collection.StorageCollectionWebsocket( + storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + ).async_setup(hass, create_list=False) websocket_api.async_register_command(hass, ws_list_person) - websocket_api.async_register_command(hass, ws_create_person) - websocket_api.async_register_command(hass, ws_update_person) - websocket_api.async_register_command(hass, ws_delete_person) + + async def _handle_user_removed(event: Event) -> None: + """Handle a user being removed.""" + user_id = event.data["user_id"] + for person in storage_collection.async_items(): + if person[CONF_USER_ID] == user_id: + await storage_collection.async_update_item( + person[CONF_ID], {"user_id": None} + ) + + hass.bus.async_listen(EVENT_USER_REMOVED, _handle_user_removed) return True @@ -353,21 +317,21 @@ class Person(RestoreEntity): if self.hass.is_running: # Update person now if hass is already running. - self.person_updated() + await self.async_update_config(self._config) else: # Wait for hass start to not have race between person # and device trackers finishing setup. - @callback - def person_start_hass(now): - self.person_updated() + async def person_start_hass(now): + await self.async_update_config(self._config) self.hass.bus.async_listen_once( EVENT_HOMEASSISTANT_START, person_start_hass ) - @callback - def person_updated(self): + async def async_update_config(self, config): """Handle when the config is updated.""" + self._config = config + if self._unsub_track_device is not None: self._unsub_track_device() self._unsub_track_device = None @@ -441,89 +405,12 @@ def ws_list_person( hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg ): """List persons.""" - manager: PersonManager = hass.data[DOMAIN] + yaml, storage = hass.data[DOMAIN] connection.send_result( - msg["id"], - {"storage": manager.storage_persons, "config": manager.config_persons}, + msg["id"], {"storage": storage.async_items(), "config": yaml.async_items()}, ) -@websocket_api.websocket_command( - { - vol.Required("type"): "person/create", - vol.Required("name"): vol.All(str, vol.Length(min=1)), - vol.Optional("user_id"): vol.Any(str, None), - vol.Optional("device_trackers", default=[]): vol.All( - cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN) - ), - } -) -@websocket_api.require_admin -@websocket_api.async_response -async def ws_create_person( - hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg -): - """Create a person.""" - manager: PersonManager = hass.data[DOMAIN] - try: - person = await manager.async_create_person( - name=msg["name"], - user_id=msg.get("user_id"), - device_trackers=msg["device_trackers"], - ) - connection.send_result(msg["id"], person) - except ValueError as err: - connection.send_error( - msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err) - ) - - -@websocket_api.websocket_command( - { - vol.Required("type"): "person/update", - vol.Required("person_id"): str, - vol.Required("name"): vol.All(str, vol.Length(min=1)), - vol.Optional("user_id"): vol.Any(str, None), - vol.Optional(CONF_DEVICE_TRACKERS, default=[]): vol.All( - cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN) - ), - } -) -@websocket_api.require_admin -@websocket_api.async_response -async def ws_update_person( - hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg -): - """Update a person.""" - manager: PersonManager = hass.data[DOMAIN] - changes = {} - for key in ("name", "user_id", "device_trackers"): - if key in msg: - changes[key] = msg[key] - - try: - person = await manager.async_update_person(msg["person_id"], **changes) - connection.send_result(msg["id"], person) - except ValueError as err: - connection.send_error( - msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err) - ) - - -@websocket_api.websocket_command( - {vol.Required("type"): "person/delete", vol.Required("person_id"): str} -) -@websocket_api.require_admin -@websocket_api.async_response -async def ws_delete_person( - hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg -): - """Delete a person.""" - manager: PersonManager = hass.data[DOMAIN] - await manager.async_delete_person(msg["person_id"]) - connection.send_result(msg["id"]) - - def _get_latest(prev: Optional[State], curr: State): """Get latest state.""" if prev is None or curr.last_updated > prev.last_updated: diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index 2beb2aa2788..60177fcde90 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -1,5 +1,9 @@ """WebSocket based API for Home Assistant.""" -from homeassistant.core import callback +from typing import Optional, Union, cast + +import voluptuous as vol + +from homeassistant.core import HomeAssistant, callback from homeassistant.loader import bind_hass from . import commands, connection, const, decorators, http, messages @@ -26,13 +30,18 @@ websocket_command = decorators.websocket_command @bind_hass @callback -def async_register_command(hass, command_or_handler, handler=None, schema=None): +def async_register_command( + hass: HomeAssistant, + command_or_handler: Union[str, const.WebSocketCommandHandler], + handler: Optional[const.WebSocketCommandHandler] = None, + schema: Optional[vol.Schema] = None, +) -> None: """Register a websocket command.""" # pylint: disable=protected-access if handler is None: - handler = command_or_handler - command = handler._ws_command - schema = handler._ws_schema + handler = cast(const.WebSocketCommandHandler, command_or_handler) + command = handler._ws_command # type: ignore + schema = handler._ws_schema # type: ignore else: command = command_or_handler handlers = hass.data.get(DOMAIN) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 93f926b537a..3e43f824e69 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -107,7 +107,6 @@ def handle_unsubscribe_events(hass, connection, msg): ) -@decorators.async_response @decorators.websocket_command( { vol.Required("type"): "call_service", @@ -116,6 +115,7 @@ def handle_unsubscribe_events(hass, connection, msg): vol.Optional("service_data"): dict, } ) +@decorators.async_response async def handle_call_service(hass, connection, msg): """Handle call service command. @@ -181,8 +181,8 @@ def handle_get_states(hass, connection, msg): connection.send_message(messages.result_message(msg["id"], states)) -@decorators.async_response @decorators.websocket_command({vol.Required("type"): "get_services"}) +@decorators.async_response async def handle_get_services(hass, connection, msg): """Handle get services command. diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index ed24a70519d..ae2bb16c6d2 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -1,6 +1,6 @@ """Connection session.""" import asyncio -from typing import Any, Callable, Dict, Hashable +from typing import Any, Callable, Dict, Hashable, Optional import voluptuous as vol @@ -37,7 +37,7 @@ class ActiveConnection: return Context(user_id=user.id) @callback - def send_result(self, msg_id, result=None): + def send_result(self, msg_id: int, result: Optional[Any] = None) -> None: """Send a result message.""" self.send_message(messages.result_message(msg_id, result)) @@ -49,7 +49,7 @@ class ActiveConnection: self.send_message(content) @callback - def send_error(self, msg_id, code, message): + def send_error(self, msg_id: int, code: str, message: str) -> None: """Send a error message.""" self.send_message(messages.error_message(msg_id, code, message)) diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 8ad9443a4d6..b1fa1263a99 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -3,9 +3,20 @@ import asyncio from concurrent import futures from functools import partial import json +from typing import TYPE_CHECKING, Callable +from homeassistant.core import HomeAssistant from homeassistant.helpers.json import JSONEncoder +if TYPE_CHECKING: + from .connection import ActiveConnection # noqa + + +WebSocketCommandHandler = Callable[ + [HomeAssistant, "ActiveConnection", dict], None +] # pylint: disable=invalid-name + + DOMAIN = "websocket_api" URL = "/api/websocket" MAX_PENDING_MSG = 512 diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index 1a1330242bc..87b5d5baf92 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -1,11 +1,13 @@ """Decorators for the Websocket API.""" from functools import wraps import logging +from typing import Awaitable, Callable -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import Unauthorized -from . import messages +from . import const, messages +from .connection import ActiveConnection # mypy: allow-untyped-calls, allow-untyped-defs @@ -20,7 +22,9 @@ async def _handle_async_response(func, hass, connection, msg): connection.async_handle_exception(msg, err) -def async_response(func): +def async_response( + func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]] +) -> const.WebSocketCommandHandler: """Decorate an async function to handle WebSocket API messages.""" @callback @@ -32,7 +36,7 @@ def async_response(func): return schedule_handler -def require_admin(func): +def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Websocket decorator to require user to be an admin.""" @wraps(func) @@ -104,7 +108,9 @@ def ws_require_user( return validator -def websocket_command(schema): +def websocket_command( + schema: dict, +) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: """Tag a function as a websocket command.""" command = schema["type"] diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py new file mode 100644 index 00000000000..80401fcb30f --- /dev/null +++ b/homeassistant/helpers/collection.py @@ -0,0 +1,401 @@ +"""Helper to deal with YAML + storage.""" +from abc import ABC, abstractmethod +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional, cast + +import voluptuous as vol +from voluptuous.humanize import humanize_error + +from homeassistant.components import websocket_api +from homeassistant.const import CONF_ID +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.storage import Store +from homeassistant.util import slugify + +STORAGE_VERSION = 1 +SAVE_DELAY = 10 + +CHANGE_ADDED = "added" +CHANGE_UPDATED = "updated" +CHANGE_REMOVED = "removed" + + +ChangeListener = Callable[ + [ + # Change type + str, + # Item ID + str, + # New config (None if removed) + Optional[dict], + ], + Awaitable[None], +] # pylint: disable=invalid-name + + +class CollectionError(HomeAssistantError): + """Base class for collection related errors.""" + + +class ItemNotFound(CollectionError): + """Raised when an item is not found.""" + + def __init__(self, item_id: str): + """Initialize item not found error.""" + super().__init__(f"Item {item_id} not found.") + self.item_id = item_id + + +class IDManager: + """Keep track of IDs across different collections.""" + + def __init__(self) -> None: + """Initiate the ID manager.""" + self.collections: List[Dict[str, Any]] = [] + + def add_collection(self, collection: Dict[str, Any]) -> None: + """Add a collection to check for ID usage.""" + self.collections.append(collection) + + def has_id(self, item_id: str) -> bool: + """Test if the ID exists.""" + return any(item_id in collection for collection in self.collections) + + def generate_id(self, suggestion: str) -> str: + """Generate an ID.""" + base = slugify(suggestion) + proposal = base + attempt = 1 + + while self.has_id(proposal): + attempt += 1 + proposal = f"{base}_{attempt}" + + return proposal + + +class ObservableCollection(ABC): + """Base collection type that can be observed.""" + + def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None): + """Initialize the base collection.""" + self.logger = logger + self.id_manager = id_manager or IDManager() + self.data: Dict[str, dict] = {} + self.listeners: List[ChangeListener] = [] + + self.id_manager.add_collection(self.data) + + @callback + def async_items(self) -> List[dict]: + """Return list of items in collection.""" + return list(self.data.values()) + + @callback + def async_add_listener(self, listener: ChangeListener) -> None: + """Add a listener. + + Will be called with (change_type, item_id, updated_config). + """ + self.listeners.append(listener) + + async def notify_change( + self, change_type: str, item_id: str, item: Optional[dict] + ) -> None: + """Notify listeners of a change.""" + self.logger.debug("%s %s: %s", change_type, item_id, item) + for listener in self.listeners: + await listener(change_type, item_id, item) + + +class YamlCollection(ObservableCollection): + """Offer a fake CRUD interface on top of static YAML.""" + + async def async_load(self, data: List[dict]) -> None: + """Load the storage Manager.""" + for item in data: + item_id = item[CONF_ID] + + if self.id_manager.has_id(item_id): + self.logger.warning("Duplicate ID '%s' detected, skipping", item_id) + continue + + self.data[item_id] = item + await self.notify_change(CHANGE_ADDED, item[CONF_ID], item) + + +class StorageCollection(ObservableCollection): + """Offer a CRUD interface on top of JSON storage.""" + + def __init__( + self, + store: Store, + logger: logging.Logger, + id_manager: Optional[IDManager] = None, + ): + """Initialize the storage collection.""" + super().__init__(logger, id_manager) + self.store = store + + @property + def hass(self) -> HomeAssistant: + """Home Assistant object.""" + return self.store.hass + + async def async_load(self) -> None: + """Load the storage Manager.""" + raw_storage = cast(Optional[dict], await self.store.async_load()) + + if raw_storage is None: + raw_storage = {"items": []} + + for item in raw_storage["items"]: + self.data[item[CONF_ID]] = item + await self.notify_change(CHANGE_ADDED, item[CONF_ID], item) + + @abstractmethod + async def _process_create_data(self, data: dict) -> dict: + """Validate the config is valid.""" + + @callback + @abstractmethod + def _get_suggested_id(self, info: dict) -> str: + """Suggest an ID based on the config.""" + + @abstractmethod + async def _update_data(self, data: dict, update_data: dict) -> dict: + """Return a new updated data object.""" + + async def async_create_item(self, data: dict) -> dict: + """Create a new item.""" + item = await self._process_create_data(data) + item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item)) + self.data[item[CONF_ID]] = item + self._async_schedule_save() + await self.notify_change(CHANGE_ADDED, item[CONF_ID], item) + return item + + async def async_update_item(self, item_id: str, updates: dict) -> dict: + """Update item.""" + if item_id not in self.data: + raise ItemNotFound(item_id) + + if CONF_ID in updates: + raise ValueError("Cannot update ID") + + current = self.data[item_id] + + updated = await self._update_data(current, updates) + + self.data[item_id] = updated + self._async_schedule_save() + + await self.notify_change(CHANGE_UPDATED, item_id, updated) + + return self.data[item_id] + + async def async_delete_item(self, item_id: str) -> None: + """Delete item.""" + if item_id not in self.data: + raise ItemNotFound(item_id) + + self.data.pop(item_id) + self._async_schedule_save() + + await self.notify_change(CHANGE_REMOVED, item_id, None) + + @callback + def _async_schedule_save(self) -> None: + """Schedule saving the area registry.""" + self.store.async_delay_save(self._data_to_save, SAVE_DELAY) + + @callback + def _data_to_save(self) -> dict: + """Return data of area registry to store in a file.""" + return {"items": list(self.data.values())} + + +@callback +def attach_entity_component_collection( + entity_component: EntityComponent, + collection: ObservableCollection, + create_entity: Callable[[dict], Entity], +) -> None: + """Map a collection to an entity component.""" + entities = {} + + async def _collection_changed( + change_type: str, item_id: str, config: Optional[dict] + ) -> None: + """Handle a collection change.""" + if change_type == CHANGE_ADDED: + entity = create_entity(cast(dict, config)) + await entity_component.async_add_entities([entity]) + entities[item_id] = entity + return + + if change_type == CHANGE_REMOVED: + entity = entities.pop(item_id) + await entity.async_remove() + return + + # CHANGE_UPDATED + await entities[item_id].async_update_config(config) # type: ignore + + collection.async_add_listener(_collection_changed) + + +class StorageCollectionWebsocket: + """Class to expose storage collection management over websocket.""" + + def __init__( + self, + storage_collection: StorageCollection, + api_prefix: str, + model_name: str, + create_schema: dict, + update_schema: dict, + ): + """Initialize a websocket CRUD.""" + self.storage_collection = storage_collection + self.api_prefix = api_prefix + self.model_name = model_name + self.create_schema = create_schema + self.update_schema = update_schema + + assert self.api_prefix[-1] != "/", "API prefix should not end in /" + + @property + def item_id_key(self) -> str: + """Return item ID key.""" + return f"{self.model_name}_id" + + @callback + def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None: + """Set up the websocket commands.""" + if create_list: + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/list", + self.ws_list_item, + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + {vol.Required("type"): f"{self.api_prefix}/list"} + ), + ) + + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/create", + websocket_api.require_admin( + websocket_api.async_response(self.ws_create_item) + ), + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + **self.create_schema, + vol.Required("type"): f"{self.api_prefix}/create", + } + ), + ) + + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/update", + websocket_api.require_admin( + websocket_api.async_response(self.ws_update_item) + ), + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + **self.update_schema, + vol.Required("type"): f"{self.api_prefix}/update", + vol.Required(self.item_id_key): str, + } + ), + ) + + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/delete", + websocket_api.require_admin( + websocket_api.async_response(self.ws_delete_item) + ), + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + vol.Required("type"): f"{self.api_prefix}/delete", + vol.Required(self.item_id_key): str, + } + ), + ) + + def ws_list_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """List items.""" + connection.send_result(msg["id"], self.storage_collection.async_items()) + + async def ws_create_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Create a item.""" + try: + data = dict(msg) + data.pop("id") + data.pop("type") + item = await self.storage_collection.async_create_item(data) + connection.send_result(msg["id"], item) + except vol.Invalid as err: + connection.send_error( + msg["id"], + websocket_api.const.ERR_INVALID_FORMAT, + humanize_error(data, err), + ) + except ValueError as err: + connection.send_error( + msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err) + ) + + async def ws_update_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Update a item.""" + data = dict(msg) + msg_id = data.pop("id") + item_id = data.pop(self.item_id_key) + data.pop("type") + + try: + item = await self.storage_collection.async_update_item(item_id, data) + connection.send_result(msg_id, item) + except ItemNotFound: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"Unable to find {self.item_id_key} {item_id}", + ) + except vol.Invalid as err: + connection.send_error( + msg["id"], + websocket_api.const.ERR_INVALID_FORMAT, + humanize_error(data, err), + ) + except ValueError as err: + connection.send_error( + msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err) + ) + + async def ws_delete_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Delete a item.""" + try: + await self.storage_collection.async_delete_item(msg[self.item_id_key]) + except ItemNotFound: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"Unable to find {self.item_id_key} {msg[self.item_id_key]}", + ) + + connection.send_result(msg["id"]) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index b7c806950a0..7ccc6c35613 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -473,8 +473,9 @@ class Entity(ABC): self._on_remove = [] self._on_remove.append(func) - async def async_remove(self): + async def async_remove(self) -> None: """Remove entity from Home Assistant.""" + assert self.hass is not None await self.async_internal_will_remove_from_hass() await self.async_will_remove_from_hass() diff --git a/tests/components/conftest.py b/tests/components/conftest.py index a589839c03f..528e804a01a 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -3,14 +3,6 @@ from unittest.mock import patch import pytest -from homeassistant.components.websocket_api.auth import ( - TYPE_AUTH, - TYPE_AUTH_OK, - TYPE_AUTH_REQUIRED, -) -from homeassistant.components.websocket_api.http import URL -from homeassistant.setup import async_setup_component - from tests.common import mock_coro @@ -22,37 +14,3 @@ def prevent_io(): side_effect=lambda *args: mock_coro([]), ): yield - - -@pytest.fixture -def hass_ws_client(aiohttp_client, hass_access_token): - """Websocket client fixture connected to websocket server.""" - - async def create_client(hass, access_token=hass_access_token): - """Create a websocket client.""" - assert await async_setup_component(hass, "websocket_api", {}) - - client = await aiohttp_client(hass.http.app) - - with patch("homeassistant.components.http.auth.setup_auth"): - websocket = await client.ws_connect(URL) - auth_resp = await websocket.receive_json() - assert auth_resp["type"] == TYPE_AUTH_REQUIRED - - if access_token is None: - await websocket.send_json( - {"type": TYPE_AUTH, "access_token": "incorrect"} - ) - else: - await websocket.send_json( - {"type": TYPE_AUTH, "access_token": access_token} - ) - - auth_ok = await websocket.receive_json() - assert auth_ok["type"] == TYPE_AUTH_OK - - # wrap in client - websocket.client = client - return websocket - - return create_client diff --git a/tests/components/onboarding/test_views.py b/tests/components/onboarding/test_views.py index 6d2c6e4c08f..c7c9782e9a8 100644 --- a/tests/components/onboarding/test_views.py +++ b/tests/components/onboarding/test_views.py @@ -98,7 +98,7 @@ async def test_onboarding_user(hass, hass_storage, aiohttp_client): assert user.name == "Test Name" assert len(user.credentials) == 1 assert user.credentials[0].data["username"] == "test-user" - assert len(hass.data["person"].storage_data) == 1 + assert len(hass.data["person"][1].async_items()) == 1 # Validate refresh token 1 resp = await client.post( diff --git a/tests/components/person/test_init.py b/tests/components/person/test_init.py index da5d7f03d34..6f4ea2a92ee 100644 --- a/tests/components/person/test_init.py +++ b/tests/components/person/test_init.py @@ -1,19 +1,15 @@ """The tests for the person component.""" -from unittest.mock import Mock +import logging import pytest +from homeassistant.components import person from homeassistant.components.device_tracker import ( ATTR_SOURCE_TYPE, SOURCE_TYPE_GPS, SOURCE_TYPE_ROUTER, ) -from homeassistant.components.person import ( - ATTR_SOURCE, - ATTR_USER_ID, - DOMAIN, - PersonManager, -) +from homeassistant.components.person import ATTR_SOURCE, ATTR_USER_ID, DOMAIN from homeassistant.const import ( ATTR_GPS_ACCURACY, ATTR_ID, @@ -23,20 +19,29 @@ from homeassistant.const import ( STATE_UNKNOWN, ) from homeassistant.core import CoreState, State +from homeassistant.helpers import collection from homeassistant.setup import async_setup_component -from tests.common import ( - assert_setup_component, - mock_component, - mock_coro_func, - mock_restore_cache, -) +from tests.common import assert_setup_component, mock_component, mock_restore_cache DEVICE_TRACKER = "device_tracker.test_tracker" DEVICE_TRACKER_2 = "device_tracker.test_tracker_2" -# pylint: disable=redefined-outer-name +@pytest.fixture +def storage_collection(hass): + """Return an empty storage collection.""" + id_manager = collection.IDManager() + return person.PersonStorageCollection( + person.PersonStore(hass, person.STORAGE_VERSION, person.STORAGE_KEY), + logging.getLogger(f"{person.__name__}.storage_collection"), + id_manager, + collection.YamlCollection( + logging.getLogger(f"{person.__name__}.yaml_collection"), id_manager + ), + ) + + @pytest.fixture def storage_setup(hass, hass_storage, hass_admin_user): """Storage setup.""" @@ -433,21 +438,21 @@ async def test_load_person_storage_two_nonlinked(hass, hass_storage): async def test_ws_list(hass, hass_ws_client, storage_setup): """Test listing via WS.""" - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) resp = await client.send_json({"id": 6, "type": "person/list"}) resp = await client.receive_json() assert resp["success"] - assert resp["result"]["storage"] == manager.storage_persons + assert resp["result"]["storage"] == manager.async_items() assert len(resp["result"]["storage"]) == 1 assert len(resp["result"]["config"]) == 0 async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_user): """Test creating via WS.""" - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) @@ -462,7 +467,7 @@ async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_use ) resp = await client.receive_json() - persons = manager.storage_persons + persons = manager.async_items() assert len(persons) == 2 assert resp["success"] @@ -474,7 +479,7 @@ async def test_ws_create_requires_admin( ): """Test creating via WS requires admin.""" hass_admin_user.groups = [] - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) @@ -489,7 +494,7 @@ async def test_ws_create_requires_admin( ) resp = await client.receive_json() - persons = manager.storage_persons + persons = manager.async_items() assert len(persons) == 1 assert not resp["success"] @@ -497,10 +502,10 @@ async def test_ws_create_requires_admin( async def test_ws_update(hass, hass_ws_client, storage_setup): """Test updating via WS.""" - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - persons = manager.storage_persons + persons = manager.async_items() resp = await client.send_json( { @@ -514,7 +519,7 @@ async def test_ws_update(hass, hass_ws_client, storage_setup): ) resp = await client.receive_json() - persons = manager.storage_persons + persons = manager.async_items() assert len(persons) == 1 assert resp["success"] @@ -533,10 +538,10 @@ async def test_ws_update_require_admin( ): """Test updating via WS requires admin.""" hass_admin_user.groups = [] - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - original = dict(manager.storage_persons[0]) + original = dict(manager.async_items()[0]) resp = await client.send_json( { @@ -551,23 +556,23 @@ async def test_ws_update_require_admin( resp = await client.receive_json() assert not resp["success"] - not_updated = dict(manager.storage_persons[0]) + not_updated = dict(manager.async_items()[0]) assert original == not_updated async def test_ws_delete(hass, hass_ws_client, storage_setup): """Test deleting via WS.""" - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) - persons = manager.storage_persons + persons = manager.async_items() resp = await client.send_json( {"id": 6, "type": "person/delete", "person_id": persons[0]["id"]} ) resp = await client.receive_json() - persons = manager.storage_persons + persons = manager.async_items() assert len(persons) == 0 assert resp["success"] @@ -581,7 +586,7 @@ async def test_ws_delete_require_admin( ): """Test deleting via WS requires admin.""" hass_admin_user.groups = [] - manager = hass.data[DOMAIN] + manager = hass.data[DOMAIN][1] client = await hass_ws_client(hass) @@ -589,7 +594,7 @@ async def test_ws_delete_require_admin( { "id": 6, "type": "person/delete", - "person_id": manager.storage_persons[0]["id"], + "person_id": manager.async_items()[0]["id"], "name": "Updated Name", "device_trackers": [DEVICE_TRACKER_2], "user_id": None, @@ -598,61 +603,64 @@ async def test_ws_delete_require_admin( resp = await client.receive_json() assert not resp["success"] - persons = manager.storage_persons + persons = manager.async_items() assert len(persons) == 1 -async def test_create_invalid_user_id(hass): +async def test_create_invalid_user_id(hass, storage_collection): """Test we do not allow invalid user ID during creation.""" - manager = PersonManager(hass, Mock(), []) - await manager.async_initialize() with pytest.raises(ValueError): - await manager.async_create_person(name="Hello", user_id="non-existing") + await storage_collection.async_create_item( + {"name": "Hello", "user_id": "non-existing"} + ) -async def test_create_duplicate_user_id(hass, hass_admin_user): +async def test_create_duplicate_user_id(hass, hass_admin_user, storage_collection): """Test we do not allow duplicate user ID during creation.""" - manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) - await manager.async_initialize() - await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) + await storage_collection.async_create_item( + {"name": "Hello", "user_id": hass_admin_user.id} + ) with pytest.raises(ValueError): - await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) + await storage_collection.async_create_item( + {"name": "Hello", "user_id": hass_admin_user.id} + ) -async def test_update_double_user_id(hass, hass_admin_user): +async def test_update_double_user_id(hass, hass_admin_user, storage_collection): """Test we do not allow double user ID during update.""" - manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) - await manager.async_initialize() - await manager.async_create_person(name="Hello", user_id=hass_admin_user.id) - person = await manager.async_create_person(name="Hello") + await storage_collection.async_create_item( + {"name": "Hello", "user_id": hass_admin_user.id} + ) + person = await storage_collection.async_create_item({"name": "Hello"}) with pytest.raises(ValueError): - await manager.async_update_person( - person_id=person["id"], user_id=hass_admin_user.id + await storage_collection.async_update_item( + person["id"], {"user_id": hass_admin_user.id} ) -async def test_update_invalid_user_id(hass): +async def test_update_invalid_user_id(hass, storage_collection): """Test updating to invalid user ID.""" - manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) - await manager.async_initialize() - person = await manager.async_create_person(name="Hello") + person = await storage_collection.async_create_item({"name": "Hello"}) with pytest.raises(ValueError): - await manager.async_update_person( - person_id=person["id"], user_id="non-existing" + await storage_collection.async_update_item( + person["id"], {"user_id": "non-existing"} ) -async def test_update_person_when_user_removed(hass, hass_read_only_user): +async def test_update_person_when_user_removed( + hass, storage_setup, hass_read_only_user +): """Update person when user is removed.""" - manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), []) - await manager.async_initialize() - person = await manager.async_create_person( - name="Hello", user_id=hass_read_only_user.id + storage_collection = hass.data[DOMAIN][1] + + person = await storage_collection.async_create_item( + {"name": "Hello", "user_id": hass_read_only_user.id} ) await hass.auth.async_remove_user(hass_read_only_user) await hass.async_block_till_done() - assert person["user_id"] is None + + assert storage_collection.data[person["id"]]["user_id"] is None diff --git a/tests/conftest.py b/tests/conftest.py index 7364b3f0b96..cd77122800a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,13 @@ import requests_mock as _requests_mock from homeassistant import util from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY from homeassistant.auth.providers import homeassistant, legacy_api_password +from homeassistant.components.websocket_api.auth import ( + TYPE_AUTH, + TYPE_AUTH_OK, + TYPE_AUTH_REQUIRED, +) +from homeassistant.components.websocket_api.http import URL +from homeassistant.setup import async_setup_component from homeassistant.util import location pytest.register_assert_rewrite("tests.common") @@ -187,3 +194,37 @@ def hass_client(hass, aiohttp_client, hass_access_token): ) return auth_client + + +@pytest.fixture +def hass_ws_client(aiohttp_client, hass_access_token): + """Websocket client fixture connected to websocket server.""" + + async def create_client(hass, access_token=hass_access_token): + """Create a websocket client.""" + assert await async_setup_component(hass, "websocket_api", {}) + + client = await aiohttp_client(hass.http.app) + + with patch("homeassistant.components.http.auth.setup_auth"): + websocket = await client.ws_connect(URL) + auth_resp = await websocket.receive_json() + assert auth_resp["type"] == TYPE_AUTH_REQUIRED + + if access_token is None: + await websocket.send_json( + {"type": TYPE_AUTH, "access_token": "incorrect"} + ) + else: + await websocket.send_json( + {"type": TYPE_AUTH, "access_token": access_token} + ) + + auth_ok = await websocket.receive_json() + assert auth_ok["type"] == TYPE_AUTH_OK + + # wrap in client + websocket.client = client + return websocket + + return create_client diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py new file mode 100644 index 00000000000..29eeca67f3f --- /dev/null +++ b/tests/helpers/test_collection.py @@ -0,0 +1,356 @@ +"""Tests for the collection helper.""" +import logging + +import pytest +import voluptuous as vol + +from homeassistant.helpers import collection, entity, entity_component, storage + +from tests.common import flush_store + +LOGGER = logging.getLogger(__name__) + + +def track_changes(coll: collection.ObservableCollection): + """Create helper to track changes in a collection.""" + changes = [] + + async def listener(*args): + changes.append(args) + + coll.async_add_listener(listener) + + return changes + + +class MockEntity(entity.Entity): + """Entity that is config based.""" + + def __init__(self, config): + """Initialize entity.""" + self._config = config + + @property + def unique_id(self): + """Return unique ID of entity.""" + return self._config["id"] + + @property + def name(self): + """Return name of entity.""" + return self._config["name"] + + @property + def state(self): + """Return state of entity.""" + return self._config["state"] + + async def async_update_config(self, config): + """Update entity config.""" + self._config = config + self.async_write_ha_state() + + +class MockStorageCollection(collection.StorageCollection): + """Mock storage collection.""" + + async def _process_create_data(self, data: dict) -> dict: + """Validate the config is valid.""" + if "name" not in data: + raise ValueError("invalid") + + return data + + def _get_suggested_id(self, info: dict) -> str: + """Suggest an ID based on the config.""" + return info["name"] + + async def _update_data(self, data: dict, update_data: dict) -> dict: + """Return a new updated data object.""" + return {**data, **update_data} + + +def test_id_manager(): + """Test the ID manager.""" + id_manager = collection.IDManager() + assert not id_manager.has_id("some_id") + data = {} + id_manager.add_collection(data) + assert not id_manager.has_id("some_id") + data["some_id"] = 1 + assert id_manager.has_id("some_id") + assert id_manager.generate_id("some_id") == "some_id_2" + assert id_manager.generate_id("bla") == "bla" + + +async def test_observable_collection(): + """Test observerable collection.""" + coll = collection.ObservableCollection(LOGGER) + assert coll.async_items() == [] + coll.data["bla"] = 1 + assert coll.async_items() == [1] + + changes = track_changes(coll) + await coll.notify_change("mock_type", "mock_id", {"mock": "item"}) + assert len(changes) == 1 + assert changes[0] == ("mock_type", "mock_id", {"mock": "item"}) + + +async def test_yaml_collection(): + """Test a YAML collection.""" + id_manager = collection.IDManager() + coll = collection.YamlCollection(LOGGER, id_manager) + changes = track_changes(coll) + await coll.async_load( + [{"id": "mock-1", "name": "Mock 1"}, {"id": "mock-2", "name": "Mock 2"}] + ) + assert id_manager.has_id("mock-1") + assert id_manager.has_id("mock-2") + assert len(changes) == 2 + assert changes[0] == ( + collection.CHANGE_ADDED, + "mock-1", + {"id": "mock-1", "name": "Mock 1"}, + ) + assert changes[1] == ( + collection.CHANGE_ADDED, + "mock-2", + {"id": "mock-2", "name": "Mock 2"}, + ) + + +async def test_yaml_collection_skipping_duplicate_ids(): + """Test YAML collection skipping duplicate IDs.""" + id_manager = collection.IDManager() + id_manager.add_collection({"existing": True}) + coll = collection.YamlCollection(LOGGER, id_manager) + changes = track_changes(coll) + await coll.async_load( + [{"id": "mock-1", "name": "Mock 1"}, {"id": "existing", "name": "Mock 2"}] + ) + assert len(changes) == 1 + assert changes[0] == ( + collection.CHANGE_ADDED, + "mock-1", + {"id": "mock-1", "name": "Mock 1"}, + ) + + +async def test_storage_collection(hass): + """Test storage collection.""" + store = storage.Store(hass, 1, "test-data") + await store.async_save( + { + "items": [ + {"id": "mock-1", "name": "Mock 1", "data": 1}, + {"id": "mock-2", "name": "Mock 2", "data": 2}, + ] + } + ) + id_manager = collection.IDManager() + coll = MockStorageCollection(store, LOGGER, id_manager) + changes = track_changes(coll) + + await coll.async_load() + assert id_manager.has_id("mock-1") + assert id_manager.has_id("mock-2") + assert len(changes) == 2 + assert changes[0] == ( + collection.CHANGE_ADDED, + "mock-1", + {"id": "mock-1", "name": "Mock 1", "data": 1}, + ) + assert changes[1] == ( + collection.CHANGE_ADDED, + "mock-2", + {"id": "mock-2", "name": "Mock 2", "data": 2}, + ) + + item = await coll.async_create_item({"name": "Mock 3"}) + assert item["id"] == "mock_3" + assert len(changes) == 3 + assert changes[2] == ( + collection.CHANGE_ADDED, + "mock_3", + {"id": "mock_3", "name": "Mock 3"}, + ) + + updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"}) + assert id_manager.has_id("mock-2") + assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2} + assert len(changes) == 4 + assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item) + + with pytest.raises(ValueError): + await coll.async_update_item("mock-2", {"id": "mock-2-updated"}) + + assert id_manager.has_id("mock-2") + assert not id_manager.has_id("mock-2-updated") + assert len(changes) == 4 + + await flush_store(store) + + assert await storage.Store(hass, 1, "test-data").async_load() == { + "items": [ + {"id": "mock-1", "name": "Mock 1", "data": 1}, + {"id": "mock-2", "name": "Mock 2 updated", "data": 2}, + {"id": "mock_3", "name": "Mock 3"}, + ] + } + + +async def test_attach_entity_component_collection(hass): + """Test attaching collection to entity component.""" + ent_comp = entity_component.EntityComponent(LOGGER, "test", hass) + coll = collection.ObservableCollection(LOGGER) + collection.attach_entity_component_collection(ent_comp, coll, MockEntity) + + await coll.notify_change( + collection.CHANGE_ADDED, + "mock_id", + {"id": "mock_id", "state": "initial", "name": "Mock 1"}, + ) + + assert hass.states.get("test.mock_1").name == "Mock 1" + assert hass.states.get("test.mock_1").state == "initial" + + await coll.notify_change( + collection.CHANGE_UPDATED, + "mock_id", + {"id": "mock_id", "state": "second", "name": "Mock 1 updated"}, + ) + + assert hass.states.get("test.mock_1").name == "Mock 1 updated" + assert hass.states.get("test.mock_1").state == "second" + + await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None) + + assert hass.states.get("test.mock_1") is None + + +async def test_storage_collection_websocket(hass, hass_ws_client): + """Test exposing a storage collection via websockets.""" + store = storage.Store(hass, 1, "test-data") + coll = MockStorageCollection(store, LOGGER) + changes = track_changes(coll) + collection.StorageCollectionWebsocket( + coll, + "test_item/collection", + "test_item", + {vol.Required("name"): str, vol.Required("immutable_string"): str}, + {vol.Optional("name"): str}, + ).async_setup(hass) + + client = await hass_ws_client(hass) + + # Create invalid + await client.send_json( + { + "id": 1, + "type": "test_item/collection/create", + "name": 1, + # Forgot to add immutable_string + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "invalid_format" + assert len(changes) == 0 + + # Create + await client.send_json( + { + "id": 2, + "type": "test_item/collection/create", + "name": "Initial Name", + "immutable_string": "no-changes", + } + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == { + "id": "initial_name", + "name": "Initial Name", + "immutable_string": "no-changes", + } + assert len(changes) == 1 + assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"]) + + # List + await client.send_json({"id": 3, "type": "test_item/collection/list"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "id": "initial_name", + "name": "Initial Name", + "immutable_string": "no-changes", + } + ] + assert len(changes) == 1 + + # Update invalid data + await client.send_json( + { + "id": 4, + "type": "test_item/collection/update", + "test_item_id": "initial_name", + "immutable_string": "no-changes", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "invalid_format" + assert len(changes) == 1 + + # Update invalid item + await client.send_json( + { + "id": 5, + "type": "test_item/collection/update", + "test_item_id": "non-existing", + "name": "Updated name", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "not_found" + assert len(changes) == 1 + + # Update + await client.send_json( + { + "id": 6, + "type": "test_item/collection/update", + "test_item_id": "initial_name", + "name": "Updated name", + } + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == { + "id": "initial_name", + "name": "Updated name", + "immutable_string": "no-changes", + } + assert len(changes) == 2 + assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"]) + + # Delete invalid ID + await client.send_json( + {"id": 7, "type": "test_item/collection/update", "test_item_id": "non-existing"} + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "not_found" + assert len(changes) == 2 + + # Delete + await client.send_json( + {"id": 8, "type": "test_item/collection/delete", "test_item_id": "initial_name"} + ) + response = await client.receive_json() + assert response["success"] + + assert len(changes) == 3 + assert changes[2] == (collection.CHANGE_REMOVED, "initial_name", None)