/*
 * Decompiled with CFR 0.152.
 */
package org.sourceid.openid.connect.rp;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.jose4j.base64url.Base64Url;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.sourceid.common.HashAlgorithm;
import org.sourceid.common.HashUtil;
import org.sourceid.common.IDGenerator;
import org.sourceid.common.SpaceDelimitedStringUtil;
import org.sourceid.oauth20.client.AuthorizationResponseException;
import org.sourceid.oauth20.protocol.Parameters;
import org.sourceid.oauth20.protocol.ResponseType;
import org.sourceid.openid.connect.domain.OIDCProviderConnection;
import org.sourceid.openid.connect.rp.OIDCAuthenticationResponse;
import org.sourceid.saml20.domain.OIDCRequestParamSetting;
import org.sourceid.saml20.domain.OIDCRequestParams;
import org.sourceid.websso.wrapper.InMessageContext;
import org.sourceid.websso.wrapper.OutMessageContext;

public class OIDCProtocolAuthProcessor {
    private final OIDCProviderConnection providerConnection;

    public OIDCProtocolAuthProcessor(OIDCProviderConnection oidcProviderConnection) {
        this.providerConnection = oidcProviderConnection;
    }

    public OutMessageContext makeAuthnRequest(String ssoRedirectUri, String state, String nonce) {
        return this.makeAuthnRequest(ssoRedirectUri, state, nonce, Collections.emptyList(), null);
    }

    public OutMessageContext makeAuthnRequest(String ssoRedirectUri, String state, String nonce, Collection<OIDCRequestParamSetting> requestParamSettings, Map<String, String[]> incomingRequestParams) {
        Map<String, Object> params = new HashMap<String, Object>();
        params.put(Parameters.SCOPE, this.providerConnection.getSpaceSeparatedScope());
        params.put("response_type", "code");
        this.handleAcrValues(params);
        this.populateCommonRequestParam(ssoRedirectUri, state, nonce, params);
        OutMessageContext outContext = new OutMessageContext();
        if (StringUtils.isEmpty((String)this.providerConnection.getPushedAuthorizationRequestEndpoint())) {
            outContext.setEndpoint(this.providerConnection.getAuthorizationEndpoint());
        } else {
            outContext.setEndpoint(this.providerConnection.getPushedAuthorizationRequestEndpoint());
        }
        this.populateCodeChallenge(params, this.providerConnection.isEnableProofKeyForCodeExchange(), outContext);
        params = this.populateCustomRequestParam(requestParamSettings, params, incomingRequestParams);
        outContext.setParams(params);
        outContext.setRelayState(state);
        outContext.setEntityId(this.providerConnection.getIssuer());
        return outContext;
    }

    protected void populateCodeChallenge(Map<String, Object> params, boolean pKCEEnabled, OutMessageContext outMessageContext) {
        if (pKCEEnabled) {
            String codeVerifier = Base64Url.encode((byte[])IDGenerator.generateBytes(32));
            byte[] hashedChallenge = HashUtil.hashToBytes((String)codeVerifier, (HashAlgorithm)HashAlgorithm.SHA256);
            params.put("code_challenge", Base64Url.encode((byte[])hashedChallenge));
            params.put("code_challenge_method", "S256");
            outMessageContext.setSupplementalContext("code_verifier", codeVerifier);
        }
    }

    private void handleAcrValues(Map<String, Object> params) {
        if (this.providerConnection.getAcrValues() != null && !this.providerConnection.getAcrValues().isEmpty()) {
            String acrValues = String.join((CharSequence)" ", this.providerConnection.getAcrValues());
            params.put("acr_values", acrValues);
        }
    }

    protected void populateCommonRequestParam(String ssoRedirectUri, String state, String nonce, Map<String, Object> params) {
        params.put(Parameters.CLIENT_ID, this.providerConnection.getClientId());
        params.put("redirect_uri", ssoRedirectUri);
        params.put("state", state);
        params.put("response_mode", ResponseType.ResponseMode.form_post.name());
        params.put("nonce", nonce);
    }

    public Map<String, Object> populateCustomRequestParam(Collection<OIDCRequestParamSetting> requestParamSettings, Map<String, Object> existingParamMap, Map<String, String[]> incomingParamMap) {
        if (CollectionUtils.isEmpty(requestParamSettings)) {
            return existingParamMap;
        }
        OIDCRequestParams existingParams = new OIDCRequestParams(existingParamMap);
        OIDCRequestParams incomingParams = new OIDCRequestParams(incomingParamMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        for (OIDCRequestParamSetting paramSetting : requestParamSettings) {
            String name = paramSetting.getName();
            String value = paramSetting.getValue().getValue();
            if (!existingParams.getParamMap().containsKey(name) || !StringUtils.isNotEmpty((String)value)) continue;
            existingParams.getParamMap().remove(paramSetting.getName());
        }
        for (OIDCRequestParamSetting paramSetting : requestParamSettings) {
            String paramName = paramSetting.getName();
            String paramValue = paramSetting.getValue().getValue();
            if (this.isOverridable(incomingParams.getParamMap(), paramSetting) || StringUtils.isEmpty((String)paramValue)) {
                this.overrideExistingParameter(paramSetting, existingParams, incomingParams.getParamMap());
                continue;
            }
            if (!StringUtils.isNotEmpty((String)paramValue)) continue;
            existingParams.getParamMap().computeIfAbsent(paramName, key -> new HashSet()).add(paramValue);
        }
        return existingParams.convertToStringObjectMap();
    }

    private boolean isOverridable(Map<String, Set<String>> incomingParams, OIDCRequestParamSetting paramSetting) {
        return incomingParams != null && incomingParams.containsKey(paramSetting.getName()) && paramSetting.isOverride() != false;
    }

    private void overrideExistingParameter(OIDCRequestParamSetting requestParamSetting, OIDCRequestParams existingParams, Map<String, Set<String>> incomingParams) {
        String paramName = requestParamSetting.getName();
        Set<String> incomingParamValues = incomingParams.get(paramName);
        if (incomingParamValues != null) {
            incomingParamValues.removeIf(StringUtils::isEmpty);
            if (incomingParamValues.size() > 0) {
                existingParams.getParamMap().computeIfAbsent(paramName, key -> new HashSet()).addAll(incomingParamValues);
            }
        }
    }

    public OIDCAuthenticationResponse validateAuthnResponse(InMessageContext inContext) throws AuthorizationResponseException {
        String idToken = inContext.getParam("id_token");
        String state = this.getState(inContext);
        return this.getOidcAuthenticationResponse(inContext, idToken, state, true);
    }

    protected String getState(InMessageContext inContext) {
        return inContext.getRelayState();
    }

    protected OIDCAuthenticationResponse getOidcAuthenticationResponse(InMessageContext inContext, String idToken, String state, boolean isAccessTokenExpected) throws AuthorizationResponseException {
        String error = inContext.getParam("error");
        String errorDescription = inContext.getParam("error_description");
        String errorUri = inContext.getParam("error_uri");
        if (error != null) {
            throw new AuthorizationResponseException(error, errorDescription, errorUri);
        }
        if (!isAccessTokenExpected) {
            for (String paramName : Arrays.asList("access_token", "token_type", "expires_in")) {
                if (inContext.getParam(paramName) == null) continue;
                throw new AuthorizationResponseException("Unexpected parameter " + paramName + " in authentication response");
            }
        }
        if (state == null) {
            throw new AuthorizationResponseException("Missing 'state' parameter in authentication response");
        }
        if (idToken == null) {
            throw new AuthorizationResponseException("Missing ID token in authentication response");
        }
        try {
            JwtClaims jwtClaims = this.getJwtClaims(idToken);
            if (jwtClaims.getAudience().size() > 1) {
                throw new AuthorizationResponseException("ID token contains more than one audience: " + StringUtils.join((Collection)jwtClaims.getAudience(), (String)","));
            }
            if (this.providerConnection.getAcrValues() != null && !this.providerConnection.getAcrValues().isEmpty() && !this.hasAcrValues(jwtClaims)) {
                throw new AuthorizationResponseException("ID token ACR value: '" + jwtClaims.getStringClaimValue("acr") + "' does not match any of the request ACR values.");
            }
            return new OIDCAuthenticationResponse(state, jwtClaims);
        }
        catch (MalformedClaimException | InvalidJwtException e) {
            throw new AuthorizationResponseException("Error validating ID token: " + (Exception)e, e);
        }
    }

    protected JwtClaims getJwtClaims(String idToken) throws InvalidJwtException {
        JwtConsumer jwtConsumer = new JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(this.providerConnection.getAllowedClockSkewSecs()).setRequireSubject().setExpectedIssuer(this.providerConnection.getIssuer()).setExpectedAudience(new String[]{this.providerConnection.getClientId()}).setSkipSignatureVerification().build();
        return jwtConsumer.processToClaims(idToken);
    }

    private boolean hasAcrValues(JwtClaims jwtClaims) {
        if (!jwtClaims.hasClaim("acr")) {
            return false;
        }
        String acrValue = jwtClaims.getClaimValueAsString("acr");
        List acrs = SpaceDelimitedStringUtil.fromString((String)acrValue);
        for (String acr : this.providerConnection.getAcrValues()) {
            for (String tokenAcr : acrs) {
                if (!tokenAcr.equals(acr)) continue;
                return true;
            }
        }
        return false;
    }
}

