diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 618bc6ea4f7..ef0b578811e 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -10,6 +10,7 @@ import random import re from typing import Any, Dict, Iterable, List, Optional, Union from urllib.parse import urlencode as urllib_urlencode +import weakref import jinja2 from jinja2 import contextfilter, contextfunction @@ -958,6 +959,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): """Initialise template environment.""" super().__init__() self.hass = hass + self.template_cache = weakref.WeakValueDictionary() self.filters["round"] = forgiving_round self.filters["multiply"] = multiply self.filters["log"] = logarithm @@ -1042,5 +1044,25 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment): """Test if attribute is safe.""" return isinstance(obj, Namespace) or super().is_safe_attribute(obj, attr, value) + def compile(self, source, name=None, filename=None, raw=False, defer_init=False): + """Compile the template.""" + if ( + name is not None + or filename is not None + or raw is not False + or defer_init is not False + ): + # If there are any non-default keywords args, we do + # not cache. In prodution we currently do not have + # any instance of this. + return super().compile(source, name, filename, raw, defer_init) + + cached = self.template_cache.get(source) + + if cached is None: + cached = self.template_cache[source] = super().compile(source) + + return cached + _NO_HASS_ENV = TemplateEnvironment(None) diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index f755e4e1084..89486129760 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -1885,3 +1885,30 @@ def test_urlencode(hass): hass, ) assert tpl.async_render() == "the%20quick%20brown%20fox%20%3D%20true" + + +async def test_cache_garbage_collection(): + """Test caching a template.""" + template_string = ( + "{% set dict = {'foo': 'x&y', 'bar': 42} %} {{ dict | urlencode }}" + ) + tpl = template.Template((template_string),) + tpl.ensure_valid() + assert template._NO_HASS_ENV.template_cache.get( + template_string + ) # pylint: disable=protected-access + + tpl2 = template.Template((template_string),) + tpl2.ensure_valid() + assert template._NO_HASS_ENV.template_cache.get( + template_string + ) # pylint: disable=protected-access + + del tpl + assert template._NO_HASS_ENV.template_cache.get( + template_string + ) # pylint: disable=protected-access + del tpl2 + assert not template._NO_HASS_ENV.template_cache.get( + template_string + ) # pylint: disable=protected-access