Spring OAuth2 Generate Access Token per request to the Token Endpoint

Following @Thanh Nguyen Van approach:

I stumbled upon the same problem while developing my backend with Spring Boot and OAuth2. The problem I encountered was that, if multiple devices shared the same tokens, once one device refreshed the token, the other device would be clueless and, long story short, both devices entered in a token refresh frenzy. My solution was to replace the default AuthenticationKeyGenerator with a custom implementation which overrides DefaultAuthenticationKeyGenerator and adds a new parameter client_instance_id in the key generator mixture. My mobile clients would then send this parameter which has to be unique across app installs (iOS or Android). This is not a special requirement, since most mobile apps already track the application instance in some form.

public class EnhancedAuthenticationKeyGenerator extends DefaultAuthenticationKeyGenerator {

    public static final String PARAM_CLIENT_INSTANCE_ID = "client_instance_id";

    private static final String KEY_SUPER_KEY = "super_key";
    private static final String KEY_CLIENT_INSTANCE_ID = PARAM_CLIENT_INSTANCE_ID;

    @Override
    public String extractKey(final OAuth2Authentication authentication) {
        final String superKey = super.extractKey(authentication);

        final OAuth2Request authorizationRequest = authentication.getOAuth2Request();
        final Map<String, String> requestParameters = authorizationRequest.getRequestParameters();

        final String clientInstanceId = requestParameters != null ? requestParameters.get(PARAM_CLIENT_INSTANCE_ID) : null;
        if (clientInstanceId == null || clientInstanceId.length() == 0) {
            return superKey;
        }

        final Map<String, String> values = new LinkedHashMap<>(2);
        values.put(KEY_SUPER_KEY, superKey);
        values.put(KEY_CLIENT_INSTANCE_ID, clientInstanceId);

        return generateKey(values);
    }

}

which you would then inject in a similar manner:

final JdbcTokenStore tokenStore = new JdbcTokenStore(mDataSource);
tokenStore.setAuthenticationKeyGenerator(new EnhancedAuthenticationKeyGenerator());

The HTTP request would then look something like this

POST /oauth/token HTTP/1.1
Host: {{host}}
Authorization: Basic {{auth_client_basic}}
Content-Type: application/x-www-form-urlencoded

grant_type=password&username={{username}}&password={{password}}&client_instance_id={{instance_id}}

The benefit of using this approach is that, if the client doesn't send a client_instance_id, the default key would be generated, and if an instance is provided, the same key is returned every time for the same instance. Also, the key is platform independent. The downside would be that the MD5 digest (used internally) is called two times.


Updated on 21/11/2014

When I double check, I found that InMemoryTokenStore use a OAuth2Authentication's hash string as key of serveral Map. And when I use same username, client_id, scope.. and I got same key. So this may leading to some problem. So I think the old way are deprecated. The following is what I did to avoid the problem.

Create another AuthenticationKeyGenerator that can calculate unique key, called UniqueAuthenticationKeyGenerator

/*
 * Copyright 2006-2011 the original author or authors.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

/**
 * Basic key generator taking into account the client id, scope, resource ids and username (principal name) if they
 * exist.
 * 
 * @author Dave Syer
 * @author thanh
 */
public class UniqueAuthenticationKeyGenerator implements AuthenticationKeyGenerator {

    private static final String CLIENT_ID = "client_id";

    private static final String SCOPE = "scope";

    private static final String USERNAME = "username";

    private static final String UUID_KEY = "uuid";

    public String extractKey(OAuth2Authentication authentication) {
        Map<String, String> values = new LinkedHashMap<String, String>();
        OAuth2Request authorizationRequest = authentication.getOAuth2Request();
        if (!authentication.isClientOnly()) {
            values.put(USERNAME, authentication.getName());
        }
        values.put(CLIENT_ID, authorizationRequest.getClientId());
        if (authorizationRequest.getScope() != null) {
            values.put(SCOPE, OAuth2Utils.formatParameterList(authorizationRequest.getScope()));
        }
        Map<String, Serializable> extentions = authorizationRequest.getExtensions();
        String uuid = null;
        if (extentions == null) {
            extentions = new HashMap<String, Serializable>(1);
            uuid = UUID.randomUUID().toString();
            extentions.put(UUID_KEY, uuid);
        } else {
            uuid = (String) extentions.get(UUID_KEY);
            if (uuid == null) {
                uuid = UUID.randomUUID().toString();
                extentions.put(UUID_KEY, uuid);
            }
        }
        values.put(UUID_KEY, uuid);

        MessageDigest digest;
        try {
            digest = MessageDigest.getInstance("MD5");
        }
        catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException("MD5 algorithm not available.  Fatal (should be in the JDK).");
        }

        try {
            byte[] bytes = digest.digest(values.toString().getBytes("UTF-8"));
            return String.format("%032x", new BigInteger(1, bytes));
        }
        catch (UnsupportedEncodingException e) {
            throw new IllegalStateException("UTF-8 encoding not available.  Fatal (should be in the JDK).");
        }
    }
}

Finally, wire them up

<bean id="tokenStore" class="org.springframework.security.oauth2.provider.token.store.JdbcTokenStore">
    <constructor-arg ref="jdbcTemplate" />
    <property name="authenticationKeyGenerator">
        <bean class="your.package.UniqueAuthenticationKeyGenerator" />
    </property>
</bean>

Below way may leading to some problem, see updated answer!!!

You are using DefaultTokenServices. Try this code and make sure to re-define your `tokenServices` package com.thanh.backend.oauth2.core; import java.util.Date; import java.util.UUID; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.common.DefaultExpiringOAuth2RefreshToken; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2RefreshToken; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.token.DefaultTokenServices; import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.security.oauth2.provider.token.TokenStore; /** * @author thanh */ public class SimpleTokenService extends DefaultTokenServices { private TokenStore tokenStore; private TokenEnhancer accessTokenEnhancer; @Override public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) throws AuthenticationException { OAuth2RefreshToken refreshToken = createRefreshToken(authentication);; OAuth2AccessToken accessToken = createAccessToken(authentication, refreshToken); tokenStore.storeAccessToken(accessToken, authentication); tokenStore.storeRefreshToken(refreshToken, authentication); return accessToken; } private OAuth2AccessToken createAccessToken(OAuth2Authentication authentication, OAuth2RefreshToken refreshToken) { DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(UUID.randomUUID().toString()); int validitySeconds = getAccessTokenValiditySeconds(authentication.getOAuth2Request()); if (validitySeconds > 0) { token.setExpiration(new Date(System.currentTimeMillis() + (validitySeconds * 1000L))); } token.setRefreshToken(refreshToken); token.setScope(authentication.getOAuth2Request().getScope()); return accessTokenEnhancer != null ? accessTokenEnhancer.enhance(token, authentication) : token; } private ExpiringOAuth2RefreshToken createRefreshToken(OAuth2Authentication authentication) { if (!isSupportRefreshToken(authentication.getOAuth2Request())) { return null; } int validitySeconds = getRefreshTokenValiditySeconds(authentication.getOAuth2Request()); ExpiringOAuth2RefreshToken refreshToken = new DefaultExpiringOAuth2RefreshToken(UUID.randomUUID().toString(), new Date(System.currentTimeMillis() + (validitySeconds * 1000L))); return refreshToken; } @Override public void setTokenEnhancer(TokenEnhancer accessTokenEnhancer) { super.setTokenEnhancer(accessTokenEnhancer); this.accessTokenEnhancer = accessTokenEnhancer; } @Override public void setTokenStore(TokenStore tokenStore) { super.setTokenStore(tokenStore); this.tokenStore = tokenStore; } }