diff --git a/homeassistant/components/python_script/__init__.py b/homeassistant/components/python_script/__init__.py index 7b49a6b1b0d..c10c6de5b3f 100644 --- a/homeassistant/components/python_script/__init__.py +++ b/homeassistant/components/python_script/__init__.py @@ -2,8 +2,11 @@ import datetime import glob import logging +from numbers import Number +import operator import os import time +from typing import Any from RestrictedPython import ( compile_restricted_exec, @@ -146,6 +149,36 @@ def discover_scripts(hass): async_set_service_schema(hass, DOMAIN, name, service_desc) +IOPERATOR_TO_OPERATOR = { + "%=": operator.mod, + "&=": operator.and_, + "**=": operator.pow, + "*=": operator.mul, + "+=": operator.add, + "-=": operator.sub, + "//=": operator.floordiv, + "/=": operator.truediv, + "<<=": operator.lshift, + ">>=": operator.rshift, + "@=": operator.matmul, + "^=": operator.xor, + "|=": operator.or_, +} + + +def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any: + """Implement augmented-assign (+=, -=, etc.) operators for restricted code. + + See RestrictedPython's `visit_AugAssign` for details. + """ + if not isinstance(target, (list, Number, str)): + raise ScriptError(f"The {op!r} operation is not allowed on a {type(target)}") + op_fun = IOPERATOR_TO_OPERATOR.get(op) + if not op_fun: + raise ScriptError(f"The {op!r} operation is not allowed") + return op_fun(target, operand) + + @bind_hass def execute_script(hass, name, data=None, return_response=False): """Execute a script.""" @@ -223,6 +256,7 @@ def execute(hass, filename, source, data=None, return_response=False): "_getitem_": default_guarded_getitem, "_iter_unpack_sequence_": guarded_iter_unpack_sequence, "_unpack_sequence_": guarded_unpack_sequence, + "_inplacevar_": guarded_inplacevar, "hass": hass, "data": data or {}, "logger": logger, diff --git a/tests/components/python_script/test_init.py b/tests/components/python_script/test_init.py index ee7fedee0d5..0692bdfd816 100644 --- a/tests/components/python_script/test_init.py +++ b/tests/components/python_script/test_init.py @@ -596,3 +596,48 @@ output = f"hello {data.get('name', 'World')}" blocking=True, return_response=True, ) + + +async def test_augmented_assignment_operations(hass: HomeAssistant) -> None: + """Test that augmented assignment operations work.""" + source = """ +a = 10 +a += 20 +a *= 5 +a -= 8 +b = "foo" +b += "bar" +b *= 2 +c = [] +c += [1, 2, 3] +c *= 2 +hass.states.set('hello.a', a) +hass.states.set('hello.b', b) +hass.states.set('hello.c', c) + """ + + hass.async_add_executor_job(execute, hass, "aug_assign.py", source, {}) + await hass.async_block_till_done() + + assert hass.states.get("hello.a").state == str(((10 + 20) * 5) - 8) + assert hass.states.get("hello.b").state == ("foo" + "bar") * 2 + assert hass.states.get("hello.c").state == str([1, 2, 3] * 2) + + +@pytest.mark.parametrize( + ("case", "error"), + [ + pytest.param( + "d = datetime.date(2024, 1, 1); d += 5", + "The '+=' operation is not allowed", + id="datetime.date", + ), + ], +) +async def test_prohibited_augmented_assignment_operations( + hass: HomeAssistant, case: str, error: str, caplog +) -> None: + """Test that prohibited augmented assignment operations raise an error.""" + hass.async_add_executor_job(execute, hass, "aug_assign_prohibited.py", case, {}) + await hass.async_block_till_done() + assert error in caplog.text