Code styling tweaks to core helpers (#85441)

This commit is contained in:
Franck Nijhof 2023-01-09 00:44:09 +01:00 committed by GitHub
parent cf5fca0464
commit 06a35fb7db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 192 additions and 100 deletions

View File

@ -124,9 +124,14 @@ def _async_create_clientsession(
# If a package requires a different user agent, override it by passing a headers # If a package requires a different user agent, override it by passing a headers
# dictionary to the request method. # dictionary to the request method.
# pylint: disable=protected-access # pylint: disable=protected-access
clientsession._default_headers = MappingProxyType({USER_AGENT: SERVER_SOFTWARE}) # type: ignore[assignment] clientsession._default_headers = MappingProxyType( # type: ignore[assignment]
{USER_AGENT: SERVER_SOFTWARE},
)
clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore[assignment] clientsession.close = warn_use( # type: ignore[assignment]
clientsession.close,
WARN_CLOSE_MSG,
)
if auto_cleanup_method: if auto_cleanup_method:
auto_cleanup_method(hass, clientsession) auto_cleanup_method(hass, clientsession)

View File

@ -611,8 +611,8 @@ def sun(
# Special case: before sunrise OR after sunset # Special case: before sunrise OR after sunset
# This will handle the very rare case in the polar region when the sun rises/sets # This will handle the very rare case in the polar region when the sun rises/sets
# but does not set/rise. # but does not set/rise.
# However this entire condition does not handle those full days of darkness or light, # However this entire condition does not handle those full days of darkness
# the following should be used instead: # or light, the following should be used instead:
# #
# condition: # condition:
# condition: state # condition: state

View File

@ -778,12 +778,12 @@ def _deprecated_or_removed(
raise_if_present: bool, raise_if_present: bool,
option_removed: bool, option_removed: bool,
) -> Callable[[dict], dict]: ) -> Callable[[dict], dict]:
""" """Log key as deprecated and provide a replacement (if exists) or fail.
Log key as deprecated and provide a replacement (if exists) or fail.
Expected behavior: Expected behavior:
- Outputs or throws the appropriate deprecation warning if key is detected - Outputs or throws the appropriate deprecation warning if key is detected
- Outputs or throws the appropriate error if key is detected and removed from support - Outputs or throws the appropriate error if key is detected
and removed from support
- Processes schema moving the value from key to replacement_key - Processes schema moving the value from key to replacement_key
- Processes schema changing nothing if only replacement_key provided - Processes schema changing nothing if only replacement_key provided
- No warning if only replacement_key provided - No warning if only replacement_key provided
@ -809,7 +809,10 @@ def _deprecated_or_removed(
"""Check if key is in config and log warning or error.""" """Check if key is in config and log warning or error."""
if key in config: if key in config:
try: try:
near = f"near {config.__config_file__}:{config.__line__} " # type: ignore[attr-defined] near = (
f"near {config.__config_file__}" # type: ignore[attr-defined]
f":{config.__line__} "
)
except AttributeError: except AttributeError:
near = "" near = ""
arguments: tuple[str, ...] arguments: tuple[str, ...]
@ -851,11 +854,11 @@ def deprecated(
default: Any | None = None, default: Any | None = None,
raise_if_present: bool | None = False, raise_if_present: bool | None = False,
) -> Callable[[dict], dict]: ) -> Callable[[dict], dict]:
""" """Log key as deprecated and provide a replacement (if exists).
Log key as deprecated and provide a replacement (if exists).
Expected behavior: Expected behavior:
- Outputs the appropriate deprecation warning if key is detected or raises an exception - Outputs the appropriate deprecation warning if key is detected
or raises an exception
- Processes schema moving the value from key to replacement_key - Processes schema moving the value from key to replacement_key
- Processes schema changing nothing if only replacement_key provided - Processes schema changing nothing if only replacement_key provided
- No warning if only replacement_key provided - No warning if only replacement_key provided
@ -876,11 +879,11 @@ def removed(
default: Any | None = None, default: Any | None = None,
raise_if_present: bool | None = True, raise_if_present: bool | None = True,
) -> Callable[[dict], dict]: ) -> Callable[[dict], dict]:
""" """Log key as deprecated and fail the config validation.
Log key as deprecated and fail the config validation.
Expected behavior: Expected behavior:
- Outputs the appropriate error if key is detected and removed from support or raises an exception - Outputs the appropriate error if key is detected and removed from
support or raises an exception.
""" """
return _deprecated_or_removed( return _deprecated_or_removed(
key, key,

View File

@ -115,7 +115,7 @@ def deprecated_class(
def deprecated_function( def deprecated_function(
replacement: str, replacement: str,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Mark function as deprecated and provide a replacement function to be used instead.""" """Mark function as deprecated and provide a replacement to be used instead."""
def deprecated_decorator(func: Callable[_P, _R]) -> Callable[_P, _R]: def deprecated_decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""Decorate function as deprecated.""" """Decorate function as deprecated."""

View File

@ -161,7 +161,9 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
device.setdefault("configuration_url", None) device.setdefault("configuration_url", None)
device.setdefault("disabled_by", None) device.setdefault("disabled_by", None)
try: try:
device["entry_type"] = DeviceEntryType(device.get("entry_type")) # type: ignore[arg-type] device["entry_type"] = DeviceEntryType(
device.get("entry_type"), # type: ignore[arg-type]
)
except ValueError: except ValueError:
device["entry_type"] = None device["entry_type"] = None
device.setdefault("name_by_user", None) device.setdefault("name_by_user", None)
@ -550,7 +552,10 @@ class DeviceRegistry:
config_entries=set(device["config_entries"]), config_entries=set(device["config_entries"]),
configuration_url=device["configuration_url"], configuration_url=device["configuration_url"],
# type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625 # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc] connections={
tuple(conn) # type: ignore[misc]
for conn in device["connections"]
},
disabled_by=DeviceEntryDisabler(device["disabled_by"]) disabled_by=DeviceEntryDisabler(device["disabled_by"])
if device["disabled_by"] if device["disabled_by"]
else None, else None,
@ -559,7 +564,10 @@ class DeviceRegistry:
else None, else None,
hw_version=device["hw_version"], hw_version=device["hw_version"],
id=device["id"], id=device["id"],
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc] identifiers={
tuple(iden) # type: ignore[misc]
for iden in device["identifiers"]
},
manufacturer=device["manufacturer"], manufacturer=device["manufacturer"],
model=device["model"], model=device["model"],
name_by_user=device["name_by_user"], name_by_user=device["name_by_user"],
@ -572,8 +580,14 @@ class DeviceRegistry:
deleted_devices[device["id"]] = DeletedDeviceEntry( deleted_devices[device["id"]] = DeletedDeviceEntry(
config_entries=set(device["config_entries"]), config_entries=set(device["config_entries"]),
# type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625 # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc] connections={
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc] tuple(conn) # type: ignore[misc]
for conn in device["connections"]
},
identifiers={
tuple(iden) # type: ignore[misc]
for iden in device["identifiers"]
},
id=device["id"], id=device["id"],
orphaned_timestamp=device["orphaned_timestamp"], orphaned_timestamp=device["orphaned_timestamp"],
) )

View File

@ -185,10 +185,11 @@ class EntityCategory(StrEnum):
- Not be included in indirect service calls to devices or areas - Not be included in indirect service calls to devices or areas
""" """
# Config: An entity which allows changing the configuration of a device # Config: An entity which allows changing the configuration of a device.
CONFIG = "config" CONFIG = "config"
# Diagnostic: An entity exposing some configuration parameter or diagnostics of a device # Diagnostic: An entity exposing some configuration parameter,
# or diagnostics of a device.
DIAGNOSTIC = "diagnostic" DIAGNOSTIC = "diagnostic"
@ -198,13 +199,16 @@ ENTITY_CATEGORIES_SCHEMA: Final = vol.Coerce(EntityCategory)
class EntityPlatformState(Enum): class EntityPlatformState(Enum):
"""The platform state of an entity.""" """The platform state of an entity."""
# Not Added: Not yet added to a platform, polling updates are written to the state machine # Not Added: Not yet added to a platform, polling updates
# are written to the state machine.
NOT_ADDED = auto() NOT_ADDED = auto()
# Added: Added to a platform, polling updates are written to the state machine # Added: Added to a platform, polling updates
# are written to the state machine.
ADDED = auto() ADDED = auto()
# Removed: Removed from a platform, polling updates are not written to the state machine # Removed: Removed from a platform, polling updates
# are not written to the state machine.
REMOVED = auto() REMOVED = auto()
@ -458,7 +462,10 @@ class Entity(ABC):
@property @property
def entity_registry_enabled_default(self) -> bool: def entity_registry_enabled_default(self) -> bool:
"""Return if the entity should be enabled when first added to the entity registry.""" """Return if the entity should be enabled when first added.
This only applies when fist added to the entity registry.
"""
if hasattr(self, "_attr_entity_registry_enabled_default"): if hasattr(self, "_attr_entity_registry_enabled_default"):
return self._attr_entity_registry_enabled_default return self._attr_entity_registry_enabled_default
if hasattr(self, "entity_description"): if hasattr(self, "entity_description"):
@ -467,7 +474,10 @@ class Entity(ABC):
@property @property
def entity_registry_visible_default(self) -> bool: def entity_registry_visible_default(self) -> bool:
"""Return if the entity should be visible when first added to the entity registry.""" """Return if the entity should be visible when first added.
This only applies when fist added to the entity registry.
"""
if hasattr(self, "_attr_entity_registry_visible_default"): if hasattr(self, "_attr_entity_registry_visible_default"):
return self._attr_entity_registry_visible_default return self._attr_entity_registry_visible_default
if hasattr(self, "entity_description"): if hasattr(self, "entity_description"):

View File

@ -90,11 +90,11 @@ class EntityComponent(Generic[_EntityT]):
@property @property
def entities(self) -> Iterable[_EntityT]: def entities(self) -> Iterable[_EntityT]:
""" """Return an iterable that returns all entities.
Return an iterable that returns all entities.
As the underlying dicts may change when async context is lost, callers that As the underlying dicts may change when async context is lost,
iterate over this asynchronously should make a copy using list() before iterating. callers that iterate over this asynchronously should make a copy
using list() before iterating.
""" """
return chain.from_iterable( return chain.from_iterable(
platform.entities.values() # type: ignore[misc] platform.entities.values() # type: ignore[misc]

View File

@ -158,12 +158,16 @@ class EntityPlatform:
) -> asyncio.Semaphore | None: ) -> asyncio.Semaphore | None:
"""Get or create a semaphore for parallel updates. """Get or create a semaphore for parallel updates.
Semaphore will be created on demand because we base it off if update method is async or not. Semaphore will be created on demand because we base it off if update
method is async or not.
If parallel updates is set to 0, we skip the semaphore. - If parallel updates is set to 0, we skip the semaphore.
If parallel updates is set to a number, we initialize the semaphore to that number. - If parallel updates is set to a number, we initialize the semaphore
The default value for parallel requests is decided based on the first entity that is added to Home Assistant. to that number.
It's 0 if the entity defines the async_update method, else it's 1.
The default value for parallel requests is decided based on the first
entity that is added to Home Assistant. It's 0 if the entity defines
the async_update method, else it's 1.
""" """
if self.parallel_updates_created: if self.parallel_updates_created:
return self.parallel_updates return self.parallel_updates
@ -566,7 +570,9 @@ class EntityPlatform:
"via_device", "via_device",
): ):
if key in device_info: if key in device_info:
processed_dev_info[key] = device_info[key] # type: ignore[literal-required] processed_dev_info[key] = device_info[
key # type: ignore[literal-required]
]
if "configuration_url" in device_info: if "configuration_url" in device_info:
if device_info["configuration_url"] is None: if device_info["configuration_url"] is None:
@ -586,7 +592,9 @@ class EntityPlatform:
) )
try: try:
device = device_registry.async_get_or_create(**processed_dev_info) # type: ignore[arg-type] device = device_registry.async_get_or_create(
**processed_dev_info # type: ignore[arg-type]
)
device_id = device.id device_id = device.id
except RequiredParameterMissing: except RequiredParameterMissing:
pass pass

View File

@ -119,7 +119,8 @@ class RegistryEntry:
has_entity_name: bool = attr.ib(default=False) has_entity_name: bool = attr.ib(default=False)
name: str | None = attr.ib(default=None) name: str | None = attr.ib(default=None)
options: EntityOptionsType = attr.ib( options: EntityOptionsType = attr.ib(
default=None, converter=attr.converters.default_if_none(factory=dict) # type: ignore[misc] default=None,
converter=attr.converters.default_if_none(factory=dict), # type: ignore[misc]
) )
# As set by integration # As set by integration
original_device_class: str | None = attr.ib(default=None) original_device_class: str | None = attr.ib(default=None)
@ -780,8 +781,7 @@ class EntityRegistry:
new_unique_id: str | UndefinedType = UNDEFINED, new_unique_id: str | UndefinedType = UNDEFINED,
new_device_id: str | None | UndefinedType = UNDEFINED, new_device_id: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry: ) -> RegistryEntry:
""" """Update entity platform.
Update entity platform.
This should only be used when an entity needs to be migrated between This should only be used when an entity needs to be migrated between
integrations. integrations.

View File

@ -710,8 +710,8 @@ def async_track_state_change_filtered(
Returns Returns
------- -------
Object used to update the listeners (async_update_listeners) with a new TrackStates or Object used to update the listeners (async_update_listeners) with a new
cancel the tracking (async_remove). TrackStates or cancel the tracking (async_remove).
""" """
tracker = _TrackStateChangeFiltered(hass, track_states, action) tracker = _TrackStateChangeFiltered(hass, track_states, action)

View File

@ -50,8 +50,11 @@ def find_coordinates(
) -> str | None: ) -> str | None:
"""Try to resolve the a location from a supplied name or entity_id. """Try to resolve the a location from a supplied name or entity_id.
Will recursively resolve an entity if pointed to by the state of the supplied entity. Will recursively resolve an entity if pointed to by the state of the supplied
Returns coordinates in the form of '90.000,180.000', an address or the state of the last resolved entity. entity.
Returns coordinates in the form of '90.000,180.000', an address or
the state of the last resolved entity.
""" """
# Check if a friendly name of a zone was supplied # Check if a friendly name of a zone was supplied
if (zone_coords := resolve_zone(hass, name)) is not None: if (zone_coords := resolve_zone(hass, name)) is not None:
@ -70,7 +73,9 @@ def find_coordinates(
zone_entity = hass.states.get(f"zone.{entity_state.state}") zone_entity = hass.states.get(f"zone.{entity_state.state}")
if has_location(zone_entity): # type: ignore[arg-type] if has_location(zone_entity): # type: ignore[arg-type]
_LOGGER.debug( _LOGGER.debug(
"%s is in %s, getting zone location", name, zone_entity.entity_id # type: ignore[union-attr] "%s is in %s, getting zone location",
name,
zone_entity.entity_id, # type: ignore[union-attr]
) )
return _get_location_from_attributes(zone_entity) # type: ignore[arg-type] return _get_location_from_attributes(zone_entity) # type: ignore[arg-type]
@ -97,12 +102,16 @@ def find_coordinates(
_LOGGER.debug("Resolving nested entity_id: %s", entity_state.state) _LOGGER.debug("Resolving nested entity_id: %s", entity_state.state)
return find_coordinates(hass, entity_state.state, recursion_history) return find_coordinates(hass, entity_state.state, recursion_history)
# Might be an address, coordinates or anything else. This has to be checked by the caller. # Might be an address, coordinates or anything else.
# This has to be checked by the caller.
return entity_state.state return entity_state.state
def resolve_zone(hass: HomeAssistant, zone_name: str) -> str | None: def resolve_zone(hass: HomeAssistant, zone_name: str) -> str | None:
"""Get a lat/long from a zones friendly_name or None if no zone is found by that friendly_name.""" """Get a lat/long from a zones friendly_name.
None is returned if no zone is found by that friendly_name.
"""
states = hass.states.async_all("zone") states = hass.states.async_all("zone")
for state in states: for state in states:
if state.name == zone_name: if state.name == zone_name:

View File

@ -303,7 +303,9 @@ class RestoreEntity(Entity):
"""Get data stored for an entity, if any.""" """Get data stored for an entity, if any."""
if self.hass is None or self.entity_id is None: if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet # Return None if this entity isn't added to hass yet
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable] _LOGGER.warning( # type: ignore[unreachable]
"Cannot get last state. Entity not added to hass"
)
return None return None
data = await RestoreStateData.async_get_instance(self.hass) data = await RestoreStateData.async_get_instance(self.hass)
if self.entity_id not in data.last_states: if self.entity_id not in data.last_states:

View File

@ -60,9 +60,9 @@ class SchemaFlowFormStep(SchemaFlowStep):
"""Optional property to identify next step. """Optional property to identify next step.
- If `next_step` is a function, it is called if the schema validates successfully or - If `next_step` is a function, it is called if the schema validates successfully or
if no schema is defined. The `next_step` function is passed the union of config entry if no schema is defined. The `next_step` function is passed the union of
options and user input from previous steps. If the function returns None, the flow is config entry options and user input from previous steps. If the function returns
ended with `FlowResultType.CREATE_ENTRY`. None, the flow is ended with `FlowResultType.CREATE_ENTRY`.
- If `next_step` is None, the flow is ended with `FlowResultType.CREATE_ENTRY`. - If `next_step` is None, the flow is ended with `FlowResultType.CREATE_ENTRY`.
""" """
@ -71,11 +71,11 @@ class SchemaFlowFormStep(SchemaFlowStep):
] | None | UndefinedType = UNDEFINED ] | None | UndefinedType = UNDEFINED
"""Optional property to populate suggested values. """Optional property to populate suggested values.
- If `suggested_values` is UNDEFINED, each key in the schema will get a suggested value - If `suggested_values` is UNDEFINED, each key in the schema will get a suggested
from an option with the same key. value from an option with the same key.
Note: if a step is retried due to a validation failure, then the user input will have Note: if a step is retried due to a validation failure, then the user input will
priority over the suggested values. have priority over the suggested values.
""" """
@ -331,8 +331,8 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
) -> None: ) -> None:
"""Take necessary actions after the options flow is finished, if needed. """Take necessary actions after the options flow is finished, if needed.
The options parameter contains config entry options, which is the union of stored The options parameter contains config entry options, which is the union of
options and user input from the options flow steps. stored options and user input from the options flow steps.
""" """
@callback @callback

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
def sensor_device_info_to_hass_device_info( def sensor_device_info_to_hass_device_info(
sensor_device_info: SensorDeviceInfo, sensor_device_info: SensorDeviceInfo,
) -> DeviceInfo: ) -> DeviceInfo:
"""Convert a sensor_state_data sensor device info to a Home Assistant device info.""" """Convert a sensor_state_data sensor device info to a HA device info."""
device_info = DeviceInfo() device_info = DeviceInfo()
if sensor_device_info.name is not None: if sensor_device_info.name is not None:
device_info[const.ATTR_NAME] = sensor_device_info.name device_info[const.ATTR_NAME] = sensor_device_info.name

View File

@ -371,7 +371,8 @@ def async_extract_referenced_entity_ids(
return selected return selected
for ent_entry in ent_reg.entities.values(): for ent_entry in ent_reg.entities.values():
# Do not add entities which are hidden or which are config or diagnostic entities # Do not add entities which are hidden or which are config
# or diagnostic entities.
if ent_entry.entity_category is not None or ent_entry.hidden_by is not None: if ent_entry.entity_category is not None or ent_entry.hidden_by is not None:
continue continue
@ -489,7 +490,10 @@ async def async_get_all_descriptions(
# Cache missing descriptions # Cache missing descriptions
if description is None: if description is None:
domain_yaml = loaded[domain] domain_yaml = loaded[domain]
yaml_description = domain_yaml.get(service, {}) # type: ignore[union-attr]
yaml_description = domain_yaml.get( # type: ignore[union-attr]
service, {}
)
# Don't warn for missing services, because it triggers false # Don't warn for missing services, because it triggers false
# positives for things like scripts, that register as a service # positives for things like scripts, that register as a service
@ -706,11 +710,14 @@ async def _handle_entity_call(
entity.async_set_context(context) entity.async_set_context(context)
if isinstance(func, str): if isinstance(func, str):
result = hass.async_run_job(partial(getattr(entity, func), **data)) # type: ignore[arg-type] result = hass.async_run_job(
partial(getattr(entity, func), **data) # type: ignore[arg-type]
)
else: else:
result = hass.async_run_job(func, entity, data) result = hass.async_run_job(func, entity, data)
# Guard because callback functions do not return a task when passed to async_run_job. # Guard because callback functions do not return a task when passed to
# async_run_job.
if result is not None: if result is not None:
await result await result

View File

@ -31,8 +31,7 @@ _LOGGER = logging.getLogger(__name__)
class AsyncTrackStates: class AsyncTrackStates:
""" """Record the time when the with-block is entered.
Record the time when the with-block is entered.
Add all states that have changed since the start time to the return list Add all states that have changed since the start time to the return list
when with-block is exited. when with-block is exited.
@ -119,8 +118,7 @@ async def async_reproduce_state(
def state_as_number(state: State) -> float: def state_as_number(state: State) -> float:
""" """Try to coerce our state to a number.
Try to coerce our state to a number.
Raises ValueError if this is not possible. Raises ValueError if this is not possible.
""" """

View File

@ -104,8 +104,9 @@ class Store(Generic[_T]):
async def async_load(self) -> _T | None: async def async_load(self) -> _T | None:
"""Load data. """Load data.
If the expected version and minor version do not match the given versions, the If the expected version and minor version do not match the given
migrate function will be invoked with migrate_func(version, minor_version, config). versions, the migrate function will be invoked with
migrate_func(version, minor_version, config).
Will ensure that when a call comes in while another one is in progress, Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call. the second call will wait and return the result of the first call.

View File

@ -255,20 +255,39 @@ class RenderInfo:
def __repr__(self) -> str: def __repr__(self) -> str:
"""Representation of RenderInfo.""" """Representation of RenderInfo."""
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities} rate_limit={self.rate_limit}> has_time={self.has_time}" return (
f"<RenderInfo {self.template}"
f" all_states={self.all_states}"
f" all_states_lifecycle={self.all_states_lifecycle}"
f" domains={self.domains}"
f" domains_lifecycle={self.domains_lifecycle}"
f" entities={self.entities}"
f" rate_limit={self.rate_limit}"
f" has_time={self.has_time}"
">"
)
def _filter_domains_and_entities(self, entity_id: str) -> bool: def _filter_domains_and_entities(self, entity_id: str) -> bool:
"""Template should re-render if the entity state changes when we match specific domains or entities.""" """Template should re-render if the entity state changes.
Only when we match specific domains or entities.
"""
return ( return (
split_entity_id(entity_id)[0] in self.domains or entity_id in self.entities split_entity_id(entity_id)[0] in self.domains or entity_id in self.entities
) )
def _filter_entities(self, entity_id: str) -> bool: def _filter_entities(self, entity_id: str) -> bool:
"""Template should re-render if the entity state changes when we match specific entities.""" """Template should re-render if the entity state changes.
Only when we match specific entities.
"""
return entity_id in self.entities return entity_id in self.entities
def _filter_lifecycle_domains(self, entity_id: str) -> bool: def _filter_lifecycle_domains(self, entity_id: str) -> bool:
"""Template should re-render if the entity is added or removed with domains watched.""" """Template should re-render if the entity is added or removed.
Only with domains watched.
"""
return split_entity_id(entity_id)[0] in self.domains_lifecycle return split_entity_id(entity_id)[0] in self.domains_lifecycle
def result(self) -> str: def result(self) -> str:
@ -359,7 +378,11 @@ class Template:
wanted_env = _ENVIRONMENT wanted_env = _ENVIRONMENT
ret: TemplateEnvironment | None = self.hass.data.get(wanted_env) ret: TemplateEnvironment | None = self.hass.data.get(wanted_env)
if ret is None: if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment(self.hass, self._limited, self._strict) # type: ignore[no-untyped-call] ret = self.hass.data[wanted_env] = TemplateEnvironment( # type: ignore[no-untyped-call]
self.hass,
self._limited,
self._strict,
)
return ret return ret
def ensure_valid(self) -> None: def ensure_valid(self) -> None:
@ -382,7 +405,8 @@ class Template:
) -> Any: ) -> Any:
"""Render given template. """Render given template.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine. If limited is True, the template is not allowed to access any function
or filter depending on hass or the state machine.
""" """
if self.is_static: if self.is_static:
if not parse_result or self.hass and self.hass.config.legacy_templates: if not parse_result or self.hass and self.hass.config.legacy_templates:
@ -407,7 +431,8 @@ class Template:
This method must be run in the event loop. This method must be run in the event loop.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine. If limited is True, the template is not allowed to access any function
or filter depending on hass or the state machine.
""" """
if self.is_static: if self.is_static:
if not parse_result or self.hass and self.hass.config.legacy_templates: if not parse_result or self.hass and self.hass.config.legacy_templates:
@ -1039,11 +1064,11 @@ def device_entities(hass: HomeAssistant, _device_id: str) -> Iterable[str]:
def integration_entities(hass: HomeAssistant, entry_name: str) -> Iterable[str]: def integration_entities(hass: HomeAssistant, entry_name: str) -> Iterable[str]:
""" """Get entity ids for entities tied to an integration/domain.
Get entity ids for entities tied to an integration/domain.
Provide entry_name as domain to get all entity id's for a integration/domain Provide entry_name as domain to get all entity id's for a integration/domain
or provide a config entry title for filtering between instances of the same integration. or provide a config entry title for filtering between instances of the same
integration.
""" """
# first try if this is a config entry match # first try if this is a config entry match
conf_entry = next( conf_entry = next(
@ -1643,8 +1668,7 @@ def fail_when_undefined(value):
def min_max_from_filter(builtin_filter: Any, name: str) -> Any: def min_max_from_filter(builtin_filter: Any, name: str) -> Any:
""" """Convert a built-in min/max Jinja filter to a global function.
Convert a built-in min/max Jinja filter to a global function.
The parameters may be passed as an iterable or as separate arguments. The parameters may be passed as an iterable or as separate arguments.
""" """
@ -1667,16 +1691,17 @@ def min_max_from_filter(builtin_filter: Any, name: str) -> Any:
def average(*args: Any, default: Any = _SENTINEL) -> Any: def average(*args: Any, default: Any = _SENTINEL) -> Any:
""" """Filter and function to calculate the arithmetic mean.
Filter and function to calculate the arithmetic mean of an iterable or of two or more arguments.
Calculates of an iterable or of two or more arguments.
The parameters may be passed as an iterable or as separate arguments. The parameters may be passed as an iterable or as separate arguments.
""" """
if len(args) == 0: if len(args) == 0:
raise TypeError("average expected at least 1 argument, got 0") raise TypeError("average expected at least 1 argument, got 0")
# If first argument is iterable and more than 1 argument provided but not a named default, # If first argument is iterable and more than 1 argument provided but not a named
# then use 2nd argument as default. # default, then use 2nd argument as default.
if isinstance(args[0], Iterable): if isinstance(args[0], Iterable):
average_list = args[0] average_list = args[0]
if len(args) > 1 and default is _SENTINEL: if len(args) > 1 and default is _SENTINEL:
@ -1884,8 +1909,7 @@ def today_at(time_str: str = "") -> datetime:
def relative_time(value): def relative_time(value):
""" """Take a datetime and return its "age" as a string.
Take a datetime and return its "age" as a string.
The age can be in second, minute, hour, day, month or year. Only the The age can be in second, minute, hour, day, month or year. Only the
biggest unit is considered, e.g. if it's 2 days and 3 hours, "2 days" will biggest unit is considered, e.g. if it's 2 days and 3 hours, "2 days" will
@ -1953,7 +1977,7 @@ def _render_with_context(
class LoggingUndefined(jinja2.Undefined): class LoggingUndefined(jinja2.Undefined):
"""Log on undefined variables.""" """Log on undefined variables."""
def _log_message(self): def _log_message(self) -> None:
template, action = template_cv.get() or ("", "rendering or compiling") template, action = template_cv.get() or ("", "rendering or compiling")
_LOGGER.warning( _LOGGER.warning(
"Template variable warning: %s when %s '%s'", "Template variable warning: %s when %s '%s'",
@ -1975,7 +1999,7 @@ class LoggingUndefined(jinja2.Undefined):
) )
raise ex raise ex
def __str__(self): def __str__(self) -> str:
"""Log undefined __str___.""" """Log undefined __str___."""
self._log_message() self._log_message()
return super().__str__() return super().__str__()
@ -1985,7 +2009,7 @@ class LoggingUndefined(jinja2.Undefined):
self._log_message() self._log_message()
return super().__iter__() return super().__iter__()
def __bool__(self): def __bool__(self) -> bool:
"""Log undefined __bool___.""" """Log undefined __bool___."""
self._log_message() self._log_message()
return super().__bool__() return super().__bool__()
@ -1996,13 +2020,16 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
def __init__(self, hass, limited=False, strict=False): def __init__(self, hass, limited=False, strict=False):
"""Initialise template environment.""" """Initialise template environment."""
undefined: type[LoggingUndefined] | type[jinja2.StrictUndefined]
if not strict: if not strict:
undefined = LoggingUndefined undefined = LoggingUndefined
else: else:
undefined = jinja2.StrictUndefined undefined = jinja2.StrictUndefined
super().__init__(undefined=undefined) super().__init__(undefined=undefined)
self.hass = hass self.hass = hass
self.template_cache = weakref.WeakValueDictionary() self.template_cache: weakref.WeakValueDictionary[
str | jinja2.nodes.Template, CodeType | str | None
] = weakref.WeakValueDictionary()
self.filters["round"] = forgiving_round self.filters["round"] = forgiving_round
self.filters["multiply"] = multiply self.filters["multiply"] = multiply
self.filters["log"] = logarithm self.filters["log"] = logarithm
@ -2138,8 +2165,8 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
if limited: if limited:
# Only device_entities is available to limited templates, mark other # Only device_entities is available to limited templates, mark other
# functions and filters as unsupported. # functions and filters as unsupported.
def unsupported(name): def unsupported(name: str) -> Callable[[], NoReturn]:
def warn_unsupported(*args, **kwargs): def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
raise TemplateError( raise TemplateError(
f"Use of '{name}' is not supported in limited templates" f"Use of '{name}' is not supported in limited templates"
) )
@ -2247,7 +2274,6 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
defer_init, defer_init,
) )
cached: CodeType | str | None
if (cached := self.template_cache.get(source)) is None: if (cached := self.template_cache.get(source)) is None:
cached = self.template_cache[source] = super().compile(source) cached = self.template_cache[source] = super().compile(source)

View File

@ -268,8 +268,7 @@ class TemplateEntity(Entity):
on_update: Callable[[Any], None] | None = None, on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool = False, none_on_template_error: bool = False,
) -> None: ) -> None:
""" """Call in the constructor to add a template linked to a attribute.
Call in the constructor to add a template linked to a attribute.
Parameters Parameters
---------- ----------

View File

@ -257,7 +257,9 @@ class _TranslationCache:
_merge_resources if category == "state" else _build_resources _merge_resources if category == "state" else _build_resources
) )
new_resources: Mapping[str, dict[str, Any] | str] new_resources: Mapping[str, dict[str, Any] | str]
new_resources = resource_func(translation_strings, components, category) # type: ignore[assignment] new_resources = resource_func( # type: ignore[assignment]
translation_strings, components, category
)
for component, resource in new_resources.items(): for component, resource in new_resources.items():
category_cache: dict[str, Any] = cached.setdefault( category_cache: dict[str, Any] = cached.setdefault(

View File

@ -127,7 +127,10 @@ class PluggableAction:
action: TriggerActionType, action: TriggerActionType,
variables: dict[str, Any], variables: dict[str, Any],
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach an action to a trigger entry. Existing or future plugs registered will be attached.""" """Attach an action to a trigger entry.
Existing or future plugs registered will be attached.
"""
reg = PluggableAction.async_get_registry(hass) reg = PluggableAction.async_get_registry(hass)
key = tuple(sorted(trigger.items())) key = tuple(sorted(trigger.items()))
entry = reg[key] entry = reg[key]
@ -163,7 +166,10 @@ class PluggableAction:
@callback @callback
def _remove() -> None: def _remove() -> None:
"""Remove plug from registration, and clean up entry if there are no actions or plugs registered.""" """Remove plug from registration.
Clean up entry if there are no actions or plugs registered.
"""
assert self._entry assert self._entry
self._entry.plugs.remove(self) self._entry.plugs.remove(self)
if not self._entry.actions and not self._entry.plugs: if not self._entry.actions and not self._entry.plugs:

View File

@ -403,7 +403,9 @@ class CoordinatorEntity(BaseCoordinatorEntity[_DataUpdateCoordinatorT]):
def __init__( def __init__(
self, coordinator: _DataUpdateCoordinatorT, context: Any = None self, coordinator: _DataUpdateCoordinatorT, context: Any = None
) -> None: ) -> None:
"""Create the entity with a DataUpdateCoordinator. Passthrough to BaseCoordinatorEntity. """Create the entity with a DataUpdateCoordinator.
Passthrough to BaseCoordinatorEntity.
Necessary to bind TypeVar to correct scope. Necessary to bind TypeVar to correct scope.
""" """