RANGER-5513:Enhance Ranger lookup API input validation (#880)
* RANGER-5513:Enhance Ranger lookup API input validation
* RANGER-5513:Enhance Ranger lookup API input validation - fixed copilot comments
* RANGER-5513:Enhance Ranger lookup API input validation - fixed copilot comments #2 set
* RANGER-5513:Enhance Ranger lookup API input validation - fixed copilot comments #3 set
* RANGER-5513:Enhance Ranger lookup API input validation - fixed copilot comments #4 set
* RANGER-5513:Enhance Ranger lookup API input validation - exception propagation issue fix
* RANGER-5513:Enhance Ranger lookup API input validation - hiveclient hms api call validation fix
* RANGER-5513:Enhance Ranger lookup API input validation - remove hadoopexception dependency on schemaregistry client
* RANGER-5513:Enhance Ranger lookup API input validation - HBase lookup issue fix
* RANGER-5513:Enhance Ranger lookup API input validation - Fix Baseclient compilation issue
---------
Co-authored-by: Ramesh Mani <rmani@apache.org>
diff --git a/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java b/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java
index dbbebbf..bab9643 100644
--- a/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java
+++ b/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java
@@ -31,6 +31,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.regex.PatternSyntaxException;
public abstract class BaseClient {
private static final Logger LOG = LoggerFactory.getLogger(BaseClient.class);
@@ -184,6 +185,168 @@ private void init() {
}
}
+ protected void validateSqlIdentifier(String identifier, String identifierType) throws HadoopException {
+ if (StringUtils.isBlank(identifier)) {
+ return;
+ }
+ if (identifier.contains("..") || identifier.contains("//") || identifier.contains("\\")) {
+ String msgDesc = "Invalid " + identifierType + ": [" + identifier + "]. Path traversal patterns are not allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ if (!identifier.matches("^[a-zA-Z0-9*?\\[\\]\\-\\$%\\{\\}\\=\\/\\._]+$")) {
+ String msgDesc = "Invalid " + identifierType + ": [" + identifier + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ }
+
+ protected String convertToSqlPattern(String pattern) throws HadoopException {
+ if (pattern == null || pattern.isEmpty()) {
+ return "%";
+ }
+ // Convert custom wildcards to SQL LIKE pattern:
+ // '*' -> '%' (multi-character wildcard)
+ // '?' -> '_' (single-character wildcard)
+ String sqlPattern = pattern.replace("*", "%").replace("?", "_");
+ return sqlPattern;
+ }
+
+ protected boolean matchesSqlPattern(String value, String pattern) throws HadoopException {
+ if (pattern == null || pattern.equals("%")) {
+ return true;
+ }
+
+ String regex = convertSqlPatternToRegex(pattern);
+ try {
+ return value.matches(regex);
+ } catch (PatternSyntaxException pe) {
+ String msgDesc = "Invalid value: [" + value + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ }
+
+ protected void validateUrlResourceName(String resourceName, String resourceType) throws HadoopException {
+ if (resourceName == null) {
+ return;
+ }
+ if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters with ( ., _, *, -) are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ }
+
+ public void validateWildcardPattern(String pattern, String patternType) throws HadoopException {
+ if (pattern == null || pattern.isEmpty()) {
+ return;
+ }
+ if (pattern.contains("..") || pattern.contains("//") || pattern.contains("\\")) {
+ String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Path traversal patterns are not allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ if (!pattern.matches("^[a-zA-Z0-9_.*?\\[\\]\\-\\$%\\{\\}\\=\\/]+$")) {
+ String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null);
+ LOG.error(msgDesc);
+ throw hdpException;
+ }
+ }
+
+ protected String convertSqlPatternToRegex(String pattern) {
+ StringBuilder regexBuilder = new StringBuilder("^");
+
+ for (int i = 0; i < pattern.length(); i++) {
+ char c = pattern.charAt(i);
+ switch (c) {
+ case '%':
+ // SQL LIKE wildcard: zero or more characters
+ regexBuilder.append(".*");
+ break;
+ case '_':
+ // SQL LIKE wildcard: exactly one character
+ regexBuilder.append('.');
+ break;
+ case '.':
+ case '^':
+ case '$':
+ case '+':
+ case '?':
+ case '{':
+ case '}':
+ case '[':
+ case ']':
+ case '(':
+ case ')':
+ case '|':
+ case '\\':
+ // Escape regex metacharacters so they are treated literally
+ regexBuilder.append('\\').append(c);
+ break;
+ default:
+ regexBuilder.append(c);
+ break;
+ }
+ }
+
+ return regexBuilder.toString();
+ }
+
+ public String convertWildcardToRegex(String wildcard) {
+ if (wildcard == null || wildcard.isEmpty()) {
+ return ".*";
+ }
+ StringBuilder regex = new StringBuilder("^");
+ for (int i = 0; i < wildcard.length(); i++) {
+ char c = wildcard.charAt(i);
+ switch (c) {
+ case '*':
+ regex.append(".*");
+ break;
+ case '?':
+ regex.append(".");
+ break;
+ case '.':
+ case '\\':
+ case '^':
+ case '$':
+ case '|':
+ regex.append('\\').append(c);
+ break;
+ case '{':
+ case '}':
+ case '[':
+ case ']':
+ regex.append('\\').append(c);
+ break;
+ default:
+ regex.append(c);
+ }
+ }
+ regex.append('$');
+ return regex.toString();
+ }
+
private HadoopException createException(Exception exp) {
return createException("Unable to login to Hadoop environment [" + serviceName + "]", exp);
}
diff --git a/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java b/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java
index 6d27ef9..758f5ae 100644
--- a/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java
+++ b/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java
@@ -326,4 +326,114 @@ class TestClient extends BaseClient {
assertEquals(IllegalArgumentException.class, ex.getClass());
}
}
+
+ @Test
+ public void test15_convertWildcardToRegex() {
+ class TestClient extends BaseClient {
+ TestClient() {
+ super("test", new HashMap<>());
+ }
+
+ @Override
+ protected void login() {
+ }
+
+ public String convert(String s) {
+ return convertWildcardToRegex(s);
+ }
+ }
+
+ TestClient client = new TestClient();
+ assertEquals(".*", client.convert(null));
+ assertEquals(".*", client.convert(""));
+ assertEquals("^atlas.*$", client.convert("atlas*"));
+ assertEquals("^atlas\\..*$", client.convert("atlas.*"));
+ assertEquals("^.*atlas.*$", client.convert("*atlas*"));
+ assertEquals("^at.as$", client.convert("at?as"));
+ assertEquals("^atlas\\.$", client.convert("atlas."));
+ assertEquals("^atlas\\$$", client.convert("atlas$"));
+ assertEquals("^atlas\\^$", client.convert("atlas^"));
+ assertEquals("^atlas\\[\\]$", client.convert("atlas[]"));
+ }
+
+ @Test
+ public void test16_convertToSqlPattern() throws Exception {
+ class TestClient extends BaseClient {
+ TestClient() {
+ super("test", new HashMap<>());
+ }
+
+ @Override
+ protected void login() {
+ }
+
+ public String convert(String s) throws Exception {
+ return convertToSqlPattern(s);
+ }
+ }
+
+ TestClient client = new TestClient();
+ assertEquals("%", client.convert(null));
+ assertEquals("%", client.convert(""));
+ assertEquals("atlas%", client.convert("atlas*"));
+ assertEquals("at_as", client.convert("at?as"));
+ }
+
+ @Test
+ public void test17_matchesSqlPattern() throws Exception {
+ class TestClient extends BaseClient {
+ TestClient() {
+ super("test", new HashMap<>());
+ }
+
+ @Override
+ protected void login() {
+ }
+
+ public boolean match(String v, String p) throws Exception {
+ return matchesSqlPattern(v, p);
+ }
+ }
+
+ TestClient client = new TestClient();
+ assertEquals(true, client.match("atlas", null));
+ assertEquals(true, client.match("atlas", "%"));
+ assertEquals(true, client.match("atlas", "atlas%"));
+ assertEquals(true, client.match("atlas_test", "atlas%"));
+ assertEquals(true, client.match("atlas", "at_as"));
+ assertEquals(false, client.match("atlas", "at_a"));
+ }
+
+ @Test
+ public void test18_validateWildcardPattern() {
+ class TestClient extends BaseClient {
+ TestClient() {
+ super("test", new HashMap<>());
+ }
+
+ @Override
+ protected void login() {
+ }
+
+ public void validate(String s) throws Exception {
+ validateWildcardPattern(s, "test");
+ }
+ }
+
+ TestClient client = new TestClient();
+ try {
+ client.validate("atlas*");
+ client.validate("atlas.*");
+ client.validate("atlas?");
+ } catch (Exception e) {
+ org.junit.jupiter.api.Assertions.fail("Should not throw exception for valid patterns");
+ }
+
+ try {
+ client.validate("atlas../test");
+ org.junit.jupiter.api.Assertions.fail("Should throw exception for path traversal");
+ } catch (Exception e) {
+ // Expected
+ }
+ }
}
diff --git a/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java b/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java
index 1b11001..ee2630a 100644
--- a/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java
+++ b/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java
@@ -191,6 +191,12 @@ public List<String> getTableList(final String tableNameMatching, final List<Stri
ret = Subject.doAs(subj, new PrivilegedAction<List<String>>() {
@Override
public List<String> run() {
+ String wildcard = tableNameMatching;
+ if (wildcard != null) {
+ wildcard = wildcard.replace(".*", "*");
+ }
+ validateWildcardPattern(wildcard, "table pattern");
+ String safeTablePattern = convertWildcardToRegex(wildcard);
List<String> tableList = new ArrayList<>();
Admin admin = null;
@@ -205,8 +211,7 @@ public List<String> run() {
LOG.info("getTableList: no exception: HbaseAvailability true");
admin = conn.getAdmin();
-
- List<TableDescriptor> htds = admin.listTableDescriptors(Pattern.compile(tableNameMatching));
+ List<TableDescriptor> htds = admin.listTableDescriptors(Pattern.compile(safeTablePattern));
if (htds != null) {
for (TableDescriptor htd : htds) {
@@ -240,6 +245,8 @@ public List<String> run() {
LOG.error(msgDesc + mnre);
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
} catch (IOException io) {
String msgDesc = "getTableList: Unable to get HBase table List for [repository:" + getConfigHolder().getDatasourceName() + ",table-match:" + tableNameMatching + "].";
HadoopException hdpException = new HadoopException(msgDesc, io);
@@ -291,14 +298,18 @@ public List<String> getColumnFamilyList(final String columnFamilyMatching, final
@Override
public List<String> run() {
+ String wildcard = columnFamilyMatching;
+ if (wildcard != null) {
+ wildcard = wildcard.replace(".*", "*");
+ }
+ validateWildcardPattern(wildcard, "column family pattern");
+ String safeColumnPattern = convertWildcardToRegex(wildcard);
List<String> colfList = new ArrayList<>();
Admin admin = null;
try {
LOG.info("getColumnFamilyList: setting config values from client");
-
setClientConfigValues(conf);
-
LOG.info("getColumnFamilyList: checking HbaseAvailability with the new config");
try (Connection conn = ConnectionFactory.createConnection(conf)) {
@@ -314,8 +325,7 @@ public List<String> run() {
if (htd != null) {
for (ColumnFamilyDescriptor hcd : htd.getColumnFamilies()) {
String colf = hcd.getNameAsString();
-
- if (colf.matches(columnFamilyMatching)) {
+ if (colf.matches(safeColumnPattern)) {
if (existingColumnFamilies != null && existingColumnFamilies.contains(colf)) {
continue;
} else {
@@ -345,6 +355,8 @@ public List<String> run() {
LOG.error(msgDesc + mnre);
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
} catch (IOException io) {
String msgDesc = "getColumnFamilyList: Unable to get HBase ColumnFamilyList for [repository:" + getConfigHolder().getDatasourceName() + ",table:" + tblName + ", table-match:" + columnFamilyMatching + "] ";
HadoopException hdpException = new HadoopException(msgDesc, io);
diff --git a/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java b/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java
index 4edf5f9..7c2f858 100644
--- a/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java
+++ b/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java
@@ -314,11 +314,11 @@ public void test09_getColumnFamilyList_filtersAndExceptions() throws Exception {
List<String> tables = new ArrayList<>(Collections.singletonList("t1"));
List<String> existing = new ArrayList<>(Collections.singletonList("cf1"));
- List<String> ret = client.getColumnFamilyList("cf.*", tables, existing);
+ List<String> ret = client.getColumnFamilyList("cf*", tables, existing);
assertEquals(Collections.singletonList("cf2"), ret);
Mockito.when(admin.getDescriptor(tn)).thenThrow(new IOException("io"));
- assertThrows(HadoopException.class, () -> client.getColumnFamilyList("cf.*", tables, null));
+ assertThrows(HadoopException.class, () -> client.getColumnFamilyList("cf*", tables, null));
}
}
@@ -588,6 +588,291 @@ private ColumnFamilyDescriptor mockCfd(String name) {
return cfd;
}
+ @Test
+ public void test19_validatePattern_validWildcards() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class);
+ MockedStatic<ConnectionFactory> connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+ Connection connection = Mockito.mock(Connection.class);
+ Admin admin = Mockito.mock(Admin.class);
+ connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class)))
+ .thenReturn(connection);
+ Mockito.when(connection.getAdmin()).thenReturn(admin);
+ Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenReturn(new ArrayList<>());
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> validPatterns = Arrays.asList("user*", "test?", "table_name", "prefix-*", "test{user}");
+ for (String pattern : validPatterns) {
+ List<String> result = client.getTableList(pattern, null);
+ assertNotNull(result, "Valid pattern should not throw exception: " + pattern);
+ }
+ }
+ }
+
+ @Test
+ public void test19_getTableList_wildcardReplacement() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class);
+ MockedStatic<ConnectionFactory> connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ Connection connection = Mockito.mock(Connection.class);
+ Admin admin = Mockito.mock(Admin.class);
+ connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class)))
+ .thenReturn(connection);
+ Mockito.when(connection.getAdmin()).thenReturn(admin);
+
+ TableDescriptor td1 = Mockito.mock(TableDescriptor.class);
+ TableName tn1 = TableName.valueOf("atlas_test");
+ Mockito.when(td1.getTableName()).thenReturn(tn1);
+
+ // We expect the pattern to be "^atlas.*$" because "atlas.*" should be converted to "atlas*" then to "^atlas.*$"
+ Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenAnswer(inv -> {
+ Pattern p = inv.getArgument(0);
+ if (p.pattern().equals("^atlas.*$")) {
+ return Collections.singletonList(td1);
+ }
+ return Collections.emptyList();
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> ret = client.getTableList("atlas.*", null);
+ assertEquals(Collections.singletonList("atlas_test"), ret);
+ }
+ }
+
+ @Test
+ public void test20_getColumnFamilyList_wildcardReplacement() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class);
+ MockedStatic<ConnectionFactory> connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ Connection connection = Mockito.mock(Connection.class);
+ Admin admin = Mockito.mock(Admin.class);
+ connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class)))
+ .thenReturn(connection);
+ Mockito.when(connection.getAdmin()).thenReturn(admin);
+
+ TableDescriptor td = Mockito.mock(TableDescriptor.class);
+ TableName tn = TableName.valueOf("t1");
+ Mockito.when(admin.getDescriptor(tn)).thenReturn(td);
+ ColumnFamilyDescriptor cfd1 = mockCfd("cf_test");
+ Mockito.when(td.getColumnFamilies()).thenReturn(new ColumnFamilyDescriptor[] {cfd1});
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> tables = new ArrayList<>(Collections.singletonList("t1"));
+ // "cf.*" should be converted to "cf*" then to "^cf.*$" which matches "cf_test"
+ List<String> ret = client.getColumnFamilyList("cf.*", tables, null);
+ assertEquals(Collections.singletonList("cf_test"), ret);
+ }
+ }
+
+ @Test
+ public void test21_validatePattern_rejectsReDoSPatterns() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> redosPatterns = Arrays.asList("(a+)+", "(a|a)*", "(a+)+$", "a{100,200}", "(x+x+)+y");
+
+ for (String pattern : redosPatterns) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList(pattern, null),
+ "ReDoS pattern should be rejected: " + pattern);
+ String msg = ex.getMessage();
+ assertTrue(msg != null && msg.contains("Invalid") && msg.contains("Only alphanumeric"), "Error should indicate invalid pattern for: " + pattern + ", but got: " + msg);
+ }
+ }
+ }
+
+ @Test
+ public void test22_validatePattern_rejectsComplexRegex() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> maliciousPatterns = Arrays.asList("test(abc)", "a+b", "x|y", "$(command)");
+
+ for (String pattern : maliciousPatterns) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList(pattern, null),
+ "Complex regex should be rejected: " + pattern);
+ assertTrue(ex.getMessage().contains("Invalid") && ex.getMessage().contains("Only alphanumeric"), "Error should indicate invalid pattern for: " + pattern);
+ }
+ }
+ }
+
+ @Test
+ public void test23_validatePattern_rejectsInjectionAttempts() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> injectionAttempts = Arrays.asList("'; DROP TABLE users; --", "../../../etc/passwd", "test<script>alert(1)</script>", "table\nname", "test\0null");
+
+ for (String pattern : injectionAttempts) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList(pattern, null),
+ "Injection attempt should be rejected: " + pattern);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid pattern for: " + pattern);
+ }
+ }
+ }
+
+ @Test
+ public void test24_columnFamilyMatching_rejectsReDoS() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> tables = Arrays.asList("table1");
+ List<String> redosPatterns = Arrays.asList("(a+)+", "(x|x)*");
+
+ for (String pattern : redosPatterns) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getColumnFamilyList(pattern, tables, null),
+ "ReDoS pattern should be rejected in column family: " + pattern);
+ assertTrue(ex.getMessage().contains("Invalid") && ex.getMessage().contains("Only alphanumeric"),
+ "Error should indicate invalid pattern for: " + pattern);
+ }
+ }
+ }
+
+ @Test
+ public void test25_convertWildcardToRegex_correctConversion() throws Exception {
+ Map<String, String> props = new HashMap<>();
+ props.put("username", "user");
+
+ try (MockedStatic<HBaseConfiguration> confStatic = Mockito.mockStatic(HBaseConfiguration.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class);
+ MockedStatic<ConnectionFactory> connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) {
+ Configuration conf = Mockito.mock(Configuration.class);
+ confStatic.when(HBaseConfiguration::create).thenReturn(conf);
+
+ Subject subject = Mockito.mock(Subject.class);
+ subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ Connection connection = Mockito.mock(Connection.class);
+ Admin admin = Mockito.mock(Admin.class);
+ connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class)))
+ .thenReturn(connection);
+ Mockito.when(connection.getAdmin()).thenReturn(admin);
+
+ TableDescriptor td1 = Mockito.mock(TableDescriptor.class);
+ TableName tn1 = TableName.valueOf("test_table");
+ Mockito.when(td1.getTableName()).thenReturn(tn1);
+ Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenAnswer(inv -> {
+ Pattern pattern = inv.getArgument(0);
+ List<TableDescriptor> result = new ArrayList<>();
+ if (pattern.matcher("test_table").matches()) {
+ result.add(td1);
+ }
+ return result;
+ });
+
+ HBaseClient client = new TestableHBaseClient("svc", props, subject);
+
+ List<String> result = client.getTableList("test*", null);
+ assertEquals(1, result.size());
+ assertTrue(result.contains("test_table"));
+ }
+ }
+
private static class TestableHBaseClient extends HBaseClient {
private final Subject testSubject;
diff --git a/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java b/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java
index aa40b97..f739fef 100644
--- a/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java
+++ b/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java
@@ -261,6 +261,8 @@ private List<String> getDBListFromHM(String databaseMatching, List<String> dbLis
List<String> ret = new ArrayList<>();
+ validateSqlIdentifier(databaseMatching, "database pattern");
+
try {
if (hiveClient != null) {
List<String> hiveDBList;
@@ -303,20 +305,15 @@ private List<String> getDBList(String databaseMatching, List<String> dbList) thr
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = "show databases";
-
- if (databaseMatching != null && !databaseMatching.isEmpty()) {
- sql = sql + " like \"" + databaseMatching + "\"";
- }
+ ResultSet rs = null;
try {
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ validateSqlIdentifier(databaseMatching, "database pattern");
+ String schemaPattern = convertToSqlPattern(databaseMatching);
+ rs = con.getMetaData().getSchemas(null, schemaPattern);
while (rs.next()) {
- String dbName = rs.getString(1);
+ String dbName = rs.getString("TABLE_SCHEM");
if (dbList != null && dbList.contains(dbName)) {
continue;
@@ -325,7 +322,7 @@ private List<String> getDBList(String databaseMatching, List<String> dbList) thr
ret.add(dbName);
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve database list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
@@ -334,7 +331,7 @@ private List<String> getDBList(String databaseMatching, List<String> dbList) thr
throw hdpException;
} catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Unable to retrieve database list.";
HadoopException hdpException = new HadoopException(msgDesc, sqle);
hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
@@ -342,9 +339,10 @@ private List<String> getDBList(String databaseMatching, List<String> dbList) thr
LOG.debug("<== HiveClient.getDBList() Error : ", sqle);
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
} finally {
close(rs);
- close(stat);
}
}
@@ -358,8 +356,11 @@ private List<String> getTblListFromHM(String tableNameMatching, List<String> dbL
List<String> ret = new ArrayList<>();
+ validateSqlIdentifier(tableNameMatching, "table pattern");
+
if (hiveClient != null && dbList != null && !dbList.isEmpty()) {
for (String dbName : dbList) {
+ validateSqlIdentifier(dbName, "database name");
try {
List<String> hiveTblList = hiveClient.getTables(dbName, tableNameMatching);
@@ -394,55 +395,31 @@ private List<String> getTblList(String tableNameMatching, List<String> dbList, L
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
try {
+ validateSqlIdentifier(tableNameMatching, "table pattern");
if (dbList != null && !dbList.isEmpty()) {
for (String db : dbList) {
- sql = "use " + db;
+ validateSqlIdentifier(db, "database name");
+ String tablePattern = convertToSqlPattern(tableNameMatching);
+ rs = con.getMetaData().getTables(null, db, tablePattern, new String[] {"TABLE", "VIEW"});
- try {
- stat = con.createStatement();
+ while (rs.next()) {
+ String tblName = rs.getString("TABLE_NAME");
- stat.execute(sql);
- } finally {
- close(stat);
-
- stat = null;
- }
-
- sql = "show tables ";
-
- if (tableNameMatching != null && !tableNameMatching.isEmpty()) {
- sql = sql + " like \"" + tableNameMatching + "\"";
- }
-
- try {
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
-
- while (rs.next()) {
- String tblName = rs.getString(1);
-
- if (tblList != null && tblList.contains(tblName)) {
- continue;
- }
-
- ret.add(tblName);
+ if (tblList != null && tblList.contains(tblName)) {
+ continue;
}
- } finally {
- close(rs);
- close(stat);
- rs = null;
- stat = null;
+ ret.add(tblName);
}
+ close(rs);
+ rs = null;
}
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve table list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
@@ -451,7 +428,7 @@ private List<String> getTblList(String tableNameMatching, List<String> dbList, L
throw hdpException;
} catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Unable to retrieve table list.";
HadoopException hdpException = new HadoopException(msgDesc, sqle);
hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
@@ -459,6 +436,10 @@ private List<String> getTblList(String tableNameMatching, List<String> dbList, L
LOG.debug("<== HiveClient.getTblList() Error : ", sqle);
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -473,13 +454,17 @@ private List<String> getClmListFromHM(String columnNameMatching, List<String> db
List<String> ret = new ArrayList<>();
String columnNameMatchingRegEx = null;
+ validateSqlIdentifier(columnNameMatching, "column pattern");
+
if (columnNameMatching != null && !columnNameMatching.isEmpty()) {
columnNameMatchingRegEx = columnNameMatching;
}
if (hiveClient != null && dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) {
for (String db : dbList) {
+ validateSqlIdentifier(db, "database name");
for (String tbl : tblList) {
+ validateSqlIdentifier(tbl, "table name");
try {
List<FieldSchema> hiveSch = hiveClient.getFields(db, tbl);
@@ -529,30 +514,20 @@ private List<String> getClmList(String columnNameMatching, List<String> dbList,
columnNameMatchingRegEx = columnNameMatching;
}
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
- if (dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) {
- for (String db : dbList) {
- for (String tbl : tblList) {
- try {
- sql = "use " + db;
-
- try {
- stat = con.createStatement();
-
- stat.execute(sql);
- } finally {
- close(stat);
- }
-
- sql = "describe " + tbl;
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ try {
+ validateSqlIdentifier(columnNameMatching, "column pattern");
+ if (dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) {
+ for (String db : dbList) {
+ validateSqlIdentifier(db, "database name");
+ for (String tbl : tblList) {
+ validateSqlIdentifier(tbl, "table name");
+ String columnPattern = convertToSqlPattern(columnNameMatching);
+ rs = con.getMetaData().getColumns(null, db, tbl, columnPattern);
while (rs.next()) {
- String columnName = rs.getString(1);
+ String columnName = rs.getString("COLUMN_NAME");
if (colList != null && colList.contains(columnName)) {
continue;
@@ -564,30 +539,33 @@ private List<String> getClmList(String columnNameMatching, List<String> dbList,
ret.add(columnName);
}
}
- } catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqlt);
-
- hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== HiveClient.getClmList() Error : ", sqlt);
-
- throw hdpException;
- } catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqle);
-
- hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== HiveClient.getClmList() Error : ", sqle);
-
- throw hdpException;
- } finally {
close(rs);
- close(stat);
+ rs = null;
}
}
}
+ } catch (SQLTimeoutException sqlt) {
+ String msgDesc = "Time Out, Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqlt);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== HiveClient.getClmList() Error : ", sqlt);
+
+ throw hdpException;
+ } catch (SQLException sqle) {
+ String msgDesc = "Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqle);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== HiveClient.getClmList() Error : ", sqle);
+
+ throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
diff --git a/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java b/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java
index fc74b95..69bc3e3 100644
--- a/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java
+++ b/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java
@@ -41,12 +41,12 @@
import java.security.Permission;
import java.security.PrivilegedExceptionAction;
import java.sql.Connection;
+import java.sql.DatabaseMetaData;
import java.sql.Driver;
import java.sql.DriverPropertyInfo;
import java.sql.ResultSet;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.SQLTimeoutException;
-import java.sql.Statement;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
@@ -220,12 +220,12 @@ public void test11_getDatabaseList_jdbcPath() throws Exception {
Field fCon = HiveClient.class.getDeclaredField("con");
fCon.setAccessible(true);
Connection con = Mockito.mock(Connection.class);
- Statement stat = Mockito.mock(Statement.class);
+ DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class);
ResultSet rs = Mockito.mock(ResultSet.class);
- when(con.createStatement()).thenReturn(stat);
- when(stat.executeQuery(Mockito.anyString())).thenReturn(rs);
+ when(con.getMetaData()).thenReturn(metadata);
+ when(metadata.getSchemas(Mockito.isNull(), Mockito.anyString())).thenReturn(rs);
when(rs.next()).thenReturn(true, false);
- when(rs.getString(1)).thenReturn("db1");
+ when(rs.getString("TABLE_SCHEM")).thenReturn("db1");
fCon.set(client, con);
List<String> out = client.getDatabaseList("db*", null);
assertEquals(Collections.singletonList("db1"), out);
@@ -251,17 +251,15 @@ public void test13_getTableList_jdbcPath_excludesAndPattern() throws Exception {
Field fCon = HiveClient.class.getDeclaredField("con");
fCon.setAccessible(true);
Connection con = Mockito.mock(Connection.class);
- Statement stat = Mockito.mock(Statement.class);
+ DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class);
ResultSet rs = Mockito.mock(ResultSet.class);
- when(con.createStatement()).thenReturn(stat);
- when(stat.execute(Mockito.eq("use db1"))).thenReturn(true);
- when(stat.executeQuery(Mockito.anyString())).thenReturn(rs);
+ when(con.getMetaData()).thenReturn(metadata);
+ when(metadata.getTables(Mockito.isNull(), Mockito.eq("db1"), Mockito.anyString(), Mockito.any())).thenReturn(rs);
when(rs.next()).thenReturn(true, true, false);
- when(rs.getString(1)).thenReturn("t1", "t2");
+ when(rs.getString("TABLE_NAME")).thenReturn("t1", "t2");
fCon.set(client, con);
List<String> out = client.getTableList("t*", Collections.singletonList("db1"), Collections.singletonList("t2"));
assertEquals(Collections.singletonList("t1"), out);
- Mockito.verify(stat, Mockito.atLeastOnce()).execute(Mockito.eq("use db1"));
}
@Test
@@ -273,10 +271,9 @@ public void test14_getTableList_jdbcPath_timeoutThrowsHadoopException() throws E
Field fCon = HiveClient.class.getDeclaredField("con");
fCon.setAccessible(true);
Connection con = Mockito.mock(Connection.class);
- Statement stat = Mockito.mock(Statement.class);
- when(con.createStatement()).thenReturn(stat);
- when(stat.execute(Mockito.anyString())).thenReturn(true);
- when(stat.executeQuery(Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout"));
+ DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class);
+ when(con.getMetaData()).thenReturn(metadata);
+ when(metadata.getTables(Mockito.isNull(), Mockito.anyString(), Mockito.anyString(), Mockito.any())).thenThrow(new SQLTimeoutException("timeout"));
fCon.set(client, con);
assertThrows(HadoopException.class, () -> client.getTableList("t*", Collections.singletonList("db1"), null));
}
@@ -290,13 +287,12 @@ public void test17_getClmList_jdbcPath_excludesAndPattern() throws Exception {
Field fCon = HiveClient.class.getDeclaredField("con");
fCon.setAccessible(true);
Connection con = Mockito.mock(Connection.class);
- Statement stat = Mockito.mock(Statement.class);
+ DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class);
ResultSet rs = Mockito.mock(ResultSet.class);
- when(con.createStatement()).thenReturn(stat);
- when(stat.execute(Mockito.eq("use db1"))).thenReturn(true);
- when(stat.executeQuery(Mockito.eq("describe t1"))).thenReturn(rs);
+ when(con.getMetaData()).thenReturn(metadata);
+ when(metadata.getColumns(Mockito.isNull(), Mockito.eq("db1"), Mockito.eq("t1"), Mockito.anyString())).thenReturn(rs);
when(rs.next()).thenReturn(true, true, false);
- when(rs.getString(1)).thenReturn("c1", "c2");
+ when(rs.getString("COLUMN_NAME")).thenReturn("c1", "c2");
fCon.set(client, con);
List<String> out = client.getColumnList("c*", Collections.singletonList("db1"), Collections.singletonList("t1"), Collections.singletonList("c2"));
assertEquals(Collections.singletonList("c1"), out);
@@ -311,10 +307,9 @@ public void test18_getClmList_jdbcPath_timeoutThrowsHadoopException() throws Exc
Field fCon = HiveClient.class.getDeclaredField("con");
fCon.setAccessible(true);
Connection con = Mockito.mock(Connection.class);
- Statement stat = Mockito.mock(Statement.class);
- when(con.createStatement()).thenReturn(stat);
- when(stat.execute(Mockito.anyString())).thenReturn(true);
- when(stat.executeQuery(Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout"));
+ DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class);
+ when(con.getMetaData()).thenReturn(metadata);
+ when(metadata.getColumns(Mockito.isNull(), Mockito.anyString(), Mockito.anyString(), Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout"));
fCon.set(client, con);
assertThrows(HadoopException.class, () -> client.getColumnList("c*", Collections.singletonList("db1"), Collections.singletonList("t1"), null));
}
@@ -476,6 +471,146 @@ public void test25_initHive_nonKerberosPath_invokesJdbcInitConnectionAndWraps()
}
}
+ @Test
+ public void test26_validateSqlIdentifier_validInput() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Method m = HiveClient.class.getSuperclass().getDeclaredMethod("validateSqlIdentifier", String.class, String.class);
+ m.setAccessible(true);
+
+ m.invoke(client, "test_db123", "database");
+ m.invoke(client, "table_name", "table");
+ m.invoke(client, "col*", "column pattern");
+ m.invoke(client, "db%", "database pattern");
+ m.invoke(client, "a_b_c_123", "identifier");
+ }
+
+ @Test
+ public void test27_validateSqlIdentifier_sqlInjectionAttempts() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Method m = HiveClient.class.getSuperclass().getDeclaredMethod("validateSqlIdentifier", String.class, String.class);
+ m.setAccessible(true);
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "test\" OR 1=1 --", "database");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject double quote injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "testdb; DROP TABLE users; --", "database");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject semicolon command injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "test'; DROP DATABASE production; --", "table");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject single quote command injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "test\n--malicious", "identifier");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject newline injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "test`malicious`", "identifier");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject backtick injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "test$(whoami)", "identifier");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject shell command injection");
+
+ assertThrows(HadoopException.class, () -> {
+ try {
+ m.invoke(client, "../../../etc/passwd", "identifier");
+ } catch (Exception e) {
+ throw e.getCause();
+ }
+ }, "Should reject path traversal");
+ }
+
+ @Test
+ public void test28_convertToSqlPattern_convertsWildcards() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Method m = HiveClient.class.getSuperclass().getDeclaredMethod("convertToSqlPattern", String.class);
+ m.setAccessible(true);
+
+ String result1 = (String) m.invoke(client, "test*");
+ assertEquals("test%", result1, "Should convert * to %");
+
+ String result2 = (String) m.invoke(client, "*");
+ assertEquals("%", result2, "Should convert single * to %");
+
+ String result3 = (String) m.invoke(client, "test*pattern*");
+ assertEquals("test%pattern%", result3, "Should convert multiple *");
+
+ String result4 = (String) m.invoke(client, (Object) null);
+ assertEquals("%", result4, "Should handle null as %");
+
+ String result5 = (String) m.invoke(client, "");
+ assertEquals("%", result5, "Should handle empty string as %");
+ }
+
+ @Test
+ public void test29_getDatabaseList_rejectsInjection() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup");
+ fFlag.setAccessible(true);
+ fFlag.set(client, false);
+ Field fCon = HiveClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ Connection con = Mockito.mock(Connection.class);
+ fCon.set(client, con);
+
+ assertThrows(HadoopException.class, () -> client.getDatabaseList("testdb\"; DROP DATABASE production; --", null), "Should reject SQL injection in database pattern");
+ }
+
+ @Test
+ public void test30_getTableList_rejectsInjectionInDbName() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup");
+ fFlag.setAccessible(true);
+ fFlag.set(client, false);
+ Field fCon = HiveClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ Connection con = Mockito.mock(Connection.class);
+ fCon.set(client, con);
+
+ assertThrows(HadoopException.class, () -> client.getTableList("valid", Collections.singletonList("testdb; DROP TABLE users; --"), null), "Should reject SQL injection in database name");
+ }
+
+ @Test
+ public void test31_getColumnList_rejectsInjectionInTableName() throws Exception {
+ NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>());
+ Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup");
+ fFlag.setAccessible(true);
+ fFlag.set(client, false);
+ Field fCon = HiveClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ Connection con = Mockito.mock(Connection.class);
+ fCon.set(client, con);
+
+ assertThrows(HadoopException.class, () -> client.getColumnList("valid", Collections.singletonList("db1"), Collections.singletonList("test'; DROP TABLE users; --"), null), "Should reject SQL injection in table name");
+ }
+
public static class NoopHiveClient extends HiveClient {
public boolean initCalled;
diff --git a/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java b/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java
index 7e06de9..ec6d79e 100644
--- a/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java
+++ b/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java
@@ -340,6 +340,7 @@ public List<String> getServiceList(List<String> knoxTopologyList, String service
client.addFilter(new HTTPBasicAuthFilter(userName, decryptedPwd));
for (String topologyName : knoxTopologyList) {
+ validateResourceName(topologyName, "topology name");
WebResource webResource = client.resource(knoxUrl + "/" + topologyName);
response = webResource.accept(EXPECTED_MIME_TYPE).get(ClientResponse.class);
@@ -420,4 +421,32 @@ public List<String> getServiceList(List<String> knoxTopologyList, String service
}
return serviceList;
}
+
+ private void validateResourceName(String resourceName, String resourceType) {
+ if (resourceName == null) {
+ return;
+ }
+
+ if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + ERROR_MSG, null, null);
+
+ LOG.error(msgDesc);
+
+ throw hdpException;
+ }
+
+ if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters, dots, underscores, hyphens, and wildcards are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + ERROR_MSG, null, null);
+
+ LOG.error(msgDesc);
+
+ throw hdpException;
+ }
+ }
}
diff --git a/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java b/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java
index e70f7b6..8d3c7aa 100644
--- a/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java
+++ b/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java
@@ -37,6 +37,7 @@
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
+import java.lang.reflect.Method;
import java.security.Permission;
import java.util.ArrayList;
import java.util.Arrays;
@@ -435,4 +436,85 @@ public void test20_main_validArgs_happyPath_printsServices() {
System.setOut(origOut);
}
}
+
+ @Test
+ public void test12_validateResourceName_rejectsPathTraversal() throws Exception {
+ KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd");
+
+ List<String> pathTraversalInputs = Arrays.asList("../etc/passwd", "../../sensitive", "topology/../admin", "topology//malicious", "test\\windows\\path", "..\\..\\config");
+
+ for (String input : pathTraversalInputs) {
+ HadoopException ex = Assertions.assertThrows(HadoopException.class,
+ () -> invokeValidateResourceName(client, input, "topology name"),
+ "Path traversal should be rejected: " + input);
+ Assertions.assertTrue(ex.getMessage().contains("Path traversal"),
+ "Error should indicate path traversal for: " + input);
+ }
+ }
+
+ @Test
+ public void test13_validateResourceName_rejectsSpecialCharacters() throws Exception {
+ KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd");
+
+ List<String> invalidInputs = Arrays.asList("'; DROP TABLE users; --", "topology<script>alert(1)</script>", "test@topology", "topology#name", "test!topology", "topology&name", "topology(with)parens", "topology{with}braces", "topology[with]brackets", "topology$name", "topology%encoded", "topology name", "topology\ttab", "topology\nnewline", "topology;rm -rf /", "topology|cat /etc/passwd", "topology`whoami`", "topology$(whoami)");
+
+ for (String input : invalidInputs) {
+ HadoopException ex = Assertions.assertThrows(HadoopException.class,
+ () -> invokeValidateResourceName(client, input, "topology name"), "Special characters should be rejected: " + input);
+ Assertions.assertTrue(ex.getMessage().contains("Invalid"), "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ @Test
+ public void test14_validateResourceName_acceptsValidNames() throws Exception {
+ KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd");
+
+ List<String> validInputs = Arrays.asList("topology", "topology_name", "topology123", "TOPOLOGY", "Topology_Name_123", "_topology", "topology_", "topology.name", "topology-name", "topology*", "top*");
+
+ for (String input : validInputs) {
+ try {
+ invokeValidateResourceName(client, input, "topology name");
+ } catch (Exception e) {
+ throw new AssertionError("Valid topology name should not throw exception: " + input, e);
+ }
+ }
+ }
+
+ @Test
+ public void test15_validateResourceName_rejectsNullByteInjection() throws Exception {
+ KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd");
+
+ HadoopException ex = Assertions.assertThrows(HadoopException.class,
+ () -> invokeValidateResourceName(client, "topology\0null", "topology name"));
+ Assertions.assertTrue(ex.getMessage().contains("Invalid"));
+ }
+
+ @Test
+ public void test16_validateResourceName_rejectsUrlEncoded() throws Exception {
+ KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd");
+
+ List<String> encodedInputs = Arrays.asList("%2e%2e%2f", "topology%00", "test%20space");
+
+ for (String input : encodedInputs) {
+ HadoopException ex = Assertions.assertThrows(HadoopException.class,
+ () -> invokeValidateResourceName(client, input, "topology name"),
+ "URL encoded attack should be rejected: " + input);
+ Assertions.assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ private void invokeValidateResourceName(KnoxClient client, String resourceName, String resourceType) throws Exception {
+ Method method = KnoxClient.class.getDeclaredMethod("validateResourceName", String.class, String.class);
+ method.setAccessible(true);
+ try {
+ method.invoke(client, resourceName, resourceType);
+ } catch (java.lang.reflect.InvocationTargetException e) {
+ Throwable cause = e.getCause();
+ if (cause instanceof HadoopException) {
+ throw (HadoopException) cause;
+ }
+ throw e;
+ }
+ }
}
diff --git a/plugin-elasticsearch/pom.xml b/plugin-elasticsearch/pom.xml
index 32f6203..0461cd6 100644
--- a/plugin-elasticsearch/pom.xml
+++ b/plugin-elasticsearch/pom.xml
@@ -79,5 +79,23 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter</artifactId>
+ <version>${junit.jupiter.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-junit-jupiter</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java b/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java
index d7802e8..2a5e7aa 100644
--- a/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java
+++ b/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java
@@ -133,6 +133,7 @@ public List<String> getIndexList(final String indexMatching, final List<String>
String indexApi;
if (StringUtils.isNotEmpty(indexMatching)) {
+ validateUrlResourceName(indexMatching, "index pattern");
indexApi = '/' + indexMatching;
if (!indexApi.endsWith("*")) {
diff --git a/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java b/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java
new file mode 100644
index 0000000..46120ee
--- /dev/null
+++ b/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.ranger.services.elasticsearch.client;
+
+import org.apache.ranger.plugin.client.HadoopException;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@ExtendWith(MockitoExtension.class)
+@TestMethodOrder(MethodOrderer.MethodName.class)
+public class TestElasticsearchClient {
+ @Test
+ public void test01_validateUrlResourceName_rejectsPathTraversal() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("elasticsearch.url", "http://localhost:9200");
+ configs.put("username", "test");
+ ElasticsearchClient client = new ElasticsearchClient("svc", configs);
+
+ List<String> pathTraversalInputs = Arrays.asList(
+ "../etc/passwd",
+ "../../sensitive",
+ "test/../admin",
+ "index//malicious",
+ "test\\windows\\path",
+ "..\\..\\config");
+
+ for (String input : pathTraversalInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeValidateUrlResourceName(client, input, "index pattern"),
+ "Path traversal should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Path traversal"),
+ "Error should indicate path traversal for: " + input);
+ }
+ }
+
+ @Test
+ public void test02_validateUrlResourceName_rejectsSpecialCharacters() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("elasticsearch.url", "http://localhost:9200");
+ configs.put("username", "test");
+ ElasticsearchClient client = new ElasticsearchClient("svc", configs);
+
+ List<String> invalidInputs = Arrays.asList(
+ "'; DROP TABLE users; --",
+ "index<script>alert(1)</script>",
+ "test@index",
+ "index#name",
+ "test!index",
+ "index&name",
+ "index(with)parens",
+ "index{with}braces",
+ "index[with]brackets",
+ "index$name",
+ "index%encoded",
+ "index name",
+ "index\ttab",
+ "index\nnewline",
+ "index;rm -rf /",
+ "index|cat /etc/passwd",
+ "index`whoami`",
+ "index$(whoami)");
+
+ for (String input : invalidInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeValidateUrlResourceName(client, input, "index pattern"),
+ "Special characters should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ @Test
+ public void test03_validateUrlResourceName_acceptsValidNames() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("elasticsearch.url", "http://localhost:9200");
+ configs.put("username", "test");
+ ElasticsearchClient client = new ElasticsearchClient("svc", configs);
+
+ List<String> validInputs = Arrays.asList(
+ "index",
+ "index_name",
+ "index123",
+ "INDEX",
+ "Index_Name_123",
+ "_index",
+ "index_",
+ "index.name",
+ "index-name",
+ "index*",
+ "idx*");
+
+ for (String input : validInputs) {
+ try {
+ invokeValidateUrlResourceName(client, input, "index pattern");
+ } catch (Exception e) {
+ throw new AssertionError("Valid index name should not throw exception: " + input, e);
+ }
+ }
+ }
+
+ @Test
+ public void test04_validateUrlResourceName_rejectsNullByteInjection() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("elasticsearch.url", "http://localhost:9200");
+ configs.put("username", "test");
+ ElasticsearchClient client = new ElasticsearchClient("svc", configs);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeValidateUrlResourceName(client, "index\0null", "index pattern"));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+
+ @Test
+ public void test05_validateUrlResourceName_rejectsUrlEncoded() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("elasticsearch.url", "http://localhost:9200");
+ configs.put("username", "test");
+ ElasticsearchClient client = new ElasticsearchClient("svc", configs);
+
+ List<String> encodedInputs = Arrays.asList(
+ "%2e%2e%2f",
+ "index%00",
+ "test%20space");
+
+ for (String input : encodedInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeValidateUrlResourceName(client, input, "index pattern"),
+ "URL encoded attack should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ private void invokeValidateUrlResourceName(ElasticsearchClient client, String resourceName, String resourceType) throws Exception {
+ Method method = ElasticsearchClient.class.getSuperclass().getDeclaredMethod("validateUrlResourceName", String.class, String.class);
+ method.setAccessible(true);
+ try {
+ method.invoke(client, resourceName, resourceType);
+ } catch (java.lang.reflect.InvocationTargetException e) {
+ Throwable cause = e.getCause();
+ if (cause instanceof HadoopException) {
+ throw (HadoopException) cause;
+ }
+ throw e;
+ }
+ }
+}
diff --git a/plugin-presto/pom.xml b/plugin-presto/pom.xml
index 3a01427..15431c7 100644
--- a/plugin-presto/pom.xml
+++ b/plugin-presto/pom.xml
@@ -107,6 +107,18 @@
<version>${junit.jupiter.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-junit-jupiter</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<testResources>
diff --git a/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java b/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java
index 2492dd7..c034824 100644
--- a/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java
+++ b/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java
@@ -19,7 +19,6 @@
package org.apache.ranger.services.presto.client;
import org.apache.commons.io.FilenameUtils;
-import org.apache.commons.lang3.StringUtils;
import org.apache.ranger.plugin.client.BaseClient;
import org.apache.ranger.plugin.client.HadoopConfigHolder;
import org.apache.ranger.plugin.client.HadoopException;
@@ -118,6 +117,8 @@ public List<String> getSchemaList(String needle, List<String> catalogs, List<Str
ret = getSchemas(ndl, cats, shms);
} catch (HadoopException he) {
LOG.error("<== PrestoClient.getSchemaList() :Unable to get the Schema List", he);
+
+ throw he;
}
return ret;
@@ -300,43 +301,41 @@ private List<String> getCatalogs(String needle, List<String> catalogs) throws Ha
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = "SHOW CATALOGS";
+ ResultSet rs = null;
try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- // Cannot use a prepared statement for this as presto does not support that
- sql += " LIKE '" + escapeSql(needle) + "%'";
- }
-
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ validateSqlIdentifier(needle, "catalog pattern");
+ String catalogPattern = convertToSqlPattern(needle);
+ rs = con.getMetaData().getCatalogs();
while (rs.next()) {
- String catalogName = rs.getString(1);
+ String catalogName = rs.getString("TABLE_CAT");
if (catalogs != null && catalogs.contains(catalogName)) {
continue;
}
- ret.add(catalogName);
+ if (catalogPattern == null || catalogPattern.equals("%") || matchesSqlPattern(catalogName, catalogPattern)) {
+ ret.add(catalogName);
+ }
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve catalog list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+ throw hdpException;
} catch (SQLException se) {
- String msg = "Unable to execute SQL [" + sql + "]. ";
+ String msg = "Unable to retrieve catalog list. ";
HadoopException he = new HadoopException(msg, se);
he.generateResponseDataMap(false, getMessage(se), msg + ERR_MSG, null, null);
throw he;
+ } catch (HadoopException he) {
+ throw he;
} finally {
close(rs);
- close(stat);
}
}
@@ -347,43 +346,31 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
try {
+ validateSqlIdentifier(needle, "schema pattern");
+ String schemaPattern = convertToSqlPattern(needle);
if (catalogs != null && !catalogs.isEmpty()) {
for (String catalog : catalogs) {
- sql = "SHOW SCHEMAS FROM \"" + escapeSql(catalog) + "\"";
+ validateSqlIdentifier(catalog, "catalog name");
+ rs = con.getMetaData().getSchemas(catalog, schemaPattern);
- try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- sql += " LIKE '" + escapeSql(needle) + "%'";
+ while (rs.next()) {
+ String schema = rs.getString("TABLE_SCHEM");
+
+ if (schemas != null && schemas.contains(schema)) {
+ continue;
}
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
-
- while (rs.next()) {
- String schema = rs.getString(1);
-
- if (schemas != null && schemas.contains(schema)) {
- continue;
- }
-
- ret.add(schema);
- }
- } finally {
- close(rs);
- close(stat);
-
- rs = null;
- stat = null;
+ ret.add(schema);
}
+ close(rs);
+ rs = null;
}
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve schema list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
@@ -392,7 +379,7 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
throw hdpException;
} catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Unable to retrieve schema list.";
HadoopException hdpException = new HadoopException(msgDesc, sqle);
hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
@@ -400,6 +387,10 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
LOG.debug("<== PrestoClient.getSchemas() Error : ", sqle);
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -410,61 +401,54 @@ private List<String> getTables(String needle, List<String> catalogs, List<String
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
- if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) {
- try {
+ try {
+ validateSqlIdentifier(needle, "table pattern");
+ String tablePattern = convertToSqlPattern(needle);
+ if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) {
for (String catalog : catalogs) {
+ validateSqlIdentifier(catalog, "catalog name");
for (String schema : schemas) {
- sql = "SHOW tables FROM \"" + escapeSql(catalog) + "\".\"" + escapeSql(schema) + "\"";
+ validateSqlIdentifier(schema, "schema name");
+ rs = con.getMetaData().getTables(catalog, schema, tablePattern, new String[] {"TABLE", "VIEW"});
- try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- sql += " LIKE '" + escapeSql(needle) + "%'";
+ while (rs.next()) {
+ String table = rs.getString("TABLE_NAME");
+
+ if (tables != null && tables.contains(table)) {
+ continue;
}
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
-
- while (rs.next()) {
- String table = rs.getString(1);
-
- if (tables != null && tables.contains(table)) {
- continue;
- }
-
- ret.add(table);
- }
- } finally {
- close(rs);
- close(stat);
-
- rs = null;
- stat = null;
+ ret.add(table);
}
+ close(rs);
+ rs = null;
}
}
- } catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqlt);
-
- hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== PrestoClient.getTables() Error : ", sqlt);
-
- throw hdpException;
- } catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqle);
-
- hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== PrestoClient.getTables() Error : ", sqle);
-
- throw hdpException;
}
+ } catch (SQLTimeoutException sqlt) {
+ String msgDesc = "Time Out, Unable to retrieve table list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqlt);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== PrestoClient.getTables() Error : ", sqlt);
+
+ throw hdpException;
+ } catch (SQLException sqle) {
+ String msgDesc = "Unable to retrieve table list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqle);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== PrestoClient.getTables() Error : ", sqle);
+
+ throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -477,66 +461,64 @@ private List<String> getColumns(String needle, List<String> catalogs, List<Strin
if (con != null) {
String regex = null;
ResultSet rs = null;
- String sql = null;
- Statement stat = null;
if (needle != null && !needle.isEmpty()) {
regex = needle;
}
- if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty() && tables != null && !tables.isEmpty()) {
- try {
+ try {
+ validateSqlIdentifier(needle, "column pattern");
+ String columnPattern = convertToSqlPattern(needle);
+ if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty() && tables != null && !tables.isEmpty()) {
for (String catalog : catalogs) {
+ validateSqlIdentifier(catalog, "catalog name");
for (String schema : schemas) {
+ validateSqlIdentifier(schema, "schema name");
for (String table : tables) {
- sql = "SHOW COLUMNS FROM \"" + escapeSql(catalog) + "\"." + "\"" + escapeSql(schema) + "\"." + "\"" + escapeSql(table) + "\"";
+ validateSqlIdentifier(table, "table name");
+ rs = con.getMetaData().getColumns(catalog, schema, table, columnPattern);
- try {
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ while (rs.next()) {
+ String column = rs.getString("COLUMN_NAME");
- while (rs.next()) {
- String column = rs.getString(1);
-
- if (columns != null && columns.contains(column)) {
- continue;
- }
-
- if (regex == null) {
- ret.add(column);
- } else if (FilenameUtils.wildcardMatch(column, regex)) {
- ret.add(column);
- }
+ if (columns != null && columns.contains(column)) {
+ continue;
}
- } finally {
- close(rs);
- close(stat);
- stat = null;
- rs = null;
+ if (regex == null) {
+ ret.add(column);
+ } else if (FilenameUtils.wildcardMatch(column, regex)) {
+ ret.add(column);
+ }
}
+ close(rs);
+ rs = null;
}
}
}
- } catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqlt);
-
- hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== PrestoClient.getColumns() Error : ", sqlt);
-
- throw hdpException;
- } catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqle);
-
- hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
-
- LOG.debug("<== PrestoClient.getColumns() Error : ", sqle);
-
- throw hdpException;
}
+ } catch (SQLTimeoutException sqlt) {
+ String msgDesc = "Time Out, Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqlt);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== PrestoClient.getColumns() Error : ", sqlt);
+
+ throw hdpException;
+ } catch (SQLException sqle) {
+ String msgDesc = "Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqle);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
+
+ LOG.debug("<== PrestoClient.getColumns() Error : ", sqle);
+
+ throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -552,11 +534,4 @@ private void close(Connection con) {
LOG.error("Unable to close Presto SQL connection", e);
}
}
-
- private static String escapeSql(String str) {
- if (str == null) {
- return null;
- }
- return StringUtils.replace(str, "'", "''");
- }
}
diff --git a/plugin-presto/src/test/java/org/apache/ranger/services/presto/client/TestPrestoClient.java b/plugin-presto/src/test/java/org/apache/ranger/services/presto/client/TestPrestoClient.java
new file mode 100644
index 0000000..29462e6
--- /dev/null
+++ b/plugin-presto/src/test/java/org/apache/ranger/services/presto/client/TestPrestoClient.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.ranger.services.presto.client;
+
+import org.apache.ranger.plugin.client.HadoopException;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import javax.security.auth.Subject;
+
+import java.lang.reflect.Field;
+import java.security.PrivilegedAction;
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.DriverManager;
+import java.sql.ResultSet;
+import java.sql.SQLTimeoutException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+@TestMethodOrder(MethodOrderer.MethodName.class)
+public class TestPrestoClient {
+ private PrestoClient createMockedClient() throws Exception {
+ return createMockedClient(new HashMap<>());
+ }
+
+ private PrestoClient createMockedClient(Map<String, String> props) throws Exception {
+ Map<String, String> config = new HashMap<>();
+ config.put("username", "test");
+ config.put("password", "test");
+ config.put("jdbc.driverClassName", "com.facebook.presto.jdbc.PrestoDriver");
+ config.put("jdbc.url", "jdbc:presto://localhost:8080");
+ config.putAll(props);
+
+ NoConnectionPrestoClient client = new NoConnectionPrestoClient("svc", config);
+ return client;
+ }
+
+ @Test
+ public void test01_getCatalogList_normalOperation() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenReturn(rs);
+ when(rs.next()).thenReturn(true, true, false);
+ when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2");
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+ List<String> out = client.getCatalogList("*", null);
+ assertEquals(Arrays.asList("catalog1", "catalog2"), out);
+ }
+ }
+
+ @Test
+ public void test02_getCatalogList_withExcludeList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenReturn(rs);
+ when(rs.next()).thenReturn(true, true, false);
+ when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2");
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+ List<String> out = client.getCatalogList("*", Collections.singletonList("catalog1"));
+ assertEquals(Collections.singletonList("catalog2"), out);
+ }
+ }
+
+ @Test
+ public void test03_getCatalogList_timeoutThrowsHadoopException() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenThrow(SQLTimeoutException.class);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+ assertThrows(HadoopException.class, () -> client.getCatalogList("*", null));
+ }
+ }
+
+ @Test
+ public void test04_validateSqlIdentifier_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+
+ List<String> maliciousInputs = Arrays.asList(
+ "'; DROP TABLE users; --",
+ "test\" OR 1=1 --",
+ "catalog'; DELETE FROM users; --",
+ "../../../etc/passwd",
+ "test<script>alert(1)</script>");
+
+ for (String input : maliciousInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getCatalogList(input, null),
+ "SQL injection attempt should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+ }
+
+ @Test
+ public void test05_validateSqlIdentifier_rejectsSpecialCharacters() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+
+ List<String> invalidInputs = Arrays.asList(
+ "test@catalog",
+ "catalog#name",
+ "table!name",
+ "name(with)parens");
+
+ for (String input : invalidInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getCatalogList(input, null),
+ "Special characters should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+ }
+
+ @Test
+ public void test06_schemaValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+ Field fCon = PrestoClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ fCon.set(client, mockCon);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getSchemaList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test07_tableValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test08_columnValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getColumnList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), Collections.singletonList("table1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test09_catalogName_validateInList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getSchemas(anyString(), anyString())).thenReturn(rs);
+ when(rs.next()).thenReturn(false);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+ Field fCon = PrestoClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ fCon.set(client, mockCon);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getSchemaList("schema1", Arrays.asList("catalog1", "'; DROP TABLE --"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test10_schemaName_validateInList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getTables(anyString(), anyString(), anyString(), any())).thenReturn(rs);
+ when(rs.next()).thenReturn(false);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ PrestoClient client = createMockedClient();
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList("table1", Collections.singletonList("catalog1"), Arrays.asList("schema1", "'; DROP --"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ private static class NoConnectionPrestoClient extends PrestoClient {
+ public NoConnectionPrestoClient(String serviceName, Map<String, String> connectionProperties) {
+ super(serviceName, connectionProperties);
+ }
+
+ @Override
+ protected Subject getLoginSubject() {
+ Subject subject = new Subject();
+ return subject;
+ }
+
+ @Override
+ protected void login() {
+ }
+ }
+}
diff --git a/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java b/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java
index 074d005..ac5d3a0 100644
--- a/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java
+++ b/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java
@@ -118,15 +118,63 @@ public String toString() {
}
List<String> expandSchemaMetadataNameRegex(List<String> schemaGroupList, String lookupSchemaMetadataName) {
+ validatePattern(lookupSchemaMetadataName, "schema metadata pattern");
+ String safePattern = convertWildcardToRegex(lookupSchemaMetadataName);
List<String> res = new ArrayList<>();
Collection<String> schemas = client.getSchemaNames(schemaGroupList);
schemas.forEach(sName -> {
- if (sName.matches(lookupSchemaMetadataName)) {
+ if (sName.matches(safePattern)) {
res.add(sName);
}
});
return res;
}
+
+ private void validatePattern(String pattern, String patternType) {
+ if (pattern == null || pattern.isEmpty()) {
+ return;
+ }
+ if (!pattern.matches("^[a-zA-Z0-9*?\\[\\]\\-\\$%\\{\\}\\=\\/\\._]+$")) {
+ String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed.";
+ LOG.error(msgDesc);
+ throw new IllegalArgumentException(msgDesc);
+ }
+ }
+
+ protected String convertWildcardToRegex(String wildcard) {
+ if (wildcard == null || wildcard.isEmpty()) {
+ return ".*";
+ }
+ StringBuilder regex = new StringBuilder("^");
+ for (int i = 0; i < wildcard.length(); i++) {
+ char c = wildcard.charAt(i);
+ switch (c) {
+ case '*':
+ regex.append(".*");
+ break;
+ case '?':
+ regex.append(".");
+ break;
+ case '.':
+ case '\\':
+ case '^':
+ case '$':
+ case '|':
+ regex.append('\\').append(c);
+ break;
+ case '{':
+ case '}':
+ case '[':
+ case ']':
+ regex.append('\\').append(c);
+ break;
+ default:
+ regex.append(c);
+ }
+ }
+ regex.append('$');
+ return regex.toString();
+ }
}
diff --git a/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java b/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java
index e578ec8..779ce3b 100644
--- a/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java
+++ b/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java
@@ -19,6 +19,7 @@
import org.apache.ranger.services.schema.registry.client.connection.ISchemaRegistryClient;
import org.apache.ranger.services.schema.registry.client.util.DefaultSchemaRegistryClientForTesting;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -27,8 +28,6 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNull;
public class AutocompletionAgentTest {
@Test
@@ -37,11 +36,11 @@ public void connectionTest() {
AutocompletionAgent autocompletionAgent = new AutocompletionAgent("schema-registry", client);
HashMap<String, Object> res = autocompletionAgent.connectionTest();
- assertEquals(true, res.get("connectivityStatus"));
- assertEquals("ConnectionTest Successful", res.get("message"));
- assertEquals("ConnectionTest Successful", res.get("description"));
- assertNull(res.get("objectId"));
- assertNull(res.get("fieldName"));
+ Assertions.assertEquals(true, res.get("connectivityStatus"));
+ Assertions.assertEquals("ConnectionTest Successful", res.get("message"));
+ Assertions.assertEquals("ConnectionTest Successful", res.get("description"));
+ Assertions.assertNull(res.get("objectId"));
+ Assertions.assertNull(res.get("fieldName"));
client = new DefaultSchemaRegistryClientForTesting() {
public void checkConnection() throws Exception {
@@ -52,11 +51,11 @@ public void checkConnection() throws Exception {
res = autocompletionAgent.connectionTest();
String errMessage = "You can still save the repository and start creating policies, but you would not be able to use autocomplete for resource names. Check server logs for more info.";
- assertEquals(false, res.get("connectivityStatus"));
+ Assertions.assertEquals(false, res.get("connectivityStatus"));
assertThat(res.get("message"), is(errMessage));
assertThat(res.get("description"), is(errMessage));
- assertNull(res.get("objectId"));
- assertNull(res.get("fieldName"));
+ Assertions.assertNull(res.get("objectId"));
+ Assertions.assertNull(res.get("fieldName"));
}
@Test
@@ -75,7 +74,7 @@ public List<String> getSchemaGroups() {
// doesn't contain any groups that starts with 'tesSome'
List<String> initialGroups = new ArrayList<>();
List<String> res = autocompletionAgent.getSchemaGroupList("tesSome", initialGroups);
- assertEquals(0, res.size());
+ Assertions.assertEquals(0, res.size());
// Empty initialGroups and the list of groups returned by ISchemaRegistryClient
// contains a group that starts with 'tes'
@@ -83,7 +82,7 @@ public List<String> getSchemaGroups() {
res = autocompletionAgent.getSchemaGroupList("tes", initialGroups);
List<String> expected = new ArrayList<>();
expected.add("testGroup");
- assertEquals(1, res.size());
+ Assertions.assertEquals(1, res.size());
assertThat(res, is(expected));
// initialGroups contains one element, list of the groups returned by ISchemaRegistryClient
@@ -93,7 +92,7 @@ public List<String> getSchemaGroups() {
res = autocompletionAgent.getSchemaGroupList("tes", initialGroups);
expected = new ArrayList<>();
expected.add("testGroup");
- assertEquals(1, res.size());
+ Assertions.assertEquals(1, res.size());
assertThat(res, is(expected));
// initialGroups contains one element, list of the groups returned by ISchemaRegistryClient
@@ -104,7 +103,7 @@ public List<String> getSchemaGroups() {
expected = new ArrayList<>();
expected.add("testGroup2");
expected.add("testGroup");
- assertEquals(2, res.size());
+ Assertions.assertEquals(2, res.size());
assertThat(res, is(expected));
}
@@ -129,11 +128,11 @@ public List<String> getSchemaNames(List<String> schemaGroup) {
List<String> res = autocompletionAgent.getSchemaMetadataList("tes", groupList, new ArrayList<>());
List<String> expected = new ArrayList<>();
expected.add("testSchema");
- assertEquals(1, res.size());
+ Assertions.assertEquals(1, res.size());
assertThat(res, is(expected));
res = autocompletionAgent.getSchemaMetadataList("tesSome", groupList, new ArrayList<>());
- assertEquals(0, res.size());
+ Assertions.assertEquals(0, res.size());
}
@Test
@@ -168,10 +167,114 @@ public List<String> getSchemaBranches(String schemaMetadataName) {
List<String> res = autocompletionAgent.getBranchList("tes", groups, schemaList, new ArrayList<>());
List<String> expected = new ArrayList<>();
expected.add("testBranch");
- assertEquals(1, res.size());
+ Assertions.assertEquals(1, res.size());
assertThat(res, is(expected));
res = autocompletionAgent.getSchemaMetadataList("tesSome", schemaList, new ArrayList<>());
- assertEquals(0, res.size());
+ Assertions.assertEquals(0, res.size());
+ }
+
+ @Test
+ void testValidatePattern_validAlphanumeric() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() {
+ public List<String> getSchemaNames(List<String> schemaGroup) {
+ List<String> schemas = new ArrayList<>();
+ schemas.add("mySchema123");
+ return schemas;
+ }
+ };
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ List<String> result = agent.expandSchemaMetadataNameRegex(groups, "mySchema123");
+ Assertions.assertEquals(1, result.size());
+ Assertions.assertEquals("mySchema123", result.get(0));
+ }
+
+ @Test
+ void testValidatePattern_validWildcards() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() {
+ public List<String> getSchemaNames(List<String> schemaGroup) {
+ List<String> schemas = new ArrayList<>();
+ schemas.add("mySchema123");
+ schemas.add("testSchema");
+ return schemas;
+ }
+ };
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ List<String> result = agent.expandSchemaMetadataNameRegex(groups, "my*");
+ Assertions.assertEquals(1, result.size());
+ Assertions.assertEquals("mySchema123", result.get(0));
+ }
+
+ @Test
+ void testValidatePattern_rejectsReDoSPattern() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting();
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ agent.expandSchemaMetadataNameRegex(groups, "(a+)+");
+ });
+ }
+
+ @Test
+ void testValidatePattern_rejectsComplexRegex() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting();
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ agent.expandSchemaMetadataNameRegex(groups, "test{1,5}");
+ });
+ }
+
+ @Test
+ void testValidatePattern_rejectsInjectionAttempt() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting();
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ agent.expandSchemaMetadataNameRegex(groups, "test'; DROP TABLE users--");
+ });
+ }
+
+ @Test
+ void testConvertWildcardToRegex_asterisk() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() {
+ public List<String> getSchemaNames(List<String> schemaGroup) {
+ List<String> schemas = new ArrayList<>();
+ schemas.add("testSchema");
+ schemas.add("prodSchema");
+ return schemas;
+ }
+ };
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ List<String> result = agent.expandSchemaMetadataNameRegex(groups, "test*");
+ Assertions.assertEquals(1, result.size());
+ Assertions.assertEquals("testSchema", result.get(0));
+ }
+
+ @Test
+ void testConvertWildcardToRegex_questionMark() {
+ ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() {
+ public List<String> getSchemaNames(List<String> schemaGroup) {
+ List<String> schemas = new ArrayList<>();
+ schemas.add("schema1");
+ schemas.add("schema12");
+ return schemas;
+ }
+ };
+ AutocompletionAgent agent = new AutocompletionAgent("test", client);
+ List<String> groups = new ArrayList<>();
+ groups.add("testGroup");
+ List<String> result = agent.expandSchemaMetadataNameRegex(groups, "schema?");
+ Assertions.assertEquals(1, result.size());
+ Assertions.assertEquals("schema1", result.get(0));
}
}
diff --git a/plugin-solr/pom.xml b/plugin-solr/pom.xml
index f409216..38b8c08 100644
--- a/plugin-solr/pom.xml
+++ b/plugin-solr/pom.xml
@@ -119,5 +119,23 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter</artifactId>
+ <version>${junit.jupiter.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-junit-jupiter</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java b/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java
index 3966c11..dd7f9be 100644
--- a/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java
+++ b/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java
@@ -418,10 +418,10 @@ private List<String> getCoresList(List<String> ignoreCollectionList) throws Exce
}
private List<String> getFieldList(String collection, List<String> ignoreFieldList) throws Exception {
- // TODO: Best is to get the collections based on the collection value which could contain wild cards
String queryStr = "";
if (collection != null && !collection.isEmpty()) {
+ validateResourceName(collection, "collection name");
queryStr += "/" + collection;
}
@@ -619,6 +619,34 @@ private void login(Map<String, String> configs) {
}
}
+ private void validateResourceName(String resourceName, String resourceType) {
+ if (resourceName == null) {
+ return;
+ }
+
+ if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + RangerSolrConstants.errMessage, null, null);
+
+ LOG.error(msgDesc);
+
+ throw hdpException;
+ }
+
+ if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) {
+ String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters, dots, underscores, hyphens, and wildcards are allowed.";
+ HadoopException hdpException = new HadoopException(msgDesc);
+
+ hdpException.generateResponseDataMap(false, msgDesc, msgDesc + RangerSolrConstants.errMessage, null, null);
+
+ LOG.error(msgDesc);
+
+ throw hdpException;
+ }
+ }
+
private HadoopException createException(String msgDesc, Exception exp) {
HadoopException hdpException = new HadoopException(msgDesc, exp);
final String fullDescription = exp != null ? BaseClient.getMessage(exp) : msgDesc;
diff --git a/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java b/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java
new file mode 100644
index 0000000..0e02258
--- /dev/null
+++ b/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.ranger.services.solr.client;
+
+import org.apache.ranger.plugin.client.HadoopException;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@ExtendWith(MockitoExtension.class)
+@TestMethodOrder(MethodOrderer.MethodName.class)
+public class TestServiceSolrClient {
+ @Test
+ public void test01_validateResourceName_rejectsPathTraversal() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ List<String> pathTraversalInputs = Arrays.asList("../etc/passwd", "../../sensitive", "test/../admin", "collection//malicious", "test\\windows\\path", "..\\..\\config");
+
+ for (String input : pathTraversalInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeGetFieldList(client, input, null),
+ "Path traversal should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Path traversal"),
+ "Error should indicate path traversal for: " + input);
+ }
+ }
+
+ @Test
+ public void test02_validateResourceName_rejectsSpecialCharacters() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ List<String> invalidInputs = Arrays.asList("'; DROP TABLE users; --", "collection<script>alert(1)</script>", "test@collection", "collection#name", "test!collection", "collection&name", "collection(with)parens", "collection{with}braces", "collection[with]brackets", "collection$name", "collection%encoded", "collection name", "collection\ttab", "collection\nnewline");
+
+ for (String input : invalidInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeGetFieldList(client, input, null),
+ "Special characters should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ @Test
+ public void test03_validateResourceName_acceptsValidNames() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ List<String> validInputs = Arrays.asList("collection", "collection_name", "collection123", "COLLECTION", "Collection_Name_123", "_collection", "collection_", "collection.name", "collection-name", "collection*", "coll*");
+
+ Method validateMethod = ServiceSolrClient.class.getDeclaredMethod("validateResourceName", String.class, String.class);
+ validateMethod.setAccessible(true);
+
+ for (String input : validInputs) {
+ try {
+ validateMethod.invoke(client, input, "collection name");
+ } catch (Exception e) {
+ throw new AssertionError("Valid collection name should not throw exception: " + input, e);
+ }
+ }
+ }
+
+ @Test
+ public void test04_validateResourceName_rejectsNullByteInjection() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeGetFieldList(client, "collection\0null", null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+
+ @Test
+ public void test05_validateResourceName_rejectsCommandInjection() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ List<String> commandInjectionInputs = Arrays.asList("collection;rm -rf /", "collection|cat /etc/passwd", "collection`whoami`", "collection$(whoami)", "collection&&ls");
+
+ for (String input : commandInjectionInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeGetFieldList(client, input, null),
+ "Command injection should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ @Test
+ public void test06_validateResourceName_rejectsUrlEncoded() throws Exception {
+ Map<String, String> configs = new HashMap<>();
+ configs.put("username", "test");
+ configs.put("password", "test");
+ NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false);
+
+ List<String> encodedInputs = Arrays.asList("%2e%2e%2f", "collection%00", "test%20space");
+
+ for (String input : encodedInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> invokeGetFieldList(client, input, null),
+ "URL encoded attack should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+
+ private List<String> invokeGetFieldList(ServiceSolrClient client, String collection, List<String> ignoreList) throws Exception {
+ Method method = ServiceSolrClient.class.getDeclaredMethod("getFieldList", String.class, List.class);
+ method.setAccessible(true);
+ try {
+ return (List<String>) method.invoke(client, collection, ignoreList);
+ } catch (java.lang.reflect.InvocationTargetException e) {
+ Throwable cause = e.getCause();
+ if (cause instanceof HadoopException) {
+ throw (HadoopException) cause;
+ }
+ throw e;
+ }
+ }
+
+ private static class NoopServiceSolrClient extends ServiceSolrClient {
+ public NoopServiceSolrClient(String serviceName, Map<String, String> configs, String url, boolean isSolrCloud) {
+ super(serviceName, configs, url, isSolrCloud);
+ }
+ }
+}
diff --git a/plugin-trino/pom.xml b/plugin-trino/pom.xml
index 3db390b..cfb6d62 100644
--- a/plugin-trino/pom.xml
+++ b/plugin-trino/pom.xml
@@ -167,6 +167,18 @@
<version>${junit.jupiter.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-junit-jupiter</artifactId>
+ <version>${mockito.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<testResources>
diff --git a/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java b/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java
index 7325347..ef98869 100644
--- a/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java
+++ b/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java
@@ -14,7 +14,6 @@
package org.apache.ranger.services.trino.client;
import org.apache.commons.io.FilenameUtils;
-import org.apache.commons.lang3.StringUtils;
import org.apache.ranger.plugin.client.BaseClient;
import org.apache.ranger.plugin.client.HadoopConfigHolder;
import org.apache.ranger.plugin.client.HadoopException;
@@ -115,6 +114,8 @@ public List<String> getSchemaList(String needle, List<String> catalogs, List<Str
ret = getSchemas(ndl, cats, shms);
} catch (HadoopException he) {
LOG.error("<== TrinoClient.getSchemaList() :Unable to get the Schema List", he);
+
+ throw he;
}
return ret;
@@ -302,44 +303,41 @@ private List<String> getCatalogs(String needle, List<String> catalogs)
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = "SHOW CATALOGS";
+ ResultSet rs = null;
try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- // Cannot use a prepared statement for this as trino does not support that
- sql += " LIKE '" + escapeSql(needle) + "%'";
- }
-
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ validateSqlIdentifier(needle, "catalog pattern");
+ String catalogPattern = convertToSqlPattern(needle);
+ rs = con.getMetaData().getCatalogs();
while (rs.next()) {
- String catalogName = rs.getString(1);
+ String catalogName = rs.getString("TABLE_CAT");
if (catalogs != null && catalogs.contains(catalogName)) {
continue;
}
- ret.add(catalogName);
+ if (catalogPattern == null || catalogPattern.equals("%") || matchesSqlPattern(catalogName, catalogPattern)) {
+ ret.add(catalogName);
+ }
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve catalog list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
throw hdpException;
} catch (SQLException se) {
- String msg = "Unable to execute SQL [" + sql + "]. ";
+ String msg = "Unable to retrieve catalog list. ";
HadoopException he = new HadoopException(msg, se);
he.generateResponseDataMap(false, getMessage(se), msg + ERR_MSG, null, null);
throw he;
+ } catch (HadoopException he) {
+ throw he;
} finally {
close(rs);
- close(stat);
}
}
return ret;
@@ -350,43 +348,31 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
try {
+ validateSqlIdentifier(needle, "schema pattern");
+ String schemaPattern = convertToSqlPattern(needle);
if (catalogs != null && !catalogs.isEmpty()) {
for (String catalog : catalogs) {
- sql = "SHOW SCHEMAS FROM \"" + escapeSql(catalog) + "\"";
+ validateSqlIdentifier(catalog, "catalog name");
+ rs = con.getMetaData().getSchemas(catalog, schemaPattern);
- try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- sql += " LIKE '" + escapeSql(needle) + "%'";
+ while (rs.next()) {
+ String schema = rs.getString("TABLE_SCHEM");
+
+ if (schemas != null && schemas.contains(schema)) {
+ continue;
}
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
-
- while (rs.next()) {
- String schema = rs.getString(1);
-
- if (schemas != null && schemas.contains(schema)) {
- continue;
- }
-
- ret.add(schema);
- }
- } finally {
- close(rs);
- close(stat);
-
- rs = null;
- stat = null;
+ ret.add(schema);
}
+ close(rs);
+ rs = null;
}
}
} catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Time Out, Unable to retrieve schema list.";
HadoopException hdpException = new HadoopException(msgDesc, sqlt);
hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
@@ -397,7 +383,7 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
throw hdpException;
} catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
+ String msgDesc = "Unable to retrieve schema list.";
HadoopException hdpException = new HadoopException(msgDesc, sqle);
hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
@@ -407,6 +393,10 @@ private List<String> getSchemas(String needle, List<String> catalogs, List<Strin
}
throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -418,65 +408,58 @@ private List<String> getTables(String needle, List<String> catalogs, List<String
List<String> ret = new ArrayList<>();
if (con != null) {
- Statement stat = null;
- ResultSet rs = null;
- String sql = null;
+ ResultSet rs = null;
- if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) {
- try {
+ try {
+ validateSqlIdentifier(needle, "table pattern");
+ String tablePattern = convertToSqlPattern(needle);
+ if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) {
for (String catalog : catalogs) {
+ validateSqlIdentifier(catalog, "catalog name");
for (String schema : schemas) {
- sql = "SHOW tables FROM \"" + escapeSql(catalog) + "\".\"" + escapeSql(schema) + "\"";
+ validateSqlIdentifier(schema, "schema name");
+ rs = con.getMetaData().getTables(catalog, schema, tablePattern, new String[] {"TABLE", "VIEW"});
- try {
- if (needle != null && !needle.isEmpty() && !needle.equals("*")) {
- sql += " LIKE '" + escapeSql(needle) + "%'";
+ while (rs.next()) {
+ String table = rs.getString("TABLE_NAME");
+
+ if (tables != null && tables.contains(table)) {
+ continue;
}
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
-
- while (rs.next()) {
- String table = rs.getString(1);
-
- if (tables != null && tables.contains(table)) {
- continue;
- }
-
- ret.add(table);
- }
- } finally {
- close(rs);
- close(stat);
-
- rs = null;
- stat = null;
+ ret.add(table);
}
+ close(rs);
+ rs = null;
}
}
- } catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqlt);
-
- hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
-
- if (LOG.isDebugEnabled()) {
- LOG.debug("<== TrinoClient.getTables() Error : ", sqlt);
- }
-
- throw hdpException;
- } catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqle);
-
- hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
-
- if (LOG.isDebugEnabled()) {
- LOG.debug("<== TrinoClient.getTables() Error : ", sqle);
- }
-
- throw hdpException;
}
+ } catch (SQLTimeoutException sqlt) {
+ String msgDesc = "Time Out, Unable to retrieve table list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqlt);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("<== TrinoClient.getTables() Error : ", sqlt);
+ }
+
+ throw hdpException;
+ } catch (SQLException sqle) {
+ String msgDesc = "Unable to retrieve table list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqle);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("<== TrinoClient.getTables() Error : ", sqle);
+ }
+
+ throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -490,72 +473,68 @@ private List<String> getColumns(String needle, List<String> catalogs, List<Strin
if (con != null) {
String regex = null;
ResultSet rs = null;
- String sql = null;
- Statement stat = null;
if (needle != null && !needle.isEmpty()) {
regex = needle;
}
- if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty() && tables != null && !tables.isEmpty()) {
- try {
+ try {
+ validateSqlIdentifier(needle, "column pattern");
+ String columnPattern = convertToSqlPattern(needle);
+ if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty() && tables != null && !tables.isEmpty()) {
for (String catalog : catalogs) {
+ validateSqlIdentifier(catalog, "catalog name");
for (String schema : schemas) {
+ validateSqlIdentifier(schema, "schema name");
for (String table : tables) {
- sql = "SHOW COLUMNS FROM \"" + escapeSql(catalog) + "\"." +
- "\"" + escapeSql(schema) + "\"." +
- "\"" + escapeSql(table) + "\"";
+ validateSqlIdentifier(table, "table name");
+ rs = con.getMetaData().getColumns(catalog, schema, table, columnPattern);
- try {
- stat = con.createStatement();
- rs = stat.executeQuery(sql);
+ while (rs.next()) {
+ String column = rs.getString("COLUMN_NAME");
- while (rs.next()) {
- String column = rs.getString(1);
-
- if (columns != null && columns.contains(column)) {
- continue;
- }
-
- if (regex == null) {
- ret.add(column);
- } else if (FilenameUtils.wildcardMatch(column, regex)) {
- ret.add(column);
- }
+ if (columns != null && columns.contains(column)) {
+ continue;
}
- } finally {
- close(rs);
- close(stat);
- stat = null;
- rs = null;
+ if (regex == null) {
+ ret.add(column);
+ } else if (FilenameUtils.wildcardMatch(column, regex)) {
+ ret.add(column);
+ }
}
+ close(rs);
+ rs = null;
}
}
}
- } catch (SQLTimeoutException sqlt) {
- String msgDesc = "Time Out, Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqlt);
-
- hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
-
- if (LOG.isDebugEnabled()) {
- LOG.debug("<== TrinoClient.getColumns() Error : ", sqlt);
- }
-
- throw hdpException;
- } catch (SQLException sqle) {
- String msgDesc = "Unable to execute SQL [" + sql + "].";
- HadoopException hdpException = new HadoopException(msgDesc, sqle);
-
- hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
-
- if (LOG.isDebugEnabled()) {
- LOG.debug("<== TrinoClient.getColumns() Error : ", sqle);
- }
-
- throw hdpException;
}
+ } catch (SQLTimeoutException sqlt) {
+ String msgDesc = "Time Out, Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqlt);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("<== TrinoClient.getColumns() Error : ", sqlt);
+ }
+
+ throw hdpException;
+ } catch (SQLException sqle) {
+ String msgDesc = "Unable to retrieve column list.";
+ HadoopException hdpException = new HadoopException(msgDesc, sqle);
+
+ hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("<== TrinoClient.getColumns() Error : ", sqle);
+ }
+
+ throw hdpException;
+ } catch (HadoopException he) {
+ throw he;
+ } finally {
+ close(rs);
}
}
@@ -571,11 +550,4 @@ private void close(Connection con) {
LOG.error("Unable to close Trino SQL connection", e);
}
}
-
- private static String escapeSql(String str) {
- if (str == null) {
- return null;
- }
- return StringUtils.replace(str, "'", "''");
- }
}
diff --git a/plugin-trino/src/test/java/org/apache/ranger/services/trino/client/TestTrinoClient.java b/plugin-trino/src/test/java/org/apache/ranger/services/trino/client/TestTrinoClient.java
new file mode 100644
index 0000000..6d1276d
--- /dev/null
+++ b/plugin-trino/src/test/java/org/apache/ranger/services/trino/client/TestTrinoClient.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ranger.services.trino.client;
+
+import org.apache.ranger.plugin.client.HadoopException;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import javax.security.auth.Subject;
+
+import java.lang.reflect.Field;
+import java.security.PrivilegedAction;
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.DriverManager;
+import java.sql.ResultSet;
+import java.sql.SQLTimeoutException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+@TestMethodOrder(MethodOrderer.MethodName.class)
+public class TestTrinoClient {
+ private TrinoClient createMockedClient() throws Exception {
+ return createMockedClient(new HashMap<>());
+ }
+
+ private TrinoClient createMockedClient(Map<String, String> props) throws Exception {
+ Map<String, String> config = new HashMap<>();
+ config.put("username", "test");
+ config.put("password", "test");
+ config.put("jdbc.driverClassName", "io.trino.jdbc.TrinoDriver");
+ config.put("jdbc.url", "jdbc:trino://localhost:8080");
+ config.putAll(props);
+
+ NoConnectionTrinoClient client = new NoConnectionTrinoClient("svc", config);
+ return client;
+ }
+
+ @Test
+ public void test01_getCatalogList_normalOperation() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenReturn(rs);
+ when(rs.next()).thenReturn(true, true, false);
+ when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2");
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ List<String> out = client.getCatalogList("*", null);
+ assertEquals(Arrays.asList("catalog1", "catalog2"), out);
+ }
+ }
+
+ @Test
+ public void test02_getCatalogList_withExcludeList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenReturn(rs);
+ when(rs.next()).thenReturn(true, true, false);
+ when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2");
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ List<String> out = client.getCatalogList("*", Collections.singletonList("catalog1"));
+ assertEquals(Collections.singletonList("catalog2"), out);
+ }
+ }
+
+ @Test
+ public void test03_getCatalogList_timeoutThrowsHadoopException() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getCatalogs()).thenThrow(SQLTimeoutException.class);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ assertThrows(HadoopException.class, () -> client.getCatalogList("*", null));
+ }
+ }
+
+ @Test
+ public void test04_validateSqlIdentifier_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+
+ List<String> maliciousInputs = Arrays.asList("'; DROP TABLE users; --", "test\" OR 1=1 --", "catalog'; DELETE FROM users; --", "../../../etc/passwd", "test<script>alert(1)</script>");
+
+ for (String input : maliciousInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getCatalogList(input, null),
+ "SQL injection attempt should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+ }
+
+ @Test
+ public void test05_validateSqlIdentifier_rejectsSpecialCharacters() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+
+ List<String> invalidInputs = Arrays.asList("test@catalog", "catalog#name", "table!name", "name(with)parens");
+
+ for (String input : invalidInputs) {
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getCatalogList(input, null),
+ "Special characters should be rejected: " + input);
+ assertTrue(ex.getMessage().contains("Invalid"),
+ "Error should indicate invalid input for: " + input);
+ }
+ }
+ }
+
+ @Test
+ public void test06_schemaValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ Field fCon = TrinoClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ fCon.set(client, mockCon);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getSchemaList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test07_tableValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test08_columnValidation_rejectsSqlInjection() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getColumnList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), Collections.singletonList("table1"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test09_catalogName_validateInList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getSchemas(anyString(), anyString())).thenReturn(rs);
+ when(rs.next()).thenReturn(false);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ Field fCon = TrinoClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ fCon.set(client, mockCon);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getSchemaList("schema1", Arrays.asList("catalog1", "'; DROP TABLE --"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ @Test
+ public void test10_schemaName_validateInList() throws Exception {
+ try (MockedStatic<DriverManager> dmStatic = Mockito.mockStatic(DriverManager.class);
+ MockedStatic<Subject> subjectStatic = Mockito.mockStatic(Subject.class)) {
+ Connection mockCon = mock(Connection.class);
+ DatabaseMetaData metadata = mock(DatabaseMetaData.class);
+ ResultSet rs = mock(ResultSet.class);
+ dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon);
+ when(mockCon.getMetaData()).thenReturn(metadata);
+ when(metadata.getTables(anyString(), anyString(), anyString(), any())).thenReturn(rs);
+ when(rs.next()).thenReturn(false);
+ subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class)))
+ .thenAnswer(inv -> {
+ PrivilegedAction<?> action = inv.getArgument(1);
+ return action.run();
+ });
+
+ TrinoClient client = createMockedClient();
+ Field fCon = TrinoClient.class.getDeclaredField("con");
+ fCon.setAccessible(true);
+ fCon.set(client, mockCon);
+
+ HadoopException ex = assertThrows(HadoopException.class,
+ () -> client.getTableList("table1", Collections.singletonList("catalog1"), Arrays.asList("schema1", "'; DROP --"), null));
+ assertTrue(ex.getMessage().contains("Invalid"));
+ }
+ }
+
+ private static class NoConnectionTrinoClient extends TrinoClient {
+ public NoConnectionTrinoClient(String serviceName, Map<String, String> connectionProperties) {
+ super(serviceName, connectionProperties);
+ }
+
+ @Override
+ protected Subject getLoginSubject() {
+ Subject subject = new Subject();
+ return subject;
+ }
+
+ @Override
+ protected void login() {
+ }
+ }
+}