From cdc77506de6ad01ecd14c244927bb1ea5495e581 Mon Sep 17 00:00:00 2001 From: Kevin Ahrendt Date: Wed, 30 Apr 2025 19:22:48 -0500 Subject: [PATCH] [micro_wake_word] add new VPE features (#8655) --- .../components/micro_wake_word/__init__.py | 103 ++- .../components/micro_wake_word/automation.h | 54 ++ .../micro_wake_word/micro_wake_word.cpp | 665 +++++++++--------- .../micro_wake_word/micro_wake_word.h | 145 ++-- .../micro_wake_word/preprocessor_settings.h | 19 + .../micro_wake_word/streaming_model.cpp | 193 ++++- .../micro_wake_word/streaming_model.h | 95 ++- esphome/core/defines.h | 1 + tests/components/micro_wake_word/common.yaml | 16 + 9 files changed, 788 insertions(+), 503 deletions(-) create mode 100644 esphome/components/micro_wake_word/automation.h diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 9d5caca937..0efe2ac288 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -12,6 +12,7 @@ import esphome.config_validation as cv from esphome.const import ( CONF_FILE, CONF_ID, + CONF_INTERNAL, CONF_MICROPHONE, CONF_MODEL, CONF_PASSWORD, @@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected" CONF_PROBABILITY_CUTOFF = "probability_cutoff" CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" +CONF_STOP_AFTER_DETECTION = "stop_after_detection" CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" CONF_VAD = "vad" @@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word") MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) +DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action) +EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action) StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) +ModelIsEnabledCondition = micro_wake_word_ns.class_( + "ModelIsEnabledCondition", automation.Condition +) IsRunningCondition = micro_wake_word_ns.class_( "IsRunningCondition", automation.Condition ) +WakeWordModel = micro_wake_word_ns.class_("WakeWordModel") + def _validate_json_filename(value): value = cv.string(value) @@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest): # Original Inception-based V1 manifest models require a minimum of 45672 bytes v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 - # Original Inception-based V1 manifest models use a 20 ms feature step size v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 + # Original Inception-based V1 manifest models were trained only on TTS English samples + v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"] return v2_manifest @@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any( MODEL_SCHEMA = cv.Schema( { + cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel), cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, + cv.Optional(CONF_INTERNAL, default=False): cv.boolean, cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), } ) -# Provide a default VAD model that could be overridden +# Provides a default VAD model that could be overridden VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( cv.Schema( { @@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All( single=True ), cv.Optional(CONF_VAD): _maybe_empty_vad_schema, + cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, cv.Optional(CONF_MODEL): cv.invalid( f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." ), @@ -433,29 +446,20 @@ async def to_code(config): mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) cg.add(var.set_microphone_source(mic_source)) + cg.add_define("USE_MICRO_WAKE_WORD") + cg.add_define("USE_OTA_STATE_CALLBACK") + esp32.add_idf_component( name="esp-tflite-micro", repo="https://github.com/espressif/esp-tflite-micro", - ref="v1.3.1", - ) - # add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17 - # ...remove after switching to IDF 5.1.4+ - esp32.add_idf_component( - name="esp-nn", - repo="https://github.com/espressif/esp-nn", - ref="v1.1.0", + ref="v1.3.3.1", ) cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") cg.add_build_flag("-DESP_NN") - if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): - await automation.build_automation( - var.get_wake_word_detected_trigger(), - [(cg.std_string, "wake_word")], - on_wake_word_detection_config, - ) + cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") if vad_model := config.get(CONF_VAD): cg.add_define("USE_MICRO_WAKE_WORD_VAD") @@ -463,7 +467,7 @@ async def to_code(config): # Use the general model loading code for the VAD codegen config[CONF_MODELS].append(vad_model) - for model_parameters in config[CONF_MODELS]: + for i, model_parameters in enumerate(config[CONF_MODELS]): model_config = model_parameters.get(CONF_MODEL) data = [] manifest, data = _model_config_to_manifest_data(model_config) @@ -474,6 +478,8 @@ async def to_code(config): probability_cutoff = model_parameters.get( CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] ) + quantized_probability_cutoff = int(probability_cutoff * 255) + sliding_window_size = model_parameters.get( CONF_SLIDING_WINDOW_SIZE, manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], @@ -483,24 +489,40 @@ async def to_code(config): cg.add( var.add_vad_model( prog_arr, - probability_cutoff, + quantized_probability_cutoff, sliding_window_size, manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], ) ) else: - cg.add( - var.add_wake_word_model( - prog_arr, - probability_cutoff, - sliding_window_size, - manifest[KEY_WAKE_WORD], - manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], - ) + # Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash + default_enabled = i == 0 + wake_word_model = cg.new_Pvariable( + model_parameters[CONF_ID], + str(model_parameters[CONF_ID]), + prog_arr, + quantized_probability_cutoff, + sliding_window_size, + manifest[KEY_WAKE_WORD], + manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], + default_enabled, + model_parameters[CONF_INTERNAL], ) + for lang in manifest[KEY_TRAINED_LANGUAGES]: + cg.add(wake_word_model.add_trained_language(lang)) + + cg.add(var.add_wake_word_model(wake_word_model)) + cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) - cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") + cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION])) + + if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): + await automation.build_automation( + var.get_wake_word_detected_trigger(), + [(cg.std_string, "wake_word")], + on_wake_word_detection_config, + ) MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) @@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) await cg.register_parented(var, config[CONF_ID]) return var + + +MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id( + { + cv.Required(CONF_ID): cv.use_id(WakeWordModel), + } +) + + +@register_action( + "micro_wake_word.enable_model", + EnableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_action( + "micro_wake_word.disable_model", + DisableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_condition( + "micro_wake_word.model_is_enabled", + ModelIsEnabledCondition, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +async def model_action(config, action_id, template_arg, args): + parent = await cg.get_variable(config[CONF_ID]) + return cg.new_Pvariable(action_id, template_arg, parent) diff --git a/esphome/components/micro_wake_word/automation.h b/esphome/components/micro_wake_word/automation.h new file mode 100644 index 0000000000..f10a4ed347 --- /dev/null +++ b/esphome/components/micro_wake_word/automation.h @@ -0,0 +1,54 @@ +#pragma once + +#include "micro_wake_word.h" +#include "streaming_model.h" + +#ifdef USE_ESP_IDF +namespace esphome { +namespace micro_wake_word { + +template class StartAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->start(); } +}; + +template class StopAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->stop(); } +}; + +template class IsRunningCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->is_running(); } +}; + +template class EnableModelAction : public Action { + public: + explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->enable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class DisableModelAction : public Action { + public: + explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->disable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class ModelIsEnabledCondition : public Condition { + public: + explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +} // namespace micro_wake_word +} // namespace esphome +#endif diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index dd1a8be378..f768b661c0 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -1,5 +1,4 @@ #include "micro_wake_word.h" -#include "streaming_model.h" #ifdef USE_ESP_IDF @@ -7,41 +6,57 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" -#include -#include +#include "esphome/components/audio/audio_transfer_buffer.h" -#include -#include -#include - -#include +#ifdef USE_OTA +#include "esphome/components/ota/ota_backend.h" +#endif namespace esphome { namespace micro_wake_word { static const char *const TAG = "micro_wake_word"; -static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz -static const size_t BUFFER_LENGTH = 64; // 0.064 seconds -static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH; -static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000; // 16ms * 16kHz / 1000ms +static const ssize_t DETECTION_QUEUE_LENGTH = 5; + +static const size_t DATA_TIMEOUT_MS = 50; + +static const uint32_t RING_BUFFER_DURATION_MS = 120; +static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000); +static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); + +static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072; +static const UBaseType_t INFERENCE_TASK_PRIORITY = 3; + +enum EventGroupBits : uint32_t { + COMMAND_STOP = (1 << 0), // Signals the inference task should stop + + TASK_STARTING = (1 << 3), + TASK_RUNNING = (1 << 4), + TASK_STOPPING = (1 << 5), + TASK_STOPPED = (1 << 6), + + ERROR_MEMORY = (1 << 9), + ERROR_INFERENCE = (1 << 10), + + WARNING_FULL_RING_BUFFER = (1 << 13), + + ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE, + ALL_BITS = 0xfffff, // 24 total bits available in an event group +}; float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } static const LogString *micro_wake_word_state_to_string(State state) { switch (state) { - case State::IDLE: - return LOG_STR("IDLE"); - case State::START_MICROPHONE: - return LOG_STR("START_MICROPHONE"); - case State::STARTING_MICROPHONE: - return LOG_STR("STARTING_MICROPHONE"); + case State::STARTING: + return LOG_STR("STARTING"); case State::DETECTING_WAKE_WORD: return LOG_STR("DETECTING_WAKE_WORD"); - case State::STOP_MICROPHONE: - return LOG_STR("STOP_MICROPHONE"); - case State::STOPPING_MICROPHONE: - return LOG_STR("STOPPING_MICROPHONE"); + case State::STOPPING: + return LOG_STR("STOPPING"); + case State::STOPPED: + return LOG_STR("STOPPED"); default: return LOG_STR("UNKNOWN"); } @@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() { ESP_LOGCONFIG(TAG, "microWakeWord:"); ESP_LOGCONFIG(TAG, " models:"); for (auto &model : this->wake_word_models_) { - model.log_model_config(); + model->log_model_config(); } #ifdef USE_MICRO_WAKE_WORD_VAD this->vad_model_->log_model_config(); @@ -61,108 +76,266 @@ void MicroWakeWord::dump_config() { void MicroWakeWord::setup() { ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); + this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; + this->frontend_config_.window.step_size_ms = this->features_step_size_; + this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; + this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT; + this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT; + this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS; + this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING; + this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING; + this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING; + this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN; + this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH; + this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET; + this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS; + this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG; + this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT; + + this->event_group_ = xEventGroupCreate(); + if (this->event_group_ == nullptr) { + ESP_LOGE(TAG, "Failed to create event group"); + this->mark_failed(); + return; + } + + this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent)); + if (this->detection_queue_ == nullptr) { + ESP_LOGE(TAG, "Failed to create detection event queue"); + this->mark_failed(); + return; + } + this->microphone_source_->add_data_callback([this](const std::vector &data) { - if (this->state_ != State::DETECTING_WAKE_WORD) { + if (this->state_ == State::STOPPED) { return; } - std::shared_ptr temp_ring_buffer = this->ring_buffer_; - if (this->ring_buffer_.use_count() == 2) { - // mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer - + std::shared_ptr temp_ring_buffer = this->ring_buffer_.lock(); + if (this->ring_buffer_.use_count() > 1) { size_t bytes_free = temp_ring_buffer->free(); if (bytes_free < data.size()) { - ESP_LOGW( - TAG, - "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " - "Resetting the ring buffer. Wake word detection accuracy will be reduced.", - bytes_free, data.size()); - + xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); temp_ring_buffer->reset(); } temp_ring_buffer->write((void *) data.data(), data.size()); } }); - if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { - this->mark_failed(); - return; +#ifdef USE_OTA + ota::get_global_ota_callback()->add_on_state_callback( + [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->suspend_task_(); + } else if (state == ota::OTA_ERROR) { + this->resume_task_(); + } + }); +#endif + ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); +} + +void MicroWakeWord::inference_task(void *params) { + MicroWakeWord *this_mww = (MicroWakeWord *) params; + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING); + + { // Ensures any C++ objects fall out of scope to deallocate before deleting the task + const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000); + std::unique_ptr audio_buffer; + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]; + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate audio transfer buffer + audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t)); + + if (audio_buffer == nullptr) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate ring buffer + std::shared_ptr temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE); + if (temp_ring_buffer.use_count() == 0) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + audio_buffer->set_source(temp_ring_buffer); + this_mww->ring_buffer_ = temp_ring_buffer; + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + this_mww->microphone_source_->start(); + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING); + + while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) { + audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS)); + + if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) { + // Insufficient data to generate new spectrogram features, read more next iteration + continue; + } + + // Generate new spectrogram features + size_t processed_samples = this_mww->generate_features_( + (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer); + audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t)); + + // Run inference using the new spectorgram features + if (!this_mww->update_model_probabilities_(features_buffer)) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE); + break; + } + + // Process each model's probabilities and possibly send a Detection Event to the queue + this_mww->process_probabilities_(); + } + } } - ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING); - this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; - this->frontend_config_.window.step_size_ms = this->features_step_size_; - this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; - this->frontend_config_.filterbank.lower_band_limit = 125.0; - this->frontend_config_.filterbank.upper_band_limit = 7500.0; - this->frontend_config_.noise_reduction.smoothing_bits = 10; - this->frontend_config_.noise_reduction.even_smoothing = 0.025; - this->frontend_config_.noise_reduction.odd_smoothing = 0.06; - this->frontend_config_.noise_reduction.min_signal_remaining = 0.05; - this->frontend_config_.pcan_gain_control.enable_pcan = 1; - this->frontend_config_.pcan_gain_control.strength = 0.95; - this->frontend_config_.pcan_gain_control.offset = 80.0; - this->frontend_config_.pcan_gain_control.gain_bits = 21; - this->frontend_config_.log_scale.enable_log = 1; - this->frontend_config_.log_scale.scale_shift = 6; + this_mww->unload_models_(); + this_mww->microphone_source_->stop(); + FrontendFreeStateContents(&this_mww->frontend_state_); + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); + while (true) { + // Continuously delay until the main loop deletes the task + delay(10); + } } -void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff, - size_t sliding_window_average_size, const std::string &wake_word, - size_t tensor_arena_size) { - this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word, - tensor_arena_size); +std::vector MicroWakeWord::get_wake_words() { + std::vector external_wake_word_models; + for (auto *model : this->wake_word_models_) { + if (!model->get_internal_only()) { + external_wake_word_models.push_back(model); + } + } + return external_wake_word_models; } +void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); } + #ifdef USE_MICRO_WAKE_WORD_VAD -void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->vad_model_ = make_unique(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); } #endif +void MicroWakeWord::suspend_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskSuspend(this->inference_task_handle_); + } +} + +void MicroWakeWord::resume_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskResume(this->inference_task_handle_); + } +} + void MicroWakeWord::loop() { + uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); + + if (event_group_bits & EventGroupBits::ERROR_MEMORY) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY); + ESP_LOGE(TAG, "Encountered an error allocating buffers"); + } + + if (event_group_bits & EventGroupBits::ERROR_INFERENCE) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE); + ESP_LOGE(TAG, "Encountered an error while performing an inference"); + } + + if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) { + xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); + ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake " + "word detection accuracy will temporarily be reduced."); + } + + if (event_group_bits & EventGroupBits::TASK_STARTING) { + ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING); + } + + if (event_group_bits & EventGroupBits::TASK_RUNNING) { + ESP_LOGD(TAG, "Inference task is running"); + + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING); + this->set_state_(State::DETECTING_WAKE_WORD); + } + + if (event_group_bits & EventGroupBits::TASK_STOPPING) { + ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING); + } + + if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { + ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); + vTaskDelete(this->inference_task_handle_); + this->inference_task_handle_ = nullptr; + xEventGroupClearBits(this->event_group_, ALL_BITS); + xQueueReset(this->detection_queue_); + this->set_state_(State::STOPPED); + } + + if ((this->pending_start_) && (this->state_ == State::STOPPED)) { + this->set_state_(State::STARTING); + this->pending_start_ = false; + } + + if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) { + this->set_state_(State::STOPPING); + this->pending_stop_ = false; + } + switch (this->state_) { - case State::IDLE: - break; - case State::START_MICROPHONE: - ESP_LOGD(TAG, "Starting Microphone"); - this->microphone_source_->start(); - this->set_state_(State::STARTING_MICROPHONE); - break; - case State::STARTING_MICROPHONE: - if (this->microphone_source_->is_running()) { - this->set_state_(State::DETECTING_WAKE_WORD); - } - break; - case State::DETECTING_WAKE_WORD: - while (this->has_enough_samples_()) { - this->update_model_probabilities_(); - if (this->detect_wake_words_()) { - ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); - this->detected_ = true; - this->set_state_(State::STOP_MICROPHONE); + case State::STARTING: + if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { + // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it + // uses floating point operations. + if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { + this->status_momentary_error( + "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000); + return; + } + + xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, + INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); + + if (this->inference_task_handle_ == nullptr) { + FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state + this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); } } break; - case State::STOP_MICROPHONE: - ESP_LOGD(TAG, "Stopping Microphone"); - this->microphone_source_->stop(); - this->set_state_(State::STOPPING_MICROPHONE); - this->unload_models_(); - this->deallocate_buffers_(); - break; - case State::STOPPING_MICROPHONE: - if (this->microphone_source_->is_stopped()) { - this->set_state_(State::IDLE); - if (this->detected_) { - this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); - this->detected_ = false; - this->detected_wake_word_ = ""; + case State::DETECTING_WAKE_WORD: { + DetectionEvent detection_event; + while (xQueueReceive(this->detection_queue_, &detection_event, 0)) { + if (detection_event.blocked_by_vad) { + ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str()); + } else { + constexpr float uint8_to_float_divisor = + 255.0f; // Converting a quantized uint8 probability to floating point + ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f", + detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor), + (detection_event.max_probability / uint8_to_float_divisor)); + this->wake_word_detected_trigger_->trigger(*detection_event.wake_word); + if (this->stop_after_detection_) { + this->stop(); + } } } break; + } + case State::STOPPING: + xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP); + break; + case State::STOPPED: + break; } } @@ -177,199 +350,40 @@ void MicroWakeWord::start() { return; } - if (this->state_ != State::IDLE) { - ESP_LOGW(TAG, "Wake word is already running"); + if (this->is_running()) { + ESP_LOGW(TAG, "Wake word detection is already running"); return; } - if (!this->load_models_() || !this->allocate_buffers_()) { - ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); - this->status_set_error(); - } else { - this->status_clear_error(); - } + ESP_LOGD(TAG, "Starting wake word detection"); - if (this->status_has_error()) { - ESP_LOGW(TAG, "Wake word component has an error. Please check logs"); - return; - } - - this->reset_states_(); - this->set_state_(State::START_MICROPHONE); + this->pending_start_ = true; + this->pending_stop_ = false; } void MicroWakeWord::stop() { - if (this->state_ == State::IDLE) { - ESP_LOGW(TAG, "Wake word is already stopped"); + if (this->state_ == STOPPED) return; - } - if (this->state_ == State::STOPPING_MICROPHONE) { - ESP_LOGW(TAG, "Wake word is already stopping"); - return; - } - this->set_state_(State::STOP_MICROPHONE); + + ESP_LOGD(TAG, "Stopping wake word detection"); + + this->pending_start_ = false; + this->pending_stop_ = true; } void MicroWakeWord::set_state_(State state) { - ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), - LOG_STR_ARG(micro_wake_word_state_to_string(state))); - this->state_ = state; + if (this->state_ != state) { + ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), + LOG_STR_ARG(micro_wake_word_state_to_string(state))); + this->state_ = state; + } } -bool MicroWakeWord::allocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - - if (this->input_buffer_ == nullptr) { - this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (this->input_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate input buffer"); - return false; - } - } - - if (this->preprocessor_audio_buffer_ == nullptr) { - this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_()); - if (this->preprocessor_audio_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer."); - return false; - } - } - - if (this->ring_buffer_.use_count() == 0) { - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_.use_count() == 0) { - ESP_LOGE(TAG, "Could not allocate ring buffer"); - return false; - } - } - - return true; -} - -void MicroWakeWord::deallocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - if (this->input_buffer_ != nullptr) { - audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - this->input_buffer_ = nullptr; - } - - if (this->preprocessor_audio_buffer_ != nullptr) { - audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); - this->preprocessor_audio_buffer_ = nullptr; - } - - this->ring_buffer_.reset(); -} - -bool MicroWakeWord::load_models_() { - // Setup preprocesor feature generator - if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { - ESP_LOGD(TAG, "Failed to populate frontend state"); - FrontendFreeStateContents(&this->frontend_state_); - return false; - } - - // Setup streaming models - for (auto &model : this->wake_word_models_) { - if (!model.load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize a wake word model."); - return false; - } - } -#ifdef USE_MICRO_WAKE_WORD_VAD - if (!this->vad_model_->load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize VAD model."); - return false; - } -#endif - - return true; -} - -void MicroWakeWord::unload_models_() { - FrontendFreeStateContents(&this->frontend_state_); - - for (auto &model : this->wake_word_models_) { - model.unload_model(); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->unload_model(); -#endif -} - -void MicroWakeWord::update_model_probabilities_() { - int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]; - - if (!this->generate_features_for_window_(audio_features)) { - return; - } - - // Increase the counter since the last positive detection - this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); - - for (auto &model : this->wake_word_models_) { - // Perform inference - model.perform_streaming_inference(audio_features); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->perform_streaming_inference(audio_features); -#endif -} - -bool MicroWakeWord::detect_wake_words_() { - // Verify we have processed samples since the last positive detection - if (this->ignore_windows_ < 0) { - return false; - } - -#ifdef USE_MICRO_WAKE_WORD_VAD - bool vad_state = this->vad_model_->determine_detected(); -#endif - - for (auto &model : this->wake_word_models_) { - if (model.determine_detected()) { -#ifdef USE_MICRO_WAKE_WORD_VAD - if (vad_state) { -#endif - this->detected_wake_word_ = model.get_wake_word(); - return true; -#ifdef USE_MICRO_WAKE_WORD_VAD - } else { - ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str()); - } -#endif - } - } - - return false; -} - -bool MicroWakeWord::has_enough_samples_() { - return this->ring_buffer_->available() >= - (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t); -} - -bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - // Ensure we have enough new audio samples in the ring buffer for a full window - if (!this->has_enough_samples_()) { - return false; - } - - size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_), - this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200)); - - if (bytes_read == 0) { - ESP_LOGE(TAG, "Could not read data from Ring Buffer"); - } else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) { - ESP_LOGD(TAG, "Partial Read of Data by Model"); - ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read, - (int) (this->new_samples_to_get_() * sizeof(int16_t))); - return false; - } - - size_t num_samples_read; - struct FrontendOutput frontend_output = FrontendProcessSamples( - &this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read); +size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) { + size_t processed_samples = 0; + struct FrontendOutput frontend_output = + FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples); for (size_t i = 0; i < frontend_output.size; ++i) { // These scaling values are set to match the TFLite audio frontend int8 output. @@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F // for historical reasons, to match up with the output of other feature // generators. // The process is then further complicated when we quantize the model. This - // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 - // signed integer numbers. + // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN) + // to 127 (INT8_MAX) signed integer numbers. // All this means that to get matching values from our integer feature // output into the tensor input, we have to perform: // input = (((feature / 25.6) / 26.0) * 256) - 128 @@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F constexpr int32_t value_scale = 256; constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; - value -= 128; - if (value < -128) { - value = -128; - } - if (value > 127) { - value = 127; - } - features[i] = value; + + value -= INT8_MIN; + features_buffer[i] = clamp(value, INT8_MIN, INT8_MAX); } - return true; + return processed_samples; } -void MicroWakeWord::reset_states_() { - ESP_LOGD(TAG, "Resetting buffers and probabilities"); - this->ring_buffer_->reset(); - this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; +void MicroWakeWord::process_probabilities_() { +#ifdef USE_MICRO_WAKE_WORD_VAD + DetectionEvent vad_state = this->vad_model_->determine_detected(); + + this->vad_state_ = vad_state.detected; // atomic write, so thread safe +#endif + for (auto &model : this->wake_word_models_) { - model.reset_probabilities(); + if (model->get_unprocessed_probability_status()) { + // Only detect wake words if there is a new probability since the last check + DetectionEvent wake_word_state = model->determine_detected(); + if (wake_word_state.detected) { +#ifdef USE_MICRO_WAKE_WORD_VAD + if (vad_state.detected) { +#endif + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + model->reset_probabilities(); +#ifdef USE_MICRO_WAKE_WORD_VAD + } else { + wake_word_state.blocked_by_vad = true; + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + } +#endif + } + } + } +} + +void MicroWakeWord::unload_models_() { + for (auto &model : this->wake_word_models_) { + model->unload_model(); } #ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->reset_probabilities(); + this->vad_model_->unload_model(); #endif } -bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { - if (op_resolver.AddCallOnce() != kTfLiteOk) - return false; - if (op_resolver.AddVarHandle() != kTfLiteOk) - return false; - if (op_resolver.AddReshape() != kTfLiteOk) - return false; - if (op_resolver.AddReadVariable() != kTfLiteOk) - return false; - if (op_resolver.AddStridedSlice() != kTfLiteOk) - return false; - if (op_resolver.AddConcatenation() != kTfLiteOk) - return false; - if (op_resolver.AddAssignVariable() != kTfLiteOk) - return false; - if (op_resolver.AddConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddMul() != kTfLiteOk) - return false; - if (op_resolver.AddAdd() != kTfLiteOk) - return false; - if (op_resolver.AddMean() != kTfLiteOk) - return false; - if (op_resolver.AddFullyConnected() != kTfLiteOk) - return false; - if (op_resolver.AddLogistic() != kTfLiteOk) - return false; - if (op_resolver.AddQuantize() != kTfLiteOk) - return false; - if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddAveragePool2D() != kTfLiteOk) - return false; - if (op_resolver.AddMaxPool2D() != kTfLiteOk) - return false; - if (op_resolver.AddPad() != kTfLiteOk) - return false; - if (op_resolver.AddPack() != kTfLiteOk) - return false; - if (op_resolver.AddSplitV() != kTfLiteOk) - return false; +bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) { + bool success = true; - return true; + for (auto &model : this->wake_word_models_) { + // Perform inference + success = success & model->perform_streaming_inference(audio_features); + } +#ifdef USE_MICRO_WAKE_WORD_VAD + success = success & this->vad_model_->perform_streaming_inference(audio_features); +#endif + + return success; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index b06d35ca1f..626b8bffb8 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -5,33 +5,27 @@ #include "preprocessor_settings.h" #include "streaming_model.h" +#include "esphome/components/microphone/microphone_source.h" + #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/ring_buffer.h" -#include "esphome/components/microphone/microphone_source.h" +#include +#include #include -#include -#include -#include - namespace esphome { namespace micro_wake_word { enum State { - IDLE, - START_MICROPHONE, - STARTING_MICROPHONE, + STARTING, DETECTING_WAKE_WORD, - STOP_MICROPHONE, - STOPPING_MICROPHONE, + STOPPING, + STOPPED, }; -// The number of audio slices to process before accepting a positive detection -static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74; - class MicroWakeWord : public Component { public: void setup() override; @@ -42,7 +36,7 @@ class MicroWakeWord : public Component { void start(); void stop(); - bool is_running() const { return this->state_ != State::IDLE; } + bool is_running() const { return this->state_ != State::STOPPED; } void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } @@ -50,118 +44,87 @@ class MicroWakeWord : public Component { this->microphone_source_ = microphone_source; } + void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } + Trigger *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } - void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + void add_wake_word_model(WakeWordModel *model); #ifdef USE_MICRO_WAKE_WORD_VAD - void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, + void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + + // Intended for the voice assistant component to fetch VAD status + bool get_vad_state() { return this->vad_state_; } #endif + // Intended for the voice assistant component to access which wake words are available + // Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them + std::vector get_wake_words(); + protected: microphone::MicrophoneSource *microphone_source_{nullptr}; Trigger *wake_word_detected_trigger_ = new Trigger(); - State state_{State::IDLE}; + State state_{State::STOPPED}; - std::shared_ptr ring_buffer_; - - std::vector wake_word_models_; + std::weak_ptr ring_buffer_; + std::vector wake_word_models_; #ifdef USE_MICRO_WAKE_WORD_VAD std::unique_ptr vad_model_; + bool vad_state_{false}; #endif - tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + bool pending_start_{false}; + bool pending_stop_{false}; + + bool stop_after_detection_; + + uint8_t features_step_size_; // Audio frontend handles generating spectrogram features struct FrontendConfig frontend_config_; struct FrontendState frontend_state_; - // When the wake word detection first starts, we ignore this many audio - // feature slices before accepting a positive detection - int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + // Handles managing the stop/state of the inference task + EventGroupHandle_t event_group_; - uint8_t features_step_size_; + // Used to send messages about the models' states to the main loop + QueueHandle_t detection_queue_; - // Stores audio read from the microphone before being added to the ring buffer. - int16_t *input_buffer_{nullptr}; - // Stores audio to be fed into the audio frontend for generating features. - int16_t *preprocessor_audio_buffer_{nullptr}; + static void inference_task(void *params); + TaskHandle_t inference_task_handle_{nullptr}; - bool detected_{false}; - std::string detected_wake_word_{""}; + /// @brief Suspends the inference task + void suspend_task_(); + /// @brief Resumes the inference task + void resume_task_(); void set_state_(State state); - /// @brief Tests if there are enough samples in the ring buffer to generate new features. - /// @return True if enough samples, false otherwise. - bool has_enough_samples_(); + /// @brief Generates spectrogram features from an input buffer of audio samples + /// @param audio_buffer (int16_t *) Buffer containing input audio samples + /// @param samples_available (size_t) Number of samples avaiable in the input buffer + /// @param features_buffer (int8_t *) Buffer to store generated features + /// @return (size_t) Number of samples processed from the input buffer + size_t generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]); - /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ - /// @return True if successful, false otherwise - bool allocate_buffers_(); + /// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent + /// to the detection_queue_. + void process_probabilities_(); - /// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_ - void deallocate_buffers_(); - - /// @brief Loads streaming models and prepares the feature generation frontend - /// @return True if successful, false otherwise - bool load_models_(); - - /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature - /// generation frontend. + /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. void unload_models_(); - /** Performs inference with each configured model - * - * If enough audio samples are available, it will generate one slice of new features. - * It then loops through and performs inference with each of the loaded models. - */ - void update_model_probabilities_(); - - /** Checks every model's recent probabilities to determine if the wake word has been predicted - * - * Verifies the models have processed enough new samples for accurate predictions. - * Sets detected_wake_word_ to the wake word, if one is detected. - * @return True if a wake word is predicted, false otherwise - */ - bool detect_wake_words_(); - - /** Generates features for a window of audio samples - * - * Reads samples from the ring buffer and feeds them into the preprocessor frontend. - * Adapted from TFLite microspeech frontend. - * @param features int8_t array to store the audio features - * @return True if successful, false otherwise. - */ - bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]); - - /// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities - void reset_states_(); - - /// @brief Returns true if successfully registered the streaming model's TensorFlow operations - bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); + /// @brief Runs an inference with each model using the new spectrogram features + /// @param audio_features (int8_t *) Buffer containing new spectrogram features + /// @return True if successful, false if any errors were encountered + bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]); inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } }; -template class StartAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->start(); } -}; - -template class StopAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->stop(); } -}; - -template class IsRunningCondition : public Condition, public Parented { - public: - bool check(Ts... x) override { return this->parent_->is_running(); } -}; - } // namespace micro_wake_word } // namespace esphome diff --git a/esphome/components/micro_wake_word/preprocessor_settings.h b/esphome/components/micro_wake_word/preprocessor_settings.h index 03f4fb5230..025e21c5f7 100644 --- a/esphome/components/micro_wake_word/preprocessor_settings.h +++ b/esphome/components/micro_wake_word/preprocessor_settings.h @@ -7,6 +7,10 @@ namespace esphome { namespace micro_wake_word { +// Settings for controlling the spectrogram feature generation by the preprocessor. +// These must match the settings used when training a particular model. +// All microWakeWord models have been trained with these specific paramters. + // The number of features the audio preprocessor generates per slice static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; // Duration of each slice used as input into the preprocessor @@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30; // Audio sample frequency in hertz static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; +static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0; +static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0; + +static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10; +static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025; +static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06; +static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05; + +static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true; +static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95; +static const float PCAN_GAIN_CONTROL_OFFSET = 80.0; +static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21; + +static const bool LOG_SCALE_ENABLE_LOG = true; +static const uint8_t LOG_SCALE_SCALE_SHIFT = 6; } // namespace micro_wake_word } // namespace esphome diff --git a/esphome/components/micro_wake_word/streaming_model.cpp b/esphome/components/micro_wake_word/streaming_model.cpp index d0d2e2df05..6512c0f569 100644 --- a/esphome/components/micro_wake_word/streaming_model.cpp +++ b/esphome/components/micro_wake_word/streaming_model.cpp @@ -1,8 +1,7 @@ -#ifdef USE_ESP_IDF - #include "streaming_model.h" -#include "esphome/core/hal.h" +#ifdef USE_ESP_IDF + #include "esphome/core/helpers.h" #include "esphome/core/log.h" @@ -13,18 +12,18 @@ namespace micro_wake_word { void WakeWordModel::log_model_config() { ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str()); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } void VADModel::log_model_config() { ESP_LOGCONFIG(TAG, " - VAD Model"); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } -bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) { - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); +bool StreamingModel::load_model_() { + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); if (this->tensor_arena_ == nullptr) { this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); @@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } if (this->interpreter_ == nullptr) { - this->interpreter_ = make_unique( - tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_); + this->interpreter_ = + make_unique(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, + this->tensor_arena_, this->tensor_arena_size_, this->mrv_); if (this->interpreter_->AllocateTensors() != kTfLiteOk) { ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); return false; @@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } } + this->loaded_ = true; + this->reset_probabilities(); return true; } void StreamingModel::unload_model() { this->interpreter_.reset(); - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); - arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); - this->tensor_arena_ = nullptr; - arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); - this->var_arena_ = nullptr; + if (this->tensor_arena_ != nullptr) { + arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); + this->tensor_arena_ = nullptr; + } + + if (this->var_arena_ != nullptr) { + arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->var_arena_ = nullptr; + } + + this->loaded_ = false; } bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - if (this->interpreter_ != nullptr) { + if (this->enabled_ && !this->loaded_) { + // Model is enabled but isn't loaded + if (!this->load_model_()) { + return false; + } + } + + if (!this->enabled_ && this->loaded_) { + // Model is disabled but still loaded + this->unload_model(); + return true; + } + + if (this->loaded_) { TfLiteTensor *input = this->interpreter_->input(0); + uint8_t stride = this->interpreter_->input(0)->dims->data[1]; + this->current_stride_step_ = this->current_stride_step_ % stride; + std::memmove( (int8_t *) (tflite::GetTensorData(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, features, PREPROCESSOR_FEATURE_SIZE); ++this->current_stride_step_; - uint8_t stride = this->interpreter_->input(0)->dims->data[1]; - if (this->current_stride_step_ >= stride) { - this->current_stride_step_ = 0; - TfLiteStatus invoke_status = this->interpreter_->Invoke(); if (invoke_status != kTfLiteOk) { ESP_LOGW(TAG, "Streaming interpreter invoke failed"); @@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES if (this->last_n_index_ == this->sliding_window_size_) this->last_n_index_ = 0; this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability; + this->unprocessed_probability_status_ = true; } - return true; + this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); } - ESP_LOGE(TAG, "Streaming interpreter is not initialized."); - return false; + return true; } void StreamingModel::reset_probabilities() { for (auto &prob : this->recent_streaming_probabilities_) { prob = 0; } + this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; } -WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size) { +WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only) { + this->id_ = id; this->model_start_ = model_start; this->probability_cutoff_ = probability_cutoff; this->sliding_window_size_ = sliding_window_average_size; this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); this->wake_word_ = wake_word; this->tensor_arena_size_ = tensor_arena_size; + this->register_streaming_ops_(this->streaming_op_resolver_); + this->current_stride_step_ = 0; + this->internal_only_ = internal_only; + + this->pref_ = global_preferences->make_preference(fnv1_hash(id)); + bool enabled; + if (this->pref_.load(&enabled)) { + // Use the enabled state loaded from flash + this->enabled_ = enabled; + } else { + // If no state saved, then use the default + this->enabled_ = default_enabled; + } }; -bool WakeWordModel::determine_detected() { +void WakeWordModel::enable() { + this->enabled_ = true; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +void WakeWordModel::disable() { + this->enabled_ = false; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +DetectionEvent WakeWordModel::determine_detected() { + DetectionEvent detection_event; + detection_event.wake_word = &this->wake_word_; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if ((this->ignore_windows_ < 0) || !this->enabled_) { + detection_event.detected = false; + return detection_event; + } + uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_; - // Detect the wake word if the sliding window average is above the cutoff - if (sliding_window_average > this->probability_cutoff_) { - ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f", - this->wake_word_.c_str(), sliding_window_average, - this->recent_streaming_probabilities_[this->last_n_index_] / (255.0)); - return true; - } - return false; + this->unprocessed_probability_status_ = false; + return detection_event; } -VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->model_start_ = model_start; this->probability_cutoff_ = probability_cutoff; this->sliding_window_size_ = sliding_window_size; this->recent_streaming_probabilities_.resize(sliding_window_size, 0); this->tensor_arena_size_ = tensor_arena_size; -}; + this->register_streaming_ops_(this->streaming_op_resolver_); +} + +DetectionEvent VADModel::determine_detected() { + DetectionEvent detection_event; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if (!this->enabled_) { + // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected + detection_event.detected = true; + return detection_event; + } -bool VADModel::determine_detected() { uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_); - return sliding_window_average > this->probability_cutoff_; + return detection_event; +} + +bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { + if (op_resolver.AddCallOnce() != kTfLiteOk) + return false; + if (op_resolver.AddVarHandle() != kTfLiteOk) + return false; + if (op_resolver.AddReshape() != kTfLiteOk) + return false; + if (op_resolver.AddReadVariable() != kTfLiteOk) + return false; + if (op_resolver.AddStridedSlice() != kTfLiteOk) + return false; + if (op_resolver.AddConcatenation() != kTfLiteOk) + return false; + if (op_resolver.AddAssignVariable() != kTfLiteOk) + return false; + if (op_resolver.AddConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddMul() != kTfLiteOk) + return false; + if (op_resolver.AddAdd() != kTfLiteOk) + return false; + if (op_resolver.AddMean() != kTfLiteOk) + return false; + if (op_resolver.AddFullyConnected() != kTfLiteOk) + return false; + if (op_resolver.AddLogistic() != kTfLiteOk) + return false; + if (op_resolver.AddQuantize() != kTfLiteOk) + return false; + if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddAveragePool2D() != kTfLiteOk) + return false; + if (op_resolver.AddMaxPool2D() != kTfLiteOk) + return false; + if (op_resolver.AddPad() != kTfLiteOk) + return false; + if (op_resolver.AddPack() != kTfLiteOk) + return false; + if (op_resolver.AddSplitV() != kTfLiteOk) + return false; + + return true; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/streaming_model.h b/esphome/components/micro_wake_word/streaming_model.h index 0d85579f35..5bd1cf356a 100644 --- a/esphome/components/micro_wake_word/streaming_model.h +++ b/esphome/components/micro_wake_word/streaming_model.h @@ -4,6 +4,8 @@ #include "preprocessor_settings.h" +#include "esphome/core/preferences.h" + #include #include #include @@ -11,30 +13,63 @@ namespace esphome { namespace micro_wake_word { +static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100; static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; +struct DetectionEvent { + std::string *wake_word; + bool detected; + bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average + // hasn't yet + uint8_t max_probability; + uint8_t average_probability; + bool blocked_by_vad = false; +}; + class StreamingModel { public: virtual void log_model_config() = 0; - virtual bool determine_detected() = 0; + virtual DetectionEvent determine_detected() = 0; + // Performs inference on the given features. + // - If the model is enabled but not loaded, it will load it + // - If the model is disabled but loaded, it will unload it + // Returns true if sucessful or false if there is an error bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); - /// @brief Sets all recent_streaming_probabilities to 0 + /// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count void reset_probabilities(); - /// @brief Allocates tensor and variable arenas and sets up the model interpreter - /// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded - /// @return True if successful, false otherwise - bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver); - /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory void unload_model(); - protected: - uint8_t current_stride_step_{0}; + /// @brief Enable the model. The next performing_streaming_inference call will load it. + virtual void enable() { this->enabled_ = true; } - float probability_cutoff_; + /// @brief Disable the model. The next performing_streaming_inference call will unload it. + virtual void disable() { this->enabled_ = false; } + + /// @brief Return true if the model is enabled. + bool is_enabled() { return this->enabled_; } + + bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; } + + protected: + /// @brief Allocates tensor and variable arenas and sets up the model interpreter + /// @return True if successful, false otherwise + bool load_model_(); + /// @brief Returns true if successfully registered the streaming model's TensorFlow operations + bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); + + tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + + bool loaded_{false}; + bool enabled_{true}; + bool unprocessed_probability_status_{false}; + uint8_t current_stride_step_{0}; + int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + + uint8_t probability_cutoff_; // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255 size_t sliding_window_size_; size_t last_n_index_{0}; size_t tensor_arena_size_; @@ -50,32 +85,62 @@ class StreamingModel { class WakeWordModel final : public StreamingModel { public: - WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + /// @brief Constructs a wake word model object + /// @param id (std::string) identifier for this model + /// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer + /// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said + /// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling + /// probability + /// @param wake_word (std::string) Friendly name of the wake word + /// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena + /// @param default_enabled (bool) If true, it will be enabled by default on first boot + /// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model + WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only); void log_model_config() override; /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability /// cutoff /// @return True if wake word is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; + const std::string &get_id() const { return this->id_; } const std::string &get_wake_word() const { return this->wake_word_; } + void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); } + const std::vector &get_trained_languages() const { return this->trained_languages_; } + + /// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it. + void enable() override; + + /// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it. + void disable() override; + + bool get_internal_only() { return this->internal_only_; } + protected: + std::string id_; std::string wake_word_; + std::vector trained_languages_; + + bool internal_only_; + + ESPPreferenceObject pref_; }; class VADModel final : public StreamingModel { public: - VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, + size_t tensor_arena_size); void log_model_config() override; /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability /// cutoff /// @return True if voice activity is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; }; } // namespace micro_wake_word diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 81ff6999ba..de963313db 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -79,6 +79,7 @@ #define USE_LVGL_TEXTAREA #define USE_LVGL_TILEVIEW #define USE_LVGL_TOUCHSCREEN +#define USE_MICRO_WAKE_WORD #define USE_MD5 #define USE_MDNS #define USE_MEDIA_PLAYER diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index b5507397f8..c051c8dd57 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -14,8 +14,24 @@ micro_wake_word: microphone: echo_microphone on_wake_word_detected: - logger.log: "Wake word detected" + - micro_wake_word.stop: + - if: + condition: + - micro_wake_word.model_is_enabled: hey_jarvis_model + then: + - micro_wake_word.disable_model: hey_jarvis_model + else: + - micro_wake_word.enable_model: hey_jarvis_model + - if: + condition: + - not: + - micro_wake_word.is_running: + then: + micro_wake_word.start: + stop_after_detection: false models: - model: hey_jarvis probability_cutoff: 0.7 + id: hey_jarvis_model - model: okay_nabu sliding_window_size: 5