mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Fix Telegram bot default target when sending messages (#147470)
* handle targets * updated error message * validate chat id for single target * add validation for chat id * handle empty target * handle empty target
This commit is contained in:
parent
40f553a007
commit
68924d23ab
@ -29,6 +29,7 @@ from homeassistant.core import (
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
ConfigEntryNotReady,
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
@ -390,9 +391,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
elif msgtype == SERVICE_DELETE_MESSAGE:
|
||||
await notify_service.delete_message(context=service.context, **kwargs)
|
||||
elif msgtype == SERVICE_LEAVE_CHAT:
|
||||
messages = await notify_service.leave_chat(
|
||||
context=service.context, **kwargs
|
||||
)
|
||||
await notify_service.leave_chat(context=service.context, **kwargs)
|
||||
elif msgtype == SERVICE_SET_MESSAGE_REACTION:
|
||||
await notify_service.set_message_reaction(context=service.context, **kwargs)
|
||||
else:
|
||||
@ -400,12 +399,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
msgtype, context=service.context, **kwargs
|
||||
)
|
||||
|
||||
if service.return_response and messages:
|
||||
if service.return_response and messages is not None:
|
||||
target: list[int] | None = service.data.get(ATTR_TARGET)
|
||||
if not target:
|
||||
target = notify_service.get_target_chat_ids(None)
|
||||
|
||||
failed_chat_ids = [chat_id for chat_id in target if chat_id not in messages]
|
||||
if failed_chat_ids:
|
||||
raise HomeAssistantError(
|
||||
f"Failed targets: {failed_chat_ids}",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_chat_ids",
|
||||
translation_placeholders={
|
||||
"chat_ids": ", ".join([str(i) for i in failed_chat_ids]),
|
||||
"bot_name": config_entry.title,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"chats": [
|
||||
{"chat_id": cid, "message_id": mid} for cid, mid in messages.items()
|
||||
]
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
# Register notification services
|
||||
|
@ -287,24 +287,32 @@ class TelegramNotificationService:
|
||||
inline_message_id = msg_data["inline_message_id"]
|
||||
return message_id, inline_message_id
|
||||
|
||||
def _get_target_chat_ids(self, target: Any) -> list[int]:
|
||||
def get_target_chat_ids(self, target: int | list[int] | None) -> list[int]:
|
||||
"""Validate chat_id targets or return default target (first).
|
||||
|
||||
:param target: optional list of integers ([12234, -12345])
|
||||
:return list of chat_id targets (integers)
|
||||
"""
|
||||
allowed_chat_ids: list[int] = self._get_allowed_chat_ids()
|
||||
default_user: int = allowed_chat_ids[0]
|
||||
if target is not None:
|
||||
if isinstance(target, int):
|
||||
target = [target]
|
||||
chat_ids = [t for t in target if t in allowed_chat_ids]
|
||||
if chat_ids:
|
||||
return chat_ids
|
||||
_LOGGER.warning(
|
||||
"Disallowed targets: %s, using default: %s", target, default_user
|
||||
|
||||
if target is None:
|
||||
return [allowed_chat_ids[0]]
|
||||
|
||||
chat_ids = [target] if isinstance(target, int) else target
|
||||
valid_chat_ids = [
|
||||
chat_id for chat_id in chat_ids if chat_id in allowed_chat_ids
|
||||
]
|
||||
if not valid_chat_ids:
|
||||
raise ServiceValidationError(
|
||||
"Invalid chat IDs",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="invalid_chat_ids",
|
||||
translation_placeholders={
|
||||
"chat_ids": ", ".join(str(chat_id) for chat_id in chat_ids),
|
||||
"bot_name": self.config.title,
|
||||
},
|
||||
)
|
||||
return [default_user]
|
||||
return valid_chat_ids
|
||||
|
||||
def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get parameters in message data kwargs."""
|
||||
@ -414,9 +422,9 @@ class TelegramNotificationService:
|
||||
"""Send one message."""
|
||||
try:
|
||||
out = await func_send(*args_msg, **kwargs_msg)
|
||||
if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID):
|
||||
if isinstance(out, Message):
|
||||
chat_id = out.chat_id
|
||||
message_id = out[ATTR_MESSAGEID]
|
||||
message_id = out.message_id
|
||||
self._last_message_id[chat_id] = message_id
|
||||
_LOGGER.debug(
|
||||
"Last message ID: %s (from chat_id %s)",
|
||||
@ -424,7 +432,7 @@ class TelegramNotificationService:
|
||||
chat_id,
|
||||
)
|
||||
|
||||
event_data = {
|
||||
event_data: dict[str, Any] = {
|
||||
ATTR_CHAT_ID: chat_id,
|
||||
ATTR_MESSAGEID: message_id,
|
||||
}
|
||||
@ -437,10 +445,6 @@ class TelegramNotificationService:
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_TELEGRAM_SENT, event_data, context=context
|
||||
)
|
||||
elif not isinstance(out, bool):
|
||||
_LOGGER.warning(
|
||||
"Update last message: out_type:%s, out=%s", type(out), out
|
||||
)
|
||||
except TelegramError as exc:
|
||||
_LOGGER.error(
|
||||
"%s: %s. Args: %s, kwargs: %s", msg_error, exc, args_msg, kwargs_msg
|
||||
@ -460,7 +464,7 @@ class TelegramNotificationService:
|
||||
text = f"{title}\n{message}" if title else message
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
msg_ids = {}
|
||||
for chat_id in self._get_target_chat_ids(target):
|
||||
for chat_id in self.get_target_chat_ids(target):
|
||||
_LOGGER.debug("Send message in chat ID %s with params: %s", chat_id, params)
|
||||
msg = await self._send_msg(
|
||||
self.bot.send_message,
|
||||
@ -488,7 +492,7 @@ class TelegramNotificationService:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Delete a previously sent message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
chat_id = self.get_target_chat_ids(chat_id)[0]
|
||||
message_id, _ = self._get_msg_ids(kwargs, chat_id)
|
||||
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
|
||||
deleted: bool = await self._send_msg(
|
||||
@ -513,7 +517,7 @@ class TelegramNotificationService:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Edit a previously sent message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
chat_id = self.get_target_chat_ids(chat_id)[0]
|
||||
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
_LOGGER.debug(
|
||||
@ -620,7 +624,7 @@ class TelegramNotificationService:
|
||||
|
||||
msg_ids = {}
|
||||
if file_content:
|
||||
for chat_id in self._get_target_chat_ids(target):
|
||||
for chat_id in self.get_target_chat_ids(target):
|
||||
_LOGGER.debug("Sending file to chat ID %s", chat_id)
|
||||
|
||||
if file_type == SERVICE_SEND_PHOTO:
|
||||
@ -738,7 +742,7 @@ class TelegramNotificationService:
|
||||
|
||||
msg_ids = {}
|
||||
if stickerid:
|
||||
for chat_id in self._get_target_chat_ids(target):
|
||||
for chat_id in self.get_target_chat_ids(target):
|
||||
msg = await self._send_msg(
|
||||
self.bot.send_sticker,
|
||||
"Error sending sticker",
|
||||
@ -769,7 +773,7 @@ class TelegramNotificationService:
|
||||
longitude = float(longitude)
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
msg_ids = {}
|
||||
for chat_id in self._get_target_chat_ids(target):
|
||||
for chat_id in self.get_target_chat_ids(target):
|
||||
_LOGGER.debug(
|
||||
"Send location %s/%s to chat ID %s", latitude, longitude, chat_id
|
||||
)
|
||||
@ -803,7 +807,7 @@ class TelegramNotificationService:
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
openperiod = kwargs.get(ATTR_OPEN_PERIOD)
|
||||
msg_ids = {}
|
||||
for chat_id in self._get_target_chat_ids(target):
|
||||
for chat_id in self.get_target_chat_ids(target):
|
||||
_LOGGER.debug("Send poll '%s' to chat ID %s", question, chat_id)
|
||||
msg = await self._send_msg(
|
||||
self.bot.send_poll,
|
||||
@ -826,12 +830,12 @@ class TelegramNotificationService:
|
||||
|
||||
async def leave_chat(
|
||||
self,
|
||||
chat_id: Any = None,
|
||||
chat_id: int | None = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Remove bot from chat."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
chat_id = self.get_target_chat_ids(chat_id)[0]
|
||||
_LOGGER.debug("Leave from chat ID %s", chat_id)
|
||||
return await self._send_msg(
|
||||
self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context
|
||||
@ -839,14 +843,14 @@ class TelegramNotificationService:
|
||||
|
||||
async def set_message_reaction(
|
||||
self,
|
||||
chat_id: int,
|
||||
reaction: str,
|
||||
chat_id: int | None = None,
|
||||
is_big: bool = False,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set the bot's reaction for a given message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
chat_id = self.get_target_chat_ids(chat_id)[0]
|
||||
message_id, _ = self._get_msg_ids(kwargs, chat_id)
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
|
||||
|
@ -895,6 +895,12 @@
|
||||
"missing_allowed_chat_ids": {
|
||||
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
|
||||
},
|
||||
"invalid_chat_ids": {
|
||||
"message": "Invalid chat IDs: {chat_ids}. Please configure the chat IDs for {bot_name}."
|
||||
},
|
||||
"failed_chat_ids": {
|
||||
"message": "Failed targets: {chat_ids}. Please verify that the chat IDs for {bot_name} have been configured."
|
||||
},
|
||||
"missing_input": {
|
||||
"message": "{field} is required."
|
||||
},
|
||||
|
@ -677,13 +677,35 @@ async def test_send_message_with_config_entry(
|
||||
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# test: send message to invalid chat id
|
||||
|
||||
with pytest.raises(HomeAssistantError) as err:
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_MESSAGE,
|
||||
{
|
||||
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
|
||||
ATTR_MESSAGE: "mock message",
|
||||
ATTR_TARGET: [123456, 1],
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert err.value.translation_key == "failed_chat_ids"
|
||||
assert err.value.translation_placeholders["chat_ids"] == "1"
|
||||
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
|
||||
|
||||
# test: send message to valid chat id
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_MESSAGE,
|
||||
{
|
||||
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
|
||||
ATTR_MESSAGE: "mock message",
|
||||
ATTR_TARGET: 1,
|
||||
ATTR_TARGET: 123456,
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
@ -767,6 +789,23 @@ async def test_delete_message(
|
||||
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# test: delete message with invalid chat id
|
||||
|
||||
with pytest.raises(ServiceValidationError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_DELETE_MESSAGE,
|
||||
{ATTR_CHAT_ID: 1, ATTR_MESSAGEID: "last"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert err.value.translation_key == "invalid_chat_ids"
|
||||
assert err.value.translation_placeholders["chat_ids"] == "1"
|
||||
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
|
||||
|
||||
# test: delete message with valid chat id
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_MESSAGE,
|
||||
@ -808,7 +847,7 @@ async def test_edit_message(
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_EDIT_MESSAGE,
|
||||
{ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
|
||||
{ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
@ -822,7 +861,7 @@ async def test_edit_message(
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_EDIT_CAPTION,
|
||||
{ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
|
||||
{ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
@ -836,7 +875,7 @@ async def test_edit_message(
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_EDIT_REPLYMARKUP,
|
||||
{ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
|
||||
{ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user