Add OTA support to ESP-IDF webserver

This commit is contained in:
J. Nick Koston 2025-06-29 10:33:49 -05:00
parent 21e1f3d103
commit b77c1d0af8
No known key found for this signature in database
6 changed files with 463 additions and 13 deletions

View File

@ -71,12 +71,6 @@ def validate_local(config):
return config
def validate_ota(config):
if CORE.using_esp_idf and config[CONF_OTA]:
raise cv.Invalid("Enabling 'ota' is not supported for IDF framework yet")
return config
def validate_sorting_groups(config):
if CONF_SORTING_GROUPS in config and config[CONF_VERSION] != 3:
raise cv.Invalid(
@ -178,7 +172,7 @@ CONFIG_SCHEMA = cv.All(
CONF_OTA,
esp8266=True,
esp32_arduino=True,
esp32_idf=False,
esp32_idf=True,
bk72xx=True,
rtl87xx=True,
): cv.boolean,
@ -190,7 +184,6 @@ CONFIG_SCHEMA = cv.All(
cv.only_on([PLATFORM_ESP32, PLATFORM_ESP8266, PLATFORM_BK72XX, PLATFORM_RTL87XX]),
default_url,
validate_local,
validate_ota,
validate_sorting_groups,
)

View File

@ -14,6 +14,10 @@
#endif
#endif
#ifdef USE_ESP_IDF
#include "esphome/components/ota/ota_backend.h"
#endif
namespace esphome {
namespace web_server_base {
@ -93,6 +97,67 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
}
}
#endif
#ifdef USE_ESP_IDF
// ESP-IDF implementation
if (index == 0) {
ESP_LOGI(TAG, "OTA Update Start: %s", filename.c_str());
this->ota_read_length_ = 0;
this->ota_started_ = false;
// Create OTA backend
this->ota_backend_ = ota::make_ota_backend();
// Begin OTA with unknown size
auto result = this->ota_backend_->begin(0);
if (result != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA begin failed: %d", result);
this->ota_backend_.reset();
return;
}
this->ota_started_ = true;
} else if (!this->ota_started_ || !this->ota_backend_) {
// Begin failed or was aborted
return;
}
// Write data
if (len > 0) {
auto result = this->ota_backend_->write(data, len);
if (result != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA write failed: %d", result);
this->ota_backend_->abort();
this->ota_backend_.reset();
this->ota_started_ = false;
return;
}
this->ota_read_length_ += len;
const uint32_t now = millis();
if (now - this->last_ota_progress_ > 1000) {
if (request->contentLength() != 0) {
float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength();
ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage);
} else {
ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_);
}
this->last_ota_progress_ = now;
}
}
if (final) {
auto result = this->ota_backend_->end();
if (result == ota::OTA_RESPONSE_OK) {
ESP_LOGI(TAG, "OTA update successful!");
this->parent_->set_timeout(100, []() { App.safe_reboot(); });
} else {
ESP_LOGE(TAG, "OTA end failed: %d", result);
}
this->ota_backend_.reset();
this->ota_started_ = false;
}
#endif
}
void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) {
#ifdef USE_ARDUINO
@ -108,10 +173,20 @@ void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) {
response->addHeader("Connection", "close");
request->send(response);
#endif
#ifdef USE_ESP_IDF
AsyncWebServerResponse *response;
if (this->ota_started_ && this->ota_backend_) {
response = request->beginResponse(200, "text/plain", "Update Successful!");
} else {
response = request->beginResponse(200, "text/plain", "Update Failed!");
}
response->addHeader("Connection", "close");
request->send(response);
#endif
}
void WebServerBase::add_ota_handler() {
#ifdef USE_ARDUINO
#if defined(USE_ARDUINO) || defined(USE_ESP_IDF)
this->add_handler(new OTARequestHandler(this)); // NOLINT
#endif
}

View File

@ -142,6 +142,10 @@ class OTARequestHandler : public AsyncWebHandler {
uint32_t last_ota_progress_{0};
uint32_t ota_read_length_{0};
WebServerBase *parent_;
#ifdef USE_ESP_IDF
std::unique_ptr<ota::OTABackend> ota_backend_;
bool ota_started_{false};
#endif
};
} // namespace web_server_base

View File

@ -0,0 +1,226 @@
#ifdef USE_ESP_IDF
#include "multipart_parser.h"
#include "esphome/core/log.h"
namespace esphome {
namespace web_server_idf {
static const char *const TAG = "multipart_parser";
bool MultipartParser::parse(const uint8_t *data, size_t len) {
// Append new data to buffer
buffer_.insert(buffer_.end(), data, data + len);
while (state_ != DONE && state_ != ERROR && !buffer_.empty()) {
switch (state_) {
case BOUNDARY_SEARCH:
if (!find_boundary()) {
return false;
}
state_ = HEADERS;
break;
case HEADERS:
if (!parse_headers()) {
return false;
}
state_ = CONTENT;
content_start_ = 0; // Content starts at current buffer position
break;
case CONTENT:
if (!extract_content()) {
return false;
}
break;
default:
break;
}
}
return part_ready_;
}
bool MultipartParser::get_current_part(Part &part) const {
if (!part_ready_ || content_length_ == 0) {
return false;
}
part.name = current_name_;
part.filename = current_filename_;
part.content_type = current_content_type_;
part.data = buffer_.data() + content_start_;
part.length = content_length_;
return true;
}
void MultipartParser::consume_part() {
if (!part_ready_) {
return;
}
// Remove consumed data from buffer
if (content_start_ + content_length_ < buffer_.size()) {
buffer_.erase(buffer_.begin(), buffer_.begin() + content_start_ + content_length_);
} else {
buffer_.clear();
}
// Reset for next part
part_ready_ = false;
content_start_ = 0;
content_length_ = 0;
current_name_.clear();
current_filename_.clear();
current_content_type_.clear();
// Look for next boundary
state_ = BOUNDARY_SEARCH;
}
void MultipartParser::reset() {
buffer_.clear();
state_ = BOUNDARY_SEARCH;
part_ready_ = false;
content_start_ = 0;
content_length_ = 0;
current_name_.clear();
current_filename_.clear();
current_content_type_.clear();
}
bool MultipartParser::find_boundary() {
// Look for boundary in buffer
size_t boundary_pos = find_pattern(reinterpret_cast<const uint8_t *>(boundary_.c_str()), boundary_.length());
if (boundary_pos == std::string::npos) {
// Keep some data for next iteration to handle split boundaries
if (buffer_.size() > boundary_.length() + 4) {
buffer_.erase(buffer_.begin(), buffer_.end() - boundary_.length() - 4);
}
return false;
}
// Remove everything up to and including the boundary
buffer_.erase(buffer_.begin(), buffer_.begin() + boundary_pos + boundary_.length());
// Skip CRLF after boundary
if (buffer_.size() >= 2 && buffer_[0] == '\r' && buffer_[1] == '\n') {
buffer_.erase(buffer_.begin(), buffer_.begin() + 2);
}
// Check if this is the end boundary
if (buffer_.size() >= 2 && buffer_[0] == '-' && buffer_[1] == '-') {
state_ = DONE;
return false;
}
return true;
}
bool MultipartParser::parse_headers() {
while (true) {
std::string line = read_line();
if (line.empty()) {
// Check if we have enough data for a line
auto crlf_pos = find_pattern(reinterpret_cast<const uint8_t *>("\r\n"), 2);
if (crlf_pos == std::string::npos) {
return false; // Need more data
}
// Empty line means headers are done
buffer_.erase(buffer_.begin(), buffer_.begin() + 2);
return true;
}
// Parse Content-Disposition header
if (line.find("Content-Disposition:") == 0) {
// Extract name
size_t name_pos = line.find("name=\"");
if (name_pos != std::string::npos) {
name_pos += 6;
size_t name_end = line.find("\"", name_pos);
if (name_end != std::string::npos) {
current_name_ = line.substr(name_pos, name_end - name_pos);
}
}
// Extract filename if present
size_t filename_pos = line.find("filename=\"");
if (filename_pos != std::string::npos) {
filename_pos += 10;
size_t filename_end = line.find("\"", filename_pos);
if (filename_end != std::string::npos) {
current_filename_ = line.substr(filename_pos, filename_end - filename_pos);
}
}
}
// Parse Content-Type header
else if (line.find("Content-Type:") == 0) {
current_content_type_ = line.substr(14);
// Trim whitespace
size_t start = current_content_type_.find_first_not_of(" \t");
if (start != std::string::npos) {
current_content_type_ = current_content_type_.substr(start);
}
}
}
}
bool MultipartParser::extract_content() {
// Look for next boundary
std::string search_boundary = "\r\n" + boundary_;
size_t boundary_pos =
find_pattern(reinterpret_cast<const uint8_t *>(search_boundary.c_str()), search_boundary.length());
if (boundary_pos != std::string::npos) {
// Found complete part
content_length_ = boundary_pos - content_start_;
part_ready_ = true;
return true;
}
// No boundary found yet, but we might have partial content
// Keep enough bytes to ensure we don't split a boundary
size_t safe_length = buffer_.size();
if (safe_length > search_boundary.length() + 4) {
safe_length -= search_boundary.length() + 4;
if (safe_length > content_start_) {
content_length_ = safe_length - content_start_;
// We have partial content but not complete yet
return false;
}
}
return false;
}
std::string MultipartParser::read_line() {
auto crlf_pos = find_pattern(reinterpret_cast<const uint8_t *>("\r\n"), 2);
if (crlf_pos == std::string::npos) {
return "";
}
std::string line(buffer_.begin(), buffer_.begin() + crlf_pos);
buffer_.erase(buffer_.begin(), buffer_.begin() + crlf_pos + 2);
return line;
}
size_t MultipartParser::find_pattern(const uint8_t *pattern, size_t pattern_len, size_t start) const {
if (buffer_.size() < pattern_len + start) {
return std::string::npos;
}
for (size_t i = start; i <= buffer_.size() - pattern_len; ++i) {
if (memcmp(buffer_.data() + i, pattern, pattern_len) == 0) {
return i;
}
}
return std::string::npos;
}
} // namespace web_server_idf
} // namespace esphome
#endif

View File

@ -0,0 +1,67 @@
#pragma once
#ifdef USE_ESP_IDF
#include <string>
#include <vector>
#include <cstring>
namespace esphome {
namespace web_server_idf {
// Multipart form data parser for ESP-IDF
class MultipartParser {
public:
enum State { BOUNDARY_SEARCH, HEADERS, CONTENT, DONE, ERROR };
struct Part {
std::string name;
std::string filename;
std::string content_type;
const uint8_t *data;
size_t length;
};
explicit MultipartParser(const std::string &boundary) : boundary_("--" + boundary), state_(BOUNDARY_SEARCH) {}
// Process incoming data chunk
// Returns true if a complete part is available
bool parse(const uint8_t *data, size_t len);
// Get the current part if available
bool get_current_part(Part &part) const;
// Consume the current part and move to next
void consume_part();
State get_state() const { return state_; }
bool is_done() const { return state_ == DONE; }
bool has_error() const { return state_ == ERROR; }
// Reset parser for reuse
void reset();
private:
bool find_boundary();
bool parse_headers();
bool extract_content();
std::string read_line();
size_t find_pattern(const uint8_t *pattern, size_t pattern_len, size_t start = 0) const;
std::string boundary_;
std::string end_boundary_;
State state_;
std::vector<uint8_t> buffer_;
// Current part info
std::string current_name_;
std::string current_filename_;
std::string current_content_type_;
size_t content_start_{0};
size_t content_length_{0};
bool part_ready_{false};
};
} // namespace web_server_idf
} // namespace esphome
#endif

View File

@ -8,6 +8,7 @@
#include "esp_tls_crypto.h"
#include "utils.h"
#include "multipart_parser.h"
#include "web_server_idf.h"
@ -72,10 +73,24 @@ void AsyncWebServer::begin() {
esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
ESP_LOGVV(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri);
auto content_type = request_get_header(r, "Content-Type");
if (content_type.has_value() && *content_type != "application/x-www-form-urlencoded") {
ESP_LOGW(TAG, "Only application/x-www-form-urlencoded supported for POST request");
// fallback to get handler to support backward compatibility
return AsyncWebServer::request_handler(r);
// Check if this is a multipart form data request (for OTA updates)
bool is_multipart = false;
std::string boundary;
if (content_type.has_value()) {
std::string ct = content_type.value();
if (ct.find("multipart/form-data") != std::string::npos) {
is_multipart = true;
// Extract boundary
size_t boundary_pos = ct.find("boundary=");
if (boundary_pos != std::string::npos) {
boundary = ct.substr(boundary_pos + 9);
}
} else if (ct != "application/x-www-form-urlencoded") {
ESP_LOGW(TAG, "Unsupported content type for POST: %s", ct.c_str());
// fallback to get handler to support backward compatibility
return AsyncWebServer::request_handler(r);
}
}
if (!request_has_header(r, "Content-Length")) {
@ -84,6 +99,76 @@ esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
return ESP_OK;
}
// Handle multipart form data
if (is_multipart && !boundary.empty()) {
// Create request object
AsyncWebServerRequest req(r);
auto *server = static_cast<AsyncWebServer *>(r->user_ctx);
// Find handler that can handle this request
AsyncWebHandler *found_handler = nullptr;
for (auto *handler : server->handlers_) {
if (handler->canHandle(&req)) {
found_handler = handler;
break;
}
}
if (!found_handler) {
httpd_resp_send_err(r, HTTPD_404_NOT_FOUND, nullptr);
return ESP_OK;
}
// Handle multipart upload
MultipartParser parser(boundary);
static constexpr size_t CHUNK_SIZE = 1024;
uint8_t *chunk_buf = new uint8_t[CHUNK_SIZE];
size_t total_len = r->content_len;
size_t remaining = total_len;
bool first_part = true;
while (remaining > 0) {
size_t to_read = std::min(remaining, CHUNK_SIZE);
int recv_len = httpd_req_recv(r, reinterpret_cast<char *>(chunk_buf), to_read);
if (recv_len <= 0) {
delete[] chunk_buf;
if (recv_len == HTTPD_SOCK_ERR_TIMEOUT) {
httpd_resp_send_err(r, HTTPD_408_REQ_TIMEOUT, nullptr);
return ESP_ERR_TIMEOUT;
}
httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr);
return ESP_FAIL;
}
// Parse multipart data
if (parser.parse(chunk_buf, recv_len)) {
MultipartParser::Part part;
if (parser.get_current_part(part) && !part.filename.empty()) {
// This is a file upload
found_handler->handleUpload(&req, part.filename, first_part ? 0 : 1, const_cast<uint8_t *>(part.data),
part.length, false);
first_part = false;
parser.consume_part();
}
}
remaining -= recv_len;
}
// Final call to handler
if (!first_part) {
found_handler->handleUpload(&req, "", 2, nullptr, 0, true);
}
delete[] chunk_buf;
// Let handler send response
found_handler->handleRequest(&req);
return ESP_OK;
}
// Handle regular form data
if (r->content_len > HTTPD_MAX_REQ_HDR_LEN) {
ESP_LOGW(TAG, "Request size is to big: %zu", r->content_len);
httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr);