From 636c7ce3504a033df22a7f15de55c8b3241452b1 Mon Sep 17 00:00:00 2001 From: Stackie Jia Date: Thu, 15 Feb 2024 22:17:00 +0800 Subject: [PATCH] Enable strict type checking on apple_tv integration (#101688) * Enable strict type checking on apple_tv integration * move some instance variables to class variables * fix type of attr_value * fix tests for description_placeholders assertion * nits * Apply suggestions from code review * Update remote.py * Apply suggestions from code review * Update __init__.py * Update __init__.py * Update __init__.py * Update config_flow.py * Improve test coverage * Update test_config_flow.py * Update __init__.py --------- Co-authored-by: Joost Lekkerkerker Co-authored-by: Erik Montnemery --- .strict-typing | 1 + homeassistant/components/apple_tv/__init__.py | 64 +++++----- .../components/apple_tv/config_flow.py | 112 +++++++++++++----- .../components/apple_tv/media_player.py | 109 ++++++++++++----- homeassistant/components/apple_tv/remote.py | 16 +-- mypy.ini | 10 ++ tests/components/apple_tv/test_config_flow.py | 45 ++++++- 7 files changed, 253 insertions(+), 104 deletions(-) diff --git a/.strict-typing b/.strict-typing index bd92da2fc50..74535719bb3 100644 --- a/.strict-typing +++ b/.strict-typing @@ -80,6 +80,7 @@ homeassistant.components.anthemav.* homeassistant.components.apache_kafka.* homeassistant.components.apcupsd.* homeassistant.components.api.* +homeassistant.components.apple_tv.* homeassistant.components.apprise.* homeassistant.components.aprs.* homeassistant.components.aqualogic.* diff --git a/homeassistant/components/apple_tv/__init__.py b/homeassistant/components/apple_tv/__init__.py index 875a23c3244..c369b07de36 100644 --- a/homeassistant/components/apple_tv/__init__.py +++ b/homeassistant/components/apple_tv/__init__.py @@ -1,8 +1,10 @@ """The Apple TV integration.""" +from __future__ import annotations + import asyncio import logging from random import randrange -from typing import TYPE_CHECKING, cast +from typing import Any, cast from pyatv import connect, exceptions, scan from pyatv.conf import AppleTV @@ -25,7 +27,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, Platform, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import device_registry as dr from homeassistant.helpers.aiohttp_client import async_get_clientsession @@ -89,7 +91,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager - async def on_hass_stop(event): + async def on_hass_stop(event: Event) -> None: """Stop push updates when hass stops.""" await manager.disconnect() @@ -120,33 +122,29 @@ class AppleTVEntity(Entity): _attr_should_poll = False _attr_has_entity_name = True _attr_name = None + atv: AppleTVInterface | None = None - def __init__( - self, name: str, identifier: str | None, manager: "AppleTVManager" - ) -> None: + def __init__(self, name: str, identifier: str, manager: AppleTVManager) -> None: """Initialize device.""" - self.atv: AppleTVInterface = None # type: ignore[assignment] self.manager = manager - if TYPE_CHECKING: - assert identifier is not None self._attr_unique_id = identifier self._attr_device_info = DeviceInfo( identifiers={(DOMAIN, identifier)}, name=name, ) - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Handle when an entity is about to be added to Home Assistant.""" @callback - def _async_connected(atv): + def _async_connected(atv: AppleTVInterface) -> None: """Handle that a connection was made to a device.""" self.atv = atv self.async_device_connected(atv) self.async_write_ha_state() @callback - def _async_disconnected(): + def _async_disconnected() -> None: """Handle that a connection to a device was lost.""" self.async_device_disconnected() self.atv = None @@ -169,10 +167,10 @@ class AppleTVEntity(Entity): ) ) - def async_device_connected(self, atv): + def async_device_connected(self, atv: AppleTVInterface) -> None: """Handle when connection is made to device.""" - def async_device_disconnected(self): + def async_device_disconnected(self) -> None: """Handle when connection was lost to device.""" @@ -184,22 +182,23 @@ class AppleTVManager(DeviceListener): in case of problems. """ + atv: AppleTVInterface | None = None + _connection_attempts = 0 + _connection_was_lost = False + _task: asyncio.Task[None] | None = None + def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: """Initialize power manager.""" self.config_entry = config_entry self.hass = hass - self.atv: AppleTVInterface | None = None self.is_on = not config_entry.options.get(CONF_START_OFF, False) - self._connection_attempts = 0 - self._connection_was_lost = False - self._task = None - async def init(self): + async def init(self) -> None: """Initialize power management.""" if self.is_on: await self.connect() - def connection_lost(self, _): + def connection_lost(self, exception: Exception) -> None: """Device was unexpectedly disconnected. This is a callback function from pyatv.interface.DeviceListener. @@ -210,14 +209,14 @@ class AppleTVManager(DeviceListener): self._connection_was_lost = True self._handle_disconnect() - def connection_closed(self): + def connection_closed(self) -> None: """Device connection was (intentionally) closed. This is a callback function from pyatv.interface.DeviceListener. """ self._handle_disconnect() - def _handle_disconnect(self): + def _handle_disconnect(self) -> None: """Handle that the device disconnected and restart connect loop.""" if self.atv: self.atv.close() @@ -225,12 +224,12 @@ class AppleTVManager(DeviceListener): self._dispatch_send(SIGNAL_DISCONNECTED) self._start_connect_loop() - async def connect(self): + async def connect(self) -> None: """Connect to device.""" self.is_on = True self._start_connect_loop() - async def disconnect(self): + async def disconnect(self) -> None: """Disconnect from device.""" _LOGGER.debug("Disconnecting from device") self.is_on = False @@ -244,7 +243,7 @@ class AppleTVManager(DeviceListener): except Exception: # pylint: disable=broad-except _LOGGER.exception("An error occurred while disconnecting") - def _start_connect_loop(self): + def _start_connect_loop(self) -> None: """Start background connect loop to device.""" if not self._task and self.atv is None and self.is_on: self._task = asyncio.create_task(self._connect_loop()) @@ -258,7 +257,7 @@ class AppleTVManager(DeviceListener): if conf := await self._scan(): await self._connect(conf, raise_missing_credentials) - async def async_first_connect(self): + async def async_first_connect(self) -> None: """Connect to device for the first time.""" connect_ok = False try: @@ -286,7 +285,7 @@ class AppleTVManager(DeviceListener): _LOGGER.exception("Failed to connect") await self.disconnect() - async def _connect_loop(self): + async def _connect_loop(self) -> None: """Connect loop background task function.""" _LOGGER.debug("Starting connect loop") @@ -295,7 +294,8 @@ class AppleTVManager(DeviceListener): while self.is_on and self.atv is None: await self.connect_once(raise_missing_credentials=False) if self.atv is not None: - break + # Calling self.connect_once may have set self.atv + break # type: ignore[unreachable] self._connection_attempts += 1 backoff = min( max( @@ -392,7 +392,7 @@ class AppleTVManager(DeviceListener): self._connection_was_lost = False @callback - def _async_setup_device_registry(self): + def _async_setup_device_registry(self) -> None: attrs = { ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)}, ATTR_MANUFACTURER: "Apple", @@ -423,18 +423,18 @@ class AppleTVManager(DeviceListener): ) @property - def is_connecting(self): + def is_connecting(self) -> bool: """Return true if connection is in progress.""" return self._task is not None - def _address_updated(self, address): + def _address_updated(self, address: str) -> None: """Update cached address in config entry.""" _LOGGER.debug("Changing address to %s", address) self.hass.config_entries.async_update_entry( self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address} ) - def _dispatch_send(self, signal, *args): + def _dispatch_send(self, signal: str, *args: Any) -> None: """Dispatch a signal to all entities managed by this manager.""" async_dispatcher_send( self.hass, f"{signal}_{self.config_entry.unique_id}", *args diff --git a/homeassistant/components/apple_tv/config_flow.py b/homeassistant/components/apple_tv/config_flow.py index 11d408ee2ca..2bb4608dca1 100644 --- a/homeassistant/components/apple_tv/config_flow.py +++ b/homeassistant/components/apple_tv/config_flow.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import deque -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from ipaddress import ip_address import logging from random import randrange @@ -13,12 +13,13 @@ from pyatv import exceptions, pair, scan from pyatv.const import DeviceModel, PairingRequirement, Protocol from pyatv.convert import model_str, protocol_str from pyatv.helpers import get_unique_id +from pyatv.interface import BaseConfig, PairingHandler import voluptuous as vol from homeassistant import config_entries from homeassistant.components import zeroconf from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.aiohttp_client import async_get_clientsession @@ -49,10 +50,12 @@ OPTIONS_FLOW = { } -async def device_scan(hass, identifier, loop): +async def device_scan( + hass: HomeAssistant, identifier: str | None, loop: asyncio.AbstractEventLoop +) -> tuple[BaseConfig | None, list[str] | None]: """Scan for a specific device using identifier as filter.""" - def _filter_device(dev): + def _filter_device(dev: BaseConfig) -> bool: if identifier is None: return True if identifier == str(dev.address): @@ -61,9 +64,12 @@ async def device_scan(hass, identifier, loop): return True return any(service.identifier == identifier for service in dev.services) - def _host_filter(): + def _host_filter() -> list[str] | None: + if identifier is None: + return None try: - return [ip_address(identifier)] + ip_address(identifier) + return [identifier] except ValueError: return None @@ -84,6 +90,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 + scan_filter: str | None = None + atv: BaseConfig | None = None + atv_identifiers: list[str] | None = None + protocol: Protocol | None = None + pairing: PairingHandler | None = None + protocols_to_pair: deque[Protocol] | None = None + @staticmethod @callback def async_get_options_flow( @@ -92,18 +105,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Get options flow for this handler.""" return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW) - def __init__(self): + def __init__(self) -> None: """Initialize a new AppleTVConfigFlow.""" - self.scan_filter = None - self.atv = None - self.atv_identifiers = None - self.protocol = None - self.pairing = None - self.credentials = {} # Protocol -> credentials - self.protocols_to_pair = deque() + self.credentials: dict[int, str | None] = {} # Protocol -> credentials @property - def device_identifier(self): + def device_identifier(self) -> str | None: """Return a identifier for the config entry. A device has multiple unique identifiers, but Home Assistant only supports one @@ -118,6 +125,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): existing config entry. If that's the case, the unique_id from that entry is re-used, otherwise the newly discovered identifier is used instead. """ + assert self.atv all_identifiers = set(self.atv.all_identifiers) if unique_id := self._entry_unique_id_from_identifers(all_identifiers): return unique_id @@ -143,7 +151,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): self.context["identifier"] = self.unique_id return await self.async_step_reconfigure() - async def async_step_reconfigure(self, user_input=None): + async def async_step_reconfigure( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Inform user that reconfiguration is about to start.""" if user_input is not None: return await self.async_find_device_wrapper( @@ -152,7 +162,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form(step_id="reconfigure") - async def async_step_user(self, user_input=None): + async def async_step_user( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Handle the initial step.""" errors = {} if user_input is not None: @@ -170,6 +182,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): await self.async_set_unique_id( self.device_identifier, raise_on_progress=False ) + assert self.atv self.context["all_identifiers"] = self.atv.all_identifiers return await self.async_step_confirm() @@ -275,8 +288,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): context["all_identifiers"].append(unique_id) raise AbortFlow("already_in_progress") - async def async_found_zeroconf_device(self, user_input=None): + async def async_found_zeroconf_device( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Handle device found after Zeroconf discovery.""" + assert self.atv self.context["all_identifiers"] = self.atv.all_identifiers # Also abort if an integration with this identifier already exists await self.async_set_unique_id(self.device_identifier) @@ -288,7 +304,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): self.context["identifier"] = self.unique_id return await self.async_step_confirm() - async def async_find_device_wrapper(self, next_func, allow_exist=False): + async def async_find_device_wrapper( + self, + next_func: Callable[[], Awaitable[FlowResult]], + allow_exist: bool = False, + ) -> FlowResult: """Find a specific device and call another function when done. This function will do error handling and bail out when an error @@ -306,7 +326,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return await next_func() - async def async_find_device(self, allow_exist=False): + async def async_find_device(self, allow_exist: bool = False) -> None: """Scan for the selected device to discover services.""" self.atv, self.atv_identifiers = await device_scan( self.hass, self.scan_filter, self.hass.loop @@ -357,8 +377,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if not allow_exist: raise DeviceAlreadyConfigured() - async def async_step_confirm(self, user_input=None): + async def async_step_confirm( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Handle user-confirmation of discovered node.""" + assert self.atv if user_input is not None: expected_identifier_count = len(self.context["all_identifiers"]) # If number of services found during device scan mismatch number of @@ -384,7 +407,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): }, ) - async def async_pair_next_protocol(self): + async def async_pair_next_protocol(self) -> FlowResult: """Start pairing process for the next available protocol.""" await self._async_cleanup() @@ -393,8 +416,16 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return await self._async_get_entry() self.protocol = self.protocols_to_pair.popleft() + assert self.atv service = self.atv.get_service(self.protocol) + if service is None: + _LOGGER.debug( + "%s does not support pairing (cannot find a corresponding service)", + self.protocol, + ) + return await self.async_pair_next_protocol() + # Service requires a password if service.requires_password: return await self.async_step_password() @@ -413,7 +444,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): _LOGGER.debug("%s requires pairing", self.protocol) # Protocol specific arguments - pair_args = {} + pair_args: dict[str, Any] = {} if self.protocol in {Protocol.AirPlay, Protocol.Companion, Protocol.DMAP}: pair_args["name"] = "Home Assistant" if self.protocol == Protocol.DMAP: @@ -448,8 +479,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return await self.async_step_pair_no_pin() - async def async_step_protocol_disabled(self, user_input=None): + async def async_step_protocol_disabled( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Inform user that a protocol is disabled and cannot be paired.""" + assert self.protocol if user_input is not None: return await self.async_pair_next_protocol() return self.async_show_form( @@ -457,9 +491,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): description_placeholders={"protocol": protocol_str(self.protocol)}, ) - async def async_step_pair_with_pin(self, user_input=None): + async def async_step_pair_with_pin( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Handle pairing step where a PIN is required from the user.""" errors = {} + assert self.pairing + assert self.protocol if user_input is not None: try: self.pairing.pin(user_input[CONF_PIN]) @@ -480,8 +518,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): description_placeholders={"protocol": protocol_str(self.protocol)}, ) - async def async_step_pair_no_pin(self, user_input=None): + async def async_step_pair_no_pin( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Handle step where user has to enter a PIN on the device.""" + assert self.pairing + assert self.protocol if user_input is not None: await self.pairing.finish() if self.pairing.has_paired: @@ -497,12 +539,15 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): step_id="pair_no_pin", description_placeholders={ "protocol": protocol_str(self.protocol), - "pin": pin, + "pin": str(pin), }, ) - async def async_step_service_problem(self, user_input=None): + async def async_step_service_problem( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Inform user that a service will not be added.""" + assert self.protocol if user_input is not None: return await self.async_pair_next_protocol() @@ -511,8 +556,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): description_placeholders={"protocol": protocol_str(self.protocol)}, ) - async def async_step_password(self, user_input=None): + async def async_step_password( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Inform user that password is not supported.""" + assert self.protocol if user_input is not None: return await self.async_pair_next_protocol() @@ -521,18 +569,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): description_placeholders={"protocol": protocol_str(self.protocol)}, ) - async def _async_cleanup(self): + async def _async_cleanup(self) -> None: """Clean up allocated resources.""" if self.pairing is not None: await self.pairing.close() self.pairing = None - async def _async_get_entry(self): + async def _async_get_entry(self) -> FlowResult: """Return config entry or update existing config entry.""" # Abort if no protocols were paired if not self.credentials: return self.async_abort(reason="setup_failed") + assert self.atv + data = { CONF_NAME: self.atv.name, CONF_CREDENTIALS: self.credentials, diff --git a/homeassistant/components/apple_tv/media_player.py b/homeassistant/components/apple_tv/media_player.py index 789415a1717..a7b5957ecff 100644 --- a/homeassistant/components/apple_tv/media_player.py +++ b/homeassistant/components/apple_tv/media_player.py @@ -16,7 +16,15 @@ from pyatv.const import ( ShuffleState, ) from pyatv.helpers import is_streamable -from pyatv.interface import AppleTV, Playing +from pyatv.interface import ( + AppleTV, + AudioListener, + OutputDevice, + Playing, + PowerListener, + PushListener, + PushUpdater, +) from homeassistant.components import media_source from homeassistant.components.media_player import ( @@ -101,7 +109,9 @@ async def async_setup_entry( async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)]) -class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): +class AppleTvMediaPlayer( + AppleTVEntity, MediaPlayerEntity, PowerListener, AudioListener, PushListener +): """Representation of an Apple TV media player.""" _attr_supported_features = SUPPORT_APPLE_TV @@ -116,9 +126,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): def async_device_connected(self, atv: AppleTV) -> None: """Handle when connection is made to device.""" # NB: Do not use _is_feature_available here as it only works when playing - if self.atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates): - self.atv.push_updater.listener = self - self.atv.push_updater.start() + if atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates): + atv.push_updater.listener = self + atv.push_updater.start() self._attr_supported_features = SUPPORT_BASE @@ -126,7 +136,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): # "Unsupported" are considered here as the state of such a feature can never # change after a connection has been established, i.e. an unsupported feature # can never change to be supported. - all_features = self.atv.features.all_features() + all_features = atv.features.all_features() for feature_name, support_flag in SUPPORT_FEATURE_MAPPING.items(): feature_info = all_features.get(feature_name) if feature_info and feature_info.state != FeatureState.Unsupported: @@ -136,16 +146,18 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): # metadata update arrives (sometime very soon after this callback returns) # Listen to power updates - self.atv.power.listener = self + atv.power.listener = self # Listen to volume updates - self.atv.audio.listener = self + atv.audio.listener = self - if self.atv.features.in_state(FeatureState.Available, FeatureName.AppList): + if atv.features.in_state(FeatureState.Available, FeatureName.AppList): self.hass.create_task(self._update_app_list()) async def _update_app_list(self) -> None: _LOGGER.debug("Updating app list") + if not self.atv: + return try: apps = await self.atv.apps.app_list() except exceptions.NotSupportedError: @@ -189,33 +201,56 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): return None @callback - def playstatus_update(self, _, playing: Playing) -> None: - """Print what is currently playing when it changes.""" - self._playing = playing + def playstatus_update(self, updater: PushUpdater, playstatus: Playing) -> None: + """Print what is currently playing when it changes. + + This is a callback function from pyatv.interface.PushListener. + """ + self._playing = playstatus self.async_write_ha_state() @callback - def playstatus_error(self, _, exception: Exception) -> None: - """Inform about an error and restart push updates.""" + def playstatus_error(self, updater: PushUpdater, exception: Exception) -> None: + """Inform about an error and restart push updates. + + This is a callback function from pyatv.interface.PushListener. + """ _LOGGER.warning("A %s error occurred: %s", exception.__class__, exception) self._playing = None self.async_write_ha_state() @callback def powerstate_update(self, old_state: PowerState, new_state: PowerState) -> None: - """Update power state when it changes.""" + """Update power state when it changes. + + This is a callback function from pyatv.interface.PowerListener. + """ self.async_write_ha_state() @callback def volume_update(self, old_level: float, new_level: float) -> None: - """Update volume when it changes.""" + """Update volume when it changes. + + This is a callback function from pyatv.interface.AudioListener. + """ self.async_write_ha_state() + @callback + def outputdevices_update( + self, old_devices: list[OutputDevice], new_devices: list[OutputDevice] + ) -> None: + """Output devices were updated. + + This is a callback function from pyatv.interface.AudioListener. + """ + @property def app_id(self) -> str | None: """ID of the current running app.""" - if self._is_feature_available(FeatureName.App) and ( - app := self.atv.metadata.app + if ( + self.atv + and self._is_feature_available(FeatureName.App) + and (app := self.atv.metadata.app) is not None ): return app.identifier return None @@ -223,8 +258,10 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): @property def app_name(self) -> str | None: """Name of the current running app.""" - if self._is_feature_available(FeatureName.App) and ( - app := self.atv.metadata.app + if ( + self.atv + and self._is_feature_available(FeatureName.App) + and (app := self.atv.metadata.app) is not None ): return app.name return None @@ -255,7 +292,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): @property def volume_level(self) -> float | None: """Volume level of the media player (0..1).""" - if self._is_feature_available(FeatureName.Volume): + if self.atv and self._is_feature_available(FeatureName.Volume): return self.atv.audio.volume / 100.0 # from percent return None @@ -286,6 +323,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): """Send the play_media command to the media player.""" # If input (file) has a file format supported by pyatv, then stream it with # RAOP. Otherwise try to play it with regular AirPlay. + if not self.atv: + return if media_type in {MediaType.APP, MediaType.URL}: await self.atv.apps.launch_app(media_id) return @@ -313,7 +352,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): """Hash value for media image.""" state = self.state if ( - self._playing + self.atv + and self._playing and self._is_feature_available(FeatureName.Artwork) and state not in {None, MediaPlayerState.OFF, MediaPlayerState.IDLE} ): @@ -323,7 +363,11 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): async def async_get_media_image(self) -> tuple[bytes | None, str | None]: """Fetch media image of current playing image.""" state = self.state - if self._playing and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}: + if ( + self.atv + and self._playing + and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE} + ): artwork = await self.atv.metadata.artwork() if artwork: return artwork.bytes, artwork.mimetype @@ -439,20 +483,24 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): async def async_turn_on(self) -> None: """Turn the media player on.""" - if self._is_feature_available(FeatureName.TurnOn): + if self.atv and self._is_feature_available(FeatureName.TurnOn): await self.atv.power.turn_on() async def async_turn_off(self) -> None: """Turn the media player off.""" - if (self._is_feature_available(FeatureName.TurnOff)) and ( - not self._is_feature_available(FeatureName.PowerState) - or self.atv.power.power_state == PowerState.On + if ( + self.atv + and (self._is_feature_available(FeatureName.TurnOff)) + and ( + not self._is_feature_available(FeatureName.PowerState) + or self.atv.power.power_state == PowerState.On + ) ): await self.atv.power.turn_off() async def async_media_play_pause(self) -> None: """Pause media on media player.""" - if self._playing: + if self.atv and self._playing: await self.atv.remote_control.play_pause() async def async_media_play(self) -> None: @@ -519,5 +567,6 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): async def async_select_source(self, source: str) -> None: """Select input source.""" - if app_id := self._app_list.get(source): - await self.atv.apps.launch_app(app_id) + if self.atv: + if app_id := self._app_list.get(source): + await self.atv.apps.launch_app(app_id) diff --git a/homeassistant/components/apple_tv/remote.py b/homeassistant/components/apple_tv/remote.py index 24d2ef68ed4..7baa6321f21 100644 --- a/homeassistant/components/apple_tv/remote.py +++ b/homeassistant/components/apple_tv/remote.py @@ -15,7 +15,7 @@ from homeassistant.const import CONF_NAME from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback -from . import AppleTVEntity +from . import AppleTVEntity, AppleTVManager from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -38,8 +38,10 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Load Apple TV remote based on a config entry.""" - name = config_entry.data[CONF_NAME] - manager = hass.data[DOMAIN][config_entry.unique_id] + name: str = config_entry.data[CONF_NAME] + # apple_tv config entries always have a unique id + assert config_entry.unique_id is not None + manager: AppleTVManager = hass.data[DOMAIN][config_entry.unique_id] async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)]) @@ -47,7 +49,7 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity): """Device that sends commands to an Apple TV.""" @property - def is_on(self): + def is_on(self) -> bool: """Return true if device is on.""" return self.atv is not None @@ -64,13 +66,13 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity): num_repeats = kwargs[ATTR_NUM_REPEATS] delay = kwargs.get(ATTR_DELAY_SECS, DEFAULT_DELAY_SECS) - if not self.is_on: + if not self.atv: _LOGGER.error("Unable to send commands, not connected to %s", self.name) return for _ in range(num_repeats): for single_command in command: - attr_value = None + attr_value: Any = None if attributes := COMMAND_TO_ATTRIBUTE.get(single_command): attr_value = self.atv for attr_name in attributes: @@ -81,5 +83,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity): raise ValueError("Command not found. Exiting sequence") _LOGGER.info("Sending command %s", single_command) - await attr_value() # type: ignore[operator] + await attr_value() await asyncio.sleep(delay) diff --git a/mypy.ini b/mypy.ini index 6bafe51e1a0..224508fb6bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -560,6 +560,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.apple_tv.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.apprise.*] check_untyped_defs = true disallow_incomplete_defs = true diff --git a/tests/components/apple_tv/test_config_flow.py b/tests/components/apple_tv/test_config_flow.py index 714fe987bc8..9b9020c3cf1 100644 --- a/tests/components/apple_tv/test_config_flow.py +++ b/tests/components/apple_tv/test_config_flow.py @@ -1,6 +1,6 @@ """Test config flow.""" from ipaddress import IPv4Address, ip_address -from unittest.mock import ANY, patch +from unittest.mock import ANY, Mock, patch from pyatv import exceptions from pyatv.const import PairingRequirement, Protocol @@ -125,7 +125,7 @@ async def test_user_adds_full_device(hass: HomeAssistant, full_device, pairing) result["flow_id"], {"pin": 1111} ) assert result4["type"] == data_entry_flow.FlowResultType.FORM - assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111} + assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"} result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) assert result5["type"] == data_entry_flow.FlowResultType.FORM @@ -167,7 +167,7 @@ async def test_user_adds_dmap_device( result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) assert result3["type"] == data_entry_flow.FlowResultType.FORM - assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"} + assert result3["description_placeholders"] == {"pin": "1111", "protocol": "DMAP"} result6 = await hass.config_entries.flow.async_configure( result["flow_id"], {"pin": 1234} @@ -646,7 +646,7 @@ async def test_zeroconf_add_dmap_device( {}, ) assert result2["type"] == data_entry_flow.FlowResultType.FORM - assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111} + assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"} result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) assert result3["type"] == "create_entry" @@ -1130,6 +1130,43 @@ async def test_zeroconf_pair_additionally_found_protocols( assert result5["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY +async def test_zeroconf_mismatch( + hass: HomeAssistant, mock_scan, pairing, mock_zeroconf: None +) -> None: + """Test the technically possible case where a protocol has no service. + + This could happen in case of mDNS issues. + """ + mock_scan.result = [ + create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service()) + ] + mock_scan.result[0].get_service = Mock(return_value=None) + + # Find device with AirPlay service and set up flow for it + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=zeroconf.ZeroconfServiceInfo( + ip_address=ip_address("127.0.0.1"), + ip_addresses=[ip_address("127.0.0.1")], + hostname="mock_hostname", + port=None, + type="_airplay._tcp.local.", + name="Kitchen", + properties={"deviceid": "airplayid"}, + ), + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + await hass.async_block_till_done() + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {}, + ) + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "setup_failed" + + # Re-configuration