diff --git a/homeassistant/components/apple_tv/__init__.py b/homeassistant/components/apple_tv/__init__.py index d61c21972fb..45250451f37 100644 --- a/homeassistant/components/apple_tv/__init__.py +++ b/homeassistant/components/apple_tv/__init__.py @@ -23,6 +23,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import device_registry as dr from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import ( @@ -49,6 +50,13 @@ PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE] async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry for Apple TV.""" manager = AppleTVManager(hass, entry) + + if manager.is_on: + await manager.connect_once(raise_missing_credentials=True) + if not manager.atv: + address = entry.data[CONF_ADDRESS] + raise ConfigEntryNotReady(f"Not found at {address}, waiting for discovery") + hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager async def on_hass_stop(event): @@ -148,14 +156,14 @@ class AppleTVManager: self.config_entry = config_entry self.hass = hass self.atv = None - self._is_on = not config_entry.options.get(CONF_START_OFF, False) + self.is_on = not config_entry.options.get(CONF_START_OFF, False) self._connection_attempts = 0 self._connection_was_lost = False self._task = None async def init(self): """Initialize power management.""" - if self._is_on: + if self.is_on: await self.connect() def connection_lost(self, _): @@ -186,13 +194,13 @@ class AppleTVManager: async def connect(self): """Connect to device.""" - self._is_on = True + self.is_on = True self._start_connect_loop() async def disconnect(self): """Disconnect from device.""" _LOGGER.debug("Disconnecting from device") - self._is_on = False + self.is_on = False try: if self.atv: self.atv.close() @@ -205,50 +213,53 @@ class AppleTVManager: def _start_connect_loop(self): """Start background connect loop to device.""" - if not self._task and self.atv is None and self._is_on: + if not self._task and self.atv is None and self.is_on: self._task = asyncio.create_task(self._connect_loop()) else: _LOGGER.debug( - "Not starting connect loop (%s, %s)", self.atv is None, self._is_on + "Not starting connect loop (%s, %s)", self.atv is None, self.is_on ) + async def connect_once(self, raise_missing_credentials): + """Try to connect once.""" + try: + if conf := await self._scan(): + await self._connect(conf, raise_missing_credentials) + except exceptions.AuthenticationError: + self.config_entry.async_start_reauth(self.hass) + asyncio.create_task(self.disconnect()) + _LOGGER.exception( + "Authentication failed for %s, try reconfiguring device", + self.config_entry.data[CONF_NAME], + ) + return + except asyncio.CancelledError: + pass + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Failed to connect") + self.atv = None + async def _connect_loop(self): """Connect loop background task function.""" _LOGGER.debug("Starting connect loop") # Try to find device and connect as long as the user has said that # we are allowed to connect and we are not already connected. - while self._is_on and self.atv is None: - try: - conf = await self._scan() - if conf: - await self._connect(conf) - except exceptions.AuthenticationError: - self.config_entry.async_start_reauth(self.hass) - asyncio.create_task(self.disconnect()) - _LOGGER.exception( - "Authentication failed for %s, try reconfiguring device", - self.config_entry.data[CONF_NAME], - ) + while self.is_on and self.atv is None: + await self.connect_once(raise_missing_credentials=False) + if self.atv is not None: break - except asyncio.CancelledError: - pass - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Failed to connect") - self.atv = None + self._connection_attempts += 1 + backoff = min( + max( + BACKOFF_TIME_LOWER_LIMIT, + randrange(2**self._connection_attempts), + ), + BACKOFF_TIME_UPPER_LIMIT, + ) - if self.atv is None: - self._connection_attempts += 1 - backoff = min( - max( - BACKOFF_TIME_LOWER_LIMIT, - randrange(2**self._connection_attempts), - ), - BACKOFF_TIME_UPPER_LIMIT, - ) - - _LOGGER.debug("Reconnecting in %d seconds", backoff) - await asyncio.sleep(backoff) + _LOGGER.debug("Reconnecting in %d seconds", backoff) + await asyncio.sleep(backoff) _LOGGER.debug("Connect loop ended") self._task = None @@ -287,23 +298,33 @@ class AppleTVManager: # it will update the address and reload the config entry when the device is found. return None - async def _connect(self, conf): + async def _connect(self, conf, raise_missing_credentials): """Connect to device.""" credentials = self.config_entry.data[CONF_CREDENTIALS] - session = async_get_clientsession(self.hass) - + name = self.config_entry.data[CONF_NAME] + missing_protocols = [] for protocol_int, creds in credentials.items(): protocol = Protocol(int(protocol_int)) if conf.get_service(protocol) is not None: conf.set_credentials(protocol, creds) else: - _LOGGER.warning( - "Protocol %s not found for %s, functionality will be reduced", - protocol.name, - self.config_entry.data[CONF_NAME], + missing_protocols.append(protocol.name) + + if missing_protocols: + missing_protocols_str = ", ".join(missing_protocols) + if raise_missing_credentials: + raise ConfigEntryNotReady( + f"Protocol(s) {missing_protocols_str} not yet found for {name}, waiting for discovery." ) + _LOGGER.info( + "Protocol(s) %s not yet found for %s, trying later", + missing_protocols_str, + name, + ) + return _LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME]) + session = async_get_clientsession(self.hass) self.atv = await connect(conf, self.hass.loop, session=session) self.atv.listener = self