mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add service validation for send file for Telegram bot integration (#146192)
* added service validation for send file * update strings * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * updated exception in tests * removed TypeError since it is not thrown --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
f0a2c4e30a
commit
8deec55204
@ -28,11 +28,12 @@ from homeassistant.config_entries import ConfigEntry
|
|||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_COMMAND,
|
ATTR_COMMAND,
|
||||||
CONF_API_KEY,
|
CONF_API_KEY,
|
||||||
|
HTTP_BASIC_AUTHENTICATION,
|
||||||
HTTP_BEARER_AUTHENTICATION,
|
HTTP_BEARER_AUTHENTICATION,
|
||||||
HTTP_DIGEST_AUTHENTICATION,
|
HTTP_DIGEST_AUTHENTICATION,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import ServiceValidationError
|
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||||
from homeassistant.helpers import issue_registry as ir
|
from homeassistant.helpers import issue_registry as ir
|
||||||
from homeassistant.util.ssl import get_default_context, get_default_no_verify_context
|
from homeassistant.util.ssl import get_default_context, get_default_no_verify_context
|
||||||
|
|
||||||
@ -853,70 +854,119 @@ async def load_data(
|
|||||||
verify_ssl=None,
|
verify_ssl=None,
|
||||||
):
|
):
|
||||||
"""Load data into ByteIO/File container from a source."""
|
"""Load data into ByteIO/File container from a source."""
|
||||||
try:
|
if url is not None:
|
||||||
if url is not None:
|
# Load data from URL
|
||||||
# Load data from URL
|
params: dict[str, Any] = {}
|
||||||
params: dict[str, Any] = {}
|
headers = {}
|
||||||
headers = {}
|
_validate_credentials_input(authentication, username, password)
|
||||||
if authentication == HTTP_BEARER_AUTHENTICATION and password is not None:
|
if authentication == HTTP_BEARER_AUTHENTICATION:
|
||||||
headers = {"Authorization": f"Bearer {password}"}
|
headers = {"Authorization": f"Bearer {password}"}
|
||||||
elif username is not None and password is not None:
|
elif authentication == HTTP_DIGEST_AUTHENTICATION:
|
||||||
if authentication == HTTP_DIGEST_AUTHENTICATION:
|
params["auth"] = httpx.DigestAuth(username, password)
|
||||||
params["auth"] = httpx.DigestAuth(username, password)
|
elif authentication == HTTP_BASIC_AUTHENTICATION:
|
||||||
else:
|
params["auth"] = httpx.BasicAuth(username, password)
|
||||||
params["auth"] = httpx.BasicAuth(username, password)
|
|
||||||
if verify_ssl is not None:
|
|
||||||
params["verify"] = verify_ssl
|
|
||||||
|
|
||||||
retry_num = 0
|
if verify_ssl is not None:
|
||||||
async with httpx.AsyncClient(
|
params["verify"] = verify_ssl
|
||||||
timeout=15, headers=headers, **params
|
|
||||||
) as client:
|
retry_num = 0
|
||||||
while retry_num < num_retries:
|
async with httpx.AsyncClient(timeout=15, headers=headers, **params) as client:
|
||||||
|
while retry_num < num_retries:
|
||||||
|
try:
|
||||||
req = await client.get(url)
|
req = await client.get(url)
|
||||||
if req.status_code != 200:
|
except (httpx.HTTPError, httpx.InvalidURL) as err:
|
||||||
_LOGGER.warning(
|
raise HomeAssistantError(
|
||||||
"Status code %s (retry #%s) loading %s",
|
f"Failed to load URL: {err!s}",
|
||||||
req.status_code,
|
translation_domain=DOMAIN,
|
||||||
retry_num + 1,
|
translation_key="failed_to_load_url",
|
||||||
url,
|
translation_placeholders={"error": str(err)},
|
||||||
)
|
) from err
|
||||||
else:
|
|
||||||
data = io.BytesIO(req.content)
|
|
||||||
if data.read():
|
|
||||||
data.seek(0)
|
|
||||||
data.name = url
|
|
||||||
return data
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Empty data (retry #%s) in %s)", retry_num + 1, url
|
|
||||||
)
|
|
||||||
retry_num += 1
|
|
||||||
if retry_num < num_retries:
|
|
||||||
await asyncio.sleep(
|
|
||||||
1
|
|
||||||
) # Add a sleep to allow other async operations to proceed
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Can't load data in %s after %s retries", url, retry_num
|
|
||||||
)
|
|
||||||
elif filepath is not None:
|
|
||||||
if hass.config.is_allowed_path(filepath):
|
|
||||||
return await hass.async_add_executor_job(
|
|
||||||
_read_file_as_bytesio, filepath
|
|
||||||
)
|
|
||||||
|
|
||||||
_LOGGER.warning("'%s' are not secure to load data from!", filepath)
|
if req.status_code != 200:
|
||||||
else:
|
_LOGGER.warning(
|
||||||
_LOGGER.warning("Can't load data. No data found in params!")
|
"Status code %s (retry #%s) loading %s",
|
||||||
|
req.status_code,
|
||||||
|
retry_num + 1,
|
||||||
|
url,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data = io.BytesIO(req.content)
|
||||||
|
if data.read():
|
||||||
|
data.seek(0)
|
||||||
|
data.name = url
|
||||||
|
return data
|
||||||
|
_LOGGER.warning("Empty data (retry #%s) in %s)", retry_num + 1, url)
|
||||||
|
retry_num += 1
|
||||||
|
if retry_num < num_retries:
|
||||||
|
await asyncio.sleep(
|
||||||
|
1
|
||||||
|
) # Add a sleep to allow other async operations to proceed
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Failed to load URL: {req.status_code}",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="failed_to_load_url",
|
||||||
|
translation_placeholders={"error": str(req.status_code)},
|
||||||
|
)
|
||||||
|
elif filepath is not None:
|
||||||
|
if hass.config.is_allowed_path(filepath):
|
||||||
|
return await hass.async_add_executor_job(_read_file_as_bytesio, filepath)
|
||||||
|
|
||||||
except (OSError, TypeError) as error:
|
raise ServiceValidationError(
|
||||||
_LOGGER.error("Can't load data into ByteIO: %s", error)
|
"File path has not been configured in allowlist_external_dirs.",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="allowlist_external_dirs_error",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ServiceValidationError(
|
||||||
|
"URL or File is required.",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="missing_input",
|
||||||
|
translation_placeholders={"field": "URL or File"},
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
|
||||||
|
def _validate_credentials_input(
|
||||||
|
authentication: str | None, username: str | None, password: str | None
|
||||||
|
) -> None:
|
||||||
|
if (
|
||||||
|
authentication in (HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
|
||||||
|
and username is None
|
||||||
|
):
|
||||||
|
raise ServiceValidationError(
|
||||||
|
"Username is required.",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="missing_input",
|
||||||
|
translation_placeholders={"field": "Username"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
authentication
|
||||||
|
in (
|
||||||
|
HTTP_BASIC_AUTHENTICATION,
|
||||||
|
HTTP_BEARER_AUTHENTICATION,
|
||||||
|
HTTP_BEARER_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
and password is None
|
||||||
|
):
|
||||||
|
raise ServiceValidationError(
|
||||||
|
"Password is required.",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="missing_input",
|
||||||
|
translation_placeholders={"field": "Password"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _read_file_as_bytesio(file_path: str) -> io.BytesIO:
|
def _read_file_as_bytesio(file_path: str) -> io.BytesIO:
|
||||||
"""Read a file and return it as a BytesIO object."""
|
"""Read a file and return it as a BytesIO object."""
|
||||||
with open(file_path, "rb") as file:
|
try:
|
||||||
data = io.BytesIO(file.read())
|
with open(file_path, "rb") as file:
|
||||||
data.name = file_path
|
data = io.BytesIO(file.read())
|
||||||
return data
|
data.name = file_path
|
||||||
|
return data
|
||||||
|
except OSError as err:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Failed to load file: {err!s}",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="failed_to_load_file",
|
||||||
|
translation_placeholders={"error": str(err)},
|
||||||
|
) from err
|
||||||
|
@ -867,6 +867,18 @@
|
|||||||
},
|
},
|
||||||
"missing_allowed_chat_ids": {
|
"missing_allowed_chat_ids": {
|
||||||
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
|
"message": "No allowed chat IDs found. Please add allowed chat IDs for {bot_name}."
|
||||||
|
},
|
||||||
|
"missing_input": {
|
||||||
|
"message": "{field} is required."
|
||||||
|
},
|
||||||
|
"failed_to_load_url": {
|
||||||
|
"message": "Failed to load URL: {error}"
|
||||||
|
},
|
||||||
|
"allowlist_external_dirs_error": {
|
||||||
|
"message": "File path has not been configured in allowlist_external_dirs."
|
||||||
|
},
|
||||||
|
"failed_to_load_file": {
|
||||||
|
"message": "Failed to load file: {error}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"issues": {
|
"issues": {
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any
|
|||||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from telegram import Update, User
|
from telegram import Update
|
||||||
from telegram.error import (
|
from telegram.error import (
|
||||||
InvalidToken,
|
InvalidToken,
|
||||||
NetworkError,
|
NetworkError,
|
||||||
@ -16,6 +16,7 @@ from telegram.error import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components.telegram_bot import (
|
from homeassistant.components.telegram_bot import (
|
||||||
|
ATTR_AUTHENTICATION,
|
||||||
ATTR_CALLBACK_QUERY_ID,
|
ATTR_CALLBACK_QUERY_ID,
|
||||||
ATTR_CAPTION,
|
ATTR_CAPTION,
|
||||||
ATTR_CHAT_ID,
|
ATTR_CHAT_ID,
|
||||||
@ -27,9 +28,13 @@ from homeassistant.components.telegram_bot import (
|
|||||||
ATTR_MESSAGE_THREAD_ID,
|
ATTR_MESSAGE_THREAD_ID,
|
||||||
ATTR_MESSAGEID,
|
ATTR_MESSAGEID,
|
||||||
ATTR_OPTIONS,
|
ATTR_OPTIONS,
|
||||||
|
ATTR_PASSWORD,
|
||||||
ATTR_QUESTION,
|
ATTR_QUESTION,
|
||||||
ATTR_STICKER_ID,
|
ATTR_STICKER_ID,
|
||||||
ATTR_TARGET,
|
ATTR_TARGET,
|
||||||
|
ATTR_URL,
|
||||||
|
ATTR_USERNAME,
|
||||||
|
ATTR_VERIFY_SSL,
|
||||||
CONF_CONFIG_ENTRY_ID,
|
CONF_CONFIG_ENTRY_ID,
|
||||||
CONF_PLATFORM,
|
CONF_PLATFORM,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
@ -53,11 +58,21 @@ from homeassistant.components.telegram_bot import (
|
|||||||
)
|
)
|
||||||
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
|
from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL
|
||||||
from homeassistant.config_entries import SOURCE_USER
|
from homeassistant.config_entries import SOURCE_USER
|
||||||
from homeassistant.const import CONF_API_KEY
|
from homeassistant.const import (
|
||||||
|
CONF_API_KEY,
|
||||||
|
HTTP_BASIC_AUTHENTICATION,
|
||||||
|
HTTP_BEARER_AUTHENTICATION,
|
||||||
|
HTTP_DIGEST_AUTHENTICATION,
|
||||||
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ServiceValidationError
|
from homeassistant.exceptions import (
|
||||||
|
ConfigEntryAuthFailed,
|
||||||
|
HomeAssistantError,
|
||||||
|
ServiceValidationError,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.util.file import write_utf8_file
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, async_capture_events
|
from tests.common import MockConfigEntry, async_capture_events
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
@ -566,10 +581,7 @@ async def test_send_message_no_chat_id_error(
|
|||||||
CONF_API_KEY: "mock api key",
|
CONF_API_KEY: "mock api key",
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with patch("homeassistant.components.telegram_bot.config_flow.Bot.get_me"):
|
||||||
"homeassistant.components.telegram_bot.config_flow.Bot.get_me",
|
|
||||||
return_value=User(123456, "Testbot", True),
|
|
||||||
):
|
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
context={"source": SOURCE_USER},
|
context={"source": SOURCE_USER},
|
||||||
@ -740,8 +752,7 @@ async def test_answer_callback_query(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.telegram_bot.bot.TelegramNotificationService.answer_callback_query",
|
"homeassistant.components.telegram_bot.bot.TelegramNotificationService.answer_callback_query"
|
||||||
AsyncMock(),
|
|
||||||
) as mock:
|
) as mock:
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
@ -782,3 +793,184 @@ async def test_leave_chat(
|
|||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_video(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_broadcast_config_entry: MockConfigEntry,
|
||||||
|
mock_external_calls: None,
|
||||||
|
) -> None:
|
||||||
|
"""Test answer callback query."""
|
||||||
|
mock_broadcast_config_entry.add_to_hass(hass)
|
||||||
|
await hass.config_entries.async_setup(mock_broadcast_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# test: invalid file path
|
||||||
|
|
||||||
|
with pytest.raises(ServiceValidationError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_FILE: "/mock/file",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert (
|
||||||
|
err.value.args[0]
|
||||||
|
== "File path has not been configured in allowlist_external_dirs."
|
||||||
|
)
|
||||||
|
|
||||||
|
# test: missing username input
|
||||||
|
|
||||||
|
with pytest.raises(ServiceValidationError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_URL: "https://mock",
|
||||||
|
ATTR_AUTHENTICATION: HTTP_DIGEST_AUTHENTICATION,
|
||||||
|
ATTR_PASSWORD: "mock password",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert err.value.args[0] == "Username is required."
|
||||||
|
|
||||||
|
# test: missing password input
|
||||||
|
|
||||||
|
with pytest.raises(ServiceValidationError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_URL: "https://mock",
|
||||||
|
ATTR_AUTHENTICATION: HTTP_BEARER_AUTHENTICATION,
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert err.value.args[0] == "Password is required."
|
||||||
|
|
||||||
|
# test: 404 error
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.telegram_bot.bot.httpx.AsyncClient.get"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = AsyncMock(status_code=404, text="Success")
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_URL: "https://mock",
|
||||||
|
ATTR_AUTHENTICATION: HTTP_BASIC_AUTHENTICATION,
|
||||||
|
ATTR_USERNAME: "mock username",
|
||||||
|
ATTR_PASSWORD: "mock password",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mock_get.call_count > 0
|
||||||
|
assert err.value.args[0] == "Failed to load URL: 404"
|
||||||
|
|
||||||
|
# test: invalid url
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_URL: "invalid url",
|
||||||
|
ATTR_VERIFY_SSL: True,
|
||||||
|
ATTR_AUTHENTICATION: HTTP_BEARER_AUTHENTICATION,
|
||||||
|
ATTR_PASSWORD: "mock password",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mock_get.call_count > 0
|
||||||
|
assert (
|
||||||
|
err.value.args[0]
|
||||||
|
== "Failed to load URL: Request URL is missing an 'http://' or 'https://' protocol."
|
||||||
|
)
|
||||||
|
|
||||||
|
# test: no url/file input
|
||||||
|
|
||||||
|
with pytest.raises(ServiceValidationError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert err.value.args[0] == "URL or File is required."
|
||||||
|
|
||||||
|
# test: load file error (e.g. not found, permissions error)
|
||||||
|
|
||||||
|
hass.config.allowlist_external_dirs.add("/tmp/") # noqa: S108
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError) as err:
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_FILE: "/tmp/not-exists", # noqa: S108
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert (
|
||||||
|
err.value.args[0]
|
||||||
|
== "Failed to load file: [Errno 2] No such file or directory: '/tmp/not-exists'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# test: success with file
|
||||||
|
write_utf8_file("/tmp/mock", "mock file contents") # noqa: S108
|
||||||
|
|
||||||
|
response = await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_FILE: "/tmp/mock", # noqa: S108
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert response["chats"][0]["message_id"] == 12345
|
||||||
|
|
||||||
|
# test: success with url
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.telegram_bot.bot.httpx.AsyncClient.get"
|
||||||
|
) as mock_get:
|
||||||
|
mock_get.return_value = AsyncMock(status_code=200, content=b"mock content")
|
||||||
|
|
||||||
|
response = await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SEND_VIDEO,
|
||||||
|
{
|
||||||
|
ATTR_URL: "https://mock",
|
||||||
|
ATTR_AUTHENTICATION: HTTP_DIGEST_AUTHENTICATION,
|
||||||
|
ATTR_USERNAME: "mock username",
|
||||||
|
ATTR_PASSWORD: "mock password",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mock_get.call_count > 0
|
||||||
|
assert response["chats"][0]["message_id"] == 12345
|
||||||
|
Loading…
x
Reference in New Issue
Block a user