MINIFICPP-1026 - Added base64 encoder-decoder to StringUtils
Signed-off-by: Arpad Boda <aboda@apache.org>
This closes #690
diff --git a/extensions/expression-language/Expression.cpp b/extensions/expression-language/Expression.cpp
index 873e522..a25e1d3 100644
--- a/extensions/expression-language/Expression.cpp
+++ b/extensions/expression-language/Expression.cpp
@@ -48,7 +48,7 @@
#include <sys/types.h>
#endif
-#include "utils/base64.h"
+#include "utils/StringUtils.h"
#include "Driver.h"
#ifdef EXPRESSION_LANGUAGE_USE_DATE
@@ -709,30 +709,11 @@
}
Value expr_base64Encode(const std::vector<Value> &args) {
- auto arg_0 = args[0].asString();
- char *b64_out = nullptr;
- auto b64_len = Curl_base64_encode(arg_0.c_str(), arg_0.length(), &b64_out);
- if (b64_out) {
- std::string result(b64_out, b64_len);
- free(b64_out);
- return Value(result);
- } else {
- throw std::runtime_error("Failed to encode base64");
- }
+ return Value(utils::StringUtils::to_base64(args[0].asString()));
}
Value expr_base64Decode(const std::vector<Value> &args) {
- auto arg_0 = args[0].asString();
- unsigned char *decode_out = nullptr;
- // size_t Curl_base64_decode(const char *src, unsigned char **outptr)
- auto out_len = Curl_base64_decode(arg_0.c_str(), &decode_out);
- if (decode_out) {
- std::string result(reinterpret_cast<char *>(decode_out), out_len);
- free(decode_out);
- return Value(result);
- } else {
- throw std::runtime_error("Failed to encode base64");
- }
+ return Value(utils::StringUtils::from_base64(args[0].asString()));
}
#ifdef EXPRESSION_LANGUAGE_USE_REGEX
diff --git a/extensions/sftp/client/SFTPClient.cpp b/extensions/sftp/client/SFTPClient.cpp
index d2a2716..8f66ccc 100644
--- a/extensions/sftp/client/SFTPClient.cpp
+++ b/extensions/sftp/client/SFTPClient.cpp
@@ -26,7 +26,6 @@
#include "utils/StringUtils.h"
#include "utils/ScopeGuard.h"
#include "utils/StringUtils.h"
-#include "utils/base64.h"
namespace org {
namespace apache {
@@ -358,11 +357,9 @@
logger_->log_warn("Host %s not found in the host key file", hostname_.c_str());
break;
case LIBSSH2_KNOWNHOST_CHECK_MISMATCH: {
- char* b64_out = nullptr;
- auto b64_len = Curl_base64_encode(hostkey, hostkey_len, &b64_out);
+ auto hostkey_b64 = utils::StringUtils::to_base64(reinterpret_cast<const uint8_t*>(hostkey), hostkey_len);
logger_->log_warn("Host key mismatch for %s, expected: %s, actual: %s", hostname_.c_str(),
- known_host == nullptr ? "" : known_host->key,
- b64_out == nullptr ? "" : std::string(b64_out, b64_len).c_str());
+ known_host == nullptr ? "" : known_host->key, hostkey_b64.c_str());
break;
}
case LIBSSH2_KNOWNHOST_CHECK_MATCH:
diff --git a/libminifi/include/utils/StringUtils.h b/libminifi/include/utils/StringUtils.h
index b45f73d..fc44d93 100644
--- a/libminifi/include/utils/StringUtils.h
+++ b/libminifi/include/utils/StringUtils.h
@@ -17,6 +17,7 @@
#ifndef LIBMINIFI_INCLUDE_IO_STRINGUTILS_H_
#define LIBMINIFI_INCLUDE_IO_STRINGUTILS_H_
#include <iostream>
+#include <cstring>
#include <functional>
#ifdef WIN32
#include <cwctype>
@@ -235,34 +236,7 @@
* @param hex_length the length of hex
* @return true on success
*/
- inline static bool from_hex(uint8_t* data, size_t* data_length, const char* hex, size_t hex_length) {
- if (*data_length < hex_length / 2) {
- return false;
- }
- uint8_t n1;
- bool found_first_nibble = false;
- *data_length = 0;
- for (size_t i = 0; i < hex_length; i++) {
- const uint8_t byte = static_cast<uint8_t>(hex[i]);
- if (byte > 127) {
- continue;
- }
- uint8_t n = hex_lut[byte];
- if (n != SKIP) {
- if (found_first_nibble) {
- data[(*data_length)++] = n1 << 4 | n;
- found_first_nibble = false;
- } else {
- n1 = n;
- found_first_nibble = true;
- }
- }
- }
- if (found_first_nibble) {
- return false;
- }
- return true;
- }
+ static bool from_hex(uint8_t* data, size_t* data_length, const char* hex, size_t hex_length);
/**
* Hexdecodes a string
@@ -270,15 +244,7 @@
* @param hex_length the length of hex
* @return the vector containing the hexdecoded bytes
*/
- inline static std::vector<uint8_t> from_hex(const char* hex, size_t hex_length) {
- std::vector<uint8_t> decoded(hex_length / 2);
- size_t data_length = decoded.size();
- if (!from_hex(decoded.data(), &data_length, hex, hex_length)) {
- throw std::runtime_error("Hexencoded string is malformatted");
- }
- decoded.resize(data_length);
- return decoded;
- }
+ static std::vector<uint8_t> from_hex(const char* hex, size_t hex_length);
/**
* Hexdecodes a string
@@ -292,33 +258,22 @@
/**
* Hexencodes bytes and writes the result to hex
- * @param hex the output buffer where the hexencoded string will be written. Must be at least length * 2 bytes long.
+ * @param hex the output buffer where the hexencoded bytes will be written. Must be at least length * 2 bytes long.
* @param data the bytes to be hexencoded
- * @param length the length of data. Must not be larger than std::numeric_limits<size_t>::max()
+ * @param length the length of data. Must not be larger than std::numeric_limits<size_t>::max() / 2
* @param uppercase whether the hexencoded string should be upper case
+ * @return the size of hexencoded bytes
*/
- inline static void to_hex(char* hex, const uint8_t* data, size_t length, bool uppercase) {
- for (size_t i = 0; i < length; i++) {
- hex[i * 2] = nibble_to_hex(data[i] >> 4, uppercase);
- hex[i * 2 + 1] = nibble_to_hex(data[i] & 0xf, uppercase);
- }
- }
+ static size_t to_hex(char* hex, const uint8_t* data, size_t length, bool uppercase);
/**
* Creates a hexencoded string from data
* @param data the bytes to be hexencoded
- * @param length the length of the data
+ * @param length the length of data. Must not be larger than std::numeric_limits<size_t>::max() / 2 - 1
* @param uppercase whether the hexencoded string should be upper case
* @return the hexencoded string
*/
- inline static std::string to_hex(const uint8_t* data, size_t length, bool uppercase = false) {
- if (length > (std::numeric_limits<size_t>::max)() / 2) {
- throw std::length_error("Data is too large to be hexencoded");
- }
- std::vector<char> buf(length * 2);
- to_hex(buf.data(), data, length, uppercase);
- return std::string(buf.data(), buf.size());
- }
+ static std::string to_hex(const uint8_t* data, size_t length, bool uppercase = false);
/**
* Hexencodes a string
@@ -330,6 +285,87 @@
return to_hex(reinterpret_cast<const uint8_t*>(str.data()), str.length(), uppercase);
}
+ /**
+ * Hexencodes a vector of bytes
+ * @param data the vector of bytes to be hexencoded
+ * @param uppercase whether the hexencoded string should be upper case
+ * @return the hexencoded string
+ */
+ inline static std::string to_hex(const std::vector<uint8_t>& data, bool uppercase = false) {
+ return to_hex(data.data(), data.size(), uppercase);
+ }
+
+ /**
+ * Decodes the Base64 encoded string into data
+ * @param data the output buffer where the decoded bytes will be written. Must be at least (base64_length / 4 + 1) * 3 bytes long.
+ * @param data_length pointer to the length of data the data buffer. It will be filled with the length of the decoded bytes.
+ * @param base64 the Base64 encoded string
+ * @param base64_length the length of base64
+ * @return true on success
+ */
+ static bool from_base64(uint8_t* data, size_t* data_length, const char* base64, size_t base64_length);
+
+ /**
+ * Base64 decodes a string
+ * @param base64 the Base64 encoded string
+ * @param base64_length the length of base64
+ * @return the vector containing the decoded bytes
+ */
+ static std::vector<uint8_t> from_base64(const char* base64, size_t base64_length);
+
+ /**
+ * Base64 decodes a string
+ * @param base64 the Base64 encoded string
+ * @return the decoded string
+ */
+ inline static std::string from_base64(const std::string& base64) {
+ auto data = from_base64(base64.data(), base64.length());
+ return std::string(reinterpret_cast<char*>(data.data()), data.size());
+ }
+
+ /**
+ * Base64 encodes bytes and writes the result to base64
+ * @param base64 the output buffer where the Base64 encoded bytes will be written. Must be at least (base64_length / 3 + 1) * 4 bytes long.
+ * @param data the bytes to be Base64 encoded
+ * @param length the length of data. Must not be larger than std::numeric_limits<size_t>::max() * 3 / 4 - 3
+ * @param url if true, the URL-safe Base64 encoding will be used
+ * @param padded if true, padding is added to the Base64 encoded string
+ * @return the size of Base64 encoded bytes
+ */
+ static size_t to_base64(char* base64, const uint8_t* data, size_t length, bool url, bool padded);
+
+ /**
+ * Creates a Base64 encoded string from data
+ * @param data the bytes to be Base64 encoded
+ * @param length the length of the data
+ * @param url if true, the URL-safe Base64 encoding will be used
+ * @param padded if true, padding is added to the Base64 encoded string
+ * @return the Base64 encoded string
+ */
+ static std::string to_base64(const uint8_t* data, size_t length, bool url = false, bool padded = true);
+
+ /**
+ * Base64 encodes a string
+ * @param str the string to be Base64 encoded
+ * @param url if true, the URL-safe Base64 encoding will be used
+ * @param padded if true, padding is added to the Base64 encoded string
+ * @return the Base64 encoded string
+ */
+ inline static std::string to_base64(const std::string& str, bool url = false, bool padded = true) {
+ return to_base64(reinterpret_cast<const uint8_t*>(str.data()), str.length(), url, padded);
+ }
+
+ /**
+ * Base64 encodes a string
+ * @param str the string to be Base64 encoded
+ * @param url if true, the URL-safe Base64 encoding will be used
+ * @param padded if true, padding is added to the Base64 encoded string
+ * @return the Base64 encoded string
+ */
+ inline static std::string to_base64(const std::vector<uint8_t>& str, bool url = false, bool padded = true) {
+ return to_base64(str.data(), str.size(), url, padded);
+ }
+
static std::string replaceMap(std::string source_string, const std::map<std::string, std::string> &replace_map);
private:
@@ -341,7 +377,15 @@
}
}
- static constexpr uint8_t SKIP = 255;
+ inline static void base64_digits_to_bytes(const uint8_t digits[4], uint8_t* bytes) {
+ bytes[0] = digits[0] << 2 | digits[1] >> 4;
+ bytes[1] = (digits[1] & 0x0f) << 4 | digits[2] >> 2;
+ bytes[2] = (digits[2] & 0x03) << 6 | digits[3];
+ }
+
+ static constexpr uint8_t SKIP = 0xff;
+ static constexpr uint8_t ILGL = 0xfe;
+ static constexpr uint8_t PDNG = 0xfd;
static constexpr uint8_t hex_lut[128] =
{SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP,
SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP,
@@ -360,6 +404,26 @@
SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP,
SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP, SKIP};
+ static constexpr const char base64_enc_lut[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+ static constexpr const char base64_url_enc_lut[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
+ static constexpr uint8_t base64_dec_lut[128] =
+ {ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL,
+ ILGL, ILGL, SKIP, ILGL, ILGL, SKIP, ILGL, ILGL,
+ ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL,
+ ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL,
+ ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL, ILGL,
+ ILGL, ILGL, ILGL, 0x3e, ILGL, 0x3e, ILGL, 0x3f,
+ 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b,
+ 0x3c, 0x3d, ILGL, ILGL, ILGL, PDNG, ILGL, ILGL,
+ ILGL, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
+ 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
+ 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16,
+ 0x17, 0x18, 0x19, ILGL, ILGL, ILGL, ILGL, 0x3f,
+ ILGL, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
+ 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
+ 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
+ 0x31, 0x32, 0x33, ILGL, ILGL, ILGL, ILGL, ILGL};
+
};
} /* namespace utils */
diff --git a/libminifi/include/utils/base64.h b/libminifi/include/utils/base64.h
deleted file mode 100644
index d99a3fe..0000000
--- a/libminifi/include/utils/base64.h
+++ /dev/null
@@ -1,188 +0,0 @@
-/***************************************************************************
- * _ _ ____ _
- * Project ___| | | | _ \| |
- * / __| | | | |_) | |
- * | (__| |_| | _ <| |___
- * \___|\___/|_| \_\_____|
- *
- * Copyright (C) 1998 - 2007, Daniel Stenberg, <daniel@haxx.se>, et al.
- *
- * This software is licensed as described in the file COPYING, which
- * you should have received as part of this distribution. The terms
- * are also available at http://curl.haxx.se/docs/copyright.html.
- *
- * You may opt to use, copy, modify, merge, publish, distribute and/or sell
- * copies of the Software, and permit persons to whom the Software is
- * furnished to do so, under the terms of the COPYING file.
- *
- * This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
- * KIND, either express or implied.
- *
- * $Id$
- ***************************************************************************/
-
-#ifndef NIFI_MINIFI_CPP_BASE64_H
-#define NIFI_MINIFI_CPP_BASE64_H
-
-/* ---- Base64 Encoding/Decoding Table --- */
-static const char table64[]=
- "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
-
-static void decodeQuantum(unsigned char *dest, const char *src)
-{
- unsigned int x = 0;
- int i;
- char *found;
-
- for(i = 0; i < 4; i++) {
- if((found = strchr((char *)table64, src[i])))
- x = (x << 6) + (unsigned int)(found - table64);
- else if(src[i] == '=')
- x = (x << 6);
- }
-
- dest[2] = (unsigned char)(x & 255);
- x >>= 8;
- dest[1] = (unsigned char)(x & 255);
- x >>= 8;
- dest[0] = (unsigned char)(x & 255);
-}
-
-
-/*
- * Curl_base64_decode()
- *
- * Given a base64 string at src, decode it and return an allocated memory in
- * the *outptr. Returns the length of the decoded data.
- */
-static size_t Curl_base64_decode(const char *src, unsigned char **outptr)
-{
- int length = 0;
- int equalsTerm = 0;
- int i;
- int numQuantums;
- unsigned char lastQuantum[3];
- size_t rawlen=0;
- unsigned char *newstr;
-
- *outptr = NULL;
-
- while((src[length] != '=') && src[length])
- length++;
- /* A maximum of two = padding characters is allowed */
- if(src[length] == '=') {
- equalsTerm++;
- if(src[length+equalsTerm] == '=')
- equalsTerm++;
- }
- numQuantums = (length + equalsTerm) / 4;
-
- /* Don't allocate a buffer if the decoded length is 0 */
- if (numQuantums <= 0)
- return 0;
-
- rawlen = (numQuantums * 3) - equalsTerm;
-
- /* The buffer must be large enough to make room for the last quantum
- (which may be partially thrown out) and the zero terminator. */
- newstr = static_cast<unsigned char *>(malloc(rawlen+4));
- if(!newstr)
- return 0;
-
- *outptr = newstr;
-
- /* Decode all but the last quantum (which may not decode to a
- multiple of 3 bytes) */
- for(i = 0; i < numQuantums - 1; i++) {
- decodeQuantum((unsigned char *)newstr, src);
- newstr += 3; src += 4;
- }
-
- /* This final decode may actually read slightly past the end of the buffer
- if the input string is missing pad bytes. This will almost always be
- harmless. */
- decodeQuantum(lastQuantum, src);
- for(i = 0; i < 3 - equalsTerm; i++)
- newstr[i] = lastQuantum[i];
-
- newstr[i] = 0; /* zero terminate */
- return rawlen;
-}
-
-/*
- * Curl_base64_encode()
- *
- * Returns the length of the newly created base64 string. The third argument
- * is a pointer to an allocated area holding the base64 data. If something
- * went wrong, -1 is returned.
- *
- */
-static size_t Curl_base64_encode(const char *inp, size_t insize, char **outptr)
-{
- unsigned char ibuf[3];
- unsigned char obuf[4];
- int i;
- int inputparts;
- char *output;
- char *base64data;
-
- char *indata = (char *)inp;
-
- *outptr = NULL; /* set to NULL in case of failure before we reach the end */
-
- if(0 == insize)
- insize = strlen(indata);
-
- base64data = output = (char*)malloc(insize*4/3+4);
- if(NULL == output)
- return 0;
-
- while(insize > 0) {
- for (i = inputparts = 0; i < 3; i++) {
- if(insize > 0) {
- inputparts++;
- ibuf[i] = *indata;
- indata++;
- insize--;
- }
- else
- ibuf[i] = 0;
- }
-
- obuf[0] = (unsigned char) ((ibuf[0] & 0xFC) >> 2);
- obuf[1] = (unsigned char) (((ibuf[0] & 0x03) << 4) | \
- ((ibuf[1] & 0xF0) >> 4));
- obuf[2] = (unsigned char) (((ibuf[1] & 0x0F) << 2) | \
- ((ibuf[2] & 0xC0) >> 6));
- obuf[3] = (unsigned char) (ibuf[2] & 0x3F);
-
- switch(inputparts) {
- case 1: /* only one byte read */
- snprintf(output, 5, "%c%c==",
- table64[obuf[0]],
- table64[obuf[1]]);
- break;
- case 2: /* two bytes read */
- snprintf(output, 5, "%c%c%c=",
- table64[obuf[0]],
- table64[obuf[1]],
- table64[obuf[2]]);
- break;
- default:
- snprintf(output, 5, "%c%c%c%c",
- table64[obuf[0]],
- table64[obuf[1]],
- table64[obuf[2]],
- table64[obuf[3]] );
- break;
- }
- output += 4;
- }
- *output=0;
- *outptr = base64data; /* make it return the actual data memory */
-
- return strlen(base64data); /* return the length of the new data */
-}
-/* ---- End of Base64 Encoding ---- */
-
-#endif //NIFI_MINIFI_CPP_BASE64_H
diff --git a/libminifi/src/utils/StringUtils.cpp b/libminifi/src/utils/StringUtils.cpp
index af33fa0..ce96eec 100644
--- a/libminifi/src/utils/StringUtils.cpp
+++ b/libminifi/src/utils/StringUtils.cpp
@@ -164,8 +164,185 @@
return result_string;
}
+bool StringUtils::from_hex(uint8_t* data, size_t* data_length, const char* hex, size_t hex_length) {
+ if (*data_length < hex_length / 2) {
+ return false;
+ }
+ uint8_t n1;
+ bool found_first_nibble = false;
+ *data_length = 0;
+ for (size_t i = 0; i < hex_length; i++) {
+ const uint8_t byte = static_cast<uint8_t>(hex[i]);
+ if (byte > 127) {
+ continue;
+ }
+ uint8_t n = hex_lut[byte];
+ if (n != SKIP) {
+ if (found_first_nibble) {
+ data[(*data_length)++] = n1 << 4 | n;
+ found_first_nibble = false;
+ } else {
+ n1 = n;
+ found_first_nibble = true;
+ }
+ }
+ }
+ if (found_first_nibble) {
+ return false;
+ }
+ return true;
+}
+
+std::vector<uint8_t> StringUtils::from_hex(const char* hex, size_t hex_length) {
+ std::vector<uint8_t> decoded(hex_length / 2);
+ size_t data_length = decoded.size();
+ if (!from_hex(decoded.data(), &data_length, hex, hex_length)) {
+ throw std::invalid_argument("Hexencoded string is malformatted");
+ }
+ decoded.resize(data_length);
+ return decoded;
+}
+
+size_t StringUtils::to_hex(char* hex, const uint8_t* data, size_t length, bool uppercase) {
+ if (length > (std::numeric_limits<size_t>::max)() / 2) {
+ throw std::length_error("Data is too large to be hexencoded");
+ }
+ for (size_t i = 0; i < length; i++) {
+ hex[i * 2] = nibble_to_hex(data[i] >> 4, uppercase);
+ hex[i * 2 + 1] = nibble_to_hex(data[i] & 0xf, uppercase);
+ }
+ return length * 2;
+}
+
+std::string StringUtils::to_hex(const uint8_t* data, size_t length, bool uppercase /*= false*/) {
+ if (length > ((std::numeric_limits<size_t>::max)() / 2 - 1)) {
+ throw std::length_error("Data is too large to be hexencoded");
+ }
+ std::vector<char> buf(length * 2);
+ const size_t hex_length = to_hex(buf.data(), data, length, uppercase);
+ return std::string(buf.data(), hex_length);
+}
+
+bool StringUtils::from_base64(uint8_t* data, size_t* data_length, const char* base64, size_t base64_length) {
+ if (*data_length < (base64_length / 4 + 1) * 3) {
+ return false;
+ }
+
+ uint8_t digits[4];
+ size_t digit_counter = 0U;
+ size_t decoded_size = 0U;
+ size_t padding_counter = 0U;
+ size_t i;
+ for (i = 0U; i < base64_length; i++) {
+ const uint8_t byte = static_cast<uint8_t>(base64[i]);
+ if (byte > 127) {
+ return false;
+ }
+
+ const uint8_t decoded = base64_dec_lut[byte];
+ switch (decoded) {
+ case SKIP:
+ continue;
+ case ILGL:
+ return false;
+ case PDNG:
+ padding_counter++;
+ continue;
+ default:
+ if (padding_counter > 0U) {
+ return false;
+ }
+ digits[digit_counter++] = decoded;
+ if (digit_counter == 4U) {
+ base64_digits_to_bytes(digits, data + decoded_size);
+ decoded_size += 3U;
+ digit_counter = 0U;
+ }
+ }
+ }
+
+ if (padding_counter > 0U && padding_counter != 4U - digit_counter) {
+ return false;
+ }
+
+ switch (digit_counter) {
+ case 0:
+ break;
+ case 1:
+ return false;
+ case 2:
+ digits[2] = 0x00;
+ case 3: {
+ digits[3] = 0x00;
+
+ uint8_t bytes_temp[3];
+ base64_digits_to_bytes(digits, bytes_temp);
+ const size_t num_bytes = digit_counter - 1;
+ memcpy(data + decoded_size, bytes_temp, num_bytes);
+ decoded_size += num_bytes;
+ break;
+ }
+ default:
+ return false;
+ }
+
+ *data_length = decoded_size;
+ return true;
+}
+
+std::vector<uint8_t> StringUtils::from_base64(const char* base64, size_t base64_length) {
+ std::vector<uint8_t> decoded((base64_length / 4 + 1) * 3);
+ size_t data_length = decoded.size();
+ if (!from_base64(decoded.data(), &data_length, base64, base64_length)) {
+ throw std::invalid_argument("Base64 encoded string is malformatted");
+ }
+ decoded.resize(data_length);
+ return decoded;
+}
+
+size_t StringUtils::to_base64(char* base64, const uint8_t* data, size_t length, bool url, bool padded) {
+ if (length > (std::numeric_limits<size_t>::max)() * 3 / 4 - 3) {
+ throw std::length_error("Data is too large to be base64 encoded");
+ }
+
+ const char* enc_lut = url ? base64_url_enc_lut : base64_enc_lut;
+ size_t base64_length = 0U;
+ uint8_t bytes[3];
+ for (size_t i = 0U; i < length; i += 3U) {
+ const bool b1_present = i + 1 < length;
+ const bool b2_present = i + 2 < length;
+ bytes[0] = data[i];
+ bytes[1] = b1_present ? data[i + 1] : 0x00;
+ bytes[2] = b2_present ? data[i + 2] : 0x00;
+
+ base64[base64_length++] = enc_lut[(bytes[0] & 0xfc) >> 2];
+ base64[base64_length++] = enc_lut[(bytes[0] & 0x03) << 4 | (bytes[1] & 0xf0) >> 4];
+ if (b1_present) {
+ base64[base64_length++] = enc_lut[(bytes[1] & 0x0f) << 2 | (bytes[2] & 0xc0) >> 6];
+ } else if (padded) {
+ base64[base64_length++] = '=';
+ }
+ if (b2_present) {
+ base64[base64_length++] = enc_lut[bytes[2] & 0x3f];
+ } else if (padded) {
+ base64[base64_length++] = '=';
+ }
+ }
+
+ return base64_length;
+}
+
+std::string StringUtils::to_base64(const uint8_t* data, size_t length, bool url /*= false*/, bool padded /*= true*/) {
+ std::vector<char> buf((length / 3 + 1) * 4);
+ size_t base64_length = to_base64(buf.data(), data, length, url, padded);
+ return std::string(buf.data(), base64_length);
+}
+
constexpr uint8_t StringUtils::SKIP;
constexpr uint8_t StringUtils::hex_lut[128];
+constexpr const char StringUtils::base64_enc_lut[];
+constexpr const char StringUtils::base64_url_enc_lut[];
+constexpr uint8_t StringUtils::base64_dec_lut[128];
} /* namespace utils */
} /* namespace minifi */
diff --git a/libminifi/test/unit/StringUtilsTests.cpp b/libminifi/test/unit/StringUtilsTests.cpp
index d5149c3..30fac99 100644
--- a/libminifi/test/unit/StringUtilsTests.cpp
+++ b/libminifi/test/unit/StringUtilsTests.cpp
@@ -127,16 +127,16 @@
REQUIRE("" == StringUtils::to_hex(""));
REQUIRE("6f" == StringUtils::to_hex("o"));
REQUIRE("666f6f626172" == StringUtils::to_hex("foobar"));
- REQUIRE("000102030405060708090a0b0c0d0e0f" == StringUtils::to_hex({0x00, 0x01, 0x02, 0x03,
- 0x04, 0x05, 0x06, 0x07,
- 0x08, 0x09, 0x0a, 0x0b,
- 0x0c, 0x0d, 0x0e, 0x0f}));
+ REQUIRE("000102030405060708090a0b0c0d0e0f" == StringUtils::to_hex(std::vector<uint8_t>{0x00, 0x01, 0x02, 0x03,
+ 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f}));
REQUIRE("6F" == StringUtils::to_hex("o", true /*uppercase*/));
REQUIRE("666F6F626172" == StringUtils::to_hex("foobar", true /*uppercase*/));
- REQUIRE("000102030405060708090A0B0C0D0E0F" == StringUtils::to_hex({0x00, 0x01, 0x02, 0x03,
- 0x04, 0x05, 0x06, 0x07,
- 0x08, 0x09, 0x0a, 0x0b,
- 0x0c, 0x0d, 0x0e, 0x0f}, true /*uppercase*/));
+ REQUIRE("000102030405060708090A0B0C0D0E0F" == StringUtils::to_hex(std::vector<uint8_t>{0x00, 0x01, 0x02, 0x03,
+ 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f}, true /*uppercase*/));
}
TEST_CASE("TestStringUtils::testHexDecode", "[test hex decode]") {
@@ -155,28 +155,133 @@
0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f}) == StringUtils::from_hex("000102030405060708090A0B0C0D0E0F"));
- try {
- StringUtils::from_hex("666f6f62617");
- abort();
- } catch (std::exception& e) {
- REQUIRE(std::string("Hexencoded string is malformatted") == e.what());
- }
- try {
- StringUtils::from_hex("666f6f6261 7");
- abort();
- } catch (std::exception& e) {
- REQUIRE(std::string("Hexencoded string is malformatted") == e.what());
- }
+
+ REQUIRE_THROWS_WITH(StringUtils::from_hex("666f6f62617"), "Hexencoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_hex("666f6f6261 7"), "Hexencoded string is malformatted");
}
TEST_CASE("TestStringUtils::testHexEncodeDecode", "[test hex encode decode]") {
std::mt19937 gen(std::random_device { }());
- const bool uppercase = gen() % 2;
- const size_t length = gen() % 1024;
- std::vector<uint8_t> data(length);
- std::generate_n(data.begin(), data.size(), [&]() -> uint8_t {
- return gen() % 256;
- });
- auto hex = utils::StringUtils::to_hex(data.data(), data.size(), uppercase);
- REQUIRE(data == utils::StringUtils::from_hex(hex.data(), hex.size()));
+ for (size_t i = 0U; i < 1024U; i++) {
+ const bool uppercase = gen() % 2;
+ const size_t length = gen() % 1024;
+ std::vector<uint8_t> data(length);
+ std::generate_n(data.begin(), data.size(), [&]() -> uint8_t {
+ return gen() % 256;
+ });
+ auto hex = utils::StringUtils::to_hex(data.data(), data.size(), uppercase);
+ REQUIRE(data == utils::StringUtils::from_hex(hex.data(), hex.size()));
+ }
+}
+
+TEST_CASE("TestStringUtils::testBase64Encode", "[test base64 encode]") {
+ REQUIRE("" == StringUtils::to_base64(""));
+
+ REQUIRE("bw==" == StringUtils::to_base64("o"));
+ REQUIRE("b28=" == StringUtils::to_base64("oo"));
+ REQUIRE("b29v" == StringUtils::to_base64("ooo"));
+ REQUIRE("b29vbw==" == StringUtils::to_base64("oooo"));
+ REQUIRE("b29vb28=" == StringUtils::to_base64("ooooo"));
+ REQUIRE("b29vb29v" == StringUtils::to_base64("oooooo"));
+
+ REQUIRE("bw" == StringUtils::to_base64("o", false /*url*/, false /*padded*/));
+ REQUIRE("b28" == StringUtils::to_base64("oo", false /*url*/, false /*padded*/));
+ REQUIRE("b29v" == StringUtils::to_base64("ooo", false /*url*/, false /*padded*/));
+ REQUIRE("b29vbw" == StringUtils::to_base64("oooo", false /*url*/, false /*padded*/));
+ REQUIRE("b29vb28" == StringUtils::to_base64("ooooo", false /*url*/, false /*padded*/));
+ REQUIRE("b29vb29v" == StringUtils::to_base64("oooooo", false /*url*/, false /*padded*/));
+
+ REQUIRE("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" ==
+ StringUtils::to_base64(std::vector<uint8_t>{0x00, 0x10, 0x83, 0x10,
+ 0x51, 0x87, 0x20, 0x92,
+ 0x8b, 0x30, 0xd3, 0x8f,
+ 0x41, 0x14, 0x93, 0x51,
+ 0x55, 0x97, 0x61, 0x96,
+ 0x9b, 0x71, 0xd7, 0x9f,
+ 0x82, 0x18, 0xa3, 0x92,
+ 0x59, 0xa7, 0xa2, 0x9a,
+ 0xab, 0xb2, 0xdb, 0xaf,
+ 0xc3, 0x1c, 0xb3, 0xd3,
+ 0x5d, 0xb7, 0xe3, 0x9e,
+ 0xbb, 0xf3, 0xdf, 0xbf}));
+ REQUIRE("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" ==
+ StringUtils::to_base64(std::vector<uint8_t>{0x00, 0x10, 0x83, 0x10,
+ 0x51, 0x87, 0x20, 0x92,
+ 0x8b, 0x30, 0xd3, 0x8f,
+ 0x41, 0x14, 0x93, 0x51,
+ 0x55, 0x97, 0x61, 0x96,
+ 0x9b, 0x71, 0xd7, 0x9f,
+ 0x82, 0x18, 0xa3, 0x92,
+ 0x59, 0xa7, 0xa2, 0x9a,
+ 0xab, 0xb2, 0xdb, 0xaf,
+ 0xc3, 0x1c, 0xb3, 0xd3,
+ 0x5d, 0xb7, 0xe3, 0x9e,
+ 0xbb, 0xf3, 0xdf, 0xbf}, true /*url*/));
+}
+
+TEST_CASE("TestStringUtils::testBase64Decode", "[test base64 decode]") {
+ REQUIRE("" == StringUtils::from_base64(""));
+ REQUIRE("o" == StringUtils::from_base64("bw=="));
+ REQUIRE("oo" == StringUtils::from_base64("b28="));
+ REQUIRE("ooo" == StringUtils::from_base64("b29v"));
+ REQUIRE("oooo" == StringUtils::from_base64("b29vbw=="));
+ REQUIRE("ooooo" == StringUtils::from_base64("b29vb28="));
+ REQUIRE("oooooo" == StringUtils::from_base64("b29vb29v"));
+ REQUIRE("\xfb\xff\xbf" == StringUtils::from_base64("-_-_"));
+ REQUIRE("\xfb\xff\xbf" == StringUtils::from_base64("+/+/"));
+ REQUIRE(std::string({ 0, 16, -125, 16,
+ 81, -121, 32, -110,
+ -117, 48, -45, -113,
+ 65, 20, -109, 81,
+ 85, -105, 97, -106,
+ -101, 113, -41, -97,
+ -126, 24, -93, -110,
+ 89, -89, -94, -102,
+ -85, -78, -37, -81,
+ -61, 28, -77, -45,
+ 93, -73, -29, -98,
+ -69, -13, -33, -65}) == StringUtils::from_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"));
+ REQUIRE(std::string({ 0, 16, -125, 16,
+ 81, -121, 32, -110,
+ -117, 48, -45, -113,
+ 65, 20, -109, 81,
+ 85, -105, 97, -106,
+ -101, 113, -41, -97,
+ -126, 24, -93, -110,
+ 89, -89, -94, -102,
+ -85, -78, -37, -81,
+ -61, 28, -77, -45,
+ 93, -73, -29, -98,
+ -69, -13, -33, -65}) == StringUtils::from_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"));
+
+ REQUIRE("foobarbuzz" == StringUtils::from_base64("Zm9vYmFyYnV6eg=="));
+ REQUIRE("foobarbuzz"== StringUtils::from_base64("\r\nZm9vYmFyYnV6eg=="));
+ REQUIRE("foobarbuzz" == StringUtils::from_base64("Zm9\r\nvYmFyYnV6eg=="));
+ REQUIRE("foobarbuzz" == StringUtils::from_base64("Zm\r9vYmFy\n\n\n\n\n\n\n\nYnV6eg=="));
+ REQUIRE("foobarbuzz" == StringUtils::from_base64("\nZ\nm\n9\nv\nY\nm\nF\ny\nY\nn\nV\n6\ne\ng\n=\n=\n"));
+
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("a"), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aaaaa"), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aa="), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aaaaaa="), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aa==?"), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aa==a"), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aa==="), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("?"), "Base64 encoded string is malformatted");
+ REQUIRE_THROWS_WITH(StringUtils::from_base64("aaaa?"), "Base64 encoded string is malformatted");
+}
+
+TEST_CASE("TestStringUtils::testBase64EncodeDecode", "[test base64 encode decode]") {
+ std::mt19937 gen(std::random_device { }());
+ for (size_t i = 0U; i < 1024U; i++) {
+ const bool url = gen() % 2;
+ const bool padded = gen() % 2;
+ const size_t length = gen() % 1024;
+ std::vector<uint8_t> data(length);
+ std::generate_n(data.begin(), data.size(), [&]() -> uint8_t {
+ return gen() % 256;
+ });
+ auto base64 = utils::StringUtils::to_base64(data.data(), data.size(), url, padded);
+ REQUIRE(data == utils::StringUtils::from_base64(base64.data(), base64.size()));
+ }
}