diff --git a/esphome/config_validation.py b/esphome/config_validation.py index bb1cb1ac2f..7bd3f90adc 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -1498,30 +1498,9 @@ def dimensions(value): def directory(value): - import json - value = string(value) path = CORE.relative_config_path(value) - if CORE.vscode and ( - not CORE.ace or os.path.abspath(path) == os.path.abspath(CORE.config_path) - ): - print( - json.dumps( - { - "type": "check_directory_exists", - "path": path, - } - ) - ) - data = json.loads(input()) - assert data["type"] == "directory_exists_response" - if data["content"]: - return value - raise Invalid( - f"Could not find directory '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." - ) - if not os.path.exists(path): raise Invalid( f"Could not find directory '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." @@ -1534,30 +1513,9 @@ def directory(value): def file_(value): - import json - value = string(value) path = CORE.relative_config_path(value) - if CORE.vscode and ( - not CORE.ace or os.path.abspath(path) == os.path.abspath(CORE.config_path) - ): - print( - json.dumps( - { - "type": "check_file_exists", - "path": path, - } - ) - ) - data = json.loads(input()) - assert data["type"] == "file_exists_response" - if data["content"]: - return value - raise Invalid( - f"Could not find file '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." - ) - if not os.path.exists(path): raise Invalid( f"Could not find file '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 2a7b8b9d91..1a81a6d6cd 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -475,7 +475,6 @@ class EsphomeCore: self.dashboard = False # True if command is run from vscode api self.vscode = False - self.ace = False # The name of the node self.name: Optional[str] = None # The friendly name of the node diff --git a/esphome/vscode.py b/esphome/vscode.py index 907ed88216..fb62b60eac 100644 --- a/esphome/vscode.py +++ b/esphome/vscode.py @@ -78,28 +78,47 @@ def _print_file_read_event(path: str) -> None: ) +def _request_and_get_stream_on_stdin(fname: str) -> StringIO: + _print_file_read_event(fname) + raw_yaml_stream = StringIO(_read_file_content_from_json_on_stdin()) + return raw_yaml_stream + + +def _vscode_loader(fname: str) -> dict[str, Any]: + raw_yaml_stream = _request_and_get_stream_on_stdin(fname) + # it is required to set the name on StringIO so document on start_mark + # is set properly. Otherwise it is initialized with "" + raw_yaml_stream.name = fname + return parse_yaml(fname, raw_yaml_stream, _vscode_loader) + + +def _ace_loader(fname: str) -> dict[str, Any]: + raw_yaml_stream = _request_and_get_stream_on_stdin(fname) + return parse_yaml(fname, raw_yaml_stream) + + def read_config(args): while True: CORE.reset() data = json.loads(input()) - assert data["type"] == "validate" + assert data["type"] == "validate" or data["type"] == "exit" + if data["type"] == "exit": + return CORE.vscode = True - CORE.ace = args.ace - f = data["file"] - if CORE.ace: - CORE.config_path = os.path.join(args.configuration, f) + if args.ace: # Running from ESPHome Compiler dashboard, not vscode + CORE.config_path = os.path.join(args.configuration, data["file"]) + loader = _ace_loader else: CORE.config_path = data["file"] + loader = _vscode_loader file_name = CORE.config_path - _print_file_read_event(file_name) - raw_yaml = _read_file_content_from_json_on_stdin() command_line_substitutions: dict[str, Any] = ( dict(args.substitution) if args.substitution else {} ) vs = VSCodeResult() try: - config = parse_yaml(file_name, StringIO(raw_yaml)) + config = loader(file_name) res = validate_config(config, command_line_substitutions) except Exception as err: # pylint: disable=broad-except vs.add_yaml_error(str(err)) diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index 431f397e38..cbe3fef272 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -3,12 +3,12 @@ from __future__ import annotations import fnmatch import functools import inspect -from io import TextIOWrapper +from io import BytesIO, TextIOBase, TextIOWrapper from ipaddress import _BaseAddress import logging import math import os -from typing import Any +from typing import Any, Callable import uuid import yaml @@ -69,7 +69,10 @@ class ESPForceValue: pass -def make_data_base(value, from_database: ESPHomeDataBase = None): +def make_data_base( + value, from_database: ESPHomeDataBase = None +) -> ESPHomeDataBase | Any: + """Wrap a value in a ESPHomeDataBase object.""" try: value = add_class_to_obj(value, ESPHomeDataBase) if from_database is not None: @@ -102,6 +105,11 @@ def _add_data_ref(fn): class ESPHomeLoaderMixin: """Loader class that keeps track of line numbers.""" + def __init__(self, name: str, yaml_loader: Callable[[str], dict[str, Any]]) -> None: + """Initialize the loader.""" + self.name = name + self.yaml_loader = yaml_loader + @_add_data_ref def construct_yaml_int(self, node): return super().construct_yaml_int(node) @@ -127,7 +135,7 @@ class ESPHomeLoaderMixin: return super().construct_yaml_seq(node) @_add_data_ref - def construct_yaml_map(self, node): + def construct_yaml_map(self, node: yaml.MappingNode) -> OrderedDict[str, Any]: """Traverses the given mapping node and returns a list of constructed key-value pairs.""" assert isinstance(node, yaml.MappingNode) # A list of key-value pairs we find in the current mapping @@ -231,7 +239,7 @@ class ESPHomeLoaderMixin: return OrderedDict(pairs) @_add_data_ref - def construct_env_var(self, node): + def construct_env_var(self, node: yaml.Node) -> str: args = node.value.split() # Check for a default value if len(args) > 1: @@ -243,23 +251,23 @@ class ESPHomeLoaderMixin: ) @property - def _directory(self): + def _directory(self) -> str: return os.path.dirname(self.name) - def _rel_path(self, *args): + def _rel_path(self, *args: str) -> str: return os.path.join(self._directory, *args) @_add_data_ref - def construct_secret(self, node): + def construct_secret(self, node: yaml.Node) -> str: try: - secrets = _load_yaml_internal(self._rel_path(SECRET_YAML)) + secrets = self.yaml_loader(self._rel_path(SECRET_YAML)) except EsphomeError as e: if self.name == CORE.config_path: raise e try: main_config_dir = os.path.dirname(CORE.config_path) main_secret_yml = os.path.join(main_config_dir, SECRET_YAML) - secrets = _load_yaml_internal(main_secret_yml) + secrets = self.yaml_loader(main_secret_yml) except EsphomeError as er: raise EsphomeError(f"{e}\n{er}") from er @@ -272,7 +280,9 @@ class ESPHomeLoaderMixin: return val @_add_data_ref - def construct_include(self, node): + def construct_include( + self, node: yaml.Node + ) -> dict[str, Any] | OrderedDict[str, Any]: from esphome.const import CONF_VARS def extract_file_vars(node): @@ -290,71 +300,93 @@ class ESPHomeLoaderMixin: else: file, vars = node.value, None - result = _load_yaml_internal(self._rel_path(file)) + result = self.yaml_loader(self._rel_path(file)) if not vars: vars = {} result = substitute_vars(result, vars) return result @_add_data_ref - def construct_include_dir_list(self, node): + def construct_include_dir_list(self, node: yaml.Node) -> list[dict[str, Any]]: files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) - return [_load_yaml_internal(f) for f in files] + return [self.yaml_loader(f) for f in files] @_add_data_ref - def construct_include_dir_merge_list(self, node): + def construct_include_dir_merge_list(self, node: yaml.Node) -> list[dict[str, Any]]: files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) merged_list = [] for fname in files: - loaded_yaml = _load_yaml_internal(fname) + loaded_yaml = self.yaml_loader(fname) if isinstance(loaded_yaml, list): merged_list.extend(loaded_yaml) return merged_list @_add_data_ref - def construct_include_dir_named(self, node): + def construct_include_dir_named( + self, node: yaml.Node + ) -> OrderedDict[str, dict[str, Any]]: files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) mapping = OrderedDict() for fname in files: filename = os.path.splitext(os.path.basename(fname))[0] - mapping[filename] = _load_yaml_internal(fname) + mapping[filename] = self.yaml_loader(fname) return mapping @_add_data_ref - def construct_include_dir_merge_named(self, node): + def construct_include_dir_merge_named( + self, node: yaml.Node + ) -> OrderedDict[str, dict[str, Any]]: files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) mapping = OrderedDict() for fname in files: - loaded_yaml = _load_yaml_internal(fname) + loaded_yaml = self.yaml_loader(fname) if isinstance(loaded_yaml, dict): mapping.update(loaded_yaml) return mapping @_add_data_ref - def construct_lambda(self, node): + def construct_lambda(self, node: yaml.Node) -> Lambda: return Lambda(str(node.value)) @_add_data_ref - def construct_force(self, node): + def construct_force(self, node: yaml.Node) -> ESPForceValue: obj = self.construct_scalar(node) return add_class_to_obj(obj, ESPForceValue) @_add_data_ref - def construct_extend(self, node): + def construct_extend(self, node: yaml.Node) -> Extend: return Extend(str(node.value)) @_add_data_ref - def construct_remove(self, node): + def construct_remove(self, node: yaml.Node) -> Remove: return Remove(str(node.value)) class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader): """Loader class that keeps track of line numbers.""" + def __init__( + self, + stream: TextIOBase | BytesIO, + name: str, + yaml_loader: Callable[[str], dict[str, Any]], + ) -> None: + FastestAvailableSafeLoader.__init__(self, stream) + ESPHomeLoaderMixin.__init__(self, name, yaml_loader) + class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader): """Loader class that keeps track of line numbers.""" + def __init__( + self, + stream: TextIOBase | BytesIO, + name: str, + yaml_loader: Callable[[str], dict[str, Any]], + ) -> None: + PurePythonLoader.__init__(self, stream) + ESPHomeLoaderMixin.__init__(self, name, yaml_loader) + for _loader in (ESPHomeLoader, ESPHomePurePythonLoader): _loader.add_constructor("tag:yaml.org,2002:int", _loader.construct_yaml_int) @@ -388,17 +420,30 @@ def load_yaml(fname: str, clear_secrets: bool = True) -> Any: return _load_yaml_internal(fname) -def parse_yaml(file_name: str, file_handle: TextIOWrapper) -> Any: +def _load_yaml_internal(fname: str) -> Any: + """Load a YAML file.""" + try: + with open(fname, encoding="utf-8") as f_handle: + return parse_yaml(fname, f_handle) + except (UnicodeDecodeError, OSError) as err: + raise EsphomeError(f"Error reading file {fname}: {err}") from err + + +def parse_yaml( + file_name: str, file_handle: TextIOWrapper, yaml_loader=_load_yaml_internal +) -> Any: """Parse a YAML file.""" try: - return _load_yaml_internal_with_type(ESPHomeLoader, file_name, file_handle) + return _load_yaml_internal_with_type( + ESPHomeLoader, file_name, file_handle, yaml_loader + ) except EsphomeError: # Loading failed, so we now load with the Python loader which has more # readable exceptions # Rewind the stream so we can try again file_handle.seek(0, 0) return _load_yaml_internal_with_type( - ESPHomePurePythonLoader, file_name, file_handle + ESPHomePurePythonLoader, file_name, file_handle, yaml_loader ) @@ -435,23 +480,14 @@ def substitute_vars(config, vars): return result -def _load_yaml_internal(fname: str) -> Any: - """Load a YAML file.""" - try: - with open(fname, encoding="utf-8") as f_handle: - return parse_yaml(fname, f_handle) - except (UnicodeDecodeError, OSError) as err: - raise EsphomeError(f"Error reading file {fname}: {err}") from err - - def _load_yaml_internal_with_type( loader_type: type[ESPHomeLoader] | type[ESPHomePurePythonLoader], fname: str, content: TextIOWrapper, + yaml_loader: Any, ) -> Any: """Load a YAML file.""" - loader = loader_type(content) - loader.name = fname + loader = loader_type(content, fname, yaml_loader) try: return loader.get_single_data() or OrderedDict() except yaml.YAMLError as exc: @@ -470,7 +506,7 @@ def dump(dict_, show_secrets=False): ) -def _is_file_valid(name): +def _is_file_valid(name: str) -> bool: """Decide if a file is valid.""" return not name.startswith(".") diff --git a/tests/unit_tests/test_vscode.py b/tests/unit_tests/test_vscode.py new file mode 100644 index 0000000000..f5ebd63f60 --- /dev/null +++ b/tests/unit_tests/test_vscode.py @@ -0,0 +1,125 @@ +import json +import os +from unittest.mock import Mock, patch + +from esphome import vscode + + +def _run_repl_test(input_data): + """Reusable test function for different input scenarios.""" + input_data.append(_exit()) + with ( + patch("builtins.input", side_effect=input_data), + patch("sys.stdout") as mock_stdout, + ): + args = Mock([]) + args.ace = False + args.substitution = None + vscode.read_config(args) + + # Capture printed output + full_output = "".join(call[0][0] for call in mock_stdout.write.call_args_list) + return full_output.strip().split("\n") + + +def _validate(file_path: str): + return json.dumps({"type": "validate", "file": file_path}) + + +def _file_response(data: str): + return json.dumps({"type": "file_response", "content": data}) + + +def _read_file(file_path: str): + return json.dumps({"type": "read_file", "path": file_path}) + + +def _exit(): + return json.dumps({"type": "exit"}) + + +RESULT_NO_ERROR = '{"type": "result", "yaml_errors": [], "validation_errors": []}' + + +def test_multi_file(): + source_path = os.path.join("dir_path", "x.yaml") + output_lines = _run_repl_test( + [ + _validate(source_path), + # read_file x.yaml + _file_response("""esphome: + name: test1 +esp8266: + board: !secret my_secret_board +"""), + # read_file secrets.yaml + _file_response("""my_secret_board: esp1f"""), + ] + ) + + expected_lines = [ + _read_file(source_path), + _read_file(os.path.join("dir_path", "secrets.yaml")), + RESULT_NO_ERROR, + ] + + assert output_lines == expected_lines + + +def test_shows_correct_range_error(): + source_path = os.path.join("dir_path", "x.yaml") + output_lines = _run_repl_test( + [ + _validate(source_path), + # read_file x.yaml + _file_response("""esphome: + name: test1 +esp8266: + broad: !secret my_secret_board # typo here +"""), + # read_file secrets.yaml + _file_response("""my_secret_board: esp1f"""), + ] + ) + + assert len(output_lines) == 3 + error = json.loads(output_lines[2]) + validation_error = error["validation_errors"][0] + assert validation_error["message"].startswith("[broad] is an invalid option for") + range = validation_error["range"] + assert range["document"] == source_path + assert range["start_line"] == 3 + assert range["start_col"] == 2 + assert range["end_line"] == 3 + assert range["end_col"] == 7 + + +def test_shows_correct_loaded_file_error(): + source_path = os.path.join("dir_path", "x.yaml") + output_lines = _run_repl_test( + [ + _validate(source_path), + # read_file x.yaml + _file_response("""esphome: + name: test1 + +packages: + board: !include .pkg.esp8266.yaml +"""), + # read_file .pkg.esp8266.yaml + _file_response("""esp8266: + broad: esp1f # typo here +"""), + ] + ) + + assert len(output_lines) == 3 + error = json.loads(output_lines[2]) + validation_error = error["validation_errors"][0] + assert validation_error["message"].startswith("[broad] is an invalid option for") + range = validation_error["range"] + assert range["document"] == os.path.join("dir_path", ".pkg.esp8266.yaml") + assert range["start_line"] == 1 + assert range["start_col"] == 2 + assert range["end_line"] == 1 + assert range["end_col"] == 7 diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index 828b2bf14b..f31e9554dc 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -42,3 +42,23 @@ def test_loading_a_missing_file(fixture_path): yaml_util.load_yaml(yaml_file) except EsphomeError as err: assert "missing.yaml" in str(err) + + +def test_parsing_with_custom_loader(fixture_path): + """Test custom loader used for vscode connection + Default loader is tested in test_include_with_vars + """ + yaml_file = fixture_path / "yaml_util" / "includetest.yaml" + + loader_calls = [] + + def custom_loader(fname): + loader_calls.append(fname) + + with open(yaml_file, encoding="utf-8") as f_handle: + yaml_util.parse_yaml(yaml_file, f_handle, custom_loader) + + assert len(loader_calls) == 3 + assert loader_calls[0].endswith("includes/included.yaml") + assert loader_calls[1].endswith("includes/list.yaml") + assert loader_calls[2].endswith("includes/scalar.yaml")