From ad677b9d417a054263184a3f3dd947a02a9ed28e Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 26 Jan 2021 01:03:12 +0100 Subject: [PATCH] Improve Slack notify component (#45479) * Add typing information * Small improvments * Use %r for exceptions * Added exception handlers for aiohttp.ClientError * Added testcase * Changes after review * Bugfixes --- homeassistant/components/slack/notify.py | 155 ++++++++++++++++------- setup.cfg | 2 +- tests/components/slack/test_notify.py | 74 ++++++++++- 3 files changed, 182 insertions(+), 49 deletions(-) diff --git a/homeassistant/components/slack/notify.py b/homeassistant/components/slack/notify.py index 90caad62a58..985f59a6715 100644 --- a/homeassistant/components/slack/notify.py +++ b/homeassistant/components/slack/notify.py @@ -1,7 +1,10 @@ """Slack platform for notify component.""" +from __future__ import annotations + import asyncio import logging import os +from typing import Any, List, Optional, TypedDict from urllib.parse import urlparse from aiohttp import BasicAuth, FormData @@ -21,6 +24,11 @@ from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME from homeassistant.core import callback from homeassistant.helpers import aiohttp_client, config_validation as cv import homeassistant.helpers.template as template +from homeassistant.helpers.typing import ( + ConfigType, + DiscoveryInfoType, + HomeAssistantType, +) _LOGGER = logging.getLogger(__name__) @@ -74,7 +82,38 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( ) -async def async_get_service(hass, config, discovery_info=None): +class AuthDictT(TypedDict, total=False): + """Type for auth request data.""" + + auth: BasicAuth + + +class FormDataT(TypedDict): + """Type for form data, file upload.""" + + channels: str + filename: str + initial_comment: str + title: str + token: str + + +class MessageT(TypedDict, total=False): + """Type for message data.""" + + link_names: bool + text: str + username: str # Optional key + icon_url: str # Optional key + icon_emoji: str # Optional key + blocks: List[Any] # Optional key + + +async def async_get_service( + hass: HomeAssistantType, + config: ConfigType, + discovery_info: Optional[DiscoveryInfoType] = None, +) -> Optional[SlackNotificationService]: """Set up the Slack notification service.""" session = aiohttp_client.async_get_clientsession(hass) client = WebClient(token=config[CONF_API_KEY], run_async=True, session=session) @@ -82,8 +121,14 @@ async def async_get_service(hass, config, discovery_info=None): try: await client.auth_test() except SlackApiError as err: - _LOGGER.error("Error while setting up integration: %s", err) - return + _LOGGER.error("Error while setting up integration: %r", err) + return None + except ClientError as err: + _LOGGER.warning( + "Error testing connection to slack: %r " + "Continuing setup anyway, but notify service might not work", + err, + ) return SlackNotificationService( hass, @@ -95,20 +140,20 @@ async def async_get_service(hass, config, discovery_info=None): @callback -def _async_get_filename_from_url(url): +def _async_get_filename_from_url(url: str) -> str: """Return the filename of a passed URL.""" parsed_url = urlparse(url) return os.path.basename(parsed_url.path) @callback -def _async_sanitize_channel_names(channel_list): +def _async_sanitize_channel_names(channel_list: List[str]) -> List[str]: """Remove any # symbols from a channel list.""" return [channel.lstrip("#") for channel in channel_list] @callback -def _async_templatize_blocks(hass, value): +def _async_templatize_blocks(hass: HomeAssistantType, value: Any) -> Any: """Recursive template creator helper function.""" if isinstance(value, list): return [_async_templatize_blocks(hass, item) for item in value] @@ -117,14 +162,21 @@ def _async_templatize_blocks(hass, value): key: _async_templatize_blocks(hass, item) for key, item in value.items() } - tmpl = template.Template(value, hass=hass) + tmpl = template.Template(value, hass=hass) # type: ignore # no-untyped-call return tmpl.async_render(parse_result=False) class SlackNotificationService(BaseNotificationService): """Define the Slack notification logic.""" - def __init__(self, hass, client, default_channel, username, icon): + def __init__( + self, + hass: HomeAssistantType, + client: WebClient, + default_channel: str, + username: Optional[str], + icon: Optional[str], + ) -> None: """Initialize.""" self._client = client self._default_channel = default_channel @@ -132,7 +184,13 @@ class SlackNotificationService(BaseNotificationService): self._icon = icon self._username = username - async def _async_send_local_file_message(self, path, targets, message, title): + async def _async_send_local_file_message( + self, + path: str, + targets: List[str], + message: str, + title: Optional[str], + ) -> None: """Upload a local file (with message) to Slack.""" if not self._hass.config.is_allowed_path(path): _LOGGER.error("Path does not exist or is not allowed: %s", path) @@ -149,12 +207,19 @@ class SlackNotificationService(BaseNotificationService): initial_comment=message, title=title or filename, ) - except SlackApiError as err: - _LOGGER.error("Error while uploading file-based message: %s", err) + except (SlackApiError, ClientError) as err: + _LOGGER.error("Error while uploading file-based message: %r", err) async def _async_send_remote_file_message( - self, url, targets, message, title, *, username=None, password=None - ): + self, + url: str, + targets: List[str], + message: str, + title: Optional[str], + *, + username: Optional[str] = None, + password: Optional[str] = None, + ) -> None: """Upload a remote file (with message) to Slack. Note that we bypass the python-slackclient WebClient and use aiohttp directly, @@ -166,9 +231,9 @@ class SlackNotificationService(BaseNotificationService): return filename = _async_get_filename_from_url(url) - session = aiohttp_client.async_get_clientsession(self.hass) + session = aiohttp_client.async_get_clientsession(self._hass) - kwargs = {} + kwargs: AuthDictT = {} if username and password is not None: kwargs = {"auth": BasicAuth(username, password=password)} @@ -177,49 +242,46 @@ class SlackNotificationService(BaseNotificationService): try: resp.raise_for_status() except ClientError as err: - _LOGGER.error("Error while retrieving %s: %s", url, err) + _LOGGER.error("Error while retrieving %s: %r", url, err) return - data = FormData( - { - "channels": ",".join(targets), - "filename": filename, - "initial_comment": message, - "title": title or filename, - "token": self._client.token, - }, - charset="utf-8", - ) + form_data: FormDataT = { + "channels": ",".join(targets), + "filename": filename, + "initial_comment": message, + "title": title or filename, + "token": self._client.token, + } + + data = FormData(form_data, charset="utf-8") data.add_field("file", resp.content, filename=filename) try: await session.post("https://slack.com/api/files.upload", data=data) except ClientError as err: - _LOGGER.error("Error while uploading file message: %s", err) + _LOGGER.error("Error while uploading file message: %r", err) async def _async_send_text_only_message( self, - targets, - message, - title, + targets: List[str], + message: str, + title: Optional[str], *, - username=None, - icon=None, - blocks=None, - ): + username: Optional[str] = None, + icon: Optional[str] = None, + blocks: Optional[Any] = None, + ) -> None: """Send a text-only message.""" - message_dict = {"link_names": True, "text": message} + message_dict: MessageT = {"link_names": True, "text": message} if username: message_dict["username"] = username if icon: if icon.lower().startswith(("http://", "https://")): - icon_type = "url" + message_dict["icon_url"] = icon else: - icon_type = "emoji" - - message_dict[f"icon_{icon_type}"] = icon + message_dict["icon_emoji"] = icon if blocks: message_dict["blocks"] = blocks @@ -233,17 +295,16 @@ class SlackNotificationService(BaseNotificationService): for target, result in zip(tasks, results): if isinstance(result, SlackApiError): _LOGGER.error( - "There was a Slack API error while sending to %s: %s", + "There was a Slack API error while sending to %s: %r", target, result, ) + elif isinstance(result, ClientError): + _LOGGER.error("Error while sending message to %s: %r", target, result) - async def async_send_message(self, message, **kwargs): + async def async_send_message(self, message: str, **kwargs: Any) -> None: """Send a message to Slack.""" - data = kwargs.get(ATTR_DATA) - - if data is None: - data = {} + data = kwargs.get(ATTR_DATA) or {} try: DATA_SCHEMA(data) @@ -259,7 +320,9 @@ class SlackNotificationService(BaseNotificationService): # Message Type 1: A text-only message if ATTR_FILE not in data: if ATTR_BLOCKS_TEMPLATE in data: - blocks = _async_templatize_blocks(self.hass, data[ATTR_BLOCKS_TEMPLATE]) + blocks = _async_templatize_blocks( + self._hass, data[ATTR_BLOCKS_TEMPLATE] + ) elif ATTR_BLOCKS in data: blocks = data[ATTR_BLOCKS] else: diff --git a/setup.cfg b/setup.cfg index 4137554257f..7761ff2d67e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ warn_redundant_casts = true warn_unused_configs = true -[mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*] +[mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.slack.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*] strict = true ignore_errors = false warn_unreachable = true diff --git a/tests/components/slack/test_notify.py b/tests/components/slack/test_notify.py index 7f974b557fe..af304365158 100644 --- a/tests/components/slack/test_notify.py +++ b/tests/components/slack/test_notify.py @@ -1,7 +1,77 @@ """Test slack notifications.""" -from unittest.mock import AsyncMock, Mock +import logging +from unittest.mock import AsyncMock, Mock, patch -from homeassistant.components.slack.notify import SlackNotificationService +from _pytest.logging import LogCaptureFixture +import aiohttp +from slack.errors import SlackApiError + +from homeassistant.components.slack.notify import ( + CONF_DEFAULT_CHANNEL, + SlackNotificationService, + async_get_service, +) +from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME +from homeassistant.helpers.typing import HomeAssistantType + +MODULE_PATH = "homeassistant.components.slack.notify" + + +async def test_get_service(hass: HomeAssistantType, caplog: LogCaptureFixture): + """Test async_get_service with exceptions.""" + config = { + CONF_API_KEY: "12345", + CONF_DEFAULT_CHANNEL: "channel", + } + + with patch(MODULE_PATH + ".aiohttp_client") as mock_session, patch( + MODULE_PATH + ".WebClient" + ) as mock_client, patch( + MODULE_PATH + ".SlackNotificationService" + ) as mock_slack_service: + mock_session.async_get_clientsession.return_value = session = Mock() + mock_client.return_value = client = AsyncMock() + + # Normal setup + mock_slack_service.return_value = service = Mock() + assert await async_get_service(hass, config) == service + mock_slack_service.assert_called_once_with( + hass, client, "channel", username=None, icon=None + ) + mock_client.assert_called_with(token="12345", run_async=True, session=session) + client.auth_test.assert_called_once_with() + mock_slack_service.assert_called_once_with( + hass, client, "channel", username=None, icon=None + ) + mock_slack_service.reset_mock() + + # aiohttp.ClientError + config.update({CONF_USERNAME: "user", CONF_ICON: "icon"}) + mock_slack_service.reset_mock() + mock_slack_service.return_value = service = Mock() + client.auth_test.side_effect = [aiohttp.ClientError] + assert await async_get_service(hass, config) == service + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelno == logging.WARNING + assert aiohttp.ClientError.__qualname__ in record.message + caplog.records.clear() + mock_slack_service.assert_called_once_with( + hass, client, "channel", username="user", icon="icon" + ) + mock_slack_service.reset_mock() + + # SlackApiError + err, level = SlackApiError("msg", "resp"), logging.ERROR + client.auth_test.side_effect = [err] + assert await async_get_service(hass, config) is None + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelno == level + assert err.__class__.__qualname__ in record.message + caplog.records.clear() + mock_slack_service.assert_not_called() + mock_slack_service.reset_mock() async def test_message_includes_default_emoji():