Small cleanup of TemplateEnvironment (#99571)

* Small cleanup of TemplateEnvironment

* Fix typo
This commit is contained in:
Erik Montnemery 2023-09-04 22:19:40 +02:00 committed by Bram Kragten
parent cab9c97598
commit 4c0e4fe745

View File

@ -492,7 +492,7 @@ class Template:
if ret is None: if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment( ret = self.hass.data[wanted_env] = TemplateEnvironment(
self.hass, self.hass,
self._limited, # type: ignore[no-untyped-call] self._limited,
self._strict, self._strict,
) )
return ret return ret
@ -2276,7 +2276,12 @@ class HassLoader(jinja2.BaseLoader):
class TemplateEnvironment(ImmutableSandboxedEnvironment): class TemplateEnvironment(ImmutableSandboxedEnvironment):
"""The Home Assistant template environment.""" """The Home Assistant template environment."""
def __init__(self, hass, limited=False, strict=False): def __init__(
self,
hass: HomeAssistant | None,
limited: bool | None = False,
strict: bool | None = False,
) -> None:
"""Initialise template environment.""" """Initialise template environment."""
undefined: type[LoggingUndefined] | type[jinja2.StrictUndefined] undefined: type[LoggingUndefined] | type[jinja2.StrictUndefined]
if not strict: if not strict:
@ -2381,6 +2386,10 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# can be discarded, we only need to get at the hass object. # can be discarded, we only need to get at the hass object.
def hassfunction( def hassfunction(
func: Callable[Concatenate[HomeAssistant, _P], _R], func: Callable[Concatenate[HomeAssistant, _P], _R],
jinja_context: Callable[
[Callable[Concatenate[Any, _P], _R]],
Callable[Concatenate[Any, _P], _R],
] = pass_context,
) -> Callable[Concatenate[Any, _P], _R]: ) -> Callable[Concatenate[Any, _P], _R]:
"""Wrap function that depend on hass.""" """Wrap function that depend on hass."""
@ -2388,42 +2397,40 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R: def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return func(hass, *args, **kwargs) return func(hass, *args, **kwargs)
return pass_context(wrapper) return jinja_context(wrapper)
self.globals["device_entities"] = hassfunction(device_entities) self.globals["device_entities"] = hassfunction(device_entities)
self.filters["device_entities"] = pass_context(self.globals["device_entities"]) self.filters["device_entities"] = self.globals["device_entities"]
self.globals["device_attr"] = hassfunction(device_attr) self.globals["device_attr"] = hassfunction(device_attr)
self.filters["device_attr"] = pass_context(self.globals["device_attr"]) self.filters["device_attr"] = self.globals["device_attr"]
self.globals["is_device_attr"] = hassfunction(is_device_attr) self.globals["is_device_attr"] = hassfunction(is_device_attr)
self.tests["is_device_attr"] = pass_eval_context(self.globals["is_device_attr"]) self.tests["is_device_attr"] = hassfunction(is_device_attr, pass_eval_context)
self.globals["config_entry_id"] = hassfunction(config_entry_id) self.globals["config_entry_id"] = hassfunction(config_entry_id)
self.filters["config_entry_id"] = pass_context(self.globals["config_entry_id"]) self.filters["config_entry_id"] = self.globals["config_entry_id"]
self.globals["device_id"] = hassfunction(device_id) self.globals["device_id"] = hassfunction(device_id)
self.filters["device_id"] = pass_context(self.globals["device_id"]) self.filters["device_id"] = self.globals["device_id"]
self.globals["areas"] = hassfunction(areas) self.globals["areas"] = hassfunction(areas)
self.filters["areas"] = pass_context(self.globals["areas"]) self.filters["areas"] = self.globals["areas"]
self.globals["area_id"] = hassfunction(area_id) self.globals["area_id"] = hassfunction(area_id)
self.filters["area_id"] = pass_context(self.globals["area_id"]) self.filters["area_id"] = self.globals["area_id"]
self.globals["area_name"] = hassfunction(area_name) self.globals["area_name"] = hassfunction(area_name)
self.filters["area_name"] = pass_context(self.globals["area_name"]) self.filters["area_name"] = self.globals["area_name"]
self.globals["area_entities"] = hassfunction(area_entities) self.globals["area_entities"] = hassfunction(area_entities)
self.filters["area_entities"] = pass_context(self.globals["area_entities"]) self.filters["area_entities"] = self.globals["area_entities"]
self.globals["area_devices"] = hassfunction(area_devices) self.globals["area_devices"] = hassfunction(area_devices)
self.filters["area_devices"] = pass_context(self.globals["area_devices"]) self.filters["area_devices"] = self.globals["area_devices"]
self.globals["integration_entities"] = hassfunction(integration_entities) self.globals["integration_entities"] = hassfunction(integration_entities)
self.filters["integration_entities"] = pass_context( self.filters["integration_entities"] = self.globals["integration_entities"]
self.globals["integration_entities"]
)
if limited: if limited:
# Only device_entities is available to limited templates, mark other # Only device_entities is available to limited templates, mark other
@ -2479,25 +2486,25 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return return
self.globals["expand"] = hassfunction(expand) self.globals["expand"] = hassfunction(expand)
self.filters["expand"] = pass_context(self.globals["expand"]) self.filters["expand"] = self.globals["expand"]
self.globals["closest"] = hassfunction(closest) self.globals["closest"] = hassfunction(closest)
self.filters["closest"] = pass_context(hassfunction(closest_filter)) self.filters["closest"] = hassfunction(closest_filter)
self.globals["distance"] = hassfunction(distance) self.globals["distance"] = hassfunction(distance)
self.globals["is_hidden_entity"] = hassfunction(is_hidden_entity) self.globals["is_hidden_entity"] = hassfunction(is_hidden_entity)
self.tests["is_hidden_entity"] = pass_eval_context( self.tests["is_hidden_entity"] = hassfunction(
self.globals["is_hidden_entity"] is_hidden_entity, pass_eval_context
) )
self.globals["is_state"] = hassfunction(is_state) self.globals["is_state"] = hassfunction(is_state)
self.tests["is_state"] = pass_eval_context(self.globals["is_state"]) self.tests["is_state"] = hassfunction(is_state, pass_eval_context)
self.globals["is_state_attr"] = hassfunction(is_state_attr) self.globals["is_state_attr"] = hassfunction(is_state_attr)
self.tests["is_state_attr"] = pass_eval_context(self.globals["is_state_attr"]) self.tests["is_state_attr"] = hassfunction(is_state_attr, pass_eval_context)
self.globals["state_attr"] = hassfunction(state_attr) self.globals["state_attr"] = hassfunction(state_attr)
self.filters["state_attr"] = self.globals["state_attr"] self.filters["state_attr"] = self.globals["state_attr"]
self.globals["states"] = AllStates(hass) self.globals["states"] = AllStates(hass)
self.filters["states"] = self.globals["states"] self.filters["states"] = self.globals["states"]
self.globals["has_value"] = hassfunction(has_value) self.globals["has_value"] = hassfunction(has_value)
self.filters["has_value"] = pass_context(self.globals["has_value"]) self.filters["has_value"] = self.globals["has_value"]
self.tests["has_value"] = pass_eval_context(self.globals["has_value"]) self.tests["has_value"] = hassfunction(has_value, pass_eval_context)
self.globals["utcnow"] = hassfunction(utcnow) self.globals["utcnow"] = hassfunction(utcnow)
self.globals["now"] = hassfunction(now) self.globals["now"] = hassfunction(now)
self.globals["relative_time"] = hassfunction(relative_time) self.globals["relative_time"] = hassfunction(relative_time)
@ -2575,4 +2582,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return cached return cached
_NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call] _NO_HASS_ENV = TemplateEnvironment(None)