#include #include "wled.h" #define HMAC_KEY_SIZE 32 #define MAX_SESSION_IDS 8 void printByteArray(const byte* arr, size_t len) { for (size_t i = 0; i < len; i++) { Serial.print(arr[i], HEX); } Serial.println(); } struct Nonce { byte sessionId[SESSION_ID_SIZE]; uint32_t counter; }; Nonce knownSessions[MAX_SESSION_IDS] = {}; void moveToFirst(uint32_t i) { if (i >= MAX_SESSION_IDS) return; Nonce tmp = knownSessions[i]; for (int j = i; j > 0; j--) { knownSessions[j] = knownSessions[j - 1]; } knownSessions[0] = tmp; } uint8_t verifyNonce(const byte* sid, uint32_t counter) { Serial.println(F("check sid")); printByteArray(sid, SESSION_ID_SIZE); uint32_t sum = 0; for (size_t i = 0; i < SESSION_ID_SIZE; i++) { sum += sid[i]; } if (sum == 0) { // all-zero session ID is invalid as it is used for uninitialized entries return ERR_NONCE; } for (int i = 0; i < MAX_SESSION_IDS; i++) { if (memcmp(knownSessions[i].sessionId, sid, SESSION_ID_SIZE) == 0) { Serial.print(F("Session ID matches e")); Serial.println(i); if (counter <= knownSessions[i].counter) { Serial.println(F("Retransmission detected!")); return ERR_REPLAY; } knownSessions[i].counter = counter; // nonce good, move this entry to the first position of knownSessions moveToFirst(i); return ERR_NONE; } } Serial.println(F("Unknown session ID!")); return ERR_NONCE; } void addSessionId(byte* sid) { RNG::fill(sid, SESSION_ID_SIZE); // first, try to find a completely unused slot for (int i = 0; i < MAX_SESSION_IDS; i++) { // this is not perfect, but it is extremely unlikely that the first 32 bit of a random session ID are all zeroes if ((uint32_t)(knownSessions[i].sessionId) == 0 && knownSessions[i].counter == 0) { memcpy(knownSessions[i].sessionId, sid, SESSION_ID_SIZE); moveToFirst(i); return; } } // next, find oldest slot that has counter 0 (not used before) // but leave the two most recent slots alone for (int i = MAX_SESSION_IDS - 1; i > 1; i--) { if (knownSessions[i].counter == 0) { memcpy(knownSessions[i].sessionId, sid, SESSION_ID_SIZE); moveToFirst(i); return; } } // if all else fails, overwrite the oldest slot memcpy(knownSessions[MAX_SESSION_IDS - 1].sessionId, sid, SESSION_ID_SIZE); moveToFirst(MAX_SESSION_IDS - 1); } void hexStringToByteArray(const char* hexString, unsigned char* byteArray, size_t byteArraySize) { size_t lenStr = strlen(hexString); if (lenStr < 2 * byteArraySize) byteArraySize = lenStr / 2; for (size_t i = 0; i < byteArraySize; i++) { char c[3] = {hexString[2 * i], hexString[2 * i + 1], '\0'}; // Get two characters byteArray[i] = (unsigned char)strtoul(c, NULL, 16); // Convert to byte } } // requires hexString to be at least 2 * byteLen + 1 characters long char* byteArrayToHexString(char* hexString, const byte* byteArray, size_t byteLen) { for (size_t i = 0; i < byteLen; ++i) { // Convert each byte to a two-character hex string sprintf(&hexString[i * 2], "%02x", byteArray[i]); } // Null-terminate the string hexString[byteLen * 2] = '\0'; return hexString; } void hmacSign(const byte* message, size_t msgLen, const char* pskHex, byte* signature) { size_t len = strlen(pskHex) / 2; // This will drop the last character if the string has an odd length if (len > HMAC_KEY_SIZE) { Serial.println(F("PSK too long!")); return; } unsigned char pskByteArray[len]; hexStringToByteArray(pskHex, pskByteArray, len); SHA256HMAC hmac(pskByteArray, len); hmac.doUpdate(message, msgLen); hmac.doFinal(signature); } bool hmacVerify(const byte* message, size_t msgLen, const char* pskHex, const byte* signature) { byte sigCalculated[SHA256HMAC_SIZE]; hmacSign(message, msgLen, pskHex, sigCalculated); //Serial.print(F("Calculated: ")); //printByteArray(sigCalculated, SHA256HMAC_SIZE); if (memcmp(sigCalculated, signature, SHA256HMAC_SIZE) != 0) { Serial.println(F("HMAC verification failed!")); Serial.print(F("Expected: ")); printByteArray(signature, SHA256HMAC_SIZE); return false; } Serial.println(F("HMAC verification successful!")); return true; } #define WLED_HMAC_TEST_PW "guessihadthekeyafterall" #define WLED_HMAC_TEST_PSK "a6f8488da62c5888d7f640276676e78da8639faf0495110b43e226b35ac37a4c" uint8_t verifyHmacFromJsonString0Term(byte* jsonStr, size_t len) { // Zero-terminate the JSON string (replace the last character, usually '}', with a null terminator temporarily) byte lastChar = jsonStr[len-1]; jsonStr[len-1] = '\0'; uint8_t result = verifyHmacFromJsonStr((const char*)jsonStr, len); jsonStr[len-1] = lastChar; return result; } uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen) { // Extract the signature from the JSON string size_t jsonLen = strlen(jsonStr); Serial.print(F("Length: ")); Serial.println(jsonLen); if (jsonLen > maxLen) { // memory safety Serial.print(F("JSON string too long!")); Serial.print(F(", max: ")); Serial.println(maxLen); return ERR_HMAC_GEN; } Serial.print(F("Received JSON: ")); Serial.println(jsonStr); const char* macPos = strstr(jsonStr, "\"mac\":\""); if (macPos == nullptr) { Serial.println(F("No MAC found in JSON.")); return ERR_HMAC_MISS; } StaticJsonDocument<128> macDoc; DeserializationError error = deserializeJson(macDoc, macPos +6); if (error) { Serial.print(F("deserializeJson() failed: ")); Serial.println(error.c_str()); return ERR_HMAC_GEN; } const char* mac = macDoc.as(); if (mac == nullptr) { Serial.println(F("Failed MAC JSON.")); return ERR_HMAC_GEN; } Serial.print(F("Received MAC: ")); Serial.println(mac); // extract the message object from the JSON string char* msgPos = strstr(jsonStr, "\"msg\":"); char* objStart = strchr(msgPos + 6, '{'); if (objStart == nullptr) { Serial.println(F("Couldn't find msg object start.")); return ERR_HMAC_GEN; } size_t maxObjLen = jsonLen - (objStart - jsonStr); Serial.print(F("Max object length: ")); Serial.println(maxObjLen); int32_t objDepth = 0; char* objEnd = nullptr; for (size_t i = 0; i < maxObjLen; i++) { Serial.write(objStart[i]); if (objStart[i] == '{') objDepth++; if (objStart[i] == '}') objDepth--; if (objDepth == 0) { Serial.print(F("Found msg object end: ")); Serial.println(i); objEnd = objStart + i; break; } } if (objEnd == nullptr) { Serial.println(F("Couldn't find msg object end.")); return ERR_HMAC_GEN; } // get nonce (note: the nonce implementation uses "nc" for the key instead of "n" to avoid conflicts with segment names) const char* noncePos = strstr(objStart, "\"nc\":"); if (noncePos == nullptr || noncePos > objEnd) { // note that it is critical to check that the nonce is within the "msg" object and thus authenticated Serial.println(F("No nonce found in msg.")); return ERR_HMAC_GEN; } // Convert the MAC from hex string to byte array size_t len = strlen(mac) / 2; // This will drop the last character if the string has an odd length if (len != SHA256HMAC_SIZE) { Serial.println(F("Received MAC not expected size!")); return ERR_HMAC_GEN; } unsigned char macByteArray[len]; hexStringToByteArray(mac, macByteArray, len); // Calculate the HMAC of the message object if (!hmacVerify((const byte*)objStart, objEnd - objStart + 1, WLED_HMAC_TEST_PSK, macByteArray)) { return ERR_HMAC; } // Nonce verification (Replay attack prevention) { StaticJsonDocument<128> nonceDoc; DeserializationError error = deserializeJson(nonceDoc, noncePos +5); if (error) { Serial.print(F("deser nc failed: ")); Serial.println(error.c_str()); return ERR_HMAC_GEN; } JsonObject nonceObj = nonceDoc.as(); if (nonceObj.isNull()) { Serial.println(F("Failed nonce JSON.")); return ERR_HMAC_GEN; } const char* sessionId = nonceObj["sid"]; if (sessionId == nullptr) { Serial.println(F("No session ID found in nonce.")); return ERR_HMAC_GEN; } uint32_t counter = nonceObj["c"] | 0; if (counter == 0) { Serial.println(F("No counter found in nonce.")); return ERR_HMAC_GEN; } if (counter > UINT32_MAX - 100) { Serial.println(F("Counter too large.")); return ERR_NONCE; } byte sidBytes[SESSION_ID_SIZE] = {}; hexStringToByteArray(sessionId, sidBytes, SESSION_ID_SIZE); uint8_t nonceResult = verifyNonce(sidBytes, counter); return nonceResult ? nonceResult : ERR_NONE; } } bool hmacTest() { Serial.println(F("Testing HMAC...")); unsigned long start = millis(); const char message[] = "Hello, World!"; const char psk[] = "d0c0ffeedeadbeef"; byte mac[SHA256HMAC_SIZE]; hmacSign((const byte*)message, strlen(message), psk, mac); Serial.print(F("Took ")); Serial.print(millis() - start); Serial.println(F("ms to sign message.")); Serial.print(F("MAC: ")); printByteArray(mac, SHA256HMAC_SIZE); start = millis(); bool result = hmacVerify((const byte*)message, strlen(message), psk, mac); Serial.print(F("Took ")); Serial.print(millis() - start); Serial.println(F("ms to verify MAC.")); return result; } void printDuration(unsigned long start) { unsigned long end = millis(); Serial.print(F("Took ")); Serial.print(end - start); Serial.println(F(" ms.")); yield(); } #define HMAC_BENCH_ITERATIONS 100 void hmacBenchmark(const char* message) { Serial.print(F("Starting HMAC benchmark with message length:")); Serial.println(strlen(message)); Serial.println(F("100 iterations signing message.")); unsigned long start = millis(); byte mac[SHA256HMAC_SIZE]; for (int i = 0; i < HMAC_BENCH_ITERATIONS; i++) { hmacSign((const byte*)message, strlen(message), WLED_HMAC_TEST_PSK, mac); } printDuration(start); Serial.println(F("100 iterations verifying message.")); start = millis(); for (int i = 0; i < HMAC_BENCH_ITERATIONS; i++) { hmacVerify((const byte*)message, strlen(message), WLED_HMAC_TEST_PSK, mac); } printDuration(start); }