mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Extract Collection helper from Person integration (#30313)
* Add CRUD foundation * Use collection helper in person integration * Lint/pytest * Add tests * Lint * Create notification
This commit is contained in:
parent
3033dbd86c
commit
b9aba30a6e
@ -173,13 +173,13 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom
|
||||
return capabilities
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/action/list",
|
||||
vol.Required("device_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||
"""Handle request for device actions."""
|
||||
device_id = msg["device_id"]
|
||||
@ -187,13 +187,13 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
|
||||
connection.send_result(msg["id"], actions)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/condition/list",
|
||||
vol.Required("device_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_list_conditions(hass, connection, msg):
|
||||
"""Handle request for device conditions."""
|
||||
device_id = msg["device_id"]
|
||||
@ -201,13 +201,13 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
|
||||
connection.send_result(msg["id"], conditions)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/trigger/list",
|
||||
vol.Required("device_id"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_list_triggers(hass, connection, msg):
|
||||
"""Handle request for device triggers."""
|
||||
device_id = msg["device_id"]
|
||||
@ -215,13 +215,13 @@ async def websocket_device_automation_list_triggers(hass, connection, msg):
|
||||
connection.send_result(msg["id"], triggers)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/action/capabilities",
|
||||
vol.Required("action"): dict,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_get_action_capabilities(hass, connection, msg):
|
||||
"""Handle request for device action capabilities."""
|
||||
action = msg["action"]
|
||||
@ -231,13 +231,13 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
|
||||
connection.send_result(msg["id"], capabilities)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/condition/capabilities",
|
||||
vol.Required("condition"): dict,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_get_condition_capabilities(hass, connection, msg):
|
||||
"""Handle request for device condition capabilities."""
|
||||
condition = msg["condition"]
|
||||
@ -247,13 +247,13 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
|
||||
connection.send_result(msg["id"], capabilities)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "device_automation/trigger/capabilities",
|
||||
vol.Required("trigger"): dict,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg):
|
||||
"""Handle request for device trigger capabilities."""
|
||||
trigger = msg["trigger"]
|
||||
|
@ -1,9 +1,6 @@
|
||||
"""Support for tracking people."""
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -28,6 +25,7 @@ from homeassistant.const import (
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import Event, State, callback
|
||||
from homeassistant.helpers import collection, entity_registry
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
@ -48,8 +46,7 @@ CONF_USER_ID = "user_id"
|
||||
DOMAIN = "person"
|
||||
|
||||
STORAGE_KEY = DOMAIN
|
||||
STORAGE_VERSION = 1
|
||||
SAVE_DELAY = 10
|
||||
STORAGE_VERSION = 2
|
||||
# Device tracker states to ignore
|
||||
IGNORE_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE)
|
||||
|
||||
@ -75,217 +72,184 @@ _UNDEF = object()
|
||||
@bind_hass
|
||||
async def async_create_person(hass, name, *, user_id=None, device_trackers=None):
|
||||
"""Create a new person."""
|
||||
await hass.data[DOMAIN].async_create_person(
|
||||
name=name, user_id=user_id, device_trackers=device_trackers
|
||||
await hass.data[DOMAIN][1].async_create_item(
|
||||
{"name": name, "user_id": user_id, "device_trackers": device_trackers}
|
||||
)
|
||||
|
||||
|
||||
class PersonManager:
|
||||
"""Manage person data."""
|
||||
CREATE_FIELDS = {
|
||||
vol.Required("name"): vol.All(str, vol.Length(min=1)),
|
||||
vol.Optional("user_id"): vol.Any(str, None),
|
||||
vol.Optional("device_trackers", default=list): vol.All(
|
||||
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
UPDATE_FIELDS = {
|
||||
vol.Optional("name"): vol.All(str, vol.Length(min=1)),
|
||||
vol.Optional("user_id"): vol.Any(str, None),
|
||||
vol.Optional("device_trackers", default=list): vol.All(
|
||||
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class PersonStore(Store):
|
||||
"""Person storage."""
|
||||
|
||||
async def _async_migrate_func(self, old_version, old_data):
|
||||
"""Migrate to the new version.
|
||||
|
||||
Migrate storage to use format of collection helper.
|
||||
"""
|
||||
return {"items": old_data["persons"]}
|
||||
|
||||
|
||||
class PersonStorageCollection(collection.StorageCollection):
|
||||
"""Person collection stored in storage."""
|
||||
|
||||
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
|
||||
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistantType, component: EntityComponent, config_persons
|
||||
self,
|
||||
store: Store,
|
||||
logger: logging.Logger,
|
||||
id_manager: collection.IDManager,
|
||||
yaml_collection: collection.YamlCollection,
|
||||
):
|
||||
"""Initialize person storage."""
|
||||
self.hass = hass
|
||||
self.component = component
|
||||
self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
self.storage_data = None
|
||||
"""Initialize a person storage collection."""
|
||||
super().__init__(store, logger, id_manager)
|
||||
self.async_add_listener(self._collection_changed)
|
||||
self.yaml_collection = yaml_collection
|
||||
|
||||
config_data = self.config_data = OrderedDict()
|
||||
for conf in config_persons:
|
||||
person_id = conf[CONF_ID]
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
"""Validate the config is valid."""
|
||||
data = self.CREATE_SCHEMA(data)
|
||||
|
||||
if person_id in config_data:
|
||||
_LOGGER.error("Found config user with duplicate ID: %s", person_id)
|
||||
continue
|
||||
|
||||
config_data[person_id] = conf
|
||||
|
||||
@property
|
||||
def storage_persons(self):
|
||||
"""Iterate over persons stored in storage."""
|
||||
return list(self.storage_data.values())
|
||||
|
||||
@property
|
||||
def config_persons(self):
|
||||
"""Iterate over persons stored in config."""
|
||||
return list(self.config_data.values())
|
||||
|
||||
async def async_initialize(self):
|
||||
"""Get the person data."""
|
||||
raw_storage = await self.store.async_load()
|
||||
|
||||
if raw_storage is None:
|
||||
raw_storage = {"persons": []}
|
||||
|
||||
storage_data = self.storage_data = OrderedDict()
|
||||
|
||||
for person in raw_storage["persons"]:
|
||||
storage_data[person[CONF_ID]] = person
|
||||
|
||||
entities = []
|
||||
seen_users = set()
|
||||
|
||||
for person_conf in self.config_data.values():
|
||||
person_id = person_conf[CONF_ID]
|
||||
user_id = person_conf.get(CONF_USER_ID)
|
||||
|
||||
if user_id is not None:
|
||||
if await self.hass.auth.async_get_user(user_id) is None:
|
||||
_LOGGER.error("Invalid user_id detected for person %s", person_id)
|
||||
continue
|
||||
|
||||
if user_id in seen_users:
|
||||
_LOGGER.error(
|
||||
"Duplicate user_id %s detected for person %s",
|
||||
user_id,
|
||||
person_id,
|
||||
)
|
||||
continue
|
||||
|
||||
seen_users.add(user_id)
|
||||
|
||||
entities.append(Person(person_conf, False))
|
||||
|
||||
# To make sure IDs don't overlap between config/storage
|
||||
seen_persons = set(self.config_data)
|
||||
|
||||
for person_conf in storage_data.values():
|
||||
person_id = person_conf[CONF_ID]
|
||||
user_id = person_conf[CONF_USER_ID]
|
||||
|
||||
if person_id in seen_persons:
|
||||
_LOGGER.error(
|
||||
"Skipping adding person from storage with same ID as"
|
||||
" configuration.yaml entry: %s",
|
||||
person_id,
|
||||
)
|
||||
continue
|
||||
|
||||
if user_id is not None and user_id in seen_users:
|
||||
_LOGGER.error(
|
||||
"Duplicate user_id %s detected for person %s", user_id, person_id
|
||||
)
|
||||
continue
|
||||
|
||||
# To make sure all users have just 1 person linked.
|
||||
seen_users.add(user_id)
|
||||
|
||||
entities.append(Person(person_conf, True))
|
||||
|
||||
if entities:
|
||||
await self.component.async_add_entities(entities)
|
||||
|
||||
self.hass.bus.async_listen(EVENT_USER_REMOVED, self._user_removed)
|
||||
|
||||
async def async_create_person(self, *, name, device_trackers=None, user_id=None):
|
||||
"""Create a new person."""
|
||||
if not name:
|
||||
raise ValueError("Name is required")
|
||||
user_id = data.get("user_id")
|
||||
|
||||
if user_id is not None:
|
||||
await self._validate_user_id(user_id)
|
||||
|
||||
person = {
|
||||
CONF_ID: uuid.uuid4().hex,
|
||||
CONF_NAME: name,
|
||||
CONF_USER_ID: user_id,
|
||||
CONF_DEVICE_TRACKERS: device_trackers or [],
|
||||
}
|
||||
self.storage_data[person[CONF_ID]] = person
|
||||
self._async_schedule_save()
|
||||
await self.component.async_add_entities([Person(person, True)])
|
||||
return person
|
||||
return self.CREATE_SCHEMA(data)
|
||||
|
||||
async def async_update_person(
|
||||
self, person_id, *, name=_UNDEF, device_trackers=_UNDEF, user_id=_UNDEF
|
||||
):
|
||||
"""Update person."""
|
||||
current = self.storage_data.get(person_id)
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
return info["name"]
|
||||
|
||||
if current is None:
|
||||
raise ValueError("Invalid person specified.")
|
||||
async def _update_data(self, data: dict, update_data: dict) -> dict:
|
||||
"""Return a new updated data object."""
|
||||
update_data = self.UPDATE_SCHEMA(update_data)
|
||||
|
||||
changes = {
|
||||
key: value
|
||||
for key, value in (
|
||||
(CONF_NAME, name),
|
||||
(CONF_DEVICE_TRACKERS, device_trackers),
|
||||
(CONF_USER_ID, user_id),
|
||||
)
|
||||
if value is not _UNDEF and current[key] != value
|
||||
}
|
||||
user_id = update_data.get("user_id")
|
||||
|
||||
if CONF_USER_ID in changes and user_id is not None:
|
||||
if user_id is not None:
|
||||
await self._validate_user_id(user_id)
|
||||
|
||||
self.storage_data[person_id].update(changes)
|
||||
self._async_schedule_save()
|
||||
|
||||
for entity in self.component.entities:
|
||||
if entity.unique_id == person_id:
|
||||
entity.person_updated()
|
||||
break
|
||||
|
||||
return self.storage_data[person_id]
|
||||
|
||||
async def async_delete_person(self, person_id):
|
||||
"""Delete person."""
|
||||
if person_id not in self.storage_data:
|
||||
raise ValueError("Invalid person specified.")
|
||||
|
||||
self.storage_data.pop(person_id)
|
||||
self._async_schedule_save()
|
||||
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
|
||||
|
||||
for entity in self.component.entities:
|
||||
if entity.unique_id == person_id:
|
||||
await entity.async_remove()
|
||||
ent_reg.async_remove(entity.entity_id)
|
||||
break
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Schedule saving the area registry."""
|
||||
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> dict:
|
||||
"""Return data of area registry to store in a file."""
|
||||
return {"persons": list(self.storage_data.values())}
|
||||
return {**data, **update_data}
|
||||
|
||||
async def _validate_user_id(self, user_id):
|
||||
"""Validate the used user_id."""
|
||||
if await self.hass.auth.async_get_user(user_id) is None:
|
||||
raise ValueError("User does not exist")
|
||||
|
||||
if any(
|
||||
person
|
||||
for person in chain(self.storage_data.values(), self.config_data.values())
|
||||
if person.get(CONF_USER_ID) == user_id
|
||||
):
|
||||
raise ValueError("User already taken")
|
||||
for persons in (self.data.values(), self.yaml_collection.async_items()):
|
||||
if any(person for person in persons if person.get(CONF_USER_ID) == user_id):
|
||||
raise ValueError("User already taken")
|
||||
|
||||
async def _user_removed(self, event: Event):
|
||||
"""Handle event that a person is removed."""
|
||||
user_id = event.data["user_id"]
|
||||
for person in self.storage_data.values():
|
||||
if person[CONF_USER_ID] == user_id:
|
||||
await self.async_update_person(person_id=person[CONF_ID], user_id=None)
|
||||
async def _collection_changed(
|
||||
self, change_type: str, item_id: str, config: Optional[dict]
|
||||
) -> None:
|
||||
"""Handle a collection change."""
|
||||
if change_type != collection.CHANGE_REMOVED:
|
||||
return
|
||||
|
||||
ent_reg = await entity_registry.async_get_registry(self.hass)
|
||||
ent_reg.async_remove(ent_reg.async_get_entity_id(DOMAIN, DOMAIN, item_id))
|
||||
|
||||
|
||||
async def filter_yaml_data(hass: HomeAssistantType, persons: List[dict]) -> List[dict]:
|
||||
"""Validate YAML data that we can't validate via schema."""
|
||||
filtered = []
|
||||
person_invalid_user = []
|
||||
|
||||
for person_conf in persons:
|
||||
user_id = person_conf.get(CONF_USER_ID)
|
||||
|
||||
if user_id is not None:
|
||||
if await hass.auth.async_get_user(user_id) is None:
|
||||
_LOGGER.error(
|
||||
"Invalid user_id detected for person %s",
|
||||
person_conf[collection.CONF_ID],
|
||||
)
|
||||
person_invalid_user.append(
|
||||
f"- Person {person_conf[CONF_NAME]} (id: {person_conf[collection.CONF_ID]}) points at invalid user {user_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
filtered.append(person_conf)
|
||||
|
||||
if person_invalid_user:
|
||||
hass.components.persistent_notification.async_create(
|
||||
f"""
|
||||
The following persons point at invalid users:
|
||||
|
||||
{"- ".join(person_invalid_user)}
|
||||
""",
|
||||
"Invalid Person Configuration",
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistantType, config: ConfigType):
|
||||
"""Set up the person component."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
conf_persons = config.get(DOMAIN, [])
|
||||
manager = hass.data[DOMAIN] = PersonManager(hass, component, conf_persons)
|
||||
await manager.async_initialize()
|
||||
entity_component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
id_manager = collection.IDManager()
|
||||
yaml_collection = collection.YamlCollection(
|
||||
logging.getLogger(f"{__name__}.yaml_collection"), id_manager
|
||||
)
|
||||
storage_collection = PersonStorageCollection(
|
||||
PersonStore(hass, STORAGE_VERSION, STORAGE_KEY),
|
||||
logging.getLogger(f"{__name__}.storage_collection"),
|
||||
id_manager,
|
||||
yaml_collection,
|
||||
)
|
||||
|
||||
collection.attach_entity_component_collection(
|
||||
entity_component, yaml_collection, lambda conf: Person(conf, False)
|
||||
)
|
||||
collection.attach_entity_component_collection(
|
||||
entity_component, storage_collection, lambda conf: Person(conf, True)
|
||||
)
|
||||
|
||||
await yaml_collection.async_load(
|
||||
await filter_yaml_data(hass, config.get(DOMAIN, []))
|
||||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
hass.data[DOMAIN] = (yaml_collection, storage_collection)
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass, create_list=False)
|
||||
|
||||
websocket_api.async_register_command(hass, ws_list_person)
|
||||
websocket_api.async_register_command(hass, ws_create_person)
|
||||
websocket_api.async_register_command(hass, ws_update_person)
|
||||
websocket_api.async_register_command(hass, ws_delete_person)
|
||||
|
||||
async def _handle_user_removed(event: Event) -> None:
|
||||
"""Handle a user being removed."""
|
||||
user_id = event.data["user_id"]
|
||||
for person in storage_collection.async_items():
|
||||
if person[CONF_USER_ID] == user_id:
|
||||
await storage_collection.async_update_item(
|
||||
person[CONF_ID], {"user_id": None}
|
||||
)
|
||||
|
||||
hass.bus.async_listen(EVENT_USER_REMOVED, _handle_user_removed)
|
||||
|
||||
return True
|
||||
|
||||
@ -353,21 +317,21 @@ class Person(RestoreEntity):
|
||||
|
||||
if self.hass.is_running:
|
||||
# Update person now if hass is already running.
|
||||
self.person_updated()
|
||||
await self.async_update_config(self._config)
|
||||
else:
|
||||
# Wait for hass start to not have race between person
|
||||
# and device trackers finishing setup.
|
||||
@callback
|
||||
def person_start_hass(now):
|
||||
self.person_updated()
|
||||
async def person_start_hass(now):
|
||||
await self.async_update_config(self._config)
|
||||
|
||||
self.hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_START, person_start_hass
|
||||
)
|
||||
|
||||
@callback
|
||||
def person_updated(self):
|
||||
async def async_update_config(self, config):
|
||||
"""Handle when the config is updated."""
|
||||
self._config = config
|
||||
|
||||
if self._unsub_track_device is not None:
|
||||
self._unsub_track_device()
|
||||
self._unsub_track_device = None
|
||||
@ -441,89 +405,12 @@ def ws_list_person(
|
||||
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
"""List persons."""
|
||||
manager: PersonManager = hass.data[DOMAIN]
|
||||
yaml, storage = hass.data[DOMAIN]
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"storage": manager.storage_persons, "config": manager.config_persons},
|
||||
msg["id"], {"storage": storage.async_items(), "config": yaml.async_items()},
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "person/create",
|
||||
vol.Required("name"): vol.All(str, vol.Length(min=1)),
|
||||
vol.Optional("user_id"): vol.Any(str, None),
|
||||
vol.Optional("device_trackers", default=[]): vol.All(
|
||||
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
|
||||
),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.async_response
|
||||
async def ws_create_person(
|
||||
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
"""Create a person."""
|
||||
manager: PersonManager = hass.data[DOMAIN]
|
||||
try:
|
||||
person = await manager.async_create_person(
|
||||
name=msg["name"],
|
||||
user_id=msg.get("user_id"),
|
||||
device_trackers=msg["device_trackers"],
|
||||
)
|
||||
connection.send_result(msg["id"], person)
|
||||
except ValueError as err:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "person/update",
|
||||
vol.Required("person_id"): str,
|
||||
vol.Required("name"): vol.All(str, vol.Length(min=1)),
|
||||
vol.Optional("user_id"): vol.Any(str, None),
|
||||
vol.Optional(CONF_DEVICE_TRACKERS, default=[]): vol.All(
|
||||
cv.ensure_list, cv.entities_domain(DEVICE_TRACKER_DOMAIN)
|
||||
),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.async_response
|
||||
async def ws_update_person(
|
||||
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
"""Update a person."""
|
||||
manager: PersonManager = hass.data[DOMAIN]
|
||||
changes = {}
|
||||
for key in ("name", "user_id", "device_trackers"):
|
||||
if key in msg:
|
||||
changes[key] = msg[key]
|
||||
|
||||
try:
|
||||
person = await manager.async_update_person(msg["person_id"], **changes)
|
||||
connection.send_result(msg["id"], person)
|
||||
except ValueError as err:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{vol.Required("type"): "person/delete", vol.Required("person_id"): str}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.async_response
|
||||
async def ws_delete_person(
|
||||
hass: HomeAssistantType, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
"""Delete a person."""
|
||||
manager: PersonManager = hass.data[DOMAIN]
|
||||
await manager.async_delete_person(msg["person_id"])
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
def _get_latest(prev: Optional[State], curr: State):
|
||||
"""Get latest state."""
|
||||
if prev is None or curr.last_updated > prev.last_updated:
|
||||
|
@ -1,5 +1,9 @@
|
||||
"""WebSocket based API for Home Assistant."""
|
||||
from homeassistant.core import callback
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from . import commands, connection, const, decorators, http, messages
|
||||
@ -26,13 +30,18 @@ websocket_command = decorators.websocket_command
|
||||
|
||||
@bind_hass
|
||||
@callback
|
||||
def async_register_command(hass, command_or_handler, handler=None, schema=None):
|
||||
def async_register_command(
|
||||
hass: HomeAssistant,
|
||||
command_or_handler: Union[str, const.WebSocketCommandHandler],
|
||||
handler: Optional[const.WebSocketCommandHandler] = None,
|
||||
schema: Optional[vol.Schema] = None,
|
||||
) -> None:
|
||||
"""Register a websocket command."""
|
||||
# pylint: disable=protected-access
|
||||
if handler is None:
|
||||
handler = command_or_handler
|
||||
command = handler._ws_command
|
||||
schema = handler._ws_schema
|
||||
handler = cast(const.WebSocketCommandHandler, command_or_handler)
|
||||
command = handler._ws_command # type: ignore
|
||||
schema = handler._ws_schema # type: ignore
|
||||
else:
|
||||
command = command_or_handler
|
||||
handlers = hass.data.get(DOMAIN)
|
||||
|
@ -107,7 +107,6 @@ def handle_unsubscribe_events(hass, connection, msg):
|
||||
)
|
||||
|
||||
|
||||
@decorators.async_response
|
||||
@decorators.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "call_service",
|
||||
@ -116,6 +115,7 @@ def handle_unsubscribe_events(hass, connection, msg):
|
||||
vol.Optional("service_data"): dict,
|
||||
}
|
||||
)
|
||||
@decorators.async_response
|
||||
async def handle_call_service(hass, connection, msg):
|
||||
"""Handle call service command.
|
||||
|
||||
@ -181,8 +181,8 @@ def handle_get_states(hass, connection, msg):
|
||||
connection.send_message(messages.result_message(msg["id"], states))
|
||||
|
||||
|
||||
@decorators.async_response
|
||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||
@decorators.async_response
|
||||
async def handle_get_services(hass, connection, msg):
|
||||
"""Handle get services command.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Connection session."""
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Hashable
|
||||
from typing import Any, Callable, Dict, Hashable, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -37,7 +37,7 @@ class ActiveConnection:
|
||||
return Context(user_id=user.id)
|
||||
|
||||
@callback
|
||||
def send_result(self, msg_id, result=None):
|
||||
def send_result(self, msg_id: int, result: Optional[Any] = None) -> None:
|
||||
"""Send a result message."""
|
||||
self.send_message(messages.result_message(msg_id, result))
|
||||
|
||||
@ -49,7 +49,7 @@ class ActiveConnection:
|
||||
self.send_message(content)
|
||||
|
||||
@callback
|
||||
def send_error(self, msg_id, code, message):
|
||||
def send_error(self, msg_id: int, code: str, message: str) -> None:
|
||||
"""Send a error message."""
|
||||
self.send_message(messages.error_message(msg_id, code, message))
|
||||
|
||||
|
@ -3,9 +3,20 @@ import asyncio
|
||||
from concurrent import futures
|
||||
from functools import partial
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import ActiveConnection # noqa
|
||||
|
||||
|
||||
WebSocketCommandHandler = Callable[
|
||||
[HomeAssistant, "ActiveConnection", dict], None
|
||||
] # pylint: disable=invalid-name
|
||||
|
||||
|
||||
DOMAIN = "websocket_api"
|
||||
URL = "/api/websocket"
|
||||
MAX_PENDING_MSG = 512
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Decorators for the Websocket API."""
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
|
||||
from . import messages
|
||||
from . import const, messages
|
||||
from .connection import ActiveConnection
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
@ -20,7 +22,9 @@ async def _handle_async_response(func, hass, connection, msg):
|
||||
connection.async_handle_exception(msg, err)
|
||||
|
||||
|
||||
def async_response(func):
|
||||
def async_response(
|
||||
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
|
||||
) -> const.WebSocketCommandHandler:
|
||||
"""Decorate an async function to handle WebSocket API messages."""
|
||||
|
||||
@callback
|
||||
@ -32,7 +36,7 @@ def async_response(func):
|
||||
return schedule_handler
|
||||
|
||||
|
||||
def require_admin(func):
|
||||
def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||
"""Websocket decorator to require user to be an admin."""
|
||||
|
||||
@wraps(func)
|
||||
@ -104,7 +108,9 @@ def ws_require_user(
|
||||
return validator
|
||||
|
||||
|
||||
def websocket_command(schema):
|
||||
def websocket_command(
|
||||
schema: dict,
|
||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||
"""Tag a function as a websocket command."""
|
||||
command = schema["type"]
|
||||
|
||||
|
401
homeassistant/helpers/collection.py
Normal file
401
homeassistant/helpers/collection.py
Normal file
@ -0,0 +1,401 @@
|
||||
"""Helper to deal with YAML + storage."""
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import slugify
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
SAVE_DELAY = 10
|
||||
|
||||
CHANGE_ADDED = "added"
|
||||
CHANGE_UPDATED = "updated"
|
||||
CHANGE_REMOVED = "removed"
|
||||
|
||||
|
||||
ChangeListener = Callable[
|
||||
[
|
||||
# Change type
|
||||
str,
|
||||
# Item ID
|
||||
str,
|
||||
# New config (None if removed)
|
||||
Optional[dict],
|
||||
],
|
||||
Awaitable[None],
|
||||
] # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CollectionError(HomeAssistantError):
|
||||
"""Base class for collection related errors."""
|
||||
|
||||
|
||||
class ItemNotFound(CollectionError):
|
||||
"""Raised when an item is not found."""
|
||||
|
||||
def __init__(self, item_id: str):
|
||||
"""Initialize item not found error."""
|
||||
super().__init__(f"Item {item_id} not found.")
|
||||
self.item_id = item_id
|
||||
|
||||
|
||||
class IDManager:
|
||||
"""Keep track of IDs across different collections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initiate the ID manager."""
|
||||
self.collections: List[Dict[str, Any]] = []
|
||||
|
||||
def add_collection(self, collection: Dict[str, Any]) -> None:
|
||||
"""Add a collection to check for ID usage."""
|
||||
self.collections.append(collection)
|
||||
|
||||
def has_id(self, item_id: str) -> bool:
|
||||
"""Test if the ID exists."""
|
||||
return any(item_id in collection for collection in self.collections)
|
||||
|
||||
def generate_id(self, suggestion: str) -> str:
|
||||
"""Generate an ID."""
|
||||
base = slugify(suggestion)
|
||||
proposal = base
|
||||
attempt = 1
|
||||
|
||||
while self.has_id(proposal):
|
||||
attempt += 1
|
||||
proposal = f"{base}_{attempt}"
|
||||
|
||||
return proposal
|
||||
|
||||
|
||||
class ObservableCollection(ABC):
|
||||
"""Base collection type that can be observed."""
|
||||
|
||||
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
|
||||
"""Initialize the base collection."""
|
||||
self.logger = logger
|
||||
self.id_manager = id_manager or IDManager()
|
||||
self.data: Dict[str, dict] = {}
|
||||
self.listeners: List[ChangeListener] = []
|
||||
|
||||
self.id_manager.add_collection(self.data)
|
||||
|
||||
@callback
|
||||
def async_items(self) -> List[dict]:
|
||||
"""Return list of items in collection."""
|
||||
return list(self.data.values())
|
||||
|
||||
@callback
|
||||
def async_add_listener(self, listener: ChangeListener) -> None:
|
||||
"""Add a listener.
|
||||
|
||||
Will be called with (change_type, item_id, updated_config).
|
||||
"""
|
||||
self.listeners.append(listener)
|
||||
|
||||
async def notify_change(
|
||||
self, change_type: str, item_id: str, item: Optional[dict]
|
||||
) -> None:
|
||||
"""Notify listeners of a change."""
|
||||
self.logger.debug("%s %s: %s", change_type, item_id, item)
|
||||
for listener in self.listeners:
|
||||
await listener(change_type, item_id, item)
|
||||
|
||||
|
||||
class YamlCollection(ObservableCollection):
|
||||
"""Offer a fake CRUD interface on top of static YAML."""
|
||||
|
||||
async def async_load(self, data: List[dict]) -> None:
|
||||
"""Load the storage Manager."""
|
||||
for item in data:
|
||||
item_id = item[CONF_ID]
|
||||
|
||||
if self.id_manager.has_id(item_id):
|
||||
self.logger.warning("Duplicate ID '%s' detected, skipping", item_id)
|
||||
continue
|
||||
|
||||
self.data[item_id] = item
|
||||
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
|
||||
|
||||
|
||||
class StorageCollection(ObservableCollection):
|
||||
"""Offer a CRUD interface on top of JSON storage."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
logger: logging.Logger,
|
||||
id_manager: Optional[IDManager] = None,
|
||||
):
|
||||
"""Initialize the storage collection."""
|
||||
super().__init__(logger, id_manager)
|
||||
self.store = store
|
||||
|
||||
@property
|
||||
def hass(self) -> HomeAssistant:
|
||||
"""Home Assistant object."""
|
||||
return self.store.hass
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the storage Manager."""
|
||||
raw_storage = cast(Optional[dict], await self.store.async_load())
|
||||
|
||||
if raw_storage is None:
|
||||
raw_storage = {"items": []}
|
||||
|
||||
for item in raw_storage["items"]:
|
||||
self.data[item[CONF_ID]] = item
|
||||
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
|
||||
|
||||
@abstractmethod
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
"""Validate the config is valid."""
|
||||
|
||||
@callback
|
||||
@abstractmethod
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
|
||||
@abstractmethod
|
||||
async def _update_data(self, data: dict, update_data: dict) -> dict:
|
||||
"""Return a new updated data object."""
|
||||
|
||||
async def async_create_item(self, data: dict) -> dict:
|
||||
"""Create a new item."""
|
||||
item = await self._process_create_data(data)
|
||||
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
|
||||
self.data[item[CONF_ID]] = item
|
||||
self._async_schedule_save()
|
||||
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
|
||||
return item
|
||||
|
||||
async def async_update_item(self, item_id: str, updates: dict) -> dict:
|
||||
"""Update item."""
|
||||
if item_id not in self.data:
|
||||
raise ItemNotFound(item_id)
|
||||
|
||||
if CONF_ID in updates:
|
||||
raise ValueError("Cannot update ID")
|
||||
|
||||
current = self.data[item_id]
|
||||
|
||||
updated = await self._update_data(current, updates)
|
||||
|
||||
self.data[item_id] = updated
|
||||
self._async_schedule_save()
|
||||
|
||||
await self.notify_change(CHANGE_UPDATED, item_id, updated)
|
||||
|
||||
return self.data[item_id]
|
||||
|
||||
async def async_delete_item(self, item_id: str) -> None:
|
||||
"""Delete item."""
|
||||
if item_id not in self.data:
|
||||
raise ItemNotFound(item_id)
|
||||
|
||||
self.data.pop(item_id)
|
||||
self._async_schedule_save()
|
||||
|
||||
await self.notify_change(CHANGE_REMOVED, item_id, None)
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Schedule saving the area registry."""
|
||||
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> dict:
|
||||
"""Return data of area registry to store in a file."""
|
||||
return {"items": list(self.data.values())}
|
||||
|
||||
|
||||
@callback
|
||||
def attach_entity_component_collection(
|
||||
entity_component: EntityComponent,
|
||||
collection: ObservableCollection,
|
||||
create_entity: Callable[[dict], Entity],
|
||||
) -> None:
|
||||
"""Map a collection to an entity component."""
|
||||
entities = {}
|
||||
|
||||
async def _collection_changed(
|
||||
change_type: str, item_id: str, config: Optional[dict]
|
||||
) -> None:
|
||||
"""Handle a collection change."""
|
||||
if change_type == CHANGE_ADDED:
|
||||
entity = create_entity(cast(dict, config))
|
||||
await entity_component.async_add_entities([entity])
|
||||
entities[item_id] = entity
|
||||
return
|
||||
|
||||
if change_type == CHANGE_REMOVED:
|
||||
entity = entities.pop(item_id)
|
||||
await entity.async_remove()
|
||||
return
|
||||
|
||||
# CHANGE_UPDATED
|
||||
await entities[item_id].async_update_config(config) # type: ignore
|
||||
|
||||
collection.async_add_listener(_collection_changed)
|
||||
|
||||
|
||||
class StorageCollectionWebsocket:
|
||||
"""Class to expose storage collection management over websocket."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_collection: StorageCollection,
|
||||
api_prefix: str,
|
||||
model_name: str,
|
||||
create_schema: dict,
|
||||
update_schema: dict,
|
||||
):
|
||||
"""Initialize a websocket CRUD."""
|
||||
self.storage_collection = storage_collection
|
||||
self.api_prefix = api_prefix
|
||||
self.model_name = model_name
|
||||
self.create_schema = create_schema
|
||||
self.update_schema = update_schema
|
||||
|
||||
assert self.api_prefix[-1] != "/", "API prefix should not end in /"
|
||||
|
||||
@property
|
||||
def item_id_key(self) -> str:
|
||||
"""Return item ID key."""
|
||||
return f"{self.model_name}_id"
|
||||
|
||||
@callback
|
||||
def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None:
|
||||
"""Set up the websocket commands."""
|
||||
if create_list:
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/list",
|
||||
self.ws_list_item,
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{vol.Required("type"): f"{self.api_prefix}/list"}
|
||||
),
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/create",
|
||||
websocket_api.require_admin(
|
||||
websocket_api.async_response(self.ws_create_item)
|
||||
),
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
**self.create_schema,
|
||||
vol.Required("type"): f"{self.api_prefix}/create",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/update",
|
||||
websocket_api.require_admin(
|
||||
websocket_api.async_response(self.ws_update_item)
|
||||
),
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
**self.update_schema,
|
||||
vol.Required("type"): f"{self.api_prefix}/update",
|
||||
vol.Required(self.item_id_key): str,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/delete",
|
||||
websocket_api.require_admin(
|
||||
websocket_api.async_response(self.ws_delete_item)
|
||||
),
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): f"{self.api_prefix}/delete",
|
||||
vol.Required(self.item_id_key): str,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def ws_list_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List items."""
|
||||
connection.send_result(msg["id"], self.storage_collection.async_items())
|
||||
|
||||
async def ws_create_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Create a item."""
|
||||
try:
|
||||
data = dict(msg)
|
||||
data.pop("id")
|
||||
data.pop("type")
|
||||
item = await self.storage_collection.async_create_item(data)
|
||||
connection.send_result(msg["id"], item)
|
||||
except vol.Invalid as err:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_INVALID_FORMAT,
|
||||
humanize_error(data, err),
|
||||
)
|
||||
except ValueError as err:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
||||
)
|
||||
|
||||
async def ws_update_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Update a item."""
|
||||
data = dict(msg)
|
||||
msg_id = data.pop("id")
|
||||
item_id = data.pop(self.item_id_key)
|
||||
data.pop("type")
|
||||
|
||||
try:
|
||||
item = await self.storage_collection.async_update_item(item_id, data)
|
||||
connection.send_result(msg_id, item)
|
||||
except ItemNotFound:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"Unable to find {self.item_id_key} {item_id}",
|
||||
)
|
||||
except vol.Invalid as err:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_INVALID_FORMAT,
|
||||
humanize_error(data, err),
|
||||
)
|
||||
except ValueError as err:
|
||||
connection.send_error(
|
||||
msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
||||
)
|
||||
|
||||
async def ws_delete_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Delete a item."""
|
||||
try:
|
||||
await self.storage_collection.async_delete_item(msg[self.item_id_key])
|
||||
except ItemNotFound:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"Unable to find {self.item_id_key} {msg[self.item_id_key]}",
|
||||
)
|
||||
|
||||
connection.send_result(msg["id"])
|
@ -473,8 +473,9 @@ class Entity(ABC):
|
||||
self._on_remove = []
|
||||
self._on_remove.append(func)
|
||||
|
||||
async def async_remove(self):
|
||||
async def async_remove(self) -> None:
|
||||
"""Remove entity from Home Assistant."""
|
||||
assert self.hass is not None
|
||||
await self.async_internal_will_remove_from_hass()
|
||||
await self.async_will_remove_from_hass()
|
||||
|
||||
|
@ -3,14 +3,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH,
|
||||
TYPE_AUTH_OK,
|
||||
TYPE_AUTH_REQUIRED,
|
||||
)
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
@ -22,37 +14,3 @@ def prevent_io():
|
||||
side_effect=lambda *args: mock_coro([]),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client, hass_access_token):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
|
||||
async def create_client(hass, access_token=hass_access_token):
|
||||
"""Create a websocket client."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
with patch("homeassistant.components.http.auth.setup_auth"):
|
||||
websocket = await client.ws_connect(URL)
|
||||
auth_resp = await websocket.receive_json()
|
||||
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
if access_token is None:
|
||||
await websocket.send_json(
|
||||
{"type": TYPE_AUTH, "access_token": "incorrect"}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{"type": TYPE_AUTH, "access_token": access_token}
|
||||
)
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
assert auth_ok["type"] == TYPE_AUTH_OK
|
||||
|
||||
# wrap in client
|
||||
websocket.client = client
|
||||
return websocket
|
||||
|
||||
return create_client
|
||||
|
@ -98,7 +98,7 @@ async def test_onboarding_user(hass, hass_storage, aiohttp_client):
|
||||
assert user.name == "Test Name"
|
||||
assert len(user.credentials) == 1
|
||||
assert user.credentials[0].data["username"] == "test-user"
|
||||
assert len(hass.data["person"].storage_data) == 1
|
||||
assert len(hass.data["person"][1].async_items()) == 1
|
||||
|
||||
# Validate refresh token 1
|
||||
resp = await client.post(
|
||||
|
@ -1,19 +1,15 @@
|
||||
"""The tests for the person component."""
|
||||
from unittest.mock import Mock
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import person
|
||||
from homeassistant.components.device_tracker import (
|
||||
ATTR_SOURCE_TYPE,
|
||||
SOURCE_TYPE_GPS,
|
||||
SOURCE_TYPE_ROUTER,
|
||||
)
|
||||
from homeassistant.components.person import (
|
||||
ATTR_SOURCE,
|
||||
ATTR_USER_ID,
|
||||
DOMAIN,
|
||||
PersonManager,
|
||||
)
|
||||
from homeassistant.components.person import ATTR_SOURCE, ATTR_USER_ID, DOMAIN
|
||||
from homeassistant.const import (
|
||||
ATTR_GPS_ACCURACY,
|
||||
ATTR_ID,
|
||||
@ -23,20 +19,29 @@ from homeassistant.const import (
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import CoreState, State
|
||||
from homeassistant.helpers import collection
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
assert_setup_component,
|
||||
mock_component,
|
||||
mock_coro_func,
|
||||
mock_restore_cache,
|
||||
)
|
||||
from tests.common import assert_setup_component, mock_component, mock_restore_cache
|
||||
|
||||
DEVICE_TRACKER = "device_tracker.test_tracker"
|
||||
DEVICE_TRACKER_2 = "device_tracker.test_tracker_2"
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
@pytest.fixture
|
||||
def storage_collection(hass):
|
||||
"""Return an empty storage collection."""
|
||||
id_manager = collection.IDManager()
|
||||
return person.PersonStorageCollection(
|
||||
person.PersonStore(hass, person.STORAGE_VERSION, person.STORAGE_KEY),
|
||||
logging.getLogger(f"{person.__name__}.storage_collection"),
|
||||
id_manager,
|
||||
collection.YamlCollection(
|
||||
logging.getLogger(f"{person.__name__}.yaml_collection"), id_manager
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_setup(hass, hass_storage, hass_admin_user):
|
||||
"""Storage setup."""
|
||||
@ -433,21 +438,21 @@ async def test_load_person_storage_two_nonlinked(hass, hass_storage):
|
||||
|
||||
async def test_ws_list(hass, hass_ws_client, storage_setup):
|
||||
"""Test listing via WS."""
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
resp = await client.send_json({"id": 6, "type": "person/list"})
|
||||
resp = await client.receive_json()
|
||||
assert resp["success"]
|
||||
assert resp["result"]["storage"] == manager.storage_persons
|
||||
assert resp["result"]["storage"] == manager.async_items()
|
||||
assert len(resp["result"]["storage"]) == 1
|
||||
assert len(resp["result"]["config"]) == 0
|
||||
|
||||
|
||||
async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_user):
|
||||
"""Test creating via WS."""
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
@ -462,7 +467,7 @@ async def test_ws_create(hass, hass_ws_client, storage_setup, hass_read_only_use
|
||||
)
|
||||
resp = await client.receive_json()
|
||||
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
assert len(persons) == 2
|
||||
|
||||
assert resp["success"]
|
||||
@ -474,7 +479,7 @@ async def test_ws_create_requires_admin(
|
||||
):
|
||||
"""Test creating via WS requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
@ -489,7 +494,7 @@ async def test_ws_create_requires_admin(
|
||||
)
|
||||
resp = await client.receive_json()
|
||||
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
assert len(persons) == 1
|
||||
|
||||
assert not resp["success"]
|
||||
@ -497,10 +502,10 @@ async def test_ws_create_requires_admin(
|
||||
|
||||
async def test_ws_update(hass, hass_ws_client, storage_setup):
|
||||
"""Test updating via WS."""
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
|
||||
resp = await client.send_json(
|
||||
{
|
||||
@ -514,7 +519,7 @@ async def test_ws_update(hass, hass_ws_client, storage_setup):
|
||||
)
|
||||
resp = await client.receive_json()
|
||||
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
assert len(persons) == 1
|
||||
|
||||
assert resp["success"]
|
||||
@ -533,10 +538,10 @@ async def test_ws_update_require_admin(
|
||||
):
|
||||
"""Test updating via WS requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
original = dict(manager.storage_persons[0])
|
||||
original = dict(manager.async_items()[0])
|
||||
|
||||
resp = await client.send_json(
|
||||
{
|
||||
@ -551,23 +556,23 @@ async def test_ws_update_require_admin(
|
||||
resp = await client.receive_json()
|
||||
assert not resp["success"]
|
||||
|
||||
not_updated = dict(manager.storage_persons[0])
|
||||
not_updated = dict(manager.async_items()[0])
|
||||
assert original == not_updated
|
||||
|
||||
|
||||
async def test_ws_delete(hass, hass_ws_client, storage_setup):
|
||||
"""Test deleting via WS."""
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
|
||||
resp = await client.send_json(
|
||||
{"id": 6, "type": "person/delete", "person_id": persons[0]["id"]}
|
||||
)
|
||||
resp = await client.receive_json()
|
||||
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
assert len(persons) == 0
|
||||
|
||||
assert resp["success"]
|
||||
@ -581,7 +586,7 @@ async def test_ws_delete_require_admin(
|
||||
):
|
||||
"""Test deleting via WS requires admin."""
|
||||
hass_admin_user.groups = []
|
||||
manager = hass.data[DOMAIN]
|
||||
manager = hass.data[DOMAIN][1]
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
@ -589,7 +594,7 @@ async def test_ws_delete_require_admin(
|
||||
{
|
||||
"id": 6,
|
||||
"type": "person/delete",
|
||||
"person_id": manager.storage_persons[0]["id"],
|
||||
"person_id": manager.async_items()[0]["id"],
|
||||
"name": "Updated Name",
|
||||
"device_trackers": [DEVICE_TRACKER_2],
|
||||
"user_id": None,
|
||||
@ -598,61 +603,64 @@ async def test_ws_delete_require_admin(
|
||||
resp = await client.receive_json()
|
||||
assert not resp["success"]
|
||||
|
||||
persons = manager.storage_persons
|
||||
persons = manager.async_items()
|
||||
assert len(persons) == 1
|
||||
|
||||
|
||||
async def test_create_invalid_user_id(hass):
|
||||
async def test_create_invalid_user_id(hass, storage_collection):
|
||||
"""Test we do not allow invalid user ID during creation."""
|
||||
manager = PersonManager(hass, Mock(), [])
|
||||
await manager.async_initialize()
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_person(name="Hello", user_id="non-existing")
|
||||
await storage_collection.async_create_item(
|
||||
{"name": "Hello", "user_id": "non-existing"}
|
||||
)
|
||||
|
||||
|
||||
async def test_create_duplicate_user_id(hass, hass_admin_user):
|
||||
async def test_create_duplicate_user_id(hass, hass_admin_user, storage_collection):
|
||||
"""Test we do not allow duplicate user ID during creation."""
|
||||
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
|
||||
await manager.async_initialize()
|
||||
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
|
||||
await storage_collection.async_create_item(
|
||||
{"name": "Hello", "user_id": hass_admin_user.id}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
|
||||
await storage_collection.async_create_item(
|
||||
{"name": "Hello", "user_id": hass_admin_user.id}
|
||||
)
|
||||
|
||||
|
||||
async def test_update_double_user_id(hass, hass_admin_user):
|
||||
async def test_update_double_user_id(hass, hass_admin_user, storage_collection):
|
||||
"""Test we do not allow double user ID during update."""
|
||||
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
|
||||
await manager.async_initialize()
|
||||
await manager.async_create_person(name="Hello", user_id=hass_admin_user.id)
|
||||
person = await manager.async_create_person(name="Hello")
|
||||
await storage_collection.async_create_item(
|
||||
{"name": "Hello", "user_id": hass_admin_user.id}
|
||||
)
|
||||
person = await storage_collection.async_create_item({"name": "Hello"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_update_person(
|
||||
person_id=person["id"], user_id=hass_admin_user.id
|
||||
await storage_collection.async_update_item(
|
||||
person["id"], {"user_id": hass_admin_user.id}
|
||||
)
|
||||
|
||||
|
||||
async def test_update_invalid_user_id(hass):
|
||||
async def test_update_invalid_user_id(hass, storage_collection):
|
||||
"""Test updating to invalid user ID."""
|
||||
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
|
||||
await manager.async_initialize()
|
||||
person = await manager.async_create_person(name="Hello")
|
||||
person = await storage_collection.async_create_item({"name": "Hello"})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await manager.async_update_person(
|
||||
person_id=person["id"], user_id="non-existing"
|
||||
await storage_collection.async_update_item(
|
||||
person["id"], {"user_id": "non-existing"}
|
||||
)
|
||||
|
||||
|
||||
async def test_update_person_when_user_removed(hass, hass_read_only_user):
|
||||
async def test_update_person_when_user_removed(
|
||||
hass, storage_setup, hass_read_only_user
|
||||
):
|
||||
"""Update person when user is removed."""
|
||||
manager = PersonManager(hass, Mock(async_add_entities=mock_coro_func()), [])
|
||||
await manager.async_initialize()
|
||||
person = await manager.async_create_person(
|
||||
name="Hello", user_id=hass_read_only_user.id
|
||||
storage_collection = hass.data[DOMAIN][1]
|
||||
|
||||
person = await storage_collection.async_create_item(
|
||||
{"name": "Hello", "user_id": hass_read_only_user.id}
|
||||
)
|
||||
|
||||
await hass.auth.async_remove_user(hass_read_only_user)
|
||||
await hass.async_block_till_done()
|
||||
assert person["user_id"] is None
|
||||
|
||||
assert storage_collection.data[person["id"]]["user_id"] is None
|
||||
|
@ -9,6 +9,13 @@ import requests_mock as _requests_mock
|
||||
from homeassistant import util
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.providers import homeassistant, legacy_api_password
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH,
|
||||
TYPE_AUTH_OK,
|
||||
TYPE_AUTH_REQUIRED,
|
||||
)
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import location
|
||||
|
||||
pytest.register_assert_rewrite("tests.common")
|
||||
@ -187,3 +194,37 @@ def hass_client(hass, aiohttp_client, hass_access_token):
|
||||
)
|
||||
|
||||
return auth_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client, hass_access_token):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
|
||||
async def create_client(hass, access_token=hass_access_token):
|
||||
"""Create a websocket client."""
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
with patch("homeassistant.components.http.auth.setup_auth"):
|
||||
websocket = await client.ws_connect(URL)
|
||||
auth_resp = await websocket.receive_json()
|
||||
assert auth_resp["type"] == TYPE_AUTH_REQUIRED
|
||||
|
||||
if access_token is None:
|
||||
await websocket.send_json(
|
||||
{"type": TYPE_AUTH, "access_token": "incorrect"}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{"type": TYPE_AUTH, "access_token": access_token}
|
||||
)
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
assert auth_ok["type"] == TYPE_AUTH_OK
|
||||
|
||||
# wrap in client
|
||||
websocket.client = client
|
||||
return websocket
|
||||
|
||||
return create_client
|
||||
|
356
tests/helpers/test_collection.py
Normal file
356
tests/helpers/test_collection.py
Normal file
@ -0,0 +1,356 @@
|
||||
"""Tests for the collection helper."""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers import collection, entity, entity_component, storage
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def track_changes(coll: collection.ObservableCollection):
|
||||
"""Create helper to track changes in a collection."""
|
||||
changes = []
|
||||
|
||||
async def listener(*args):
|
||||
changes.append(args)
|
||||
|
||||
coll.async_add_listener(listener)
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
class MockEntity(entity.Entity):
|
||||
"""Entity that is config based."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize entity."""
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return unique ID of entity."""
|
||||
return self._config["id"]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return name of entity."""
|
||||
return self._config["name"]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return state of entity."""
|
||||
return self._config["state"]
|
||||
|
||||
async def async_update_config(self, config):
|
||||
"""Update entity config."""
|
||||
self._config = config
|
||||
self.async_write_ha_state()
|
||||
|
||||
|
||||
class MockStorageCollection(collection.StorageCollection):
|
||||
"""Mock storage collection."""
|
||||
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
"""Validate the config is valid."""
|
||||
if "name" not in data:
|
||||
raise ValueError("invalid")
|
||||
|
||||
return data
|
||||
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
return info["name"]
|
||||
|
||||
async def _update_data(self, data: dict, update_data: dict) -> dict:
|
||||
"""Return a new updated data object."""
|
||||
return {**data, **update_data}
|
||||
|
||||
|
||||
def test_id_manager():
|
||||
"""Test the ID manager."""
|
||||
id_manager = collection.IDManager()
|
||||
assert not id_manager.has_id("some_id")
|
||||
data = {}
|
||||
id_manager.add_collection(data)
|
||||
assert not id_manager.has_id("some_id")
|
||||
data["some_id"] = 1
|
||||
assert id_manager.has_id("some_id")
|
||||
assert id_manager.generate_id("some_id") == "some_id_2"
|
||||
assert id_manager.generate_id("bla") == "bla"
|
||||
|
||||
|
||||
async def test_observable_collection():
|
||||
"""Test observerable collection."""
|
||||
coll = collection.ObservableCollection(LOGGER)
|
||||
assert coll.async_items() == []
|
||||
coll.data["bla"] = 1
|
||||
assert coll.async_items() == [1]
|
||||
|
||||
changes = track_changes(coll)
|
||||
await coll.notify_change("mock_type", "mock_id", {"mock": "item"})
|
||||
assert len(changes) == 1
|
||||
assert changes[0] == ("mock_type", "mock_id", {"mock": "item"})
|
||||
|
||||
|
||||
async def test_yaml_collection():
|
||||
"""Test a YAML collection."""
|
||||
id_manager = collection.IDManager()
|
||||
coll = collection.YamlCollection(LOGGER, id_manager)
|
||||
changes = track_changes(coll)
|
||||
await coll.async_load(
|
||||
[{"id": "mock-1", "name": "Mock 1"}, {"id": "mock-2", "name": "Mock 2"}]
|
||||
)
|
||||
assert id_manager.has_id("mock-1")
|
||||
assert id_manager.has_id("mock-2")
|
||||
assert len(changes) == 2
|
||||
assert changes[0] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock-1",
|
||||
{"id": "mock-1", "name": "Mock 1"},
|
||||
)
|
||||
assert changes[1] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock-2",
|
||||
{"id": "mock-2", "name": "Mock 2"},
|
||||
)
|
||||
|
||||
|
||||
async def test_yaml_collection_skipping_duplicate_ids():
|
||||
"""Test YAML collection skipping duplicate IDs."""
|
||||
id_manager = collection.IDManager()
|
||||
id_manager.add_collection({"existing": True})
|
||||
coll = collection.YamlCollection(LOGGER, id_manager)
|
||||
changes = track_changes(coll)
|
||||
await coll.async_load(
|
||||
[{"id": "mock-1", "name": "Mock 1"}, {"id": "existing", "name": "Mock 2"}]
|
||||
)
|
||||
assert len(changes) == 1
|
||||
assert changes[0] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock-1",
|
||||
{"id": "mock-1", "name": "Mock 1"},
|
||||
)
|
||||
|
||||
|
||||
async def test_storage_collection(hass):
|
||||
"""Test storage collection."""
|
||||
store = storage.Store(hass, 1, "test-data")
|
||||
await store.async_save(
|
||||
{
|
||||
"items": [
|
||||
{"id": "mock-1", "name": "Mock 1", "data": 1},
|
||||
{"id": "mock-2", "name": "Mock 2", "data": 2},
|
||||
]
|
||||
}
|
||||
)
|
||||
id_manager = collection.IDManager()
|
||||
coll = MockStorageCollection(store, LOGGER, id_manager)
|
||||
changes = track_changes(coll)
|
||||
|
||||
await coll.async_load()
|
||||
assert id_manager.has_id("mock-1")
|
||||
assert id_manager.has_id("mock-2")
|
||||
assert len(changes) == 2
|
||||
assert changes[0] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock-1",
|
||||
{"id": "mock-1", "name": "Mock 1", "data": 1},
|
||||
)
|
||||
assert changes[1] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock-2",
|
||||
{"id": "mock-2", "name": "Mock 2", "data": 2},
|
||||
)
|
||||
|
||||
item = await coll.async_create_item({"name": "Mock 3"})
|
||||
assert item["id"] == "mock_3"
|
||||
assert len(changes) == 3
|
||||
assert changes[2] == (
|
||||
collection.CHANGE_ADDED,
|
||||
"mock_3",
|
||||
{"id": "mock_3", "name": "Mock 3"},
|
||||
)
|
||||
|
||||
updated_item = await coll.async_update_item("mock-2", {"name": "Mock 2 updated"})
|
||||
assert id_manager.has_id("mock-2")
|
||||
assert updated_item == {"id": "mock-2", "name": "Mock 2 updated", "data": 2}
|
||||
assert len(changes) == 4
|
||||
assert changes[3] == (collection.CHANGE_UPDATED, "mock-2", updated_item)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await coll.async_update_item("mock-2", {"id": "mock-2-updated"})
|
||||
|
||||
assert id_manager.has_id("mock-2")
|
||||
assert not id_manager.has_id("mock-2-updated")
|
||||
assert len(changes) == 4
|
||||
|
||||
await flush_store(store)
|
||||
|
||||
assert await storage.Store(hass, 1, "test-data").async_load() == {
|
||||
"items": [
|
||||
{"id": "mock-1", "name": "Mock 1", "data": 1},
|
||||
{"id": "mock-2", "name": "Mock 2 updated", "data": 2},
|
||||
{"id": "mock_3", "name": "Mock 3"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def test_attach_entity_component_collection(hass):
|
||||
"""Test attaching collection to entity component."""
|
||||
ent_comp = entity_component.EntityComponent(LOGGER, "test", hass)
|
||||
coll = collection.ObservableCollection(LOGGER)
|
||||
collection.attach_entity_component_collection(ent_comp, coll, MockEntity)
|
||||
|
||||
await coll.notify_change(
|
||||
collection.CHANGE_ADDED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "initial", "name": "Mock 1"},
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1").name == "Mock 1"
|
||||
assert hass.states.get("test.mock_1").state == "initial"
|
||||
|
||||
await coll.notify_change(
|
||||
collection.CHANGE_UPDATED,
|
||||
"mock_id",
|
||||
{"id": "mock_id", "state": "second", "name": "Mock 1 updated"},
|
||||
)
|
||||
|
||||
assert hass.states.get("test.mock_1").name == "Mock 1 updated"
|
||||
assert hass.states.get("test.mock_1").state == "second"
|
||||
|
||||
await coll.notify_change(collection.CHANGE_REMOVED, "mock_id", None)
|
||||
|
||||
assert hass.states.get("test.mock_1") is None
|
||||
|
||||
|
||||
async def test_storage_collection_websocket(hass, hass_ws_client):
|
||||
"""Test exposing a storage collection via websockets."""
|
||||
store = storage.Store(hass, 1, "test-data")
|
||||
coll = MockStorageCollection(store, LOGGER)
|
||||
changes = track_changes(coll)
|
||||
collection.StorageCollectionWebsocket(
|
||||
coll,
|
||||
"test_item/collection",
|
||||
"test_item",
|
||||
{vol.Required("name"): str, vol.Required("immutable_string"): str},
|
||||
{vol.Optional("name"): str},
|
||||
).async_setup(hass)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
# Create invalid
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 1,
|
||||
"type": "test_item/collection/create",
|
||||
"name": 1,
|
||||
# Forgot to add immutable_string
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "invalid_format"
|
||||
assert len(changes) == 0
|
||||
|
||||
# Create
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 2,
|
||||
"type": "test_item/collection/create",
|
||||
"name": "Initial Name",
|
||||
"immutable_string": "no-changes",
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == {
|
||||
"id": "initial_name",
|
||||
"name": "Initial Name",
|
||||
"immutable_string": "no-changes",
|
||||
}
|
||||
assert len(changes) == 1
|
||||
assert changes[0] == (collection.CHANGE_ADDED, "initial_name", response["result"])
|
||||
|
||||
# List
|
||||
await client.send_json({"id": 3, "type": "test_item/collection/list"})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == [
|
||||
{
|
||||
"id": "initial_name",
|
||||
"name": "Initial Name",
|
||||
"immutable_string": "no-changes",
|
||||
}
|
||||
]
|
||||
assert len(changes) == 1
|
||||
|
||||
# Update invalid data
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 4,
|
||||
"type": "test_item/collection/update",
|
||||
"test_item_id": "initial_name",
|
||||
"immutable_string": "no-changes",
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "invalid_format"
|
||||
assert len(changes) == 1
|
||||
|
||||
# Update invalid item
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "test_item/collection/update",
|
||||
"test_item_id": "non-existing",
|
||||
"name": "Updated name",
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "not_found"
|
||||
assert len(changes) == 1
|
||||
|
||||
# Update
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 6,
|
||||
"type": "test_item/collection/update",
|
||||
"test_item_id": "initial_name",
|
||||
"name": "Updated name",
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == {
|
||||
"id": "initial_name",
|
||||
"name": "Updated name",
|
||||
"immutable_string": "no-changes",
|
||||
}
|
||||
assert len(changes) == 2
|
||||
assert changes[1] == (collection.CHANGE_UPDATED, "initial_name", response["result"])
|
||||
|
||||
# Delete invalid ID
|
||||
await client.send_json(
|
||||
{"id": 7, "type": "test_item/collection/update", "test_item_id": "non-existing"}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "not_found"
|
||||
assert len(changes) == 2
|
||||
|
||||
# Delete
|
||||
await client.send_json(
|
||||
{"id": 8, "type": "test_item/collection/delete", "test_item_id": "initial_name"}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
assert len(changes) == 3
|
||||
assert changes[2] == (collection.CHANGE_REMOVED, "initial_name", None)
|
Loading…
x
Reference in New Issue
Block a user