| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "util/coding-util.h" |
| |
| #include <exception> |
| #include <sstream> |
| #include <boost/algorithm/string.hpp> |
| |
| #include "common/compiler-util.h" |
| #include "common/logging.h" |
| |
| #include "common/names.h" |
| #include "sasl/saslutil.h" |
| |
| using boost::algorithm::is_any_of; |
| using namespace impala; |
| using std::uppercase; |
| |
| namespace impala { |
| |
| // Hive selectively encodes characters. This is the whitelist of |
| // characters it will encode. |
| // See common/src/java/org/apache/hadoop/hive/common/FileUtils.java |
| // in the Hive source code for the source of this list. |
| static function<bool (char)> HiveShouldEscape = is_any_of("\"#%\\*/:=?\u00FF"); |
| |
| // It is more convenient to maintain the complement of the set of |
| // characters to escape when not in Hive-compat mode. |
| static function<bool (char)> ShouldNotEscape = is_any_of("-_.~"); |
| |
| static inline void UrlEncode(const char* in, int in_len, string* out, bool hive_compat) { |
| (*out).reserve(in_len); |
| stringstream ss; |
| for (int i = 0; i < in_len; ++i) { |
| const char ch = in[i]; |
| // Escape the character iff a) we are in Hive-compat mode and the |
| // character is in the Hive whitelist or b) we are not in |
| // Hive-compat mode, and the character is not alphanumeric or one |
| // of the four commonly excluded characters. |
| if ((hive_compat && HiveShouldEscape(ch)) || |
| (!hive_compat && !(isalnum(ch) || ShouldNotEscape(ch)))) { |
| ss << '%' << uppercase << hex << static_cast<uint32_t>(ch); |
| } else { |
| ss << ch; |
| } |
| } |
| |
| (*out) = ss.str(); |
| } |
| |
| void UrlEncode(const vector<uint8_t>& in, string* out, bool hive_compat) { |
| if (in.empty()) { |
| *out = ""; |
| } else { |
| UrlEncode(reinterpret_cast<const char*>(in.data()), in.size(), out, hive_compat); |
| } |
| } |
| |
| void UrlEncode(const string& in, string* out, bool hive_compat) { |
| UrlEncode(in.c_str(), in.size(), out, hive_compat); |
| } |
| |
| // Adapted from |
| // http://www.boost.org/doc/libs/1_40_0/doc/html/boost_asio/ |
| // example/http/server3/request_handler.cpp |
| // See http://www.boost.org/LICENSE_1_0.txt for license for this method. |
| bool UrlDecode(const string& in, string* out, bool hive_compat) { |
| out->clear(); |
| out->reserve(in.size()); |
| for (size_t i = 0; i < in.size(); ++i) { |
| if (in[i] == '%') { |
| if (i + 3 <= in.size()) { |
| int value = 0; |
| istringstream is(in.substr(i + 1, 2)); |
| if (is >> hex >> value) { |
| (*out) += static_cast<char>(value); |
| i += 2; |
| } else { |
| return false; |
| } |
| } else { |
| return false; |
| } |
| } else if (!hive_compat && in[i] == '+') { // Hive does not encode ' ' as '+' |
| (*out) += ' '; |
| } else { |
| (*out) += in[i]; |
| } |
| } |
| return true; |
| } |
| |
| bool Base64EncodeBufLen(int64_t in_len, int64_t* out_max) { |
| // Base64 encoding turns every 3 bytes into 4 characters. If the length is not |
| // divisible by 3, it pads the input with extra 0 bytes until it is divisible by 3. |
| // One more character must be allocated to account for Base64Encode's null-padding |
| // of its output. |
| *out_max = 1 + 4 * ((in_len + 2) / 3); |
| if (UNLIKELY(in_len < 0 || |
| *out_max > static_cast<unsigned>(std::numeric_limits<int>::max()))) { |
| return false; |
| } |
| return true; |
| } |
| |
| bool Base64Encode(const char* in, int64_t in_len, int64_t out_max, char* out, |
| int64_t* out_len) { |
| if (UNLIKELY(in_len < 0 || in_len > std::numeric_limits<unsigned>::max() || |
| out_max < 0 || out_max > std::numeric_limits<unsigned>::max())) { |
| return false; |
| } |
| const int encode_result = sasl_encode64(in, static_cast<unsigned>(in_len), out, |
| static_cast<unsigned>(out_max), reinterpret_cast<unsigned*>(out_len)); |
| if (UNLIKELY(encode_result != SASL_OK || *out_len != out_max - 1)) return false; |
| return true; |
| } |
| |
| static inline void Base64Encode(const char* in, int64_t in_len, stringstream* out) { |
| if (in_len == 0) { |
| (*out) << ""; |
| return; |
| } |
| int64_t out_max = 0; |
| if (UNLIKELY(!Base64EncodeBufLen(in_len, &out_max))) return; |
| string result(out_max, '\0'); |
| int64_t out_len = 0; |
| if (UNLIKELY(!Base64Encode(in, in_len, out_max, const_cast<char*>(result.c_str()), |
| &out_len))) { |
| return; |
| } |
| result.resize(out_len); |
| (*out) << result; |
| } |
| |
| void Base64Encode(const vector<uint8_t>& in, string* out) { |
| if (in.empty()) { |
| *out = ""; |
| } else { |
| stringstream ss; |
| Base64Encode(in, &ss); |
| *out = ss.str(); |
| } |
| } |
| |
| void Base64Encode(const vector<uint8_t>& in, stringstream* out) { |
| if (!in.empty()) { |
| // Boost does not like non-null terminated strings |
| string tmp(reinterpret_cast<const char*>(in.data()), in.size()); |
| Base64Encode(tmp.c_str(), tmp.size(), out); |
| } |
| } |
| |
| void Base64Encode(const string& in, string* out) { |
| stringstream ss; |
| Base64Encode(in.c_str(), in.size(), &ss); |
| *out = ss.str(); |
| } |
| |
| void Base64Encode(const string& in, stringstream* out) { |
| Base64Encode(in.c_str(), in.size(), out); |
| } |
| |
| bool Base64DecodeBufLen(const char* in, int64_t in_len, int64_t* out_max) { |
| // Base64 decoding turns every 4 characters into 3 bytes. If the last character of the |
| // encoded string is '=', that character (which represents 6 bits) and the last two bits |
| // of the previous character is ignored, for a total of 8 ignored bits, therefore |
| // producing one fewer byte of output. This is repeated if the second-to-last character |
| // is '='. One more byte must be allocated to account for Base64Decode's null-padding |
| // of its output. |
| if (UNLIKELY((in_len & 3) != 0)) return false; |
| *out_max = 1 + 3 * (in_len / 4); |
| if (in[in_len - 1] == '=') { |
| --(*out_max); |
| if (in[in_len - 2] == '=') { |
| --(*out_max); |
| } |
| } |
| return true; |
| } |
| |
| bool Base64Decode(const char* in, int64_t in_len, int64_t out_max, char* out, |
| int64_t* out_len) { |
| uint32_t out_len_u32 = 0; |
| if (UNLIKELY((in_len & 3) != 0)) return false; |
| const int decode_result = sasl_decode64(in, static_cast<unsigned>(in_len), out, |
| static_cast<unsigned>(out_max), &out_len_u32); |
| *out_len = out_len_u32; |
| if (UNLIKELY(decode_result != SASL_OK || *out_len != out_max - 1)) return false; |
| return true; |
| } |
| |
| void EscapeForHtml(const string& in, stringstream* out) { |
| DCHECK(out != NULL); |
| for (const char c: in) { |
| switch (c) { |
| case '<': (*out) << "<"; |
| break; |
| case '>': (*out) << ">"; |
| break; |
| case '&': (*out) << "&"; |
| break; |
| default: (*out) << c; |
| } |
| } |
| } |
| |
| } |