Update new values coming in for dev registry (#16852)

* Update new values coming in for dev registry

* fix Lint+Test;2C
This commit is contained in:
Paulus Schoutsen 2018-09-27 11:26:58 +02:00 committed by GitHub
parent 29db43edb2
commit da3342f1aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 39 deletions

View File

@ -26,11 +26,12 @@ CONNECTION_ZIGBEE = 'zigbee'
class DeviceEntry: class DeviceEntry:
"""Device Registry Entry.""" """Device Registry Entry."""
config_entries = attr.ib(type=set, converter=set) config_entries = attr.ib(type=set, converter=set,
connections = attr.ib(type=set, converter=set) default=attr.Factory(set))
identifiers = attr.ib(type=set, converter=set) connections = attr.ib(type=set, converter=set, default=attr.Factory(set))
manufacturer = attr.ib(type=str) identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set))
model = attr.ib(type=str) manufacturer = attr.ib(type=str, default=None)
model = attr.ib(type=str, default=None)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
sw_version = attr.ib(type=str, default=None) sw_version = attr.ib(type=str, default=None)
hub_device_id = attr.ib(type=str, default=None) hub_device_id = attr.ib(type=str, default=None)
@ -56,46 +57,53 @@ class DeviceRegistry:
return None return None
@callback @callback
def async_get_or_create(self, *, config_entry_id, connections, identifiers, def async_get_or_create(self, *, config_entry_id, connections=None,
manufacturer, model, name=None, sw_version=None, identifiers=None, manufacturer=_UNDEF,
model=_UNDEF, name=_UNDEF, sw_version=_UNDEF,
via_hub=None): via_hub=None):
"""Get device. Create if it doesn't exist.""" """Get device. Create if it doesn't exist."""
if not identifiers and not connections: if not identifiers and not connections:
return None return None
if identifiers is None:
identifiers = set()
if connections is None:
connections = set()
device = self.async_get_device(identifiers, connections) device = self.async_get_device(identifiers, connections)
if device is None:
device = DeviceEntry()
self.devices[device.id] = device
if via_hub is not None: if via_hub is not None:
hub_device = self.async_get_device({via_hub}, set()) hub_device = self.async_get_device({via_hub}, set())
hub_device_id = hub_device.id if hub_device else None hub_device_id = hub_device.id if hub_device else _UNDEF
else: else:
hub_device_id = None hub_device_id = _UNDEF
if device is not None: return self._async_update_device(
return self._async_update_device( device.id,
device.id, config_entry_id=config_entry_id, add_config_entry_id=config_entry_id,
hub_device_id=hub_device_id hub_device_id=hub_device_id,
) merge_connections=connections,
merge_identifiers=identifiers,
device = DeviceEntry(
config_entries={config_entry_id},
connections=connections,
identifiers=identifiers,
manufacturer=manufacturer, manufacturer=manufacturer,
model=model, model=model,
name=name, name=name,
sw_version=sw_version, sw_version=sw_version,
hub_device_id=hub_device_id
) )
self.devices[device.id] = device
self.async_schedule_save()
return device
@callback @callback
def _async_update_device(self, device_id, *, config_entry_id=_UNDEF, def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF,
remove_config_entry_id=_UNDEF, remove_config_entry_id=_UNDEF,
merge_connections=_UNDEF,
merge_identifiers=_UNDEF,
manufacturer=_UNDEF,
model=_UNDEF,
name=_UNDEF,
sw_version=_UNDEF,
hub_device_id=_UNDEF): hub_device_id=_UNDEF):
"""Update device attributes.""" """Update device attributes."""
old = self.devices[device_id] old = self.devices[device_id]
@ -104,21 +112,34 @@ class DeviceRegistry:
config_entries = old.config_entries config_entries = old.config_entries
if (config_entry_id is not _UNDEF and if (add_config_entry_id is not _UNDEF and
config_entry_id not in old.config_entries): add_config_entry_id not in old.config_entries):
config_entries = old.config_entries | {config_entry_id} config_entries = old.config_entries | {add_config_entry_id}
if (remove_config_entry_id is not _UNDEF and if (remove_config_entry_id is not _UNDEF and
remove_config_entry_id in config_entries): remove_config_entry_id in config_entries):
config_entries = set(config_entries) config_entries = config_entries - {remove_config_entry_id}
config_entries.remove(remove_config_entry_id)
if config_entries is not old.config_entries: if config_entries is not old.config_entries:
changes['config_entries'] = config_entries changes['config_entries'] = config_entries
if (hub_device_id is not _UNDEF and for attr_name, value in (
hub_device_id != old.hub_device_id): ('connections', merge_connections),
changes['hub_device_id'] = hub_device_id ('identifiers', merge_identifiers),
):
old_value = getattr(old, attr_name)
if value is not _UNDEF and value != old_value:
changes[attr_name] = old_value | value
for attr_name, value in (
('manufacturer', manufacturer),
('model', model),
('name', name),
('sw_version', sw_version),
('hub_device_id', hub_device_id),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
if not changes: if not changes:
return old return old

View File

@ -27,7 +27,6 @@ async def test_list_devices(hass, client, registry):
manufacturer='manufacturer', model='model') manufacturer='manufacturer', model='model')
registry.async_get_or_create( registry.async_get_or_create(
config_entry_id='1234', config_entry_id='1234',
connections={},
identifiers={('bridgeid', '1234')}, identifiers={('bridgeid', '1234')},
manufacturer='manufacturer', model='model', manufacturer='manufacturer', model='model',
via_hub=('bridgeid', '0123')) via_hub=('bridgeid', '0123'))

View File

@ -17,7 +17,10 @@ async def test_get_or_create_returns_same_entry(registry):
config_entry_id='1234', config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')}, identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model') sw_version='sw-version',
name='name',
manufacturer='manufacturer',
model='model')
entry2 = registry.async_get_or_create( entry2 = registry.async_get_or_create(
config_entry_id='1234', config_entry_id='1234',
connections={('ethernet', '11:22:33:44:55:66:77:88')}, connections={('ethernet', '11:22:33:44:55:66:77:88')},
@ -25,15 +28,19 @@ async def test_get_or_create_returns_same_entry(registry):
manufacturer='manufacturer', model='model') manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create( entry3 = registry.async_get_or_create(
config_entry_id='1234', config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}
identifiers={('bridgeid', '1234')}, )
manufacturer='manufacturer', model='model')
assert len(registry.devices) == 1 assert len(registry.devices) == 1
assert entry.id == entry2.id assert entry.id == entry2.id
assert entry.id == entry3.id assert entry.id == entry3.id
assert entry.identifiers == {('bridgeid', '0123')} assert entry.identifiers == {('bridgeid', '0123')}
assert entry3.manufacturer == 'manufacturer'
assert entry3.model == 'model'
assert entry3.name == 'name'
assert entry3.sw_version == 'sw-version'
async def test_requirement_for_identifier_or_connection(registry): async def test_requirement_for_identifier_or_connection(registry):
"""Make sure we do require some descriptor of device.""" """Make sure we do require some descriptor of device."""