Mobile App add device tracker to person registering app (#30460)

This commit is contained in:
Paulus Schoutsen 2020-01-04 23:15:50 +01:00 committed by GitHub
parent e233dd7cbe
commit 95cd0a2c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 174 additions and 43 deletions

View File

@ -1,7 +1,11 @@
"""Config flow for Mobile App.""" """Config flow for Mobile App."""
from homeassistant import config_entries import uuid
from .const import ATTR_DEVICE_ID, ATTR_DEVICE_NAME, DOMAIN from homeassistant import config_entries
from homeassistant.components import person
from homeassistant.helpers import entity_registry
from .const import ATTR_APP_ID, ATTR_DEVICE_ID, ATTR_DEVICE_NAME, CONF_USER_ID, DOMAIN
@config_entries.HANDLERS.register(DOMAIN) @config_entries.HANDLERS.register(DOMAIN)
@ -23,7 +27,25 @@ class MobileAppFlowHandler(config_entries.ConfigFlow):
async def async_step_registration(self, user_input=None): async def async_step_registration(self, user_input=None):
"""Handle a flow initialized during registration.""" """Handle a flow initialized during registration."""
await self.async_set_unique_id(user_input[ATTR_DEVICE_ID]) if ATTR_DEVICE_ID in user_input:
# Unique ID is combi of app + device ID.
await self.async_set_unique_id(
f"{user_input[ATTR_APP_ID]}-{user_input[ATTR_DEVICE_ID]}"
)
else:
user_input[ATTR_DEVICE_ID] = str(uuid.uuid4()).replace("-", "")
# Register device tracker entity and add to person registering app
ent_reg = await entity_registry.async_get_registry(self.hass)
devt_entry = ent_reg.async_get_or_create(
"device_tracker",
DOMAIN,
user_input[ATTR_DEVICE_ID],
suggested_object_id=user_input[ATTR_DEVICE_NAME],
)
await person.async_add_user_device_tracker(
self.hass, user_input[CONF_USER_ID], devt_entry.entity_id
)
return self.async_create_entry( return self.async_create_entry(
title=user_input[ATTR_DEVICE_NAME], data=user_input title=user_input[ATTR_DEVICE_NAME], data=user_input

View File

@ -25,7 +25,6 @@ ATTR_DEVICE_ID = "device_id"
ATTR_DEVICE_NAME = "device_name" ATTR_DEVICE_NAME = "device_name"
ATTR_MANUFACTURER = "manufacturer" ATTR_MANUFACTURER = "manufacturer"
ATTR_MODEL = "model" ATTR_MODEL = "model"
ATTR_MODEL_ID = "model_id"
ATTR_OS_NAME = "os_name" ATTR_OS_NAME = "os_name"
ATTR_OS_VERSION = "os_version" ATTR_OS_VERSION = "os_version"
ATTR_PUSH_TOKEN = "push_token" ATTR_PUSH_TOKEN = "push_token"

View File

@ -1,7 +1,6 @@
"""Provides an HTTP API for mobile_app.""" """Provides an HTTP API for mobile_app."""
import secrets import secrets
from typing import Dict from typing import Dict
import uuid
from aiohttp.web import Request, Response from aiohttp.web import Request, Response
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -21,7 +20,6 @@ from .const import (
ATTR_DEVICE_NAME, ATTR_DEVICE_NAME,
ATTR_MANUFACTURER, ATTR_MANUFACTURER,
ATTR_MODEL, ATTR_MODEL,
ATTR_MODEL_ID,
ATTR_OS_NAME, ATTR_OS_NAME,
ATTR_OS_VERSION, ATTR_OS_VERSION,
ATTR_SUPPORTS_ENCRYPTION, ATTR_SUPPORTS_ENCRYPTION,
@ -50,7 +48,7 @@ class RegistrationsView(HomeAssistantView):
vol.Required(ATTR_DEVICE_NAME): cv.string, vol.Required(ATTR_DEVICE_NAME): cv.string,
vol.Required(ATTR_MANUFACTURER): cv.string, vol.Required(ATTR_MANUFACTURER): cv.string,
vol.Required(ATTR_MODEL): cv.string, vol.Required(ATTR_MODEL): cv.string,
vol.Optional(ATTR_MODEL_ID): cv.string, # Added in 0.104 vol.Optional(ATTR_DEVICE_ID): cv.string, # Added in 0.104
vol.Required(ATTR_OS_NAME): cv.string, vol.Required(ATTR_OS_NAME): cv.string,
vol.Optional(ATTR_OS_VERSION): cv.string, vol.Optional(ATTR_OS_VERSION): cv.string,
vol.Required(ATTR_SUPPORTS_ENCRYPTION, default=False): cv.boolean, vol.Required(ATTR_SUPPORTS_ENCRYPTION, default=False): cv.boolean,
@ -70,14 +68,6 @@ class RegistrationsView(HomeAssistantView):
CONF_CLOUDHOOK_URL CONF_CLOUDHOOK_URL
] = await hass.components.cloud.async_create_cloudhook(webhook_id) ] = await hass.components.cloud.async_create_cloudhook(webhook_id)
model_id = data.get(ATTR_MODEL_ID)
if model_id is None:
data[ATTR_DEVICE_ID] = str(uuid.uuid4()).replace("-", "")
else:
data[ATTR_DEVICE_ID] = f"{data[ATTR_APP_ID]}-{model_id}"
data[CONF_WEBHOOK_ID] = webhook_id data[CONF_WEBHOOK_ID] = webhook_id
if data[ATTR_SUPPORTS_ENCRYPTION] and supports_encryption(): if data[ATTR_SUPPORTS_ENCRYPTION] and supports_encryption():

View File

@ -4,7 +4,7 @@
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/mobile_app", "documentation": "https://www.home-assistant.io/integrations/mobile_app",
"requirements": ["PyNaCl==1.3.0"], "requirements": ["PyNaCl==1.3.0"],
"dependencies": ["http", "webhook"], "dependencies": ["http", "webhook", "person"],
"after_dependencies": ["cloud"], "after_dependencies": ["cloud"],
"codeowners": ["@robbiet480"] "codeowners": ["@robbiet480"]
} }

View File

@ -1,6 +1,6 @@
"""Support for tracking people.""" """Support for tracking people."""
import logging import logging
from typing import List, Optional from typing import List, Optional, cast
import voluptuous as vol import voluptuous as vol
@ -24,9 +24,8 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import Event, State, callback from homeassistant.core import Event, HomeAssistant, State, callback, split_entity_id
from homeassistant.helpers import collection, entity_registry from homeassistant.helpers import collection, config_validation as cv, entity_registry
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
@ -77,6 +76,29 @@ async def async_create_person(hass, name, *, user_id=None, device_trackers=None)
) )
@bind_hass
async def async_add_user_device_tracker(
hass: HomeAssistant, user_id: str, device_tracker_entity_id: str
):
"""Add a device tracker to a person linked to a user."""
coll = cast(PersonStorageCollection, hass.data[DOMAIN][1])
for person in coll.async_items():
if person.get(ATTR_USER_ID) != user_id:
continue
device_trackers = person["device_trackers"]
if device_tracker_entity_id in device_trackers:
return
await coll.async_update_item(
person[collection.CONF_ID],
{"device_trackers": device_trackers + [device_tracker_entity_id]},
)
break
CREATE_FIELDS = { CREATE_FIELDS = {
vol.Required("name"): vol.All(str, vol.Length(min=1)), vol.Required("name"): vol.All(str, vol.Length(min=1)),
vol.Optional("user_id"): vol.Any(str, None), vol.Optional("user_id"): vol.Any(str, None),
@ -124,6 +146,36 @@ class PersonStorageCollection(collection.StorageCollection):
self.async_add_listener(self._collection_changed) self.async_add_listener(self._collection_changed)
self.yaml_collection = yaml_collection self.yaml_collection = yaml_collection
async def async_load(self) -> None:
"""Load the Storage collection."""
await super().async_load()
self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._entity_registry_updated
)
async def _entity_registry_updated(self, event) -> None:
"""Handle entity registry updated."""
if event.data["action"] != "remove":
return
entity_id = event.data["entity_id"]
if split_entity_id(entity_id)[0] != "device_tracker":
return
for person in list(self.data.values()):
if entity_id not in person["device_trackers"]:
continue
await self.async_update_item(
person[collection.CONF_ID],
{
"device_trackers": [
devt for devt in person["device_trackers"] if devt != entity_id
]
},
)
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid.""" """Validate the config is valid."""
data = self.CREATE_SCHEMA(data) data = self.CREATE_SCHEMA(data)

View File

@ -17,7 +17,6 @@ REGISTER = {
"device_name": "Test 1", "device_name": "Test 1",
"manufacturer": "mobile_app", "manufacturer": "mobile_app",
"model": "Test", "model": "Test",
"model_id": "mock-model-id",
"os_name": "Linux", "os_name": "Linux",
"os_version": "1.0", "os_version": "1.0",
"supports_encryption": True, "supports_encryption": True,
@ -31,6 +30,7 @@ REGISTER_CLEARTEXT = {
"device_name": "Test 1", "device_name": "Test 1",
"manufacturer": "mobile_app", "manufacturer": "mobile_app",
"model": "Test", "model": "Test",
"device_id": "mock-device-id",
"os_name": "Linux", "os_name": "Linux",
"os_version": "1.0", "os_version": "1.0",
"supports_encryption": False, "supports_encryption": False,

View File

@ -1,15 +1,62 @@
"""Tests for the mobile_app HTTP API.""" """Tests for the mobile_app HTTP API."""
# pylint: disable=redefined-outer-name,unused-import import json
from unittest.mock import patch
import pytest import pytest
from homeassistant.components.mobile_app.const import CONF_SECRET, DOMAIN from homeassistant.components.mobile_app.const import CONF_SECRET, DOMAIN
from homeassistant.const import CONF_WEBHOOK_ID from homeassistant.const import CONF_WEBHOOK_ID
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .const import REGISTER, RENDER_TEMPLATE from .const import REGISTER, REGISTER_CLEARTEXT, RENDER_TEMPLATE
from tests.common import mock_coro
async def test_registration(hass, hass_client): async def test_registration(hass, hass_client, hass_admin_user):
"""Test that registrations happen."""
await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
api_client = await hass_client()
with patch(
"homeassistant.components.person.async_add_user_device_tracker",
spec=True,
return_value=mock_coro(),
) as add_user_dev_track:
resp = await api_client.post(
"/api/mobile_app/registrations", json=REGISTER_CLEARTEXT
)
assert len(add_user_dev_track.mock_calls) == 1
assert add_user_dev_track.mock_calls[0][1][1] == hass_admin_user.id
assert add_user_dev_track.mock_calls[0][1][2] == "device_tracker.test_1"
assert resp.status == 201
register_json = await resp.json()
assert CONF_WEBHOOK_ID in register_json
assert CONF_SECRET in register_json
entries = hass.config_entries.async_entries(DOMAIN)
assert entries[0].unique_id == "io.homeassistant.mobile_app_test-mock-device-id"
assert entries[0].data["device_id"] == REGISTER_CLEARTEXT["device_id"]
assert entries[0].data["app_data"] == REGISTER_CLEARTEXT["app_data"]
assert entries[0].data["app_id"] == REGISTER_CLEARTEXT["app_id"]
assert entries[0].data["app_name"] == REGISTER_CLEARTEXT["app_name"]
assert entries[0].data["app_version"] == REGISTER_CLEARTEXT["app_version"]
assert entries[0].data["device_name"] == REGISTER_CLEARTEXT["device_name"]
assert entries[0].data["manufacturer"] == REGISTER_CLEARTEXT["manufacturer"]
assert entries[0].data["model"] == REGISTER_CLEARTEXT["model"]
assert entries[0].data["os_name"] == REGISTER_CLEARTEXT["os_name"]
assert entries[0].data["os_version"] == REGISTER_CLEARTEXT["os_version"]
assert (
entries[0].data["supports_encryption"]
== REGISTER_CLEARTEXT["supports_encryption"]
)
async def test_registration_encryption(hass, hass_client):
"""Test that registrations happen.""" """Test that registrations happen."""
try: try:
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -18,8 +65,6 @@ async def test_registration(hass, hass_client):
pytest.skip("libnacl/libsodium is not installed") pytest.skip("libnacl/libsodium is not installed")
return return
import json
await async_setup_component(hass, DOMAIN, {DOMAIN: {}}) await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
api_client = await hass_client() api_client = await hass_client()
@ -28,22 +73,6 @@ async def test_registration(hass, hass_client):
assert resp.status == 201 assert resp.status == 201
register_json = await resp.json() register_json = await resp.json()
assert CONF_WEBHOOK_ID in register_json
assert CONF_SECRET in register_json
entries = hass.config_entries.async_entries(DOMAIN)
assert entries[0].unique_id == "io.homeassistant.mobile_app_test-mock-model-id"
assert entries[0].data["app_data"] == REGISTER["app_data"]
assert entries[0].data["app_id"] == REGISTER["app_id"]
assert entries[0].data["app_name"] == REGISTER["app_name"]
assert entries[0].data["app_version"] == REGISTER["app_version"]
assert entries[0].data["device_name"] == REGISTER["device_name"]
assert entries[0].data["manufacturer"] == REGISTER["manufacturer"]
assert entries[0].data["model"] == REGISTER["model"]
assert entries[0].data["os_name"] == REGISTER["os_name"]
assert entries[0].data["os_version"] == REGISTER["os_version"]
assert entries[0].data["supports_encryption"] == REGISTER["supports_encryption"]
keylen = SecretBox.KEY_SIZE keylen = SecretBox.KEY_SIZE
key = register_json[CONF_SECRET].encode("utf-8") key = register_json[CONF_SECRET].encode("utf-8")

View File

@ -19,7 +19,7 @@ from homeassistant.const import (
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import CoreState, State from homeassistant.core import CoreState, State
from homeassistant.helpers import collection from homeassistant.helpers import collection, entity_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, mock_component, mock_restore_cache from tests.common import assert_setup_component, mock_component, mock_restore_cache
@ -664,3 +664,42 @@ async def test_update_person_when_user_removed(
await hass.async_block_till_done() await hass.async_block_till_done()
assert storage_collection.data[person["id"]]["user_id"] is None assert storage_collection.data[person["id"]]["user_id"] is None
async def test_removing_device_tracker(hass, storage_setup):
"""Test we automatically remove removed device trackers."""
storage_collection = hass.data[DOMAIN][1]
reg = await entity_registry.async_get_registry(hass)
entry = reg.async_get_or_create(
"device_tracker", "mobile_app", "bla", suggested_object_id="pixel"
)
person = await storage_collection.async_create_item(
{"name": "Hello", "device_trackers": [entry.entity_id]}
)
reg.async_remove(entry.entity_id)
await hass.async_block_till_done()
assert storage_collection.data[person["id"]]["device_trackers"] == []
async def test_add_user_device_tracker(hass, storage_setup, hass_read_only_user):
"""Test adding a device tracker to a person tied to a user."""
storage_collection = hass.data[DOMAIN][1]
pers = await storage_collection.async_create_item(
{
"name": "Hello",
"user_id": hass_read_only_user.id,
"device_trackers": ["device_tracker.on_create"],
}
)
await person.async_add_user_device_tracker(
hass, hass_read_only_user.id, "device_tracker.added"
)
assert storage_collection.data[pers["id"]]["device_trackers"] == [
"device_tracker.on_create",
"device_tracker.added",
]