diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index d846a3418b..069275a6f3 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -71,12 +71,6 @@ def validate_local(config): return config -def validate_ota(config): - if CORE.using_esp_idf and config[CONF_OTA]: - raise cv.Invalid("Enabling 'ota' is not supported for IDF framework yet") - return config - - def validate_sorting_groups(config): if CONF_SORTING_GROUPS in config and config[CONF_VERSION] != 3: raise cv.Invalid( @@ -178,7 +172,7 @@ CONFIG_SCHEMA = cv.All( CONF_OTA, esp8266=True, esp32_arduino=True, - esp32_idf=False, + esp32_idf=True, bk72xx=True, rtl87xx=True, ): cv.boolean, @@ -190,7 +184,6 @@ CONFIG_SCHEMA = cv.All( cv.only_on([PLATFORM_ESP32, PLATFORM_ESP8266, PLATFORM_BK72XX, PLATFORM_RTL87XX]), default_url, validate_local, - validate_ota, validate_sorting_groups, ) diff --git a/esphome/components/web_server_base/web_server_base.cpp b/esphome/components/web_server_base/web_server_base.cpp index 2835585387..6f768d0d21 100644 --- a/esphome/components/web_server_base/web_server_base.cpp +++ b/esphome/components/web_server_base/web_server_base.cpp @@ -14,6 +14,10 @@ #endif #endif +#ifdef USE_ESP_IDF +#include "esphome/components/ota/ota_backend.h" +#endif + namespace esphome { namespace web_server_base { @@ -93,6 +97,67 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin } } #endif + +#ifdef USE_ESP_IDF + // ESP-IDF implementation + if (index == 0) { + ESP_LOGI(TAG, "OTA Update Start: %s", filename.c_str()); + this->ota_read_length_ = 0; + this->ota_started_ = false; + + // Create OTA backend + this->ota_backend_ = ota::make_ota_backend(); + + // Begin OTA with unknown size + auto result = this->ota_backend_->begin(0); + if (result != ota::OTA_RESPONSE_OK) { + ESP_LOGE(TAG, "OTA begin failed: %d", result); + this->ota_backend_.reset(); + return; + } + this->ota_started_ = true; + } else if (!this->ota_started_ || !this->ota_backend_) { + // Begin failed or was aborted + return; + } + + // Write data + if (len > 0) { + auto result = this->ota_backend_->write(data, len); + if (result != ota::OTA_RESPONSE_OK) { + ESP_LOGE(TAG, "OTA write failed: %d", result); + this->ota_backend_->abort(); + this->ota_backend_.reset(); + this->ota_started_ = false; + return; + } + + this->ota_read_length_ += len; + + const uint32_t now = millis(); + if (now - this->last_ota_progress_ > 1000) { + if (request->contentLength() != 0) { + float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength(); + ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); + } else { + ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_); + } + this->last_ota_progress_ = now; + } + } + + if (final) { + auto result = this->ota_backend_->end(); + if (result == ota::OTA_RESPONSE_OK) { + ESP_LOGI(TAG, "OTA update successful!"); + this->parent_->set_timeout(100, []() { App.safe_reboot(); }); + } else { + ESP_LOGE(TAG, "OTA end failed: %d", result); + } + this->ota_backend_.reset(); + this->ota_started_ = false; + } +#endif } void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { #ifdef USE_ARDUINO @@ -108,10 +173,20 @@ void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { response->addHeader("Connection", "close"); request->send(response); #endif +#ifdef USE_ESP_IDF + AsyncWebServerResponse *response; + if (this->ota_started_ && this->ota_backend_) { + response = request->beginResponse(200, "text/plain", "Update Successful!"); + } else { + response = request->beginResponse(200, "text/plain", "Update Failed!"); + } + response->addHeader("Connection", "close"); + request->send(response); +#endif } void WebServerBase::add_ota_handler() { -#ifdef USE_ARDUINO +#if defined(USE_ARDUINO) || defined(USE_ESP_IDF) this->add_handler(new OTARequestHandler(this)); // NOLINT #endif } diff --git a/esphome/components/web_server_base/web_server_base.h b/esphome/components/web_server_base/web_server_base.h index 641006cb99..33aba6247a 100644 --- a/esphome/components/web_server_base/web_server_base.h +++ b/esphome/components/web_server_base/web_server_base.h @@ -142,6 +142,10 @@ class OTARequestHandler : public AsyncWebHandler { uint32_t last_ota_progress_{0}; uint32_t ota_read_length_{0}; WebServerBase *parent_; +#ifdef USE_ESP_IDF + std::unique_ptr ota_backend_; + bool ota_started_{false}; +#endif }; } // namespace web_server_base diff --git a/esphome/components/web_server_idf/multipart_parser.cpp b/esphome/components/web_server_idf/multipart_parser.cpp new file mode 100644 index 0000000000..89417733d6 --- /dev/null +++ b/esphome/components/web_server_idf/multipart_parser.cpp @@ -0,0 +1,226 @@ +#ifdef USE_ESP_IDF +#include "multipart_parser.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace web_server_idf { + +static const char *const TAG = "multipart_parser"; + +bool MultipartParser::parse(const uint8_t *data, size_t len) { + // Append new data to buffer + buffer_.insert(buffer_.end(), data, data + len); + + while (state_ != DONE && state_ != ERROR && !buffer_.empty()) { + switch (state_) { + case BOUNDARY_SEARCH: + if (!find_boundary()) { + return false; + } + state_ = HEADERS; + break; + + case HEADERS: + if (!parse_headers()) { + return false; + } + state_ = CONTENT; + content_start_ = 0; // Content starts at current buffer position + break; + + case CONTENT: + if (!extract_content()) { + return false; + } + break; + + default: + break; + } + } + + return part_ready_; +} + +bool MultipartParser::get_current_part(Part &part) const { + if (!part_ready_ || content_length_ == 0) { + return false; + } + + part.name = current_name_; + part.filename = current_filename_; + part.content_type = current_content_type_; + part.data = buffer_.data() + content_start_; + part.length = content_length_; + + return true; +} + +void MultipartParser::consume_part() { + if (!part_ready_) { + return; + } + + // Remove consumed data from buffer + if (content_start_ + content_length_ < buffer_.size()) { + buffer_.erase(buffer_.begin(), buffer_.begin() + content_start_ + content_length_); + } else { + buffer_.clear(); + } + + // Reset for next part + part_ready_ = false; + content_start_ = 0; + content_length_ = 0; + current_name_.clear(); + current_filename_.clear(); + current_content_type_.clear(); + + // Look for next boundary + state_ = BOUNDARY_SEARCH; +} + +void MultipartParser::reset() { + buffer_.clear(); + state_ = BOUNDARY_SEARCH; + part_ready_ = false; + content_start_ = 0; + content_length_ = 0; + current_name_.clear(); + current_filename_.clear(); + current_content_type_.clear(); +} + +bool MultipartParser::find_boundary() { + // Look for boundary in buffer + size_t boundary_pos = find_pattern(reinterpret_cast(boundary_.c_str()), boundary_.length()); + + if (boundary_pos == std::string::npos) { + // Keep some data for next iteration to handle split boundaries + if (buffer_.size() > boundary_.length() + 4) { + buffer_.erase(buffer_.begin(), buffer_.end() - boundary_.length() - 4); + } + return false; + } + + // Remove everything up to and including the boundary + buffer_.erase(buffer_.begin(), buffer_.begin() + boundary_pos + boundary_.length()); + + // Skip CRLF after boundary + if (buffer_.size() >= 2 && buffer_[0] == '\r' && buffer_[1] == '\n') { + buffer_.erase(buffer_.begin(), buffer_.begin() + 2); + } + + // Check if this is the end boundary + if (buffer_.size() >= 2 && buffer_[0] == '-' && buffer_[1] == '-') { + state_ = DONE; + return false; + } + + return true; +} + +bool MultipartParser::parse_headers() { + while (true) { + std::string line = read_line(); + if (line.empty()) { + // Check if we have enough data for a line + auto crlf_pos = find_pattern(reinterpret_cast("\r\n"), 2); + if (crlf_pos == std::string::npos) { + return false; // Need more data + } + // Empty line means headers are done + buffer_.erase(buffer_.begin(), buffer_.begin() + 2); + return true; + } + + // Parse Content-Disposition header + if (line.find("Content-Disposition:") == 0) { + // Extract name + size_t name_pos = line.find("name=\""); + if (name_pos != std::string::npos) { + name_pos += 6; + size_t name_end = line.find("\"", name_pos); + if (name_end != std::string::npos) { + current_name_ = line.substr(name_pos, name_end - name_pos); + } + } + + // Extract filename if present + size_t filename_pos = line.find("filename=\""); + if (filename_pos != std::string::npos) { + filename_pos += 10; + size_t filename_end = line.find("\"", filename_pos); + if (filename_end != std::string::npos) { + current_filename_ = line.substr(filename_pos, filename_end - filename_pos); + } + } + } + // Parse Content-Type header + else if (line.find("Content-Type:") == 0) { + current_content_type_ = line.substr(14); + // Trim whitespace + size_t start = current_content_type_.find_first_not_of(" \t"); + if (start != std::string::npos) { + current_content_type_ = current_content_type_.substr(start); + } + } + } +} + +bool MultipartParser::extract_content() { + // Look for next boundary + std::string search_boundary = "\r\n" + boundary_; + size_t boundary_pos = + find_pattern(reinterpret_cast(search_boundary.c_str()), search_boundary.length()); + + if (boundary_pos != std::string::npos) { + // Found complete part + content_length_ = boundary_pos - content_start_; + part_ready_ = true; + return true; + } + + // No boundary found yet, but we might have partial content + // Keep enough bytes to ensure we don't split a boundary + size_t safe_length = buffer_.size(); + if (safe_length > search_boundary.length() + 4) { + safe_length -= search_boundary.length() + 4; + if (safe_length > content_start_) { + content_length_ = safe_length - content_start_; + // We have partial content but not complete yet + return false; + } + } + + return false; +} + +std::string MultipartParser::read_line() { + auto crlf_pos = find_pattern(reinterpret_cast("\r\n"), 2); + if (crlf_pos == std::string::npos) { + return ""; + } + + std::string line(buffer_.begin(), buffer_.begin() + crlf_pos); + buffer_.erase(buffer_.begin(), buffer_.begin() + crlf_pos + 2); + return line; +} + +size_t MultipartParser::find_pattern(const uint8_t *pattern, size_t pattern_len, size_t start) const { + if (buffer_.size() < pattern_len + start) { + return std::string::npos; + } + + for (size_t i = start; i <= buffer_.size() - pattern_len; ++i) { + if (memcmp(buffer_.data() + i, pattern, pattern_len) == 0) { + return i; + } + } + + return std::string::npos; +} + +} // namespace web_server_idf +} // namespace esphome +#endif \ No newline at end of file diff --git a/esphome/components/web_server_idf/multipart_parser.h b/esphome/components/web_server_idf/multipart_parser.h new file mode 100644 index 0000000000..6d3f3f6575 --- /dev/null +++ b/esphome/components/web_server_idf/multipart_parser.h @@ -0,0 +1,67 @@ +#pragma once +#ifdef USE_ESP_IDF + +#include +#include +#include + +namespace esphome { +namespace web_server_idf { + +// Multipart form data parser for ESP-IDF +class MultipartParser { + public: + enum State { BOUNDARY_SEARCH, HEADERS, CONTENT, DONE, ERROR }; + + struct Part { + std::string name; + std::string filename; + std::string content_type; + const uint8_t *data; + size_t length; + }; + + explicit MultipartParser(const std::string &boundary) : boundary_("--" + boundary), state_(BOUNDARY_SEARCH) {} + + // Process incoming data chunk + // Returns true if a complete part is available + bool parse(const uint8_t *data, size_t len); + + // Get the current part if available + bool get_current_part(Part &part) const; + + // Consume the current part and move to next + void consume_part(); + + State get_state() const { return state_; } + bool is_done() const { return state_ == DONE; } + bool has_error() const { return state_ == ERROR; } + + // Reset parser for reuse + void reset(); + + private: + bool find_boundary(); + bool parse_headers(); + bool extract_content(); + + std::string read_line(); + size_t find_pattern(const uint8_t *pattern, size_t pattern_len, size_t start = 0) const; + + std::string boundary_; + std::string end_boundary_; + State state_; + std::vector buffer_; + + // Current part info + std::string current_name_; + std::string current_filename_; + std::string current_content_type_; + size_t content_start_{0}; + size_t content_length_{0}; + bool part_ready_{false}; +}; + +} // namespace web_server_idf +} // namespace esphome +#endif \ No newline at end of file diff --git a/esphome/components/web_server_idf/web_server_idf.cpp b/esphome/components/web_server_idf/web_server_idf.cpp index 90fdf720cd..2e1cf185db 100644 --- a/esphome/components/web_server_idf/web_server_idf.cpp +++ b/esphome/components/web_server_idf/web_server_idf.cpp @@ -8,6 +8,7 @@ #include "esp_tls_crypto.h" #include "utils.h" +#include "multipart_parser.h" #include "web_server_idf.h" @@ -72,10 +73,24 @@ void AsyncWebServer::begin() { esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) { ESP_LOGVV(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri); auto content_type = request_get_header(r, "Content-Type"); - if (content_type.has_value() && *content_type != "application/x-www-form-urlencoded") { - ESP_LOGW(TAG, "Only application/x-www-form-urlencoded supported for POST request"); - // fallback to get handler to support backward compatibility - return AsyncWebServer::request_handler(r); + + // Check if this is a multipart form data request (for OTA updates) + bool is_multipart = false; + std::string boundary; + if (content_type.has_value()) { + std::string ct = content_type.value(); + if (ct.find("multipart/form-data") != std::string::npos) { + is_multipart = true; + // Extract boundary + size_t boundary_pos = ct.find("boundary="); + if (boundary_pos != std::string::npos) { + boundary = ct.substr(boundary_pos + 9); + } + } else if (ct != "application/x-www-form-urlencoded") { + ESP_LOGW(TAG, "Unsupported content type for POST: %s", ct.c_str()); + // fallback to get handler to support backward compatibility + return AsyncWebServer::request_handler(r); + } } if (!request_has_header(r, "Content-Length")) { @@ -84,6 +99,76 @@ esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) { return ESP_OK; } + // Handle multipart form data + if (is_multipart && !boundary.empty()) { + // Create request object + AsyncWebServerRequest req(r); + auto *server = static_cast(r->user_ctx); + + // Find handler that can handle this request + AsyncWebHandler *found_handler = nullptr; + for (auto *handler : server->handlers_) { + if (handler->canHandle(&req)) { + found_handler = handler; + break; + } + } + + if (!found_handler) { + httpd_resp_send_err(r, HTTPD_404_NOT_FOUND, nullptr); + return ESP_OK; + } + + // Handle multipart upload + MultipartParser parser(boundary); + static constexpr size_t CHUNK_SIZE = 1024; + uint8_t *chunk_buf = new uint8_t[CHUNK_SIZE]; + size_t total_len = r->content_len; + size_t remaining = total_len; + bool first_part = true; + + while (remaining > 0) { + size_t to_read = std::min(remaining, CHUNK_SIZE); + int recv_len = httpd_req_recv(r, reinterpret_cast(chunk_buf), to_read); + + if (recv_len <= 0) { + delete[] chunk_buf; + if (recv_len == HTTPD_SOCK_ERR_TIMEOUT) { + httpd_resp_send_err(r, HTTPD_408_REQ_TIMEOUT, nullptr); + return ESP_ERR_TIMEOUT; + } + httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); + return ESP_FAIL; + } + + // Parse multipart data + if (parser.parse(chunk_buf, recv_len)) { + MultipartParser::Part part; + if (parser.get_current_part(part) && !part.filename.empty()) { + // This is a file upload + found_handler->handleUpload(&req, part.filename, first_part ? 0 : 1, const_cast(part.data), + part.length, false); + first_part = false; + parser.consume_part(); + } + } + + remaining -= recv_len; + } + + // Final call to handler + if (!first_part) { + found_handler->handleUpload(&req, "", 2, nullptr, 0, true); + } + + delete[] chunk_buf; + + // Let handler send response + found_handler->handleRequest(&req); + return ESP_OK; + } + + // Handle regular form data if (r->content_len > HTTPD_MAX_REQ_HDR_LEN) { ESP_LOGW(TAG, "Request size is to big: %zu", r->content_len); httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr);