blob: f16ecd68953c022249fd358e58f4c8f15f58588b [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.fineract.infrastructure.security.utils;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.fineract.infrastructure.security.service.SqlValidator;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.stereotype.Component;
@Slf4j
@RequiredArgsConstructor
@Component
public class ColumnValidator {
private final SqlValidator sqlValidator;
private final JdbcTemplate jdbcTemplate;
@SuppressFBWarnings(value = "NP_NULL_ON_SOME_PATH_FROM_RETURN_VALUE", justification = "TODO: fix this!")
private void validateColumn(Map<String, Set<String>> tableColumnMap) {
Connection connection = null;
try {
connection = Objects.requireNonNull(this.jdbcTemplate.getDataSource()).getConnection();
DatabaseMetaData dbMetaData = connection.getMetaData();
for (Map.Entry<String, Set<String>> entry : tableColumnMap.entrySet()) {
Set<String> columns = entry.getValue();
ResultSet resultSet = dbMetaData.getColumns(null, null, entry.getKey(), null);
Set<String> tableColumns = getTableColumns(resultSet);
if (!columns.isEmpty() && tableColumns.isEmpty()) {
throw new SQLInjectionException();
}
for (String requestedColumn : columns) {
if (!tableColumns.contains(requestedColumn)) {
throw new SQLInjectionException();
}
}
}
} catch (SQLException e) {
throw new SQLInjectionException(e);
} finally {
if (connection != null) {
DataSourceUtils.releaseConnection(connection, jdbcTemplate.getDataSource());
}
connection = null;
}
}
private Set<String> getTableColumns(final ResultSet rs) {
Set<String> columns = new HashSet<>();
try {
while (rs.next()) {
columns.add(rs.getString("column_name"));
}
} catch (SQLException e) {
log.error("Problem occurred in getTableColumns function", e);
}
return columns;
}
public void validateSqlInjection(String schema, String... conditions) {
for (String condition : conditions) {
if (StringUtils.isBlank(condition)) {
continue;
}
sqlValidator.validate("column", condition);
List<String> operator = new ArrayList<>(Arrays.asList("=", ">", "<", "> =", "< =", "! =", "!=", ">=", "<="));
condition = condition.trim().replace("( ", "(").replace(" )", ")").toLowerCase();
for (String op : operator) {
condition = replaceAll(condition, op).replaceAll(" +", " ");
}
Set<String> operands = getOperand(condition);
schema = schema.trim().replaceAll(" +", " ").toLowerCase();
Map<String, Set<String>> tableColumnAliasMap = getTableColumnAliasMap(operands);
Map<String, Set<String>> tableColumnMap = getTableColumnMap(schema, tableColumnAliasMap);
validateColumn(tableColumnMap);
}
}
private static Map<String, Set<String>> getTableColumnMap(String schema, Map<String, Set<String>> tableColumnAliasMap) {
Map<String, Set<String>> tableColumnMap = new HashMap<>();
schema = schema.substring(schema.indexOf("from"));
for (Map.Entry<String, Set<String>> entry : tableColumnAliasMap.entrySet()) {
int index = schema.indexOf(" " + entry.getKey() + " ");
if (index > -1) {
int startPos;
startPos = schema.substring(0, index - 1).lastIndexOf(' ', index);
Set<String> columns = entry.getValue();
tableColumnMap.put(schema.substring(startPos, index).trim(), columns);
} else {
throw new SQLInjectionException();
}
}
return tableColumnMap;
}
@SuppressWarnings("StringSplitter")
private static Map<String, Set<String>> getTableColumnAliasMap(Set<String> operands) {
Map<String, Set<String>> tableColumnMap = new HashMap<>();
for (String operand : operands) {
String[] tableColumn = operand.split("\\.");
if (tableColumn.length == 2) {
if (tableColumnMap.containsKey(tableColumn[0])) {
Set<String> columns = tableColumnMap.get(tableColumn[0]);
columns.add(tableColumn[1]);
} else {
Set<String> columns = new HashSet<>();
columns.add(tableColumn[1]);
tableColumnMap.put(tableColumn[0], columns);
}
} else {
throw new SQLInjectionException();
}
}
return tableColumnMap;
}
private static Set<String> getOperand(String condition) {
Set<String> operandList = new HashSet<>();
List<String> operatorList = new ArrayList<>(
Arrays.asList("!=", "=", ">", "<", " like ", " between ", " in ", " in(", " is ", " is not ", " equals ", " not equals "));
for (String op : operatorList) {
int startIndex = 0;
do {
int index = condition.indexOf(op, startIndex);
if (index > -1) {
char currentChar = condition.charAt(index - 1);
if (op.equals("=")) {
if (!((currentChar + "").equals("!") || (currentChar + "").equals(">") || (currentChar + "").equals("<"))) {
operandList.add(getOperand(condition, index, currentChar));
}
} else {
operandList.add(getOperand(condition, index, currentChar));
}
startIndex = index + op.length();
}
} while (condition.indexOf(op, startIndex) > -1);
}
return operandList;
}
private static String getOperand(String condition, int index, char currentChar) {
int startPos = 0;
if ((currentChar + "").equals(" ")) {
startPos = condition.substring(0, index - 1).lastIndexOf(' ', index);
} else {
startPos = condition.substring(0, index).lastIndexOf(' ', index);
}
String a = condition.substring(startPos == -1 ? 0 : startPos, index);
return a.trim().replace("(", "").replace(")", "");
}
private static String replaceAll(String condition, String op) {
int startIndex = 0;
do {
int index = condition.indexOf(op, startIndex);
if (index > -1) {
if (op.equals("=")) {
if (!((condition.charAt(index - 1) + "").equals("!") || (condition.charAt(index - 1) + "").equals(">")
|| (condition.charAt(index - 1) + "").equals("<"))) {
condition = condition.replace(op, " " + op + " ");
return condition;
}
startIndex = index + 2 + op.length();
} else if (op.equals("< =") || op.equals("> =") || op.equals("! =")) {
condition = condition.replace(op, op.replace(" ", ""));
return condition;
} else {
condition = condition.replace(op, " " + op + " ");
return condition;
}
} else {
return condition;
}
} while (condition.indexOf(op, startIndex) > -1);
return condition;
}
}