Add timeout to lutron_caseta to prevent it blocking startup (#45769)

This commit is contained in:
J. Nick Koston 2021-01-31 10:43:00 -10:00 committed by GitHub
parent 852af7e372
commit 73d7d80731
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 38 deletions

View File

@ -1,10 +1,12 @@
"""Component for interacting with a Lutron Caseta system.""" """Component for interacting with a Lutron Caseta system."""
import asyncio import asyncio
import logging import logging
import ssl
from aiolip import LIP from aiolip import LIP
from aiolip.data import LIPMode from aiolip.data import LIPMode
from aiolip.protocol import LIP_BUTTON_PRESS from aiolip.protocol import LIP_BUTTON_PRESS
import async_timeout
from pylutron_caseta.smartbridge import Smartbridge from pylutron_caseta.smartbridge import Smartbridge
import voluptuous as vol import voluptuous as vol
@ -29,6 +31,7 @@ from .const import (
BRIDGE_DEVICE_ID, BRIDGE_DEVICE_ID,
BRIDGE_LEAP, BRIDGE_LEAP,
BRIDGE_LIP, BRIDGE_LIP,
BRIDGE_TIMEOUT,
BUTTON_DEVICES, BUTTON_DEVICES,
CONF_CA_CERTS, CONF_CA_CERTS,
CONF_CERTFILE, CONF_CERTFILE,
@ -94,15 +97,26 @@ async def async_setup_entry(hass, config_entry):
keyfile = hass.config.path(config_entry.data[CONF_KEYFILE]) keyfile = hass.config.path(config_entry.data[CONF_KEYFILE])
certfile = hass.config.path(config_entry.data[CONF_CERTFILE]) certfile = hass.config.path(config_entry.data[CONF_CERTFILE])
ca_certs = hass.config.path(config_entry.data[CONF_CA_CERTS]) ca_certs = hass.config.path(config_entry.data[CONF_CA_CERTS])
bridge = None
bridge = Smartbridge.create_tls( try:
hostname=host, keyfile=keyfile, certfile=certfile, ca_certs=ca_certs bridge = Smartbridge.create_tls(
) hostname=host, keyfile=keyfile, certfile=certfile, ca_certs=ca_certs
)
except ssl.SSLError:
_LOGGER.error("Invalid certificate used to connect to bridge at %s.", host)
return False
await bridge.connect() timed_out = True
if not bridge.is_connected(): try:
with async_timeout.timeout(BRIDGE_TIMEOUT):
await bridge.connect()
timed_out = False
except asyncio.TimeoutError:
_LOGGER.error("Timeout while trying to connect to bridge at %s.", host)
if timed_out or not bridge.is_connected():
await bridge.close() await bridge.close()
_LOGGER.error("Unable to connect to Lutron Caseta bridge at %s", host)
raise ConfigEntryNotReady raise ConfigEntryNotReady
_LOGGER.debug("Connected to Lutron Caseta bridge via LEAP at %s", host) _LOGGER.debug("Connected to Lutron Caseta bridge via LEAP at %s", host)

View File

@ -2,7 +2,9 @@
import asyncio import asyncio
import logging import logging
import os import os
import ssl
import async_timeout
from pylutron_caseta.pairing import PAIR_CA, PAIR_CERT, PAIR_KEY, async_pair from pylutron_caseta.pairing import PAIR_CA, PAIR_CERT, PAIR_KEY, async_pair
from pylutron_caseta.smartbridge import Smartbridge from pylutron_caseta.smartbridge import Smartbridge
import voluptuous as vol import voluptuous as vol
@ -15,6 +17,7 @@ from homeassistant.core import callback
from .const import ( from .const import (
ABORT_REASON_ALREADY_CONFIGURED, ABORT_REASON_ALREADY_CONFIGURED,
ABORT_REASON_CANNOT_CONNECT, ABORT_REASON_CANNOT_CONNECT,
BRIDGE_TIMEOUT,
CONF_CA_CERTS, CONF_CA_CERTS,
CONF_CERTFILE, CONF_CERTFILE,
CONF_KEYFILE, CONF_KEYFILE,
@ -50,6 +53,8 @@ class LutronCasetaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Initialize a Lutron Caseta flow.""" """Initialize a Lutron Caseta flow."""
self.data = {} self.data = {}
self.lutron_id = None self.lutron_id = None
self.tls_assets_validated = False
self.attempted_tls_validation = False
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
@ -92,11 +97,16 @@ class LutronCasetaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._configure_tls_assets() self._configure_tls_assets()
if (
not self.attempted_tls_validation
and await self.hass.async_add_executor_job(self._tls_assets_exist)
and await self.async_validate_connectable_bridge_config()
):
self.tls_assets_validated = True
self.attempted_tls_validation = True
if user_input is not None: if user_input is not None:
if ( if self.tls_assets_validated:
await self.hass.async_add_executor_job(self._tls_assets_exist)
and await self.async_validate_connectable_bridge_config()
):
# If we previous paired and the tls assets already exist, # If we previous paired and the tls assets already exist,
# we do not need to go though pairing again. # we do not need to go though pairing again.
return self.async_create_entry(title=self.bridge_id, data=self.data) return self.async_create_entry(title=self.bridge_id, data=self.data)
@ -207,6 +217,8 @@ class LutronCasetaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
async def async_validate_connectable_bridge_config(self): async def async_validate_connectable_bridge_config(self):
"""Check if we can connect to the bridge with the current config.""" """Check if we can connect to the bridge with the current config."""
bridge = None
try: try:
bridge = Smartbridge.create_tls( bridge = Smartbridge.create_tls(
hostname=self.data[CONF_HOST], hostname=self.data[CONF_HOST],
@ -214,16 +226,23 @@ class LutronCasetaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
certfile=self.hass.config.path(self.data[CONF_CERTFILE]), certfile=self.hass.config.path(self.data[CONF_CERTFILE]),
ca_certs=self.hass.config.path(self.data[CONF_CA_CERTS]), ca_certs=self.hass.config.path(self.data[CONF_CA_CERTS]),
) )
except ssl.SSLError:
await bridge.connect() _LOGGER.error(
if not bridge.is_connected(): "Invalid certificate used to connect to bridge at %s.",
return False
await bridge.close()
return True
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Unknown exception while checking connectivity to bridge %s",
self.data[CONF_HOST], self.data[CONF_HOST],
) )
return False return False
connected_ok = False
try:
with async_timeout.timeout(BRIDGE_TIMEOUT):
await bridge.connect()
connected_ok = bridge.is_connected()
except asyncio.TimeoutError:
_LOGGER.error(
"Timeout while trying to connect to bridge at %s.",
self.data[CONF_HOST],
)
await bridge.close()
return connected_ok

View File

@ -33,3 +33,5 @@ ACTION_RELEASE = "release"
CONF_TYPE = "type" CONF_TYPE = "type"
CONF_SUBTYPE = "subtype" CONF_SUBTYPE = "subtype"
BRIDGE_TIMEOUT = 35

View File

@ -1,5 +1,6 @@
"""Test the Lutron Caseta config flow.""" """Test the Lutron Caseta config flow."""
import asyncio import asyncio
import ssl
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from pylutron_caseta.pairing import PAIR_CA, PAIR_CERT, PAIR_KEY from pylutron_caseta.pairing import PAIR_CA, PAIR_CERT, PAIR_KEY
@ -21,6 +22,14 @@ from homeassistant.const import CONF_HOST
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
EMPTY_MOCK_CONFIG_ENTRY = {
CONF_HOST: "",
CONF_KEYFILE: "",
CONF_CERTFILE: "",
CONF_CA_CERTS: "",
}
MOCK_ASYNC_PAIR_SUCCESS = { MOCK_ASYNC_PAIR_SUCCESS = {
PAIR_KEY: "mock_key", PAIR_KEY: "mock_key",
PAIR_CERT: "mock_cert", PAIR_CERT: "mock_cert",
@ -115,21 +124,34 @@ async def test_bridge_cannot_connect(hass):
async def test_bridge_cannot_connect_unknown_error(hass): async def test_bridge_cannot_connect_unknown_error(hass):
"""Test checking for connection and encountering an unknown error.""" """Test checking for connection and encountering an unknown error."""
entry_mock_data = {
CONF_HOST: "",
CONF_KEYFILE: "",
CONF_CERTFILE: "",
CONF_CA_CERTS: "",
}
with patch.object(Smartbridge, "create_tls") as create_tls: with patch.object(Smartbridge, "create_tls") as create_tls:
mock_bridge = MockBridge() mock_bridge = MockBridge()
mock_bridge.connect = AsyncMock(side_effect=Exception()) mock_bridge.connect = AsyncMock(side_effect=asyncio.TimeoutError)
create_tls.return_value = mock_bridge create_tls.return_value = mock_bridge
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": config_entries.SOURCE_IMPORT}, context={"source": config_entries.SOURCE_IMPORT},
data=entry_mock_data, data=EMPTY_MOCK_CONFIG_ENTRY,
)
assert result["type"] == "form"
assert result["step_id"] == STEP_IMPORT_FAILED
assert result["errors"] == {"base": ERROR_CANNOT_CONNECT}
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == CasetaConfigFlow.ABORT_REASON_CANNOT_CONNECT
async def test_bridge_invalid_ssl_error(hass):
"""Test checking for connection and encountering invalid ssl certs."""
with patch.object(Smartbridge, "create_tls", side_effect=ssl.SSLError):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_IMPORT},
data=EMPTY_MOCK_CONFIG_ENTRY,
) )
assert result["type"] == "form" assert result["type"] == "form"
@ -351,23 +373,25 @@ async def test_form_user_reuses_existing_assets_when_pairing_again(hass, tmpdir)
assert result["errors"] is None assert result["errors"] is None
assert result["step_id"] == "user" assert result["step_id"] == "user"
result2 = await hass.config_entries.flow.async_configure( with patch.object(Smartbridge, "create_tls") as create_tls:
result["flow_id"], create_tls.return_value = MockBridge(can_connect=True)
{ result2 = await hass.config_entries.flow.async_configure(
CONF_HOST: "1.1.1.1", result["flow_id"],
}, {
) CONF_HOST: "1.1.1.1",
await hass.async_block_till_done() },
)
await hass.async_block_till_done()
assert result2["type"] == "form" assert result2["type"] == "form"
assert result2["step_id"] == "link" assert result2["step_id"] == "link"
with patch.object(Smartbridge, "create_tls") as create_tls, patch( with patch(
"homeassistant.components.lutron_caseta.async_setup", return_value=True "homeassistant.components.lutron_caseta.async_setup", return_value=True
), patch( ), patch(
"homeassistant.components.lutron_caseta.async_setup_entry", "homeassistant.components.lutron_caseta.async_setup_entry",
return_value=True, return_value=True,
): ):
create_tls.return_value = MockBridge(can_connect=True)
result3 = await hass.config_entries.flow.async_configure( result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], result2["flow_id"],
{}, {},