Fix Telegram bot default target when sending messages (#147470)

* handle targets

* updated error message

* validate chat id for single target

* add validation for chat id

* handle empty target

* handle empty target
This commit is contained in:
hanwg 2025-06-26 22:43:09 +08:00 committed by GitHub
parent 40f553a007
commit 68924d23ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 102 additions and 37 deletions

View File

@ -29,6 +29,7 @@ from homeassistant.core import (
from homeassistant.exceptions import ( from homeassistant.exceptions import (
ConfigEntryAuthFailed, ConfigEntryAuthFailed,
ConfigEntryNotReady, ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError, ServiceValidationError,
) )
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
@ -390,9 +391,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
elif msgtype == SERVICE_DELETE_MESSAGE: elif msgtype == SERVICE_DELETE_MESSAGE:
await notify_service.delete_message(context=service.context, **kwargs) await notify_service.delete_message(context=service.context, **kwargs)
elif msgtype == SERVICE_LEAVE_CHAT: elif msgtype == SERVICE_LEAVE_CHAT:
messages = await notify_service.leave_chat( await notify_service.leave_chat(context=service.context, **kwargs)
context=service.context, **kwargs
)
elif msgtype == SERVICE_SET_MESSAGE_REACTION: elif msgtype == SERVICE_SET_MESSAGE_REACTION:
await notify_service.set_message_reaction(context=service.context, **kwargs) await notify_service.set_message_reaction(context=service.context, **kwargs)
else: else:
@ -400,12 +399,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
msgtype, context=service.context, **kwargs msgtype, context=service.context, **kwargs
) )
if service.return_response and messages: if service.return_response and messages is not None:
target: list[int] | None = service.data.get(ATTR_TARGET)
if not target:
target = notify_service.get_target_chat_ids(None)
failed_chat_ids = [chat_id for chat_id in target if chat_id not in messages]
if failed_chat_ids:
raise HomeAssistantError(
f"Failed targets: {failed_chat_ids}",
translation_domain=DOMAIN,
translation_key="failed_chat_ids",
translation_placeholders={
"chat_ids": ", ".join([str(i) for i in failed_chat_ids]),
"bot_name": config_entry.title,
},
)
return { return {
"chats": [ "chats": [
{"chat_id": cid, "message_id": mid} for cid, mid in messages.items() {"chat_id": cid, "message_id": mid} for cid, mid in messages.items()
] ]
} }
return None return None
# Register notification services # Register notification services

View File

@ -287,24 +287,32 @@ class TelegramNotificationService:
inline_message_id = msg_data["inline_message_id"] inline_message_id = msg_data["inline_message_id"]
return message_id, inline_message_id return message_id, inline_message_id
def _get_target_chat_ids(self, target: Any) -> list[int]: def get_target_chat_ids(self, target: int | list[int] | None) -> list[int]:
"""Validate chat_id targets or return default target (first). """Validate chat_id targets or return default target (first).
:param target: optional list of integers ([12234, -12345]) :param target: optional list of integers ([12234, -12345])
:return list of chat_id targets (integers) :return list of chat_id targets (integers)
""" """
allowed_chat_ids: list[int] = self._get_allowed_chat_ids() allowed_chat_ids: list[int] = self._get_allowed_chat_ids()
default_user: int = allowed_chat_ids[0]
if target is not None: if target is None:
if isinstance(target, int): return [allowed_chat_ids[0]]
target = [target]
chat_ids = [t for t in target if t in allowed_chat_ids] chat_ids = [target] if isinstance(target, int) else target
if chat_ids: valid_chat_ids = [
return chat_ids chat_id for chat_id in chat_ids if chat_id in allowed_chat_ids
_LOGGER.warning( ]
"Disallowed targets: %s, using default: %s", target, default_user if not valid_chat_ids:
raise ServiceValidationError(
"Invalid chat IDs",
translation_domain=DOMAIN,
translation_key="invalid_chat_ids",
translation_placeholders={
"chat_ids": ", ".join(str(chat_id) for chat_id in chat_ids),
"bot_name": self.config.title,
},
) )
return [default_user] return valid_chat_ids
def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]: def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]:
"""Get parameters in message data kwargs.""" """Get parameters in message data kwargs."""
@ -414,9 +422,9 @@ class TelegramNotificationService:
"""Send one message.""" """Send one message."""
try: try:
out = await func_send(*args_msg, **kwargs_msg) out = await func_send(*args_msg, **kwargs_msg)
if not isinstance(out, bool) and hasattr(out, ATTR_MESSAGEID): if isinstance(out, Message):
chat_id = out.chat_id chat_id = out.chat_id
message_id = out[ATTR_MESSAGEID] message_id = out.message_id
self._last_message_id[chat_id] = message_id self._last_message_id[chat_id] = message_id
_LOGGER.debug( _LOGGER.debug(
"Last message ID: %s (from chat_id %s)", "Last message ID: %s (from chat_id %s)",
@ -424,7 +432,7 @@ class TelegramNotificationService:
chat_id, chat_id,
) )
event_data = { event_data: dict[str, Any] = {
ATTR_CHAT_ID: chat_id, ATTR_CHAT_ID: chat_id,
ATTR_MESSAGEID: message_id, ATTR_MESSAGEID: message_id,
} }
@ -437,10 +445,6 @@ class TelegramNotificationService:
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_TELEGRAM_SENT, event_data, context=context EVENT_TELEGRAM_SENT, event_data, context=context
) )
elif not isinstance(out, bool):
_LOGGER.warning(
"Update last message: out_type:%s, out=%s", type(out), out
)
except TelegramError as exc: except TelegramError as exc:
_LOGGER.error( _LOGGER.error(
"%s: %s. Args: %s, kwargs: %s", msg_error, exc, args_msg, kwargs_msg "%s: %s. Args: %s, kwargs: %s", msg_error, exc, args_msg, kwargs_msg
@ -460,7 +464,7 @@ class TelegramNotificationService:
text = f"{title}\n{message}" if title else message text = f"{title}\n{message}" if title else message
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
msg_ids = {} msg_ids = {}
for chat_id in self._get_target_chat_ids(target): for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Send message in chat ID %s with params: %s", chat_id, params) _LOGGER.debug("Send message in chat ID %s with params: %s", chat_id, params)
msg = await self._send_msg( msg = await self._send_msg(
self.bot.send_message, self.bot.send_message,
@ -488,7 +492,7 @@ class TelegramNotificationService:
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> bool: ) -> bool:
"""Delete a previously sent message.""" """Delete a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id) message_id, _ = self._get_msg_ids(kwargs, chat_id)
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id) _LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
deleted: bool = await self._send_msg( deleted: bool = await self._send_msg(
@ -513,7 +517,7 @@ class TelegramNotificationService:
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> Any: ) -> Any:
"""Edit a previously sent message.""" """Edit a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id) message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
_LOGGER.debug( _LOGGER.debug(
@ -620,7 +624,7 @@ class TelegramNotificationService:
msg_ids = {} msg_ids = {}
if file_content: if file_content:
for chat_id in self._get_target_chat_ids(target): for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Sending file to chat ID %s", chat_id) _LOGGER.debug("Sending file to chat ID %s", chat_id)
if file_type == SERVICE_SEND_PHOTO: if file_type == SERVICE_SEND_PHOTO:
@ -738,7 +742,7 @@ class TelegramNotificationService:
msg_ids = {} msg_ids = {}
if stickerid: if stickerid:
for chat_id in self._get_target_chat_ids(target): for chat_id in self.get_target_chat_ids(target):
msg = await self._send_msg( msg = await self._send_msg(
self.bot.send_sticker, self.bot.send_sticker,
"Error sending sticker", "Error sending sticker",
@ -769,7 +773,7 @@ class TelegramNotificationService:
longitude = float(longitude) longitude = float(longitude)
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
msg_ids = {} msg_ids = {}
for chat_id in self._get_target_chat_ids(target): for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug( _LOGGER.debug(
"Send location %s/%s to chat ID %s", latitude, longitude, chat_id "Send location %s/%s to chat ID %s", latitude, longitude, chat_id
) )
@ -803,7 +807,7 @@ class TelegramNotificationService:
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)
openperiod = kwargs.get(ATTR_OPEN_PERIOD) openperiod = kwargs.get(ATTR_OPEN_PERIOD)
msg_ids = {} msg_ids = {}
for chat_id in self._get_target_chat_ids(target): for chat_id in self.get_target_chat_ids(target):
_LOGGER.debug("Send poll '%s' to chat ID %s", question, chat_id) _LOGGER.debug("Send poll '%s' to chat ID %s", question, chat_id)
msg = await self._send_msg( msg = await self._send_msg(
self.bot.send_poll, self.bot.send_poll,
@ -826,12 +830,12 @@ class TelegramNotificationService:
async def leave_chat( async def leave_chat(
self, self,
chat_id: Any = None, chat_id: int | None = None,
context: Context | None = None, context: Context | None = None,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> Any: ) -> Any:
"""Remove bot from chat.""" """Remove bot from chat."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self.get_target_chat_ids(chat_id)[0]
_LOGGER.debug("Leave from chat ID %s", chat_id) _LOGGER.debug("Leave from chat ID %s", chat_id)
return await self._send_msg( return await self._send_msg(
self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context self.bot.leave_chat, "Error leaving chat", None, chat_id, context=context
@ -839,14 +843,14 @@ class TelegramNotificationService:
async def set_message_reaction( async def set_message_reaction(
self, self,
chat_id: int,
reaction: str, reaction: str,
chat_id: int | None = None,
is_big: bool = False, is_big: bool = False,
context: Context | None = None, context: Context | None = None,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> None: ) -> None:
"""Set the bot's reaction for a given message.""" """Set the bot's reaction for a given message."""
chat_id = self._get_target_chat_ids(chat_id)[0] chat_id = self.get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id) message_id, _ = self._get_msg_ids(kwargs, chat_id)
params = self._get_msg_kwargs(kwargs) params = self._get_msg_kwargs(kwargs)

View File

@ -895,6 +895,12 @@
"missing_allowed_chat_ids": { "missing_allowed_chat_ids": {
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}." "message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
}, },
"invalid_chat_ids": {
"message": "Invalid chat IDs: {chat_ids}. Please configure the chat IDs for {bot_name}."
},
"failed_chat_ids": {
"message": "Failed targets: {chat_ids}. Please verify that the chat IDs for {bot_name} have been configured."
},
"missing_input": { "missing_input": {
"message": "{field} is required." "message": "{field} is required."
}, },

View File

@ -677,13 +677,35 @@ async def test_send_message_with_config_entry(
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id) await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
# test: send message to invalid chat id
with pytest.raises(HomeAssistantError) as err:
response = await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
{
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
ATTR_MESSAGE: "mock message",
ATTR_TARGET: [123456, 1],
},
blocking=True,
return_response=True,
)
await hass.async_block_till_done()
assert err.value.translation_key == "failed_chat_ids"
assert err.value.translation_placeholders["chat_ids"] == "1"
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
# test: send message to valid chat id
response = await hass.services.async_call( response = await hass.services.async_call(
DOMAIN, DOMAIN,
SERVICE_SEND_MESSAGE, SERVICE_SEND_MESSAGE,
{ {
CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id, CONF_CONFIG_ENTRY_ID: mock_broadcast_config_entry.entry_id,
ATTR_MESSAGE: "mock message", ATTR_MESSAGE: "mock message",
ATTR_TARGET: 1, ATTR_TARGET: 123456,
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -767,6 +789,23 @@ async def test_delete_message(
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id) await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
# test: delete message with invalid chat id
with pytest.raises(ServiceValidationError) as err:
await hass.services.async_call(
DOMAIN,
SERVICE_DELETE_MESSAGE,
{ATTR_CHAT_ID: 1, ATTR_MESSAGEID: "last"},
blocking=True,
)
await hass.async_block_till_done()
assert err.value.translation_key == "invalid_chat_ids"
assert err.value.translation_placeholders["chat_ids"] == "1"
assert err.value.translation_placeholders["bot_name"] == "Mock Title"
# test: delete message with valid chat id
response = await hass.services.async_call( response = await hass.services.async_call(
DOMAIN, DOMAIN,
SERVICE_SEND_MESSAGE, SERVICE_SEND_MESSAGE,
@ -808,7 +847,7 @@ async def test_edit_message(
await hass.services.async_call( await hass.services.async_call(
DOMAIN, DOMAIN,
SERVICE_EDIT_MESSAGE, SERVICE_EDIT_MESSAGE,
{ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345}, {ATTR_MESSAGE: "mock message", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True, blocking=True,
) )
@ -822,7 +861,7 @@ async def test_edit_message(
await hass.services.async_call( await hass.services.async_call(
DOMAIN, DOMAIN,
SERVICE_EDIT_CAPTION, SERVICE_EDIT_CAPTION,
{ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345}, {ATTR_CAPTION: "mock caption", ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True, blocking=True,
) )
@ -836,7 +875,7 @@ async def test_edit_message(
await hass.services.async_call( await hass.services.async_call(
DOMAIN, DOMAIN,
SERVICE_EDIT_REPLYMARKUP, SERVICE_EDIT_REPLYMARKUP,
{ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 12345, ATTR_MESSAGEID: 12345}, {ATTR_KEYBOARD_INLINE: [], ATTR_CHAT_ID: 123456, ATTR_MESSAGEID: 12345},
blocking=True, blocking=True,
) )