Improve Slack notify component (#45479)

* Add typing information

* Small improvments

* Use %r for exceptions
* Added exception handlers for aiohttp.ClientError
* Added testcase

* Changes after review

* Bugfixes
This commit is contained in:
Marc Mueller 2021-01-26 01:03:12 +01:00 committed by GitHub
parent ec47df4880
commit ad677b9d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 182 additions and 49 deletions

View File

@ -1,7 +1,10 @@
"""Slack platform for notify component.""" """Slack platform for notify component."""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
from typing import Any, List, Optional, TypedDict
from urllib.parse import urlparse from urllib.parse import urlparse
from aiohttp import BasicAuth, FormData from aiohttp import BasicAuth, FormData
@ -21,6 +24,11 @@ from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
import homeassistant.helpers.template as template import homeassistant.helpers.template as template
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
HomeAssistantType,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -74,7 +82,38 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
async def async_get_service(hass, config, discovery_info=None): class AuthDictT(TypedDict, total=False):
"""Type for auth request data."""
auth: BasicAuth
class FormDataT(TypedDict):
"""Type for form data, file upload."""
channels: str
filename: str
initial_comment: str
title: str
token: str
class MessageT(TypedDict, total=False):
"""Type for message data."""
link_names: bool
text: str
username: str # Optional key
icon_url: str # Optional key
icon_emoji: str # Optional key
blocks: List[Any] # Optional key
async def async_get_service(
hass: HomeAssistantType,
config: ConfigType,
discovery_info: Optional[DiscoveryInfoType] = None,
) -> Optional[SlackNotificationService]:
"""Set up the Slack notification service.""" """Set up the Slack notification service."""
session = aiohttp_client.async_get_clientsession(hass) session = aiohttp_client.async_get_clientsession(hass)
client = WebClient(token=config[CONF_API_KEY], run_async=True, session=session) client = WebClient(token=config[CONF_API_KEY], run_async=True, session=session)
@ -82,8 +121,14 @@ async def async_get_service(hass, config, discovery_info=None):
try: try:
await client.auth_test() await client.auth_test()
except SlackApiError as err: except SlackApiError as err:
_LOGGER.error("Error while setting up integration: %s", err) _LOGGER.error("Error while setting up integration: %r", err)
return return None
except ClientError as err:
_LOGGER.warning(
"Error testing connection to slack: %r "
"Continuing setup anyway, but notify service might not work",
err,
)
return SlackNotificationService( return SlackNotificationService(
hass, hass,
@ -95,20 +140,20 @@ async def async_get_service(hass, config, discovery_info=None):
@callback @callback
def _async_get_filename_from_url(url): def _async_get_filename_from_url(url: str) -> str:
"""Return the filename of a passed URL.""" """Return the filename of a passed URL."""
parsed_url = urlparse(url) parsed_url = urlparse(url)
return os.path.basename(parsed_url.path) return os.path.basename(parsed_url.path)
@callback @callback
def _async_sanitize_channel_names(channel_list): def _async_sanitize_channel_names(channel_list: List[str]) -> List[str]:
"""Remove any # symbols from a channel list.""" """Remove any # symbols from a channel list."""
return [channel.lstrip("#") for channel in channel_list] return [channel.lstrip("#") for channel in channel_list]
@callback @callback
def _async_templatize_blocks(hass, value): def _async_templatize_blocks(hass: HomeAssistantType, value: Any) -> Any:
"""Recursive template creator helper function.""" """Recursive template creator helper function."""
if isinstance(value, list): if isinstance(value, list):
return [_async_templatize_blocks(hass, item) for item in value] return [_async_templatize_blocks(hass, item) for item in value]
@ -117,14 +162,21 @@ def _async_templatize_blocks(hass, value):
key: _async_templatize_blocks(hass, item) for key, item in value.items() key: _async_templatize_blocks(hass, item) for key, item in value.items()
} }
tmpl = template.Template(value, hass=hass) tmpl = template.Template(value, hass=hass) # type: ignore # no-untyped-call
return tmpl.async_render(parse_result=False) return tmpl.async_render(parse_result=False)
class SlackNotificationService(BaseNotificationService): class SlackNotificationService(BaseNotificationService):
"""Define the Slack notification logic.""" """Define the Slack notification logic."""
def __init__(self, hass, client, default_channel, username, icon): def __init__(
self,
hass: HomeAssistantType,
client: WebClient,
default_channel: str,
username: Optional[str],
icon: Optional[str],
) -> None:
"""Initialize.""" """Initialize."""
self._client = client self._client = client
self._default_channel = default_channel self._default_channel = default_channel
@ -132,7 +184,13 @@ class SlackNotificationService(BaseNotificationService):
self._icon = icon self._icon = icon
self._username = username self._username = username
async def _async_send_local_file_message(self, path, targets, message, title): async def _async_send_local_file_message(
self,
path: str,
targets: List[str],
message: str,
title: Optional[str],
) -> None:
"""Upload a local file (with message) to Slack.""" """Upload a local file (with message) to Slack."""
if not self._hass.config.is_allowed_path(path): if not self._hass.config.is_allowed_path(path):
_LOGGER.error("Path does not exist or is not allowed: %s", path) _LOGGER.error("Path does not exist or is not allowed: %s", path)
@ -149,12 +207,19 @@ class SlackNotificationService(BaseNotificationService):
initial_comment=message, initial_comment=message,
title=title or filename, title=title or filename,
) )
except SlackApiError as err: except (SlackApiError, ClientError) as err:
_LOGGER.error("Error while uploading file-based message: %s", err) _LOGGER.error("Error while uploading file-based message: %r", err)
async def _async_send_remote_file_message( async def _async_send_remote_file_message(
self, url, targets, message, title, *, username=None, password=None self,
): url: str,
targets: List[str],
message: str,
title: Optional[str],
*,
username: Optional[str] = None,
password: Optional[str] = None,
) -> None:
"""Upload a remote file (with message) to Slack. """Upload a remote file (with message) to Slack.
Note that we bypass the python-slackclient WebClient and use aiohttp directly, Note that we bypass the python-slackclient WebClient and use aiohttp directly,
@ -166,9 +231,9 @@ class SlackNotificationService(BaseNotificationService):
return return
filename = _async_get_filename_from_url(url) filename = _async_get_filename_from_url(url)
session = aiohttp_client.async_get_clientsession(self.hass) session = aiohttp_client.async_get_clientsession(self._hass)
kwargs = {} kwargs: AuthDictT = {}
if username and password is not None: if username and password is not None:
kwargs = {"auth": BasicAuth(username, password=password)} kwargs = {"auth": BasicAuth(username, password=password)}
@ -177,49 +242,46 @@ class SlackNotificationService(BaseNotificationService):
try: try:
resp.raise_for_status() resp.raise_for_status()
except ClientError as err: except ClientError as err:
_LOGGER.error("Error while retrieving %s: %s", url, err) _LOGGER.error("Error while retrieving %s: %r", url, err)
return return
data = FormData( form_data: FormDataT = {
{ "channels": ",".join(targets),
"channels": ",".join(targets), "filename": filename,
"filename": filename, "initial_comment": message,
"initial_comment": message, "title": title or filename,
"title": title or filename, "token": self._client.token,
"token": self._client.token, }
},
charset="utf-8", data = FormData(form_data, charset="utf-8")
)
data.add_field("file", resp.content, filename=filename) data.add_field("file", resp.content, filename=filename)
try: try:
await session.post("https://slack.com/api/files.upload", data=data) await session.post("https://slack.com/api/files.upload", data=data)
except ClientError as err: except ClientError as err:
_LOGGER.error("Error while uploading file message: %s", err) _LOGGER.error("Error while uploading file message: %r", err)
async def _async_send_text_only_message( async def _async_send_text_only_message(
self, self,
targets, targets: List[str],
message, message: str,
title, title: Optional[str],
*, *,
username=None, username: Optional[str] = None,
icon=None, icon: Optional[str] = None,
blocks=None, blocks: Optional[Any] = None,
): ) -> None:
"""Send a text-only message.""" """Send a text-only message."""
message_dict = {"link_names": True, "text": message} message_dict: MessageT = {"link_names": True, "text": message}
if username: if username:
message_dict["username"] = username message_dict["username"] = username
if icon: if icon:
if icon.lower().startswith(("http://", "https://")): if icon.lower().startswith(("http://", "https://")):
icon_type = "url" message_dict["icon_url"] = icon
else: else:
icon_type = "emoji" message_dict["icon_emoji"] = icon
message_dict[f"icon_{icon_type}"] = icon
if blocks: if blocks:
message_dict["blocks"] = blocks message_dict["blocks"] = blocks
@ -233,17 +295,16 @@ class SlackNotificationService(BaseNotificationService):
for target, result in zip(tasks, results): for target, result in zip(tasks, results):
if isinstance(result, SlackApiError): if isinstance(result, SlackApiError):
_LOGGER.error( _LOGGER.error(
"There was a Slack API error while sending to %s: %s", "There was a Slack API error while sending to %s: %r",
target, target,
result, result,
) )
elif isinstance(result, ClientError):
_LOGGER.error("Error while sending message to %s: %r", target, result)
async def async_send_message(self, message, **kwargs): async def async_send_message(self, message: str, **kwargs: Any) -> None:
"""Send a message to Slack.""" """Send a message to Slack."""
data = kwargs.get(ATTR_DATA) data = kwargs.get(ATTR_DATA) or {}
if data is None:
data = {}
try: try:
DATA_SCHEMA(data) DATA_SCHEMA(data)
@ -259,7 +320,9 @@ class SlackNotificationService(BaseNotificationService):
# Message Type 1: A text-only message # Message Type 1: A text-only message
if ATTR_FILE not in data: if ATTR_FILE not in data:
if ATTR_BLOCKS_TEMPLATE in data: if ATTR_BLOCKS_TEMPLATE in data:
blocks = _async_templatize_blocks(self.hass, data[ATTR_BLOCKS_TEMPLATE]) blocks = _async_templatize_blocks(
self._hass, data[ATTR_BLOCKS_TEMPLATE]
)
elif ATTR_BLOCKS in data: elif ATTR_BLOCKS in data:
blocks = data[ATTR_BLOCKS] blocks = data[ATTR_BLOCKS]
else: else:

View File

@ -42,7 +42,7 @@ warn_redundant_casts = true
warn_unused_configs = true warn_unused_configs = true
[mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*] [mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.slack.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*]
strict = true strict = true
ignore_errors = false ignore_errors = false
warn_unreachable = true warn_unreachable = true

View File

@ -1,7 +1,77 @@
"""Test slack notifications.""" """Test slack notifications."""
from unittest.mock import AsyncMock, Mock import logging
from unittest.mock import AsyncMock, Mock, patch
from homeassistant.components.slack.notify import SlackNotificationService from _pytest.logging import LogCaptureFixture
import aiohttp
from slack.errors import SlackApiError
from homeassistant.components.slack.notify import (
CONF_DEFAULT_CHANNEL,
SlackNotificationService,
async_get_service,
)
from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME
from homeassistant.helpers.typing import HomeAssistantType
MODULE_PATH = "homeassistant.components.slack.notify"
async def test_get_service(hass: HomeAssistantType, caplog: LogCaptureFixture):
"""Test async_get_service with exceptions."""
config = {
CONF_API_KEY: "12345",
CONF_DEFAULT_CHANNEL: "channel",
}
with patch(MODULE_PATH + ".aiohttp_client") as mock_session, patch(
MODULE_PATH + ".WebClient"
) as mock_client, patch(
MODULE_PATH + ".SlackNotificationService"
) as mock_slack_service:
mock_session.async_get_clientsession.return_value = session = Mock()
mock_client.return_value = client = AsyncMock()
# Normal setup
mock_slack_service.return_value = service = Mock()
assert await async_get_service(hass, config) == service
mock_slack_service.assert_called_once_with(
hass, client, "channel", username=None, icon=None
)
mock_client.assert_called_with(token="12345", run_async=True, session=session)
client.auth_test.assert_called_once_with()
mock_slack_service.assert_called_once_with(
hass, client, "channel", username=None, icon=None
)
mock_slack_service.reset_mock()
# aiohttp.ClientError
config.update({CONF_USERNAME: "user", CONF_ICON: "icon"})
mock_slack_service.reset_mock()
mock_slack_service.return_value = service = Mock()
client.auth_test.side_effect = [aiohttp.ClientError]
assert await async_get_service(hass, config) == service
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelno == logging.WARNING
assert aiohttp.ClientError.__qualname__ in record.message
caplog.records.clear()
mock_slack_service.assert_called_once_with(
hass, client, "channel", username="user", icon="icon"
)
mock_slack_service.reset_mock()
# SlackApiError
err, level = SlackApiError("msg", "resp"), logging.ERROR
client.auth_test.side_effect = [err]
assert await async_get_service(hass, config) is None
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelno == level
assert err.__class__.__qualname__ in record.message
caplog.records.clear()
mock_slack_service.assert_not_called()
mock_slack_service.reset_mock()
async def test_message_includes_default_emoji(): async def test_message_includes_default_emoji():