mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Enforce type hints for config_flow (#72756)
* Enforce type hints for config_flow * Keep astroid migration for another PR * Defer elif case * Adjust tests * Use ancestors * Match on single base_class * Invert for loops * Review comments * slots is new in 3.10
This commit is contained in:
parent
5d2326386d
commit
4c7837a576
@ -25,6 +25,14 @@ class TypeHintMatch:
|
||||
return_type: list[str] | str | None | object
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassTypeHintMatch:
|
||||
"""Class for pattern matching."""
|
||||
|
||||
base_class: str
|
||||
matches: list[TypeHintMatch]
|
||||
|
||||
|
||||
_TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = {
|
||||
# a_or_b matches items such as "DiscoveryInfoType | None"
|
||||
"a_or_b": re.compile(r"^(\w+) \| (\w+)$"),
|
||||
@ -368,6 +376,65 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
|
||||
],
|
||||
}
|
||||
|
||||
_CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
|
||||
"config_flow": [
|
||||
ClassTypeHintMatch(
|
||||
base_class="ConfigFlow",
|
||||
matches=[
|
||||
TypeHintMatch(
|
||||
function_name="async_step_dhcp",
|
||||
arg_types={
|
||||
1: "DhcpServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_hassio",
|
||||
arg_types={
|
||||
1: "HassioServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_homekit",
|
||||
arg_types={
|
||||
1: "ZeroconfServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_mqtt",
|
||||
arg_types={
|
||||
1: "MqttServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_ssdp",
|
||||
arg_types={
|
||||
1: "SsdpServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_usb",
|
||||
arg_types={
|
||||
1: "UsbServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="async_step_zeroconf",
|
||||
arg_types={
|
||||
1: "ZeroconfServiceInfo",
|
||||
},
|
||||
return_type="FlowResult",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _is_valid_type(
|
||||
expected_type: list[str] | str | None | object, node: astroid.NodeNG
|
||||
@ -494,10 +561,12 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
|
||||
def __init__(self, linter: PyLinter | None = None) -> None:
|
||||
super().__init__(linter)
|
||||
self._function_matchers: list[TypeHintMatch] = []
|
||||
self._class_matchers: list[ClassTypeHintMatch] = []
|
||||
|
||||
def visit_module(self, node: astroid.Module) -> None:
|
||||
"""Called when a Module node is visited."""
|
||||
self._function_matchers = []
|
||||
self._class_matchers = []
|
||||
|
||||
if (module_platform := _get_module_platform(node.name)) is None:
|
||||
return
|
||||
@ -505,8 +574,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
|
||||
if module_platform in _PLATFORMS:
|
||||
self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"])
|
||||
|
||||
if matches := _FUNCTION_MATCH.get(module_platform):
|
||||
self._function_matchers.extend(matches)
|
||||
if function_matches := _FUNCTION_MATCH.get(module_platform):
|
||||
self._function_matchers.extend(function_matches)
|
||||
|
||||
if class_matches := _CLASS_MATCH.get(module_platform):
|
||||
self._class_matchers = class_matches
|
||||
|
||||
def visit_classdef(self, node: astroid.ClassDef) -> None:
|
||||
"""Called when a ClassDef node is visited."""
|
||||
ancestor: astroid.ClassDef
|
||||
for ancestor in node.ancestors():
|
||||
for class_matches in self._class_matchers:
|
||||
if ancestor.name == class_matches.base_class:
|
||||
self._visit_class_functions(node, class_matches.matches)
|
||||
|
||||
def _visit_class_functions(
|
||||
self, node: astroid.ClassDef, matches: list[TypeHintMatch]
|
||||
) -> None:
|
||||
for match in matches:
|
||||
for function_node in node.mymethods():
|
||||
function_name: str | None = function_node.name
|
||||
if match.function_name == function_name:
|
||||
self._check_function(function_node, match)
|
||||
|
||||
def visit_functiondef(self, node: astroid.FunctionDef) -> None:
|
||||
"""Called when a FunctionDef node is visited."""
|
||||
|
@ -277,3 +277,75 @@ def test_valid_list_dict_str_any(
|
||||
|
||||
with assert_no_messages(linter):
|
||||
type_hint_checker.visit_asyncfunctiondef(func_node)
|
||||
|
||||
|
||||
def test_invalid_config_flow_step(
|
||||
linter: UnittestLinter, type_hint_checker: BaseChecker
|
||||
) -> None:
|
||||
"""Ensure invalid hints are rejected for ConfigFlow step."""
|
||||
class_node, func_node, arg_node = astroid.extract_node(
|
||||
"""
|
||||
class ConfigFlow():
|
||||
pass
|
||||
|
||||
class AxisFlowHandler( #@
|
||||
ConfigFlow, domain=AXIS_DOMAIN
|
||||
):
|
||||
async def async_step_zeroconf( #@
|
||||
self,
|
||||
device_config: dict #@
|
||||
):
|
||||
pass
|
||||
""",
|
||||
"homeassistant.components.pylint_test.config_flow",
|
||||
)
|
||||
type_hint_checker.visit_module(class_node.parent)
|
||||
|
||||
with assert_adds_messages(
|
||||
linter,
|
||||
pylint.testutils.MessageTest(
|
||||
msg_id="hass-argument-type",
|
||||
node=arg_node,
|
||||
args=(2, "ZeroconfServiceInfo"),
|
||||
line=10,
|
||||
col_offset=8,
|
||||
end_line=10,
|
||||
end_col_offset=27,
|
||||
),
|
||||
pylint.testutils.MessageTest(
|
||||
msg_id="hass-return-type",
|
||||
node=func_node,
|
||||
args="FlowResult",
|
||||
line=8,
|
||||
col_offset=4,
|
||||
end_line=8,
|
||||
end_col_offset=33,
|
||||
),
|
||||
):
|
||||
type_hint_checker.visit_classdef(class_node)
|
||||
|
||||
|
||||
def test_valid_config_flow_step(
|
||||
linter: UnittestLinter, type_hint_checker: BaseChecker
|
||||
) -> None:
|
||||
"""Ensure valid hints are accepted for ConfigFlow step."""
|
||||
class_node = astroid.extract_node(
|
||||
"""
|
||||
class ConfigFlow():
|
||||
pass
|
||||
|
||||
class AxisFlowHandler( #@
|
||||
ConfigFlow, domain=AXIS_DOMAIN
|
||||
):
|
||||
async def async_step_zeroconf(
|
||||
self,
|
||||
device_config: ZeroconfServiceInfo
|
||||
) -> FlowResult:
|
||||
pass
|
||||
""",
|
||||
"homeassistant.components.pylint_test.config_flow",
|
||||
)
|
||||
type_hint_checker.visit_module(class_node.parent)
|
||||
|
||||
with assert_no_messages(linter):
|
||||
type_hint_checker.visit_classdef(class_node)
|
||||
|
Loading…
x
Reference in New Issue
Block a user