Extract Collection helper from Person integration (#30313)

* Add CRUD foundation

* Use collection helper in person integration

* Lint/pytest

* Add tests

* Lint

* Create notification
This commit is contained in:
Paulus Schoutsen 2020-01-03 21:37:11 +01:00 committed by GitHub
parent 3033dbd86c
commit b9aba30a6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1074 additions and 396 deletions

View File

@ -173,13 +173,13 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
return capabilities
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/action/list",
vol.Required("device_id"): str,
}
)
@websocket_api.async_response
async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions."""
device_id = msg["device_id"]
@ -187,13 +187,13 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
connection.send_result(msg["id"], actions)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/condition/list",
vol.Required("device_id"): str,
}
)
@websocket_api.async_response
async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions."""
device_id = msg["device_id"]
@ -201,13 +201,13 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
connection.send_result(msg["id"], conditions)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/trigger/list",
vol.Required("device_id"): str,
}
)
@websocket_api.async_response
async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers."""
device_id = msg["device_id"]
@ -215,13 +215,13 @@ async def websocket_device_automation_list_triggers(hass, connection, msg):
connection.send_result(msg["id"], triggers)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/action/capabilities",
vol.Required("action"): dict,
}
)
@websocket_api.async_response
async def websocket_device_automation_get_action_capabilities(hass, connection, msg):
"""Handle request for device action capabilities."""
action = msg["action"]
@ -231,13 +231,13 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
connection.send_result(msg["id"], capabilities)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/condition/capabilities",
vol.Required("condition"): dict,
}
)
@websocket_api.async_response
async def websocket_device_automation_get_condition_capabilities(hass, connection, msg):
"""Handle request for device condition capabilities."""
condition = msg["condition"]
@ -247,13 +247,13 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
connection.send_result(msg["id"], capabilities)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "device_automation/trigger/capabilities",
vol.Required("trigger"): dict,
}
)
@websocket_api.async_response
async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg):
"""Handle request for device trigger capabilities."""
trigger = msg["trigger"]

View File

@ -1,9 +1,6 @@
"""Support for tracking people."""
from collections import OrderedDict
from itertools import chain
import logging
from typing import Optional
import uuid
from typing import List, Optional
import voluptuous as vol
@ -28,6 +25,7 @@ from homeassistant.const import (
STATE_UNKNOWN,
)
from homeassistant.core import Event, State, callback
from homeassistant.helpers import collection, entity_registry
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change
@ -48,8 +46,7 @@ CONF_USER_ID = "user_id"
DOMAIN = "person"
STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
SAVE_DELAY = 10
STORAGE_VERSION = 2
# Device tracker states to ignore
IGNORE_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE)
@ -75,217 +72,184 @@ _UNDEF = object()
@bind_hass
async def async_create_person(hass, name, *, user_id=None, device_trackers=None):
"""Create a new person."""
await hass.data[DOMAIN].async_create_person(
name=name, user_id=user_id, device_trackers=device_trackers
await hass.data[DOMAIN][1].async_create_item(
{"name": name, "user_id": user_id, "device_trackers": device_trackers}
)
class PersonManager:
"""Manage person data."""
CREATE_FIELDS = {
vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=list): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
UPDATE_FIELDS = {
vol.Optional("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=list): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
class PersonStore(Store):
"""Person storage."""
async def _async_migrate_func(self, old_version, old_data):
"""Migrate to the new version.
Migrate storage to use format of collection helper.
"""
return {"items": old_data["persons"]}
class PersonStorageCollection(collection.StorageCollection):
"""Person collection stored in storage."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
def __init__(
self, hass: HomeAssistantType, component: EntityComponent, config_persons
self,
store: Store,
logger: logging.Logger,
id_manager: collection.IDManager,
yaml_collection: collection.YamlCollection,
):
"""Initialize person storage."""
self.hass = hass
self.component = component
self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
self.storage_data = None
"""Initialize a person storage collection."""
super().__init__(store, logger, id_manager)
self.async_add_listener(self._collection_changed)
self.yaml_collection = yaml_collection
config_data = self.config_data = OrderedDict()
for conf in config_persons:
person_id = conf[CONF_ID]
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
data = self.CREATE_SCHEMA(data)
if person_id in config_data:
_LOGGER.error("Found config user with duplicate ID: %s", person_id)
continue
config_data[person_id] = conf
@property
def storage_persons(self):
"""Iterate over persons stored in storage."""
return list(self.storage_data.values())
@property
def config_persons(self):
"""Iterate over persons stored in config."""
return list(self.config_data.values())
async def async_initialize(self):
"""Get the person data."""
raw_storage = await self.store.async_load()
if raw_storage is None:
raw_storage = {"persons": []}
storage_data = self.storage_data = OrderedDict()
for person in raw_storage["persons"]:
storage_data[person[CONF_ID]] = person
entities = []
seen_users = set()
for person_conf in self.config_data.values():
person_id = person_conf[CONF_ID]
user_id = person_conf.get(CONF_USER_ID)
if user_id is not None:
if await self.hass.auth.async_get_user(user_id) is None:
_LOGGER.error("Invalid user_id detected for person %s", person_id)
continue
if user_id in seen_users:
_LOGGER.error(
"Duplicate user_id %s detected for person %s",
user_id,
person_id,
)
continue
seen_users.add(user_id)
entities.append(Person(person_conf, False))
# To make sure IDs don't overlap between config/storage
seen_persons = set(self.config_data)
for person_conf in storage_data.values():
person_id = person_conf[CONF_ID]
user_id = person_conf[CONF_USER_ID]
if person_id in seen_persons:
_LOGGER.error(
"Skipping adding person from storage with same ID as"
" configuration.yaml entry: %s",
person_id,
)
continue
if user_id is not None and user_id in seen_users:
_LOGGER.error(
"Duplicate user_id %s detected for person %s", user_id, person_id
)
continue
# To make sure all users have just 1 person linked.
seen_users.add(user_id)
entities.append(Person(person_conf, True))
if entities:
await self.component.async_add_entities(entities)
self.hass.bus.async_listen(EVENT_USER_REMOVED, self._user_removed)
async def async_create_person(self, *, name, device_trackers=None, user_id=None):
"""Create a new person."""
if not name:
raise ValueError("Name is required")
user_id = data.get("user_id")
if user_id is not None:
await self._validate_user_id(user_id)
person = {
CONF_ID: uuid.uuid4().hex,
CONF_NAME: name,
CONF_USER_ID: user_id,
CONF_DEVICE_TRACKERS: device_trackers or [],
}
self.storage_data[person[CONF_ID]] = person
self._async_schedule_save()
await self.component.async_add_entities([Person(person, True)])
return person
return self.CREATE_SCHEMA(data)
async def async_update_person(
self, person_id, *, name=_UNDEF, device_trackers=_UNDEF, user_id=_UNDEF
):
"""Update person."""
current = self.storage_data.get(person_id)
@callback
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return info["name"]
if current is None:
raise ValueError("Invalid person specified.")
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
changes = {
key: value
for key, value in (
(CONF_NAME, name),
(CONF_DEVICE_TRACKERS, device_trackers),
(CONF_USER_ID, user_id),
)
if value is not _UNDEF and current[key] != value
}
user_id = update_data.get("user_id")
if CONF_USER_ID in changes and user_id is not None:
if user_id is not None:
await self._validate_user_id(user_id)
self.storage_data[person_id].update(changes)
self._async_schedule_save()
for entity in self.component.entities:
if entity.unique_id == person_id:
entity.person_updated()
break
return self.storage_data[person_id]
async def async_delete_person(self, person_id):
"""Delete person."""
if person_id not in self.storage_data:
raise ValueError("Invalid person specified.")
self.storage_data.pop(person_id)
self._async_schedule_save()
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
for entity in self.component.entities:
if entity.unique_id == person_id:
await entity.async_remove()
ent_reg.async_remove(entity.entity_id)
break
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"persons": list(self.storage_data.values())}
return {**data, **update_data}
async def _validate_user_id(self, user_id):
"""Validate the used user_id."""
if await self.hass.auth.async_get_user(user_id) is None:
raise ValueError("User does not exist")
if any(
person
for person in chain(self.storage_data.values(), self.config_data.values())
if person.get(CONF_USER_ID) == user_id
):
raise ValueError("User already taken")
for persons in (self.data.values(), self.yaml_collection.async_items()):
if any(person for person in persons if person.get(CONF_USER_ID) == user_id):
raise ValueError("User already taken")
async def _user_removed(self, event: Event):
"""Handle event that a person is removed."""
user_id = event.data["user_id"]
for person in self.storage_data.values():
if person[CONF_USER_ID] == user_id:
await self.async_update_person(person_id=person[CONF_ID], user_id=None)
async def _collection_changed(
self, change_type: str, item_id: str, config: Optional[dict]
) -> None:
"""Handle a collection change."""
if change_type != collection.CHANGE_REMOVED:
return
ent_reg = await entity_registry.async_get_registry(self.hass)
ent_reg.async_remove(ent_reg.async_get_entity_id(DOMAIN, DOMAIN, item_id))
async def filter_yaml_data(hass: HomeAssistantType, persons: List[dict]) -> List[dict]:
"""Validate YAML data that we can't validate via schema."""
filtered = []
person_invalid_user = []
for person_conf in persons:
user_id = person_conf.get(CONF_USER_ID)
if user_id is not None:
if await hass.auth.async_get_user(user_id) is None:
_LOGGER.error(
"Invalid user_id detected for person %s",
person_conf[collection.CONF_ID],
)
person_invalid_user.append(
f"- Person {person_conf[CONF_NAME]} (id: {person_conf[collection.CONF_ID]}) points at invalid user {user_id}"
)
continue
filtered.append(person_conf)
if person_invalid_user:
hass.components.persistent_notification.async_create(
f"""
The following persons point at invalid users:
{"- ".join(person_invalid_user)}
""",
"Invalid Person Configuration",
DOMAIN,
)
return filtered
async def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Set up the person component."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
conf_persons = config.get(DOMAIN, [])
manager = hass.data[DOMAIN] = PersonManager(hass, component, conf_persons)
await manager.async_initialize()
entity_component = EntityComponent(_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection(
logging.getLogger(f"{__name__}.yaml_collection"), id_manager
)
storage_collection = PersonStorageCollection(
PersonStore(hass, STORAGE_VERSION, STORAGE_KEY),
logging.getLogger(f"{__name__}.storage_collection"),
id_manager,
yaml_collection,
)
collection.attach_entity_component_collection(
entity_component, yaml_collection, lambda conf: Person(conf, False)
)
collection.attach_entity_component_collection(
entity_component, storage_collection, lambda conf: Person(conf, True)
)
await yaml_collection.async_load(
await filter_yaml_data(hass, config.get(DOMAIN, []))
)
await storage_collection.async_load()
hass.data[DOMAIN] = (yaml_collection, storage_collection)
collection.StorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass, create_list=False)
websocket_api.async_register_command(hass, ws_list_person)
websocket_api.async_register_command(hass, ws_create_person)
websocket_api.async_register_command(hass, ws_update_person)
websocket_api.async_register_command(hass, ws_delete_person)
async def _handle_user_removed(event: Event) -> None:
"""Handle a user being removed."""
user_id = event.data["user_id"]
for person in storage_collection.async_items():
if person[CONF_USER_ID] == user_id:
await storage_collection.async_update_item(
person[CONF_ID], {"user_id": None}
)
hass.bus.async_listen(EVENT_USER_REMOVED, _handle_user_removed)
return True
@ -353,21 +317,21 @@ class Person(RestoreEntity):
if self.hass.is_running:
# Update person now if hass is already running.
self.person_updated()
await self.async_update_config(self._config)
else:
# Wait for hass start to not have race between person
# and device trackers finishing setup.
@callback
def person_start_hass(now):
self.person_updated()
async def person_start_hass(now):
await self.async_update_config(self._config)
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, person_start_hass
)
@callback
def person_updated(self):
async def async_update_config(self, config):
"""Handle when the config is updated."""
self._config = config
if self._unsub_track_device is not None:
self._unsub_track_device()
self._unsub_track_device = None
@ -441,89 +405,12 @@ def ws_list_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""List persons."""
manager: PersonManager = hass.data[DOMAIN]
yaml, storage = hass.data[DOMAIN]
connection.send_result(
msg["id"],
{"storage": manager.storage_persons, "config": manager.config_persons},
msg["id"], {"storage": storage.async_items(), "config": yaml.async_items()},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "person/create",
vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional("device_trackers", default=[]): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_create_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Create a person."""
manager: PersonManager = hass.data[DOMAIN]
try:
person = await manager.async_create_person(
name=msg["name"],
user_id=msg.get("user_id"),
device_trackers=msg["device_trackers"],
)
connection.send_result(msg["id"], person)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
@websocket_api.websocket_command(
{
vol.Required("type"): "person/update",
vol.Required("person_id"): str,
vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None),
vol.Optional(CONF_DEVICE_TRACKERS, default=[]): vol.All(
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_update_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Update a person."""
manager: PersonManager = hass.data[DOMAIN]
changes = {}
for key in ("name", "user_id", "device_trackers"):
if key in msg:
changes[key] = msg[key]
try:
person = await manager.async_update_person(msg["person_id"], **changes)
connection.send_result(msg["id"], person)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
@websocket_api.websocket_command(
{vol.Required("type"): "person/delete", vol.Required("person_id"): str}
)
@websocket_api.require_admin
@websocket_api.async_response
async def ws_delete_person(
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
):
"""Delete a person."""
manager: PersonManager = hass.data[DOMAIN]
await manager.async_delete_person(msg["person_id"])
connection.send_result(msg["id"])
def _get_latest(prev: Optional[State], curr: State):
"""Get latest state."""
if prev is None or curr.last_updated > prev.last_updated:

View File

@ -1,5 +1,9 @@
"""WebSocket based API for Home Assistant."""
from homeassistant.core import callback
from typing import Optional, Union, cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass
from . import commands, connection, const, decorators, http, messages
@ -26,13 +30,18 @@ websocket_command = decorators.websocket_command
@bind_hass
@callback
def async_register_command(hass, command_or_handler, handler=None, schema=None):
def async_register_command(
hass: HomeAssistant,
command_or_handler: Union[str, const.WebSocketCommandHandler],
handler: Optional[const.WebSocketCommandHandler] = None,
schema: Optional[vol.Schema] = None,
) -> None:
"""Register a websocket command."""
# pylint: disable=protected-access
if handler is None:
handler = command_or_handler
command = handler._ws_command
schema = handler._ws_schema
handler = cast(const.WebSocketCommandHandler, command_or_handler)
command = handler._ws_command # type: ignore
schema = handler._ws_schema # type: ignore
else:
command = command_or_handler
handlers = hass.data.get(DOMAIN)

View File

@ -107,7 +107,6 @@ def handle_unsubscribe_events(hass, connection, msg):
)
@decorators.async_response
@decorators.websocket_command(
{
vol.Required("type"): "call_service",
@ -116,6 +115,7 @@ def handle_unsubscribe_events(hass, connection, msg):
vol.Optional("service_data"): dict,
}
)
@decorators.async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command.
@ -181,8 +181,8 @@ def handle_get_states(hass, connection, msg):
connection.send_message(messages.result_message(msg["id"], states))
@decorators.async_response
@decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command.

View File

@ -1,6 +1,6 @@
"""Connection session."""
import asyncio
from typing import Any, Callable, Dict, Hashable
from typing import Any, Callable, Dict, Hashable, Optional
import voluptuous as vol
@ -37,7 +37,7 @@ class ActiveConnection:
return Context(user_id=user.id)
@callback
def send_result(self, msg_id, result=None):
def send_result(self, msg_id: int, result: Optional[Any] = None) -> None:
"""Send a result message."""
self.send_message(messages.result_message(msg_id, result))
@ -49,7 +49,7 @@ class ActiveConnection:
self.send_message(content)
@callback
def send_error(self, msg_id, code, message):
def send_error(self, msg_id: int, code: str, message: str) -> None:
"""Send a error message."""
self.send_message(messages.error_message(msg_id, code, message))

View File

@ -3,9 +3,20 @@ import asyncio
from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Callable
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
if TYPE_CHECKING:
from .connection import ActiveConnection # noqa
WebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", dict], None
] # pylint: disable=invalid-name
DOMAIN = "websocket_api"
URL = "/api/websocket"
MAX_PENDING_MSG = 512

View File

@ -1,11 +1,13 @@
"""Decorators for the Websocket API."""
from functools import wraps
import logging
from typing import Awaitable, Callable
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
from . import messages
from . import const, messages
from .connection import ActiveConnection
# mypy: allow-untyped-calls, allow-untyped-defs
@ -20,7 +22,9 @@ async def _handle_async_response(func, hass, connection, msg):
connection.async_handle_exception(msg, err)
def async_response(func):
def async_response(
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
) -> const.WebSocketCommandHandler:
"""Decorate an async function to handle WebSocket API messages."""
@callback
@ -32,7 +36,7 @@ def async_response(func):
return schedule_handler
def require_admin(func):
def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Websocket decorator to require user to be an admin."""
@wraps(func)
@ -104,7 +108,9 @@ def ws_require_user(
return validator
def websocket_command(schema):
def websocket_command(
schema: dict,
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Tag a function as a websocket command."""
command = schema["type"]

View File

@ -0,0 +1,401 @@
"""Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import websocket_api
from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.storage import Store
from homeassistant.util import slugify
STORAGE_VERSION = 1
SAVE_DELAY = 10
CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
ChangeListener = Callable[
[
# Change type
str,
# Item ID
str,
# New config (None if removed)
Optional[dict],
],
Awaitable[None],
] # pylint: disable=invalid-name
class CollectionError(HomeAssistantError):
"""Base class for collection related errors."""
class ItemNotFound(CollectionError):
"""Raised when an item is not found."""
def __init__(self, item_id: str):
"""Initialize item not found error."""
super().__init__(f"Item {item_id} not found.")
self.item_id = item_id
class IDManager:
"""Keep track of IDs across different collections."""
def __init__(self) -> None:
"""Initiate the ID manager."""
self.collections: List[Dict[str, Any]] = []
def add_collection(self, collection: Dict[str, Any]) -> None:
"""Add a collection to check for ID usage."""
self.collections.append(collection)
def has_id(self, item_id: str) -> bool:
"""Test if the ID exists."""
return any(item_id in collection for collection in self.collections)
def generate_id(self, suggestion: str) -> str:
"""Generate an ID."""
base = slugify(suggestion)
proposal = base
attempt = 1
while self.has_id(proposal):
attempt += 1
proposal = f"{base}_{attempt}"
return proposal
class ObservableCollection(ABC):
"""Base collection type that can be observed."""
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
"""Initialize the base collection."""
self.logger = logger
self.id_manager = id_manager or IDManager()
self.data: Dict[str, dict] = {}
self.listeners: List[ChangeListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> List[dict]:
"""Return list of items in collection."""
return list(self.data.values())
@callback
def async_add_listener(self, listener: ChangeListener) -> None:
"""Add a listener.
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
async def notify_change(
self, change_type: str, item_id: str, item: Optional[dict]
) -> None:
"""Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
for listener in self.listeners:
await listener(change_type, item_id, item)
class YamlCollection(ObservableCollection):
"""Offer a fake CRUD interface on top of static YAML."""
async def async_load(self, data: List[dict]) -> None:
"""Load the storage Manager."""
for item in data:
item_id = item[CONF_ID]
if self.id_manager.has_id(item_id):
self.logger.warning("Duplicate ID '%s' detected, skipping", item_id)
continue
self.data[item_id] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
class StorageCollection(ObservableCollection):
"""Offer a CRUD interface on top of JSON storage."""
def __init__(
self,
store: Store,
logger: logging.Logger,
id_manager: Optional[IDManager] = None,
):
"""Initialize the storage collection."""
super().__init__(logger, id_manager)
self.store = store
@property
def hass(self) -> HomeAssistant:
"""Home Assistant object."""
return self.store.hass
async def async_load(self) -> None:
"""Load the storage Manager."""
raw_storage = cast(Optional[dict], await self.store.async_load())
if raw_storage is None:
raw_storage = {"items": []}
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
@abstractmethod
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
@callback
@abstractmethod
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
@abstractmethod
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
async def async_create_item(self, data: dict) -> dict:
"""Create a new item."""
item = await self._process_create_data(data)
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
return item
async def async_update_item(self, item_id: str, updates: dict) -> dict:
"""Update item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
if CONF_ID in updates:
raise ValueError("Cannot update ID")
current = self.data[item_id]
updated = await self._update_data(current, updates)
self.data[item_id] = updated
self._async_schedule_save()
await self.notify_change(CHANGE_UPDATED, item_id, updated)
return self.data[item_id]
async def async_delete_item(self, item_id: str) -> None:
"""Delete item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
self.data.pop(item_id)
self._async_schedule_save()
await self.notify_change(CHANGE_REMOVED, item_id, None)
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"items": list(self.data.values())}
@callback
def attach_entity_component_collection(
entity_component: EntityComponent,
collection: ObservableCollection,
create_entity: Callable[[dict], Entity],
) -> None:
"""Map a collection to an entity component."""
entities = {}
async def _collection_changed(
change_type: str, item_id: str, config: Optional[dict]
) -> None:
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(cast(dict, config))
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return
if change_type == CHANGE_REMOVED:
entity = entities.pop(item_id)
await entity.async_remove()
return
# CHANGE_UPDATED
await entities[item_id].async_update_config(config) # type: ignore
collection.async_add_listener(_collection_changed)
class StorageCollectionWebsocket:
"""Class to expose storage collection management over websocket."""
def __init__(
self,
storage_collection: StorageCollection,
api_prefix: str,
model_name: str,
create_schema: dict,
update_schema: dict,
):
"""Initialize a websocket CRUD."""
self.storage_collection = storage_collection
self.api_prefix = api_prefix
self.model_name = model_name
self.create_schema = create_schema
self.update_schema = update_schema
assert self.api_prefix[-1] != "/", "API prefix should not end in /"
@property
def item_id_key(self) -> str:
"""Return item ID key."""
return f"{self.model_name}_id"
@callback
def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None:
"""Set up the websocket commands."""
if create_list:
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/list",
self.ws_list_item,
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): f"{self.api_prefix}/list"}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/create",
websocket_api.require_admin(
websocket_api.async_response(self.ws_create_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.create_schema,
vol.Required("type"): f"{self.api_prefix}/create",
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/update",
websocket_api.require_admin(
websocket_api.async_response(self.ws_update_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.update_schema,
vol.Required("type"): f"{self.api_prefix}/update",
vol.Required(self.item_id_key): str,
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/delete",
websocket_api.require_admin(
websocket_api.async_response(self.ws_delete_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/delete",
vol.Required(self.item_id_key): str,
}
),
)
def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(msg["id"], self.storage_collection.async_items())
async def ws_create_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Create a item."""
try:
data = dict(msg)
data.pop("id")
data.pop("type")
item = await self.storage_collection.async_create_item(data)
connection.send_result(msg["id"], item)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_update_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Update a item."""
data = dict(msg)
msg_id = data.pop("id")
item_id = data.pop(self.item_id_key)
data.pop("type")
try:
item = await self.storage_collection.async_update_item(item_id, data)
connection.send_result(msg_id, item)
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {item_id}",
)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete a item."""
try:
await self.storage_collection.async_delete_item(msg[self.item_id_key])
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {msg[self.item_id_key]}",
)
connection.send_result(msg["id"])

View File

@ -473,8 +473,9 @@ class Entity(ABC):
self._on_remove = []
self._on_remove.append(func)
async def async_remove(self):
async def async_remove(self) -> None:
"""Remove entity from Home Assistant."""
assert self.hass is not None
await self.async_internal_will_remove_from_hass()
await self.async_will_remove_from_hass()

View File

@ -3,14 +3,6 @@ from unittest.mock import patch
import pytest
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH,
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.http import URL
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
@ -22,37 +14,3 @@ def prevent_io():
side_effect=lambda *args: mock_coro([]),
):
yield
@pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token):
"""Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=hass_access_token):
"""Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": "incorrect"}
)
else:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": access_token}
)
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK
# wrap in client
websocket.client = client
return websocket
return create_client

View File

@ -98,7 +98,7 @@ async def test_onboarding_user(hass, hass_storage, aiohttp_client):
assert user.name == "Test Name"
assert len(user.credentials) == 1
assert user.credentials[0].data["username"] == "test-user"
assert len(hass.data["person"].storage_data) == 1
assert len(hass.data["person"][1].async_items()) == 1
# Validate refresh token 1
resp = await client.post(

View File

@ -1,19 +1,15 @@
"""The tests for the person component."""
from unittest.mock import Mock
import logging
import pytest
from homeassistant.components import person
from homeassistant.components.device_tracker import (
ATTR_SOURCE_TYPE,
SOURCE_TYPE_GPS,
SOURCE_TYPE_ROUTER,
)
from homeassistant.components.person import (
ATTR_SOURCE,
ATTR_USER_ID,
DOMAIN,
PersonManager,
)
from homeassistant.components.person import ATTR_SOURCE, ATTR_USER_ID, DOMAIN
from homeassistant.const import (
ATTR_GPS_ACCURACY,
ATTR_ID,
@ -23,20 +19,29 @@ from homeassistant.const import (
STATE_UNKNOWN,
)
from homeassistant.core import CoreState, State
from homeassistant.helpers import collection
from homeassistant.setup import async_setup_component
from tests.common import (
assert_setup_component,
mock_component,
mock_coro_func,
mock_restore_cache,
)
from tests.common import assert_setup_component, mock_component, mock_restore_cache
DEVICE_TRACKER = "device_tracker.test_tracker"
DEVICE_TRACKER_2 = "device_tracker.test_tracker_2"
# pylint: disable=redefined-outer-name
@pytest.fixture
def storage_collection(hass):
"""Return an empty storage collection."""
id_manager = collection.IDManager()
return person.PersonStorageCollection(
person.PersonStore(hass, person.STORAGE_VERSION, person.STORAGE_KEY),
logging.getLogger(f"{person.__name__}.storage_collection"),
id_manager,
collection.YamlCollection(
logging.getLogger(f"{person.__name__}.yaml_collection"), id_manager
),
)
@pytest.fixture
def storage_setup(hass, hass_storage, hass_admin_user):
"""Storage setup."""
@ -433,21 +438,21 @@ async def test_load_person_storage_two_nonlinked(hass, hass_storage):
async def test_ws_list(hass, hass_ws_client, storage_setup):
"""Test listing via WS."""
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
resp = await client.send_json({"id": 6, "type": "person/list"})
resp = await client.receive_json()
assert resp["success"]
assert resp["result"]["storage"] == manager.storage_persons
assert resp["result"]["storage"] == manager.async_items()
assert len(resp["result"]["storage"]) == 1
assert len(resp["result"]["config"]) == 0
async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_user):
"""Test creating via WS."""
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
@ -462,7 +467,7 @@ async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_use
)
resp = await client.receive_json()
persons = manager.storage_persons
persons = manager.async_items()
assert len(persons) == 2
assert resp["success"]
@ -474,7 +479,7 @@ async def test_ws_create_requires_admin(
):
"""Test creating via WS requires admin."""
hass_admin_user.groups = []
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
@ -489,7 +494,7 @@ async def test_ws_create_requires_admin(
)
resp = await client.receive_json()
persons = manager.storage_persons
persons = manager.async_items()
assert len(persons) == 1
assert not resp["success"]
@ -497,10 +502,10 @@ async def test_ws_create_requires_admin(
async def test_ws_update(hass, hass_ws_client, storage_setup):
"""Test updating via WS."""
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
persons = manager.storage_persons
persons = manager.async_items()
resp = await client.send_json(
{
@ -514,7 +519,7 @@ async def test_ws_update(hass, hass_ws_client, storage_setup):
)
resp = await client.receive_json()
persons = manager.storage_persons
persons = manager.async_items()
assert len(persons) == 1
assert resp["success"]
@ -533,10 +538,10 @@ async def test_ws_update_require_admin(
):
"""Test updating via WS requires admin."""
hass_admin_user.groups = []
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
original = dict(manager.storage_persons[0])
original = dict(manager.async_items()[0])
resp = await client.send_json(
{
@ -551,23 +556,23 @@ async def test_ws_update_require_admin(
resp = await client.receive_json()
assert not resp["success"]
not_updated = dict(manager.storage_persons[0])
not_updated = dict(manager.async_items()[0])
assert original == not_updated
async def test_ws_delete(hass, hass_ws_client, storage_setup):
"""Test deleting via WS."""
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
persons = manager.storage_persons
persons = manager.async_items()
resp = await client.send_json(
{"id": 6, "type": "person/delete", "person_id": persons[0]["id"]}
)
resp = await client.receive_json()
persons = manager.storage_persons
persons = manager.async_items()
assert len(persons) == 0
assert resp["success"]
@ -581,7 +586,7 @@ async def test_ws_delete_require_admin(
):
"""Test deleting via WS requires admin."""
hass_admin_user.groups = []
manager = hass.data[DOMAIN]
manager = hass.data[DOMAIN][1]
client = await hass_ws_client(hass)
@ -589,7 +594,7 @@ async def test_ws_delete_require_admin(
{
"id": 6,
"type": "person/delete",
"person_id": manager.storage_persons[0]["id"],
"person_id": manager.async_items()[0]["id"],
"name": "Updated Name",
"device_trackers": [DEVICE_TRACKER_2],
"user_id": None,
@ -598,61 +603,64 @@ async def test_ws_delete_require_admin(
resp = await client.receive_json()
assert not resp["success"]
persons = manager.storage_persons
persons = manager.async_items()
assert len(persons) == 1
async def test_create_invalid_user_id(hass):
async def test_create_invalid_user_id(hass, storage_collection):
"""Test we do not allow invalid user ID during creation."""
manager = PersonManager(hass, Mock(), [])
await manager.async_initialize()
with pytest.raises(ValueError):
await manager.async_create_person(name="Hello", user_id="non-existing")
await storage_collection.async_create_item(
{"name": "Hello", "user_id": "non-existing"}
)
async def test_create_duplicate_user_id(hass, hass_admin_user):
async def test_create_duplicate_user_id(hass, hass_admin_user, storage_collection):
"""Test we do not allow duplicate user ID during creation."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
await manager.async_initialize()
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_admin_user.id}
)
with pytest.raises(ValueError):
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_admin_user.id}
)
async def test_update_double_user_id(hass, hass_admin_user):
async def test_update_double_user_id(hass, hass_admin_user, storage_collection):
"""Test we do not allow double user ID during update."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
await manager.async_initialize()
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
person = await manager.async_create_person(name="Hello")
await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_admin_user.id}
)
person = await storage_collection.async_create_item({"name": "Hello"})
with pytest.raises(ValueError):
await manager.async_update_person(
person_id=person["id"], user_id=hass_admin_user.id
await storage_collection.async_update_item(
person["id"], {"user_id": hass_admin_user.id}
)
async def test_update_invalid_user_id(hass):
async def test_update_invalid_user_id(hass, storage_collection):
"""Test updating to invalid user ID."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
await manager.async_initialize()
person = await manager.async_create_person(name="Hello")
person = await storage_collection.async_create_item({"name": "Hello"})
with pytest.raises(ValueError):
await manager.async_update_person(
person_id=person["id"], user_id="non-existing"
await storage_collection.async_update_item(
person["id"], {"user_id": "non-existing"}
)
async def test_update_person_when_user_removed(hass, hass_read_only_user):
async def test_update_person_when_user_removed(
hass, storage_setup, hass_read_only_user
):
"""Update person when user is removed."""
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
await manager.async_initialize()
person = await manager.async_create_person(
name="Hello", user_id=hass_read_only_user.id
storage_collection = hass.data[DOMAIN][1]
person = await storage_collection.async_create_item(
{"name": "Hello", "user_id": hass_read_only_user.id}
)
await hass.auth.async_remove_user(hass_read_only_user)
await hass.async_block_till_done()
assert person["user_id"] is None
assert storage_collection.data[person["id"]]["user_id"] is None

View File

@ -9,6 +9,13 @@ import requests_mock as _requests_mock
from homeassistant import util
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.providers import homeassistant, legacy_api_password
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH,
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.http import URL
from homeassistant.setup import async_setup_component
from homeassistant.util import location
pytest.register_assert_rewrite("tests.common")
@ -187,3 +194,37 @@ def hass_client(hass, aiohttp_client, hass_access_token):
)
return auth_client
@pytest.fixture
def hass_ws_client(aiohttp_client, hass_access_token):
"""Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=hass_access_token):
"""Create a websocket client."""
assert await async_setup_component(hass, "websocket_api", {})
client = await aiohttp_client(hass.http.app)
with patch("homeassistant.components.http.auth.setup_auth"):
websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
if access_token is None:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": "incorrect"}
)
else:
await websocket.send_json(
{"type": TYPE_AUTH, "access_token": access_token}
)
auth_ok = await websocket.receive_json()
assert auth_ok["type"] == TYPE_AUTH_OK
# wrap in client
websocket.client = client
return websocket
return create_client

View File

@ -0,0 +1,356 @@
"""Tests for the collection helper."""
import logging
import pytest
import voluptuous as vol
from homeassistant.helpers import collection, entity, entity_component, storage
from tests.common import flush_store
LOGGER = logging.getLogger(__name__)
def track_changes(coll: collection.ObservableCollection):
"""Create helper to track changes in a collection."""
changes = []
async def listener(*args):
changes.append(args)
coll.async_add_listener(listener)
return changes
class MockEntity(entity.Entity):
"""Entity that is config based."""
def __init__(self, config):
"""Initialize entity."""
self._config = config
@property
def unique_id(self):
"""Return unique ID of entity."""
return self._config["id"]
@property
def name(self):
"""Return name of entity."""
return self._config["name"]
@property
def state(self):
"""Return state of entity."""
return self._config["state"]
async def async_update_config(self, config):
"""Update entity config."""
self._config = config
self.async_write_ha_state()
class MockStorageCollection(collection.StorageCollection):
"""Mock storage collection."""
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
if "name" not in data:
raise ValueError("invalid")
return data
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return info["name"]
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
return {**data, **update_data}
def test_id_manager():
"""Test the ID manager."""
id_manager = collection.IDManager()
assert not id_manager.has_id("some_id")
data = {}
id_manager.add_collection(data)
assert not id_manager.has_id("some_id")
data["some_id"] = 1
assert id_manager.has_id("some_id")
assert id_manager.generate_id("some_id") == "some_id_2"
assert id_manager.generate_id("bla") == "bla"
async def test_observable_collection():
"""Test observerable collection."""
coll = collection.ObservableCollection(LOGGER)
assert coll.async_items() == []
coll.data["bla"] = 1
assert coll.async_items() == [1]
changes = track_changes(coll)
await coll.notify_change("mock_type", "mock_id", {"mock": "item"})
assert len(changes) == 1
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
async def test_yaml_collection():
"""Test a YAML collection."""
id_manager = collection.IDManager()
coll = collection.YamlCollection(LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load(
[{"id": "mock-1", "name": "Mock 1"}, {"id": "mock-2", "name": "Mock 2"}]
)
assert id_manager.has_id("mock-1")
assert id_manager.has_id("mock-2")
assert len(changes) == 2
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1"},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"id": "mock-2", "name": "Mock 2"},
)
async def test_yaml_collection_skipping_duplicate_ids():
"""Test YAML collection skipping duplicate IDs."""
id_manager = collection.IDManager()
id_manager.add_collection({"existing": True})
coll = collection.YamlCollection(LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load(
[{"id": "mock-1", "name": "Mock 1"}, {"id": "existing", "name": "Mock 2"}]
)
assert len(changes) == 1
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1"},
)
async def test_storage_collection(hass):
"""Test storage collection."""
store = storage.Store(hass, 1, "test-data")
await store.async_save(
{
"items": [
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"id": "mock-2", "name": "Mock 2", "data": 2},
]
}
)
id_manager = collection.IDManager()
coll = MockStorageCollection(store, LOGGER, id_manager)
changes = track_changes(coll)
await coll.async_load()
assert id_manager.has_id("mock-1")
assert id_manager.has_id("mock-2")
assert len(changes) == 2
assert changes[0] == (
collection.CHANGE_ADDED,
"mock-1",
{"id": "mock-1", "name": "Mock 1", "data": 1},
)
assert changes[1] == (
collection.CHANGE_ADDED,
"mock-2",
{"id": "mock-2", "name": "Mock 2", "data": 2},
)
item = await coll.async_create_item({"name": "Mock 3"})
assert item["id"] == "mock_3"
assert len(changes) == 3
assert changes[2] == (
collection.CHANGE_ADDED,
"mock_3",
{"id": "mock_3", "name": "Mock 3"},
)
updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"})
assert id_manager.has_id("mock-2")
assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2}
assert len(changes) == 4
assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item)
with pytest.raises(ValueError):
await coll.async_update_item("mock-2", {"id": "mock-2-updated"})
assert id_manager.has_id("mock-2")
assert not id_manager.has_id("mock-2-updated")
assert len(changes) == 4
await flush_store(store)
assert await storage.Store(hass, 1, "test-data").async_load() == {
"items": [
{"id": "mock-1", "name": "Mock 1", "data": 1},
{"id": "mock-2", "name": "Mock 2 updated", "data": 2},
{"id": "mock_3", "name": "Mock 3"},
]
}
async def test_attach_entity_component_collection(hass):
"""Test attaching collection to entity component."""
ent_comp = entity_component.EntityComponent(LOGGER, "test", hass)
coll = collection.ObservableCollection(LOGGER)
collection.attach_entity_component_collection(ent_comp, coll, MockEntity)
await coll.notify_change(
collection.CHANGE_ADDED,
"mock_id",
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
)
assert hass.states.get("test.mock_1").name == "Mock 1"
assert hass.states.get("test.mock_1").state == "initial"
await coll.notify_change(
collection.CHANGE_UPDATED,
"mock_id",
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
)
assert hass.states.get("test.mock_1").name == "Mock 1 updated"
assert hass.states.get("test.mock_1").state == "second"
await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None)
assert hass.states.get("test.mock_1") is None
async def test_storage_collection_websocket(hass, hass_ws_client):
"""Test exposing a storage collection via websockets."""
store = storage.Store(hass, 1, "test-data")
coll = MockStorageCollection(store, LOGGER)
changes = track_changes(coll)
collection.StorageCollectionWebsocket(
coll,
"test_item/collection",
"test_item",
{vol.Required("name"): str, vol.Required("immutable_string"): str},
{vol.Optional("name"): str},
).async_setup(hass)
client = await hass_ws_client(hass)
# Create invalid
await client.send_json(
{
"id": 1,
"type": "test_item/collection/create",
"name": 1,
# Forgot to add immutable_string
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "invalid_format"
assert len(changes) == 0
# Create
await client.send_json(
{
"id": 2,
"type": "test_item/collection/create",
"name": "Initial Name",
"immutable_string": "no-changes",
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"id": "initial_name",
"name": "Initial Name",
"immutable_string": "no-changes",
}
assert len(changes) == 1
assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"])
# List
await client.send_json({"id": 3, "type": "test_item/collection/list"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == [
{
"id": "initial_name",
"name": "Initial Name",
"immutable_string": "no-changes",
}
]
assert len(changes) == 1
# Update invalid data
await client.send_json(
{
"id": 4,
"type": "test_item/collection/update",
"test_item_id": "initial_name",
"immutable_string": "no-changes",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "invalid_format"
assert len(changes) == 1
# Update invalid item
await client.send_json(
{
"id": 5,
"type": "test_item/collection/update",
"test_item_id": "non-existing",
"name": "Updated name",
}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "not_found"
assert len(changes) == 1
# Update
await client.send_json(
{
"id": 6,
"type": "test_item/collection/update",
"test_item_id": "initial_name",
"name": "Updated name",
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"id": "initial_name",
"name": "Updated name",
"immutable_string": "no-changes",
}
assert len(changes) == 2
assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"])
# Delete invalid ID
await client.send_json(
{"id": 7, "type": "test_item/collection/update", "test_item_id": "non-existing"}
)
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "not_found"
assert len(changes) == 2
# Delete
await client.send_json(
{"id": 8, "type": "test_item/collection/delete", "test_item_id": "initial_name"}
)
response = await client.receive_json()
assert response["success"]
assert len(changes) == 3
assert changes[2] == (collection.CHANGE_REMOVED, "initial_name", None)