tweak types

This commit is contained in:
J. Nick Koston 2025-04-01 11:29:12 -10:00
parent 48fb69347d
commit 0f8a4504b2
No known key found for this signature in database
2 changed files with 44 additions and 24 deletions

View File

@ -78,13 +78,13 @@ def _print_file_read_event(path: str) -> None:
) )
def _request_and_get_stream_on_stdin(fname: str): def _request_and_get_stream_on_stdin(fname: str) -> StringIO:
_print_file_read_event(fname) _print_file_read_event(fname)
raw_yaml_stream = StringIO(_read_file_content_from_json_on_stdin()) raw_yaml_stream = StringIO(_read_file_content_from_json_on_stdin())
return raw_yaml_stream return raw_yaml_stream
def _vscode_loader(fname: str): def _vscode_loader(fname: str) -> dict[str, Any]:
raw_yaml_stream = _request_and_get_stream_on_stdin(fname) raw_yaml_stream = _request_and_get_stream_on_stdin(fname)
# it is required to set the name on StringIO so document on start_mark # it is required to set the name on StringIO so document on start_mark
# is set properly. Otherwise it is initialized with "<file>" # is set properly. Otherwise it is initialized with "<file>"
@ -92,7 +92,7 @@ def _vscode_loader(fname: str):
return parse_yaml(fname, raw_yaml_stream, _vscode_loader) return parse_yaml(fname, raw_yaml_stream, _vscode_loader)
def _ace_loader(fname: str): def _ace_loader(fname: str) -> dict[str, Any]:
raw_yaml_stream = _request_and_get_stream_on_stdin(fname) raw_yaml_stream = _request_and_get_stream_on_stdin(fname)
return parse_yaml(fname, raw_yaml_stream) return parse_yaml(fname, raw_yaml_stream)

View File

@ -3,12 +3,12 @@ from __future__ import annotations
import fnmatch import fnmatch
import functools import functools
import inspect import inspect
from io import TextIOWrapper from io import BytesIO, TextIOBase, TextIOWrapper
from ipaddress import _BaseAddress from ipaddress import _BaseAddress
import logging import logging
import math import math
import os import os
from typing import Any from typing import Any, Callable
import uuid import uuid
import yaml import yaml
@ -69,7 +69,10 @@ class ESPForceValue:
pass pass
def make_data_base(value, from_database: ESPHomeDataBase = None): def make_data_base(
value, from_database: ESPHomeDataBase = None
) -> ESPHomeDataBase | Any:
"""Wrap a value in a ESPHomeDataBase object."""
try: try:
value = add_class_to_obj(value, ESPHomeDataBase) value = add_class_to_obj(value, ESPHomeDataBase)
if from_database is not None: if from_database is not None:
@ -102,7 +105,8 @@ def _add_data_ref(fn):
class ESPHomeLoaderMixin: class ESPHomeLoaderMixin:
"""Loader class that keeps track of line numbers.""" """Loader class that keeps track of line numbers."""
def __init__(self, name, yaml_loader): def __init__(self, name: str, yaml_loader: Callable[[str], dict[str, Any]]) -> None:
"""Initialize the loader."""
self.name = name self.name = name
self.yaml_loader = yaml_loader self.yaml_loader = yaml_loader
@ -131,7 +135,7 @@ class ESPHomeLoaderMixin:
return super().construct_yaml_seq(node) return super().construct_yaml_seq(node)
@_add_data_ref @_add_data_ref
def construct_yaml_map(self, node): def construct_yaml_map(self, node: yaml.MappingNode) -> OrderedDict[str, Any]:
"""Traverses the given mapping node and returns a list of constructed key-value pairs.""" """Traverses the given mapping node and returns a list of constructed key-value pairs."""
assert isinstance(node, yaml.MappingNode) assert isinstance(node, yaml.MappingNode)
# A list of key-value pairs we find in the current mapping # A list of key-value pairs we find in the current mapping
@ -235,7 +239,7 @@ class ESPHomeLoaderMixin:
return OrderedDict(pairs) return OrderedDict(pairs)
@_add_data_ref @_add_data_ref
def construct_env_var(self, node): def construct_env_var(self, node: yaml.Node) -> str:
args = node.value.split() args = node.value.split()
# Check for a default value # Check for a default value
if len(args) > 1: if len(args) > 1:
@ -247,14 +251,14 @@ class ESPHomeLoaderMixin:
) )
@property @property
def _directory(self): def _directory(self) -> str:
return os.path.dirname(self.name) return os.path.dirname(self.name)
def _rel_path(self, *args): def _rel_path(self, *args: str) -> str:
return os.path.join(self._directory, *args) return os.path.join(self._directory, *args)
@_add_data_ref @_add_data_ref
def construct_secret(self, node): def construct_secret(self, node: yaml.Node) -> str:
try: try:
secrets = self.yaml_loader(self._rel_path(SECRET_YAML)) secrets = self.yaml_loader(self._rel_path(SECRET_YAML))
except EsphomeError as e: except EsphomeError as e:
@ -276,7 +280,9 @@ class ESPHomeLoaderMixin:
return val return val
@_add_data_ref @_add_data_ref
def construct_include(self, node): def construct_include(
self, node: yaml.Node
) -> dict[str, Any] | OrderedDict[str, Any]:
from esphome.const import CONF_VARS from esphome.const import CONF_VARS
def extract_file_vars(node): def extract_file_vars(node):
@ -301,12 +307,12 @@ class ESPHomeLoaderMixin:
return result return result
@_add_data_ref @_add_data_ref
def construct_include_dir_list(self, node): def construct_include_dir_list(self, node: yaml.Node) -> list[dict[str, Any]]:
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
return [self.yaml_loader(f) for f in files] return [self.yaml_loader(f) for f in files]
@_add_data_ref @_add_data_ref
def construct_include_dir_merge_list(self, node): def construct_include_dir_merge_list(self, node: yaml.Node) -> list[dict[str, Any]]:
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
merged_list = [] merged_list = []
for fname in files: for fname in files:
@ -316,7 +322,9 @@ class ESPHomeLoaderMixin:
return merged_list return merged_list
@_add_data_ref @_add_data_ref
def construct_include_dir_named(self, node): def construct_include_dir_named(
self, node: yaml.Node
) -> OrderedDict[str, dict[str, Any]]:
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict() mapping = OrderedDict()
for fname in files: for fname in files:
@ -325,7 +333,9 @@ class ESPHomeLoaderMixin:
return mapping return mapping
@_add_data_ref @_add_data_ref
def construct_include_dir_merge_named(self, node): def construct_include_dir_merge_named(
self, node: yaml.Node
) -> OrderedDict[str, dict[str, Any]]:
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict() mapping = OrderedDict()
for fname in files: for fname in files:
@ -335,27 +345,32 @@ class ESPHomeLoaderMixin:
return mapping return mapping
@_add_data_ref @_add_data_ref
def construct_lambda(self, node): def construct_lambda(self, node: yaml.Node) -> Lambda:
return Lambda(str(node.value)) return Lambda(str(node.value))
@_add_data_ref @_add_data_ref
def construct_force(self, node): def construct_force(self, node: yaml.Node) -> ESPForceValue:
obj = self.construct_scalar(node) obj = self.construct_scalar(node)
return add_class_to_obj(obj, ESPForceValue) return add_class_to_obj(obj, ESPForceValue)
@_add_data_ref @_add_data_ref
def construct_extend(self, node): def construct_extend(self, node: yaml.Node) -> Extend:
return Extend(str(node.value)) return Extend(str(node.value))
@_add_data_ref @_add_data_ref
def construct_remove(self, node): def construct_remove(self, node: yaml.Node) -> Remove:
return Remove(str(node.value)) return Remove(str(node.value))
class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader): class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader):
"""Loader class that keeps track of line numbers.""" """Loader class that keeps track of line numbers."""
def __init__(self, stream, name, yaml_loader): def __init__(
self,
stream: TextIOBase | BytesIO,
name: str,
yaml_loader: Callable[[str], dict[str, Any]],
) -> None:
FastestAvailableSafeLoader.__init__(self, stream) FastestAvailableSafeLoader.__init__(self, stream)
ESPHomeLoaderMixin.__init__(self, name, yaml_loader) ESPHomeLoaderMixin.__init__(self, name, yaml_loader)
@ -363,7 +378,12 @@ class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader):
class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader): class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader):
"""Loader class that keeps track of line numbers.""" """Loader class that keeps track of line numbers."""
def __init__(self, stream, name, yaml_loader): def __init__(
self,
stream: TextIOBase | BytesIO,
name: str,
yaml_loader: Callable[[str], dict[str, Any]],
) -> None:
PurePythonLoader.__init__(self, stream) PurePythonLoader.__init__(self, stream)
ESPHomeLoaderMixin.__init__(self, name, yaml_loader) ESPHomeLoaderMixin.__init__(self, name, yaml_loader)
@ -486,7 +506,7 @@ def dump(dict_, show_secrets=False):
) )
def _is_file_valid(name): def _is_file_valid(name: str) -> bool:
"""Decide if a file is valid.""" """Decide if a file is valid."""
return not name.startswith(".") return not name.startswith(".")