diff --git a/esphome/components/openthread/__init__.py b/esphome/components/openthread/__init__.py index 393c47e720..65138e28c7 100644 --- a/esphome/components/openthread/__init__.py +++ b/esphome/components/openthread/__init__.py @@ -22,7 +22,6 @@ from .const import ( CONF_SRP_ID, CONF_TLV, ) -from .tlv import parse_tlv CODEOWNERS = ["@mrene"] @@ -43,29 +42,40 @@ def set_sdkconfig_options(config): add_idf_sdkconfig_option("CONFIG_OPENTHREAD_CLI", False) add_idf_sdkconfig_option("CONFIG_OPENTHREAD_ENABLED", True) - add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_PANID", config[CONF_PAN_ID]) - add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_CHANNEL", config[CONF_CHANNEL]) - add_idf_sdkconfig_option( - "CONFIG_OPENTHREAD_NETWORK_MASTERKEY", f"{config[CONF_NETWORK_KEY]:X}".lower() - ) - if network_name := config.get(CONF_NETWORK_NAME): - add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_NAME", network_name) + if tlv := config.get(CONF_TLV): + cg.add_define("USE_OPENTHREAD_TLVS", tlv) + else: + if pan_id := config.get(CONF_PAN_ID): + add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_PANID", pan_id) - if (ext_pan_id := config.get(CONF_EXT_PAN_ID)) is not None: - add_idf_sdkconfig_option( - "CONFIG_OPENTHREAD_NETWORK_EXTPANID", f"{ext_pan_id:X}".lower() - ) - if (mesh_local_prefix := config.get(CONF_MESH_LOCAL_PREFIX)) is not None: - add_idf_sdkconfig_option( - "CONFIG_OPENTHREAD_MESH_LOCAL_PREFIX", f"{mesh_local_prefix}".lower() - ) - if (pskc := config.get(CONF_PSKC)) is not None: - add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_PSKC", f"{pskc:X}".lower()) + if channel := config.get(CONF_CHANNEL): + add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_CHANNEL", channel) - if CONF_FORCE_DATASET in config: - if config[CONF_FORCE_DATASET]: - cg.add_define("CONFIG_OPENTHREAD_FORCE_DATASET") + if network_key := config.get(CONF_NETWORK_KEY): + add_idf_sdkconfig_option( + "CONFIG_OPENTHREAD_NETWORK_MASTERKEY", f"{network_key:X}".lower() + ) + + if network_name := config.get(CONF_NETWORK_NAME): + add_idf_sdkconfig_option("CONFIG_OPENTHREAD_NETWORK_NAME", network_name) + + if (ext_pan_id := config.get(CONF_EXT_PAN_ID)) is not None: + add_idf_sdkconfig_option( + "CONFIG_OPENTHREAD_NETWORK_EXTPANID", f"{ext_pan_id:X}".lower() + ) + if (mesh_local_prefix := config.get(CONF_MESH_LOCAL_PREFIX)) is not None: + add_idf_sdkconfig_option( + "CONFIG_OPENTHREAD_MESH_LOCAL_PREFIX", f"{mesh_local_prefix}".lower() + ) + if (pskc := config.get(CONF_PSKC)) is not None: + add_idf_sdkconfig_option( + "CONFIG_OPENTHREAD_NETWORK_PSKC", f"{pskc:X}".lower() + ) + + if force_dataset := config.get(CONF_FORCE_DATASET): + if force_dataset: + cg.add_define("USE_OPENTHREAD_FORCE_DATASET") add_idf_sdkconfig_option("CONFIG_OPENTHREAD_DNS64_CLIENT", True) add_idf_sdkconfig_option("CONFIG_OPENTHREAD_SRP_CLIENT", True) @@ -79,22 +89,11 @@ openthread_ns = cg.esphome_ns.namespace("openthread") OpenThreadComponent = openthread_ns.class_("OpenThreadComponent", cg.Component) OpenThreadSrpComponent = openthread_ns.class_("OpenThreadSrpComponent", cg.Component) - -def _convert_tlv(config): - if tlv := config.get(CONF_TLV): - config = config.copy() - parsed_tlv = parse_tlv(tlv) - validated = _CONNECTION_SCHEMA(parsed_tlv) - config.update(validated) - del config[CONF_TLV] - return config - - _CONNECTION_SCHEMA = cv.Schema( { - cv.Inclusive(CONF_PAN_ID, "manual"): cv.hex_int, - cv.Inclusive(CONF_CHANNEL, "manual"): cv.int_, - cv.Inclusive(CONF_NETWORK_KEY, "manual"): cv.hex_int, + cv.Optional(CONF_PAN_ID): cv.hex_int, + cv.Optional(CONF_CHANNEL): cv.int_, + cv.Optional(CONF_NETWORK_KEY): cv.hex_int, cv.Optional(CONF_EXT_PAN_ID): cv.hex_int, cv.Optional(CONF_NETWORK_NAME): cv.string_strict, cv.Optional(CONF_PSKC): cv.hex_int, @@ -112,8 +111,7 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_TLV): cv.string_strict, } ).extend(_CONNECTION_SCHEMA), - cv.has_exactly_one_key(CONF_PAN_ID, CONF_TLV), - _convert_tlv, + cv.has_exactly_one_key(CONF_NETWORK_KEY, CONF_TLV), cv.only_with_esp_idf, only_on_variant(supported=[VARIANT_ESP32C6, VARIANT_ESP32H2]), ) diff --git a/esphome/components/openthread/openthread_esp.cpp b/esphome/components/openthread/openthread_esp.cpp index c5c817382f..dc303cef17 100644 --- a/esphome/components/openthread/openthread_esp.cpp +++ b/esphome/components/openthread/openthread_esp.cpp @@ -111,14 +111,36 @@ void OpenThreadComponent::ot_main() { esp_openthread_cli_create_task(); #endif ESP_LOGI(TAG, "Activating dataset..."); - otOperationalDatasetTlvs dataset; + otOperationalDatasetTlvs dataset = {}; -#ifdef CONFIG_OPENTHREAD_FORCE_DATASET - ESP_ERROR_CHECK(esp_openthread_auto_start(NULL)); -#else +#ifndef USE_OPENTHREAD_FORCE_DATASET + // Check if openthread has a valid dataset from a previous execution otError error = otDatasetGetActiveTlvs(esp_openthread_get_instance(), &dataset); - ESP_ERROR_CHECK(esp_openthread_auto_start((error == OT_ERROR_NONE) ? &dataset : NULL)); + if (error != OT_ERROR_NONE) { + // Make sure the length is 0 so we fallback to the configuration + dataset.mLength = 0; + } else { + ESP_LOGI(TAG, "Found OpenThread-managed dataset, ignoring esphome configuration"); + ESP_LOGI(TAG, "(set force_dataset: true to override)"); + } #endif + +#ifdef USE_OPENTHREAD_TLVS + if (dataset.mLength == 0) { + // If we didn't have an active dataset, and we have tlvs, parse it and pass it to esp_openthread_auto_start + size_t len = (sizeof(USE_OPENTHREAD_TLVS) - 1) / 2; + if (len > sizeof(dataset.mTlvs)) { + ESP_LOGW(TAG, "TLV buffer too small, truncating"); + len = sizeof(dataset.mTlvs); + } + parse_hex(USE_OPENTHREAD_TLVS, sizeof(USE_OPENTHREAD_TLVS) - 1, dataset.mTlvs, len); + dataset.mLength = len; + } +#endif + + // Pass the existing dataset, or NULL which will use the preprocessor definitions + ESP_ERROR_CHECK(esp_openthread_auto_start(dataset.mLength > 0 ? &dataset : nullptr)); + esp_openthread_launch_mainloop(); // Clean up diff --git a/esphome/components/openthread/tlv.py b/esphome/components/openthread/tlv.py deleted file mode 100644 index 4a7d21c47d..0000000000 --- a/esphome/components/openthread/tlv.py +++ /dev/null @@ -1,65 +0,0 @@ -# Sourced from https://gist.github.com/agners/0338576e0003318b63ec1ea75adc90f9 -import binascii -import ipaddress - -from esphome.const import CONF_CHANNEL - -from . import ( - CONF_EXT_PAN_ID, - CONF_MESH_LOCAL_PREFIX, - CONF_NETWORK_KEY, - CONF_NETWORK_NAME, - CONF_PAN_ID, - CONF_PSKC, -) - -TLV_TYPES = { - 0: CONF_CHANNEL, - 1: CONF_PAN_ID, - 2: CONF_EXT_PAN_ID, - 3: CONF_NETWORK_NAME, - 4: CONF_PSKC, - 5: CONF_NETWORK_KEY, - 7: CONF_MESH_LOCAL_PREFIX, -} - - -def parse_tlv(tlv) -> dict: - data = binascii.a2b_hex(tlv) - output = {} - pos = 0 - while pos < len(data): - tag = data[pos] - pos += 1 - _len = data[pos] - pos += 1 - val = data[pos : pos + _len] - pos += _len - if tag in TLV_TYPES: - if tag == 3: - output[TLV_TYPES[tag]] = val.decode("utf-8") - elif tag == 7: - mesh_local_prefix = binascii.hexlify(val).decode("utf-8") - mesh_local_prefix_str = f"{mesh_local_prefix}0000000000000000" - ipv6_bytes = bytes.fromhex(mesh_local_prefix_str) - ipv6_address = ipaddress.IPv6Address(ipv6_bytes) - output[TLV_TYPES[tag]] = f"{ipv6_address}/64" - else: - output[TLV_TYPES[tag]] = int.from_bytes(val) - return output - - -def main(): - import sys - - args = sys.argv[1:] - parsed = parse_tlv(args[0]) - # print the parsed TLV data - for key, value in parsed.items(): - if isinstance(value, bytes): - value = value.hex() - print(f"{key}: {value}") - - -if __name__ == "__main__": - main()