diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index 562e832cc19..164207a8b2a 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -25,6 +25,7 @@ class AreaEntry: """Area Registry Entry.""" name: str = attr.ib() + normalized_name: str = attr.ib() id: Optional[str] = attr.ib(default=None) def generate_id(self, existing_ids: Container[str]) -> None: @@ -45,27 +46,47 @@ class AreaRegistry: self.hass = hass self.areas: MutableMapping[str, AreaEntry] = {} self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + self._normalized_name_area_idx: Dict[str, str] = {} @callback def async_get_area(self, area_id: str) -> Optional[AreaEntry]: - """Get all areas.""" + """Get area by id.""" return self.areas.get(area_id) + @callback + def async_get_area_by_name(self, name: str) -> Optional[AreaEntry]: + """Get area by name.""" + normalized_name = normalize_area_name(name) + if normalized_name not in self._normalized_name_area_idx: + return None + return self.areas[self._normalized_name_area_idx[normalized_name]] + @callback def async_list_areas(self) -> Iterable[AreaEntry]: """Get all areas.""" return self.areas.values() + @callback + def async_get_or_create(self, name: str) -> AreaEntry: + """Get or create an area.""" + area = self.async_get_area_by_name(name) + if area: + return area + return self.async_create(name) + @callback def async_create(self, name: str) -> AreaEntry: """Create a new area.""" - if self._async_is_registered(name): - raise ValueError("Name is already in use") + normalized_name = normalize_area_name(name) - area = AreaEntry(name=name) + if self.async_get_area_by_name(name): + raise ValueError(f"The name {name} ({normalized_name}) is already in use") + + area = AreaEntry(name=name, normalized_name=normalized_name) area.generate_id(self.areas) assert area.id is not None self.areas[area.id] = area + self._normalized_name_area_idx[normalized_name] = area.id self.async_schedule_save() self.hass.bus.async_fire( EVENT_AREA_REGISTRY_UPDATED, {"action": "create", "area_id": area.id} @@ -75,12 +96,14 @@ class AreaRegistry: @callback def async_delete(self, area_id: str) -> None: """Delete area.""" + area = self.areas[area_id] device_registry = dr.async_get(self.hass) entity_registry = er.async_get(self.hass) device_registry.async_clear_area_id(area_id) entity_registry.async_clear_area_id(area_id) del self.areas[area_id] + del self._normalized_name_area_idx[area.normalized_name] self.hass.bus.async_fire( EVENT_AREA_REGISTRY_UPDATED, {"action": "remove", "area_id": area_id} @@ -107,23 +130,25 @@ class AreaRegistry: if name == old.name: return old - if self._async_is_registered(name): - raise ValueError("Name is already in use") + normalized_name = normalize_area_name(name) + + if normalized_name != old.normalized_name: + if self.async_get_area_by_name(name): + raise ValueError( + f"The name {name} ({normalized_name}) is already in use" + ) changes["name"] = name + changes["normalized_name"] = normalized_name new = self.areas[area_id] = attr.evolve(old, **changes) + self._normalized_name_area_idx[ + normalized_name + ] = self._normalized_name_area_idx.pop(old.normalized_name) + self.async_schedule_save() return new - @callback - def _async_is_registered(self, name: str) -> Optional[AreaEntry]: - """Check if a name is currently registered.""" - for area in self.areas.values(): - if name == area.name: - return area - return None - async def async_load(self) -> None: """Load the area registry.""" data = await self._store.async_load() @@ -132,7 +157,11 @@ class AreaRegistry: if data is not None: for area in data["areas"]: - areas[area["id"]] = AreaEntry(name=area["name"], id=area["id"]) + normalized_name = normalize_area_name(area["name"]) + areas[area["id"]] = AreaEntry( + name=area["name"], id=area["id"], normalized_name=normalized_name + ) + self._normalized_name_area_idx[normalized_name] = area["id"] self.areas = areas @@ -147,7 +176,11 @@ class AreaRegistry: data = {} data["areas"] = [ - {"name": entry.name, "id": entry.id} for entry in self.areas.values() + { + "name": entry.name, + "id": entry.id, + } + for entry in self.areas.values() ] return data @@ -173,3 +206,8 @@ async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry: This is deprecated and will be removed in the future. Use async_get instead. """ return async_get(hass) + + +def normalize_area_name(area_name: str) -> str: + """Normalize an area name by removing whitespace and case folding.""" + return area_name.casefold().replace(" ", "") diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 77dc2cdf609..3dd44364604 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -71,6 +71,7 @@ class DeviceEntry: ) ), ) + suggested_area: Optional[str] = attr.ib(default=None) @property def disabled(self) -> bool: @@ -251,6 +252,7 @@ class DeviceRegistry: via_device: Optional[Tuple[str, str]] = None, # To disable a device if it gets created disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + suggested_area: Union[str, None, UndefinedType] = UNDEFINED, ) -> Optional[DeviceEntry]: """Get device. Create if it doesn't exist.""" if not identifiers and not connections: @@ -304,6 +306,7 @@ class DeviceRegistry: sw_version=sw_version, entry_type=entry_type, disabled_by=disabled_by, + suggested_area=suggested_area, ) @callback @@ -321,6 +324,7 @@ class DeviceRegistry: via_device_id: Union[str, None, UndefinedType] = UNDEFINED, remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + suggested_area: Union[str, None, UndefinedType] = UNDEFINED, ) -> Optional[DeviceEntry]: """Update properties of a device.""" return self._async_update_device( @@ -335,6 +339,7 @@ class DeviceRegistry: via_device_id=via_device_id, remove_config_entry_id=remove_config_entry_id, disabled_by=disabled_by, + suggested_area=suggested_area, ) @callback @@ -356,6 +361,7 @@ class DeviceRegistry: area_id: Union[str, None, UndefinedType] = UNDEFINED, name_by_user: Union[str, None, UndefinedType] = UNDEFINED, disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + suggested_area: Union[str, None, UndefinedType] = UNDEFINED, ) -> Optional[DeviceEntry]: """Update device attributes.""" old = self.devices[device_id] @@ -364,6 +370,16 @@ class DeviceRegistry: config_entries = old.config_entries + if ( + suggested_area not in (UNDEFINED, None, "") + and area_id is UNDEFINED + and old.area_id is None + ): + area = self.hass.helpers.area_registry.async_get( + self.hass + ).async_get_or_create(suggested_area) + area_id = area.id + if ( add_config_entry_id is not UNDEFINED and add_config_entry_id not in old.config_entries @@ -403,6 +419,7 @@ class DeviceRegistry: ("entry_type", entry_type), ("via_device_id", via_device_id), ("disabled_by", disabled_by), + ("suggested_area", suggested_area), ): if value is not UNDEFINED and value != getattr(old, attr_name): changes[attr_name] = value diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 509508405a4..2caf7fe46ab 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -399,6 +399,7 @@ class EntityPlatform: "sw_version", "entry_type", "via_device", + "suggested_area", ): if key in device_info: processed_dev_info[key] = device_info[key] diff --git a/tests/components/config/test_area_registry.py b/tests/components/config/test_area_registry.py index f66e16e606f..35176cc79f9 100644 --- a/tests/components/config/test_area_registry.py +++ b/tests/components/config/test_area_registry.py @@ -55,7 +55,7 @@ async def test_create_area_with_name_already_in_use(hass, client, registry): assert not msg["success"] assert msg["error"]["code"] == "invalid_info" - assert msg["error"]["message"] == "Name is already in use" + assert msg["error"]["message"] == "The name mock (mock) is already in use" assert len(registry.areas) == 1 @@ -147,5 +147,5 @@ async def test_update_area_with_name_already_in_use(hass, client, registry): assert not msg["success"] assert msg["error"]["code"] == "invalid_info" - assert msg["error"]["message"] == "Name is already in use" + assert msg["error"]["message"] == "The name mock 2 (mock2) is already in use" assert len(registry.areas) == 2 diff --git a/tests/helpers/test_area_registry.py b/tests/helpers/test_area_registry.py index 0bfa5e597d2..7dca029987e 100644 --- a/tests/helpers/test_area_registry.py +++ b/tests/helpers/test_area_registry.py @@ -58,7 +58,7 @@ async def test_create_area_with_name_already_in_use(hass, registry, update_event with pytest.raises(ValueError) as e_info: area2 = registry.async_create("mock") assert area1 != area2 - assert e_info == "Name is already in use" + assert e_info == "The name mock 2 (mock2) is already in use" await hass.async_block_till_done() @@ -133,6 +133,18 @@ async def test_update_area_with_same_name(registry): assert len(registry.areas) == 1 +async def test_update_area_with_same_name_change_case(registry): + """Make sure that we can reapply the same name with a different case to the area.""" + area = registry.async_create("mock") + + updated_area = registry.async_update(area.id, name="Mock") + + assert updated_area.name == "Mock" + assert updated_area.id == area.id + assert updated_area.normalized_name == area.normalized_name + assert len(registry.areas) == 1 + + async def test_update_area_with_name_already_in_use(registry): """Make sure that we can't update an area with a name already in use.""" area1 = registry.async_create("mock1") @@ -140,17 +152,31 @@ async def test_update_area_with_name_already_in_use(registry): with pytest.raises(ValueError) as e_info: registry.async_update(area1.id, name="mock2") - assert e_info == "Name is already in use" + assert e_info == "The name mock 2 (mock2) is already in use" assert area1.name == "mock1" assert area2.name == "mock2" assert len(registry.areas) == 2 +async def test_update_area_with_normalized_name_already_in_use(registry): + """Make sure that we can't update an area with a normalized name already in use.""" + area1 = registry.async_create("mock1") + area2 = registry.async_create("Moc k2") + + with pytest.raises(ValueError) as e_info: + registry.async_update(area1.id, name="mock2") + assert e_info == "The name mock 2 (mock2) is already in use" + + assert area1.name == "mock1" + assert area2.name == "Moc k2" + assert len(registry.areas) == 2 + + async def test_load_area(hass, registry): """Make sure that we can load/save data correctly.""" - registry.async_create("mock1") - registry.async_create("mock2") + area1 = registry.async_create("mock1") + area2 = registry.async_create("mock2") assert len(registry.areas) == 2 @@ -160,6 +186,11 @@ async def test_load_area(hass, registry): assert list(registry.areas) == list(registry2.areas) + area1_registry2 = registry2.async_get_or_create("mock1") + assert area1_registry2.id == area1.id + area2_registry2 = registry2.async_get_or_create("mock2") + assert area2_registry2.id == area2.id + @pytest.mark.parametrize("load_registries", [False]) async def test_loading_area_from_storage(hass, hass_storage): @@ -173,3 +204,41 @@ async def test_loading_area_from_storage(hass, hass_storage): registry = area_registry.async_get(hass) assert len(registry.areas) == 1 + + +async def test_async_get_or_create(hass, registry): + """Make sure we can get the area by name.""" + area = registry.async_get_or_create("Mock1") + area2 = registry.async_get_or_create("mock1") + area3 = registry.async_get_or_create("mock 1") + + assert area == area2 + assert area == area3 + assert area2 == area3 + + +async def test_async_get_area_by_name(hass, registry): + """Make sure we can get the area by name.""" + registry.async_create("Mock1") + + assert len(registry.areas) == 1 + + assert registry.async_get_area_by_name("M o c k 1").normalized_name == "mock1" + + +async def test_async_get_area_by_name_not_found(hass, registry): + """Make sure we return None for non-existent areas.""" + registry.async_create("Mock1") + + assert len(registry.areas) == 1 + + assert registry.async_get_area_by_name("non_exist") is None + + +async def test_async_get_area(hass, registry): + """Make sure we can get the area by id.""" + area = registry.async_create("Mock1") + + assert len(registry.areas) == 1 + + assert registry.async_get_area(area.id).normalized_name == "mock1" diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index a128f8aa390..bc0e1c3bec9 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -8,7 +8,12 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, callback from homeassistant.helpers import device_registry, entity_registry -from tests.common import MockConfigEntry, flush_store, mock_device_registry +from tests.common import ( + MockConfigEntry, + flush_store, + mock_area_registry, + mock_device_registry, +) @pytest.fixture @@ -17,6 +22,12 @@ def registry(hass): return mock_device_registry(hass) +@pytest.fixture +def area_registry(hass): + """Return an empty, loaded, registry.""" + return mock_area_registry(hass) + + @pytest.fixture def update_events(hass): """Capture update events.""" @@ -31,7 +42,9 @@ def update_events(hass): return events -async def test_get_or_create_returns_same_entry(hass, registry, update_events): +async def test_get_or_create_returns_same_entry( + hass, registry, area_registry, update_events +): """Make sure we do not duplicate entries.""" entry = registry.async_get_or_create( config_entry_id="1234", @@ -41,6 +54,7 @@ async def test_get_or_create_returns_same_entry(hass, registry, update_events): name="name", manufacturer="manufacturer", model="model", + suggested_area="Game Room", ) entry2 = registry.async_get_or_create( config_entry_id="1234", @@ -48,21 +62,31 @@ async def test_get_or_create_returns_same_entry(hass, registry, update_events): identifiers={("bridgeid", "0123")}, manufacturer="manufacturer", model="model", + suggested_area="Game Room", ) entry3 = registry.async_get_or_create( config_entry_id="1234", connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) + game_room_area = area_registry.async_get_area_by_name("Game Room") + assert game_room_area is not None + assert len(area_registry.areas) == 1 + assert len(registry.devices) == 1 + assert entry.area_id == game_room_area.id assert entry.id == entry2.id assert entry.id == entry3.id assert entry.identifiers == {("bridgeid", "0123")} + assert entry2.area_id == game_room_area.id + assert entry3.manufacturer == "manufacturer" assert entry3.model == "model" assert entry3.name == "name" assert entry3.sw_version == "sw-version" + assert entry3.suggested_area == "Game Room" + assert entry3.area_id == game_room_area.id await hass.async_block_till_done() @@ -154,6 +178,7 @@ async def test_loading_from_storage(hass, hass_storage): "area_id": "12345A", "name_by_user": "Test Friendly Name", "disabled_by": "user", + "suggested_area": "Kitchen", } ], "deleted_devices": [ @@ -444,7 +469,7 @@ async def test_specifying_via_device_update(registry): assert light.via_device_id == via.id -async def test_loading_saving_data(hass, registry): +async def test_loading_saving_data(hass, registry, area_registry): """Test that we load/save data correctly.""" orig_via = registry.async_get_or_create( config_entry_id="123", @@ -506,7 +531,18 @@ async def test_loading_saving_data(hass, registry): assert orig_light4.id == orig_light3.id - assert len(registry.devices) == 3 + orig_kitchen_light = registry.async_get_or_create( + config_entry_id="999", + connections=set(), + identifiers={("hue", "999")}, + manufacturer="manufacturer", + model="light", + via_device=("hue", "0123"), + disabled_by="user", + suggested_area="Kitchen", + ) + + assert len(registry.devices) == 4 assert len(registry.deleted_devices) == 1 orig_via = registry.async_update_device( @@ -530,6 +566,16 @@ async def test_loading_saving_data(hass, registry): assert orig_light == new_light assert orig_light4 == new_light4 + # Ensure a save/load cycle does not keep suggested area + new_kitchen_light = registry2.async_get_device({("hue", "999")}) + assert orig_kitchen_light.suggested_area == "Kitchen" + + orig_kitchen_light_witout_suggested_area = registry.async_update_device( + orig_kitchen_light.id, suggested_area=None + ) + orig_kitchen_light_witout_suggested_area.suggested_area is None + assert orig_kitchen_light_witout_suggested_area == new_kitchen_light + async def test_no_unnecessary_changes(registry): """Make sure we do not consider devices changes.""" @@ -706,6 +752,33 @@ async def test_update_sw_version(registry): assert updated_entry.sw_version == sw_version +async def test_update_suggested_area(registry, area_registry): + """Verify that we can update the suggested area version of a device.""" + entry = registry.async_get_or_create( + config_entry_id="1234", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("bla", "123")}, + ) + assert not entry.suggested_area + assert entry.area_id is None + + suggested_area = "Pool" + + with patch.object(registry, "async_schedule_save") as mock_save: + updated_entry = registry.async_update_device( + entry.id, suggested_area=suggested_area + ) + + assert mock_save.call_count == 1 + assert updated_entry != entry + assert updated_entry.suggested_area == suggested_area + + pool_area = area_registry.async_get_area_by_name("Pool") + assert pool_area is not None + assert updated_entry.area_id == pool_area.id + assert len(area_registry.areas) == 1 + + async def test_cleanup_device_registry(hass, registry): """Test cleanup works.""" config_entry = MockConfigEntry(domain="hue") @@ -1104,3 +1177,35 @@ async def test_get_or_create_sets_default_values(hass, registry): assert entry.name == "default name 1" assert entry.model == "default model 1" assert entry.manufacturer == "default manufacturer 1" + + +async def test_verify_suggested_area_does_not_overwrite_area_id( + hass, registry, area_registry +): + """Make sure suggested area does not override a set area id.""" + game_room_area = area_registry.async_create("Game Room") + + original_entry = registry.async_get_or_create( + config_entry_id="1234", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("bridgeid", "0123")}, + sw_version="sw-version", + name="name", + manufacturer="manufacturer", + model="model", + ) + entry = registry.async_update_device(original_entry.id, area_id=game_room_area.id) + + assert entry.area_id == game_room_area.id + + entry2 = registry.async_get_or_create( + config_entry_id="1234", + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("bridgeid", "0123")}, + sw_version="sw-version", + name="name", + manufacturer="manufacturer", + model="model", + suggested_area="New Game Room", + ) + assert entry2.area_id == game_room_area.id diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 0a939ba2825..ab3e04843f9 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -728,6 +728,7 @@ async def test_device_info_called(hass): "model": "test-model", "name": "test-name", "sw_version": "test-sw", + "suggested_area": "Heliport", "entry_type": "service", "via_device": ("hue", "via-id"), }, @@ -755,6 +756,7 @@ async def test_device_info_called(hass): assert device.model == "test-model" assert device.name == "test-name" assert device.sw_version == "test-sw" + assert device.suggested_area == "Heliport" assert device.entry_type == "service" assert device.via_device_id == via.id