Wait for discovery to complete before starting apple_tv (#74133)

This commit is contained in:
J. Nick Koston 2022-06-29 03:13:10 -05:00 committed by GitHub
parent 6a0ca2b36d
commit 99329ef04f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
@ -49,6 +50,13 @@ PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry for Apple TV.""" """Set up a config entry for Apple TV."""
manager = AppleTVManager(hass, entry) manager = AppleTVManager(hass, entry)
if manager.is_on:
await manager.connect_once(raise_missing_credentials=True)
if not manager.atv:
address = entry.data[CONF_ADDRESS]
raise ConfigEntryNotReady(f"Not found at {address}, waiting for discovery")
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
async def on_hass_stop(event): async def on_hass_stop(event):
@ -148,14 +156,14 @@ class AppleTVManager:
self.config_entry = config_entry self.config_entry = config_entry
self.hass = hass self.hass = hass
self.atv = None self.atv = None
self._is_on = not config_entry.options.get(CONF_START_OFF, False) self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0 self._connection_attempts = 0
self._connection_was_lost = False self._connection_was_lost = False
self._task = None self._task = None
async def init(self): async def init(self):
"""Initialize power management.""" """Initialize power management."""
if self._is_on: if self.is_on:
await self.connect() await self.connect()
def connection_lost(self, _): def connection_lost(self, _):
@ -186,13 +194,13 @@ class AppleTVManager:
async def connect(self): async def connect(self):
"""Connect to device.""" """Connect to device."""
self._is_on = True self.is_on = True
self._start_connect_loop() self._start_connect_loop()
async def disconnect(self): async def disconnect(self):
"""Disconnect from device.""" """Disconnect from device."""
_LOGGER.debug("Disconnecting from device") _LOGGER.debug("Disconnecting from device")
self._is_on = False self.is_on = False
try: try:
if self.atv: if self.atv:
self.atv.close() self.atv.close()
@ -205,24 +213,18 @@ class AppleTVManager:
def _start_connect_loop(self): def _start_connect_loop(self):
"""Start background connect loop to device.""" """Start background connect loop to device."""
if not self._task and self.atv is None and self._is_on: if not self._task and self.atv is None and self.is_on:
self._task = asyncio.create_task(self._connect_loop()) self._task = asyncio.create_task(self._connect_loop())
else: else:
_LOGGER.debug( _LOGGER.debug(
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on "Not starting connect loop (%s, %s)", self.atv is None, self.is_on
) )
async def _connect_loop(self): async def connect_once(self, raise_missing_credentials):
"""Connect loop background task function.""" """Try to connect once."""
_LOGGER.debug("Starting connect loop")
# Try to find device and connect as long as the user has said that
# we are allowed to connect and we are not already connected.
while self._is_on and self.atv is None:
try: try:
conf = await self._scan() if conf := await self._scan():
if conf: await self._connect(conf, raise_missing_credentials)
await self._connect(conf)
except exceptions.AuthenticationError: except exceptions.AuthenticationError:
self.config_entry.async_start_reauth(self.hass) self.config_entry.async_start_reauth(self.hass)
asyncio.create_task(self.disconnect()) asyncio.create_task(self.disconnect())
@ -230,14 +232,23 @@ class AppleTVManager:
"Authentication failed for %s, try reconfiguring device", "Authentication failed for %s, try reconfiguring device",
self.config_entry.data[CONF_NAME], self.config_entry.data[CONF_NAME],
) )
break return
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Failed to connect") _LOGGER.exception("Failed to connect")
self.atv = None self.atv = None
if self.atv is None: async def _connect_loop(self):
"""Connect loop background task function."""
_LOGGER.debug("Starting connect loop")
# Try to find device and connect as long as the user has said that
# we are allowed to connect and we are not already connected.
while self.is_on and self.atv is None:
await self.connect_once(raise_missing_credentials=False)
if self.atv is not None:
break
self._connection_attempts += 1 self._connection_attempts += 1
backoff = min( backoff = min(
max( max(
@ -287,23 +298,33 @@ class AppleTVManager:
# it will update the address and reload the config entry when the device is found. # it will update the address and reload the config entry when the device is found.
return None return None
async def _connect(self, conf): async def _connect(self, conf, raise_missing_credentials):
"""Connect to device.""" """Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS] credentials = self.config_entry.data[CONF_CREDENTIALS]
session = async_get_clientsession(self.hass) name = self.config_entry.data[CONF_NAME]
missing_protocols = []
for protocol_int, creds in credentials.items(): for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int)) protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None: if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds) conf.set_credentials(protocol, creds)
else: else:
_LOGGER.warning( missing_protocols.append(protocol.name)
"Protocol %s not found for %s, functionality will be reduced",
protocol.name, if missing_protocols:
self.config_entry.data[CONF_NAME], missing_protocols_str = ", ".join(missing_protocols)
if raise_missing_credentials:
raise ConfigEntryNotReady(
f"Protocol(s) {missing_protocols_str} not yet found for {name}, waiting for discovery."
) )
_LOGGER.info(
"Protocol(s) %s not yet found for %s, trying later",
missing_protocols_str,
name,
)
return
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME]) _LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
session = async_get_clientsession(self.hass)
self.atv = await connect(conf, self.hass.loop, session=session) self.atv = await connect(conf, self.hass.loop, session=session)
self.atv.listener = self self.atv.listener = self