Add base Entity classes to enforce-class-module pylint plugin (#126473)

This commit is contained in:
epenet 2024-09-24 08:52:07 +02:00 committed by GitHub
parent 31200040da
commit 61ff40c299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 10 deletions

View File

@ -98,7 +98,7 @@ class PassiveBluetoothDataUpdateCoordinator(
self.async_update_listeners() self.async_update_listeners()
class PassiveBluetoothCoordinatorEntity( class PassiveBluetoothCoordinatorEntity( # pylint: disable=hass-enforce-class-module
BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT] BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT]
): ):
"""A class for entities using DataUpdateCoordinator.""" """A class for entities using DataUpdateCoordinator."""

View File

@ -28,7 +28,9 @@ async def async_setup_entry(
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class StarlinkDeviceTrackerEntityDescription(EntityDescription): class StarlinkDeviceTrackerEntityDescription( # pylint: disable=hass-enforce-class-module
EntityDescription
):
"""Describes a Starlink button entity.""" """Describes a Starlink button entity."""
latitude_fn: Callable[[StarlinkData], float] latitude_fn: Callable[[StarlinkData], float]

View File

@ -2,14 +2,23 @@
from __future__ import annotations from __future__ import annotations
from ast import ClassDef
from astroid import nodes from astroid import nodes
from pylint.checkers import BaseChecker from pylint.checkers import BaseChecker
from pylint.lint import PyLinter from pylint.lint import PyLinter
from homeassistant.const import Platform from homeassistant.const import Platform
_BASE_ENTITY_MODULES: set[str] = {
"BaseCoordinatorEntity",
"CoordinatorEntity",
"Entity",
"EntityDescription",
"ManualTriggerEntity",
"RestoreEntity",
"ToggleEntity",
"ToggleEntityDescription",
"TriggerBaseEntity",
}
_MODULES: dict[str, set[str]] = { _MODULES: dict[str, set[str]] = {
"air_quality": {"AirQualityEntity"}, "air_quality": {"AirQualityEntity"},
"alarm_control_panel": { "alarm_control_panel": {
@ -82,6 +91,11 @@ _ENTITY_COMPONENTS: set[str] = {platform.value for platform in Platform}.union(
) )
_MODULE_CLASSES = {
class_name for classes in _MODULES.values() for class_name in classes
}
class HassEnforceClassModule(BaseChecker): class HassEnforceClassModule(BaseChecker):
"""Checker for class in correct module.""" """Checker for class in correct module."""
@ -106,11 +120,15 @@ class HassEnforceClassModule(BaseChecker):
current_integration = parts[2] current_integration = parts[2]
current_module = parts[3] if len(parts) > 3 else "" current_module = parts[3] if len(parts) > 3 else ""
ancestors = list(node.ancestors())
if current_module != "entity" and current_integration not in _ENTITY_COMPONENTS: if current_module != "entity" and current_integration not in _ENTITY_COMPONENTS:
top_level_ancestors = list(node.ancestors(recurs=False)) top_level_ancestors = list(node.ancestors(recurs=False))
for ancestor in top_level_ancestors: for ancestor in top_level_ancestors:
if ancestor.name == "Entity": if ancestor.name in _BASE_ENTITY_MODULES and not any(
anc.name in _MODULE_CLASSES for anc in ancestors
):
self.add_message( self.add_message(
"hass-enforce-class-module", "hass-enforce-class-module",
node=node, node=node,
@ -118,15 +136,10 @@ class HassEnforceClassModule(BaseChecker):
) )
return return
ancestors: list[ClassDef] | None = None
for expected_module, classes in _MODULES.items(): for expected_module, classes in _MODULES.items():
if expected_module in (current_module, current_integration): if expected_module in (current_module, current_integration):
continue continue
if ancestors is None:
ancestors = list(node.ancestors()) # cache result for other modules
for ancestor in ancestors: for ancestor in ancestors:
if ancestor.name in classes: if ancestor.name in classes:
self.add_message( self.add_message(

View File

@ -84,6 +84,12 @@ def test_enforce_class_platform_good(
class CustomSensorEntity(SensorEntity): class CustomSensorEntity(SensorEntity):
pass pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity, SensorEntity):
pass
""" """
root_node = astroid.parse(code, path) root_node = astroid.parse(code, path)
walker = ASTWalker(linter) walker = ASTWalker(linter)
@ -115,6 +121,12 @@ def test_enforce_class_module_bad_simple(
class TestCoordinator(DataUpdateCoordinator): class TestCoordinator(DataUpdateCoordinator):
pass pass
class CoordinatorEntity:
pass
class CustomCoordinatorSensorEntity(CoordinatorEntity):
pass
""", """,
path, path,
) )
@ -133,6 +145,16 @@ def test_enforce_class_module_bad_simple(
end_line=5, end_line=5,
end_col_offset=21, end_col_offset=21,
), ),
MessageTest(
msg_id="hass-enforce-class-module",
line=11,
node=root_node.body[3],
args=("CoordinatorEntity", "entity"),
confidence=UNDEFINED,
col_offset=0,
end_line=11,
end_col_offset=35,
),
): ):
walker.walk(root_node) walker.walk(root_node)