Implement suggested_area in the device registry (#45940)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2021-02-19 19:34:33 -10:00 committed by GitHub
parent e6125a1e4e
commit bb7e4d7daa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 258 additions and 26 deletions

View File

@ -25,6 +25,7 @@ class AreaEntry:
"""Area Registry Entry."""
name: str = attr.ib()
normalized_name: str = attr.ib()
id: Optional[str] = attr.ib(default=None)
def generate_id(self, existing_ids: Container[str]) -> None:
@ -45,27 +46,47 @@ class AreaRegistry:
self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._normalized_name_area_idx: Dict[str, str] = {}
@callback
def async_get_area(self, area_id: str) -> Optional[AreaEntry]:
"""Get all areas."""
"""Get area by id."""
return self.areas.get(area_id)
@callback
def async_get_area_by_name(self, name: str) -> Optional[AreaEntry]:
"""Get area by name."""
normalized_name = normalize_area_name(name)
if normalized_name not in self._normalized_name_area_idx:
return None
return self.areas[self._normalized_name_area_idx[normalized_name]]
@callback
def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas."""
return self.areas.values()
@callback
def async_get_or_create(self, name: str) -> AreaEntry:
"""Get or create an area."""
area = self.async_get_area_by_name(name)
if area:
return area
return self.async_create(name)
@callback
def async_create(self, name: str) -> AreaEntry:
"""Create a new area."""
if self._async_is_registered(name):
raise ValueError("Name is already in use")
normalized_name = normalize_area_name(name)
area = AreaEntry(name=name)
if self.async_get_area_by_name(name):
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
area = AreaEntry(name=name, normalized_name=normalized_name)
area.generate_id(self.areas)
assert area.id is not None
self.areas[area.id] = area
self._normalized_name_area_idx[normalized_name] = area.id
self.async_schedule_save()
self.hass.bus.async_fire(
EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id}
@ -75,12 +96,14 @@ class AreaRegistry:
@callback
def async_delete(self, area_id: str) -> None:
"""Delete area."""
area = self.areas[area_id]
device_registry = dr.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry.async_clear_area_id(area_id)
entity_registry.async_clear_area_id(area_id)
del self.areas[area_id]
del self._normalized_name_area_idx[area.normalized_name]
self.hass.bus.async_fire(
EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id}
@ -107,23 +130,25 @@ class AreaRegistry:
if name == old.name:
return old
if self._async_is_registered(name):
raise ValueError("Name is already in use")
normalized_name = normalize_area_name(name)
if normalized_name != old.normalized_name:
if self.async_get_area_by_name(name):
raise ValueError(
f"The name {name} ({normalized_name}) is already in use"
)
changes["name"] = name
changes["normalized_name"] = normalized_name
new = self.areas[area_id] = attr.evolve(old, **changes)
self._normalized_name_area_idx[
normalized_name
] = self._normalized_name_area_idx.pop(old.normalized_name)
self.async_schedule_save()
return new
@callback
def _async_is_registered(self, name: str) -> Optional[AreaEntry]:
"""Check if a name is currently registered."""
for area in self.areas.values():
if name == area.name:
return area
return None
async def async_load(self) -> None:
"""Load the area registry."""
data = await self._store.async_load()
@ -132,7 +157,11 @@ class AreaRegistry:
if data is not None:
for area in data["areas"]:
areas[area["id"]] = AreaEntry(name=area["name"], id=area["id"])
normalized_name = normalize_area_name(area["name"])
areas[area["id"]] = AreaEntry(
name=area["name"], id=area["id"], normalized_name=normalized_name
)
self._normalized_name_area_idx[normalized_name] = area["id"]
self.areas = areas
@ -147,7 +176,11 @@ class AreaRegistry:
data = {}
data["areas"] = [
{"name": entry.name, "id": entry.id} for entry in self.areas.values()
{
"name": entry.name,
"id": entry.id,
}
for entry in self.areas.values()
]
return data
@ -173,3 +206,8 @@ async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry:
This is deprecated and will be removed in the future. Use async_get instead.
"""
return async_get(hass)
def normalize_area_name(area_name: str) -> str:
"""Normalize an area name by removing whitespace and case folding."""
return area_name.casefold().replace(" ", "")

View File

@ -71,6 +71,7 @@ class DeviceEntry:
)
),
)
suggested_area: Optional[str] = attr.ib(default=None)
@property
def disabled(self) -> bool:
@ -251,6 +252,7 @@ class DeviceRegistry:
via_device: Optional[Tuple[str, str]] = None,
# To disable a device if it gets created
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
@ -304,6 +306,7 @@ class DeviceRegistry:
sw_version=sw_version,
entry_type=entry_type,
disabled_by=disabled_by,
suggested_area=suggested_area,
)
@callback
@ -321,6 +324,7 @@ class DeviceRegistry:
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Update properties of a device."""
return self._async_update_device(
@ -335,6 +339,7 @@ class DeviceRegistry:
via_device_id=via_device_id,
remove_config_entry_id=remove_config_entry_id,
disabled_by=disabled_by,
suggested_area=suggested_area,
)
@callback
@ -356,6 +361,7 @@ class DeviceRegistry:
area_id: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
suggested_area: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Update device attributes."""
old = self.devices[device_id]
@ -364,6 +370,16 @@ class DeviceRegistry:
config_entries = old.config_entries
if (
suggested_area not in (UNDEFINED, None, "")
and area_id is UNDEFINED
and old.area_id is None
):
area = self.hass.helpers.area_registry.async_get(
self.hass
).async_get_or_create(suggested_area)
area_id = area.id
if (
add_config_entry_id is not UNDEFINED
and add_config_entry_id not in old.config_entries
@ -403,6 +419,7 @@ class DeviceRegistry:
("entry_type", entry_type),
("via_device_id", via_device_id),
("disabled_by", disabled_by),
("suggested_area", suggested_area),
):
if value is not UNDEFINED and value != getattr(old, attr_name):
changes[attr_name] = value

View File

@ -399,6 +399,7 @@ class EntityPlatform:
"sw_version",
"entry_type",
"via_device",
"suggested_area",
):
if key in device_info:
processed_dev_info[key] = device_info[key]

View File

@ -55,7 +55,7 @@ async def test_create_area_with_name_already_in_use(hass, client, registry):
assert not msg["success"]
assert msg["error"]["code"] == "invalid_info"
assert msg["error"]["message"] == "Name is already in use"
assert msg["error"]["message"] == "The name mock (mock) is already in use"
assert len(registry.areas) == 1
@ -147,5 +147,5 @@ async def test_update_area_with_name_already_in_use(hass, client, registry):
assert not msg["success"]
assert msg["error"]["code"] == "invalid_info"
assert msg["error"]["message"] == "Name is already in use"
assert msg["error"]["message"] == "The name mock 2 (mock2) is already in use"
assert len(registry.areas) == 2

View File

@ -58,7 +58,7 @@ async def test_create_area_with_name_already_in_use(hass, registry, update_event
with pytest.raises(ValueError) as e_info:
area2 = registry.async_create("mock")
assert area1 != area2
assert e_info == "Name is already in use"
assert e_info == "The name mock 2 (mock2) is already in use"
await hass.async_block_till_done()
@ -133,6 +133,18 @@ async def test_update_area_with_same_name(registry):
assert len(registry.areas) == 1
async def test_update_area_with_same_name_change_case(registry):
"""Make sure that we can reapply the same name with a different case to the area."""
area = registry.async_create("mock")
updated_area = registry.async_update(area.id, name="Mock")
assert updated_area.name == "Mock"
assert updated_area.id == area.id
assert updated_area.normalized_name == area.normalized_name
assert len(registry.areas) == 1
async def test_update_area_with_name_already_in_use(registry):
"""Make sure that we can't update an area with a name already in use."""
area1 = registry.async_create("mock1")
@ -140,17 +152,31 @@ async def test_update_area_with_name_already_in_use(registry):
with pytest.raises(ValueError) as e_info:
registry.async_update(area1.id, name="mock2")
assert e_info == "Name is already in use"
assert e_info == "The name mock 2 (mock2) is already in use"
assert area1.name == "mock1"
assert area2.name == "mock2"
assert len(registry.areas) == 2
async def test_update_area_with_normalized_name_already_in_use(registry):
"""Make sure that we can't update an area with a normalized name already in use."""
area1 = registry.async_create("mock1")
area2 = registry.async_create("Moc k2")
with pytest.raises(ValueError) as e_info:
registry.async_update(area1.id, name="mock2")
assert e_info == "The name mock 2 (mock2) is already in use"
assert area1.name == "mock1"
assert area2.name == "Moc k2"
assert len(registry.areas) == 2
async def test_load_area(hass, registry):
"""Make sure that we can load/save data correctly."""
registry.async_create("mock1")
registry.async_create("mock2")
area1 = registry.async_create("mock1")
area2 = registry.async_create("mock2")
assert len(registry.areas) == 2
@ -160,6 +186,11 @@ async def test_load_area(hass, registry):
assert list(registry.areas) == list(registry2.areas)
area1_registry2 = registry2.async_get_or_create("mock1")
assert area1_registry2.id == area1.id
area2_registry2 = registry2.async_get_or_create("mock2")
assert area2_registry2.id == area2.id
@pytest.mark.parametrize("load_registries", [False])
async def test_loading_area_from_storage(hass, hass_storage):
@ -173,3 +204,41 @@ async def test_loading_area_from_storage(hass, hass_storage):
registry = area_registry.async_get(hass)
assert len(registry.areas) == 1
async def test_async_get_or_create(hass, registry):
"""Make sure we can get the area by name."""
area = registry.async_get_or_create("Mock1")
area2 = registry.async_get_or_create("mock1")
area3 = registry.async_get_or_create("mock 1")
assert area == area2
assert area == area3
assert area2 == area3
async def test_async_get_area_by_name(hass, registry):
"""Make sure we can get the area by name."""
registry.async_create("Mock1")
assert len(registry.areas) == 1
assert registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1"
async def test_async_get_area_by_name_not_found(hass, registry):
"""Make sure we return None for non-existent areas."""
registry.async_create("Mock1")
assert len(registry.areas) == 1
assert registry.async_get_area_by_name("non_exist") is None
async def test_async_get_area(hass, registry):
"""Make sure we can get the area by id."""
area = registry.async_create("Mock1")
assert len(registry.areas) == 1
assert registry.async_get_area(area.id).normalized_name == "mock1"

View File

@ -8,7 +8,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, callback
from homeassistant.helpers import device_registry, entity_registry
from tests.common import MockConfigEntry, flush_store, mock_device_registry
from tests.common import (
MockConfigEntry,
flush_store,
mock_area_registry,
mock_device_registry,
)
@pytest.fixture
@ -17,6 +22,12 @@ def registry(hass):
return mock_device_registry(hass)
@pytest.fixture
def area_registry(hass):
"""Return an empty, loaded, registry."""
return mock_area_registry(hass)
@pytest.fixture
def update_events(hass):
"""Capture update events."""
@ -31,7 +42,9 @@ def update_events(hass):
return events
async def test_get_or_create_returns_same_entry(hass, registry, update_events):
async def test_get_or_create_returns_same_entry(
hass, registry, area_registry, update_events
):
"""Make sure we do not duplicate entries."""
entry = registry.async_get_or_create(
config_entry_id="1234",
@ -41,6 +54,7 @@ async def test_get_or_create_returns_same_entry(hass, registry, update_events):
name="name",
manufacturer="manufacturer",
model="model",
suggested_area="Game Room",
)
entry2 = registry.async_get_or_create(
config_entry_id="1234",
@ -48,21 +62,31 @@ async def test_get_or_create_returns_same_entry(hass, registry, update_events):
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
suggested_area="Game Room",
)
entry3 = registry.async_get_or_create(
config_entry_id="1234",
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
game_room_area = area_registry.async_get_area_by_name("Game Room")
assert game_room_area is not None
assert len(area_registry.areas) == 1
assert len(registry.devices) == 1
assert entry.area_id == game_room_area.id
assert entry.id == entry2.id
assert entry.id == entry3.id
assert entry.identifiers == {("bridgeid", "0123")}
assert entry2.area_id == game_room_area.id
assert entry3.manufacturer == "manufacturer"
assert entry3.model == "model"
assert entry3.name == "name"
assert entry3.sw_version == "sw-version"
assert entry3.suggested_area == "Game Room"
assert entry3.area_id == game_room_area.id
await hass.async_block_till_done()
@ -154,6 +178,7 @@ async def test_loading_from_storage(hass, hass_storage):
"area_id": "12345A",
"name_by_user": "Test Friendly Name",
"disabled_by": "user",
"suggested_area": "Kitchen",
}
],
"deleted_devices": [
@ -444,7 +469,7 @@ async def test_specifying_via_device_update(registry):
assert light.via_device_id == via.id
async def test_loading_saving_data(hass, registry):
async def test_loading_saving_data(hass, registry, area_registry):
"""Test that we load/save data correctly."""
orig_via = registry.async_get_or_create(
config_entry_id="123",
@ -506,7 +531,18 @@ async def test_loading_saving_data(hass, registry):
assert orig_light4.id == orig_light3.id
assert len(registry.devices) == 3
orig_kitchen_light = registry.async_get_or_create(
config_entry_id="999",
connections=set(),
identifiers={("hue", "999")},
manufacturer="manufacturer",
model="light",
via_device=("hue", "0123"),
disabled_by="user",
suggested_area="Kitchen",
)
assert len(registry.devices) == 4
assert len(registry.deleted_devices) == 1
orig_via = registry.async_update_device(
@ -530,6 +566,16 @@ async def test_loading_saving_data(hass, registry):
assert orig_light == new_light
assert orig_light4 == new_light4
# Ensure a save/load cycle does not keep suggested area
new_kitchen_light = registry2.async_get_device({("hue", "999")})
assert orig_kitchen_light.suggested_area == "Kitchen"
orig_kitchen_light_witout_suggested_area = registry.async_update_device(
orig_kitchen_light.id, suggested_area=None
)
orig_kitchen_light_witout_suggested_area.suggested_area is None
assert orig_kitchen_light_witout_suggested_area == new_kitchen_light
async def test_no_unnecessary_changes(registry):
"""Make sure we do not consider devices changes."""
@ -706,6 +752,33 @@ async def test_update_sw_version(registry):
assert updated_entry.sw_version == sw_version
async def test_update_suggested_area(registry, area_registry):
"""Verify that we can update the suggested area version of a device."""
entry = registry.async_get_or_create(
config_entry_id="1234",
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers={("bla", "123")},
)
assert not entry.suggested_area
assert entry.area_id is None
suggested_area = "Pool"
with patch.object(registry, "async_schedule_save") as mock_save:
updated_entry = registry.async_update_device(
entry.id, suggested_area=suggested_area
)
assert mock_save.call_count == 1
assert updated_entry != entry
assert updated_entry.suggested_area == suggested_area
pool_area = area_registry.async_get_area_by_name("Pool")
assert pool_area is not None
assert updated_entry.area_id == pool_area.id
assert len(area_registry.areas) == 1
async def test_cleanup_device_registry(hass, registry):
"""Test cleanup works."""
config_entry = MockConfigEntry(domain="hue")
@ -1104,3 +1177,35 @@ async def test_get_or_create_sets_default_values(hass, registry):
assert entry.name == "default name 1"
assert entry.model == "default model 1"
assert entry.manufacturer == "default manufacturer 1"
async def test_verify_suggested_area_does_not_overwrite_area_id(
hass, registry, area_registry
):
"""Make sure suggested area does not override a set area id."""
game_room_area = area_registry.async_create("Game Room")
original_entry = registry.async_get_or_create(
config_entry_id="1234",
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers={("bridgeid", "0123")},
sw_version="sw-version",
name="name",
manufacturer="manufacturer",
model="model",
)
entry = registry.async_update_device(original_entry.id, area_id=game_room_area.id)
assert entry.area_id == game_room_area.id
entry2 = registry.async_get_or_create(
config_entry_id="1234",
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers={("bridgeid", "0123")},
sw_version="sw-version",
name="name",
manufacturer="manufacturer",
model="model",
suggested_area="New Game Room",
)
assert entry2.area_id == game_room_area.id

View File

@ -728,6 +728,7 @@ async def test_device_info_called(hass):
"model": "test-model",
"name": "test-name",
"sw_version": "test-sw",
"suggested_area": "Heliport",
"entry_type": "service",
"via_device": ("hue", "via-id"),
},
@ -755,6 +756,7 @@ async def test_device_info_called(hass):
assert device.model == "test-model"
assert device.name == "test-name"
assert device.sw_version == "test-sw"
assert device.suggested_area == "Heliport"
assert device.entry_type == "service"
assert device.via_device_id == via.id