Add template list functions: intersect, difference, symmetric_difference, union (#141420)

This commit is contained in:
Franck Nijhof 2025-03-26 07:51:25 +01:00 committed by GitHub
parent 56cc4044e4
commit eb1caeb770
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 230 additions and 0 deletions

View File

@ -2785,6 +2785,50 @@ def flatten(value: Iterable[Any], levels: int | None = None) -> list[Any]:
return flattened return flattened
def intersect(value: Iterable[Any], other: Iterable[Any]) -> list[Any]:
"""Return the common elements between two lists."""
if not isinstance(value, Iterable) or isinstance(value, str):
raise TypeError(f"intersect expected a list, got {type(value).__name__}")
if not isinstance(other, Iterable) or isinstance(other, str):
raise TypeError(f"intersect expected a list, got {type(other).__name__}")
return list(set(value) & set(other))
def difference(value: Iterable[Any], other: Iterable[Any]) -> list[Any]:
"""Return elements in first list that are not in second list."""
if not isinstance(value, Iterable) or isinstance(value, str):
raise TypeError(f"difference expected a list, got {type(value).__name__}")
if not isinstance(other, Iterable) or isinstance(other, str):
raise TypeError(f"difference expected a list, got {type(other).__name__}")
return list(set(value) - set(other))
def union(value: Iterable[Any], other: Iterable[Any]) -> list[Any]:
"""Return all unique elements from both lists combined."""
if not isinstance(value, Iterable) or isinstance(value, str):
raise TypeError(f"union expected a list, got {type(value).__name__}")
if not isinstance(other, Iterable) or isinstance(other, str):
raise TypeError(f"union expected a list, got {type(other).__name__}")
return list(set(value) | set(other))
def symmetric_difference(value: Iterable[Any], other: Iterable[Any]) -> list[Any]:
"""Return elements that are in either list but not in both."""
if not isinstance(value, Iterable) or isinstance(value, str):
raise TypeError(
f"symmetric_difference expected a list, got {type(value).__name__}"
)
if not isinstance(other, Iterable) or isinstance(other, str):
raise TypeError(
f"symmetric_difference expected a list, got {type(other).__name__}"
)
return list(set(value) ^ set(other))
def combine(*args: Any, recursive: bool = False) -> dict[Any, Any]: def combine(*args: Any, recursive: bool = False) -> dict[Any, Any]:
"""Combine multiple dictionaries into one.""" """Combine multiple dictionaries into one."""
if not args: if not args:
@ -2996,11 +3040,13 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.globals["bool"] = forgiving_boolean self.globals["bool"] = forgiving_boolean
self.globals["combine"] = combine self.globals["combine"] = combine
self.globals["cos"] = cosine self.globals["cos"] = cosine
self.globals["difference"] = difference
self.globals["e"] = math.e self.globals["e"] = math.e
self.globals["flatten"] = flatten self.globals["flatten"] = flatten
self.globals["float"] = forgiving_float self.globals["float"] = forgiving_float
self.globals["iif"] = iif self.globals["iif"] = iif
self.globals["int"] = forgiving_int self.globals["int"] = forgiving_int
self.globals["intersect"] = intersect
self.globals["is_number"] = is_number self.globals["is_number"] = is_number
self.globals["log"] = logarithm self.globals["log"] = logarithm
self.globals["max"] = min_max_from_filter(self.filters["max"], "max") self.globals["max"] = min_max_from_filter(self.filters["max"], "max")
@ -3020,11 +3066,13 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.globals["sqrt"] = square_root self.globals["sqrt"] = square_root
self.globals["statistical_mode"] = statistical_mode self.globals["statistical_mode"] = statistical_mode
self.globals["strptime"] = strptime self.globals["strptime"] = strptime
self.globals["symmetric_difference"] = symmetric_difference
self.globals["tan"] = tangent self.globals["tan"] = tangent
self.globals["tau"] = math.pi * 2 self.globals["tau"] = math.pi * 2
self.globals["timedelta"] = timedelta self.globals["timedelta"] = timedelta
self.globals["tuple"] = _to_tuple self.globals["tuple"] = _to_tuple
self.globals["typeof"] = typeof self.globals["typeof"] = typeof
self.globals["union"] = union
self.globals["unpack"] = struct_unpack self.globals["unpack"] = struct_unpack
self.globals["urlencode"] = urlencode self.globals["urlencode"] = urlencode
self.globals["version"] = version self.globals["version"] = version
@ -3049,11 +3097,13 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.filters["combine"] = combine self.filters["combine"] = combine
self.filters["contains"] = contains self.filters["contains"] = contains
self.filters["cos"] = cosine self.filters["cos"] = cosine
self.filters["difference"] = difference
self.filters["flatten"] = flatten self.filters["flatten"] = flatten
self.filters["float"] = forgiving_float_filter self.filters["float"] = forgiving_float_filter
self.filters["from_json"] = from_json self.filters["from_json"] = from_json
self.filters["iif"] = iif self.filters["iif"] = iif
self.filters["int"] = forgiving_int_filter self.filters["int"] = forgiving_int_filter
self.filters["intersect"] = intersect
self.filters["is_defined"] = fail_when_undefined self.filters["is_defined"] = fail_when_undefined
self.filters["is_number"] = is_number self.filters["is_number"] = is_number
self.filters["log"] = logarithm self.filters["log"] = logarithm
@ -3078,12 +3128,14 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.filters["slugify"] = slugify self.filters["slugify"] = slugify
self.filters["sqrt"] = square_root self.filters["sqrt"] = square_root
self.filters["statistical_mode"] = statistical_mode self.filters["statistical_mode"] = statistical_mode
self.filters["symmetric_difference"] = symmetric_difference
self.filters["tan"] = tangent self.filters["tan"] = tangent
self.filters["timestamp_custom"] = timestamp_custom self.filters["timestamp_custom"] = timestamp_custom
self.filters["timestamp_local"] = timestamp_local self.filters["timestamp_local"] = timestamp_local
self.filters["timestamp_utc"] = timestamp_utc self.filters["timestamp_utc"] = timestamp_utc
self.filters["to_json"] = to_json self.filters["to_json"] = to_json
self.filters["typeof"] = typeof self.filters["typeof"] = typeof
self.filters["union"] = union
self.filters["unpack"] = struct_unpack self.filters["unpack"] = struct_unpack
self.filters["version"] = version self.filters["version"] = version

View File

@ -6790,6 +6790,184 @@ def test_flatten(hass: HomeAssistant) -> None:
template.Template("{{ flatten() }}", hass).async_render() template.Template("{{ flatten() }}", hass).async_render()
def test_intersect(hass: HomeAssistant) -> None:
"""Test the intersect function and filter."""
assert list(
template.Template(
"{{ intersect([1, 2, 5, 3, 4, 10], [1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == unordered([1, 2, 3, 4, 5])
assert list(
template.Template(
"{{ [1, 2, 5, 3, 4, 10] | intersect([1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == unordered([1, 2, 3, 4, 5])
assert list(
template.Template(
"{{ intersect(['a', 'b', 'c'], ['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["b", "c"])
assert list(
template.Template(
"{{ ['a', 'b', 'c'] | intersect(['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["b", "c"])
assert (
template.Template("{{ intersect([], [1, 2, 3]) }}", hass).async_render() == []
)
assert (
template.Template("{{ [] | intersect([1, 2, 3]) }}", hass).async_render() == []
)
with pytest.raises(TemplateError, match="intersect expected a list, got str"):
template.Template("{{ 'string' | intersect([1, 2, 3]) }}", hass).async_render()
with pytest.raises(TemplateError, match="intersect expected a list, got str"):
template.Template("{{ [1, 2, 3] | intersect('string') }}", hass).async_render()
def test_difference(hass: HomeAssistant) -> None:
"""Test the difference function and filter."""
assert list(
template.Template(
"{{ difference([1, 2, 5, 3, 4, 10], [1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == [10]
assert list(
template.Template(
"{{ [1, 2, 5, 3, 4, 10] | difference([1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == [10]
assert list(
template.Template(
"{{ difference(['a', 'b', 'c'], ['b', 'c', 'd']) }}", hass
).async_render()
) == ["a"]
assert list(
template.Template(
"{{ ['a', 'b', 'c'] | difference(['b', 'c', 'd']) }}", hass
).async_render()
) == ["a"]
assert (
template.Template("{{ difference([], [1, 2, 3]) }}", hass).async_render() == []
)
assert (
template.Template("{{ [] | difference([1, 2, 3]) }}", hass).async_render() == []
)
with pytest.raises(TemplateError, match="difference expected a list, got str"):
template.Template("{{ 'string' | difference([1, 2, 3]) }}", hass).async_render()
with pytest.raises(TemplateError, match="difference expected a list, got str"):
template.Template("{{ [1, 2, 3] | difference('string') }}", hass).async_render()
def test_union(hass: HomeAssistant) -> None:
"""Test the union function and filter."""
assert list(
template.Template(
"{{ union([1, 2, 5, 3, 4, 10], [1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == unordered([1, 2, 3, 4, 5, 10, 11, 99])
assert list(
template.Template(
"{{ [1, 2, 5, 3, 4, 10] | union([1, 2, 3, 4, 5, 11, 99]) }}", hass
).async_render()
) == unordered([1, 2, 3, 4, 5, 10, 11, 99])
assert list(
template.Template(
"{{ union(['a', 'b', 'c'], ['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["a", "b", "c", "d"])
assert list(
template.Template(
"{{ ['a', 'b', 'c'] | union(['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["a", "b", "c", "d"])
assert list(
template.Template("{{ union([], [1, 2, 3]) }}", hass).async_render()
) == unordered([1, 2, 3])
assert list(
template.Template("{{ [] | union([1, 2, 3]) }}", hass).async_render()
) == unordered([1, 2, 3])
with pytest.raises(TemplateError, match="union expected a list, got str"):
template.Template("{{ 'string' | union([1, 2, 3]) }}", hass).async_render()
with pytest.raises(TemplateError, match="union expected a list, got str"):
template.Template("{{ [1, 2, 3] | union('string') }}", hass).async_render()
def test_symmetric_difference(hass: HomeAssistant) -> None:
"""Test the symmetric_difference function and filter."""
assert list(
template.Template(
"{{ symmetric_difference([1, 2, 5, 3, 4, 10], [1, 2, 3, 4, 5, 11, 99]) }}",
hass,
).async_render()
) == unordered([10, 11, 99])
assert list(
template.Template(
"{{ [1, 2, 5, 3, 4, 10] | symmetric_difference([1, 2, 3, 4, 5, 11, 99]) }}",
hass,
).async_render()
) == unordered([10, 11, 99])
assert list(
template.Template(
"{{ symmetric_difference(['a', 'b', 'c'], ['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["a", "d"])
assert list(
template.Template(
"{{ ['a', 'b', 'c'] | symmetric_difference(['b', 'c', 'd']) }}", hass
).async_render()
) == unordered(["a", "d"])
assert list(
template.Template(
"{{ symmetric_difference([], [1, 2, 3]) }}", hass
).async_render()
) == unordered([1, 2, 3])
assert list(
template.Template(
"{{ [] | symmetric_difference([1, 2, 3]) }}", hass
).async_render()
) == unordered([1, 2, 3])
with pytest.raises(
TemplateError, match="symmetric_difference expected a list, got str"
):
template.Template(
"{{ 'string' | symmetric_difference([1, 2, 3]) }}", hass
).async_render()
with pytest.raises(
TemplateError, match="symmetric_difference expected a list, got str"
):
template.Template(
"{{ [1, 2, 3] | symmetric_difference('string') }}", hass
).async_render()
def test_md5(hass: HomeAssistant) -> None: def test_md5(hass: HomeAssistant) -> None:
"""Test the md5 function and filter.""" """Test the md5 function and filter."""
assert ( assert (