Add base Entity classes to enforce-class-module pylint plugin (#126473)

This commit is contained in:
epenet 2024-09-24 08:52:07 +02:00 committed by GitHub
parent 31200040da
commit 61ff40c299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 10 deletions

View File

@ -98,7 +98,7 @@ class PassiveBluetoothDataUpdateCoordinator(
self.async_update_listeners()
class PassiveBluetoothCoordinatorEntity(
class PassiveBluetoothCoordinatorEntity( # pylint: disable=hass-enforce-class-module
BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT]
):
"""A class for entities using DataUpdateCoordinator."""

View File

@ -28,7 +28,9 @@ async def async_setup_entry(
@dataclass(frozen=True, kw_only=True)
class StarlinkDeviceTrackerEntityDescription(EntityDescription):
class StarlinkDeviceTrackerEntityDescription( # pylint: disable=hass-enforce-class-module
EntityDescription
):
"""Describes a Starlink button entity."""
latitude_fn: Callable[[StarlinkData], float]

View File

@ -2,14 +2,23 @@
from __future__ import annotations
from ast import ClassDef
from astroid import nodes
from pylint.checkers import BaseChecker
from pylint.lint import PyLinter
from homeassistant.const import Platform
_BASE_ENTITY_MODULES: set[str] = {
"BaseCoordinatorEntity",
"CoordinatorEntity",
"Entity",
"EntityDescription",
"ManualTriggerEntity",
"RestoreEntity",
"ToggleEntity",
"ToggleEntityDescription",
"TriggerBaseEntity",
}
_MODULES: dict[str, set[str]] = {
"air_quality": {"AirQualityEntity"},
"alarm_control_panel": {
@ -82,6 +91,11 @@ _ENTITY_COMPONENTS: set[str] = {platform.value for platform in Platform}.union(
)
_MODULE_CLASSES = {
class_name for classes in _MODULES.values() for class_name in classes
}
class HassEnforceClassModule(BaseChecker):
"""Checker for class in correct module."""
@ -106,11 +120,15 @@ class HassEnforceClassModule(BaseChecker):
current_integration = parts[2]
current_module = parts[3] if len(parts) > 3 else ""
ancestors = list(node.ancestors())
if current_module != "entity" and current_integration not in _ENTITY_COMPONENTS:
top_level_ancestors = list(node.ancestors(recurs=False))
for ancestor in top_level_ancestors:
if ancestor.name == "Entity":
if ancestor.name in _BASE_ENTITY_MODULES and not any(
anc.name in _MODULE_CLASSES for anc in ancestors
):
self.add_message(
"hass-enforce-class-module",
node=node,
@ -118,15 +136,10 @@ class HassEnforceClassModule(BaseChecker):
)
return
ancestors: list[ClassDef] | None = None
for expected_module, classes in _MODULES.items():
if expected_module in (current_module, current_integration):
continue
if ancestors is None:
ancestors = list(node.ancestors()) # cache result for other modules
for ancestor in ancestors:
if ancestor.name in classes:
self.add_message(

View File

@ -84,6 +84,12 @@ def test_enforce_class_platform_good(
class CustomSensorEntity(SensorEntity):
pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity, SensorEntity):
pass
"""
root_node = astroid.parse(code, path)
walker = ASTWalker(linter)
@ -115,6 +121,12 @@ def test_enforce_class_module_bad_simple(
class TestCoordinator(DataUpdateCoordinator):
pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity):
pass
""",
path,
)
@ -133,6 +145,16 @@ def test_enforce_class_module_bad_simple(
end_line=5,
end_col_offset=21,
),
MessageTest(
msg_id="hass-enforce-class-module",
line=11,
node=root_node.body[3],
args=("CoordinatorEntity", "entity"),
confidence=UNDEFINED,
col_offset=0,
end_line=11,
end_col_offset=35,
),
):
walker.walk(root_node)