[MINOR] Tighten origin and content-type handling in REST/WebSocket layer
## What is this PR for?
Apply stricter defaults to the request-handling layer for tighter out-of-the-box behavior:
- `CorsFilter` blocks state-changing methods (POST/PUT/DELETE/PATCH) and cross-origin preflight requests when the `Origin` header is not in the configured allow-list. `Access-Control-Allow-Credentials` is only sent when the `Origin` is allowed.
- The default value of `zeppelin.server.allowed.origins` changes from `*` to empty so cross-origin browser access must be explicitly enabled. **Operators relying on the previous default need to set this back to `*` or to specific origin(s).** Same-origin / same-host and non-browser clients are unaffected.
- A new Jersey request filter restricts REST request bodies on state-changing methods to `application/json`, `application/x-www-form-urlencoded`, or `multipart/form-data`; other media types are rejected with `415`.
- The default `shiro.ini.template` now sets `cookie.sameSite = LAX`.
- `ZeppelinClient.addParagraph` and `updateParagraph` now send an explicit `Content-Type: application/json` header so they pass the new filter.
- `CorsUtils.isValidOrigin` normalizes the `Origin` header to lowercase before the allow-list membership check, mirroring how the configured origins are stored, so case differences in the `Origin` header do not produce false rejections.
- A small `HttpMethods` utility holds the shared `STATE_CHANGING` method set used by both the servlet filter and the Jersey filter.
## What type of PR is it?
Improvement
## Todos
- [ ] CI green
## Questions
- None
## Screenshots (if appropriate)
N/A
Closes #5229 from jongyoul/minor-cors-hardening.
Signed-off-by: Jongyoul Lee <jongyoul@gmail.com>
diff --git a/conf/shiro.ini.template b/conf/shiro.ini.template
index 6721d17..24b18e2 100644
--- a/conf/shiro.ini.template
+++ b/conf/shiro.ini.template
@@ -87,6 +87,9 @@
cookie = org.apache.shiro.web.servlet.SimpleCookie
cookie.name = JSESSIONID
cookie.httpOnly = true
+### Restrict the session cookie to same-site requests by default. Set to NONE only when
+### Zeppelin is intentionally embedded into a different origin (and 'cookie.secure = true').
+cookie.sameSite = LAX
### Uncomment the below line only when Zeppelin is running over HTTPS
#cookie.secure = true
sessionManager.sessionIdCookie = $cookie
diff --git a/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java b/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java
index a21ef99..f5db6f1 100644
--- a/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java
+++ b/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java
@@ -139,6 +139,7 @@
public SessionInfo newSession(String interpreter) throws Exception {
HttpResponse<JsonNode> response = Unirest
.post("/session")
+ .header("Content-Type", "application/json")
.queryString("interpreter", interpreter)
.asJson();
checkResponse(response);
@@ -307,6 +308,7 @@
bodyObject.put("defaultInterpreterGroup", defaultInterpreterGroup);
HttpResponse<JsonNode> response = Unirest
.post("/notebook")
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
@@ -346,6 +348,7 @@
HttpResponse<JsonNode> response = Unirest
.post("/notebook/{noteId}")
.routeParam("noteId", noteId)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
@@ -362,6 +365,7 @@
HttpResponse<JsonNode> response = Unirest
.put("/notebook/{noteId}/rename")
.routeParam("noteId", noteId)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
@@ -397,6 +401,7 @@
bodyObject.put("notePath", notePath);
HttpResponse<JsonNode> response = Unirest
.post("/notebook/getByPath")
+ .header("Content-Type", "application/json")
.body(bodyObject)
.asJson();
return extractNoteResultFromResponse(response);
@@ -494,6 +499,7 @@
.routeParam("noteId", noteId)
.queryString("blocking", "false")
.queryString("isolated", "true")
+ .header("Content-Type", "application/json")
.body(bodyObject)
.asJson();
checkResponse(response);
@@ -531,6 +537,7 @@
HttpResponse<JsonNode> response = Unirest
.post("/notebook/import")
.queryString("notePath", notePath)
+ .header("Content-Type", "application/json")
.body(bodyObject)
.asJson();
checkResponse(response);
@@ -592,6 +599,7 @@
bodyObject.put("text", text);
HttpResponse<JsonNode> response = Unirest.post("/notebook/{noteId}/paragraph")
.routeParam("noteId", noteId)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
@@ -617,6 +625,7 @@
HttpResponse<JsonNode> response = Unirest.put("/notebook/{noteId}/paragraph/{paragraphId}")
.routeParam("noteId", noteId)
.routeParam("paragraphId", paragraphId)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
@@ -708,6 +717,7 @@
.routeParam("noteId", noteId)
.routeParam("paragraphId", paragraphId)
.queryString("sessionId", sessionId)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
@@ -773,6 +783,7 @@
HttpResponse<JsonNode> response = Unirest
.post("/notebook/{noteId}/paragraph/next")
.routeParam("noteId", noteId)
+ .header("Content-Type", "application/json")
.queryString("maxParagraph", maxParagraph)
.asJson();
checkResponse(response);
@@ -892,6 +903,7 @@
HttpResponse<JsonNode> response = Unirest
.put("/interpreter/setting/restart/{interpreter}")
.routeParam("interpreter", interpreter)
+ .header("Content-Type", "application/json")
.body(bodyObject.toString())
.asJson();
checkResponse(response);
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java b/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
index 3b9ebee..958e4a5 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
@@ -27,6 +27,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
@@ -726,7 +727,8 @@
return Collections.emptyList();
}
- return Arrays.asList(getString(ConfVars.ZEPPELIN_ALLOWED_ORIGINS).toLowerCase().split(","));
+ return Arrays.asList(
+ getString(ConfVars.ZEPPELIN_ALLOWED_ORIGINS).toLowerCase(Locale.ROOT).split(","));
}
public String getWebsocketMaxTextMessageSize() {
@@ -1045,7 +1047,9 @@
"https://github.com/yarnpkg/yarn/releases/download/"),
// Allows a way to specify a ',' separated list of allowed origins for rest and websockets
// i.e. http://localhost:8080
- ZEPPELIN_ALLOWED_ORIGINS("zeppelin.server.allowed.origins", "*"),
+ // Default is empty (no cross-origin requests permitted). Operators that need cross-origin
+ // access must set this explicitly to the trusted origin(s) or to "*".
+ ZEPPELIN_ALLOWED_ORIGINS("zeppelin.server.allowed.origins", ""),
ZEPPELIN_USERNAME_FORCE_LOWERCASE("zeppelin.username.force.lowercase", false),
ZEPPELIN_CREDENTIALS_PERSIST("zeppelin.credentials.persist", true),
ZEPPELIN_CREDENTIALS_ENCRYPT_KEY("zeppelin.credentials.encryptKey", null),
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilter.java
new file mode 100644
index 0000000..3abe178
--- /dev/null
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilter.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zeppelin.rest.filter;
+
+import java.util.Locale;
+import java.util.Set;
+
+import jakarta.ws.rs.container.ContainerRequestContext;
+import jakarta.ws.rs.container.ContainerRequestFilter;
+import jakarta.ws.rs.core.MediaType;
+import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.ext.Provider;
+
+import org.apache.zeppelin.utils.HttpMethods;
+
+/**
+ * Restricts the request body media types accepted by REST endpoints to a small allow-list.
+ * Requests carrying state-changing methods (POST/PUT/DELETE/PATCH) with a body must use
+ * {@code application/json}, {@code application/x-www-form-urlencoded}, or
+ * {@code multipart/form-data}; anything else is rejected with 415.
+ */
+@Provider
+public class AllowedContentTypeFilter implements ContainerRequestFilter {
+
+ private static final Set<String> ALLOWED_TYPES = Set.of(
+ "application/json",
+ "application/x-www-form-urlencoded",
+ "multipart/form-data");
+
+ @Override
+ public void filter(ContainerRequestContext ctx) {
+ String method = ctx.getMethod();
+ if (method == null || !HttpMethods.STATE_CHANGING.contains(method.toUpperCase(Locale.ROOT))) {
+ return;
+ }
+ if (!ctx.hasEntity()) {
+ return;
+ }
+ MediaType mt = ctx.getMediaType();
+ if (mt == null || !ALLOWED_TYPES.contains(baseType(mt))) {
+ ctx.abortWith(
+ Response.status(Response.Status.UNSUPPORTED_MEDIA_TYPE)
+ .entity("Unsupported Content-Type")
+ .build());
+ }
+ }
+
+ private static String baseType(MediaType mt) {
+ return (mt.getType() + "/" + mt.getSubtype()).toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
index 51906cf..24d030b 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
@@ -18,6 +18,8 @@
import java.io.IOException;
import java.net.URISyntaxException;
+import java.net.UnknownHostException;
+import java.util.Locale;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
@@ -28,6 +30,7 @@
import jakarta.servlet.http.HttpServletResponse;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.utils.CorsUtils;
+import org.apache.zeppelin.utils.HttpMethods;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -46,33 +49,56 @@
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
- String sourceHost = ((HttpServletRequest) request).getHeader("Origin");
- String origin = "";
+ HttpServletRequest httpRequest = (HttpServletRequest) request;
+ HttpServletResponse httpResponse = (HttpServletResponse) response;
- try {
- if (CorsUtils.isValidOrigin(sourceHost, zConf)) {
- origin = sourceHost;
+ String sourceHost = httpRequest.getHeader(CorsUtils.HEADER_ORIGIN);
+ String method = httpRequest.getMethod();
+ String allowedOrigin = "";
+
+ if (sourceHost != null && !sourceHost.isEmpty()) {
+ try {
+ if (CorsUtils.isValidOrigin(sourceHost, zConf)) {
+ allowedOrigin = sourceHost;
+ }
+ } catch (URISyntaxException e) {
+ LOGGER.warn("Rejecting request with malformed Origin header: {}", sourceHost);
+ } catch (UnknownHostException e) {
+ // Treat as not allowed so a misconfigured host doesn't surface as a 500.
+ LOGGER.warn("Cannot resolve local host for Origin check; treating Origin {} as not allowed",
+ sourceHost);
}
- } catch (URISyntaxException e) {
- LOGGER.error("Exception in WebDriverManager while getWebDriver ", e);
+
+ if (allowedOrigin.isEmpty() && (isCorsPreflight(httpRequest) || isStateChanging(method))) {
+ LOGGER.warn("Blocking cross-origin {} request from disallowed Origin: {}",
+ method, sourceHost);
+ httpResponse.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed");
+ return;
+ }
}
- if (((HttpServletRequest) request).getMethod().equals("OPTIONS")) {
- HttpServletResponse resp = ((HttpServletResponse) response);
- addCorsHeaders(resp, origin);
+ addCorsHeaders(httpResponse, allowedOrigin);
+ if (isCorsPreflight(httpRequest)) {
return;
}
-
- if (response instanceof HttpServletResponse) {
- HttpServletResponse alteredResponse = ((HttpServletResponse) response);
- addCorsHeaders(alteredResponse, origin);
- }
filterChain.doFilter(request, response);
}
+ private static boolean isCorsPreflight(HttpServletRequest request) {
+ return "OPTIONS".equalsIgnoreCase(request.getMethod())
+ && request.getHeader("Access-Control-Request-Method") != null;
+ }
+
+ private static boolean isStateChanging(String method) {
+ return method != null
+ && HttpMethods.STATE_CHANGING.contains(method.toUpperCase(Locale.ROOT));
+ }
+
private void addCorsHeaders(HttpServletResponse response, String origin) {
response.setHeader("Access-Control-Allow-Origin", origin);
- response.setHeader("Access-Control-Allow-Credentials", "true");
+ if (!origin.isEmpty()) {
+ response.setHeader("Access-Control-Allow-Credentials", "true");
+ }
response.setHeader("Access-Control-Allow-Headers", "authorization,Content-Type");
response.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, HEAD, DELETE");
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java
index f4cff80..b43fe01 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java
@@ -36,6 +36,7 @@
import org.apache.zeppelin.rest.ZeppelinRestApi;
import org.apache.zeppelin.rest.exception.WebApplicationExceptionMapper;
import org.apache.zeppelin.rest.filter.CacheControlFilter;
+import org.apache.zeppelin.rest.filter.AllowedContentTypeFilter;
import org.glassfish.jersey.server.ServerProperties;
public class RestApiApplication extends Application {
@@ -60,6 +61,7 @@
s.add(GsonProvider.class);
// Filter
s.add(CacheControlFilter.class);
+ s.add(AllowedContentTypeFilter.class);
return s;
}
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java
index 363bfaf..1d20783 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java
@@ -20,6 +20,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
+import java.util.Locale;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
public class CorsUtils {
@@ -36,15 +37,19 @@
if (sourceHost != null && !sourceHost.isEmpty()) {
sourceUriHost = new URI(sourceHost).getHost();
- sourceUriHost = (sourceUriHost == null) ? "" : sourceUriHost.toLowerCase();
+ sourceUriHost = (sourceUriHost == null) ? "" : sourceUriHost.toLowerCase(Locale.ROOT);
}
- sourceUriHost = sourceUriHost.toLowerCase();
- String currentHost = InetAddress.getLocalHost().getHostName().toLowerCase();
+ String currentHost = InetAddress.getLocalHost().getHostName().toLowerCase(Locale.ROOT);
+ // getAllowedOrigins() returns lowercased entries; normalize sourceHost the same way
+ // before the membership check so case differences in the Origin header do not produce
+ // false rejections of explicitly configured origins.
+ String normalizedOrigin =
+ sourceHost == null ? "" : sourceHost.toLowerCase(Locale.ROOT);
return zConf.getAllowedOrigins().contains("*")
|| currentHost.equals(sourceUriHost)
|| "localhost".equals(sourceUriHost)
- || zConf.getAllowedOrigins().contains(sourceHost);
+ || zConf.getAllowedOrigins().contains(normalizedOrigin);
}
}
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java
new file mode 100644
index 0000000..440ab7f
--- /dev/null
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zeppelin.utils;
+
+import java.util.Set;
+
+public final class HttpMethods {
+
+ private HttpMethods() {
+ }
+
+ public static final Set<String> STATE_CHANGING = Set.of("POST", "PUT", "DELETE", "PATCH");
+}
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java
index a0ab71e..63fcd0d 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java
@@ -53,7 +53,7 @@
ZeppelinConfiguration zConf = ZeppelinConfiguration.load("zeppelin-test-site.xml");
List<String> origins = zConf.getAllowedOrigins();
- assertEquals(1, origins.size());
+ assertTrue(origins.isEmpty());
}
@Test
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java
index 9662421..6e3b4d8 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java
@@ -211,7 +211,7 @@
LOGGER.info("Connecting to {}", getUrlToTest(zConf) + path);
HttpPut httpPut = new HttpPut(getUrlToTest(zConf) + path);
httpPut.addHeader("Origin", getUrlToTest(zConf));
- httpPut.setEntity(new StringEntity(body, ContentType.TEXT_PLAIN));
+ httpPut.setEntity(new StringEntity(body, ContentType.APPLICATION_JSON));
if (userAndPasswordAreNotBlank(user, pwd)) {
httpPut.setHeader("Cookie", "JSESSIONID=" + getCookie(user, pwd));
}
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilterTest.java
new file mode 100644
index 0000000..b5f3f03
--- /dev/null
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/AllowedContentTypeFilterTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zeppelin.rest.filter;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import jakarta.ws.rs.container.ContainerRequestContext;
+import jakarta.ws.rs.core.MediaType;
+import jakarta.ws.rs.core.Response;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.mockito.ArgumentCaptor;
+
+class AllowedContentTypeFilterTest {
+
+ private final AllowedContentTypeFilter filter = new AllowedContentTypeFilter();
+
+ @Test
+ void getRequestPasses() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("GET");
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+
+ @Test
+ void postWithoutBodyPasses() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(false);
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+
+ @Test
+ void postJsonPasses() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(MediaType.APPLICATION_JSON_TYPE);
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+
+ @Test
+ void postFormUrlEncodedPasses() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(MediaType.APPLICATION_FORM_URLENCODED_TYPE);
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+
+ @Test
+ void postMultipartFormDataPasses() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(MediaType.MULTIPART_FORM_DATA_TYPE);
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+
+ @Test
+ void postTextPlainRejected() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(MediaType.TEXT_PLAIN_TYPE);
+
+ filter.filter(ctx);
+
+ ArgumentCaptor<Response> captor = ArgumentCaptor.forClass(Response.class);
+ verify(ctx, times(1)).abortWith(captor.capture());
+ org.junit.jupiter.api.Assertions.assertEquals(
+ Response.Status.UNSUPPORTED_MEDIA_TYPE.getStatusCode(),
+ captor.getValue().getStatus());
+ }
+
+ @Test
+ void postWithoutContentTypeRejected() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(null);
+
+ filter.filter(ctx);
+
+ verify(ctx, times(1)).abortWith(any());
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {"PUT", "DELETE", "PATCH"})
+ void stateChangingTextPlainRejected(String method) {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn(method);
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(MediaType.TEXT_PLAIN_TYPE);
+
+ filter.filter(ctx);
+
+ verify(ctx, times(1)).abortWith(any());
+ }
+
+ @Test
+ void contentTypeWithCharsetParameterAllowed() {
+ ContainerRequestContext ctx = mock(ContainerRequestContext.class);
+ when(ctx.getMethod()).thenReturn("POST");
+ when(ctx.hasEntity()).thenReturn(true);
+ when(ctx.getMediaType()).thenReturn(
+ MediaType.valueOf("application/json; charset=UTF-8"));
+
+ filter.filter(ctx);
+
+ verify(ctx, never()).abortWith(any());
+ }
+}
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java
index 0a6f1ed..1048c0b 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java
@@ -17,15 +17,21 @@
package org.apache.zeppelin.server;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.junit.jupiter.api.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
@@ -33,17 +39,16 @@
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
/**
* Basic CORS REST API tests.
*/
class CorsFilterTest {
- public static String[] headers = new String[8];
- public static Integer count = 0;
@Test
- @SuppressWarnings("rawtypes")
void validCorsFilterTest() throws IOException, ServletException {
CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
@@ -51,24 +56,14 @@
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
- when(mockRequest.getServerName()).thenReturn("localhost");
- count = 0;
-
- doAnswer(new Answer() {
- @Override
- public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
- headers[count] = invocationOnMock.getArguments()[1].toString();
- count++;
- return null;
- }
- }).when(mockResponse).setHeader(anyString(), anyString());
+ Map<String, String> setHeaders = recordSetHeaders(mockResponse);
filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
- assertEquals("http://localhost:8080", headers[0]);
+
+ assertEquals("http://localhost:8080", setHeaders.get("Access-Control-Allow-Origin"));
}
@Test
- @SuppressWarnings("rawtypes")
void invalidCorsFilterTest() throws IOException, ServletException {
CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
@@ -76,18 +71,118 @@
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn("http://evillocalhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
- when(mockRequest.getServerName()).thenReturn("evillocalhost");
-
- doAnswer(new Answer() {
- @Override
- public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
- headers[count] = invocationOnMock.getArguments()[1].toString();
- count++;
- return null;
- }
- }).when(mockResponse).setHeader(anyString(), anyString());
+ Map<String, String> setHeaders = recordSetHeaders(mockResponse);
filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
- assertEquals("", headers[0]);
+
+ assertEquals("", setHeaders.get("Access-Control-Allow-Origin"));
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {"POST", "PUT", "DELETE", "PATCH"})
+ void crossOriginStateChangingBlocked(String method) throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com");
+ when(mockRequest.getMethod()).thenReturn(method);
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse).sendError(eq(HttpServletResponse.SC_FORBIDDEN), anyString());
+ verify(mockedFilterChain, never()).doFilter(mockRequest, mockResponse);
+ }
+
+ @Test
+ void crossOriginPreflightBlocked() throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com");
+ when(mockRequest.getHeader("Access-Control-Request-Method")).thenReturn("POST");
+ when(mockRequest.getMethod()).thenReturn("OPTIONS");
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse).sendError(eq(HttpServletResponse.SC_FORBIDDEN), anyString());
+ verify(mockedFilterChain, never()).doFilter(mockRequest, mockResponse);
+ }
+
+ @Test
+ void allowedOriginPostPasses() throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://localhost");
+ when(mockRequest.getMethod()).thenReturn("POST");
+ Map<String, String> setHeaders = recordSetHeaders(mockResponse);
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse, never()).sendError(anyInt(), anyString());
+ verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse);
+ assertEquals("http://localhost", setHeaders.get("Access-Control-Allow-Origin"));
+ assertEquals("true", setHeaders.get("Access-Control-Allow-Credentials"));
+ }
+
+ @Test
+ void disallowedOriginGetPasses() throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com");
+ when(mockRequest.getMethod()).thenReturn("GET");
+ Map<String, String> setHeaders = recordSetHeaders(mockResponse);
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse, never()).sendError(anyInt(), anyString());
+ verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse);
+ assertEquals("", setHeaders.get("Access-Control-Allow-Origin"));
+ assertNull(setHeaders.get("Access-Control-Allow-Credentials"));
+ }
+
+ @Test
+ void noOriginPostPasses() throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn(null);
+ when(mockRequest.getMethod()).thenReturn("POST");
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse, never()).sendError(anyInt(), anyString());
+ verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse);
+ }
+
+ @Test
+ void simpleOptionsWithoutPreflightHeaderPasses() throws IOException, ServletException {
+ CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load());
+ HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com");
+ when(mockRequest.getHeader("Access-Control-Request-Method")).thenReturn(null);
+ when(mockRequest.getMethod()).thenReturn("OPTIONS");
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+
+ verify(mockResponse, never()).sendError(anyInt(), anyString());
+ verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse);
+ }
+
+ private static Map<String, String> recordSetHeaders(HttpServletResponse response) {
+ Map<String, String> recorded = new HashMap<>();
+ doAnswer(invocation -> {
+ recorded.put(invocation.getArgument(0), invocation.getArgument(1));
+ return null;
+ }).when(response).setHeader(anyString(), anyString());
+ return recorded;
}
}
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/service/NotebookServiceTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/service/NotebookServiceTest.java
index 152d085..0a176ac 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/service/NotebookServiceTest.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/service/NotebookServiceTest.java
@@ -623,7 +623,7 @@
assertEquals("Note name can not contain '..'", e.getMessage());
}
try {
- notebookService.normalizeNotePath("%252525252e%252525252e/tmp/test444");
+ notebookService.normalizeNotePath("%25252525252e%25252525252e/tmp/test444");
fail("Should fail");
} catch (IOException e) {
assertEquals("Exceeded maximum decode attempts. Possible malicious input.", e.getMessage());