diff --git a/tests/pylint/__init__.py b/tests/pylint/__init__.py index d6bdd6675f0..e03a2d2a118 100644 --- a/tests/pylint/__init__.py +++ b/tests/pylint/__init__.py @@ -1 +1,31 @@ """Tests for pylint.""" +import contextlib + +from pylint.testutils.unittest_linter import UnittestLinter + + +@contextlib.contextmanager +def assert_no_messages(linter: UnittestLinter): + """Assert that no messages are added by the given method.""" + with assert_adds_messages(linter): + yield + + +@contextlib.contextmanager +def assert_adds_messages(linter: UnittestLinter, *messages): + """Assert that exactly the given method adds the given messages. + + The list of messages must exactly match *all* the messages added by the + method. Additionally, we check to see whether the args in each message can + actually be substituted into the message string. + """ + yield + got = linter.release_messages() + no_msg = "No message." + expected = "\n".join(repr(m) for m in messages) or no_msg + got_str = "\n".join(repr(m) for m in got) or no_msg + msg = ( + "Expected messages did not match actual.\n" + f"\nExpected:\n{expected}\n\nGot:\n{got_str}\n" + ) + assert got == list(messages), msg diff --git a/tests/pylint/conftest.py b/tests/pylint/conftest.py new file mode 100644 index 00000000000..887f50fb628 --- /dev/null +++ b/tests/pylint/conftest.py @@ -0,0 +1,30 @@ +"""Configuration for pylint tests.""" +from importlib.machinery import SourceFileLoader +from types import ModuleType + +from pylint.checkers import BaseChecker +from pylint.testutils.unittest_linter import UnittestLinter +import pytest + + +@pytest.fixture(name="hass_enforce_type_hints") +def hass_enforce_type_hints_fixture() -> ModuleType: + """Fixture to provide a requests mocker.""" + loader = SourceFileLoader( + "hass_enforce_type_hints", "pylint/plugins/hass_enforce_type_hints.py" + ) + return loader.load_module(None) + + +@pytest.fixture(name="linter") +def linter_fixture() -> UnittestLinter: + """Fixture to provide a requests mocker.""" + return UnittestLinter() + + +@pytest.fixture(name="type_hint_checker") +def type_hint_checker_fixture(hass_enforce_type_hints, linter) -> BaseChecker: + """Fixture to provide a requests mocker.""" + type_hint_checker = hass_enforce_type_hints.HassTypeHintChecker(linter) + type_hint_checker.module = "homeassistant.components.pylint_test" + return type_hint_checker diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index fe60ed022f4..81fdd2fa916 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1,16 +1,17 @@ """Tests for pylint hass_enforce_type_hints plugin.""" # pylint:disable=protected-access -from importlib.machinery import SourceFileLoader import re +from types import ModuleType +from unittest.mock import patch +import astroid +from pylint.checkers import BaseChecker +import pylint.testutils +from pylint.testutils.unittest_linter import UnittestLinter import pytest -loader = SourceFileLoader( - "hass_enforce_type_hints", "pylint/plugins/hass_enforce_type_hints.py" -) -hass_enforce_type_hints = loader.load_module(None) -_TYPE_HINT_MATCHERS: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_MATCHERS +from . import assert_adds_messages, assert_no_messages @pytest.mark.parametrize( @@ -20,9 +21,17 @@ _TYPE_HINT_MATCHERS: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_ ("Callable[..., Awaitable[None]]", "Callable", "...", "Awaitable[None]"), ], ) -def test_regex_x_of_y_comma_z(string, expected_x, expected_y, expected_z): +def test_regex_x_of_y_comma_z( + hass_enforce_type_hints: ModuleType, + string: str, + expected_x: str, + expected_y: str, + expected_z: str, +) -> None: """Test x_of_y_comma_z regexes.""" - assert (match := _TYPE_HINT_MATCHERS["x_of_y_comma_z"].match(string)) + matchers: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_MATCHERS + + assert (match := matchers["x_of_y_comma_z"].match(string)) assert match.group(0) == string assert match.group(1) == expected_x assert match.group(2) == expected_y @@ -33,9 +42,122 @@ def test_regex_x_of_y_comma_z(string, expected_x, expected_y, expected_z): ("string", "expected_a", "expected_b"), [("DiscoveryInfoType | None", "DiscoveryInfoType", "None")], ) -def test_regex_a_or_b(string, expected_a, expected_b): +def test_regex_a_or_b( + hass_enforce_type_hints: ModuleType, string: str, expected_a: str, expected_b: str +) -> None: """Test a_or_b regexes.""" - assert (match := _TYPE_HINT_MATCHERS["a_or_b"].match(string)) + matchers: dict[str, re.Pattern] = hass_enforce_type_hints._TYPE_HINT_MATCHERS + + assert (match := matchers["a_or_b"].match(string)) assert match.group(0) == string assert match.group(1) == expected_a assert match.group(2) == expected_b + + +@pytest.mark.parametrize( + "code", + [ + """ + async def setup( #@ + arg1, arg2 + ): + pass + """ + ], +) +def test_ignore_not_annotations( + hass_enforce_type_hints: ModuleType, type_hint_checker: BaseChecker, code: str +) -> None: + """Ensure that _is_valid_type is not run if there are no annotations.""" + func_node = astroid.extract_node(code) + + with patch.object( + hass_enforce_type_hints, "_is_valid_type", return_value=True + ) as is_valid_type: + type_hint_checker.visit_asyncfunctiondef(func_node) + is_valid_type.assert_not_called() + + +@pytest.mark.parametrize( + "code", + [ + """ + async def setup( #@ + arg1: ArgHint, arg2 + ): + pass + """, + """ + async def setup( #@ + arg1, arg2 + ) -> ReturnHint: + pass + """, + """ + async def setup( #@ + arg1: ArgHint, arg2: ArgHint + ) -> ReturnHint: + pass + """, + ], +) +def test_dont_ignore_partial_annotations( + hass_enforce_type_hints: ModuleType, type_hint_checker: BaseChecker, code: str +) -> None: + """Ensure that _is_valid_type is run if there is at least one annotation.""" + func_node = astroid.extract_node(code) + + with patch.object( + hass_enforce_type_hints, "_is_valid_type", return_value=True + ) as is_valid_type: + type_hint_checker.visit_asyncfunctiondef(func_node) + is_valid_type.assert_called() + + +def test_invalid_discovery_info( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure invalid hints are rejected for discovery_info.""" + type_hint_checker.module = "homeassistant.components.pylint_test.device_tracker" + func_node, discovery_info_node = astroid.extract_node( + """ + async def async_setup_scanner( #@ + hass: HomeAssistant, + config: ConfigType, + async_see: Callable[..., Awaitable[None]], + discovery_info: dict[str, Any] | None = None, #@ + ) -> bool: + pass + """ + ) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=discovery_info_node, + args=(4, "DiscoveryInfoType | None"), + ), + ): + type_hint_checker.visit_asyncfunctiondef(func_node) + + +def test_valid_discovery_info( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for discovery_info.""" + type_hint_checker.module = "homeassistant.components.pylint_test.device_tracker" + func_node = astroid.extract_node( + """ + async def async_setup_scanner( #@ + hass: HomeAssistant, + config: ConfigType, + async_see: Callable[..., Awaitable[None]], + discovery_info: DiscoveryInfoType | None = None, + ) -> bool: + pass + """ + ) + + with assert_no_messages(linter): + type_hint_checker.visit_asyncfunctiondef(func_node)