mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Enable strict type checking on apple_tv integration (#101688)
* Enable strict type checking on apple_tv integration * move some instance variables to class variables * fix type of attr_value * fix tests for description_placeholders assertion * nits * Apply suggestions from code review * Update remote.py * Apply suggestions from code review * Update __init__.py * Update __init__.py * Update __init__.py * Update config_flow.py * Improve test coverage * Update test_config_flow.py * Update __init__.py --------- Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
d555f91702
commit
636c7ce350
@ -80,6 +80,7 @@ homeassistant.components.anthemav.*
|
||||
homeassistant.components.apache_kafka.*
|
||||
homeassistant.components.apcupsd.*
|
||||
homeassistant.components.api.*
|
||||
homeassistant.components.apple_tv.*
|
||||
homeassistant.components.apprise.*
|
||||
homeassistant.components.aprs.*
|
||||
homeassistant.components.aqualogic.*
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""The Apple TV integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from random import randrange
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from pyatv import connect, exceptions, scan
|
||||
from pyatv.conf import AppleTV
|
||||
@ -25,7 +27,7 @@ from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
@ -89,7 +91,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
|
||||
|
||||
async def on_hass_stop(event):
|
||||
async def on_hass_stop(event: Event) -> None:
|
||||
"""Stop push updates when hass stops."""
|
||||
await manager.disconnect()
|
||||
|
||||
@ -120,33 +122,29 @@ class AppleTVEntity(Entity):
|
||||
_attr_should_poll = False
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
atv: AppleTVInterface | None = None
|
||||
|
||||
def __init__(
|
||||
self, name: str, identifier: str | None, manager: "AppleTVManager"
|
||||
) -> None:
|
||||
def __init__(self, name: str, identifier: str, manager: AppleTVManager) -> None:
|
||||
"""Initialize device."""
|
||||
self.atv: AppleTVInterface = None # type: ignore[assignment]
|
||||
self.manager = manager
|
||||
if TYPE_CHECKING:
|
||||
assert identifier is not None
|
||||
self._attr_unique_id = identifier
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, identifier)},
|
||||
name=name,
|
||||
)
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Handle when an entity is about to be added to Home Assistant."""
|
||||
|
||||
@callback
|
||||
def _async_connected(atv):
|
||||
def _async_connected(atv: AppleTVInterface) -> None:
|
||||
"""Handle that a connection was made to a device."""
|
||||
self.atv = atv
|
||||
self.async_device_connected(atv)
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def _async_disconnected():
|
||||
def _async_disconnected() -> None:
|
||||
"""Handle that a connection to a device was lost."""
|
||||
self.async_device_disconnected()
|
||||
self.atv = None
|
||||
@ -169,10 +167,10 @@ class AppleTVEntity(Entity):
|
||||
)
|
||||
)
|
||||
|
||||
def async_device_connected(self, atv):
|
||||
def async_device_connected(self, atv: AppleTVInterface) -> None:
|
||||
"""Handle when connection is made to device."""
|
||||
|
||||
def async_device_disconnected(self):
|
||||
def async_device_disconnected(self) -> None:
|
||||
"""Handle when connection was lost to device."""
|
||||
|
||||
|
||||
@ -184,22 +182,23 @@ class AppleTVManager(DeviceListener):
|
||||
in case of problems.
|
||||
"""
|
||||
|
||||
atv: AppleTVInterface | None = None
|
||||
_connection_attempts = 0
|
||||
_connection_was_lost = False
|
||||
_task: asyncio.Task[None] | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize power manager."""
|
||||
self.config_entry = config_entry
|
||||
self.hass = hass
|
||||
self.atv: AppleTVInterface | None = None
|
||||
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
|
||||
self._connection_attempts = 0
|
||||
self._connection_was_lost = False
|
||||
self._task = None
|
||||
|
||||
async def init(self):
|
||||
async def init(self) -> None:
|
||||
"""Initialize power management."""
|
||||
if self.is_on:
|
||||
await self.connect()
|
||||
|
||||
def connection_lost(self, _):
|
||||
def connection_lost(self, exception: Exception) -> None:
|
||||
"""Device was unexpectedly disconnected.
|
||||
|
||||
This is a callback function from pyatv.interface.DeviceListener.
|
||||
@ -210,14 +209,14 @@ class AppleTVManager(DeviceListener):
|
||||
self._connection_was_lost = True
|
||||
self._handle_disconnect()
|
||||
|
||||
def connection_closed(self):
|
||||
def connection_closed(self) -> None:
|
||||
"""Device connection was (intentionally) closed.
|
||||
|
||||
This is a callback function from pyatv.interface.DeviceListener.
|
||||
"""
|
||||
self._handle_disconnect()
|
||||
|
||||
def _handle_disconnect(self):
|
||||
def _handle_disconnect(self) -> None:
|
||||
"""Handle that the device disconnected and restart connect loop."""
|
||||
if self.atv:
|
||||
self.atv.close()
|
||||
@ -225,12 +224,12 @@ class AppleTVManager(DeviceListener):
|
||||
self._dispatch_send(SIGNAL_DISCONNECTED)
|
||||
self._start_connect_loop()
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
"""Connect to device."""
|
||||
self.is_on = True
|
||||
self._start_connect_loop()
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from device."""
|
||||
_LOGGER.debug("Disconnecting from device")
|
||||
self.is_on = False
|
||||
@ -244,7 +243,7 @@ class AppleTVManager(DeviceListener):
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("An error occurred while disconnecting")
|
||||
|
||||
def _start_connect_loop(self):
|
||||
def _start_connect_loop(self) -> None:
|
||||
"""Start background connect loop to device."""
|
||||
if not self._task and self.atv is None and self.is_on:
|
||||
self._task = asyncio.create_task(self._connect_loop())
|
||||
@ -258,7 +257,7 @@ class AppleTVManager(DeviceListener):
|
||||
if conf := await self._scan():
|
||||
await self._connect(conf, raise_missing_credentials)
|
||||
|
||||
async def async_first_connect(self):
|
||||
async def async_first_connect(self) -> None:
|
||||
"""Connect to device for the first time."""
|
||||
connect_ok = False
|
||||
try:
|
||||
@ -286,7 +285,7 @@ class AppleTVManager(DeviceListener):
|
||||
_LOGGER.exception("Failed to connect")
|
||||
await self.disconnect()
|
||||
|
||||
async def _connect_loop(self):
|
||||
async def _connect_loop(self) -> None:
|
||||
"""Connect loop background task function."""
|
||||
_LOGGER.debug("Starting connect loop")
|
||||
|
||||
@ -295,7 +294,8 @@ class AppleTVManager(DeviceListener):
|
||||
while self.is_on and self.atv is None:
|
||||
await self.connect_once(raise_missing_credentials=False)
|
||||
if self.atv is not None:
|
||||
break
|
||||
# Calling self.connect_once may have set self.atv
|
||||
break # type: ignore[unreachable]
|
||||
self._connection_attempts += 1
|
||||
backoff = min(
|
||||
max(
|
||||
@ -392,7 +392,7 @@ class AppleTVManager(DeviceListener):
|
||||
self._connection_was_lost = False
|
||||
|
||||
@callback
|
||||
def _async_setup_device_registry(self):
|
||||
def _async_setup_device_registry(self) -> None:
|
||||
attrs = {
|
||||
ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)},
|
||||
ATTR_MANUFACTURER: "Apple",
|
||||
@ -423,18 +423,18 @@ class AppleTVManager(DeviceListener):
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connecting(self):
|
||||
def is_connecting(self) -> bool:
|
||||
"""Return true if connection is in progress."""
|
||||
return self._task is not None
|
||||
|
||||
def _address_updated(self, address):
|
||||
def _address_updated(self, address: str) -> None:
|
||||
"""Update cached address in config entry."""
|
||||
_LOGGER.debug("Changing address to %s", address)
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
|
||||
)
|
||||
|
||||
def _dispatch_send(self, signal, *args):
|
||||
def _dispatch_send(self, signal: str, *args: Any) -> None:
|
||||
"""Dispatch a signal to all entities managed by this manager."""
|
||||
async_dispatcher_send(
|
||||
self.hass, f"{signal}_{self.config_entry.unique_id}", *args
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from random import randrange
|
||||
@ -13,12 +13,13 @@ from pyatv import exceptions, pair, scan
|
||||
from pyatv.const import DeviceModel, PairingRequirement, Protocol
|
||||
from pyatv.convert import model_str, protocol_str
|
||||
from pyatv.helpers import get_unique_id
|
||||
from pyatv.interface import BaseConfig, PairingHandler
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import zeroconf
|
||||
from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import AbortFlow, FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
@ -49,10 +50,12 @@ OPTIONS_FLOW = {
|
||||
}
|
||||
|
||||
|
||||
async def device_scan(hass, identifier, loop):
|
||||
async def device_scan(
|
||||
hass: HomeAssistant, identifier: str | None, loop: asyncio.AbstractEventLoop
|
||||
) -> tuple[BaseConfig | None, list[str] | None]:
|
||||
"""Scan for a specific device using identifier as filter."""
|
||||
|
||||
def _filter_device(dev):
|
||||
def _filter_device(dev: BaseConfig) -> bool:
|
||||
if identifier is None:
|
||||
return True
|
||||
if identifier == str(dev.address):
|
||||
@ -61,9 +64,12 @@ async def device_scan(hass, identifier, loop):
|
||||
return True
|
||||
return any(service.identifier == identifier for service in dev.services)
|
||||
|
||||
def _host_filter():
|
||||
def _host_filter() -> list[str] | None:
|
||||
if identifier is None:
|
||||
return None
|
||||
try:
|
||||
return [ip_address(identifier)]
|
||||
ip_address(identifier)
|
||||
return [identifier]
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@ -84,6 +90,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
scan_filter: str | None = None
|
||||
atv: BaseConfig | None = None
|
||||
atv_identifiers: list[str] | None = None
|
||||
protocol: Protocol | None = None
|
||||
pairing: PairingHandler | None = None
|
||||
protocols_to_pair: deque[Protocol] | None = None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(
|
||||
@ -92,18 +105,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Get options flow for this handler."""
|
||||
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a new AppleTVConfigFlow."""
|
||||
self.scan_filter = None
|
||||
self.atv = None
|
||||
self.atv_identifiers = None
|
||||
self.protocol = None
|
||||
self.pairing = None
|
||||
self.credentials = {} # Protocol -> credentials
|
||||
self.protocols_to_pair = deque()
|
||||
self.credentials: dict[int, str | None] = {} # Protocol -> credentials
|
||||
|
||||
@property
|
||||
def device_identifier(self):
|
||||
def device_identifier(self) -> str | None:
|
||||
"""Return a identifier for the config entry.
|
||||
|
||||
A device has multiple unique identifiers, but Home Assistant only supports one
|
||||
@ -118,6 +125,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
existing config entry. If that's the case, the unique_id from that entry is
|
||||
re-used, otherwise the newly discovered identifier is used instead.
|
||||
"""
|
||||
assert self.atv
|
||||
all_identifiers = set(self.atv.all_identifiers)
|
||||
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
|
||||
return unique_id
|
||||
@ -143,7 +151,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
self.context["identifier"] = self.unique_id
|
||||
return await self.async_step_reconfigure()
|
||||
|
||||
async def async_step_reconfigure(self, user_input=None):
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Inform user that reconfiguration is about to start."""
|
||||
if user_input is not None:
|
||||
return await self.async_find_device_wrapper(
|
||||
@ -152,7 +162,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
return self.async_show_form(step_id="reconfigure")
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
@ -170,6 +182,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
await self.async_set_unique_id(
|
||||
self.device_identifier, raise_on_progress=False
|
||||
)
|
||||
assert self.atv
|
||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
||||
return await self.async_step_confirm()
|
||||
|
||||
@ -275,8 +288,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
context["all_identifiers"].append(unique_id)
|
||||
raise AbortFlow("already_in_progress")
|
||||
|
||||
async def async_found_zeroconf_device(self, user_input=None):
|
||||
async def async_found_zeroconf_device(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle device found after Zeroconf discovery."""
|
||||
assert self.atv
|
||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
||||
# Also abort if an integration with this identifier already exists
|
||||
await self.async_set_unique_id(self.device_identifier)
|
||||
@ -288,7 +304,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
self.context["identifier"] = self.unique_id
|
||||
return await self.async_step_confirm()
|
||||
|
||||
async def async_find_device_wrapper(self, next_func, allow_exist=False):
|
||||
async def async_find_device_wrapper(
|
||||
self,
|
||||
next_func: Callable[[], Awaitable[FlowResult]],
|
||||
allow_exist: bool = False,
|
||||
) -> FlowResult:
|
||||
"""Find a specific device and call another function when done.
|
||||
|
||||
This function will do error handling and bail out when an error
|
||||
@ -306,7 +326,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
return await next_func()
|
||||
|
||||
async def async_find_device(self, allow_exist=False):
|
||||
async def async_find_device(self, allow_exist: bool = False) -> None:
|
||||
"""Scan for the selected device to discover services."""
|
||||
self.atv, self.atv_identifiers = await device_scan(
|
||||
self.hass, self.scan_filter, self.hass.loop
|
||||
@ -357,8 +377,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
if not allow_exist:
|
||||
raise DeviceAlreadyConfigured()
|
||||
|
||||
async def async_step_confirm(self, user_input=None):
|
||||
async def async_step_confirm(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle user-confirmation of discovered node."""
|
||||
assert self.atv
|
||||
if user_input is not None:
|
||||
expected_identifier_count = len(self.context["all_identifiers"])
|
||||
# If number of services found during device scan mismatch number of
|
||||
@ -384,7 +407,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
},
|
||||
)
|
||||
|
||||
async def async_pair_next_protocol(self):
|
||||
async def async_pair_next_protocol(self) -> FlowResult:
|
||||
"""Start pairing process for the next available protocol."""
|
||||
await self._async_cleanup()
|
||||
|
||||
@ -393,8 +416,16 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
return await self._async_get_entry()
|
||||
|
||||
self.protocol = self.protocols_to_pair.popleft()
|
||||
assert self.atv
|
||||
service = self.atv.get_service(self.protocol)
|
||||
|
||||
if service is None:
|
||||
_LOGGER.debug(
|
||||
"%s does not support pairing (cannot find a corresponding service)",
|
||||
self.protocol,
|
||||
)
|
||||
return await self.async_pair_next_protocol()
|
||||
|
||||
# Service requires a password
|
||||
if service.requires_password:
|
||||
return await self.async_step_password()
|
||||
@ -413,7 +444,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
_LOGGER.debug("%s requires pairing", self.protocol)
|
||||
|
||||
# Protocol specific arguments
|
||||
pair_args = {}
|
||||
pair_args: dict[str, Any] = {}
|
||||
if self.protocol in {Protocol.AirPlay, Protocol.Companion, Protocol.DMAP}:
|
||||
pair_args["name"] = "Home Assistant"
|
||||
if self.protocol == Protocol.DMAP:
|
||||
@ -448,8 +479,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
return await self.async_step_pair_no_pin()
|
||||
|
||||
async def async_step_protocol_disabled(self, user_input=None):
|
||||
async def async_step_protocol_disabled(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Inform user that a protocol is disabled and cannot be paired."""
|
||||
assert self.protocol
|
||||
if user_input is not None:
|
||||
return await self.async_pair_next_protocol()
|
||||
return self.async_show_form(
|
||||
@ -457,9 +491,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
async def async_step_pair_with_pin(self, user_input=None):
|
||||
async def async_step_pair_with_pin(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle pairing step where a PIN is required from the user."""
|
||||
errors = {}
|
||||
assert self.pairing
|
||||
assert self.protocol
|
||||
if user_input is not None:
|
||||
try:
|
||||
self.pairing.pin(user_input[CONF_PIN])
|
||||
@ -480,8 +518,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
async def async_step_pair_no_pin(self, user_input=None):
|
||||
async def async_step_pair_no_pin(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle step where user has to enter a PIN on the device."""
|
||||
assert self.pairing
|
||||
assert self.protocol
|
||||
if user_input is not None:
|
||||
await self.pairing.finish()
|
||||
if self.pairing.has_paired:
|
||||
@ -497,12 +539,15 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
step_id="pair_no_pin",
|
||||
description_placeholders={
|
||||
"protocol": protocol_str(self.protocol),
|
||||
"pin": pin,
|
||||
"pin": str(pin),
|
||||
},
|
||||
)
|
||||
|
||||
async def async_step_service_problem(self, user_input=None):
|
||||
async def async_step_service_problem(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Inform user that a service will not be added."""
|
||||
assert self.protocol
|
||||
if user_input is not None:
|
||||
return await self.async_pair_next_protocol()
|
||||
|
||||
@ -511,8 +556,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
async def async_step_password(self, user_input=None):
|
||||
async def async_step_password(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Inform user that password is not supported."""
|
||||
assert self.protocol
|
||||
if user_input is not None:
|
||||
return await self.async_pair_next_protocol()
|
||||
|
||||
@ -521,18 +569,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
async def _async_cleanup(self):
|
||||
async def _async_cleanup(self) -> None:
|
||||
"""Clean up allocated resources."""
|
||||
if self.pairing is not None:
|
||||
await self.pairing.close()
|
||||
self.pairing = None
|
||||
|
||||
async def _async_get_entry(self):
|
||||
async def _async_get_entry(self) -> FlowResult:
|
||||
"""Return config entry or update existing config entry."""
|
||||
# Abort if no protocols were paired
|
||||
if not self.credentials:
|
||||
return self.async_abort(reason="setup_failed")
|
||||
|
||||
assert self.atv
|
||||
|
||||
data = {
|
||||
CONF_NAME: self.atv.name,
|
||||
CONF_CREDENTIALS: self.credentials,
|
||||
|
@ -16,7 +16,15 @@ from pyatv.const import (
|
||||
ShuffleState,
|
||||
)
|
||||
from pyatv.helpers import is_streamable
|
||||
from pyatv.interface import AppleTV, Playing
|
||||
from pyatv.interface import (
|
||||
AppleTV,
|
||||
AudioListener,
|
||||
OutputDevice,
|
||||
Playing,
|
||||
PowerListener,
|
||||
PushListener,
|
||||
PushUpdater,
|
||||
)
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.media_player import (
|
||||
@ -101,7 +109,9 @@ async def async_setup_entry(
|
||||
async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)])
|
||||
|
||||
|
||||
class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
class AppleTvMediaPlayer(
|
||||
AppleTVEntity, MediaPlayerEntity, PowerListener, AudioListener, PushListener
|
||||
):
|
||||
"""Representation of an Apple TV media player."""
|
||||
|
||||
_attr_supported_features = SUPPORT_APPLE_TV
|
||||
@ -116,9 +126,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
def async_device_connected(self, atv: AppleTV) -> None:
|
||||
"""Handle when connection is made to device."""
|
||||
# NB: Do not use _is_feature_available here as it only works when playing
|
||||
if self.atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates):
|
||||
self.atv.push_updater.listener = self
|
||||
self.atv.push_updater.start()
|
||||
if atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates):
|
||||
atv.push_updater.listener = self
|
||||
atv.push_updater.start()
|
||||
|
||||
self._attr_supported_features = SUPPORT_BASE
|
||||
|
||||
@ -126,7 +136,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
# "Unsupported" are considered here as the state of such a feature can never
|
||||
# change after a connection has been established, i.e. an unsupported feature
|
||||
# can never change to be supported.
|
||||
all_features = self.atv.features.all_features()
|
||||
all_features = atv.features.all_features()
|
||||
for feature_name, support_flag in SUPPORT_FEATURE_MAPPING.items():
|
||||
feature_info = all_features.get(feature_name)
|
||||
if feature_info and feature_info.state != FeatureState.Unsupported:
|
||||
@ -136,16 +146,18 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
# metadata update arrives (sometime very soon after this callback returns)
|
||||
|
||||
# Listen to power updates
|
||||
self.atv.power.listener = self
|
||||
atv.power.listener = self
|
||||
|
||||
# Listen to volume updates
|
||||
self.atv.audio.listener = self
|
||||
atv.audio.listener = self
|
||||
|
||||
if self.atv.features.in_state(FeatureState.Available, FeatureName.AppList):
|
||||
if atv.features.in_state(FeatureState.Available, FeatureName.AppList):
|
||||
self.hass.create_task(self._update_app_list())
|
||||
|
||||
async def _update_app_list(self) -> None:
|
||||
_LOGGER.debug("Updating app list")
|
||||
if not self.atv:
|
||||
return
|
||||
try:
|
||||
apps = await self.atv.apps.app_list()
|
||||
except exceptions.NotSupportedError:
|
||||
@ -189,33 +201,56 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
return None
|
||||
|
||||
@callback
|
||||
def playstatus_update(self, _, playing: Playing) -> None:
|
||||
"""Print what is currently playing when it changes."""
|
||||
self._playing = playing
|
||||
def playstatus_update(self, updater: PushUpdater, playstatus: Playing) -> None:
|
||||
"""Print what is currently playing when it changes.
|
||||
|
||||
This is a callback function from pyatv.interface.PushListener.
|
||||
"""
|
||||
self._playing = playstatus
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def playstatus_error(self, _, exception: Exception) -> None:
|
||||
"""Inform about an error and restart push updates."""
|
||||
def playstatus_error(self, updater: PushUpdater, exception: Exception) -> None:
|
||||
"""Inform about an error and restart push updates.
|
||||
|
||||
This is a callback function from pyatv.interface.PushListener.
|
||||
"""
|
||||
_LOGGER.warning("A %s error occurred: %s", exception.__class__, exception)
|
||||
self._playing = None
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def powerstate_update(self, old_state: PowerState, new_state: PowerState) -> None:
|
||||
"""Update power state when it changes."""
|
||||
"""Update power state when it changes.
|
||||
|
||||
This is a callback function from pyatv.interface.PowerListener.
|
||||
"""
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def volume_update(self, old_level: float, new_level: float) -> None:
|
||||
"""Update volume when it changes."""
|
||||
"""Update volume when it changes.
|
||||
|
||||
This is a callback function from pyatv.interface.AudioListener.
|
||||
"""
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def outputdevices_update(
|
||||
self, old_devices: list[OutputDevice], new_devices: list[OutputDevice]
|
||||
) -> None:
|
||||
"""Output devices were updated.
|
||||
|
||||
This is a callback function from pyatv.interface.AudioListener.
|
||||
"""
|
||||
|
||||
@property
|
||||
def app_id(self) -> str | None:
|
||||
"""ID of the current running app."""
|
||||
if self._is_feature_available(FeatureName.App) and (
|
||||
app := self.atv.metadata.app
|
||||
if (
|
||||
self.atv
|
||||
and self._is_feature_available(FeatureName.App)
|
||||
and (app := self.atv.metadata.app) is not None
|
||||
):
|
||||
return app.identifier
|
||||
return None
|
||||
@ -223,8 +258,10 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
@property
|
||||
def app_name(self) -> str | None:
|
||||
"""Name of the current running app."""
|
||||
if self._is_feature_available(FeatureName.App) and (
|
||||
app := self.atv.metadata.app
|
||||
if (
|
||||
self.atv
|
||||
and self._is_feature_available(FeatureName.App)
|
||||
and (app := self.atv.metadata.app) is not None
|
||||
):
|
||||
return app.name
|
||||
return None
|
||||
@ -255,7 +292,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
@property
|
||||
def volume_level(self) -> float | None:
|
||||
"""Volume level of the media player (0..1)."""
|
||||
if self._is_feature_available(FeatureName.Volume):
|
||||
if self.atv and self._is_feature_available(FeatureName.Volume):
|
||||
return self.atv.audio.volume / 100.0 # from percent
|
||||
return None
|
||||
|
||||
@ -286,6 +323,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
"""Send the play_media command to the media player."""
|
||||
# If input (file) has a file format supported by pyatv, then stream it with
|
||||
# RAOP. Otherwise try to play it with regular AirPlay.
|
||||
if not self.atv:
|
||||
return
|
||||
if media_type in {MediaType.APP, MediaType.URL}:
|
||||
await self.atv.apps.launch_app(media_id)
|
||||
return
|
||||
@ -313,7 +352,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
"""Hash value for media image."""
|
||||
state = self.state
|
||||
if (
|
||||
self._playing
|
||||
self.atv
|
||||
and self._playing
|
||||
and self._is_feature_available(FeatureName.Artwork)
|
||||
and state not in {None, MediaPlayerState.OFF, MediaPlayerState.IDLE}
|
||||
):
|
||||
@ -323,7 +363,11 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
async def async_get_media_image(self) -> tuple[bytes | None, str | None]:
|
||||
"""Fetch media image of current playing image."""
|
||||
state = self.state
|
||||
if self._playing and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}:
|
||||
if (
|
||||
self.atv
|
||||
and self._playing
|
||||
and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}
|
||||
):
|
||||
artwork = await self.atv.metadata.artwork()
|
||||
if artwork:
|
||||
return artwork.bytes, artwork.mimetype
|
||||
@ -439,20 +483,24 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
|
||||
async def async_turn_on(self) -> None:
|
||||
"""Turn the media player on."""
|
||||
if self._is_feature_available(FeatureName.TurnOn):
|
||||
if self.atv and self._is_feature_available(FeatureName.TurnOn):
|
||||
await self.atv.power.turn_on()
|
||||
|
||||
async def async_turn_off(self) -> None:
|
||||
"""Turn the media player off."""
|
||||
if (self._is_feature_available(FeatureName.TurnOff)) and (
|
||||
not self._is_feature_available(FeatureName.PowerState)
|
||||
or self.atv.power.power_state == PowerState.On
|
||||
if (
|
||||
self.atv
|
||||
and (self._is_feature_available(FeatureName.TurnOff))
|
||||
and (
|
||||
not self._is_feature_available(FeatureName.PowerState)
|
||||
or self.atv.power.power_state == PowerState.On
|
||||
)
|
||||
):
|
||||
await self.atv.power.turn_off()
|
||||
|
||||
async def async_media_play_pause(self) -> None:
|
||||
"""Pause media on media player."""
|
||||
if self._playing:
|
||||
if self.atv and self._playing:
|
||||
await self.atv.remote_control.play_pause()
|
||||
|
||||
async def async_media_play(self) -> None:
|
||||
@ -519,5 +567,6 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
|
||||
async def async_select_source(self, source: str) -> None:
|
||||
"""Select input source."""
|
||||
if app_id := self._app_list.get(source):
|
||||
await self.atv.apps.launch_app(app_id)
|
||||
if self.atv:
|
||||
if app_id := self._app_list.get(source):
|
||||
await self.atv.apps.launch_app(app_id)
|
||||
|
@ -15,7 +15,7 @@ from homeassistant.const import CONF_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from . import AppleTVEntity
|
||||
from . import AppleTVEntity, AppleTVManager
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -38,8 +38,10 @@ async def async_setup_entry(
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Load Apple TV remote based on a config entry."""
|
||||
name = config_entry.data[CONF_NAME]
|
||||
manager = hass.data[DOMAIN][config_entry.unique_id]
|
||||
name: str = config_entry.data[CONF_NAME]
|
||||
# apple_tv config entries always have a unique id
|
||||
assert config_entry.unique_id is not None
|
||||
manager: AppleTVManager = hass.data[DOMAIN][config_entry.unique_id]
|
||||
async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)])
|
||||
|
||||
|
||||
@ -47,7 +49,7 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
|
||||
"""Device that sends commands to an Apple TV."""
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
def is_on(self) -> bool:
|
||||
"""Return true if device is on."""
|
||||
return self.atv is not None
|
||||
|
||||
@ -64,13 +66,13 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
|
||||
num_repeats = kwargs[ATTR_NUM_REPEATS]
|
||||
delay = kwargs.get(ATTR_DELAY_SECS, DEFAULT_DELAY_SECS)
|
||||
|
||||
if not self.is_on:
|
||||
if not self.atv:
|
||||
_LOGGER.error("Unable to send commands, not connected to %s", self.name)
|
||||
return
|
||||
|
||||
for _ in range(num_repeats):
|
||||
for single_command in command:
|
||||
attr_value = None
|
||||
attr_value: Any = None
|
||||
if attributes := COMMAND_TO_ATTRIBUTE.get(single_command):
|
||||
attr_value = self.atv
|
||||
for attr_name in attributes:
|
||||
@ -81,5 +83,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
|
||||
raise ValueError("Command not found. Exiting sequence")
|
||||
|
||||
_LOGGER.info("Sending command %s", single_command)
|
||||
await attr_value() # type: ignore[operator]
|
||||
await attr_value()
|
||||
await asyncio.sleep(delay)
|
||||
|
10
mypy.ini
10
mypy.ini
@ -560,6 +560,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.apple_tv.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.apprise.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test config flow."""
|
||||
from ipaddress import IPv4Address, ip_address
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
from pyatv import exceptions
|
||||
from pyatv.const import PairingRequirement, Protocol
|
||||
@ -125,7 +125,7 @@ async def test_user_adds_full_device(hass: HomeAssistant, full_device, pairing)
|
||||
result["flow_id"], {"pin": 1111}
|
||||
)
|
||||
assert result4["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
|
||||
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
|
||||
|
||||
result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result5["type"] == data_entry_flow.FlowResultType.FORM
|
||||
@ -167,7 +167,7 @@ async def test_user_adds_dmap_device(
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"}
|
||||
assert result3["description_placeholders"] == {"pin": "1111", "protocol": "DMAP"}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1234}
|
||||
@ -646,7 +646,7 @@ async def test_zeroconf_add_dmap_device(
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
|
||||
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == "create_entry"
|
||||
@ -1130,6 +1130,43 @@ async def test_zeroconf_pair_additionally_found_protocols(
|
||||
assert result5["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_zeroconf_mismatch(
|
||||
hass: HomeAssistant, mock_scan, pairing, mock_zeroconf: None
|
||||
) -> None:
|
||||
"""Test the technically possible case where a protocol has no service.
|
||||
|
||||
This could happen in case of mDNS issues.
|
||||
"""
|
||||
mock_scan.result = [
|
||||
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
|
||||
]
|
||||
mock_scan.result[0].get_service = Mock(return_value=None)
|
||||
|
||||
# Find device with AirPlay service and set up flow for it
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
data=zeroconf.ZeroconfServiceInfo(
|
||||
ip_address=ip_address("127.0.0.1"),
|
||||
ip_addresses=[ip_address("127.0.0.1")],
|
||||
hostname="mock_hostname",
|
||||
port=None,
|
||||
type="_airplay._tcp.local.",
|
||||
name="Kitchen",
|
||||
properties={"deviceid": "airplayid"},
|
||||
),
|
||||
)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
assert result["reason"] == "setup_failed"
|
||||
|
||||
|
||||
# Re-configuration
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user