Make UniFi utilise forward_entry_setups (#74835)

This commit is contained in:
Robert Svensson 2022-07-18 17:39:38 +02:00 committed by GitHub
parent b3ef6f4d04
commit 3144d179e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 126 deletions

View File

@ -5,19 +5,13 @@ from typing import Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import CONF_CONTROLLER, DOMAIN as UNIFI_DOMAIN, UNIFI_WIRELESS_CLIENTS
ATTR_MANUFACTURER, from .controller import PLATFORMS, UniFiController, get_unifi_controller
CONF_CONTROLLER, from .errors import AuthenticationRequired, CannotConnect
DOMAIN as UNIFI_DOMAIN,
LOGGER,
UNIFI_WIRELESS_CLIENTS,
)
from .controller import UniFiController
from .services import async_setup_services, async_unload_services from .services import async_setup_services, async_unload_services
SAVE_DELAY = 10 SAVE_DELAY = 10
@ -40,9 +34,16 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
# Flat configuration was introduced with 2021.3 # Flat configuration was introduced with 2021.3
await async_flatten_entry_data(hass, config_entry) await async_flatten_entry_data(hass, config_entry)
controller = UniFiController(hass, config_entry) try:
if not await controller.async_setup(): api = await get_unifi_controller(hass, config_entry.data)
return False controller = UniFiController(hass, config_entry, api)
await controller.initialize()
except CannotConnect as err:
raise ConfigEntryNotReady from err
except AuthenticationRequired as err:
raise ConfigEntryAuthFailed from err
# Unique ID was introduced with 2021.3 # Unique ID was introduced with 2021.3
if config_entry.unique_id is None: if config_entry.unique_id is None:
@ -50,30 +51,19 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
config_entry, unique_id=controller.site_id config_entry, unique_id=controller.site_id
) )
if not hass.data[UNIFI_DOMAIN]: hass.data[UNIFI_DOMAIN][config_entry.entry_id] = controller
await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS)
await controller.async_update_device_registry()
if len(hass.data[UNIFI_DOMAIN]) == 1:
async_setup_services(hass) async_setup_services(hass)
hass.data[UNIFI_DOMAIN][config_entry.entry_id] = controller api.start_websocket()
config_entry.async_on_unload( config_entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, controller.shutdown) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, controller.shutdown)
) )
LOGGER.debug("UniFi Network config options %s", config_entry.options)
if controller.mac is None:
return True
device_registry = dr.async_get(hass)
device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
configuration_url=controller.api.url,
connections={(CONNECTION_NETWORK_MAC, controller.mac)},
default_manufacturer=ATTR_MANUFACTURER,
default_model="UniFi Network",
default_name="UniFi Network",
)
return True return True

View File

@ -9,6 +9,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import socket import socket
from types import MappingProxyType
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
@ -46,7 +47,7 @@ from .const import (
DEFAULT_POE_CLIENTS, DEFAULT_POE_CLIENTS,
DOMAIN as UNIFI_DOMAIN, DOMAIN as UNIFI_DOMAIN,
) )
from .controller import UniFiController, get_controller from .controller import UniFiController, get_unifi_controller
from .errors import AuthenticationRequired, CannotConnect from .errors import AuthenticationRequired, CannotConnect
DEFAULT_PORT = 443 DEFAULT_PORT = 443
@ -99,16 +100,9 @@ class UnifiFlowHandler(config_entries.ConfigFlow, domain=UNIFI_DOMAIN):
} }
try: try:
controller = await get_controller( controller = await get_unifi_controller(
self.hass, self.hass, MappingProxyType(self.config)
host=self.config[CONF_HOST],
username=self.config[CONF_USERNAME],
password=self.config[CONF_PASSWORD],
port=self.config[CONF_PORT],
site=self.config[CONF_SITE_ID],
verify_ssl=self.config[CONF_VERIFY_SSL],
) )
sites = await controller.sites() sites = await controller.sites()
except AuthenticationRequired: except AuthenticationRequired:

View File

@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import ssl import ssl
from types import MappingProxyType
from typing import Any
from aiohttp import CookieJar from aiohttp import CookieJar
import aiounifi import aiounifi
@ -36,14 +38,19 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import (
from homeassistant.helpers import aiohttp_client, entity_registry as er aiohttp_client,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity_registry import async_entries_for_config_entry from homeassistant.helpers.entity_registry import async_entries_for_config_entry
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import ( from .const import (
ATTR_MANUFACTURER,
CONF_ALLOW_BANDWIDTH_SENSORS, CONF_ALLOW_BANDWIDTH_SENSORS,
CONF_ALLOW_UPTIME_SENSORS, CONF_ALLOW_UPTIME_SENSORS,
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
@ -91,12 +98,15 @@ DEVICE_CONNECTED = (
class UniFiController: class UniFiController:
"""Manages a single UniFi Network instance.""" """Manages a single UniFi Network instance."""
def __init__(self, hass, config_entry): def __init__(self, hass, config_entry, api):
"""Initialize the system.""" """Initialize the system."""
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry
self.api = api
api.callback = self.async_unifi_signalling_callback
self.available = True self.available = True
self.api = None
self.progress = None self.progress = None
self.wireless_clients = None self.wireless_clients = None
@ -295,36 +305,18 @@ class UniFiController:
unifi_wireless_clients = self.hass.data[UNIFI_WIRELESS_CLIENTS] unifi_wireless_clients = self.hass.data[UNIFI_WIRELESS_CLIENTS]
unifi_wireless_clients.update_data(self.wireless_clients, self.config_entry) unifi_wireless_clients.update_data(self.wireless_clients, self.config_entry)
async def async_setup(self): async def initialize(self):
"""Set up a UniFi Network instance.""" """Set up a UniFi Network instance."""
try: await self.api.initialize()
self.api = await get_controller(
self.hass,
host=self.config_entry.data[CONF_HOST],
username=self.config_entry.data[CONF_USERNAME],
password=self.config_entry.data[CONF_PASSWORD],
port=self.config_entry.data[CONF_PORT],
site=self.config_entry.data[CONF_SITE_ID],
verify_ssl=self.config_entry.data[CONF_VERIFY_SSL],
async_callback=self.async_unifi_signalling_callback,
)
await self.api.initialize()
sites = await self.api.sites()
description = await self.api.site_description()
except CannotConnect as err:
raise ConfigEntryNotReady from err
except AuthenticationRequired as err:
raise ConfigEntryAuthFailed from err
sites = await self.api.sites()
for site in sites.values(): for site in sites.values():
if self.site == site["name"]: if self.site == site["name"]:
self.site_id = site["_id"] self.site_id = site["_id"]
self._site_name = site["desc"] self._site_name = site["desc"]
break break
description = await self.api.site_description()
self._site_role = description[0]["site_role"] self._site_role = description[0]["site_role"]
# Restore clients that are not a part of active clients list. # Restore clients that are not a part of active clients list.
@ -357,18 +349,12 @@ class UniFiController:
self.wireless_clients = wireless_clients.get_data(self.config_entry) self.wireless_clients = wireless_clients.get_data(self.config_entry)
self.update_wireless_clients() self.update_wireless_clients()
self.hass.config_entries.async_setup_platforms(self.config_entry, PLATFORMS)
self.api.start_websocket()
self.config_entry.add_update_listener(self.async_config_entry_updated) self.config_entry.add_update_listener(self.async_config_entry_updated)
self._cancel_heartbeat_check = async_track_time_interval( self._cancel_heartbeat_check = async_track_time_interval(
self.hass, self._async_check_for_stale, CHECK_HEARTBEAT_INTERVAL self.hass, self._async_check_for_stale, CHECK_HEARTBEAT_INTERVAL
) )
return True
@callback @callback
def async_heartbeat( def async_heartbeat(
self, unique_id: str, heartbeat_expire_time: datetime | None = None self, unique_id: str, heartbeat_expire_time: datetime | None = None
@ -397,6 +383,22 @@ class UniFiController:
for unique_id in unique_ids_to_remove: for unique_id in unique_ids_to_remove:
del self._heartbeat_time[unique_id] del self._heartbeat_time[unique_id]
async def async_update_device_registry(self) -> None:
"""Update device registry."""
if self.mac is None:
return
device_registry = dr.async_get(self.hass)
device_registry.async_get_or_create(
config_entry_id=self.config_entry.entry_id,
configuration_url=self.api.url,
connections={(CONNECTION_NETWORK_MAC, self.mac)},
default_manufacturer=ATTR_MANUFACTURER,
default_model="UniFi Network",
default_name="UniFi Network",
)
@staticmethod @staticmethod
async def async_config_entry_updated( async def async_config_entry_updated(
hass: HomeAssistant, config_entry: ConfigEntry hass: HomeAssistant, config_entry: ConfigEntry
@ -463,13 +465,14 @@ class UniFiController:
return True return True
async def get_controller( async def get_unifi_controller(
hass, host, username, password, port, site, verify_ssl, async_callback=None hass: HomeAssistant,
): config: MappingProxyType[str, Any],
) -> aiounifi.Controller:
"""Create a controller object and verify authentication.""" """Create a controller object and verify authentication."""
sslcontext = None sslcontext = None
if verify_ssl: if verify_ssl := bool(config.get(CONF_VERIFY_SSL)):
session = aiohttp_client.async_get_clientsession(hass) session = aiohttp_client.async_get_clientsession(hass)
if isinstance(verify_ssl, str): if isinstance(verify_ssl, str):
sslcontext = ssl.create_default_context(cafile=verify_ssl) sslcontext = ssl.create_default_context(cafile=verify_ssl)
@ -479,14 +482,13 @@ async def get_controller(
) )
controller = aiounifi.Controller( controller = aiounifi.Controller(
host, host=config[CONF_HOST],
username=username, username=config[CONF_USERNAME],
password=password, password=config[CONF_PASSWORD],
port=port, port=config[CONF_PORT],
site=site, site=config[CONF_SITE_ID],
websession=session, websession=session,
sslcontext=sslcontext, sslcontext=sslcontext,
callback=async_callback,
) )
try: try:
@ -498,7 +500,7 @@ async def get_controller(
except aiounifi.Unauthorized as err: except aiounifi.Unauthorized as err:
LOGGER.warning( LOGGER.warning(
"Connected to UniFi Network at %s but not registered: %s", "Connected to UniFi Network at %s but not registered: %s",
host, config[CONF_HOST],
err, err,
) )
raise AuthenticationRequired from err raise AuthenticationRequired from err
@ -510,13 +512,15 @@ async def get_controller(
aiounifi.RequestError, aiounifi.RequestError,
aiounifi.ResponseError, aiounifi.ResponseError,
) as err: ) as err:
LOGGER.error("Error connecting to the UniFi Network at %s: %s", host, err) LOGGER.error(
"Error connecting to the UniFi Network at %s: %s", config[CONF_HOST], err
)
raise CannotConnect from err raise CannotConnect from err
except aiounifi.LoginRequired as err: except aiounifi.LoginRequired as err:
LOGGER.warning( LOGGER.warning(
"Connected to UniFi Network at %s but login required: %s", "Connected to UniFi Network at %s but login required: %s",
host, config[CONF_HOST],
err, err,
) )
raise AuthenticationRequired from err raise AuthenticationRequired from err

View File

@ -30,7 +30,7 @@ from homeassistant.components.unifi.const import (
from homeassistant.components.unifi.controller import ( from homeassistant.components.unifi.controller import (
PLATFORMS, PLATFORMS,
RETRY_TIMER, RETRY_TIMER,
get_controller, get_unifi_controller,
) )
from homeassistant.components.unifi.errors import AuthenticationRequired, CannotConnect from homeassistant.components.unifi.errors import AuthenticationRequired, CannotConnect
from homeassistant.const import ( from homeassistant.const import (
@ -271,7 +271,7 @@ async def test_controller_mac(hass, aioclient_mock):
async def test_controller_not_accessible(hass): async def test_controller_not_accessible(hass):
"""Retry to login gets scheduled when connection fails.""" """Retry to login gets scheduled when connection fails."""
with patch( with patch(
"homeassistant.components.unifi.controller.get_controller", "homeassistant.components.unifi.controller.get_unifi_controller",
side_effect=CannotConnect, side_effect=CannotConnect,
): ):
await setup_unifi_integration(hass) await setup_unifi_integration(hass)
@ -281,7 +281,7 @@ async def test_controller_not_accessible(hass):
async def test_controller_trigger_reauth_flow(hass): async def test_controller_trigger_reauth_flow(hass):
"""Failed authentication trigger a reauthentication flow.""" """Failed authentication trigger a reauthentication flow."""
with patch( with patch(
"homeassistant.components.unifi.controller.get_controller", "homeassistant.components.unifi.get_unifi_controller",
side_effect=AuthenticationRequired, side_effect=AuthenticationRequired,
), patch.object(hass.config_entries.flow, "async_init") as mock_flow_init: ), patch.object(hass.config_entries.flow, "async_init") as mock_flow_init:
await setup_unifi_integration(hass) await setup_unifi_integration(hass)
@ -292,7 +292,7 @@ async def test_controller_trigger_reauth_flow(hass):
async def test_controller_unknown_error(hass): async def test_controller_unknown_error(hass):
"""Unknown errors are handled.""" """Unknown errors are handled."""
with patch( with patch(
"homeassistant.components.unifi.controller.get_controller", "homeassistant.components.unifi.controller.get_unifi_controller",
side_effect=Exception, side_effect=Exception,
): ):
await setup_unifi_integration(hass) await setup_unifi_integration(hass)
@ -470,22 +470,22 @@ async def test_reconnect_mechanism_exceptions(
mock_reconnect.assert_called_once() mock_reconnect.assert_called_once()
async def test_get_controller(hass): async def test_get_unifi_controller(hass):
"""Successful call.""" """Successful call."""
with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch( with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch(
"aiounifi.Controller.login", return_value=True "aiounifi.Controller.login", return_value=True
): ):
assert await get_controller(hass, **CONTROLLER_DATA) assert await get_unifi_controller(hass, CONTROLLER_DATA)
async def test_get_controller_verify_ssl_false(hass): async def test_get_unifi_controller_verify_ssl_false(hass):
"""Successful call with verify ssl set to false.""" """Successful call with verify ssl set to false."""
controller_data = dict(CONTROLLER_DATA) controller_data = dict(CONTROLLER_DATA)
controller_data[CONF_VERIFY_SSL] = False controller_data[CONF_VERIFY_SSL] = False
with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch( with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch(
"aiounifi.Controller.login", return_value=True "aiounifi.Controller.login", return_value=True
): ):
assert await get_controller(hass, **controller_data) assert await get_unifi_controller(hass, controller_data)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -501,9 +501,11 @@ async def test_get_controller_verify_ssl_false(hass):
(aiounifi.AiounifiException, AuthenticationRequired), (aiounifi.AiounifiException, AuthenticationRequired),
], ],
) )
async def test_get_controller_fails_to_connect(hass, side_effect, raised_exception): async def test_get_unifi_controller_fails_to_connect(
"""Check that get_controller can handle controller being unavailable.""" hass, side_effect, raised_exception
):
"""Check that get_unifi_controller can handle controller being unavailable."""
with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch( with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch(
"aiounifi.Controller.login", side_effect=side_effect "aiounifi.Controller.login", side_effect=side_effect
), pytest.raises(raised_exception): ), pytest.raises(raised_exception):
await get_controller(hass, **CONTROLLER_DATA) await get_unifi_controller(hass, CONTROLLER_DATA)

View File

@ -1,10 +1,10 @@
"""Test UniFi Network integration setup process.""" """Test UniFi Network integration setup process."""
from unittest.mock import AsyncMock, patch from unittest.mock import patch
from homeassistant.components import unifi from homeassistant.components import unifi
from homeassistant.components.unifi import async_flatten_entry_data from homeassistant.components.unifi import async_flatten_entry_data
from homeassistant.components.unifi.const import CONF_CONTROLLER, DOMAIN as UNIFI_DOMAIN from homeassistant.components.unifi.const import CONF_CONTROLLER, DOMAIN as UNIFI_DOMAIN
from homeassistant.helpers import device_registry as dr from homeassistant.components.unifi.errors import AuthenticationRequired, CannotConnect
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .test_controller import ( from .test_controller import (
@ -29,40 +29,27 @@ async def test_successful_config_entry(hass, aioclient_mock):
assert hass.data[UNIFI_DOMAIN] assert hass.data[UNIFI_DOMAIN]
async def test_controller_fail_setup(hass): async def test_setup_entry_fails_config_entry_not_ready(hass):
"""Test that a failed setup still stores controller.""" """Failed authentication trigger a reauthentication flow."""
with patch("homeassistant.components.unifi.UniFiController") as mock_controller: with patch(
mock_controller.return_value.async_setup = AsyncMock(return_value=False) "homeassistant.components.unifi.get_unifi_controller",
side_effect=CannotConnect,
):
await setup_unifi_integration(hass) await setup_unifi_integration(hass)
assert hass.data[UNIFI_DOMAIN] == {} assert hass.data[UNIFI_DOMAIN] == {}
async def test_controller_mac(hass): async def test_setup_entry_fails_trigger_reauth_flow(hass):
"""Test that configured options for a host are loaded via config entry.""" """Failed authentication trigger a reauthentication flow."""
entry = MockConfigEntry( with patch(
domain=UNIFI_DOMAIN, data=ENTRY_CONFIG, unique_id="1", entry_id=1 "homeassistant.components.unifi.get_unifi_controller",
) side_effect=AuthenticationRequired,
entry.add_to_hass(hass) ), patch.object(hass.config_entries.flow, "async_init") as mock_flow_init:
await setup_unifi_integration(hass)
mock_flow_init.assert_called_once()
with patch("homeassistant.components.unifi.UniFiController") as mock_controller: assert hass.data[UNIFI_DOMAIN] == {}
mock_controller.return_value.async_setup = AsyncMock(return_value=True)
mock_controller.return_value.mac = "mac1"
mock_controller.return_value.api.url = "https://123:443"
assert await unifi.async_setup_entry(hass, entry) is True
assert len(mock_controller.mock_calls) == 2
device_registry = dr.async_get(hass)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "mac1")},
)
assert device.configuration_url == "https://123:443"
assert device.manufacturer == "Ubiquiti Networks"
assert device.model == "UniFi Network"
assert device.name == "UniFi Network"
assert device.sw_version is None
async def test_flatten_entry_data(hass): async def test_flatten_entry_data(hass):