SLING-12999 - Invalid refresh tokens are not cleared (#35)
Always clear access/refresh tokens in case the refreshing the token fails.
diff --git a/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java b/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
index 2d61849..0f9798c 100644
--- a/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
+++ b/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
@@ -27,6 +27,7 @@
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
+import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
@@ -41,11 +42,12 @@
public class OAuthTokenRefresherImpl implements OAuthTokenRefresher {
@Override
- public @NotNull OAuthTokens refreshTokens(@NotNull ClientConnection connection, @NotNull String refreshToken) {
+ public @NotNull OAuthTokens refreshTokens(@NotNull ClientConnection connection, @NotNull String refreshToken)
+ throws OAuthException {
return Converter.toSlingOAuthTokens(refreshTokensInternal(connection, refreshToken));
}
- private static @NotNull Tokens refreshTokensInternal(
+ private @NotNull Tokens refreshTokensInternal(
@NotNull ClientConnection connection, @NotNull String refreshTokenString) throws OAuthException {
try {
// Construct the grant from the saved refresh token
@@ -65,14 +67,12 @@
// Make the token request
TokenRequest request = new TokenRequest.Builder(tokenEndpoint, clientAuth, refreshTokenGrant).build();
- AccessTokenResponse response =
- AccessTokenResponse.parse(request.toHTTPRequest().send());
+ TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send());
if (!response.indicatesSuccess()) {
- // We got an error response...
TokenErrorResponse errorResponse = response.toErrorResponse();
- throw new OAuthException("Failed refreshing the access token "
- + errorResponse.getErrorObject().getCode() + " : "
+ throw new OAuthException("Failed refreshing the access token. Code: "
+ + errorResponse.getErrorObject().getCode() + ", description: "
+ errorResponse.getErrorObject().getDescription());
}
diff --git a/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java b/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
index 80b29ce..3fb2562 100644
--- a/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
+++ b/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
@@ -81,15 +81,26 @@
request.getUserPrincipal());
}
- OAuthTokens newTokens = tokenRefresher.refreshTokens(connection, refreshToken.getValue());
- if (newTokens.refreshToken() == null) {
- // retain old refresh token if none was returned
- newTokens =
- new OAuthTokens(newTokens.accessToken(), newTokens.expiresAt(), refreshToken.getValue());
+ OAuthTokens newTokens;
+ try {
+ newTokens = tokenRefresher.refreshTokens(connection, refreshToken.getValue());
+ if (newTokens.refreshToken() == null) {
+ // retain old refresh token if none was returned but call was successful
+ newTokens = new OAuthTokens(
+ newTokens.accessToken(), newTokens.expiresAt(), refreshToken.getValue());
+ }
+ } catch (OAuthException e) {
+ logger.warn(
+ "Failed to refresh access token for connection {} and user {}. Clearing all tokens",
+ connection.name(),
+ request.getRemoteUser(),
+ e);
+ newTokens = new OAuthTokens(null, 0, null);
}
tokenStore.persistTokens(connection, resolver, newTokens);
- return new OAuthTokenResponse(Optional.of(newTokens.accessToken()), connection, request, redirectPath);
+ return new OAuthTokenResponse(
+ Optional.ofNullable(newTokens.accessToken()), connection, request, redirectPath);
}
}
diff --git a/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java b/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
index 8beb4bf..5549ebc 100644
--- a/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
+++ b/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
@@ -19,10 +19,13 @@
package org.apache.sling.auth.oauth_client;
import org.apache.sling.auth.oauth_client.impl.MockOidcConnection;
+import org.apache.sling.auth.oauth_client.impl.OAuthException;
+import org.apache.sling.auth.oauth_client.impl.OAuthToken;
import org.apache.sling.auth.oauth_client.impl.OAuthTokenRefresher;
import org.apache.sling.auth.oauth_client.impl.OAuthTokenStore;
import org.apache.sling.auth.oauth_client.impl.OAuthTokens;
import org.apache.sling.auth.oauth_client.impl.TokenAccessImpl;
+import org.apache.sling.auth.oauth_client.impl.TokenState;
import org.apache.sling.testing.mock.sling.junit5.SlingContext;
import org.apache.sling.testing.mock.sling.junit5.SlingContextExtension;
import org.jetbrains.annotations.NotNull;
@@ -87,7 +90,7 @@
OAuthTokenStore tokenStore = new InMemoryOAuthTokenStore();
- TokenAccessImpl tokenAccess = getTokenAccess(expiredTokens, refreshedTokens, tokenStore);
+ TokenAccessImpl tokenAccess = getTokenAccess(expiredTokens.refreshToken(), refreshedTokens, tokenStore);
tokenStore.persistTokens(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver(), expiredTokens);
@@ -108,13 +111,14 @@
}
private static @NotNull TokenAccessImpl getTokenAccess(
- OAuthTokens expiredTokens, OAuthTokens refreshedTokens, OAuthTokenStore tokenStore) {
+ String expectedRefreshToken, OAuthTokens refreshedTokens, OAuthTokenStore tokenStore) {
OAuthTokenRefresher tokenRefresher = new OAuthTokenRefresher() {
@Override
public @NotNull OAuthTokens refreshTokens(
@NotNull ClientConnection connection, @NotNull String refreshToken) {
- if (!refreshToken.equals(expiredTokens.refreshToken())) {
- throw new IllegalArgumentException("Invalid refresh token");
+ if (!refreshToken.equals(expectedRefreshToken)) {
+ throw new OAuthException("Invalid refresh token. Expected '" + expectedRefreshToken + "' but got '"
+ + refreshToken + "'");
}
return refreshedTokens;
}
@@ -169,4 +173,36 @@
assertThat(tokenStore.allTokens()).as("all persisted tokens").isEmpty();
}
+
+ @Test
+ void refreshToken_invalid() {
+
+ OAuthTokens expiredTokens = new OAuthTokens("access", -1, "refresh");
+ OAuthTokens refreshedTokens = new OAuthTokens("access2", 0, null);
+
+ OAuthTokenStore tokenStore = new InMemoryOAuthTokenStore();
+
+ // ensure that the refresher fails when an unexpected refresh token is used
+ TokenAccessImpl tokenAccess =
+ getTokenAccess("NOT " + expiredTokens.refreshToken(), refreshedTokens, tokenStore);
+
+ tokenStore.persistTokens(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver(), expiredTokens);
+
+ OAuthTokenResponse tokenResponse =
+ tokenAccess.getAccessToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.request(), "/");
+
+ assertThat(tokenResponse).as("tokenResponse").isNotNull().satisfies(tr -> {
+ assertThat(tr.hasValidToken()).as("hasValidToken").isFalse();
+ });
+
+ assertThat(tokenStore.getRefreshToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver()))
+ .as("refresh token after failed refresh")
+ .extracting(OAuthToken::getState)
+ .isEqualTo(TokenState.MISSING);
+
+ assertThat(tokenStore.getAccessToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver()))
+ .as("acess token after failed refresh")
+ .extracting(OAuthToken::getState)
+ .isEqualTo(TokenState.MISSING);
+ }
}