Small cleanup of TemplateEnvironment (#99571)

* Small cleanup of TemplateEnvironment

* Fix typo
This commit is contained in:
Erik Montnemery 2023-09-04 22:19:40 +02:00 committed by Bram Kragten
parent cab9c97598
commit 4c0e4fe745

View File

@ -492,7 +492,7 @@ class Template:
if ret is None:
ret = self.hass.data[wanted_env] = TemplateEnvironment(
self.hass,
self._limited, # type: ignore[no-untyped-call]
self._limited,
self._strict,
)
return ret
@ -2276,7 +2276,12 @@ class HassLoader(jinja2.BaseLoader):
class TemplateEnvironment(ImmutableSandboxedEnvironment):
"""The Home Assistant template environment."""
def __init__(self, hass, limited=False, strict=False):
def __init__(
self,
hass: HomeAssistant | None,
limited: bool | None = False,
strict: bool | None = False,
) -> None:
"""Initialise template environment."""
undefined: type[LoggingUndefined] | type[jinja2.StrictUndefined]
if not strict:
@ -2381,6 +2386,10 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# can be discarded, we only need to get at the hass object.
def hassfunction(
func: Callable[Concatenate[HomeAssistant, _P], _R],
jinja_context: Callable[
[Callable[Concatenate[Any, _P], _R]],
Callable[Concatenate[Any, _P], _R],
] = pass_context,
) -> Callable[Concatenate[Any, _P], _R]:
"""Wrap function that depend on hass."""
@ -2388,42 +2397,40 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return func(hass, *args, **kwargs)
return pass_context(wrapper)
return jinja_context(wrapper)
self.globals["device_entities"] = hassfunction(device_entities)
self.filters["device_entities"] = pass_context(self.globals["device_entities"])
self.filters["device_entities"] = self.globals["device_entities"]
self.globals["device_attr"] = hassfunction(device_attr)
self.filters["device_attr"] = pass_context(self.globals["device_attr"])
self.filters["device_attr"] = self.globals["device_attr"]
self.globals["is_device_attr"] = hassfunction(is_device_attr)
self.tests["is_device_attr"] = pass_eval_context(self.globals["is_device_attr"])
self.tests["is_device_attr"] = hassfunction(is_device_attr, pass_eval_context)
self.globals["config_entry_id"] = hassfunction(config_entry_id)
self.filters["config_entry_id"] = pass_context(self.globals["config_entry_id"])
self.filters["config_entry_id"] = self.globals["config_entry_id"]
self.globals["device_id"] = hassfunction(device_id)
self.filters["device_id"] = pass_context(self.globals["device_id"])
self.filters["device_id"] = self.globals["device_id"]
self.globals["areas"] = hassfunction(areas)
self.filters["areas"] = pass_context(self.globals["areas"])
self.filters["areas"] = self.globals["areas"]
self.globals["area_id"] = hassfunction(area_id)
self.filters["area_id"] = pass_context(self.globals["area_id"])
self.filters["area_id"] = self.globals["area_id"]
self.globals["area_name"] = hassfunction(area_name)
self.filters["area_name"] = pass_context(self.globals["area_name"])
self.filters["area_name"] = self.globals["area_name"]
self.globals["area_entities"] = hassfunction(area_entities)
self.filters["area_entities"] = pass_context(self.globals["area_entities"])
self.filters["area_entities"] = self.globals["area_entities"]
self.globals["area_devices"] = hassfunction(area_devices)
self.filters["area_devices"] = pass_context(self.globals["area_devices"])
self.filters["area_devices"] = self.globals["area_devices"]
self.globals["integration_entities"] = hassfunction(integration_entities)
self.filters["integration_entities"] = pass_context(
self.globals["integration_entities"]
)
self.filters["integration_entities"] = self.globals["integration_entities"]
if limited:
# Only device_entities is available to limited templates, mark other
@ -2479,25 +2486,25 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return
self.globals["expand"] = hassfunction(expand)
self.filters["expand"] = pass_context(self.globals["expand"])
self.filters["expand"] = self.globals["expand"]
self.globals["closest"] = hassfunction(closest)
self.filters["closest"] = pass_context(hassfunction(closest_filter))
self.filters["closest"] = hassfunction(closest_filter)
self.globals["distance"] = hassfunction(distance)
self.globals["is_hidden_entity"] = hassfunction(is_hidden_entity)
self.tests["is_hidden_entity"] = pass_eval_context(
self.globals["is_hidden_entity"]
self.tests["is_hidden_entity"] = hassfunction(
is_hidden_entity, pass_eval_context
)
self.globals["is_state"] = hassfunction(is_state)
self.tests["is_state"] = pass_eval_context(self.globals["is_state"])
self.tests["is_state"] = hassfunction(is_state, pass_eval_context)
self.globals["is_state_attr"] = hassfunction(is_state_attr)
self.tests["is_state_attr"] = pass_eval_context(self.globals["is_state_attr"])
self.tests["is_state_attr"] = hassfunction(is_state_attr, pass_eval_context)
self.globals["state_attr"] = hassfunction(state_attr)
self.filters["state_attr"] = self.globals["state_attr"]
self.globals["states"] = AllStates(hass)
self.filters["states"] = self.globals["states"]
self.globals["has_value"] = hassfunction(has_value)
self.filters["has_value"] = pass_context(self.globals["has_value"])
self.tests["has_value"] = pass_eval_context(self.globals["has_value"])
self.filters["has_value"] = self.globals["has_value"]
self.tests["has_value"] = hassfunction(has_value, pass_eval_context)
self.globals["utcnow"] = hassfunction(utcnow)
self.globals["now"] = hassfunction(now)
self.globals["relative_time"] = hassfunction(relative_time)
@ -2575,4 +2582,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
return cached
_NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call]
_NO_HASS_ENV = TemplateEnvironment(None)