Enforce type hints for config_flow (#72756)

* Enforce type hints for config_flow

* Keep astroid migration for another PR

* Defer elif case

* Adjust tests

* Use ancestors

* Match on single base_class

* Invert for loops

* Review comments

* slots is new in 3.10
This commit is contained in:
epenet 2022-06-01 13:09:53 +02:00 committed by GitHub
parent 5d2326386d
commit 4c7837a576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 163 additions and 2 deletions

View File

@ -25,6 +25,14 @@ class TypeHintMatch:
return_type: list[str] | str | None | object return_type: list[str] | str | None | object
@dataclass
class ClassTypeHintMatch:
"""Class for pattern matching."""
base_class: str
matches: list[TypeHintMatch]
_TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = { _TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = {
# a_or_b matches items such as "DiscoveryInfoType | None" # a_or_b matches items such as "DiscoveryInfoType | None"
"a_or_b": re.compile(r"^(\w+) \| (\w+)$"), "a_or_b": re.compile(r"^(\w+) \| (\w+)$"),
@ -368,6 +376,65 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
], ],
} }
_CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
"config_flow": [
ClassTypeHintMatch(
base_class="ConfigFlow",
matches=[
TypeHintMatch(
function_name="async_step_dhcp",
arg_types={
1: "DhcpServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_hassio",
arg_types={
1: "HassioServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_homekit",
arg_types={
1: "ZeroconfServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_mqtt",
arg_types={
1: "MqttServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_ssdp",
arg_types={
1: "SsdpServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_usb",
arg_types={
1: "UsbServiceInfo",
},
return_type="FlowResult",
),
TypeHintMatch(
function_name="async_step_zeroconf",
arg_types={
1: "ZeroconfServiceInfo",
},
return_type="FlowResult",
),
],
),
]
}
def _is_valid_type( def _is_valid_type(
expected_type: list[str] | str | None | object, node: astroid.NodeNG expected_type: list[str] | str | None | object, node: astroid.NodeNG
@ -494,10 +561,12 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
def __init__(self, linter: PyLinter | None = None) -> None: def __init__(self, linter: PyLinter | None = None) -> None:
super().__init__(linter) super().__init__(linter)
self._function_matchers: list[TypeHintMatch] = [] self._function_matchers: list[TypeHintMatch] = []
self._class_matchers: list[ClassTypeHintMatch] = []
def visit_module(self, node: astroid.Module) -> None: def visit_module(self, node: astroid.Module) -> None:
"""Called when a Module node is visited.""" """Called when a Module node is visited."""
self._function_matchers = [] self._function_matchers = []
self._class_matchers = []
if (module_platform := _get_module_platform(node.name)) is None: if (module_platform := _get_module_platform(node.name)) is None:
return return
@ -505,8 +574,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
if module_platform in _PLATFORMS: if module_platform in _PLATFORMS:
self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"]) self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"])
if matches := _FUNCTION_MATCH.get(module_platform): if function_matches := _FUNCTION_MATCH.get(module_platform):
self._function_matchers.extend(matches) self._function_matchers.extend(function_matches)
if class_matches := _CLASS_MATCH.get(module_platform):
self._class_matchers = class_matches
def visit_classdef(self, node: astroid.ClassDef) -> None:
"""Called when a ClassDef node is visited."""
ancestor: astroid.ClassDef
for ancestor in node.ancestors():
for class_matches in self._class_matchers:
if ancestor.name == class_matches.base_class:
self._visit_class_functions(node, class_matches.matches)
def _visit_class_functions(
self, node: astroid.ClassDef, matches: list[TypeHintMatch]
) -> None:
for match in matches:
for function_node in node.mymethods():
function_name: str | None = function_node.name
if match.function_name == function_name:
self._check_function(function_node, match)
def visit_functiondef(self, node: astroid.FunctionDef) -> None: def visit_functiondef(self, node: astroid.FunctionDef) -> None:
"""Called when a FunctionDef node is visited.""" """Called when a FunctionDef node is visited."""

View File

@ -277,3 +277,75 @@ def test_valid_list_dict_str_any(
with assert_no_messages(linter): with assert_no_messages(linter):
type_hint_checker.visit_asyncfunctiondef(func_node) type_hint_checker.visit_asyncfunctiondef(func_node)
def test_invalid_config_flow_step(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure invalid hints are rejected for ConfigFlow step."""
class_node, func_node, arg_node = astroid.extract_node(
"""
class ConfigFlow():
pass
class AxisFlowHandler( #@
ConfigFlow, domain=AXIS_DOMAIN
):
async def async_step_zeroconf( #@
self,
device_config: dict #@
):
pass
""",
"homeassistant.components.pylint_test.config_flow",
)
type_hint_checker.visit_module(class_node.parent)
with assert_adds_messages(
linter,
pylint.testutils.MessageTest(
msg_id="hass-argument-type",
node=arg_node,
args=(2, "ZeroconfServiceInfo"),
line=10,
col_offset=8,
end_line=10,
end_col_offset=27,
),
pylint.testutils.MessageTest(
msg_id="hass-return-type",
node=func_node,
args="FlowResult",
line=8,
col_offset=4,
end_line=8,
end_col_offset=33,
),
):
type_hint_checker.visit_classdef(class_node)
def test_valid_config_flow_step(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure valid hints are accepted for ConfigFlow step."""
class_node = astroid.extract_node(
"""
class ConfigFlow():
pass
class AxisFlowHandler( #@
ConfigFlow, domain=AXIS_DOMAIN
):
async def async_step_zeroconf(
self,
device_config: ZeroconfServiceInfo
) -> FlowResult:
pass
""",
"homeassistant.components.pylint_test.config_flow",
)
type_hint_checker.visit_module(class_node.parent)
with assert_no_messages(linter):
type_hint_checker.visit_classdef(class_node)