YAML loader performance improvements (#111199)

* YAML loader performance improvements

- Cache the name of the loader since we call it multiple
  times for every line

- Add a fast path for scalar tags since they are the
  most common

* Update homeassistant/util/yaml/loader.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

* remove unreachable code

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
J. Nick Koston 2024-02-23 08:37:09 -10:00 committed by GitHub
parent b9ed315cf7
commit 5d421e249f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,13 +2,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from contextlib import suppress
import fnmatch import fnmatch
from io import StringIO, TextIOWrapper from io import StringIO, TextIOWrapper
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, TextIO, TypeVar, overload from typing import TYPE_CHECKING, Any, TextIO, TypeVar, overload
import yaml import yaml
@ -28,6 +27,12 @@ from homeassistant.helpers.frame import report
from .const import SECRET_YAML from .const import SECRET_YAML
from .objects import Input, NodeDictClass, NodeListClass, NodeStrClass from .objects import Input, NodeDictClass, NodeListClass, NodeStrClass
if TYPE_CHECKING:
from functools import cached_property
else:
from homeassistant.backports.functools import cached_property
# mypy: allow-untyped-calls, no-warn-return-any # mypy: allow-untyped-calls, no-warn-return-any
JSON_TYPE = list | dict | str JSON_TYPE = list | dict | str
@ -113,10 +118,12 @@ class _LoaderMixin:
name: str name: str
stream: Any stream: Any
@cached_property
def get_name(self) -> str: def get_name(self) -> str:
"""Get the name of the loader.""" """Get the name of the loader."""
return self.name return self.name
@cached_property
def get_stream_name(self) -> str: def get_stream_name(self) -> str:
"""Get the name of the stream.""" """Get the name of the stream."""
return getattr(self.stream, "name", "") return getattr(self.stream, "name", "")
@ -311,9 +318,11 @@ def _add_reference( # type: ignore[no-untyped-def]
obj = NodeListClass(obj) obj = NodeListClass(obj)
if isinstance(obj, str): if isinstance(obj, str):
obj = NodeStrClass(obj) obj = NodeStrClass(obj)
with suppress(AttributeError): try: # noqa: SIM105 suppress is much slower
setattr(obj, "__config_file__", loader.get_name()) setattr(obj, "__config_file__", loader.get_name)
setattr(obj, "__line__", node.start_mark.line + 1) setattr(obj, "__line__", node.start_mark.line + 1)
except AttributeError:
pass
return obj return obj
@ -324,7 +333,7 @@ def _include_yaml(loader: LoaderType, node: yaml.nodes.Node) -> JSON_TYPE:
device_tracker: !include device_tracker.yaml device_tracker: !include device_tracker.yaml
""" """
fname = os.path.join(os.path.dirname(loader.get_name()), node.value) fname = os.path.join(os.path.dirname(loader.get_name), node.value)
try: try:
loaded_yaml = load_yaml(fname, loader.secrets) loaded_yaml = load_yaml(fname, loader.secrets)
if loaded_yaml is None: if loaded_yaml is None:
@ -354,7 +363,7 @@ def _find_files(directory: str, pattern: str) -> Iterator[str]:
def _include_dir_named_yaml(loader: LoaderType, node: yaml.nodes.Node) -> NodeDictClass: def _include_dir_named_yaml(loader: LoaderType, node: yaml.nodes.Node) -> NodeDictClass:
"""Load multiple files from directory as a dictionary.""" """Load multiple files from directory as a dictionary."""
mapping = NodeDictClass() mapping = NodeDictClass()
loc = os.path.join(os.path.dirname(loader.get_name()), node.value) loc = os.path.join(os.path.dirname(loader.get_name), node.value)
for fname in _find_files(loc, "*.yaml"): for fname in _find_files(loc, "*.yaml"):
filename = os.path.splitext(os.path.basename(fname))[0] filename = os.path.splitext(os.path.basename(fname))[0]
if os.path.basename(fname) == SECRET_YAML: if os.path.basename(fname) == SECRET_YAML:
@ -373,7 +382,7 @@ def _include_dir_merge_named_yaml(
) -> NodeDictClass: ) -> NodeDictClass:
"""Load multiple files from directory as a merged dictionary.""" """Load multiple files from directory as a merged dictionary."""
mapping = NodeDictClass() mapping = NodeDictClass()
loc = os.path.join(os.path.dirname(loader.get_name()), node.value) loc = os.path.join(os.path.dirname(loader.get_name), node.value)
for fname in _find_files(loc, "*.yaml"): for fname in _find_files(loc, "*.yaml"):
if os.path.basename(fname) == SECRET_YAML: if os.path.basename(fname) == SECRET_YAML:
continue continue
@ -387,7 +396,7 @@ def _include_dir_list_yaml(
loader: LoaderType, node: yaml.nodes.Node loader: LoaderType, node: yaml.nodes.Node
) -> list[JSON_TYPE]: ) -> list[JSON_TYPE]:
"""Load multiple files from directory as a list.""" """Load multiple files from directory as a list."""
loc = os.path.join(os.path.dirname(loader.get_name()), node.value) loc = os.path.join(os.path.dirname(loader.get_name), node.value)
return [ return [
loaded_yaml loaded_yaml
for f in _find_files(loc, "*.yaml") for f in _find_files(loc, "*.yaml")
@ -400,7 +409,7 @@ def _include_dir_merge_list_yaml(
loader: LoaderType, node: yaml.nodes.Node loader: LoaderType, node: yaml.nodes.Node
) -> JSON_TYPE: ) -> JSON_TYPE:
"""Load multiple files from directory as a merged list.""" """Load multiple files from directory as a merged list."""
loc: str = os.path.join(os.path.dirname(loader.get_name()), node.value) loc: str = os.path.join(os.path.dirname(loader.get_name), node.value)
merged_list: list[JSON_TYPE] = [] merged_list: list[JSON_TYPE] = []
for fname in _find_files(loc, "*.yaml"): for fname in _find_files(loc, "*.yaml"):
if os.path.basename(fname) == SECRET_YAML: if os.path.basename(fname) == SECRET_YAML:
@ -425,7 +434,7 @@ def _handle_mapping_tag(
try: try:
hash(key) hash(key)
except TypeError as exc: except TypeError as exc:
fname = loader.get_stream_name() fname = loader.get_stream_name
raise yaml.MarkedYAMLError( raise yaml.MarkedYAMLError(
context=f'invalid key: "{key}"', context=f'invalid key: "{key}"',
context_mark=yaml.Mark( context_mark=yaml.Mark(
@ -439,7 +448,7 @@ def _handle_mapping_tag(
) from exc ) from exc
if key in seen: if key in seen:
fname = loader.get_stream_name() fname = loader.get_stream_name
_LOGGER.warning( _LOGGER.warning(
'YAML file %s contains duplicate key "%s". Check lines %d and %d', 'YAML file %s contains duplicate key "%s". Check lines %d and %d',
fname, fname,
@ -462,7 +471,7 @@ def _handle_scalar_tag(
loader: LoaderType, node: yaml.nodes.ScalarNode loader: LoaderType, node: yaml.nodes.ScalarNode
) -> str | int | float | None: ) -> str | int | float | None:
"""Add line number and file name to Load YAML sequence.""" """Add line number and file name to Load YAML sequence."""
obj = loader.construct_scalar(node) obj = node.value
if not isinstance(obj, str): if not isinstance(obj, str):
return obj return obj
return _add_reference(obj, loader, node) return _add_reference(obj, loader, node)
@ -486,7 +495,7 @@ def secret_yaml(loader: LoaderType, node: yaml.nodes.Node) -> JSON_TYPE:
if loader.secrets is None: if loader.secrets is None:
raise HomeAssistantError("Secrets not supported in this YAML file") raise HomeAssistantError("Secrets not supported in this YAML file")
return loader.secrets.get(loader.get_name(), node.value) return loader.secrets.get(loader.get_name, node.value)
def add_constructor(tag: Any, constructor: Any) -> None: def add_constructor(tag: Any, constructor: Any) -> None: