Refactor ESPHome connection management logic into a class (#95457)

* Refactor ESPHome setup logic into a class

Avoids all the nonlocals and fixes the C901

* cleanup

* touch ups

* touch ups

* touch ups

* make easier to read

* stale
This commit is contained in:
J. Nick Koston 2023-06-28 20:39:31 -05:00 committed by GitHub
parent a7dfe46fb1
commit dfe7c5ebed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -137,57 +137,60 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry( # noqa: C901 class ESPHomeManager:
hass: HomeAssistant, entry: ConfigEntry """Class to manage an ESPHome connection."""
) -> bool:
"""Set up the esphome component."""
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK)
device_id: str = None # type: ignore[assignment]
zeroconf_instance = await zeroconf.async_get_instance(hass) __slots__ = (
"hass",
cli = APIClient( "host",
host, "password",
port, "entry",
password, "cli",
client_info=f"Home Assistant {ha_version}", "device_id",
zeroconf_instance=zeroconf_instance, "domain_data",
noise_psk=noise_psk, "voice_assistant_udp_server",
"reconnect_logic",
"zeroconf_instance",
"entry_data",
) )
services_issue = f"service_calls_not_enabled-{entry.unique_id}" def __init__(
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS): self,
async_delete_issue(hass, DOMAIN, services_issue) hass: HomeAssistant,
entry: ConfigEntry,
host: str,
password: str | None,
cli: APIClient,
zeroconf_instance: zeroconf.HaZeroconf,
domain_data: DomainData,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize the esphome manager."""
self.hass = hass
self.host = host
self.password = password
self.entry = entry
self.cli = cli
self.device_id: str | None = None
self.domain_data = domain_data
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
self.reconnect_logic: ReconnectLogic | None = None
self.zeroconf_instance = zeroconf_instance
self.entry_data = entry_data
domain_data = DomainData.get(hass) async def on_stop(self, event: Event) -> None:
entry_data = RuntimeEntryData(
client=cli,
entry_id=entry.entry_id,
store=domain_data.get_or_create_store(hass, entry),
original_options=dict(entry.options),
)
domain_data.set_entry_data(entry, entry_data)
async def on_stop(event: Event) -> None:
"""Cleanup the socket client on HA stop.""" """Cleanup the socket client on HA stop."""
await _cleanup_instance(hass, entry) await _cleanup_instance(self.hass, self.entry)
# Use async_listen instead of async_listen_once so that we don't deregister @property
# the callback twice when shutting down Home Assistant. def services_issue(self) -> str:
# "Unable to remove unknown listener """Return the services issue name for this entry."""
# <function EventBus.async_listen_once.<locals>.onetime_listener>" return f"service_calls_not_enabled-{self.entry.unique_id}"
entry_data.cleanup_callbacks.append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop)
)
@callback @callback
def async_on_service_call(service: HomeassistantServiceCall) -> None: def async_on_service_call(self, service: HomeassistantServiceCall) -> None:
"""Call service when user automation in ESPHome config is triggered.""" """Call service when user automation in ESPHome config is triggered."""
device_info = entry_data.device_info hass = self.hass
assert device_info is not None
domain, service_name = service.service.split(".", 1) domain, service_name = service.service.split(".", 1)
service_data = service.data service_data = service.data
@ -201,15 +204,16 @@ async def async_setup_entry( # noqa: C901
template.render_complex(data_template, service.variables) template.render_complex(data_template, service.variables)
) )
except TemplateError as ex: except TemplateError as ex:
_LOGGER.error("Error rendering data template for %s: %s", host, ex) _LOGGER.error("Error rendering data template for %s: %s", self.host, ex)
return return
if service.is_event: if service.is_event:
device_id = self.device_id
# ESPHome uses service call packet for both events and service calls # ESPHome uses service call packet for both events and service calls
# Ensure the user can only send events of form 'esphome.xyz' # Ensure the user can only send events of form 'esphome.xyz'
if domain != "esphome": if domain != "esphome":
_LOGGER.error( _LOGGER.error(
"Can only generate events under esphome domain! (%s)", host "Can only generate events under esphome domain! (%s)", self.host
) )
return return
@ -226,17 +230,21 @@ async def async_setup_entry( # noqa: C901
**service_data, **service_data,
}, },
) )
elif entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS): elif self.entry.options.get(
CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS
):
hass.async_create_task( hass.async_create_task(
hass.services.async_call( hass.services.async_call(
domain, service_name, service_data, blocking=True domain, service_name, service_data, blocking=True
) )
) )
else: else:
device_info = self.entry_data.device_info
assert device_info is not None
async_create_issue( async_create_issue(
hass, hass,
DOMAIN, DOMAIN,
services_issue, self.services_issue,
is_fixable=False, is_fixable=False,
severity=IssueSeverity.WARNING, severity=IssueSeverity.WARNING,
translation_key="service_calls_not_allowed", translation_key="service_calls_not_allowed",
@ -256,7 +264,7 @@ async def async_setup_entry( # noqa: C901
) )
async def _send_home_assistant_state( async def _send_home_assistant_state(
entity_id: str, attribute: str | None, state: State | None self, entity_id: str, attribute: str | None, state: State | None
) -> None: ) -> None:
"""Forward Home Assistant states to ESPHome.""" """Forward Home Assistant states to ESPHome."""
if state is None or (attribute and attribute not in state.attributes): if state is None or (attribute and attribute not in state.attributes):
@ -271,102 +279,102 @@ async def async_setup_entry( # noqa: C901
else: else:
send_state = attr_val send_state = attr_val
await cli.send_home_assistant_state(entity_id, attribute, str(send_state)) await self.cli.send_home_assistant_state(entity_id, attribute, str(send_state))
@callback @callback
def async_on_state_subscription( def async_on_state_subscription(
entity_id: str, attribute: str | None = None self, entity_id: str, attribute: str | None = None
) -> None: ) -> None:
"""Subscribe and forward states for requested entities.""" """Subscribe and forward states for requested entities."""
hass = self.hass
async def send_home_assistant_state_event(event: Event) -> None: async def send_home_assistant_state_event(event: Event) -> None:
"""Forward Home Assistant states updates to ESPHome.""" """Forward Home Assistant states updates to ESPHome."""
event_data = event.data
new_state: State | None = event_data.get("new_state")
old_state: State | None = event_data.get("old_state")
if new_state is None or old_state is None:
return
# Only communicate changes to the state or attribute tracked # Only communicate changes to the state or attribute tracked
if event.data.get("new_state") is None or ( if (not attribute and old_state.state == new_state.state) or (
event.data.get("old_state") is not None attribute
and "new_state" in event.data and old_state.attributes.get(attribute)
and ( == new_state.attributes.get(attribute)
(
not attribute
and event.data["old_state"].state
== event.data["new_state"].state
)
or (
attribute
and attribute in event.data["old_state"].attributes
and attribute in event.data["new_state"].attributes
and event.data["old_state"].attributes[attribute]
== event.data["new_state"].attributes[attribute]
)
)
): ):
return return
await _send_home_assistant_state( await self._send_home_assistant_state(
event.data["entity_id"], attribute, event.data.get("new_state") event.data["entity_id"], attribute, new_state
) )
unsub = async_track_state_change_event( self.entry_data.disconnect_callbacks.append(
hass, [entity_id], send_home_assistant_state_event async_track_state_change_event(
hass, [entity_id], send_home_assistant_state_event
)
) )
entry_data.disconnect_callbacks.append(unsub)
# Send initial state # Send initial state
hass.async_create_task( hass.async_create_task(
_send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id)) self._send_home_assistant_state(
entity_id, attribute, hass.states.get(entity_id)
)
) )
voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
def _handle_pipeline_event( def _handle_pipeline_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None self, event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None: ) -> None:
cli.send_voice_assistant_event(event_type, data) self.cli.send_voice_assistant_event(event_type, data)
def _handle_pipeline_finished() -> None: def _handle_pipeline_finished(self) -> None:
nonlocal voice_assistant_udp_server self.entry_data.async_set_assist_pipeline_state(False)
entry_data.async_set_assist_pipeline_state(False) if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.close()
self.voice_assistant_udp_server = None
if voice_assistant_udp_server is not None: async def _handle_pipeline_start(
voice_assistant_udp_server.close() self, conversation_id: str, use_vad: bool
voice_assistant_udp_server = None ) -> int | None:
async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None:
"""Start a voice assistant pipeline.""" """Start a voice assistant pipeline."""
nonlocal voice_assistant_udp_server if self.voice_assistant_udp_server is not None:
if voice_assistant_udp_server is not None:
return None return None
hass = self.hass
voice_assistant_udp_server = VoiceAssistantUDPServer( voice_assistant_udp_server = VoiceAssistantUDPServer(
hass, entry_data, _handle_pipeline_event, _handle_pipeline_finished hass,
self.entry_data,
self._handle_pipeline_event,
self._handle_pipeline_finished,
) )
port = await voice_assistant_udp_server.start_server() port = await voice_assistant_udp_server.start_server()
assert self.device_id is not None, "Device ID must be set"
hass.async_create_background_task( hass.async_create_background_task(
voice_assistant_udp_server.run_pipeline( voice_assistant_udp_server.run_pipeline(
device_id=device_id, device_id=self.device_id,
conversation_id=conversation_id or None, conversation_id=conversation_id or None,
use_vad=use_vad, use_vad=use_vad,
), ),
"esphome.voice_assistant_udp_server.run_pipeline", "esphome.voice_assistant_udp_server.run_pipeline",
) )
entry_data.async_set_assist_pipeline_state(True) self.entry_data.async_set_assist_pipeline_state(True)
return port return port
async def _handle_pipeline_stop() -> None: async def _handle_pipeline_stop(self) -> None:
"""Stop a voice assistant pipeline.""" """Stop a voice assistant pipeline."""
nonlocal voice_assistant_udp_server if self.voice_assistant_udp_server is not None:
self.voice_assistant_udp_server.stop()
if voice_assistant_udp_server is not None: async def on_connect(self) -> None:
voice_assistant_udp_server.stop()
async def on_connect() -> None:
"""Subscribe to states and list entities on successful API login.""" """Subscribe to states and list entities on successful API login."""
nonlocal device_id entry = self.entry
entry_data = self.entry_data
reconnect_logic = self.reconnect_logic
hass = self.hass
cli = self.cli
try: try:
device_info = await cli.device_info() device_info = await cli.device_info()
@ -389,6 +397,7 @@ async def async_setup_entry( # noqa: C901
entry_data.api_version = cli.api_version entry_data.api_version = cli.api_version
entry_data.available = True entry_data.available = True
if entry_data.device_info.name: if entry_data.device_info.name:
assert reconnect_logic is not None, "Reconnect logic must be set"
reconnect_logic.name = entry_data.device_info.name reconnect_logic.name = entry_data.device_info.name
if device_info.bluetooth_proxy_feature_flags_compat(cli.api_version): if device_info.bluetooth_proxy_feature_flags_compat(cli.api_version):
@ -396,37 +405,38 @@ async def async_setup_entry( # noqa: C901
await async_connect_scanner(hass, entry, cli, entry_data) await async_connect_scanner(hass, entry, cli, entry_data)
) )
device_id = _async_setup_device_registry( _async_setup_device_registry(hass, entry, entry_data.device_info)
hass, entry, entry_data.device_info
)
entry_data.async_update_device_state(hass) entry_data.async_update_device_state(hass)
entity_infos, services = await cli.list_entities_services() entity_infos, services = await cli.list_entities_services()
await entry_data.async_update_static_infos(hass, entry, entity_infos) await entry_data.async_update_static_infos(hass, entry, entity_infos)
await _setup_services(hass, entry_data, services) await _setup_services(hass, entry_data, services)
await cli.subscribe_states(entry_data.async_update_state) await cli.subscribe_states(entry_data.async_update_state)
await cli.subscribe_service_calls(async_on_service_call) await cli.subscribe_service_calls(self.async_on_service_call)
await cli.subscribe_home_assistant_states(async_on_state_subscription) await cli.subscribe_home_assistant_states(self.async_on_state_subscription)
if device_info.voice_assistant_version: if device_info.voice_assistant_version:
entry_data.disconnect_callbacks.append( entry_data.disconnect_callbacks.append(
await cli.subscribe_voice_assistant( await cli.subscribe_voice_assistant(
_handle_pipeline_start, self._handle_pipeline_start,
_handle_pipeline_stop, self._handle_pipeline_stop,
) )
) )
hass.async_create_task(entry_data.async_save_to_store()) hass.async_create_task(entry_data.async_save_to_store())
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.warning("Error getting initial data for %s: %s", host, err) _LOGGER.warning("Error getting initial data for %s: %s", self.host, err)
# Re-connection logic will trigger after this # Re-connection logic will trigger after this
await cli.disconnect() await cli.disconnect()
else: else:
_async_check_firmware_version(hass, device_info, entry_data.api_version) _async_check_firmware_version(hass, device_info, entry_data.api_version)
_async_check_using_api_password(hass, device_info, bool(password)) _async_check_using_api_password(hass, device_info, bool(self.password))
async def on_disconnect(expected_disconnect: bool) -> None: async def on_disconnect(self, expected_disconnect: bool) -> None:
"""Run disconnect callbacks on API disconnect.""" """Run disconnect callbacks on API disconnect."""
entry_data = self.entry_data
hass = self.hass
host = self.host
name = entry_data.device_info.name if entry_data.device_info else host name = entry_data.device_info.name if entry_data.device_info else host
_LOGGER.debug( _LOGGER.debug(
"%s: %s disconnected (expected=%s), running disconnected callbacks", "%s: %s disconnected (expected=%s), running disconnected callbacks",
@ -453,7 +463,7 @@ async def async_setup_entry( # noqa: C901
# will be cleared anyway. # will be cleared anyway.
entry_data.async_update_device_state(hass) entry_data.async_update_device_state(hass)
async def on_connect_error(err: Exception) -> None: async def on_connect_error(self, err: Exception) -> None:
"""Start reauth flow if appropriate connect error type.""" """Start reauth flow if appropriate connect error type."""
if isinstance( if isinstance(
err, err,
@ -463,32 +473,85 @@ async def async_setup_entry( # noqa: C901
InvalidAuthAPIError, InvalidAuthAPIError,
), ),
): ):
entry.async_start_reauth(hass) self.entry.async_start_reauth(self.hass)
reconnect_logic = ReconnectLogic( async def async_start(self) -> None:
client=cli, """Start the esphome connection manager."""
on_connect=on_connect, hass = self.hass
on_disconnect=on_disconnect, entry = self.entry
entry_data = self.entry_data
if entry.options.get(CONF_ALLOW_SERVICE_CALLS, DEFAULT_ALLOW_SERVICE_CALLS):
async_delete_issue(hass, DOMAIN, self.services_issue)
# Use async_listen instead of async_listen_once so that we don't deregister
# the callback twice when shutting down Home Assistant.
# "Unable to remove unknown listener
# <function EventBus.async_listen_once.<locals>.onetime_listener>"
entry_data.cleanup_callbacks.append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, self.on_stop)
)
reconnect_logic = ReconnectLogic(
client=self.cli,
on_connect=self.on_connect,
on_disconnect=self.on_disconnect,
zeroconf_instance=self.zeroconf_instance,
name=self.host,
on_connect_error=self.on_connect_error,
)
self.reconnect_logic = reconnect_logic
infos, services = await entry_data.async_load_from_store()
await entry_data.async_update_static_infos(hass, entry, infos)
await _setup_services(hass, entry_data, services)
if entry_data.device_info is not None and entry_data.device_info.name:
reconnect_logic.name = entry_data.device_info.name
if entry.unique_id is None:
hass.config_entries.async_update_entry(
entry, unique_id=format_mac(entry_data.device_info.mac_address)
)
await reconnect_logic.start()
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
entry.async_on_unload(
entry.add_update_listener(entry_data.async_update_listener)
)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the esphome component."""
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK)
zeroconf_instance = await zeroconf.async_get_instance(hass)
cli = APIClient(
host,
port,
password,
client_info=f"Home Assistant {ha_version}",
zeroconf_instance=zeroconf_instance, zeroconf_instance=zeroconf_instance,
name=host, noise_psk=noise_psk,
on_connect_error=on_connect_error,
) )
infos, services = await entry_data.async_load_from_store() domain_data = DomainData.get(hass)
await entry_data.async_update_static_infos(hass, entry, infos) entry_data = RuntimeEntryData(
await _setup_services(hass, entry_data, services) client=cli,
entry_id=entry.entry_id,
store=domain_data.get_or_create_store(hass, entry),
original_options=dict(entry.options),
)
domain_data.set_entry_data(entry, entry_data)
if entry_data.device_info is not None and entry_data.device_info.name: manager = ESPHomeManager(
reconnect_logic.name = entry_data.device_info.name hass, entry, host, password, cli, zeroconf_instance, domain_data, entry_data
if entry.unique_id is None: )
hass.config_entries.async_update_entry( await manager.async_start()
entry, unique_id=format_mac(entry_data.device_info.mac_address)
)
await reconnect_logic.start()
entry_data.cleanup_callbacks.append(reconnect_logic.stop_callback)
entry.async_on_unload(entry.add_update_listener(entry_data.async_update_listener))
return True return True