Make yeelight aware of the network integration (#54854)

This commit is contained in:
J. Nick Koston 2021-08-20 19:09:22 -05:00 committed by GitHub
parent 1325b38256
commit b71f2689d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 283 additions and 37 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
from ipaddress import IPv4Address, IPv6Address
import logging import logging
from urllib.parse import urlparse from urllib.parse import urlparse
@ -13,6 +14,7 @@ from yeelight import BulbException
from yeelight.aio import KEY_CONNECTED, AsyncBulb from yeelight.aio import KEY_CONNECTED, AsyncBulb
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import network
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryNotReady from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryNotReady
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICES, CONF_DEVICES,
@ -269,13 +271,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
raise ConfigEntryNotReady from ex raise ConfigEntryNotReady from ex
# Otherwise fall through to discovery # Otherwise fall through to discovery
else: else:
# manually added device # Since device is passed this cannot throw an exception anymore
try: await _async_initialize(hass, entry, entry.data[CONF_HOST], device=device)
await _async_initialize(
hass, entry, entry.data[CONF_HOST], device=device
)
except BulbException as ex:
raise ConfigEntryNotReady from ex
return True return True
async def _async_from_discovery(capabilities: dict[str, str]) -> None: async def _async_from_discovery(capabilities: dict[str, str]) -> None:
@ -367,34 +364,77 @@ class YeelightScanner:
self._unique_id_capabilities = {} self._unique_id_capabilities = {}
self._host_capabilities = {} self._host_capabilities = {}
self._track_interval = None self._track_interval = None
self._listener = None self._listeners = []
self._connected_event = None self._connected_events = []
async def async_setup(self): async def async_setup(self):
"""Set up the scanner.""" """Set up the scanner."""
if self._connected_event: if self._connected_events:
await self._connected_event.wait() await asyncio.gather(*(event.wait() for event in self._connected_events))
return return
self._connected_event = asyncio.Event()
for idx, source_ip in enumerate(await self._async_build_source_set()):
self._connected_events.append(asyncio.Event())
def _wrap_async_connected_idx(idx):
"""Create a function to capture the idx cell variable."""
async def _async_connected(): async def _async_connected():
self._listener.async_search() self._connected_events[idx].set()
self._connected_event.set()
self._listener = SSDPListener( return _async_connected
self._listeners.append(
SSDPListener(
async_callback=self._async_process_entry, async_callback=self._async_process_entry,
service_type=SSDP_ST, service_type=SSDP_ST,
target=SSDP_TARGET, target=SSDP_TARGET,
async_connect_callback=_async_connected, source_ip=source_ip,
async_connect_callback=_wrap_async_connected_idx(idx),
) )
await self._listener.async_start() )
await self._connected_event.wait()
results = await asyncio.gather(
*(listener.async_start() for listener in self._listeners),
return_exceptions=True,
)
failed_listeners = []
for idx, result in enumerate(results):
if not isinstance(result, Exception):
continue
_LOGGER.warning(
"Failed to setup listener for %s: %s",
self._listeners[idx].source_ip,
result,
)
failed_listeners.append(self._listeners[idx])
self._connected_events[idx].set()
for listener in failed_listeners:
self._listeners.remove(listener)
await asyncio.gather(*(event.wait() for event in self._connected_events))
self.async_scan()
async def _async_build_source_set(self) -> set[IPv4Address]:
"""Build the list of ssdp sources."""
adapters = await network.async_get_adapters(self._hass)
sources: set[IPv4Address] = set()
if network.async_only_default_interface_enabled(adapters):
sources.add(IPv4Address("0.0.0.0"))
return sources
return {
source_ip
for source_ip in await network.async_get_enabled_source_ips(self._hass)
if not source_ip.is_loopback and not isinstance(source_ip, IPv6Address)
}
async def async_discover(self): async def async_discover(self):
"""Discover bulbs.""" """Discover bulbs."""
await self.async_setup() await self.async_setup()
for _ in range(DISCOVERY_ATTEMPTS): for _ in range(DISCOVERY_ATTEMPTS):
self._listener.async_search() self.async_scan()
await asyncio.sleep(DISCOVERY_SEARCH_INTERVAL.total_seconds()) await asyncio.sleep(DISCOVERY_SEARCH_INTERVAL.total_seconds())
return self._unique_id_capabilities.values() return self._unique_id_capabilities.values()
@ -402,7 +442,8 @@ class YeelightScanner:
def async_scan(self, *_): def async_scan(self, *_):
"""Send discovery packets.""" """Send discovery packets."""
_LOGGER.debug("Yeelight scanning") _LOGGER.debug("Yeelight scanning")
self._listener.async_search() for listener in self._listeners:
listener.async_search()
async def async_get_capabilities(self, host): async def async_get_capabilities(self, host):
"""Get capabilities via SSDP.""" """Get capabilities via SSDP."""
@ -413,7 +454,8 @@ class YeelightScanner:
self._host_discovered_events.setdefault(host, []).append(host_event) self._host_discovered_events.setdefault(host, []).append(host_event)
await self.async_setup() await self.async_setup()
self._listener.async_search((host, SSDP_TARGET[1])) for listener in self._listeners:
listener.async_search((host, SSDP_TARGET[1]))
with contextlib.suppress(asyncio.TimeoutError): with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(host_event.wait(), timeout=DISCOVERY_TIMEOUT) await asyncio.wait_for(host_event.wait(), timeout=DISCOVERY_TIMEOUT)

View File

@ -5,6 +5,7 @@
"requirements": ["yeelight==0.7.2", "async-upnp-client==0.20.0"], "requirements": ["yeelight==0.7.2", "async-upnp-client==0.20.0"],
"codeowners": ["@rytilahti", "@zewelor", "@shenxn", "@starkillerOG"], "codeowners": ["@rytilahti", "@zewelor", "@shenxn", "@starkillerOG"],
"config_flow": true, "config_flow": true,
"dependencies": ["network"],
"quality_scale": "platinum", "quality_scale": "platinum",
"iot_class": "local_push", "iot_class": "local_push",
"dhcp": [{ "dhcp": [{

View File

@ -1,6 +1,7 @@
"""Tests for the Yeelight integration.""" """Tests for the Yeelight integration."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from ipaddress import IPv4Address
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from async_upnp_client.search import SSDPListener from async_upnp_client.search import SSDPListener
@ -19,6 +20,8 @@ from homeassistant.components.yeelight import (
from homeassistant.const import CONF_DEVICES, CONF_ID, CONF_NAME from homeassistant.const import CONF_DEVICES, CONF_ID, CONF_NAME
from homeassistant.core import callback from homeassistant.core import callback
FAIL_TO_BIND_IP = "1.2.3.4"
IP_ADDRESS = "192.168.1.239" IP_ADDRESS = "192.168.1.239"
MODEL = "color" MODEL = "color"
ID = "0x000000000015243f" ID = "0x000000000015243f"
@ -127,6 +130,8 @@ def _patched_ssdp_listener(info, *args, **kwargs):
listener = SSDPListener(*args, **kwargs) listener = SSDPListener(*args, **kwargs)
async def _async_callback(*_): async def _async_callback(*_):
if kwargs["source_ip"] == IPv4Address(FAIL_TO_BIND_IP):
raise OSError
await listener.async_connect_callback() await listener.async_connect_callback()
@callback @callback
@ -139,12 +144,12 @@ def _patched_ssdp_listener(info, *args, **kwargs):
return listener return listener
def _patch_discovery(no_device=False): def _patch_discovery(no_device=False, capabilities=None):
YeelightScanner._scanner = None # Clear class scanner to reset hass YeelightScanner._scanner = None # Clear class scanner to reset hass
def _generate_fake_ssdp_listener(*args, **kwargs): def _generate_fake_ssdp_listener(*args, **kwargs):
return _patched_ssdp_listener( return _patched_ssdp_listener(
None if no_device else CAPABILITIES, None if no_device else capabilities or CAPABILITIES,
*args, *args,
**kwargs, **kwargs,
) )

View File

@ -51,6 +51,20 @@ DEFAULT_CONFIG = {
async def test_discovery(hass: HomeAssistant): async def test_discovery(hass: HomeAssistant):
"""Test setting up discovery.""" """Test setting up discovery."""
with _patch_discovery(), _patch_discovery_interval():
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == "form"
assert result["step_id"] == "user"
assert not result["errors"]
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2["type"] == "form"
assert result2["step_id"] == "pick_device"
assert not result2["errors"]
# test we can try again
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
@ -58,7 +72,6 @@ async def test_discovery(hass: HomeAssistant):
assert result["step_id"] == "user" assert result["step_id"] == "user"
assert not result["errors"] assert not result["errors"]
with _patch_discovery(), _patch_discovery_interval():
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2["type"] == "form" assert result2["type"] == "form"
assert result2["step_id"] == "pick_device" assert result2["step_id"] == "pick_device"
@ -93,6 +106,78 @@ async def test_discovery(hass: HomeAssistant):
assert result2["reason"] == "no_devices_found" assert result2["reason"] == "no_devices_found"
async def test_discovery_with_existing_device_present(hass: HomeAssistant):
"""Test setting up discovery."""
config_entry = MockConfigEntry(
domain=DOMAIN, data={CONF_ID: "0x000000000099999", CONF_HOST: "4.4.4.4"}
)
config_entry.add_to_hass(hass)
alternate_bulb = _mocked_bulb()
alternate_bulb.capabilities["id"] = "0x000000000099999"
alternate_bulb.capabilities["location"] = "yeelight://4.4.4.4"
with _patch_discovery(), patch(f"{MODULE}.AsyncBulb", return_value=alternate_bulb):
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == "form"
assert result["step_id"] == "user"
assert not result["errors"]
with _patch_discovery(), _patch_discovery_interval():
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
await hass.async_block_till_done()
await hass.async_block_till_done()
assert result2["type"] == "form"
assert result2["step_id"] == "pick_device"
assert not result2["errors"]
# Now abort and make sure we can start over
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == "form"
assert result["step_id"] == "user"
assert not result["errors"]
with _patch_discovery(), _patch_discovery_interval():
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2["type"] == "form"
assert result2["step_id"] == "pick_device"
assert not result2["errors"]
with _patch_discovery(), _patch_discovery_interval(), patch(
f"{MODULE}.AsyncBulb", return_value=_mocked_bulb()
):
result3 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_DEVICE: ID}
)
assert result3["type"] == "create_entry"
assert result3["title"] == UNIQUE_FRIENDLY_NAME
assert result3["data"] == {CONF_ID: ID, CONF_HOST: IP_ADDRESS}
await hass.async_block_till_done()
await hass.async_block_till_done()
# ignore configured devices
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == "form"
assert result["step_id"] == "user"
assert not result["errors"]
with _patch_discovery(), _patch_discovery_interval():
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2["type"] == "abort"
assert result2["reason"] == "no_devices_found"
async def test_discovery_no_device(hass: HomeAssistant): async def test_discovery_no_device(hass: HomeAssistant):
"""Test discovery without device.""" """Test discovery without device."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(

View File

@ -32,6 +32,7 @@ from . import (
ENTITY_BINARY_SENSOR_TEMPLATE, ENTITY_BINARY_SENSOR_TEMPLATE,
ENTITY_LIGHT, ENTITY_LIGHT,
ENTITY_NIGHTLIGHT, ENTITY_NIGHTLIGHT,
FAIL_TO_BIND_IP,
ID, ID,
IP_ADDRESS, IP_ADDRESS,
MODULE, MODULE,
@ -131,6 +132,107 @@ async def test_setup_discovery(hass: HomeAssistant):
assert hass.states.get(ENTITY_LIGHT) is None assert hass.states.get(ENTITY_LIGHT) is None
_ADAPTERS_WITH_MANUAL_CONFIG = [
{
"auto": True,
"index": 2,
"default": False,
"enabled": True,
"ipv4": [{"address": "192.168.1.5", "network_prefix": 23}],
"ipv6": [],
"name": "eth1",
},
]
async def test_setup_discovery_with_manually_configured_network_adapter(
hass: HomeAssistant,
):
"""Test setting up Yeelight by discovery with a manually configured network adapter."""
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
config_entry.add_to_hass(hass)
mocked_bulb = _mocked_bulb()
with _patch_discovery(), patch(
f"{MODULE}.AsyncBulb", return_value=mocked_bulb
), patch(
"homeassistant.components.zeroconf.network.async_get_adapters",
return_value=_ADAPTERS_WITH_MANUAL_CONFIG,
):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert hass.states.get(ENTITY_BINARY_SENSOR) is not None
assert hass.states.get(ENTITY_LIGHT) is not None
# Unload
assert await hass.config_entries.async_unload(config_entry.entry_id)
assert hass.states.get(ENTITY_BINARY_SENSOR).state == STATE_UNAVAILABLE
assert hass.states.get(ENTITY_LIGHT).state == STATE_UNAVAILABLE
# Remove
assert await hass.config_entries.async_remove(config_entry.entry_id)
await hass.async_block_till_done()
assert hass.states.get(ENTITY_BINARY_SENSOR) is None
assert hass.states.get(ENTITY_LIGHT) is None
_ADAPTERS_WITH_MANUAL_CONFIG_ONE_FAILING = [
{
"auto": True,
"index": 1,
"default": False,
"enabled": True,
"ipv4": [{"address": FAIL_TO_BIND_IP, "network_prefix": 23}],
"ipv6": [],
"name": "eth0",
},
{
"auto": True,
"index": 2,
"default": False,
"enabled": True,
"ipv4": [{"address": "192.168.1.5", "network_prefix": 23}],
"ipv6": [],
"name": "eth1",
},
]
async def test_setup_discovery_with_manually_configured_network_adapter_one_fails(
hass: HomeAssistant, caplog
):
"""Test setting up Yeelight by discovery with a manually configured network adapter with one that fails to bind."""
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
config_entry.add_to_hass(hass)
mocked_bulb = _mocked_bulb()
with _patch_discovery(), patch(
f"{MODULE}.AsyncBulb", return_value=mocked_bulb
), patch(
"homeassistant.components.zeroconf.network.async_get_adapters",
return_value=_ADAPTERS_WITH_MANUAL_CONFIG_ONE_FAILING,
):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert hass.states.get(ENTITY_BINARY_SENSOR) is not None
assert hass.states.get(ENTITY_LIGHT) is not None
# Unload
assert await hass.config_entries.async_unload(config_entry.entry_id)
assert hass.states.get(ENTITY_BINARY_SENSOR).state == STATE_UNAVAILABLE
assert hass.states.get(ENTITY_LIGHT).state == STATE_UNAVAILABLE
# Remove
assert await hass.config_entries.async_remove(config_entry.entry_id)
await hass.async_block_till_done()
assert hass.states.get(ENTITY_BINARY_SENSOR) is None
assert hass.states.get(ENTITY_LIGHT) is None
assert f"Failed to setup listener for {FAIL_TO_BIND_IP}" in caplog.text
async def test_setup_import(hass: HomeAssistant): async def test_setup_import(hass: HomeAssistant):
"""Test import from yaml.""" """Test import from yaml."""
mocked_bulb = _mocked_bulb() mocked_bulb = _mocked_bulb()
@ -247,6 +349,17 @@ async def test_async_listen_error_late_discovery(hass, caplog):
assert config_entry.state is ConfigEntryState.LOADED assert config_entry.state is ConfigEntryState.LOADED
assert "Failed to connect to bulb at" in caplog.text assert "Failed to connect to bulb at" in caplog.text
await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done()
caplog.clear()
with _patch_discovery(), patch(f"{MODULE}.AsyncBulb", return_value=_mocked_bulb()):
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert "Failed to connect to bulb at" not in caplog.text
assert config_entry.state is ConfigEntryState.LOADED
async def test_async_listen_error_has_host_with_id(hass: HomeAssistant): async def test_async_listen_error_has_host_with_id(hass: HomeAssistant):