diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 40d64ba37ae..b5a6a45e97f 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -492,7 +492,7 @@ class Template: if ret is None: ret = self.hass.data[wanted_env] = TemplateEnvironment( self.hass, - self._limited, # type: ignore[no-untyped-call] + self._limited, self._strict, ) return ret @@ -2276,7 +2276,12 @@ class HassLoader(jinja2.BaseLoader): class TemplateEnvironment(ImmutableSandboxedEnvironment): """The Home Assistant template environment.""" - def __init__(self, hass, limited=False, strict=False): + def __init__( + self, + hass: HomeAssistant | None, + limited: bool | None = False, + strict: bool | None = False, + ) -> None: """Initialise template environment.""" undefined: type[LoggingUndefined] | type[jinja2.StrictUndefined] if not strict: @@ -2381,6 +2386,10 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): # can be discarded, we only need to get at the hass object. def hassfunction( func: Callable[Concatenate[HomeAssistant, _P], _R], + jinja_context: Callable[ + [Callable[Concatenate[Any, _P], _R]], + Callable[Concatenate[Any, _P], _R], + ] = pass_context, ) -> Callable[Concatenate[Any, _P], _R]: """Wrap function that depend on hass.""" @@ -2388,42 +2397,40 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R: return func(hass, *args, **kwargs) - return pass_context(wrapper) + return jinja_context(wrapper) self.globals["device_entities"] = hassfunction(device_entities) - self.filters["device_entities"] = pass_context(self.globals["device_entities"]) + self.filters["device_entities"] = self.globals["device_entities"] self.globals["device_attr"] = hassfunction(device_attr) - self.filters["device_attr"] = pass_context(self.globals["device_attr"]) + self.filters["device_attr"] = self.globals["device_attr"] self.globals["is_device_attr"] = hassfunction(is_device_attr) - self.tests["is_device_attr"] = pass_eval_context(self.globals["is_device_attr"]) + self.tests["is_device_attr"] = hassfunction(is_device_attr, pass_eval_context) self.globals["config_entry_id"] = hassfunction(config_entry_id) - self.filters["config_entry_id"] = pass_context(self.globals["config_entry_id"]) + self.filters["config_entry_id"] = self.globals["config_entry_id"] self.globals["device_id"] = hassfunction(device_id) - self.filters["device_id"] = pass_context(self.globals["device_id"]) + self.filters["device_id"] = self.globals["device_id"] self.globals["areas"] = hassfunction(areas) - self.filters["areas"] = pass_context(self.globals["areas"]) + self.filters["areas"] = self.globals["areas"] self.globals["area_id"] = hassfunction(area_id) - self.filters["area_id"] = pass_context(self.globals["area_id"]) + self.filters["area_id"] = self.globals["area_id"] self.globals["area_name"] = hassfunction(area_name) - self.filters["area_name"] = pass_context(self.globals["area_name"]) + self.filters["area_name"] = self.globals["area_name"] self.globals["area_entities"] = hassfunction(area_entities) - self.filters["area_entities"] = pass_context(self.globals["area_entities"]) + self.filters["area_entities"] = self.globals["area_entities"] self.globals["area_devices"] = hassfunction(area_devices) - self.filters["area_devices"] = pass_context(self.globals["area_devices"]) + self.filters["area_devices"] = self.globals["area_devices"] self.globals["integration_entities"] = hassfunction(integration_entities) - self.filters["integration_entities"] = pass_context( - self.globals["integration_entities"] - ) + self.filters["integration_entities"] = self.globals["integration_entities"] if limited: # Only device_entities is available to limited templates, mark other @@ -2479,25 +2486,25 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): return self.globals["expand"] = hassfunction(expand) - self.filters["expand"] = pass_context(self.globals["expand"]) + self.filters["expand"] = self.globals["expand"] self.globals["closest"] = hassfunction(closest) - self.filters["closest"] = pass_context(hassfunction(closest_filter)) + self.filters["closest"] = hassfunction(closest_filter) self.globals["distance"] = hassfunction(distance) self.globals["is_hidden_entity"] = hassfunction(is_hidden_entity) - self.tests["is_hidden_entity"] = pass_eval_context( - self.globals["is_hidden_entity"] + self.tests["is_hidden_entity"] = hassfunction( + is_hidden_entity, pass_eval_context ) self.globals["is_state"] = hassfunction(is_state) - self.tests["is_state"] = pass_eval_context(self.globals["is_state"]) + self.tests["is_state"] = hassfunction(is_state, pass_eval_context) self.globals["is_state_attr"] = hassfunction(is_state_attr) - self.tests["is_state_attr"] = pass_eval_context(self.globals["is_state_attr"]) + self.tests["is_state_attr"] = hassfunction(is_state_attr, pass_eval_context) self.globals["state_attr"] = hassfunction(state_attr) self.filters["state_attr"] = self.globals["state_attr"] self.globals["states"] = AllStates(hass) self.filters["states"] = self.globals["states"] self.globals["has_value"] = hassfunction(has_value) - self.filters["has_value"] = pass_context(self.globals["has_value"]) - self.tests["has_value"] = pass_eval_context(self.globals["has_value"]) + self.filters["has_value"] = self.globals["has_value"] + self.tests["has_value"] = hassfunction(has_value, pass_eval_context) self.globals["utcnow"] = hassfunction(utcnow) self.globals["now"] = hassfunction(now) self.globals["relative_time"] = hassfunction(relative_time) @@ -2575,4 +2582,4 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): return cached -_NO_HASS_ENV = TemplateEnvironment(None) # type: ignore[no-untyped-call] +_NO_HASS_ENV = TemplateEnvironment(None)