mirror of
https://github.com/home-assistant/core.git
synced 2025-04-22 16:27:56 +00:00
Improve Slack notify component (#45479)
* Add typing information * Small improvments * Use %r for exceptions * Added exception handlers for aiohttp.ClientError * Added testcase * Changes after review * Bugfixes
This commit is contained in:
parent
ec47df4880
commit
ad677b9d41
@ -1,7 +1,10 @@
|
||||
"""Slack platform for notify component."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from aiohttp import BasicAuth, FormData
|
||||
@ -21,6 +24,11 @@ from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import aiohttp_client, config_validation as cv
|
||||
import homeassistant.helpers.template as template
|
||||
from homeassistant.helpers.typing import (
|
||||
ConfigType,
|
||||
DiscoveryInfoType,
|
||||
HomeAssistantType,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -74,7 +82,38 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
)
|
||||
|
||||
|
||||
async def async_get_service(hass, config, discovery_info=None):
|
||||
class AuthDictT(TypedDict, total=False):
|
||||
"""Type for auth request data."""
|
||||
|
||||
auth: BasicAuth
|
||||
|
||||
|
||||
class FormDataT(TypedDict):
|
||||
"""Type for form data, file upload."""
|
||||
|
||||
channels: str
|
||||
filename: str
|
||||
initial_comment: str
|
||||
title: str
|
||||
token: str
|
||||
|
||||
|
||||
class MessageT(TypedDict, total=False):
|
||||
"""Type for message data."""
|
||||
|
||||
link_names: bool
|
||||
text: str
|
||||
username: str # Optional key
|
||||
icon_url: str # Optional key
|
||||
icon_emoji: str # Optional key
|
||||
blocks: List[Any] # Optional key
|
||||
|
||||
|
||||
async def async_get_service(
|
||||
hass: HomeAssistantType,
|
||||
config: ConfigType,
|
||||
discovery_info: Optional[DiscoveryInfoType] = None,
|
||||
) -> Optional[SlackNotificationService]:
|
||||
"""Set up the Slack notification service."""
|
||||
session = aiohttp_client.async_get_clientsession(hass)
|
||||
client = WebClient(token=config[CONF_API_KEY], run_async=True, session=session)
|
||||
@ -82,8 +121,14 @@ async def async_get_service(hass, config, discovery_info=None):
|
||||
try:
|
||||
await client.auth_test()
|
||||
except SlackApiError as err:
|
||||
_LOGGER.error("Error while setting up integration: %s", err)
|
||||
return
|
||||
_LOGGER.error("Error while setting up integration: %r", err)
|
||||
return None
|
||||
except ClientError as err:
|
||||
_LOGGER.warning(
|
||||
"Error testing connection to slack: %r "
|
||||
"Continuing setup anyway, but notify service might not work",
|
||||
err,
|
||||
)
|
||||
|
||||
return SlackNotificationService(
|
||||
hass,
|
||||
@ -95,20 +140,20 @@ async def async_get_service(hass, config, discovery_info=None):
|
||||
|
||||
|
||||
@callback
|
||||
def _async_get_filename_from_url(url):
|
||||
def _async_get_filename_from_url(url: str) -> str:
|
||||
"""Return the filename of a passed URL."""
|
||||
parsed_url = urlparse(url)
|
||||
return os.path.basename(parsed_url.path)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_sanitize_channel_names(channel_list):
|
||||
def _async_sanitize_channel_names(channel_list: List[str]) -> List[str]:
|
||||
"""Remove any # symbols from a channel list."""
|
||||
return [channel.lstrip("#") for channel in channel_list]
|
||||
|
||||
|
||||
@callback
|
||||
def _async_templatize_blocks(hass, value):
|
||||
def _async_templatize_blocks(hass: HomeAssistantType, value: Any) -> Any:
|
||||
"""Recursive template creator helper function."""
|
||||
if isinstance(value, list):
|
||||
return [_async_templatize_blocks(hass, item) for item in value]
|
||||
@ -117,14 +162,21 @@ def _async_templatize_blocks(hass, value):
|
||||
key: _async_templatize_blocks(hass, item) for key, item in value.items()
|
||||
}
|
||||
|
||||
tmpl = template.Template(value, hass=hass)
|
||||
tmpl = template.Template(value, hass=hass) # type: ignore # no-untyped-call
|
||||
return tmpl.async_render(parse_result=False)
|
||||
|
||||
|
||||
class SlackNotificationService(BaseNotificationService):
|
||||
"""Define the Slack notification logic."""
|
||||
|
||||
def __init__(self, hass, client, default_channel, username, icon):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistantType,
|
||||
client: WebClient,
|
||||
default_channel: str,
|
||||
username: Optional[str],
|
||||
icon: Optional[str],
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
self._client = client
|
||||
self._default_channel = default_channel
|
||||
@ -132,7 +184,13 @@ class SlackNotificationService(BaseNotificationService):
|
||||
self._icon = icon
|
||||
self._username = username
|
||||
|
||||
async def _async_send_local_file_message(self, path, targets, message, title):
|
||||
async def _async_send_local_file_message(
|
||||
self,
|
||||
path: str,
|
||||
targets: List[str],
|
||||
message: str,
|
||||
title: Optional[str],
|
||||
) -> None:
|
||||
"""Upload a local file (with message) to Slack."""
|
||||
if not self._hass.config.is_allowed_path(path):
|
||||
_LOGGER.error("Path does not exist or is not allowed: %s", path)
|
||||
@ -149,12 +207,19 @@ class SlackNotificationService(BaseNotificationService):
|
||||
initial_comment=message,
|
||||
title=title or filename,
|
||||
)
|
||||
except SlackApiError as err:
|
||||
_LOGGER.error("Error while uploading file-based message: %s", err)
|
||||
except (SlackApiError, ClientError) as err:
|
||||
_LOGGER.error("Error while uploading file-based message: %r", err)
|
||||
|
||||
async def _async_send_remote_file_message(
|
||||
self, url, targets, message, title, *, username=None, password=None
|
||||
):
|
||||
self,
|
||||
url: str,
|
||||
targets: List[str],
|
||||
message: str,
|
||||
title: Optional[str],
|
||||
*,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Upload a remote file (with message) to Slack.
|
||||
|
||||
Note that we bypass the python-slackclient WebClient and use aiohttp directly,
|
||||
@ -166,9 +231,9 @@ class SlackNotificationService(BaseNotificationService):
|
||||
return
|
||||
|
||||
filename = _async_get_filename_from_url(url)
|
||||
session = aiohttp_client.async_get_clientsession(self.hass)
|
||||
session = aiohttp_client.async_get_clientsession(self._hass)
|
||||
|
||||
kwargs = {}
|
||||
kwargs: AuthDictT = {}
|
||||
if username and password is not None:
|
||||
kwargs = {"auth": BasicAuth(username, password=password)}
|
||||
|
||||
@ -177,49 +242,46 @@ class SlackNotificationService(BaseNotificationService):
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except ClientError as err:
|
||||
_LOGGER.error("Error while retrieving %s: %s", url, err)
|
||||
_LOGGER.error("Error while retrieving %s: %r", url, err)
|
||||
return
|
||||
|
||||
data = FormData(
|
||||
{
|
||||
"channels": ",".join(targets),
|
||||
"filename": filename,
|
||||
"initial_comment": message,
|
||||
"title": title or filename,
|
||||
"token": self._client.token,
|
||||
},
|
||||
charset="utf-8",
|
||||
)
|
||||
form_data: FormDataT = {
|
||||
"channels": ",".join(targets),
|
||||
"filename": filename,
|
||||
"initial_comment": message,
|
||||
"title": title or filename,
|
||||
"token": self._client.token,
|
||||
}
|
||||
|
||||
data = FormData(form_data, charset="utf-8")
|
||||
data.add_field("file", resp.content, filename=filename)
|
||||
|
||||
try:
|
||||
await session.post("https://slack.com/api/files.upload", data=data)
|
||||
except ClientError as err:
|
||||
_LOGGER.error("Error while uploading file message: %s", err)
|
||||
_LOGGER.error("Error while uploading file message: %r", err)
|
||||
|
||||
async def _async_send_text_only_message(
|
||||
self,
|
||||
targets,
|
||||
message,
|
||||
title,
|
||||
targets: List[str],
|
||||
message: str,
|
||||
title: Optional[str],
|
||||
*,
|
||||
username=None,
|
||||
icon=None,
|
||||
blocks=None,
|
||||
):
|
||||
username: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
blocks: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Send a text-only message."""
|
||||
message_dict = {"link_names": True, "text": message}
|
||||
message_dict: MessageT = {"link_names": True, "text": message}
|
||||
|
||||
if username:
|
||||
message_dict["username"] = username
|
||||
|
||||
if icon:
|
||||
if icon.lower().startswith(("http://", "https://")):
|
||||
icon_type = "url"
|
||||
message_dict["icon_url"] = icon
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
|
||||
message_dict[f"icon_{icon_type}"] = icon
|
||||
message_dict["icon_emoji"] = icon
|
||||
|
||||
if blocks:
|
||||
message_dict["blocks"] = blocks
|
||||
@ -233,17 +295,16 @@ class SlackNotificationService(BaseNotificationService):
|
||||
for target, result in zip(tasks, results):
|
||||
if isinstance(result, SlackApiError):
|
||||
_LOGGER.error(
|
||||
"There was a Slack API error while sending to %s: %s",
|
||||
"There was a Slack API error while sending to %s: %r",
|
||||
target,
|
||||
result,
|
||||
)
|
||||
elif isinstance(result, ClientError):
|
||||
_LOGGER.error("Error while sending message to %s: %r", target, result)
|
||||
|
||||
async def async_send_message(self, message, **kwargs):
|
||||
async def async_send_message(self, message: str, **kwargs: Any) -> None:
|
||||
"""Send a message to Slack."""
|
||||
data = kwargs.get(ATTR_DATA)
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
data = kwargs.get(ATTR_DATA) or {}
|
||||
|
||||
try:
|
||||
DATA_SCHEMA(data)
|
||||
@ -259,7 +320,9 @@ class SlackNotificationService(BaseNotificationService):
|
||||
# Message Type 1: A text-only message
|
||||
if ATTR_FILE not in data:
|
||||
if ATTR_BLOCKS_TEMPLATE in data:
|
||||
blocks = _async_templatize_blocks(self.hass, data[ATTR_BLOCKS_TEMPLATE])
|
||||
blocks = _async_templatize_blocks(
|
||||
self._hass, data[ATTR_BLOCKS_TEMPLATE]
|
||||
)
|
||||
elif ATTR_BLOCKS in data:
|
||||
blocks = data[ATTR_BLOCKS]
|
||||
else:
|
||||
|
@ -42,7 +42,7 @@ warn_redundant_casts = true
|
||||
warn_unused_configs = true
|
||||
|
||||
|
||||
[mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*]
|
||||
[mypy-homeassistant.block_async_io,homeassistant.bootstrap,homeassistant.components,homeassistant.config_entries,homeassistant.config,homeassistant.const,homeassistant.core,homeassistant.data_entry_flow,homeassistant.exceptions,homeassistant.__init__,homeassistant.loader,homeassistant.__main__,homeassistant.requirements,homeassistant.runner,homeassistant.setup,homeassistant.util,homeassistant.auth.*,homeassistant.components.automation.*,homeassistant.components.binary_sensor.*,homeassistant.components.calendar.*,homeassistant.components.cover.*,homeassistant.components.device_automation.*,homeassistant.components.frontend.*,homeassistant.components.geo_location.*,homeassistant.components.group.*,homeassistant.components.history.*,homeassistant.components.http.*,homeassistant.components.huawei_lte.*,homeassistant.components.hyperion.*,homeassistant.components.image_processing.*,homeassistant.components.integration.*,homeassistant.components.light.*,homeassistant.components.lock.*,homeassistant.components.mailbox.*,homeassistant.components.media_player.*,homeassistant.components.notify.*,homeassistant.components.number.*,homeassistant.components.persistent_notification.*,homeassistant.components.proximity.*,homeassistant.components.remote.*,homeassistant.components.scene.*,homeassistant.components.sensor.*,homeassistant.components.slack.*,homeassistant.components.sun.*,homeassistant.components.switch.*,homeassistant.components.systemmonitor.*,homeassistant.components.tts.*,homeassistant.components.vacuum.*,homeassistant.components.water_heater.*,homeassistant.components.weather.*,homeassistant.components.websocket_api.*,homeassistant.components.zone.*,homeassistant.components.zwave_js.*,homeassistant.helpers.*,homeassistant.scripts.*,homeassistant.util.*,tests.components.hyperion.*]
|
||||
strict = true
|
||||
ignore_errors = false
|
||||
warn_unreachable = true
|
||||
|
@ -1,7 +1,77 @@
|
||||
"""Test slack notifications."""
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from homeassistant.components.slack.notify import SlackNotificationService
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
import aiohttp
|
||||
from slack.errors import SlackApiError
|
||||
|
||||
from homeassistant.components.slack.notify import (
|
||||
CONF_DEFAULT_CHANNEL,
|
||||
SlackNotificationService,
|
||||
async_get_service,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_ICON, CONF_USERNAME
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
MODULE_PATH = "homeassistant.components.slack.notify"
|
||||
|
||||
|
||||
async def test_get_service(hass: HomeAssistantType, caplog: LogCaptureFixture):
|
||||
"""Test async_get_service with exceptions."""
|
||||
config = {
|
||||
CONF_API_KEY: "12345",
|
||||
CONF_DEFAULT_CHANNEL: "channel",
|
||||
}
|
||||
|
||||
with patch(MODULE_PATH + ".aiohttp_client") as mock_session, patch(
|
||||
MODULE_PATH + ".WebClient"
|
||||
) as mock_client, patch(
|
||||
MODULE_PATH + ".SlackNotificationService"
|
||||
) as mock_slack_service:
|
||||
mock_session.async_get_clientsession.return_value = session = Mock()
|
||||
mock_client.return_value = client = AsyncMock()
|
||||
|
||||
# Normal setup
|
||||
mock_slack_service.return_value = service = Mock()
|
||||
assert await async_get_service(hass, config) == service
|
||||
mock_slack_service.assert_called_once_with(
|
||||
hass, client, "channel", username=None, icon=None
|
||||
)
|
||||
mock_client.assert_called_with(token="12345", run_async=True, session=session)
|
||||
client.auth_test.assert_called_once_with()
|
||||
mock_slack_service.assert_called_once_with(
|
||||
hass, client, "channel", username=None, icon=None
|
||||
)
|
||||
mock_slack_service.reset_mock()
|
||||
|
||||
# aiohttp.ClientError
|
||||
config.update({CONF_USERNAME: "user", CONF_ICON: "icon"})
|
||||
mock_slack_service.reset_mock()
|
||||
mock_slack_service.return_value = service = Mock()
|
||||
client.auth_test.side_effect = [aiohttp.ClientError]
|
||||
assert await async_get_service(hass, config) == service
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.levelno == logging.WARNING
|
||||
assert aiohttp.ClientError.__qualname__ in record.message
|
||||
caplog.records.clear()
|
||||
mock_slack_service.assert_called_once_with(
|
||||
hass, client, "channel", username="user", icon="icon"
|
||||
)
|
||||
mock_slack_service.reset_mock()
|
||||
|
||||
# SlackApiError
|
||||
err, level = SlackApiError("msg", "resp"), logging.ERROR
|
||||
client.auth_test.side_effect = [err]
|
||||
assert await async_get_service(hass, config) is None
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.levelno == level
|
||||
assert err.__class__.__qualname__ in record.message
|
||||
caplog.records.clear()
|
||||
mock_slack_service.assert_not_called()
|
||||
mock_slack_service.reset_mock()
|
||||
|
||||
|
||||
async def test_message_includes_default_emoji():
|
||||
|
Loading…
x
Reference in New Issue
Block a user