Improve UI in pylint plugin (#74157)

* Adjust FlowResult result type

* Adjust tests

* Adjust return_type

* Use StrEnum for base device_class

* Add test for device_class

* Add and use SentinelValues.DEVICE_CLASS

* Remove duplicate device_class

* Cleanup return-type

* Drop inheritance check from device_class

* Add caching for class methods

* Improve tests

* Adjust duplicate checks

* Adjust tests

* Fix rebase
This commit is contained in:
epenet 2022-08-02 00:03:52 +02:00 committed by GitHub
parent 652a8e9e8a
commit 3eafe13085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 116 additions and 44 deletions

View File

@ -16,7 +16,6 @@ class _Special(Enum):
"""Sentinel values""" """Sentinel values"""
UNDEFINED = 1 UNDEFINED = 1
DEVICE_CLASS = 2
_PLATFORMS: set[str] = {platform.value for platform in Platform} _PLATFORMS: set[str] = {platform.value for platform in Platform}
@ -466,6 +465,7 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
} }
# Overriding properties and functions are normally checked by mypy, and will only # Overriding properties and functions are normally checked by mypy, and will only
# be checked by pylint when --ignore-missing-annotations is False # be checked by pylint when --ignore-missing-annotations is False
_ENTITY_MATCH: list[TypeHintMatch] = [ _ENTITY_MATCH: list[TypeHintMatch] = [
TypeHintMatch( TypeHintMatch(
function_name="should_poll", function_name="should_poll",
@ -505,7 +505,7 @@ _ENTITY_MATCH: list[TypeHintMatch] = [
), ),
TypeHintMatch( TypeHintMatch(
function_name="device_class", function_name="device_class",
return_type=[_Special.DEVICE_CLASS, "str", None], return_type=["str", None],
), ),
TypeHintMatch( TypeHintMatch(
function_name="unit_of_measurement", function_name="unit_of_measurement",
@ -1416,15 +1416,6 @@ def _is_valid_type(
if expected_type is _Special.UNDEFINED: if expected_type is _Special.UNDEFINED:
return True return True
# Special case for device_class
if expected_type is _Special.DEVICE_CLASS and in_return:
return (
isinstance(node, nodes.Name)
and node.name.endswith("DeviceClass")
or isinstance(node, nodes.Attribute)
and node.attrname.endswith("DeviceClass")
)
if isinstance(expected_type, list): if isinstance(expected_type, list):
for expected_type_item in expected_type: for expected_type_item in expected_type:
if _is_valid_type(expected_type_item, node, in_return): if _is_valid_type(expected_type_item, node, in_return):
@ -1636,18 +1627,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
def visit_classdef(self, node: nodes.ClassDef) -> None: def visit_classdef(self, node: nodes.ClassDef) -> None:
"""Called when a ClassDef node is visited.""" """Called when a ClassDef node is visited."""
ancestor: nodes.ClassDef ancestor: nodes.ClassDef
checked_class_methods: set[str] = set()
for ancestor in node.ancestors(): for ancestor in node.ancestors():
for class_matches in self._class_matchers: for class_matches in self._class_matchers:
if ancestor.name == class_matches.base_class: if ancestor.name == class_matches.base_class:
self._visit_class_functions(node, class_matches.matches) self._visit_class_functions(
node, class_matches.matches, checked_class_methods
)
def _visit_class_functions( def _visit_class_functions(
self, node: nodes.ClassDef, matches: list[TypeHintMatch] self,
node: nodes.ClassDef,
matches: list[TypeHintMatch],
checked_class_methods: set[str],
) -> None: ) -> None:
cached_methods: list[nodes.FunctionDef] = list(node.mymethods())
for match in matches: for match in matches:
for function_node in node.mymethods(): for function_node in cached_methods:
if function_node.name in checked_class_methods:
continue
if match.need_to_check_function(function_node): if match.need_to_check_function(function_node):
self._check_function(function_node, match) self._check_function(function_node, match)
checked_class_methods.add(function_node.name)
def visit_functiondef(self, node: nodes.FunctionDef) -> None: def visit_functiondef(self, node: nodes.FunctionDef) -> None:
"""Called when a FunctionDef node is visited.""" """Called when a FunctionDef node is visited."""

View File

@ -307,7 +307,10 @@ def test_invalid_config_flow_step(
"""Ensure invalid hints are rejected for ConfigFlow step.""" """Ensure invalid hints are rejected for ConfigFlow step."""
class_node, func_node, arg_node = astroid.extract_node( class_node, func_node, arg_node = astroid.extract_node(
""" """
class ConfigFlow(): class FlowHandler():
pass
class ConfigFlow(FlowHandler):
pass pass
class AxisFlowHandler( #@ class AxisFlowHandler( #@
@ -329,18 +332,18 @@ def test_invalid_config_flow_step(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=arg_node, node=arg_node,
args=(2, "ZeroconfServiceInfo", "async_step_zeroconf"), args=(2, "ZeroconfServiceInfo", "async_step_zeroconf"),
line=10, line=13,
col_offset=8, col_offset=8,
end_line=10, end_line=13,
end_col_offset=27, end_col_offset=27,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-return-type", msg_id="hass-return-type",
node=func_node, node=func_node,
args=("FlowResult", "async_step_zeroconf"), args=("FlowResult", "async_step_zeroconf"),
line=8, line=11,
col_offset=4, col_offset=4,
end_line=8, end_line=11,
end_col_offset=33, end_col_offset=33,
), ),
): ):
@ -353,7 +356,10 @@ def test_valid_config_flow_step(
"""Ensure valid hints are accepted for ConfigFlow step.""" """Ensure valid hints are accepted for ConfigFlow step."""
class_node = astroid.extract_node( class_node = astroid.extract_node(
""" """
class ConfigFlow(): class FlowHandler():
pass
class ConfigFlow(FlowHandler):
pass pass
class AxisFlowHandler( #@ class AxisFlowHandler( #@
@ -377,9 +383,16 @@ def test_invalid_config_flow_async_get_options_flow(
linter: UnittestLinter, type_hint_checker: BaseChecker linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None: ) -> None:
"""Ensure invalid hints are rejected for ConfigFlow async_get_options_flow.""" """Ensure invalid hints are rejected for ConfigFlow async_get_options_flow."""
# AxisOptionsFlow doesn't inherit OptionsFlow, and therefore should fail
class_node, func_node, arg_node = astroid.extract_node( class_node, func_node, arg_node = astroid.extract_node(
""" """
class ConfigFlow(): class FlowHandler():
pass
class ConfigFlow(FlowHandler):
pass
class OptionsFlow(FlowHandler):
pass pass
class AxisOptionsFlow(): class AxisOptionsFlow():
@ -403,18 +416,18 @@ def test_invalid_config_flow_async_get_options_flow(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=arg_node, node=arg_node,
args=(1, "ConfigEntry", "async_get_options_flow"), args=(1, "ConfigEntry", "async_get_options_flow"),
line=12, line=18,
col_offset=8, col_offset=8,
end_line=12, end_line=18,
end_col_offset=20, end_col_offset=20,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-return-type", msg_id="hass-return-type",
node=func_node, node=func_node,
args=("OptionsFlow", "async_get_options_flow"), args=("OptionsFlow", "async_get_options_flow"),
line=11, line=17,
col_offset=4, col_offset=4,
end_line=11, end_line=17,
end_col_offset=30, end_col_offset=30,
), ),
): ):
@ -427,10 +440,13 @@ def test_valid_config_flow_async_get_options_flow(
"""Ensure valid hints are accepted for ConfigFlow async_get_options_flow.""" """Ensure valid hints are accepted for ConfigFlow async_get_options_flow."""
class_node = astroid.extract_node( class_node = astroid.extract_node(
""" """
class ConfigFlow(): class FlowHandler():
pass pass
class OptionsFlow(): class ConfigFlow(FlowHandler):
pass
class OptionsFlow(FlowHandler):
pass pass
class AxisOptionsFlow(OptionsFlow): class AxisOptionsFlow(OptionsFlow):
@ -467,7 +483,10 @@ def test_invalid_entity_properties(
class_node, prop_node, func_node = astroid.extract_node( class_node, prop_node, func_node = astroid.extract_node(
""" """
class LockEntity(): class Entity():
pass
class LockEntity(Entity):
pass pass
class DoorLock( #@ class DoorLock( #@
@ -495,27 +514,27 @@ def test_invalid_entity_properties(
msg_id="hass-return-type", msg_id="hass-return-type",
node=prop_node, node=prop_node,
args=(["str", None], "changed_by"), args=(["str", None], "changed_by"),
line=9, line=12,
col_offset=4, col_offset=4,
end_line=9, end_line=12,
end_col_offset=18, end_col_offset=18,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=func_node, node=func_node,
args=("kwargs", "Any", "async_lock"), args=("kwargs", "Any", "async_lock"),
line=14, line=17,
col_offset=4, col_offset=4,
end_line=14, end_line=17,
end_col_offset=24, end_col_offset=24,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-return-type", msg_id="hass-return-type",
node=func_node, node=func_node,
args=("None", "async_lock"), args=("None", "async_lock"),
line=14, line=17,
col_offset=4, col_offset=4,
end_line=14, end_line=17,
end_col_offset=24, end_col_offset=24,
), ),
): ):
@ -531,7 +550,10 @@ def test_ignore_invalid_entity_properties(
class_node = astroid.extract_node( class_node = astroid.extract_node(
""" """
class LockEntity(): class Entity():
pass
class LockEntity(Entity):
pass pass
class DoorLock( #@ class DoorLock( #@
@ -566,7 +588,13 @@ def test_named_arguments(
class_node, func_node, percentage_node, preset_mode_node = astroid.extract_node( class_node, func_node, percentage_node, preset_mode_node = astroid.extract_node(
""" """
class FanEntity(): class Entity():
pass
class ToggleEntity(Entity):
pass
class FanEntity(ToggleEntity):
pass pass
class MyFan( #@ class MyFan( #@
@ -591,36 +619,36 @@ def test_named_arguments(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=percentage_node, node=percentage_node,
args=("percentage", "int | None", "async_turn_on"), args=("percentage", "int | None", "async_turn_on"),
line=10, line=16,
col_offset=8, col_offset=8,
end_line=10, end_line=16,
end_col_offset=18, end_col_offset=18,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=preset_mode_node, node=preset_mode_node,
args=("preset_mode", "str | None", "async_turn_on"), args=("preset_mode", "str | None", "async_turn_on"),
line=12, line=18,
col_offset=8, col_offset=8,
end_line=12, end_line=18,
end_col_offset=24, end_col_offset=24,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-argument-type", msg_id="hass-argument-type",
node=func_node, node=func_node,
args=("kwargs", "Any", "async_turn_on"), args=("kwargs", "Any", "async_turn_on"),
line=8, line=14,
col_offset=4, col_offset=4,
end_line=8, end_line=14,
end_col_offset=27, end_col_offset=27,
), ),
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-return-type", msg_id="hass-return-type",
node=func_node, node=func_node,
args=("None", "async_turn_on"), args=("None", "async_turn_on"),
line=8, line=14,
col_offset=4, col_offset=4,
end_line=8, end_line=14,
end_col_offset=27, end_col_offset=27,
), ),
): ):
@ -829,3 +857,46 @@ def test_invalid_long_tuple(
), ),
): ):
type_hint_checker.visit_classdef(class_node) type_hint_checker.visit_classdef(class_node)
def test_invalid_device_class(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure invalid hints are rejected for entity device_class."""
# Set bypass option
type_hint_checker.config.ignore_missing_annotations = False
class_node, prop_node = astroid.extract_node(
"""
class Entity():
pass
class CoverEntity(Entity):
pass
class MyCover( #@
CoverEntity
):
@property
def device_class( #@
self
):
pass
""",
"homeassistant.components.pylint_test.cover",
)
type_hint_checker.visit_module(class_node.parent)
with assert_adds_messages(
linter,
pylint.testutils.MessageTest(
msg_id="hass-return-type",
node=prop_node,
args=(["CoverDeviceClass", "str", None], "device_class"),
line=12,
col_offset=4,
end_line=12,
end_col_offset=20,
),
):
type_hint_checker.visit_classdef(class_node)