Fix Telegram bot default target when sending messages (#147470)

* handle targets

* updated error message

* validate chat id for single target

* add validation for chat id

* handle empty target

* handle empty target
This commit is contained in:
hanwg 2025-06-26 22:43:09 +08:00 committed by GitHub
parent 40f553a007
commit 68924d23ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 102 additions and 37 deletions

View File

@ -29,6 +29,7 @@ from homeassistant.core import (
from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import config_validation as cv
@ -390,9 +391,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
elif msgtype == SERVICE_DELETE_MESSAGE:
await notify_service.delete_message(context=service.context, **kwargs)
elif msgtype == SERVICE_LEAVE_CHAT:
messages = await notify_service.leave_chat(
context=service.context, **kwargs
)
await notify_service.leave_chat(context=service.context, **kwargs)
elif msgtype == SERVICE_SET_MESSAGE_REACTION:
await notify_service.set_message_reaction(context=service.context, **kwargs)
else:
@ -400,12 +399,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
msgtype, context=service.context, **kwargs
)
if service.return_response and messages:
if service.return_response and messages is not None:
target: list[int] | None = service.data.get(ATTR_TARGET)
if not target:
target = notify_service.get_target_chat_ids(None)
failed_chat_ids = [chat_id for chat_id in target if chat_id not in messages]
if failed_chat_ids:
raise HomeAssistantError(
f"Failed targets: {failed_chat_ids}",
translation_domain=DOMAIN,
translation_key="failed_chat_ids",
translation_placeholders={
"chat_ids": ", ".join([str(i) for i in failed_chat_ids]),
"bot_name": config_entry.title,
},
)
return {
"chats": [
{"chat_id": cid, "message_id": mid} for cid, mid in messages.items()
]
}
return None
# Register notification services

View File

@ -287,24 +287,32 @@ class TelegramNotificationService:
inline_message_id = msg_data["inline_message_id"]
return message_id, inline_message_id
def _get_target_chat_ids(self, target: Any) -> list[int]:
def get_target_chat_ids(self, target: int | list[int] | None) -> list[int]:
"""Validate chat_id targets or return default target (first).
:param target: optional list of integers ([12234, -12345])
:return list of chat_id targets (integers)
"""
allowed_chat_ids: list[int] = self._get_allowed_chat_ids()
default_user: int = allowed_chat_ids[0]
if target is not None:
if isinstance(target, int):
target = [target]
chat_ids = [t for t in target if t in allowed_chat_ids]
if chat_ids:
return chat_ids
_LOGGER.warning(
"Disallowed targets: %s, using default: %s", target, default_user
if target is None:
return [allowed_chat_ids[0]]
chat_ids = [target] if isinstance(target, int) else target
valid_chat_ids = [
chat_id for chat_id in chat_ids if chat_id in allowed_chat_ids
]
if not valid_chat_ids:
raise ServiceValidationError(
"Invalid chat IDs",
translation_domain=DOMAIN,
translation_key="invalid_chat_ids",
translation_placeholders={
"chat_ids": ", ".join(str(chat_id) for chat_id in chat_ids),
"bot_name": self.config.title,
},
)
return [default_user]
return valid_chat_ids
def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]:
"""Get parameters in message data kwargs."""
@ -414,9 +422,9 @@ class TelegramNotificationService:
"""Send one message."""
try:
out = await func_send(*args_msg, **kwargs_msg)
if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID):
if isinstance(out, Message):
chat_id = out.chat_id
message_id = out[ATTR_MESSAGEID]
message_id = out.message_id
self._last_message_id[chat_id] = message_id
_LOGGER.debug(
"Last message ID: %s (from chat_id %s)",
@ -424,7 +432,7 @@ class TelegramNotificationService:
chat_id,
)
event_data = {
event_data: dict[str, Any] = {
ATTR_CHAT_ID: chat_id,
ATTR_MESSAGEID: message_id,
}
@ -437,10 +445,6 @@ class TelegramNotificationService:
self.hass.bus.async_fire(
EVENT_TELEGRAM_SENT, event_data, context=context
)
elif not isinstance(out, bool):
_LOGGER.warning(
"Update last message: out_type:%s, out=%s", type(out), out
)
except TelegramError as exc:
_LOGGER.error(
"%s: %s. Args: %s, kwargs: %s", msg_error, exc, args_msg, kwargs_msg
@ -460,7 +464,7 @@ class TelegramNotificationService:
text = f"{title}\n{message}" if title else message
params = self._get_msg_kwargs(kwargs)
msg_ids = {}
for chat_id in self._get_target_chat_ids(target):
for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Send message in chat ID %s with params: %s", chat_id, params)
msg = await self._send_msg(
self.bot.send_message,
@ -488,7 +492,7 @@ class TelegramNotificationService:
**kwargs: dict[str, Any],
) -> bool:
"""Delete a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id)
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
deleted: bool = await self._send_msg(
@ -513,7 +517,7 @@ class TelegramNotificationService:
**kwargs: dict[str, Any],
) -> Any:
"""Edit a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
params = self._get_msg_kwargs(kwargs)
_LOGGER.debug(
@ -620,7 +624,7 @@ class TelegramNotificationService:
msg_ids = {}
if file_content:
for chat_id in self._get_target_chat_ids(target):
for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Sending file to chat ID %s", chat_id)
if file_type == SERVICE_SEND_PHOTO:
@ -738,7 +742,7 @@ class TelegramNotificationService:
msg_ids = {}
if stickerid:
for chat_id in self._get_target_chat_ids(target):
for chat_id in self.get_target_chat_ids(target):
msg = await self._send_msg(
self.bot.send_sticker,
"Error sending sticker",
@ -769,7 +773,7 @@ class TelegramNotificationService:
longitude = float(longitude)
params = self._get_msg_kwargs(kwargs)
msg_ids = {}
for chat_id in self._get_target_chat_ids(target):
for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug(
"Send location %s/%s to chat ID %s", latitude, longitude, chat_id
)
@ -803,7 +807,7 @@ class TelegramNotificationService:
params = self._get_msg_kwargs(kwargs)
openperiod = kwargs.get(ATTR_OPEN_PERIOD)
msg_ids = {}
for chat_id in self._get_target_chat_ids(target):
for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Send poll '%s' to chat ID %s", question, chat_id)
msg = await self._send_msg(
self.bot.send_poll,
@ -826,12 +830,12 @@ class TelegramNotificationService:
async def leave_chat(
self,
chat_id: Any = None,
chat_id: int | None = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> Any:
"""Remove bot from chat."""
chat_id = self._get_target_chat_ids(chat_id)[0]
chat_id = self.get_target_chat_ids(chat_id)[0]
_LOGGER.debug("Leave from chat ID %s", chat_id)
return await self._send_msg(
self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context
@ -839,14 +843,14 @@ class TelegramNotificationService:
async def set_message_reaction(
self,
chat_id: int,
reaction: str,
chat_id: int | None = None,
is_big: bool = False,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> None:
"""Set the bot's reaction for a given message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id)
params = self._get_msg_kwargs(kwargs)

View File

@ -895,6 +895,12 @@
"missing_allowed_chat_ids": {
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
},
"invalid_chat_ids": {
"message": "Invalid chat IDs: {chat_ids}. Please configure the chat IDs for {bot_name}."
},
"failed_chat_ids": {
"message": "Failed targets: {chat_ids}. Please verify that the chat IDs for {bot_name} have been configured."
},
"missing_input": {
"message": "{field} is required."
},

View File

@ -677,13 +677,35 @@ async def test_send_message_with_config_entry(
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
await hass.async_block_till_done()
# test: send message to invalid chat id
with pytest.raises(HomeAssistantError) as err:
response = await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
{
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
ATTR_MESSAGE: "mock message",
ATTR_TARGET: [123456, 1],
},
blocking=True,
return_response=True,
)
await hass.async_block_till_done()
assert err.value.translation_key == "failed_chat_ids"
assert err.value.translation_placeholders["chat_ids"] == "1"
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
# test: send message to valid chat id
response = await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
{
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
ATTR_MESSAGE: "mock message",
ATTR_TARGET: 1,
ATTR_TARGET: 123456,
},
blocking=True,
return_response=True,
@ -767,6 +789,23 @@ async def test_delete_message(
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
await hass.async_block_till_done()
# test: delete message with invalid chat id
with pytest.raises(ServiceValidationError) as err:
await hass.services.async_call(
DOMAIN,
SERVICE_DELETE_MESSAGE,
{ATTR_CHAT_ID: 1, ATTR_MESSAGEID: "last"},
blocking=True,
)
await hass.async_block_till_done()
assert err.value.translation_key == "invalid_chat_ids"
assert err.value.translation_placeholders["chat_ids"] == "1"
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
# test: delete message with valid chat id
response = await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
@ -808,7 +847,7 @@ async def test_edit_message(
await hass.services.async_call(
DOMAIN,
SERVICE_EDIT_MESSAGE,
{ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
{ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True,
)
@ -822,7 +861,7 @@ async def test_edit_message(
await hass.services.async_call(
DOMAIN,
SERVICE_EDIT_CAPTION,
{ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
{ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True,
)
@ -836,7 +875,7 @@ async def test_edit_message(
await hass.services.async_call(
DOMAIN,
SERVICE_EDIT_REPLYMARKUP,
{ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345},
{ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True,
)