[audio, microphone] Quantization Improvements (#8695)

This commit is contained in:
Kevin Ahrendt 2025-05-05 16:23:50 -05:00 committed by GitHub
parent 1ac56b06c5
commit 88be14aaa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 44 deletions

View File

@ -135,5 +135,30 @@ const char *audio_file_type_to_string(AudioFileType file_type);
void scale_audio_samples(const int16_t *audio_samples, int16_t *output_buffer, int16_t scale_factor, void scale_audio_samples(const int16_t *audio_samples, int16_t *output_buffer, int16_t scale_factor,
size_t samples_to_scale); size_t samples_to_scale);
/// @brief Unpacks a quantized audio sample into a Q31 fixed point number.
/// @param data Pointer to uint8_t array containing the audio sample
/// @param bytes_per_sample The number of bytes per sample
/// @return Q31 sample
inline int32_t unpack_audio_sample_to_q31(const uint8_t *data, size_t bytes_per_sample) {
int32_t sample = 0;
if (bytes_per_sample == 1) {
sample |= data[0] << 24;
} else if (bytes_per_sample == 2) {
sample |= data[0] << 16;
sample |= data[1] << 24;
} else if (bytes_per_sample == 3) {
sample |= data[0] << 8;
sample |= data[1] << 16;
sample |= data[2] << 24;
} else if (bytes_per_sample == 4) {
sample |= data[0];
sample |= data[1] << 8;
sample |= data[2] << 16;
sample |= data[3] << 24;
}
return sample;
}
} // namespace audio } // namespace audio
} // namespace esphome } // namespace esphome

View File

@ -3,16 +3,34 @@
namespace esphome { namespace esphome {
namespace microphone { namespace microphone {
static const int32_t Q25_MAX_VALUE = (1 << 25) - 1;
static const int32_t Q25_MIN_VALUE = ~Q25_MAX_VALUE;
static const uint32_t HISTORY_VALUES = 32;
void MicrophoneSource::add_data_callback(std::function<void(const std::vector<uint8_t> &)> &&data_callback) { void MicrophoneSource::add_data_callback(std::function<void(const std::vector<uint8_t> &)> &&data_callback) {
std::function<void(const std::vector<uint8_t> &)> filtered_callback = std::function<void(const std::vector<uint8_t> &)> filtered_callback =
[this, data_callback](const std::vector<uint8_t> &data) { [this, data_callback](const std::vector<uint8_t> &data) {
if (this->enabled_) { if (this->enabled_) {
data_callback(this->process_audio_(data)); if (this->processed_samples_.use_count() == 0) {
// Create vector if its unused
this->processed_samples_ = std::make_shared<std::vector<uint8_t>>();
}
// Take temporary ownership of samples vector to avoid deallaction before the callback finishes
std::shared_ptr<std::vector<uint8_t>> output_samples = this->processed_samples_;
this->process_audio_(data, *output_samples);
data_callback(*output_samples);
} }
}; };
this->mic_->add_data_callback(std::move(filtered_callback)); this->mic_->add_data_callback(std::move(filtered_callback));
} }
audio::AudioStreamInfo MicrophoneSource::get_audio_stream_info() {
return audio::AudioStreamInfo(this->bits_per_sample_, this->channels_.count(),
this->mic_->get_audio_stream_info().get_sample_rate());
}
void MicrophoneSource::start() { void MicrophoneSource::start() {
if (!this->enabled_) { if (!this->enabled_) {
this->enabled_ = true; this->enabled_ = true;
@ -23,14 +41,21 @@ void MicrophoneSource::stop() {
if (this->enabled_) { if (this->enabled_) {
this->enabled_ = false; this->enabled_ = false;
this->mic_->stop(); this->mic_->stop();
this->processed_samples_.reset();
} }
} }
std::vector<uint8_t> MicrophoneSource::process_audio_(const std::vector<uint8_t> &data) { void MicrophoneSource::process_audio_(const std::vector<uint8_t> &data, std::vector<uint8_t> &filtered_data) {
// Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied. // - Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied.
// - In the comments, Qxx refers to a fixed point number with xx bits of precision for representing fractional values.
// For example, audio with a bit depth of 16 can store a sample in a int16, which can be considered a Q15 number.
// - All samples are converted to Q25 before applying the gain factor - this results in a small precision loss for
// data with 32 bits per sample. Since the maximum gain factor is 64 = (1<<6), this ensures that applying the gain
// will never overflow a 32 bit signed integer. This still retains more bit depth than what is audibly noticeable.
// - Loops for reading/writing data buffers are unrolled, assuming little endian, for a small performance increase.
const size_t source_bytes_per_sample = this->mic_->get_audio_stream_info().samples_to_bytes(1); const size_t source_bytes_per_sample = this->mic_->get_audio_stream_info().samples_to_bytes(1);
const size_t source_channels = this->mic_->get_audio_stream_info().get_channels(); const uint32_t source_channels = this->mic_->get_audio_stream_info().get_channels();
const size_t source_bytes_per_frame = this->mic_->get_audio_stream_info().frames_to_bytes(1); const size_t source_bytes_per_frame = this->mic_->get_audio_stream_info().frames_to_bytes(1);
@ -38,60 +63,48 @@ std::vector<uint8_t> MicrophoneSource::process_audio_(const std::vector<uint8_t>
const size_t target_bytes_per_sample = (this->bits_per_sample_ + 7) / 8; const size_t target_bytes_per_sample = (this->bits_per_sample_ + 7) / 8;
const size_t target_bytes_per_frame = target_bytes_per_sample * this->channels_.count(); const size_t target_bytes_per_frame = target_bytes_per_sample * this->channels_.count();
std::vector<uint8_t> filtered_data;
filtered_data.reserve(target_bytes_per_frame * total_frames); filtered_data.reserve(target_bytes_per_frame * total_frames);
filtered_data.resize(0);
const int32_t target_min_value = -(1 << (8 * target_bytes_per_sample - 1)); for (uint32_t frame_index = 0; frame_index < total_frames; ++frame_index) {
const int32_t target_max_value = (1 << (8 * target_bytes_per_sample - 1)) - 1; for (uint32_t channel_index = 0; channel_index < source_channels; ++channel_index) {
for (size_t frame_index = 0; frame_index < total_frames; ++frame_index) {
for (size_t channel_index = 0; channel_index < source_channels; ++channel_index) {
if (this->channels_.test(channel_index)) { if (this->channels_.test(channel_index)) {
// Channel's current sample is included in the target mask. Convert bits per sample, if necessary. // Channel's current sample is included in the target mask. Convert bits per sample, if necessary.
size_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample; const uint32_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample;
int32_t sample = 0; int32_t sample = audio::unpack_audio_sample_to_q31(&data[sample_index], source_bytes_per_sample); // Q31
sample >>= 6; // Q31 -> Q25
// Copy the data into the most significant bits of the sample variable to ensure the sign bit is correct
uint8_t bit_offset = (4 - source_bytes_per_sample) * 8;
for (int i = 0; i < source_bytes_per_sample; ++i) {
sample |= data[sample_index + i] << bit_offset;
bit_offset += 8;
}
// Shift data back to the least significant bits
if (source_bytes_per_sample >= target_bytes_per_sample) {
// Keep source bytes per sample of data so that the gain multiplication uses all significant bits instead of
// shifting to the target bytes per sample immediately, potentially losing information.
sample >>= (4 - source_bytes_per_sample) * 8; // ``source_bytes_per_sample`` bytes of valid data
} else {
// Keep padded zeros to match the target bytes per sample
sample >>= (4 - target_bytes_per_sample) * 8; // ``target_bytes_per_sample`` bytes of valid data
}
// Apply gain using multiplication // Apply gain using multiplication
sample *= this->gain_factor_; sample *= this->gain_factor_; // Q25
// Match target output bytes by shifting out the least significant bits // Clamp ``sample`` in case gain multiplication overflows 25 bits
if (source_bytes_per_sample > target_bytes_per_sample) { sample = clamp<int32_t>(sample, Q25_MIN_VALUE, Q25_MAX_VALUE); // Q25
sample >>= 8 * (source_bytes_per_sample -
target_bytes_per_sample); // ``target_bytes_per_sample`` bytes of valid data
}
// Clamp ``sample`` to the target bytes per sample range in case gain multiplication overflows
sample = clamp<int32_t>(sample, target_min_value, target_max_value);
// Copy ``target_bytes_per_sample`` bytes to the output buffer. // Copy ``target_bytes_per_sample`` bytes to the output buffer.
for (int i = 0; i < target_bytes_per_sample; ++i) { if (target_bytes_per_sample == 1) {
sample >>= 18; // Q25 -> Q7
filtered_data.push_back(static_cast<uint8_t>(sample)); filtered_data.push_back(static_cast<uint8_t>(sample));
sample >>= 8; } else if (target_bytes_per_sample == 2) {
sample >>= 10; // Q25 -> Q15
filtered_data.push_back(static_cast<uint8_t>(sample));
filtered_data.push_back(static_cast<uint8_t>(sample >> 8));
} else if (target_bytes_per_sample == 3) {
sample >>= 2; // Q25 -> Q23
filtered_data.push_back(static_cast<uint8_t>(sample));
filtered_data.push_back(static_cast<uint8_t>(sample >> 8));
filtered_data.push_back(static_cast<uint8_t>(sample >> 16));
} else {
sample *= (1 << 6); // Q25 -> Q31
filtered_data.push_back(static_cast<uint8_t>(sample));
filtered_data.push_back(static_cast<uint8_t>(sample >> 8));
filtered_data.push_back(static_cast<uint8_t>(sample >> 16));
filtered_data.push_back(static_cast<uint8_t>(sample >> 24));
} }
} }
} }
} }
return filtered_data;
} }
} // namespace microphone } // namespace microphone

View File

@ -1,15 +1,20 @@
#pragma once #pragma once
#include "microphone.h"
#include "esphome/components/audio/audio.h"
#include <bitset> #include <bitset>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "microphone.h"
namespace esphome { namespace esphome {
namespace microphone { namespace microphone {
static const int32_t MAX_GAIN_FACTOR = 64;
class MicrophoneSource { class MicrophoneSource {
/* /*
* @brief Helper class that handles converting raw microphone data to a requested format. * @brief Helper class that handles converting raw microphone data to a requested format.
@ -44,13 +49,23 @@ class MicrophoneSource {
void add_data_callback(std::function<void(const std::vector<uint8_t> &)> &&data_callback); void add_data_callback(std::function<void(const std::vector<uint8_t> &)> &&data_callback);
void set_gain_factor(int32_t gain_factor) { this->gain_factor_ = clamp<int32_t>(gain_factor, 1, MAX_GAIN_FACTOR); }
int32_t get_gain_factor() { return this->gain_factor_; }
/// @brief Gets the AudioStreamInfo of the data after processing
/// @return audio::AudioStreamInfo with the configured bits per sample, configured channel count, and source
/// microphone's sample rate
audio::AudioStreamInfo get_audio_stream_info();
void start(); void start();
void stop(); void stop();
bool is_running() const { return (this->mic_->is_running() && this->enabled_); } bool is_running() const { return (this->mic_->is_running() && this->enabled_); }
bool is_stopped() const { return !this->enabled_; } bool is_stopped() const { return !this->enabled_; }
protected: protected:
std::vector<uint8_t> process_audio_(const std::vector<uint8_t> &data); void process_audio_(const std::vector<uint8_t> &data, std::vector<uint8_t> &filtered_data);
std::shared_ptr<std::vector<uint8_t>> processed_samples_;
Microphone *mic_; Microphone *mic_;
uint8_t bits_per_sample_; uint8_t bits_per_sample_;