Enforce config-flow type hints to get options flow (#72831)

* Enforce config-flow type hints to get options flow

* Add checks on return_type

* Fix tests

* Add tests

* Add BinOp to test

* Update tests/pylint/test_enforce_type_hints.py

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Update pylint/plugins/hass_enforce_type_hints.py

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Add TypeHintMatch property

* Update pylint/plugins/hass_enforce_type_hints.py

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
epenet 2022-06-13 11:14:30 +02:00 committed by GitHub
parent d9f3e9a71c
commit ca0a185b32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 2 deletions

View File

@ -22,6 +22,7 @@ class TypeHintMatch:
function_name: str
arg_types: dict[int, str]
return_type: list[str] | str | None | object
check_return_type_inheritance: bool = False
@dataclass
@ -380,6 +381,14 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
ClassTypeHintMatch(
base_class="ConfigFlow",
matches=[
TypeHintMatch(
function_name="async_get_options_flow",
arg_types={
0: "ConfigEntry",
},
return_type="OptionsFlow",
check_return_type_inheritance=True,
),
TypeHintMatch(
function_name="async_step_dhcp",
arg_types={
@ -504,6 +513,32 @@ def _is_valid_type(
return isinstance(node, nodes.Attribute) and node.attrname == expected_type
def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool:
if _is_valid_type(match.return_type, node):
return True
if isinstance(node, nodes.BinOp):
return _is_valid_return_type(match, node.left) and _is_valid_return_type(
match, node.right
)
if (
match.check_return_type_inheritance
and isinstance(match.return_type, str)
and isinstance(node, nodes.Name)
):
ancestor: nodes.ClassDef
for infer_node in node.infer():
if isinstance(infer_node, nodes.ClassDef):
if infer_node.name == match.return_type:
return True
for ancestor in infer_node.ancestors():
if ancestor.name == match.return_type:
return True
return False
def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]:
args = node.args
annotations: list[nodes.NodeNG | None] = (
@ -619,8 +654,10 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
)
# Check the return type.
if not _is_valid_type(return_type := match.return_type, node.returns):
self.add_message("hass-return-type", node=node, args=return_type or "None")
if not _is_valid_return_type(match, node.returns):
self.add_message(
"hass-return-type", node=node, args=match.return_type or "None"
)
def register(linter: PyLinter) -> None:

View File

@ -349,3 +349,88 @@ def test_valid_config_flow_step(
with assert_no_messages(linter):
type_hint_checker.visit_classdef(class_node)
def test_invalid_config_flow_async_get_options_flow(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure invalid hints are rejected for ConfigFlow async_get_options_flow."""
class_node, func_node, arg_node = astroid.extract_node(
"""
class ConfigFlow():
pass
class AxisOptionsFlow():
pass
class AxisFlowHandler( #@
ConfigFlow, domain=AXIS_DOMAIN
):
def async_get_options_flow( #@
config_entry #@
) -> AxisOptionsFlow:
return AxisOptionsFlow(config_entry)
""",
"homeassistant.components.pylint_test.config_flow",
)
type_hint_checker.visit_module(class_node.parent)
with assert_adds_messages(
linter,
pylint.testutils.MessageTest(
msg_id="hass-argument-type",
node=arg_node,
args=(1, "ConfigEntry"),
line=12,
col_offset=8,
end_line=12,
end_col_offset=20,
),
pylint.testutils.MessageTest(
msg_id="hass-return-type",
node=func_node,
args="OptionsFlow",
line=11,
col_offset=4,
end_line=11,
end_col_offset=30,
),
):
type_hint_checker.visit_classdef(class_node)
def test_valid_config_flow_async_get_options_flow(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure valid hints are accepted for ConfigFlow async_get_options_flow."""
class_node = astroid.extract_node(
"""
class ConfigFlow():
pass
class OptionsFlow():
pass
class AxisOptionsFlow(OptionsFlow):
pass
class OtherOptionsFlow(OptionsFlow):
pass
class AxisFlowHandler( #@
ConfigFlow, domain=AXIS_DOMAIN
):
def async_get_options_flow(
config_entry: ConfigEntry
) -> AxisOptionsFlow | OtherOptionsFlow | OptionsFlow:
if self.use_other:
return OtherOptionsFlow(config_entry)
return AxisOptionsFlow(config_entry)
""",
"homeassistant.components.pylint_test.config_flow",
)
type_hint_checker.visit_module(class_node.parent)
with assert_no_messages(linter):
type_hint_checker.visit_classdef(class_node)