mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add service validation for send file for Telegram bot integration (#146192)
* added service validation for send file * update strings * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * updated exception in tests * removed TypeError since it is not thrown --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
f0a2c4e30a
commit
8deec55204
@ -28,11 +28,12 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
ATTR_COMMAND,
|
||||
CONF_API_KEY,
|
||||
HTTP_BASIC_AUTHENTICATION,
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
HTTP_DIGEST_AUTHENTICATION,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.util.ssl import get_default_context, get_default_no_verify_context
|
||||
|
||||
@ -853,70 +854,119 @@ async def load_data(
|
||||
verify_ssl=None,
|
||||
):
|
||||
"""Load data into ByteIO/File container from a source."""
|
||||
try:
|
||||
if url is not None:
|
||||
# Load data from URL
|
||||
params: dict[str, Any] = {}
|
||||
headers = {}
|
||||
if authentication == HTTP_BEARER_AUTHENTICATION and password is not None:
|
||||
headers = {"Authorization": f"Bearer {password}"}
|
||||
elif username is not None and password is not None:
|
||||
if authentication == HTTP_DIGEST_AUTHENTICATION:
|
||||
params["auth"] = httpx.DigestAuth(username, password)
|
||||
else:
|
||||
params["auth"] = httpx.BasicAuth(username, password)
|
||||
if verify_ssl is not None:
|
||||
params["verify"] = verify_ssl
|
||||
if url is not None:
|
||||
# Load data from URL
|
||||
params: dict[str, Any] = {}
|
||||
headers = {}
|
||||
_validate_credentials_input(authentication, username, password)
|
||||
if authentication == HTTP_BEARER_AUTHENTICATION:
|
||||
headers = {"Authorization": f"Bearer {password}"}
|
||||
elif authentication == HTTP_DIGEST_AUTHENTICATION:
|
||||
params["auth"] = httpx.DigestAuth(username, password)
|
||||
elif authentication == HTTP_BASIC_AUTHENTICATION:
|
||||
params["auth"] = httpx.BasicAuth(username, password)
|
||||
|
||||
retry_num = 0
|
||||
async with httpx.AsyncClient(
|
||||
timeout=15, headers=headers, **params
|
||||
) as client:
|
||||
while retry_num < num_retries:
|
||||
if verify_ssl is not None:
|
||||
params["verify"] = verify_ssl
|
||||
|
||||
retry_num = 0
|
||||
async with httpx.AsyncClient(timeout=15, headers=headers, **params) as client:
|
||||
while retry_num < num_retries:
|
||||
try:
|
||||
req = await client.get(url)
|
||||
if req.status_code != 200:
|
||||
_LOGGER.warning(
|
||||
"Status code %s (retry #%s) loading %s",
|
||||
req.status_code,
|
||||
retry_num + 1,
|
||||
url,
|
||||
)
|
||||
else:
|
||||
data = io.BytesIO(req.content)
|
||||
if data.read():
|
||||
data.seek(0)
|
||||
data.name = url
|
||||
return data
|
||||
_LOGGER.warning(
|
||||
"Empty data (retry #%s) in %s)", retry_num + 1, url
|
||||
)
|
||||
retry_num += 1
|
||||
if retry_num < num_retries:
|
||||
await asyncio.sleep(
|
||||
1
|
||||
) # Add a sleep to allow other async operations to proceed
|
||||
_LOGGER.warning(
|
||||
"Can't load data in %s after %s retries", url, retry_num
|
||||
)
|
||||
elif filepath is not None:
|
||||
if hass.config.is_allowed_path(filepath):
|
||||
return await hass.async_add_executor_job(
|
||||
_read_file_as_bytesio, filepath
|
||||
)
|
||||
except (httpx.HTTPError, httpx.InvalidURL) as err:
|
||||
raise HomeAssistantError(
|
||||
f"Failed to load URL: {err!s}",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_load_url",
|
||||
translation_placeholders={"error": str(err)},
|
||||
) from err
|
||||
|
||||
_LOGGER.warning("'%s' are not secure to load data from!", filepath)
|
||||
else:
|
||||
_LOGGER.warning("Can't load data. No data found in params!")
|
||||
if req.status_code != 200:
|
||||
_LOGGER.warning(
|
||||
"Status code %s (retry #%s) loading %s",
|
||||
req.status_code,
|
||||
retry_num + 1,
|
||||
url,
|
||||
)
|
||||
else:
|
||||
data = io.BytesIO(req.content)
|
||||
if data.read():
|
||||
data.seek(0)
|
||||
data.name = url
|
||||
return data
|
||||
_LOGGER.warning("Empty data (retry #%s) in %s)", retry_num + 1, url)
|
||||
retry_num += 1
|
||||
if retry_num < num_retries:
|
||||
await asyncio.sleep(
|
||||
1
|
||||
) # Add a sleep to allow other async operations to proceed
|
||||
raise HomeAssistantError(
|
||||
f"Failed to load URL: {req.status_code}",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_load_url",
|
||||
translation_placeholders={"error": str(req.status_code)},
|
||||
)
|
||||
elif filepath is not None:
|
||||
if hass.config.is_allowed_path(filepath):
|
||||
return await hass.async_add_executor_job(_read_file_as_bytesio, filepath)
|
||||
|
||||
except (OSError, TypeError) as error:
|
||||
_LOGGER.error("Can't load data into ByteIO: %s", error)
|
||||
raise ServiceValidationError(
|
||||
"File path has not been configured in allowlist_external_dirs.",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="allowlist_external_dirs_error",
|
||||
)
|
||||
else:
|
||||
raise ServiceValidationError(
|
||||
"URL or File is required.",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="missing_input",
|
||||
translation_placeholders={"field": "URL or File"},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _validate_credentials_input(
|
||||
authentication: str | None, username: str | None, password: str | None
|
||||
) -> None:
|
||||
if (
|
||||
authentication in (HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
|
||||
and username is None
|
||||
):
|
||||
raise ServiceValidationError(
|
||||
"Username is required.",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="missing_input",
|
||||
translation_placeholders={"field": "Username"},
|
||||
)
|
||||
|
||||
if (
|
||||
authentication
|
||||
in (
|
||||
HTTP_BASIC_AUTHENTICATION,
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
)
|
||||
and password is None
|
||||
):
|
||||
raise ServiceValidationError(
|
||||
"Password is required.",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="missing_input",
|
||||
translation_placeholders={"field": "Password"},
|
||||
)
|
||||
|
||||
|
||||
def _read_file_as_bytesio(file_path: str) -> io.BytesIO:
|
||||
"""Read a file and return it as a BytesIO object."""
|
||||
with open(file_path, "rb") as file:
|
||||
data = io.BytesIO(file.read())
|
||||
data.name = file_path
|
||||
return data
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
data = io.BytesIO(file.read())
|
||||
data.name = file_path
|
||||
return data
|
||||
except OSError as err:
|
||||
raise HomeAssistantError(
|
||||
f"Failed to load file: {err!s}",
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="failed_to_load_file",
|
||||
translation_placeholders={"error": str(err)},
|
||||
) from err
|
||||
|
@ -867,6 +867,18 @@
|
||||
},
|
||||
"missing_allowed_chat_ids": {
|
||||
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
|
||||
},
|
||||
"missing_input": {
|
||||
"message": "{field} is required."
|
||||
},
|
||||
"failed_to_load_url": {
|
||||
"message": "Failed to load URL: {error}"
|
||||
},
|
||||
"allowlist_external_dirs_error": {
|
||||
"message": "File path has not been configured in allowlist_external_dirs."
|
||||
},
|
||||
"failed_to_load_file": {
|
||||
"message": "Failed to load file: {error}"
|
||||
}
|
||||
},
|
||||
"issues": {
|
||||
|
@ -6,7 +6,7 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from telegram import Update, User
|
||||
from telegram import Update
|
||||
from telegram.error import (
|
||||
InvalidToken,
|
||||
NetworkError,
|
||||
@ -16,6 +16,7 @@ from telegram.error import (
|
||||
)
|
||||
|
||||
from homeassistant.components.telegram_bot import (
|
||||
ATTR_AUTHENTICATION,
|
||||
ATTR_CALLBACK_QUERY_ID,
|
||||
ATTR_CAPTION,
|
||||
ATTR_CHAT_ID,
|
||||
@ -27,9 +28,13 @@ from homeassistant.components.telegram_bot import (
|
||||
ATTR_MESSAGE_THREAD_ID,
|
||||
ATTR_MESSAGEID,
|
||||
ATTR_OPTIONS,
|
||||
ATTR_PASSWORD,
|
||||
ATTR_QUESTION,
|
||||
ATTR_STICKER_ID,
|
||||
ATTR_TARGET,
|
||||
ATTR_URL,
|
||||
ATTR_USERNAME,
|
||||
ATTR_VERIFY_SSL,
|
||||
CONF_CONFIG_ENTRY_ID,
|
||||
CONF_PLATFORM,
|
||||
DOMAIN,
|
||||
@ -53,11 +58,21 @@ from homeassistant.components.telegram_bot import (
|
||||
)
|
||||
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.const import (
|
||||
CONF_API_KEY,
|
||||
HTTP_BASIC_AUTHENTICATION,
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
HTTP_DIGEST_AUTHENTICATION,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ServiceValidationError
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.file import write_utf8_file
|
||||
|
||||
from tests.common import MockConfigEntry, async_capture_events
|
||||
from tests.typing import ClientSessionGenerator
|
||||
@ -566,10 +581,7 @@ async def test_send_message_no_chat_id_error(
|
||||
CONF_API_KEY: "mock api key",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.telegram_bot.config_flow.Bot.get_me",
|
||||
return_value=User(123456, "Testbot", True),
|
||||
):
|
||||
with patch("homeassistant.components.telegram_bot.config_flow.Bot.get_me"):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": SOURCE_USER},
|
||||
@ -740,8 +752,7 @@ async def test_answer_callback_query(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.telegram_bot.bot.TelegramNotificationService.answer_callback_query",
|
||||
AsyncMock(),
|
||||
"homeassistant.components.telegram_bot.bot.TelegramNotificationService.answer_callback_query"
|
||||
) as mock:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
@ -782,3 +793,184 @@ async def test_leave_chat(
|
||||
|
||||
await hass.async_block_till_done()
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
async def test_send_video(
|
||||
hass: HomeAssistant,
|
||||
mock_broadcast_config_entry: MockConfigEntry,
|
||||
mock_external_calls: None,
|
||||
) -> None:
|
||||
"""Test answer callback query."""
|
||||
mock_broadcast_config_entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# test: invalid file path
|
||||
|
||||
with pytest.raises(ServiceValidationError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_FILE: "/mock/file",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== "File path has not been configured in allowlist_external_dirs."
|
||||
)
|
||||
|
||||
# test: missing username input
|
||||
|
||||
with pytest.raises(ServiceValidationError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_URL: "https://mock",
|
||||
ATTR_AUTHENTICATION: HTTP_DIGEST_AUTHENTICATION,
|
||||
ATTR_PASSWORD: "mock password",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert err.value.args[0] == "Username is required."
|
||||
|
||||
# test: missing password input
|
||||
|
||||
with pytest.raises(ServiceValidationError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_URL: "https://mock",
|
||||
ATTR_AUTHENTICATION: HTTP_BEARER_AUTHENTICATION,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert err.value.args[0] == "Password is required."
|
||||
|
||||
# test: 404 error
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.telegram_bot.bot.httpx.AsyncClient.get"
|
||||
) as mock_get:
|
||||
mock_get.return_value = AsyncMock(status_code=404, text="Success")
|
||||
|
||||
with pytest.raises(HomeAssistantError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_URL: "https://mock",
|
||||
ATTR_AUTHENTICATION: HTTP_BASIC_AUTHENTICATION,
|
||||
ATTR_USERNAME: "mock username",
|
||||
ATTR_PASSWORD: "mock password",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert mock_get.call_count > 0
|
||||
assert err.value.args[0] == "Failed to load URL: 404"
|
||||
|
||||
# test: invalid url
|
||||
|
||||
with pytest.raises(HomeAssistantError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_URL: "invalid url",
|
||||
ATTR_VERIFY_SSL: True,
|
||||
ATTR_AUTHENTICATION: HTTP_BEARER_AUTHENTICATION,
|
||||
ATTR_PASSWORD: "mock password",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert mock_get.call_count > 0
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== "Failed to load URL: Request URL is missing an 'http://' or 'https://' protocol."
|
||||
)
|
||||
|
||||
# test: no url/file input
|
||||
|
||||
with pytest.raises(ServiceValidationError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert err.value.args[0] == "URL or File is required."
|
||||
|
||||
# test: load file error (e.g. not found, permissions error)
|
||||
|
||||
hass.config.allowlist_external_dirs.add("/tmp/") # noqa: S108
|
||||
|
||||
with pytest.raises(HomeAssistantError) as err:
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_FILE: "/tmp/not-exists", # noqa: S108
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== "Failed to load file: [Errno 2] No such file or directory: '/tmp/not-exists'"
|
||||
)
|
||||
|
||||
# test: success with file
|
||||
write_utf8_file("/tmp/mock", "mock file contents") # noqa: S108
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_FILE: "/tmp/mock", # noqa: S108
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert response["chats"][0]["message_id"] == 12345
|
||||
|
||||
# test: success with url
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.telegram_bot.bot.httpx.AsyncClient.get"
|
||||
) as mock_get:
|
||||
mock_get.return_value = AsyncMock(status_code=200, content=b"mock content")
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SEND_VIDEO,
|
||||
{
|
||||
ATTR_URL: "https://mock",
|
||||
ATTR_AUTHENTICATION: HTTP_DIGEST_AUTHENTICATION,
|
||||
ATTR_USERNAME: "mock username",
|
||||
ATTR_PASSWORD: "mock password",
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert mock_get.call_count > 0
|
||||
assert response["chats"][0]["message_id"] == 12345
|
||||
|
Loading…
x
Reference in New Issue
Block a user