SLING-12999 - Invalid refresh tokens are not cleared (#35)

Always clear access/refresh tokens in case the refreshing the token fails.
diff --git a/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java b/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
index 2d61849..0f9798c 100644
--- a/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
+++ b/src/main/java/org/apache/sling/auth/oauth_client/impl/OAuthTokenRefresherImpl.java
@@ -27,6 +27,7 @@
 import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
 import com.nimbusds.oauth2.sdk.TokenErrorResponse;
 import com.nimbusds.oauth2.sdk.TokenRequest;
+import com.nimbusds.oauth2.sdk.TokenResponse;
 import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
 import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
 import com.nimbusds.oauth2.sdk.auth.Secret;
@@ -41,11 +42,12 @@
 public class OAuthTokenRefresherImpl implements OAuthTokenRefresher {
 
     @Override
-    public @NotNull OAuthTokens refreshTokens(@NotNull ClientConnection connection, @NotNull String refreshToken) {
+    public @NotNull OAuthTokens refreshTokens(@NotNull ClientConnection connection, @NotNull String refreshToken)
+            throws OAuthException {
         return Converter.toSlingOAuthTokens(refreshTokensInternal(connection, refreshToken));
     }
 
-    private static @NotNull Tokens refreshTokensInternal(
+    private @NotNull Tokens refreshTokensInternal(
             @NotNull ClientConnection connection, @NotNull String refreshTokenString) throws OAuthException {
         try {
             // Construct the grant from the saved refresh token
@@ -65,14 +67,12 @@
             // Make the token request
             TokenRequest request = new TokenRequest.Builder(tokenEndpoint, clientAuth, refreshTokenGrant).build();
 
-            AccessTokenResponse response =
-                    AccessTokenResponse.parse(request.toHTTPRequest().send());
+            TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send());
 
             if (!response.indicatesSuccess()) {
-                // We got an error response...
                 TokenErrorResponse errorResponse = response.toErrorResponse();
-                throw new OAuthException("Failed refreshing the access token "
-                        + errorResponse.getErrorObject().getCode() + " : "
+                throw new OAuthException("Failed refreshing the access token. Code: "
+                        + errorResponse.getErrorObject().getCode() + ", description: "
                         + errorResponse.getErrorObject().getDescription());
             }
 
diff --git a/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java b/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
index 80b29ce..3fb2562 100644
--- a/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
+++ b/src/main/java/org/apache/sling/auth/oauth_client/impl/TokenAccessImpl.java
@@ -81,15 +81,26 @@
                             request.getUserPrincipal());
                 }
 
-                OAuthTokens newTokens = tokenRefresher.refreshTokens(connection, refreshToken.getValue());
-                if (newTokens.refreshToken() == null) {
-                    // retain old refresh token if none was returned
-                    newTokens =
-                            new OAuthTokens(newTokens.accessToken(), newTokens.expiresAt(), refreshToken.getValue());
+                OAuthTokens newTokens;
+                try {
+                    newTokens = tokenRefresher.refreshTokens(connection, refreshToken.getValue());
+                    if (newTokens.refreshToken() == null) {
+                        // retain old refresh token if none was returned but call was successful
+                        newTokens = new OAuthTokens(
+                                newTokens.accessToken(), newTokens.expiresAt(), refreshToken.getValue());
+                    }
+                } catch (OAuthException e) {
+                    logger.warn(
+                            "Failed to refresh access token for connection {} and user {}. Clearing all tokens",
+                            connection.name(),
+                            request.getRemoteUser(),
+                            e);
+                    newTokens = new OAuthTokens(null, 0, null);
                 }
                 tokenStore.persistTokens(connection, resolver, newTokens);
 
-                return new OAuthTokenResponse(Optional.of(newTokens.accessToken()), connection, request, redirectPath);
+                return new OAuthTokenResponse(
+                        Optional.ofNullable(newTokens.accessToken()), connection, request, redirectPath);
             }
         }
 
diff --git a/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java b/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
index 8beb4bf..5549ebc 100644
--- a/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
+++ b/src/test/java/org/apache/sling/auth/oauth_client/TokenAccessImplTest.java
@@ -19,10 +19,13 @@
 package org.apache.sling.auth.oauth_client;
 
 import org.apache.sling.auth.oauth_client.impl.MockOidcConnection;
+import org.apache.sling.auth.oauth_client.impl.OAuthException;
+import org.apache.sling.auth.oauth_client.impl.OAuthToken;
 import org.apache.sling.auth.oauth_client.impl.OAuthTokenRefresher;
 import org.apache.sling.auth.oauth_client.impl.OAuthTokenStore;
 import org.apache.sling.auth.oauth_client.impl.OAuthTokens;
 import org.apache.sling.auth.oauth_client.impl.TokenAccessImpl;
+import org.apache.sling.auth.oauth_client.impl.TokenState;
 import org.apache.sling.testing.mock.sling.junit5.SlingContext;
 import org.apache.sling.testing.mock.sling.junit5.SlingContextExtension;
 import org.jetbrains.annotations.NotNull;
@@ -87,7 +90,7 @@
 
         OAuthTokenStore tokenStore = new InMemoryOAuthTokenStore();
 
-        TokenAccessImpl tokenAccess = getTokenAccess(expiredTokens, refreshedTokens, tokenStore);
+        TokenAccessImpl tokenAccess = getTokenAccess(expiredTokens.refreshToken(), refreshedTokens, tokenStore);
 
         tokenStore.persistTokens(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver(), expiredTokens);
 
@@ -108,13 +111,14 @@
     }
 
     private static @NotNull TokenAccessImpl getTokenAccess(
-            OAuthTokens expiredTokens, OAuthTokens refreshedTokens, OAuthTokenStore tokenStore) {
+            String expectedRefreshToken, OAuthTokens refreshedTokens, OAuthTokenStore tokenStore) {
         OAuthTokenRefresher tokenRefresher = new OAuthTokenRefresher() {
             @Override
             public @NotNull OAuthTokens refreshTokens(
                     @NotNull ClientConnection connection, @NotNull String refreshToken) {
-                if (!refreshToken.equals(expiredTokens.refreshToken())) {
-                    throw new IllegalArgumentException("Invalid refresh token");
+                if (!refreshToken.equals(expectedRefreshToken)) {
+                    throw new OAuthException("Invalid refresh token. Expected '" + expectedRefreshToken + "' but got '"
+                            + refreshToken + "'");
                 }
                 return refreshedTokens;
             }
@@ -169,4 +173,36 @@
 
         assertThat(tokenStore.allTokens()).as("all persisted tokens").isEmpty();
     }
+
+    @Test
+    void refreshToken_invalid() {
+
+        OAuthTokens expiredTokens = new OAuthTokens("access", -1, "refresh");
+        OAuthTokens refreshedTokens = new OAuthTokens("access2", 0, null);
+
+        OAuthTokenStore tokenStore = new InMemoryOAuthTokenStore();
+
+        // ensure that the refresher fails when an unexpected refresh token is used
+        TokenAccessImpl tokenAccess =
+                getTokenAccess("NOT " + expiredTokens.refreshToken(), refreshedTokens, tokenStore);
+
+        tokenStore.persistTokens(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver(), expiredTokens);
+
+        OAuthTokenResponse tokenResponse =
+                tokenAccess.getAccessToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.request(), "/");
+
+        assertThat(tokenResponse).as("tokenResponse").isNotNull().satisfies(tr -> {
+            assertThat(tr.hasValidToken()).as("hasValidToken").isFalse();
+        });
+
+        assertThat(tokenStore.getRefreshToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver()))
+                .as("refresh token after failed refresh")
+                .extracting(OAuthToken::getState)
+                .isEqualTo(TokenState.MISSING);
+
+        assertThat(tokenStore.getAccessToken(MockOidcConnection.DEFAULT_CONNECTION, slingContext.resourceResolver()))
+                .as("acess token after failed refresh")
+                .extracting(OAuthToken::getState)
+                .isEqualTo(TokenState.MISSING);
+    }
 }