Add Slack thread/reply support (#93384)

This commit is contained in:
Fletcher 2023-09-21 17:06:55 +08:00 committed by GitHub
parent e4742c04f2
commit 11c4c37cf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 6 deletions

View File

@ -1145,8 +1145,8 @@ build.json @home-assistant/supervisor
/homeassistant/components/sky_hub/ @rogerselwyn /homeassistant/components/sky_hub/ @rogerselwyn
/homeassistant/components/skybell/ @tkdrob /homeassistant/components/skybell/ @tkdrob
/tests/components/skybell/ @tkdrob /tests/components/skybell/ @tkdrob
/homeassistant/components/slack/ @tkdrob /homeassistant/components/slack/ @tkdrob @fletcherau
/tests/components/slack/ @tkdrob /tests/components/slack/ @tkdrob @fletcherau
/homeassistant/components/sleepiq/ @mfugate1 @kbickar /homeassistant/components/sleepiq/ @mfugate1 @kbickar
/tests/components/sleepiq/ @mfugate1 @kbickar /tests/components/sleepiq/ @mfugate1 @kbickar
/homeassistant/components/slide/ @ualex73 /homeassistant/components/slide/ @ualex73

View File

@ -10,6 +10,7 @@ ATTR_SNOOZE = "snooze_endtime"
ATTR_URL = "url" ATTR_URL = "url"
ATTR_USERNAME = "username" ATTR_USERNAME = "username"
ATTR_USER_ID = "user_id" ATTR_USER_ID = "user_id"
ATTR_THREAD_TS = "thread_ts"
CONF_DEFAULT_CHANNEL = "default_channel" CONF_DEFAULT_CHANNEL = "default_channel"

View File

@ -1,7 +1,7 @@
{ {
"domain": "slack", "domain": "slack",
"name": "Slack", "name": "Slack",
"codeowners": ["@tkdrob"], "codeowners": ["@tkdrob", "@fletcherau"],
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/slack", "documentation": "https://www.home-assistant.io/integrations/slack",
"integration_type": "service", "integration_type": "service",

View File

@ -30,6 +30,7 @@ from .const import (
ATTR_FILE, ATTR_FILE,
ATTR_PASSWORD, ATTR_PASSWORD,
ATTR_PATH, ATTR_PATH,
ATTR_THREAD_TS,
ATTR_URL, ATTR_URL,
ATTR_USERNAME, ATTR_USERNAME,
CONF_DEFAULT_CHANNEL, CONF_DEFAULT_CHANNEL,
@ -50,7 +51,10 @@ FILE_URL_SCHEMA = vol.Schema(
) )
DATA_FILE_SCHEMA = vol.Schema( DATA_FILE_SCHEMA = vol.Schema(
{vol.Required(ATTR_FILE): vol.Any(FILE_PATH_SCHEMA, FILE_URL_SCHEMA)} {
vol.Required(ATTR_FILE): vol.Any(FILE_PATH_SCHEMA, FILE_URL_SCHEMA),
vol.Optional(ATTR_THREAD_TS): cv.string,
}
) )
DATA_TEXT_ONLY_SCHEMA = vol.Schema( DATA_TEXT_ONLY_SCHEMA = vol.Schema(
@ -59,6 +63,7 @@ DATA_TEXT_ONLY_SCHEMA = vol.Schema(
vol.Optional(ATTR_ICON): cv.string, vol.Optional(ATTR_ICON): cv.string,
vol.Optional(ATTR_BLOCKS): list, vol.Optional(ATTR_BLOCKS): list,
vol.Optional(ATTR_BLOCKS_TEMPLATE): list, vol.Optional(ATTR_BLOCKS_TEMPLATE): list,
vol.Optional(ATTR_THREAD_TS): cv.string,
} }
) )
@ -73,7 +78,7 @@ class AuthDictT(TypedDict, total=False):
auth: BasicAuth auth: BasicAuth
class FormDataT(TypedDict): class FormDataT(TypedDict, total=False):
"""Type for form data, file upload.""" """Type for form data, file upload."""
channels: str channels: str
@ -81,6 +86,7 @@ class FormDataT(TypedDict):
initial_comment: str initial_comment: str
title: str title: str
token: str token: str
thread_ts: str # Optional key
class MessageT(TypedDict, total=False): class MessageT(TypedDict, total=False):
@ -92,6 +98,7 @@ class MessageT(TypedDict, total=False):
icon_url: str # Optional key icon_url: str # Optional key
icon_emoji: str # Optional key icon_emoji: str # Optional key
blocks: list[Any] # Optional key blocks: list[Any] # Optional key
thread_ts: str # Optional key
async def async_get_service( async def async_get_service(
@ -142,6 +149,7 @@ class SlackNotificationService(BaseNotificationService):
targets: list[str], targets: list[str],
message: str, message: str,
title: str | None, title: str | None,
thread_ts: str | None,
) -> None: ) -> None:
"""Upload a local file (with message) to Slack.""" """Upload a local file (with message) to Slack."""
if not self._hass.config.is_allowed_path(path): if not self._hass.config.is_allowed_path(path):
@ -158,6 +166,7 @@ class SlackNotificationService(BaseNotificationService):
filename=filename, filename=filename,
initial_comment=message, initial_comment=message,
title=title or filename, title=title or filename,
thread_ts=thread_ts,
) )
except (SlackApiError, ClientError) as err: except (SlackApiError, ClientError) as err:
_LOGGER.error("Error while uploading file-based message: %r", err) _LOGGER.error("Error while uploading file-based message: %r", err)
@ -168,6 +177,7 @@ class SlackNotificationService(BaseNotificationService):
targets: list[str], targets: list[str],
message: str, message: str,
title: str | None, title: str | None,
thread_ts: str | None,
*, *,
username: str | None = None, username: str | None = None,
password: str | None = None, password: str | None = None,
@ -205,6 +215,9 @@ class SlackNotificationService(BaseNotificationService):
"token": self._client.token, "token": self._client.token,
} }
if thread_ts:
form_data["thread_ts"] = thread_ts
data = FormData(form_data, charset="utf-8") data = FormData(form_data, charset="utf-8")
data.add_field("file", resp.content, filename=filename) data.add_field("file", resp.content, filename=filename)
@ -218,6 +231,7 @@ class SlackNotificationService(BaseNotificationService):
targets: list[str], targets: list[str],
message: str, message: str,
title: str | None, title: str | None,
thread_ts: str | None,
*, *,
username: str | None = None, username: str | None = None,
icon: str | None = None, icon: str | None = None,
@ -238,6 +252,9 @@ class SlackNotificationService(BaseNotificationService):
if blocks: if blocks:
message_dict["blocks"] = blocks message_dict["blocks"] = blocks
if thread_ts:
message_dict["thread_ts"] = thread_ts
tasks = { tasks = {
target: self._client.chat_postMessage(**message_dict, channel=target) target: self._client.chat_postMessage(**message_dict, channel=target)
for target in targets for target in targets
@ -286,6 +303,7 @@ class SlackNotificationService(BaseNotificationService):
title, title,
username=data.get(ATTR_USERNAME, self._config.get(ATTR_USERNAME)), username=data.get(ATTR_USERNAME, self._config.get(ATTR_USERNAME)),
icon=data.get(ATTR_ICON, self._config.get(ATTR_ICON)), icon=data.get(ATTR_ICON, self._config.get(ATTR_ICON)),
thread_ts=data.get(ATTR_THREAD_TS),
blocks=blocks, blocks=blocks,
) )
@ -296,11 +314,16 @@ class SlackNotificationService(BaseNotificationService):
targets, targets,
message, message,
title, title,
thread_ts=data.get(ATTR_THREAD_TS),
username=data[ATTR_FILE].get(ATTR_USERNAME), username=data[ATTR_FILE].get(ATTR_USERNAME),
password=data[ATTR_FILE].get(ATTR_PASSWORD), password=data[ATTR_FILE].get(ATTR_PASSWORD),
) )
# Message Type 3: A message that uploads a local file # Message Type 3: A message that uploads a local file
return await self._async_send_local_file_message( return await self._async_send_local_file_message(
data[ATTR_FILE][ATTR_PATH], targets, message, title data[ATTR_FILE][ATTR_PATH],
targets,
message,
title,
thread_ts=data.get(ATTR_THREAD_TS),
) )

View File

@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, Mock
from homeassistant.components import notify from homeassistant.components import notify
from homeassistant.components.slack import DOMAIN from homeassistant.components.slack import DOMAIN
from homeassistant.components.slack.notify import ( from homeassistant.components.slack.notify import (
ATTR_THREAD_TS,
CONF_DEFAULT_CHANNEL, CONF_DEFAULT_CHANNEL,
SlackNotificationService, SlackNotificationService,
) )
@ -93,3 +94,18 @@ async def test_message_icon_url_overrides_default() -> None:
mock_fn.assert_called_once() mock_fn.assert_called_once()
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs["icon_url"] == expected_icon assert kwargs["icon_url"] == expected_icon
async def test_message_as_reply() -> None:
"""Tests that a message pointer will be passed to Slack if specified."""
mock_client = Mock()
mock_client.chat_postMessage = AsyncMock()
service = SlackNotificationService(None, mock_client, CONF_DATA)
expected_ts = "1624146685.064129"
await service.async_send_message("test", data={ATTR_THREAD_TS: expected_ts})
mock_fn = mock_client.chat_postMessage
mock_fn.assert_called_once()
_, kwargs = mock_fn.call_args
assert kwargs["thread_ts"] == expected_ts