Functional nonce implemetation

This commit is contained in:
Christian Schwinne 2024-11-17 01:09:36 +01:00
parent fd624aa94b
commit 82073e6bc2
5 changed files with 112 additions and 57 deletions

View File

@ -421,6 +421,8 @@
#define SEG_CAPABILITY_W 0x02 #define SEG_CAPABILITY_W 0x02
#define SEG_CAPABILITY_CCT 0x04 #define SEG_CAPABILITY_CCT 0x04
#define SESSION_ID_SIZE 16
// WLED Error modes // WLED Error modes
#define ERR_NONE 0 // All good :) #define ERR_NONE 0 // All good :)
#define ERR_DENIED 1 // Permission denied #define ERR_DENIED 1 // Permission denied

View File

@ -3,11 +3,13 @@
#define HMAC_KEY_SIZE 32 #define HMAC_KEY_SIZE 32
#define SESSION_ID_SIZE 16
#define MAX_SESSION_IDS 8 #define MAX_SESSION_IDS 8
void getNonce(byte* nonce) { void printByteArray(const byte* arr, size_t len) {
RNG::fill(nonce, SESSION_ID_SIZE); for (size_t i = 0; i < len; i++) {
Serial.print(arr[i], HEX);
}
Serial.println();
} }
struct Nonce { struct Nonce {
@ -27,26 +29,38 @@ void moveToFirst(uint32_t i) {
knownSessions[0] = tmp; knownSessions[0] = tmp;
} }
bool verifyNonce(const byte* sid, uint32_t counter) { uint8_t verifyNonce(const byte* sid, uint32_t counter) {
Serial.println(F("check sid"));
printByteArray(sid, SESSION_ID_SIZE);
uint32_t sum = 0;
for (size_t i = 0; i < SESSION_ID_SIZE; i++) {
sum += sid[i];
}
if (sum == 0) { // all-zero session ID is invalid as it is used for uninitialized entries
return ERR_NONCE;
}
for (int i = 0; i < MAX_SESSION_IDS; i++) { for (int i = 0; i < MAX_SESSION_IDS; i++) {
if (memcmp(knownSessions[i].sessionId, sid, SESSION_ID_SIZE) == 0) { if (memcmp(knownSessions[i].sessionId, sid, SESSION_ID_SIZE) == 0) {
Serial.print(F("Session ID matches e"));
Serial.println(i);
if (counter <= knownSessions[i].counter) { if (counter <= knownSessions[i].counter) {
Serial.println(F("Retransmission detected!")); Serial.println(F("Retransmission detected!"));
return false; return ERR_REPLAY;
} }
knownSessions[i].counter = counter; knownSessions[i].counter = counter;
// nonce good, move this entry to the first position of knownSessions // nonce good, move this entry to the first position of knownSessions
moveToFirst(i); moveToFirst(i);
return true; return ERR_NONE;
} }
} }
Serial.println(F("Unknown session ID!")); Serial.println(F("Unknown session ID!"));
return false; return ERR_NONCE;
} }
void addSession(const char* sid) { void addSessionId(byte* sid) {
byte sid_new[SESSION_ID_SIZE]; RNG::fill(sid, SESSION_ID_SIZE);
RNG::fill(sid_new, SESSION_ID_SIZE);
// first, try to find a completely unused slot // first, try to find a completely unused slot
for (int i = 0; i < MAX_SESSION_IDS; i++) { for (int i = 0; i < MAX_SESSION_IDS; i++) {
@ -71,20 +85,30 @@ void addSession(const char* sid) {
moveToFirst(MAX_SESSION_IDS - 1); moveToFirst(MAX_SESSION_IDS - 1);
} }
void printByteArray(const byte* arr, size_t len) {
for (size_t i = 0; i < len; i++) {
Serial.print(arr[i], HEX);
}
Serial.println();
}
void hexStringToByteArray(const char* hexString, unsigned char* byteArray, size_t byteArraySize) { void hexStringToByteArray(const char* hexString, unsigned char* byteArray, size_t byteArraySize) {
size_t lenStr = strlen(hexString);
if (lenStr < 2 * byteArraySize) byteArraySize = lenStr / 2;
for (size_t i = 0; i < byteArraySize; i++) { for (size_t i = 0; i < byteArraySize; i++) {
char c[3] = {hexString[2 * i], hexString[2 * i + 1], '\0'}; // Get two characters char c[3] = {hexString[2 * i], hexString[2 * i + 1], '\0'}; // Get two characters
byteArray[i] = (unsigned char)strtoul(c, NULL, 16); // Convert to byte byteArray[i] = (unsigned char)strtoul(c, NULL, 16); // Convert to byte
} }
} }
// requires hexString to be at least 2 * byteLen + 1 characters long
char* byteArrayToHexString(char* hexString, const byte* byteArray, size_t byteLen) {
for (size_t i = 0; i < byteLen; ++i) {
// Convert each byte to a two-character hex string
sprintf(&hexString[i * 2], "%02x", byteArray[i]);
}
// Null-terminate the string
hexString[byteLen * 2] = '\0';
return hexString;
}
void hmacSign(const byte* message, size_t msgLen, const char* pskHex, byte* signature) { void hmacSign(const byte* message, size_t msgLen, const char* pskHex, byte* signature) {
size_t len = strlen(pskHex) / 2; // This will drop the last character if the string has an odd length size_t len = strlen(pskHex) / 2; // This will drop the last character if the string has an odd length
if (len > HMAC_KEY_SIZE) { if (len > HMAC_KEY_SIZE) {
@ -182,7 +206,6 @@ uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen) {
objEnd = objStart + i; objEnd = objStart + i;
break; break;
} }
//i++;
} }
if (objEnd == nullptr) { if (objEnd == nullptr) {
Serial.println(F("Couldn't find msg object end.")); Serial.println(F("Couldn't find msg object end."));
@ -196,39 +219,6 @@ uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen) {
Serial.println(F("No nonce found in msg.")); Serial.println(F("No nonce found in msg."));
return ERR_HMAC_GEN; return ERR_HMAC_GEN;
} }
// {
// StaticJsonDocument<128> nonceDoc;
// DeserializationError error = deserializeJson(nonceDoc, noncePos +5);
// if (error) {
// Serial.print(F("deser nc failed: "));
// Serial.println(error.c_str());
// return false;
// }
// JsonObject nonceObj = nonceDoc.as<JsonObject>();
// if (nonceObj.isNull()) {
// Serial.println(F("Failed nonce JSON."));
// return false;
// }
// const char* sessionId = nonceObj["sid"];
// if (sessionId == nullptr) {
// Serial.println(F("No session ID found in nonce."));
// return false;
// }
// uint32_t counter = nonceObj["c"] | 0;
// if (counter == 0) {
// Serial.println(F("No counter found in nonce."));
// return false;
// }
// if (counter > UINT32_MAX - 100) {
// Serial.println(F("Counter too large."));
// return false;
// }
// byte sidBytes[SESSION_ID_SIZE];
// hexStringToByteArray(sessionId, sidBytes, SESSION_ID_SIZE);
// if (!verifyNonce(sidBytes, counter)) {
// return false;
// }
// }
// Convert the MAC from hex string to byte array // Convert the MAC from hex string to byte array
size_t len = strlen(mac) / 2; // This will drop the last character if the string has an odd length size_t len = strlen(mac) / 2; // This will drop the last character if the string has an odd length
@ -240,8 +230,44 @@ uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen) {
hexStringToByteArray(mac, macByteArray, len); hexStringToByteArray(mac, macByteArray, len);
// Calculate the HMAC of the message object // Calculate the HMAC of the message object
bool hmacOk = hmacVerify((const byte*)objStart, objEnd - objStart + 1, WLED_HMAC_TEST_PSK, macByteArray); if (!hmacVerify((const byte*)objStart, objEnd - objStart + 1, WLED_HMAC_TEST_PSK, macByteArray)) {
return hmacOk ? ERR_NONE : ERR_HMAC; return ERR_HMAC;
}
// Nonce verification (Replay attack prevention)
{
StaticJsonDocument<128> nonceDoc;
DeserializationError error = deserializeJson(nonceDoc, noncePos +5);
if (error) {
Serial.print(F("deser nc failed: "));
Serial.println(error.c_str());
return ERR_HMAC_GEN;
}
JsonObject nonceObj = nonceDoc.as<JsonObject>();
if (nonceObj.isNull()) {
Serial.println(F("Failed nonce JSON."));
return ERR_HMAC_GEN;
}
const char* sessionId = nonceObj["sid"];
if (sessionId == nullptr) {
Serial.println(F("No session ID found in nonce."));
return ERR_HMAC_GEN;
}
uint32_t counter = nonceObj["c"] | 0;
if (counter == 0) {
Serial.println(F("No counter found in nonce."));
return ERR_HMAC_GEN;
}
if (counter > UINT32_MAX - 100) {
Serial.println(F("Counter too large."));
return ERR_NONCE;
}
byte sidBytes[SESSION_ID_SIZE] = {};
hexStringToByteArray(sessionId, sidBytes, SESSION_ID_SIZE);
uint8_t nonceResult = verifyNonce(sidBytes, counter);
return nonceResult ? nonceResult : ERR_NONE;
}
} }
bool hmacTest() { bool hmacTest() {

View File

@ -713,6 +713,12 @@ function parseInfo(i) {
} else { } else {
gId("filter2D").classList.remove('hide'); gId("filter2D").classList.remove('hide');
} }
if (useSRA && i.sid) {
if (sraWindow) {
sraWindow.postMessage(JSON.stringify({"wled-ui":"sid","sid":i.sid}), sraOrigin);
}
}
// if (i.noaudio) { // if (i.noaudio) {
// gId("filterVol").classList.add("hide"); // gId("filterVol").classList.add("hide");
// gId("filterFreq").classList.add("hide"); // gId("filterFreq").classList.add("hide");
@ -1438,9 +1444,18 @@ function makeWS() {
} else } else
i = lastinfo; i = lastinfo;
if (json.error) { if (json.error) {
if (json.error == 1) { if (json.error == 42) {
showToast('HMAC verification failed! Please make sure you used the right password!', true); showToast('HMAC verification failed! Please make sure you used the right password!', true);
return; return;
} else if (json.error == 43) {
showToast("This light's control is password protected. Please access it through rc.wled.me", true);
return;
} else if (json.error == 41) {
showToast('Replayed message detected!', true);
return;
} else if (json.error == 40) {
showToast('Invalid nonce', true);
return;
} }
showToast(json.error, true); showToast(json.error, true);
return; return;

View File

@ -96,6 +96,8 @@ uint16_t approximateKelvinFromRGB(uint32_t rgb);
void setRandomColor(byte* rgb); void setRandomColor(byte* rgb);
//crypto.cpp //crypto.cpp
void addSessionId(byte* sid);
char* byteArrayToHexString(char* hexString, const byte* byteArray, size_t byteLen);
void hmacSign(const byte* message, size_t msgLen, const char* pskHex, byte* signature); void hmacSign(const byte* message, size_t msgLen, const char* pskHex, byte* signature);
bool hmacVerify(const byte* message, size_t msgLen, const char* pskHex, const byte* signature); bool hmacVerify(const byte* message, size_t msgLen, const char* pskHex, const byte* signature);
uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen); uint8_t verifyHmacFromJsonStr(const char* jsonStr, uint32_t maxLen);
@ -460,7 +462,7 @@ void serveSettingsJS(AsyncWebServerRequest* request);
//ws.cpp //ws.cpp
void handleWs(); void handleWs();
void wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); void wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len);
void sendDataWs(AsyncWebSocketClient * client = nullptr); void sendDataWs(AsyncWebSocketClient * client = nullptr, bool initialConnection = false);
//xml.cpp //xml.cpp
void XML_response(Print& dest); void XML_response(Print& dest);

View File

@ -32,7 +32,7 @@ void wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventTyp
if(type == WS_EVT_CONNECT){ if(type == WS_EVT_CONNECT){
//client connected //client connected
DEBUG_PRINTLN(F("WS client connected.")); DEBUG_PRINTLN(F("WS client connected."));
sendDataWs(client); sendDataWs(client, true);
} else if(type == WS_EVT_DISCONNECT){ } else if(type == WS_EVT_DISCONNECT){
//client disconnected //client disconnected
if (client->id() == wsLiveClientId) wsLiveClientId = 0; if (client->id() == wsLiveClientId) wsLiveClientId = 0;
@ -138,7 +138,7 @@ void wsEvent(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventTyp
} }
} }
void sendDataWs(AsyncWebSocketClient * client) void sendDataWs(AsyncWebSocketClient * client, bool initialConnection)
{ {
if (!ws.count()) return; if (!ws.count()) return;
@ -157,6 +157,16 @@ void sendDataWs(AsyncWebSocketClient * client)
JsonObject info = pDoc->createNestedObject("info"); JsonObject info = pDoc->createNestedObject("info");
serializeInfo(info); serializeInfo(info);
if (initialConnection) {
char sid[SESSION_ID_SIZE*2+1] = {};
byte sidBytes[SESSION_ID_SIZE] = {};
addSessionId(sidBytes);
byteArrayToHexString(sid, sidBytes, SESSION_ID_SIZE);
Serial.print(F("New session ID: "));
Serial.println(sid);
info["sid"] = sid;
}
size_t len = measureJson(*pDoc); size_t len = measureJson(*pDoc);
DEBUG_PRINTF_P(PSTR("JSON buffer size: %u for WS request (%u).\n"), pDoc->memoryUsage(), len); DEBUG_PRINTF_P(PSTR("JSON buffer size: %u for WS request (%u).\n"), pDoc->memoryUsage(), len);