From c33ba541b0f4ac024d25ed4d833b32d6817a1f71 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:11:03 +0200 Subject: [PATCH] Add flexibility to HassEnforceClassModule (#125739) * Add flexibility to HassEnforceClassModule * Adjust --- pylint/plugins/hass_enforce_class_module.py | 5 ++- tests/pylint/test_enforce_class_module.py | 38 ++++++++++++++++++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/pylint/plugins/hass_enforce_class_module.py b/pylint/plugins/hass_enforce_class_module.py index dcd42f9a1c1..b8f83b1602f 100644 --- a/pylint/plugins/hass_enforce_class_module.py +++ b/pylint/plugins/hass_enforce_class_module.py @@ -68,11 +68,14 @@ class HassEnforceClassModule(BaseChecker): # we only want to check components if not root_name.startswith("homeassistant.components."): return + parts = root_name.split(".") + current_module = parts[3] if len(parts) > 3 else "" ancestors: list[ClassDef] | None = None for match in _MODULES: - if root_name.endswith(f".{match.expected_module}"): + # Allow module.py and module/sub_module.py + if current_module == match.expected_module: continue if ancestors is None: diff --git a/tests/pylint/test_enforce_class_module.py b/tests/pylint/test_enforce_class_module.py index b0f071fde52..13d3c2538a1 100644 --- a/tests/pylint/test_enforce_class_module.py +++ b/tests/pylint/test_enforce_class_module.py @@ -41,11 +41,21 @@ from . import assert_adds_messages, assert_no_messages ), ], ) +@pytest.mark.parametrize( + "path", + [ + "homeassistant.components.pylint_test.coordinator", + "homeassistant.components.pylint_test.coordinator.my_coordinator", + ], +) def test_enforce_class_module_good( - linter: UnittestLinter, enforce_class_module_checker: BaseChecker, code: str + linter: UnittestLinter, + enforce_class_module_checker: BaseChecker, + code: str, + path: str, ) -> None: """Good test cases.""" - root_node = astroid.parse(code, "homeassistant.components.pylint_test.coordinator") + root_node = astroid.parse(code, path) walker = ASTWalker(linter) walker.add_checker(enforce_class_module_checker) @@ -53,9 +63,19 @@ def test_enforce_class_module_good( walker.walk(root_node) +@pytest.mark.parametrize( + "path", + [ + "homeassistant.components.pylint_test", + "homeassistant.components.pylint_test.my_coordinator", + "homeassistant.components.pylint_test.coordinator_other", + "homeassistant.components.pylint_test.sensor", + ], +) def test_enforce_class_module_bad_simple( linter: UnittestLinter, enforce_class_module_checker: BaseChecker, + path: str, ) -> None: """Bad test case with coordinator extending directly.""" root_node = astroid.parse( @@ -66,7 +86,7 @@ def test_enforce_class_module_bad_simple( class TestCoordinator(DataUpdateCoordinator): pass """, - "homeassistant.components.pylint_test", + path, ) walker = ASTWalker(linter) walker.add_checker(enforce_class_module_checker) @@ -87,9 +107,19 @@ def test_enforce_class_module_bad_simple( walker.walk(root_node) +@pytest.mark.parametrize( + "path", + [ + "homeassistant.components.pylint_test", + "homeassistant.components.pylint_test.my_coordinator", + "homeassistant.components.pylint_test.coordinator_other", + "homeassistant.components.pylint_test.sensor", + ], +) def test_enforce_class_module_bad_nested( linter: UnittestLinter, enforce_class_module_checker: BaseChecker, + path: str, ) -> None: """Bad test case with nested coordinators.""" root_node = astroid.parse( @@ -103,7 +133,7 @@ def test_enforce_class_module_bad_nested( class NopeCoordinator(TestCoordinator): pass """, - "homeassistant.components.pylint_test", + path, ) walker = ASTWalker(linter) walker.add_checker(enforce_class_module_checker)