From 051531d9c1230c44000686d9f5c61837171f0de8 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 1 Apr 2021 16:22:08 -0700 Subject: [PATCH] Clean up mobile app (#48607) Co-authored-by: Martin Hjelmare --- .../components/mobile_app/binary_sensor.py | 6 ++-- homeassistant/components/mobile_app/entity.py | 12 +++----- homeassistant/components/mobile_app/notify.py | 11 +++---- homeassistant/components/mobile_app/sensor.py | 6 ++-- .../components/mobile_app/webhook.py | 8 ++--- homeassistant/util/logging.py | 6 +++- tests/util/test_logging.py | 29 +++++++++++++++++++ 7 files changed, 51 insertions(+), 27 deletions(-) diff --git a/homeassistant/components/mobile_app/binary_sensor.py b/homeassistant/components/mobile_app/binary_sensor.py index 36897dd9f69..616cd97a775 100644 --- a/homeassistant/components/mobile_app/binary_sensor.py +++ b/homeassistant/components/mobile_app/binary_sensor.py @@ -1,6 +1,4 @@ """Binary sensor platform for mobile_app.""" -from functools import partial - from homeassistant.components.binary_sensor import BinarySensorEntity from homeassistant.const import CONF_NAME, CONF_UNIQUE_ID, CONF_WEBHOOK_ID, STATE_ON from homeassistant.core import callback @@ -48,7 +46,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_add_entities(entities) @callback - def handle_sensor_registration(webhook_id, data): + def handle_sensor_registration(data): if data[CONF_WEBHOOK_ID] != webhook_id: return @@ -66,7 +64,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_dispatcher_connect( hass, f"{DOMAIN}_{ENTITY_TYPE}_register", - partial(handle_sensor_registration, webhook_id), + handle_sensor_registration, ) diff --git a/homeassistant/components/mobile_app/entity.py b/homeassistant/components/mobile_app/entity.py index 2f30c4b9f1b..46f4589fa2c 100644 --- a/homeassistant/components/mobile_app/entity.py +++ b/homeassistant/components/mobile_app/entity.py @@ -34,13 +34,14 @@ class MobileAppEntity(RestoreEntity): self._registration = entry.data self._unique_id = config[CONF_UNIQUE_ID] self._entity_type = config[ATTR_SENSOR_TYPE] - self.unsub_dispatcher = None self._name = config[CONF_NAME] async def async_added_to_hass(self): """Register callbacks.""" - self.unsub_dispatcher = async_dispatcher_connect( - self.hass, SIGNAL_SENSOR_UPDATE, self._handle_update + self.async_on_remove( + async_dispatcher_connect( + self.hass, SIGNAL_SENSOR_UPDATE, self._handle_update + ) ) state = await self.async_get_last_state() @@ -49,11 +50,6 @@ class MobileAppEntity(RestoreEntity): self.async_restore_last_state(state) - async def async_will_remove_from_hass(self): - """Disconnect dispatcher listener when removed.""" - if self.unsub_dispatcher is not None: - self.unsub_dispatcher() - @callback def async_restore_last_state(self, last_state): """Restore previous state.""" diff --git a/homeassistant/components/mobile_app/notify.py b/homeassistant/components/mobile_app/notify.py index 763186df998..803f00764e7 100644 --- a/homeassistant/components/mobile_app/notify.py +++ b/homeassistant/components/mobile_app/notify.py @@ -84,17 +84,16 @@ def log_rate_limits(hass, device_name, resp, level=logging.INFO): async def async_get_service(hass, config, discovery_info=None): """Get the mobile_app notification service.""" - session = async_get_clientsession(hass) - service = hass.data[DOMAIN][DATA_NOTIFY] = MobileAppNotificationService(session) + service = hass.data[DOMAIN][DATA_NOTIFY] = MobileAppNotificationService(hass) return service class MobileAppNotificationService(BaseNotificationService): """Implement the notification service for mobile_app.""" - def __init__(self, session): + def __init__(self, hass): """Initialize the service.""" - self._session = session + self._hass = hass @property def targets(self): @@ -141,7 +140,9 @@ class MobileAppNotificationService(BaseNotificationService): try: with async_timeout.timeout(10): - response = await self._session.post(push_url, json=data) + response = await async_get_clientsession(self._hass).post( + push_url, json=data + ) result = await response.json() if response.status in [HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED]: diff --git a/homeassistant/components/mobile_app/sensor.py b/homeassistant/components/mobile_app/sensor.py index 3f4c7d56f3f..7e3c1c13148 100644 --- a/homeassistant/components/mobile_app/sensor.py +++ b/homeassistant/components/mobile_app/sensor.py @@ -1,6 +1,4 @@ """Sensor platform for mobile_app.""" -from functools import partial - from homeassistant.components.sensor import SensorEntity from homeassistant.const import CONF_NAME, CONF_UNIQUE_ID, CONF_WEBHOOK_ID from homeassistant.core import callback @@ -50,7 +48,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_add_entities(entities) @callback - def handle_sensor_registration(webhook_id, data): + def handle_sensor_registration(data): if data[CONF_WEBHOOK_ID] != webhook_id: return @@ -68,7 +66,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_dispatcher_connect( hass, f"{DOMAIN}_{ENTITY_TYPE}_register", - partial(handle_sensor_registration, webhook_id), + handle_sensor_registration, ) diff --git a/homeassistant/components/mobile_app/webhook.py b/homeassistant/components/mobile_app/webhook.py index efef6eb1c8a..6be39f34f00 100644 --- a/homeassistant/components/mobile_app/webhook.py +++ b/homeassistant/components/mobile_app/webhook.py @@ -472,6 +472,7 @@ async def webhook_update_sensor_states(hass, config_entry, data): device_name = config_entry.data[ATTR_DEVICE_NAME] resp = {} + for sensor in data: entity_type = sensor[ATTR_SENSOR_TYPE] @@ -495,8 +496,6 @@ async def webhook_update_sensor_states(hass, config_entry, data): } continue - entry = {CONF_WEBHOOK_ID: config_entry.data[CONF_WEBHOOK_ID]} - try: sensor = sensor_schema_full(sensor) except vol.Invalid as err: @@ -513,9 +512,8 @@ async def webhook_update_sensor_states(hass, config_entry, data): } continue - new_state = {**entry, **sensor} - - async_dispatcher_send(hass, SIGNAL_SENSOR_UPDATE, new_state) + sensor[CONF_WEBHOOK_ID] = config_entry.data[CONF_WEBHOOK_ID] + async_dispatcher_send(hass, SIGNAL_SENSOR_UPDATE, sensor) resp[unique_id] = {"success": True} diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index 5653523b677..ba846c0e8b4 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -11,7 +11,7 @@ import traceback from typing import Any, Awaitable, Callable, Coroutine, cast, overload from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback, is_callback class HideSensitiveDataFilter(logging.Filter): @@ -138,6 +138,7 @@ def catch_log_exception( log_exception(format_err, *args) wrapper_func = async_wrapper + else: @wraps(func) @@ -148,6 +149,9 @@ def catch_log_exception( except Exception: # pylint: disable=broad-except log_exception(format_err, *args) + if is_callback(check_func): + wrapper = callback(wrapper) + wrapper_func = wrapper return wrapper_func diff --git a/tests/util/test_logging.py b/tests/util/test_logging.py index 1a82c35e82d..9277d92f368 100644 --- a/tests/util/test_logging.py +++ b/tests/util/test_logging.py @@ -1,11 +1,13 @@ """Test Home Assistant logging util methods.""" import asyncio +from functools import partial import logging import queue from unittest.mock import patch import pytest +from homeassistant.core import callback, is_callback import homeassistant.util.logging as logging_util @@ -80,3 +82,30 @@ async def test_async_create_catching_coro(hass, caplog): await hass.async_block_till_done() assert "This is a bad coroutine" in caplog.text assert "in test_async_create_catching_coro" in caplog.text + + +def test_catch_log_exception(): + """Test it is still a callback after wrapping including partial.""" + + async def async_meth(): + pass + + assert asyncio.iscoroutinefunction( + logging_util.catch_log_exception(partial(async_meth), lambda: None) + ) + + @callback + def callback_meth(): + pass + + assert is_callback( + logging_util.catch_log_exception(partial(callback_meth), lambda: None) + ) + + def sync_meth(): + pass + + wrapped = logging_util.catch_log_exception(partial(sync_meth), lambda: None) + + assert not is_callback(wrapped) + assert not asyncio.iscoroutinefunction(wrapped)