Use assignment expressions 01 (#56394)

This commit is contained in:
Marc Mueller 2021-09-19 01:31:35 +02:00 committed by GitHub
parent a4f6c3336f
commit 7af67d34cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 73 additions and 171 deletions

View File

@ -132,16 +132,14 @@ def get_arguments() -> argparse.Namespace:
def daemonize() -> None: def daemonize() -> None:
"""Move current process to daemon process.""" """Move current process to daemon process."""
# Create first fork # Create first fork
pid = os.fork() if os.fork() > 0:
if pid > 0:
sys.exit(0) sys.exit(0)
# Decouple fork # Decouple fork
os.setsid() os.setsid()
# Create second fork # Create second fork
pid = os.fork() if os.fork() > 0:
if pid > 0:
sys.exit(0) sys.exit(0)
# redirect standard file descriptors to devnull # redirect standard file descriptors to devnull

View File

@ -341,8 +341,7 @@ class AuthManager:
"System generated users cannot enable multi-factor auth module." "System generated users cannot enable multi-factor auth module."
) )
module = self.get_auth_mfa_module(mfa_module_id) if (module := self.get_auth_mfa_module(mfa_module_id)) is None:
if module is None:
raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}") raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}")
await module.async_setup_user(user.id, data) await module.async_setup_user(user.id, data)
@ -356,8 +355,7 @@ class AuthManager:
"System generated users cannot disable multi-factor auth module." "System generated users cannot disable multi-factor auth module."
) )
module = self.get_auth_mfa_module(mfa_module_id) if (module := self.get_auth_mfa_module(mfa_module_id)) is None:
if module is None:
raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}") raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}")
await module.async_depose_user(user.id) await module.async_depose_user(user.id)
@ -498,8 +496,7 @@ class AuthManager:
Will raise InvalidAuthError on errors. Will raise InvalidAuthError on errors.
""" """
provider = self._async_resolve_provider(refresh_token) if provider := self._async_resolve_provider(refresh_token):
if provider:
provider.async_validate_refresh_token(refresh_token, remote_ip) provider.async_validate_refresh_token(refresh_token, remote_ip)
async def async_validate_access_token( async def async_validate_access_token(

View File

@ -96,8 +96,7 @@ class AuthStore:
groups = [] groups = []
for group_id in group_ids or []: for group_id in group_ids or []:
group = self._groups.get(group_id) if (group := self._groups.get(group_id)) is None:
if group is None:
raise ValueError(f"Invalid group specified {group_id}") raise ValueError(f"Invalid group specified {group_id}")
groups.append(group) groups.append(group)
@ -160,8 +159,7 @@ class AuthStore:
if group_ids is not None: if group_ids is not None:
groups = [] groups = []
for grid in group_ids: for grid in group_ids:
group = self._groups.get(grid) if (group := self._groups.get(grid)) is None:
if group is None:
raise ValueError("Invalid group specified.") raise ValueError("Invalid group specified.")
groups.append(group) groups.append(group)
@ -446,16 +444,14 @@ class AuthStore:
) )
continue continue
token_type = rt_dict.get("token_type") if (token_type := rt_dict.get("token_type")) is None:
if token_type is None:
if rt_dict["client_id"] is None: if rt_dict["client_id"] is None:
token_type = models.TOKEN_TYPE_SYSTEM token_type = models.TOKEN_TYPE_SYSTEM
else: else:
token_type = models.TOKEN_TYPE_NORMAL token_type = models.TOKEN_TYPE_NORMAL
# old refresh_token don't have last_used_at (pre-0.78) # old refresh_token don't have last_used_at (pre-0.78)
last_used_at_str = rt_dict.get("last_used_at") if last_used_at_str := rt_dict.get("last_used_at"):
if last_used_at_str:
last_used_at = dt_util.parse_datetime(last_used_at_str) last_used_at = dt_util.parse_datetime(last_used_at_str)
else: else:
last_used_at = None last_used_at = None

View File

@ -118,9 +118,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
if self._user_settings is not None: if self._user_settings is not None:
return return
data = await self._user_store.async_load() if (data := await self._user_store.async_load()) is None:
if data is None:
data = {STORAGE_USERS: {}} data = {STORAGE_USERS: {}}
self._user_settings = { self._user_settings = {
@ -207,8 +205,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self._async_load() await self._async_load()
assert self._user_settings is not None assert self._user_settings is not None
notify_setting = self._user_settings.get(user_id) if (notify_setting := self._user_settings.get(user_id)) is None:
if notify_setting is None:
return False return False
# user_input has been validate in caller # user_input has been validate in caller
@ -225,8 +222,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self._async_load() await self._async_load()
assert self._user_settings is not None assert self._user_settings is not None
notify_setting = self._user_settings.get(user_id) if (notify_setting := self._user_settings.get(user_id)) is None:
if notify_setting is None:
raise ValueError("Cannot find user_id") raise ValueError("Cannot find user_id")
def generate_secret_and_one_time_password() -> str: def generate_secret_and_one_time_password() -> str:

View File

@ -92,9 +92,7 @@ class TotpAuthModule(MultiFactorAuthModule):
if self._users is not None: if self._users is not None:
return return
data = await self._user_store.async_load() if (data := await self._user_store.async_load()) is None:
if data is None:
data = {STORAGE_USERS: {}} data = {STORAGE_USERS: {}}
self._users = data.get(STORAGE_USERS, {}) self._users = data.get(STORAGE_USERS, {})
@ -163,8 +161,7 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Validate two factor authentication code.""" """Validate two factor authentication code."""
import pyotp # pylint: disable=import-outside-toplevel import pyotp # pylint: disable=import-outside-toplevel
ota_secret = self._users.get(user_id) # type: ignore if (ota_secret := self._users.get(user_id)) is None: # type: ignore
if ota_secret is None:
# even we cannot find user, we still do verify # even we cannot find user, we still do verify
# to make timing the same as if user was found. # to make timing the same as if user was found.
pyotp.TOTP(DUMMY_SECRET).verify(code, valid_window=1) pyotp.TOTP(DUMMY_SECRET).verify(code, valid_window=1)

View File

@ -33,9 +33,7 @@ class AbstractPermissions:
def check_entity(self, entity_id: str, key: str) -> bool: def check_entity(self, entity_id: str, key: str) -> bool:
"""Check if we can access entity.""" """Check if we can access entity."""
entity_func = self._cached_entity_func if (entity_func := self._cached_entity_func) is None:
if entity_func is None:
entity_func = self._cached_entity_func = self._entity_func() entity_func = self._cached_entity_func = self._entity_func()
return entity_func(entity_id, key) return entity_func(entity_id, key)

View File

@ -72,8 +72,7 @@ def compile_policy(
def apply_policy_funcs(object_id: str, key: str) -> bool: def apply_policy_funcs(object_id: str, key: str) -> bool:
"""Apply several policy functions.""" """Apply several policy functions."""
for func in funcs: for func in funcs:
result = func(object_id, key) if (result := func(object_id, key)) is not None:
if result is not None:
return result return result
return False return False

View File

@ -169,9 +169,7 @@ async def load_auth_provider_module(
if hass.config.skip_pip or not hasattr(module, "REQUIREMENTS"): if hass.config.skip_pip or not hasattr(module, "REQUIREMENTS"):
return module return module
processed = hass.data.get(DATA_REQS) if (processed := hass.data.get(DATA_REQS)) is None:
if processed is None:
processed = hass.data[DATA_REQS] = set() processed = hass.data[DATA_REQS] = set()
elif provider in processed: elif provider in processed:
return module return module

View File

@ -82,9 +82,7 @@ class Data:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load stored data.""" """Load stored data."""
data = await self._store.async_load() if (data := await self._store.async_load()) is None:
if data is None:
data = {"users": []} data = {"users": []}
seen: set[str] = set() seen: set[str] = set()
@ -93,9 +91,7 @@ class Data:
username = user["username"] username = user["username"]
# check if we have duplicates # check if we have duplicates
folded = username.casefold() if (folded := username.casefold()) in seen:
if folded in seen:
self.is_legacy = True self.is_legacy = True
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(

View File

@ -109,9 +109,8 @@ async def async_setup_hass(
config_dict = None config_dict = None
basic_setup_success = False basic_setup_success = False
safe_mode = runtime_config.safe_mode
if not safe_mode: if not (safe_mode := runtime_config.safe_mode):
await hass.async_add_executor_job(conf_util.process_ha_config_upgrade, hass) await hass.async_add_executor_job(conf_util.process_ha_config_upgrade, hass)
try: try:
@ -368,8 +367,7 @@ async def async_mount_local_lib_path(config_dir: str) -> str:
This function is a coroutine. This function is a coroutine.
""" """
deps_dir = os.path.join(config_dir, "deps") deps_dir = os.path.join(config_dir, "deps")
lib_dir = await async_get_user_site(deps_dir) if (lib_dir := await async_get_user_site(deps_dir)) not in sys.path:
if lib_dir not in sys.path:
sys.path.insert(0, lib_dir) sys.path.insert(0, lib_dir)
return deps_dir return deps_dir
@ -494,17 +492,13 @@ async def _async_set_up_integrations(
_LOGGER.info("Domains to be set up: %s", domains_to_setup) _LOGGER.info("Domains to be set up: %s", domains_to_setup)
logging_domains = domains_to_setup & LOGGING_INTEGRATIONS
# Load logging as soon as possible # Load logging as soon as possible
if logging_domains: if logging_domains := domains_to_setup & LOGGING_INTEGRATIONS:
_LOGGER.info("Setting up logging: %s", logging_domains) _LOGGER.info("Setting up logging: %s", logging_domains)
await async_setup_multi_components(hass, logging_domains, config) await async_setup_multi_components(hass, logging_domains, config)
# Start up debuggers. Start these first in case they want to wait. # Start up debuggers. Start these first in case they want to wait.
debuggers = domains_to_setup & DEBUGGER_INTEGRATIONS if debuggers := domains_to_setup & DEBUGGER_INTEGRATIONS:
if debuggers:
_LOGGER.debug("Setting up debuggers: %s", debuggers) _LOGGER.debug("Setting up debuggers: %s", debuggers)
await async_setup_multi_components(hass, debuggers, config) await async_setup_multi_components(hass, debuggers, config)
@ -524,9 +518,7 @@ async def _async_set_up_integrations(
stage_1_domains.add(domain) stage_1_domains.add(domain)
dep_itg = integration_cache.get(domain) if (dep_itg := integration_cache.get(domain)) is None:
if dep_itg is None:
continue continue
deps_promotion.update(dep_itg.all_dependencies) deps_promotion.update(dep_itg.all_dependencies)

View File

@ -512,9 +512,7 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
# Only load auth during startup. # Only load auth during startup.
if not hasattr(hass, "auth"): if not hasattr(hass, "auth"):
auth_conf = config.get(CONF_AUTH_PROVIDERS) if (auth_conf := config.get(CONF_AUTH_PROVIDERS)) is None:
if auth_conf is None:
auth_conf = [{"type": "homeassistant"}] auth_conf = [{"type": "homeassistant"}]
mfa_conf = config.get( mfa_conf = config.get(
@ -598,9 +596,7 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
cust_glob = OrderedDict(config[CONF_CUSTOMIZE_GLOB]) cust_glob = OrderedDict(config[CONF_CUSTOMIZE_GLOB])
for name, pkg in config[CONF_PACKAGES].items(): for name, pkg in config[CONF_PACKAGES].items():
pkg_cust = pkg.get(CONF_CORE) if (pkg_cust := pkg.get(CONF_CORE)) is None:
if pkg_cust is None:
continue continue
try: try:
@ -957,9 +953,7 @@ def async_notify_setup_error(
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
errors = hass.data.get(DATA_PERSISTENT_ERRORS) if (errors := hass.data.get(DATA_PERSISTENT_ERRORS)) is None:
if errors is None:
errors = hass.data[DATA_PERSISTENT_ERRORS] = {} errors = hass.data[DATA_PERSISTENT_ERRORS] = {}
errors[component] = errors.get(component) or display_link errors[component] = errors.get(component) or display_link

View File

@ -492,8 +492,7 @@ class ConfigEntry:
Returns True if config entry is up-to-date or has been migrated. Returns True if config entry is up-to-date or has been migrated.
""" """
handler = HANDLERS.get(self.domain) if (handler := HANDLERS.get(self.domain)) is None:
if handler is None:
_LOGGER.error( _LOGGER.error(
"Flow handler not found for entry %s for %s", self.title, self.domain "Flow handler not found for entry %s for %s", self.title, self.domain
) )
@ -716,9 +715,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
) )
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key) if (handler := HANDLERS.get(handler_key)) is None:
if handler is None:
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
if not context or "source" not in context: if not context or "source" not in context:
@ -814,9 +811,7 @@ class ConfigEntries:
async def async_remove(self, entry_id: str) -> dict[str, Any]: async def async_remove(self, entry_id: str) -> dict[str, Any]:
"""Remove an entry.""" """Remove an entry."""
entry = self.async_get_entry(entry_id) if (entry := self.async_get_entry(entry_id)) is None:
if entry is None:
raise UnknownEntry raise UnknownEntry
if not entry.state.recoverable: if not entry.state.recoverable:
@ -933,9 +928,7 @@ class ConfigEntries:
Return True if entry has been successfully loaded. Return True if entry has been successfully loaded.
""" """
entry = self.async_get_entry(entry_id) if (entry := self.async_get_entry(entry_id)) is None:
if entry is None:
raise UnknownEntry raise UnknownEntry
if entry.state is not ConfigEntryState.NOT_LOADED: if entry.state is not ConfigEntryState.NOT_LOADED:
@ -957,9 +950,7 @@ class ConfigEntries:
async def async_unload(self, entry_id: str) -> bool: async def async_unload(self, entry_id: str) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
entry = self.async_get_entry(entry_id) if (entry := self.async_get_entry(entry_id)) is None:
if entry is None:
raise UnknownEntry raise UnknownEntry
if not entry.state.recoverable: if not entry.state.recoverable:
@ -972,9 +963,7 @@ class ConfigEntries:
If an entry was not loaded, will just load. If an entry was not loaded, will just load.
""" """
entry = self.async_get_entry(entry_id) if (entry := self.async_get_entry(entry_id)) is None:
if entry is None:
raise UnknownEntry raise UnknownEntry
unload_result = await self.async_unload(entry_id) unload_result = await self.async_unload(entry_id)
@ -991,9 +980,7 @@ class ConfigEntries:
If disabled_by is changed, the config entry will be reloaded. If disabled_by is changed, the config entry will be reloaded.
""" """
entry = self.async_get_entry(entry_id) if (entry := self.async_get_entry(entry_id)) is None:
if entry is None:
raise UnknownEntry raise UnknownEntry
if entry.disabled_by == disabled_by: if entry.disabled_by == disabled_by:
@ -1066,8 +1053,7 @@ class ConfigEntries:
return False return False
for listener_ref in entry.update_listeners: for listener_ref in entry.update_listeners:
listener = listener_ref() if (listener := listener_ref()) is not None:
if listener is not None:
self.hass.async_create_task(listener(self.hass, entry)) self.hass.async_create_task(listener(self.hass, entry))
self._async_schedule_save() self._async_schedule_save()

View File

@ -971,8 +971,7 @@ class State:
if isinstance(last_updated, str): if isinstance(last_updated, str):
last_updated = dt_util.parse_datetime(last_updated) last_updated = dt_util.parse_datetime(last_updated)
context = json_dict.get("context") if context := json_dict.get("context"):
if context:
context = Context(id=context.get("id"), user_id=context.get("user_id")) context = Context(id=context.get("id"), user_id=context.get("user_id"))
return cls( return cls(
@ -1199,8 +1198,7 @@ class StateMachine:
entity_id = entity_id.lower() entity_id = entity_id.lower()
new_state = str(new_state) new_state = str(new_state)
attributes = attributes or {} attributes = attributes or {}
old_state = self._states.get(entity_id) if (old_state := self._states.get(entity_id)) is None:
if old_state is None:
same_state = False same_state = False
same_attr = False same_attr = False
last_changed = None last_changed = None
@ -1658,9 +1656,7 @@ class Config:
def set_time_zone(self, time_zone_str: str) -> None: def set_time_zone(self, time_zone_str: str) -> None:
"""Help to set the time zone.""" """Help to set the time zone."""
time_zone = dt_util.get_time_zone(time_zone_str) if time_zone := dt_util.get_time_zone(time_zone_str):
if time_zone:
self.time_zone = time_zone_str self.time_zone = time_zone_str
dt_util.set_default_time_zone(time_zone) dt_util.set_default_time_zone(time_zone)
else: else:
@ -1717,9 +1713,8 @@ class Config:
store = self.hass.helpers.storage.Store( store = self.hass.helpers.storage.Store(
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True
) )
data = await store.async_load()
if not data: if not (data := await store.async_load()):
return return
# In 2021.9 we fixed validation to disallow a path (because that's never correct) # In 2021.9 we fixed validation to disallow a path (because that's never correct)
@ -1792,8 +1787,7 @@ def _async_create_timer(hass: HomeAssistant) -> None:
) )
# If we are more than a second late, a tick was missed # If we are more than a second late, a tick was missed
late = monotonic() - target if (late := monotonic() - target) > 1:
if late > 1:
hass.bus.async_fire( hass.bus.async_fire(
EVENT_TIMER_OUT_OF_SYNC, EVENT_TIMER_OUT_OF_SYNC,
{ATTR_SECONDS: late}, {ATTR_SECONDS: late},

View File

@ -93,9 +93,7 @@ class FlowManager(abc.ABC):
async def async_wait_init_flow_finish(self, handler: str) -> None: async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized.""" """Wait till all flows in progress are initialized."""
current = self._initializing.get(handler) if not (current := self._initializing.get(handler)):
if not current:
return return
await asyncio.wait(current) await asyncio.wait(current)
@ -189,9 +187,7 @@ class FlowManager(abc.ABC):
self, flow_id: str, user_input: dict | None = None self, flow_id: str, user_input: dict | None = None
) -> FlowResult: ) -> FlowResult:
"""Continue a configuration flow.""" """Continue a configuration flow."""
flow = self._progress.get(flow_id) if (flow := self._progress.get(flow_id)) is None:
if flow is None:
raise UnknownFlow raise UnknownFlow
cur_step = flow.cur_step cur_step = flow.cur_step

View File

@ -18,9 +18,7 @@ def config_per_platform(config: ConfigType, domain: str) -> Iterable[tuple[Any,
Async friendly. Async friendly.
""" """
for config_key in extract_domain_configs(config, domain): for config_key in extract_domain_configs(config, domain):
platform_config = config[config_key] if not (platform_config := config[config_key]):
if not platform_config:
continue continue
if not isinstance(platform_config, list): if not isinstance(platform_config, list):

View File

@ -99,13 +99,11 @@ def get_capability(hass: HomeAssistant, entity_id: str, capability: str) -> Any
First try the statemachine, then entity registry. First try the statemachine, then entity registry.
""" """
state = hass.states.get(entity_id) if state := hass.states.get(entity_id):
if state:
return state.attributes.get(capability) return state.attributes.get(capability)
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id) if not (entry := entity_registry.async_get(entity_id)):
if not entry:
raise HomeAssistantError(f"Unknown entity {entity_id}") raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.capabilities.get(capability) if entry.capabilities else None return entry.capabilities.get(capability) if entry.capabilities else None
@ -116,13 +114,11 @@ def get_device_class(hass: HomeAssistant, entity_id: str) -> str | None:
First try the statemachine, then entity registry. First try the statemachine, then entity registry.
""" """
state = hass.states.get(entity_id) if state := hass.states.get(entity_id):
if state:
return state.attributes.get(ATTR_DEVICE_CLASS) return state.attributes.get(ATTR_DEVICE_CLASS)
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id) if not (entry := entity_registry.async_get(entity_id)):
if not entry:
raise HomeAssistantError(f"Unknown entity {entity_id}") raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.device_class return entry.device_class
@ -133,13 +129,11 @@ def get_supported_features(hass: HomeAssistant, entity_id: str) -> int:
First try the statemachine, then entity registry. First try the statemachine, then entity registry.
""" """
state = hass.states.get(entity_id) if state := hass.states.get(entity_id):
if state:
return state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) return state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id) if not (entry := entity_registry.async_get(entity_id)):
if not entry:
raise HomeAssistantError(f"Unknown entity {entity_id}") raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.supported_features or 0 return entry.supported_features or 0
@ -150,13 +144,11 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
First try the statemachine, then entity registry. First try the statemachine, then entity registry.
""" """
state = hass.states.get(entity_id) if state := hass.states.get(entity_id):
if state:
return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id) if not (entry := entity_registry.async_get(entity_id)):
if not entry:
raise HomeAssistantError(f"Unknown entity {entity_id}") raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.unit_of_measurement return entry.unit_of_measurement
@ -467,8 +459,7 @@ class Entity(ABC):
"""Convert state to string.""" """Convert state to string."""
if not self.available: if not self.available:
return STATE_UNAVAILABLE return STATE_UNAVAILABLE
state = self.state if (state := self.state) is None:
if state is None:
return STATE_UNKNOWN return STATE_UNKNOWN
if isinstance(state, float): if isinstance(state, float):
# If the entity's state is a float, limit precision according to machine # If the entity's state is a float, limit precision according to machine
@ -511,28 +502,22 @@ class Entity(ABC):
entry = self.registry_entry entry = self.registry_entry
# pylint: disable=consider-using-ternary # pylint: disable=consider-using-ternary
name = (entry and entry.name) or self.name if (name := (entry and entry.name) or self.name) is not None:
if name is not None:
attr[ATTR_FRIENDLY_NAME] = name attr[ATTR_FRIENDLY_NAME] = name
icon = (entry and entry.icon) or self.icon if (icon := (entry and entry.icon) or self.icon) is not None:
if icon is not None:
attr[ATTR_ICON] = icon attr[ATTR_ICON] = icon
entity_picture = self.entity_picture if (entity_picture := self.entity_picture) is not None:
if entity_picture is not None:
attr[ATTR_ENTITY_PICTURE] = entity_picture attr[ATTR_ENTITY_PICTURE] = entity_picture
assumed_state = self.assumed_state if assumed_state := self.assumed_state:
if assumed_state:
attr[ATTR_ASSUMED_STATE] = assumed_state attr[ATTR_ASSUMED_STATE] = assumed_state
supported_features = self.supported_features if (supported_features := self.supported_features) is not None:
if supported_features is not None:
attr[ATTR_SUPPORTED_FEATURES] = supported_features attr[ATTR_SUPPORTED_FEATURES] = supported_features
device_class = self.device_class if (device_class := self.device_class) is not None:
if device_class is not None:
attr[ATTR_DEVICE_CLASS] = str(device_class) attr[ATTR_DEVICE_CLASS] = str(device_class)
end = timer() end = timer()
@ -636,8 +621,7 @@ class Entity(ABC):
finished, _ = await asyncio.wait([task], timeout=SLOW_UPDATE_WARNING) finished, _ = await asyncio.wait([task], timeout=SLOW_UPDATE_WARNING)
for done in finished: for done in finished:
exc = done.exception() if exc := done.exception():
if exc:
raise exc raise exc
return return

View File

@ -175,16 +175,14 @@ def async_track_state_change(
def state_change_filter(event: Event) -> bool: def state_change_filter(event: Event) -> bool:
"""Handle specific state changes.""" """Handle specific state changes."""
if from_state is not None: if from_state is not None:
old_state = event.data.get("old_state") if (old_state := event.data.get("old_state")) is not None:
if old_state is not None:
old_state = old_state.state old_state = old_state.state
if not match_from_state(old_state): if not match_from_state(old_state):
return False return False
if to_state is not None: if to_state is not None:
new_state = event.data.get("new_state") if (new_state := event.data.get("new_state")) is not None:
if new_state is not None:
new_state = new_state.state new_state = new_state.state
if not match_to_state(new_state): if not match_to_state(new_state):
@ -246,8 +244,7 @@ def async_track_state_change_event(
care about the state change events so we can care about the state change events so we can
do a fast dict lookup to route events. do a fast dict lookup to route events.
""" """
entity_ids = _async_string_to_lower_list(entity_ids) if not (entity_ids := _async_string_to_lower_list(entity_ids)):
if not entity_ids:
return _remove_empty_listener return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {}) entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {})
@ -336,8 +333,7 @@ def async_track_entity_registry_updated_event(
Similar to async_track_state_change_event. Similar to async_track_state_change_event.
""" """
entity_ids = _async_string_to_lower_list(entity_ids) if not (entity_ids := _async_string_to_lower_list(entity_ids)):
if not entity_ids:
return _remove_empty_listener return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}) entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {})
@ -419,8 +415,7 @@ def async_track_state_added_domain(
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track state change events when an entity is added to domains.""" """Track state change events when an entity is added to domains."""
domains = _async_string_to_lower_list(domains) if not (domains := _async_string_to_lower_list(domains)):
if not domains:
return _remove_empty_listener return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}) domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})
@ -472,8 +467,7 @@ def async_track_state_removed_domain(
action: Callable[[Event], Any], action: Callable[[Event], Any],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Track state change events when an entity is removed from domains.""" """Track state change events when an entity is removed from domains."""
domains = _async_string_to_lower_list(domains) if not (domains := _async_string_to_lower_list(domains)):
if not domains:
return _remove_empty_listener return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}) domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {})
@ -1185,8 +1179,7 @@ def async_track_point_in_utc_time(
# as measured by utcnow(). That is bad when callbacks have assumptions # as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining # about the current time. Thus, we rearm the timer for the remaining
# time. # time.
delta = (utc_point_in_time - now).total_seconds() if (delta := (utc_point_in_time - now).total_seconds()) > 0:
if delta > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta) _LOGGER.debug("Called %f seconds too early, rearming", delta)
cancel_callback = hass.loop.call_later(delta, run_action, job) cancel_callback = hass.loop.call_later(delta, run_action, job)
@ -1520,11 +1513,9 @@ def _rate_limit_for_event(
event: Event, info: RenderInfo, track_template_: TrackTemplate event: Event, info: RenderInfo, track_template_: TrackTemplate
) -> timedelta | None: ) -> timedelta | None:
"""Determine the rate limit for an event.""" """Determine the rate limit for an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
# Specifically referenced entities are excluded # Specifically referenced entities are excluded
# from the rate limit # from the rate limit
if entity_id in info.entities: if event.data.get(ATTR_ENTITY_ID) in info.entities:
return None return None
if track_template_.rate_limit is not None: if track_template_.rate_limit is not None:

View File

@ -366,9 +366,7 @@ async def async_process_deps_reqs(
Module is a Python module of either a component or platform. Module is a Python module of either a component or platform.
""" """
processed = hass.data.get(DATA_DEPS_REQS) if (processed := hass.data.get(DATA_DEPS_REQS)) is None:
if processed is None:
processed = hass.data[DATA_DEPS_REQS] = set() processed = hass.data[DATA_DEPS_REQS] = set()
elif integration.domain in processed: elif integration.domain in processed:
return return

View File

@ -132,8 +132,7 @@ def parse_datetime(dt_str: str) -> dt.datetime | None:
with suppress(ValueError, IndexError): with suppress(ValueError, IndexError):
return ciso8601.parse_datetime(dt_str) return ciso8601.parse_datetime(dt_str)
match = DATETIME_RE.match(dt_str) if not (match := DATETIME_RE.match(dt_str)):
if not match:
return None return None
kws: dict[str, Any] = match.groupdict() kws: dict[str, Any] = match.groupdict()
if kws["microsecond"]: if kws["microsecond"]:
@ -269,16 +268,14 @@ def find_next_time_expression_time(
Return None if no such value exists. Return None if no such value exists.
""" """
left = bisect.bisect_left(arr, cmp) if (left := bisect.bisect_left(arr, cmp)) == len(arr):
if left == len(arr):
return None return None
return arr[left] return arr[left]
result = now.replace(microsecond=0) result = now.replace(microsecond=0)
# Match next second # Match next second
next_second = _lower_bound(seconds, result.second) if (next_second := _lower_bound(seconds, result.second)) is None:
if next_second is None:
# No second to match in this minute. Roll-over to next minute. # No second to match in this minute. Roll-over to next minute.
next_second = seconds[0] next_second = seconds[0]
result += dt.timedelta(minutes=1) result += dt.timedelta(minutes=1)

View File

@ -43,8 +43,7 @@ def percentage_to_ordered_list_item(ordered_list: list[T], percentage: int) -> T
51-75: high 51-75: high
76-100: very_high 76-100: very_high
""" """
list_len = len(ordered_list) if not (list_len := len(ordered_list)):
if not list_len:
raise ValueError("The ordered list is empty") raise ValueError("The ordered list is empty")
for offset, speed in enumerate(ordered_list): for offset, speed in enumerate(ordered_list):

View File

@ -60,9 +60,7 @@ class Secrets:
def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]: def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
"""Load the secrets yaml from path.""" """Load the secrets yaml from path."""
secret_path = secret_dir / SECRET_YAML if (secret_path := secret_dir / SECRET_YAML) in self._cache:
if secret_path in self._cache:
return self._cache[secret_path] return self._cache[secret_path]
_LOGGER.debug("Loading %s", secret_path) _LOGGER.debug("Loading %s", secret_path)