Convert cert_expiry to use asyncio (#106919)

This commit is contained in:
J. Nick Koston 2024-01-05 08:03:53 -10:00 committed by GitHub
parent 9a15a5b6c2
commit 24ee64e20c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 22 deletions

View File

@ -1,7 +1,10 @@
"""Helper functions for the Cert Expiry platform.""" """Helper functions for the Cert Expiry platform."""
import asyncio
import datetime
from functools import cache from functools import cache
import socket import socket
import ssl import ssl
from typing import Any
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -21,31 +24,38 @@ def _get_default_ssl_context():
return ssl.create_default_context() return ssl.create_default_context()
def get_cert( async def async_get_cert(
hass: HomeAssistant,
host: str, host: str,
port: int, port: int,
): ) -> dict[str, Any]:
"""Get the certificate for the host and port combination.""" """Get the certificate for the host and port combination."""
ctx = _get_default_ssl_context() async with asyncio.timeout(TIMEOUT):
address = (host, port) transport, _ = await hass.loop.create_connection(
with socket.create_connection(address, timeout=TIMEOUT) as sock, ctx.wrap_socket( asyncio.Protocol,
sock, server_hostname=address[0] host,
) as ssock: port,
cert = ssock.getpeercert() ssl=_get_default_ssl_context(),
return cert happy_eyeballs_delay=0.25,
server_hostname=host,
)
try:
return transport.get_extra_info("peercert")
finally:
transport.close()
async def get_cert_expiry_timestamp( async def get_cert_expiry_timestamp(
hass: HomeAssistant, hass: HomeAssistant,
hostname: str, hostname: str,
port: int, port: int,
): ) -> datetime.datetime:
"""Return the certificate's expiration timestamp.""" """Return the certificate's expiration timestamp."""
try: try:
cert = await hass.async_add_executor_job(get_cert, hostname, port) cert = await async_get_cert(hass, hostname, port)
except socket.gaierror as err: except socket.gaierror as err:
raise ResolveFailed(f"Cannot resolve hostname: {hostname}") from err raise ResolveFailed(f"Cannot resolve hostname: {hostname}") from err
except socket.timeout as err: except asyncio.TimeoutError as err:
raise ConnectionTimeout( raise ConnectionTimeout(
f"Connection timeout with server: {hostname}:{port}" f"Connection timeout with server: {hostname}:{port}"
) from err ) from err

View File

@ -1,4 +1,5 @@
"""Tests for the Cert Expiry config flow.""" """Tests for the Cert Expiry config flow."""
import asyncio
import socket import socket
import ssl import ssl
from unittest.mock import patch from unittest.mock import patch
@ -48,7 +49,7 @@ async def test_user_with_bad_cert(hass: HomeAssistant) -> None:
assert result["step_id"] == "user" assert result["step_id"] == "user"
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("some error"), side_effect=ssl.SSLError("some error"),
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
@ -153,7 +154,7 @@ async def test_import_with_name(hass: HomeAssistant) -> None:
async def test_bad_import(hass: HomeAssistant) -> None: async def test_bad_import(hass: HomeAssistant) -> None:
"""Test import step.""" """Test import step."""
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ConnectionRefusedError(), side_effect=ConnectionRefusedError(),
): ):
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -198,7 +199,7 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
) )
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=socket.gaierror(), side_effect=socket.gaierror(),
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
@ -208,8 +209,8 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
assert result["errors"] == {CONF_HOST: "resolve_failed"} assert result["errors"] == {CONF_HOST: "resolve_failed"}
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=socket.timeout(), side_effect=asyncio.TimeoutError,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_HOST: HOST} result["flow_id"], user_input={CONF_HOST: HOST}
@ -218,7 +219,7 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
assert result["errors"] == {CONF_HOST: "connection_timeout"} assert result["errors"] == {CONF_HOST: "connection_timeout"}
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ConnectionRefusedError, side_effect=ConnectionRefusedError,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(

View File

@ -57,7 +57,7 @@ async def test_async_setup_entry_bad_cert(hass: HomeAssistant) -> None:
) )
with patch( with patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("some error"), side_effect=ssl.SSLError("some error"),
): ):
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -146,7 +146,7 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=24) next_update = starting_time + timedelta(hours=24)
with freeze_time(next_update), patch( with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=socket.gaierror, side_effect=socket.gaierror,
): ):
async_fire_time_changed(hass, utcnow() + timedelta(hours=24)) async_fire_time_changed(hass, utcnow() + timedelta(hours=24))
@ -174,7 +174,7 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=72) next_update = starting_time + timedelta(hours=72)
with freeze_time(next_update), patch( with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert", "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("something bad"), side_effect=ssl.SSLError("something bad"),
): ):
async_fire_time_changed(hass, utcnow() + timedelta(hours=72)) async_fire_time_changed(hass, utcnow() + timedelta(hours=72))
@ -189,7 +189,8 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=96) next_update = starting_time + timedelta(hours=96)
with freeze_time(next_update), patch( with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert", side_effect=Exception() "homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=Exception(),
): ):
async_fire_time_changed(hass, utcnow() + timedelta(hours=96)) async_fire_time_changed(hass, utcnow() + timedelta(hours=96))
await hass.async_block_till_done() await hass.async_block_till_done()