mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Restrict Python Script (#8053)
This commit is contained in:
parent
c478f2c7d0
commit
74cc675a38
@ -5,6 +5,9 @@ import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import sanitize_filename
|
||||
|
||||
DOMAIN = 'python_script'
|
||||
REQUIREMENTS = ['restrictedpython==4.0a2']
|
||||
FOLDER = 'python_scripts'
|
||||
@ -14,6 +17,18 @@ CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema(dict)
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
ALLOWED_HASS = set(['bus', 'services', 'states'])
|
||||
ALLOWED_EVENTBUS = set(['fire'])
|
||||
ALLOWED_STATEMACHINE = set(['entity_ids', 'all', 'get', 'is_state',
|
||||
'is_state_attr', 'remove', 'set'])
|
||||
ALLOWED_SERVICEREGISTRY = set(['services', 'has_service', 'call'])
|
||||
|
||||
|
||||
class ScriptError(HomeAssistantError):
|
||||
"""When a script error occurs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def setup(hass, config):
|
||||
"""Initialize the python_script component."""
|
||||
@ -23,21 +38,27 @@ def setup(hass, config):
|
||||
_LOGGER.warning('Folder %s not found in config folder', FOLDER)
|
||||
return False
|
||||
|
||||
def service_handler(call):
|
||||
def python_script_service_handler(call):
|
||||
"""Handle python script service calls."""
|
||||
filename = '{}.py'.format(call.service)
|
||||
with open(hass.config.path(FOLDER, filename)) as fil:
|
||||
execute(hass, filename, fil.read(), call.data)
|
||||
execute_script(hass, call.service, call.data)
|
||||
|
||||
for fil in glob.iglob(os.path.join(path, '*.py')):
|
||||
name = os.path.splitext(os.path.basename(fil))[0]
|
||||
hass.services.register(DOMAIN, name, service_handler)
|
||||
hass.services.register(DOMAIN, name, python_script_service_handler)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def execute(hass, filename, source, data):
|
||||
def execute_script(hass, name, data=None):
|
||||
"""Execute a script."""
|
||||
filename = '{}.py'.format(name)
|
||||
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil:
|
||||
source = fil.read()
|
||||
execute(hass, filename, source, data)
|
||||
|
||||
|
||||
def execute(hass, filename, source, data=None):
|
||||
"""Execute Python source."""
|
||||
from RestrictedPython import compile_restricted_exec
|
||||
from RestrictedPython.Guards import safe_builtins, full_write_guard
|
||||
|
||||
@ -52,24 +73,41 @@ def execute(hass, filename, source, data):
|
||||
_LOGGER.warning('Warning loading script %s: %s', filename,
|
||||
', '.join(compiled.warnings))
|
||||
|
||||
def protected_getattr(obj, name, default=None):
|
||||
"""Restricted method to get attributes."""
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if name.startswith('async_'):
|
||||
raise ScriptError('Not allowed to access async methods')
|
||||
elif (obj is hass and name not in ALLOWED_HASS or
|
||||
obj is hass.bus and name not in ALLOWED_EVENTBUS or
|
||||
obj is hass.states and name not in ALLOWED_STATEMACHINE or
|
||||
obj is hass.services and name not in ALLOWED_SERVICEREGISTRY):
|
||||
raise ScriptError('Not allowed to access {}.{}'.format(
|
||||
obj.__class__.__name__, name))
|
||||
|
||||
return getattr(obj, name, default)
|
||||
|
||||
restricted_globals = {
|
||||
'__builtins__': safe_builtins,
|
||||
'_print_': StubPrinter,
|
||||
'_getattr_': getattr,
|
||||
'_getattr_': protected_getattr,
|
||||
'_write_': full_write_guard,
|
||||
}
|
||||
logger = logging.getLogger('{}.{}'.format(__name__, filename))
|
||||
local = {
|
||||
'hass': hass,
|
||||
'data': data,
|
||||
'logger': logging.getLogger('{}.{}'.format(__name__, filename))
|
||||
'data': data or {},
|
||||
'logger': logger
|
||||
}
|
||||
|
||||
try:
|
||||
_LOGGER.info('Executing %s: %s', filename, data)
|
||||
# pylint: disable=exec-used
|
||||
exec(compiled.code, restricted_globals, local)
|
||||
except ScriptError as err:
|
||||
logger.error('Error executing script: %s', err)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error executing script %s: %s', filename, err)
|
||||
logger.exception('Error executing script: %s', err)
|
||||
|
||||
|
||||
class StubPrinter:
|
||||
|
@ -120,4 +120,32 @@ raise Exception('boom')
|
||||
hass.async_add_job(execute, hass, 'test.py', source, {})
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert "Error executing script test.py" in caplog.text
|
||||
assert "Error executing script: boom" in caplog.text
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_accessing_async_methods(hass, caplog):
|
||||
"""Test compile error logs error."""
|
||||
caplog.set_level(logging.ERROR)
|
||||
source = """
|
||||
hass.async_stop()
|
||||
"""
|
||||
|
||||
hass.async_add_job(execute, hass, 'test.py', source, {})
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert "Not allowed to access async methods" in caplog.text
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_accessing_forbidden_methods(hass, caplog):
|
||||
"""Test compile error logs error."""
|
||||
caplog.set_level(logging.ERROR)
|
||||
source = """
|
||||
hass.stop()
|
||||
"""
|
||||
|
||||
hass.async_add_job(execute, hass, 'test.py', source, {})
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert "Not allowed to access HomeAssistant.stop" in caplog.text
|
||||
|
Loading…
x
Reference in New Issue
Block a user