/*
 * Decompiled with CFR 0.152.
 */
package com.pingidentity.crypto.jwk;

import com.pingidentity.crypto.jwk.GroupRpcJsonWebKeysAdder;
import com.pingidentity.crypto.jwk.JwkAdditionalKeysState;
import com.pingidentity.crypto.jwk.JwkKeyPair;
import com.pingidentity.crypto.jwk.JwkLifecycleState;
import com.pingidentity.crypto.jwk.JwkUtils;
import com.pingidentity.crypto.jwk.JwkWrapper;
import com.pingidentity.crypto.jwk.JwksContent;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwk.PublicJsonWebKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sourceid.config.ConfigStore;
import org.sourceid.config.ConfigStoreFarm;
import org.sourceid.mgmt.JwkDecryptionKeyManager;
import org.sourceid.saml20.domain.mgmt.MgmtFactory;
import org.sourceid.saml20.state.StateAccepter;

@SuppressFBWarnings(value={"UC_USELESS_OBJECT", "SE_TRANSIENT_FIELD_NOT_RESTORED", "IS2_INCONSISTENT_SYNC"})
public class JwkState
implements GroupRpcJsonWebKeysAdder.JwksRpcTarget,
StateAccepter,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger log = LoggerFactory.getLogger(JwkState.class);
    private static final ConfigStore config = ConfigStoreFarm.getConfig("jwks-endpoint-configuration");
    private static final int MAX_KEY_RETAIN_SIZE = config.getIntValue("max-keysets-retained", 5) * 4;
    private static final boolean ADD_DUPLICATE_RS256_KEY = config.getBooleanValue("add-duplicate-rs256-alg-key", true);
    private Map<String, JwkWrapper> keysById = new LinkedHashMap<String, JwkWrapper>();
    private Map<String, JwkWrapper> encryptionKeysById = new LinkedHashMap<String, JwkWrapper>();
    private transient JwkAdditionalKeysState jwkAdditionalKeysState = null;
    private List<JwkWrapper> sortedKeys = null;
    private List<JwkWrapper> sortedEncryptionKeys = null;
    private long lastUpdate;
    private transient double rollPeriodDays;
    public static final String KEYS = "keys";
    public static final String LAST_MODIFIED = "last_modified";
    private final transient ScheduledExecutorService singleThreadedExecutor;
    private final transient JwkDecryptionKeyManager jwkDecryptionKeyManager;

    public JwkState() {
        this(MgmtFactory.getSingleThreadedExecutor(), MgmtFactory.getJwkDecryptionKeyManager());
    }

    public JwkState(ScheduledExecutorService singleThreadedExecutor, JwkDecryptionKeyManager jwkDecryptionKeyManager) {
        this.jwkDecryptionKeyManager = jwkDecryptionKeyManager;
        this.singleThreadedExecutor = singleThreadedExecutor;
    }

    @Override
    public synchronized void addKeys(JwksContent jwksContent) {
        if (jwksContent.isRemoveAllKeys()) {
            this.removeAllKeys();
        }
        this.sortedKeys = null;
        this.sortedEncryptionKeys = null;
        this.updateKeyStates(this.keysById, "sig");
        for (JwkKeyPair keyPair : jwksContent.getKeys()) {
            boolean duplicateRsaKeyWithRS256algValue = ADD_DUPLICATE_RS256_KEY && jwksContent.isDuplicateRsaWithAlg(keyPair.getState());
            JwkUtils.createJwkWrapper(keyPair, "sig").ifPresent(wrapper -> {
                this.keysById.put(wrapper.getKeyId(), (JwkWrapper)wrapper);
                log.info("Added new key to JSON Web Key Set: {}", (Object)wrapper.getJwk());
            });
            if (!duplicateRsaKeyWithRS256algValue) continue;
            JwkUtils.createJwkWrapperForRS256SignatureAlgParam(keyPair, "sig").ifPresent(wrapper -> {
                this.keysById.put(wrapper.getKeyId(), (JwkWrapper)wrapper);
                log.info("Added new key to JSON Web Key Set: {}", (Object)wrapper.getJwk());
            });
        }
        this.updateKeyStates(this.encryptionKeysById, "enc");
        if (jwksContent.getEncryptionKeys() != null) {
            for (JwkKeyPair keyPair : jwksContent.getEncryptionKeys()) {
                JwkUtils.createJwkWrapper(keyPair, "enc").ifPresent(wrapper -> {
                    this.encryptionKeysById.put(wrapper.getKeyId(), (JwkWrapper)wrapper);
                    log.info("Added new encryption key to JSON Web Key Set: {}", (Object)wrapper.getJwk());
                });
            }
        }
        if (this.jwkAdditionalKeysState == null) {
            this.jwkAdditionalKeysState = new JwkAdditionalKeysState(jwksContent.getAdditionalKeys());
        } else if (jwksContent.getAdditionalKeys() != null) {
            this.jwkAdditionalKeysState.replaceAdditionalKeySetsWrapper(jwksContent.getAdditionalKeys());
        }
        JsonWebKeySet keySet = this.getActiveRetiredEncryptionKeys();
        this.singleThreadedExecutor.execute(() -> this.jwkDecryptionKeyManager.writeDecryptionKeys(keySet));
        this.lastUpdate = System.currentTimeMillis();
    }

    public synchronized JsonWebKeySet getActiveRetiredEncryptionKeys() {
        ArrayList<JwkWrapper> mostRecentlyCreatedKeys = new ArrayList<JwkWrapper>();
        ListIterator<JwkWrapper> iterator = new ArrayList<JwkWrapper>(this.encryptionKeysById.values()).listIterator(this.encryptionKeysById.size());
        int count = 0;
        while (iterator.hasPrevious() && count < MAX_KEY_RETAIN_SIZE) {
            JwkWrapper wrapper = iterator.previous();
            if (wrapper.getLifecycleState() == JwkLifecycleState.CREATED) continue;
            mostRecentlyCreatedKeys.add(wrapper);
            ++count;
        }
        return new JsonWebKeySet(mostRecentlyCreatedKeys.stream().map(this::getKey).collect(Collectors.toList()));
    }

    private PublicJsonWebKey getKey(JwkWrapper wrapper) {
        PublicJsonWebKey key = (PublicJsonWebKey)wrapper.getJwk();
        key.setPrivateKey(wrapper.getPrivateKey());
        return key;
    }

    public synchronized void registerKeys(List<JwkWrapper> wrappers) {
        boolean keyAdded = false;
        for (JwkWrapper wrapper : wrappers) {
            Map<String, JwkWrapper> keyMap = wrapper.getUse().equals("sig") ? this.keysById : this.encryptionKeysById;
            JwkWrapper existingKey = keyMap.get(wrapper.getKeyId());
            if (existingKey != null) {
                if (wrapper.getLifecycleState().compareTo(existingKey.getLifecycleState()) <= 0) continue;
                log.debug("Updating state for key ID " + wrapper.getKeyId() + ": " + wrapper.getLifecycleState());
                keyMap.put(wrapper.getKeyId(), wrapper);
                keyAdded = true;
                continue;
            }
            log.debug("Adding key " + wrapper.getKeyId() + " in state " + wrapper.getLifecycleState());
            keyMap.put(wrapper.getKeyId(), wrapper);
            keyAdded = true;
        }
        if (keyAdded) {
            this.sortedKeys = null;
            this.sortedEncryptionKeys = null;
        }
    }

    @Override
    public synchronized List<JwkWrapper> synchronizeKeys(List<JwkWrapper> coordKeys) {
        this.registerKeys(coordKeys);
        ArrayList<JwkWrapper> keysToReturn = new ArrayList<JwkWrapper>();
        for (Collection wrapperCollection : Arrays.asList(this.keysById.values(), this.encryptionKeysById.values())) {
            for (JwkWrapper wrapper : wrapperCollection) {
                if (!coordKeys.stream().noneMatch(coordKey -> coordKey.getKeyId().equals(wrapper.getKeyId()) && coordKey.getLifecycleState().compareTo(wrapper.getLifecycleState()) >= 0)) continue;
                keysToReturn.add(wrapper);
            }
        }
        return keysToReturn;
    }

    public synchronized void removeAllKeys() {
        this.sortedKeys = null;
        this.sortedEncryptionKeys = null;
        this.keysById.clear();
        this.encryptionKeysById.clear();
        this.jwkAdditionalKeysState = null;
    }

    private void updateKeyStates(Map<String, JwkWrapper> keys, String use) {
        Collection<JwkWrapper> currentKeys = keys.values();
        this.updateKeyStates(currentKeys, use);
    }

    private void updateKeyStates(Collection<JwkWrapper> currentKeys, String use) {
        List<JwkWrapper> keysToRemove = JwkState.updateAndGetKeysToRemove(currentKeys);
        keysToRemove.forEach(w -> this.removeKey((JwkWrapper)w, use));
    }

    static List<JwkWrapper> updateAndGetKeysToRemove(Collection<JwkWrapper> currentKeys) {
        ArrayList<JwkWrapper> keysToRemove = new ArrayList<JwkWrapper>();
        block5: for (JwkWrapper wrapper : currentKeys) {
            switch (wrapper.getLifecycleState()) {
                case CREATED: {
                    wrapper.setLifecycleState(JwkLifecycleState.ACTIVE);
                    continue block5;
                }
                case ACTIVE: {
                    wrapper.setLifecycleState(JwkLifecycleState.RETIRED);
                    continue block5;
                }
                case RETIRED: {
                    keysToRemove.add(wrapper);
                    continue block5;
                }
            }
            throw new RuntimeException("Unexpected JWK lifecycle state: " + wrapper.getLifecycleState());
        }
        return keysToRemove;
    }

    private void removeKey(JwkWrapper toRemove, String use) {
        if ("sig".equals(use)) {
            this.keysById.remove(toRemove.getKeyId());
        } else {
            this.encryptionKeysById.remove(toRemove.getKeyId());
        }
        log.info("Removed key from JSON Web Key Set: " + toRemove.getKeyId() + "/" + toRemove.getJwk().getKeyType());
    }

    public synchronized List<JwkWrapper> getKeys() {
        if (this.sortedKeys == null) {
            ArrayList<JwkWrapper> newKeys = new ArrayList<JwkWrapper>(this.keysById.values());
            Collections.sort(newKeys);
            this.sortedKeys = Collections.unmodifiableList(newKeys);
        }
        return this.sortedKeys;
    }

    public synchronized List<JwkWrapper> getKeys(String oAuthIssuerInstanceId) {
        List<JwkWrapper> keysForIssuer = null;
        if (this.jwkAdditionalKeysState != null) {
            ArrayList<JwkWrapper> newKeys = new ArrayList<JwkWrapper>(this.jwkAdditionalKeysState.getKeyPairsForIssuer(oAuthIssuerInstanceId));
            Collections.sort(newKeys);
            keysForIssuer = Collections.unmodifiableList(newKeys);
        }
        return keysForIssuer;
    }

    public synchronized List<JwkWrapper> getEncryptionKeys() {
        if (this.sortedEncryptionKeys == null) {
            ArrayList<JwkWrapper> newKeys = new ArrayList<JwkWrapper>(this.encryptionKeysById.values());
            Collections.sort(newKeys);
            this.sortedEncryptionKeys = Collections.unmodifiableList(newKeys);
        }
        return this.sortedEncryptionKeys;
    }

    public synchronized Map<String, Map<String, JwkWrapper>> getAdditionalSigningKeys() {
        Map<String, Map<String, JwkWrapper>> additionalSigningKeys = null;
        if (this.jwkAdditionalKeysState != null && !this.jwkAdditionalKeysState.getIssuerToJwkWrappers().isEmpty()) {
            additionalSigningKeys = Collections.unmodifiableMap(this.jwkAdditionalKeysState.getIssuerToJwkWrappers());
        }
        return additionalSigningKeys;
    }

    public synchronized List<JwkWrapper> getAllKeys() {
        ArrayList<JwkWrapper> allKeys = new ArrayList<JwkWrapper>();
        allKeys.addAll(this.getKeys());
        allKeys.addAll(this.getEncryptionKeys());
        return allKeys;
    }

    public synchronized long getLastUpdate() {
        return this.lastUpdate;
    }

    public synchronized void setRollPeriodDays(double rollPeriodDays) {
        this.rollPeriodDays = rollPeriodDays;
    }

    public synchronized JwkWrapper getKeyById(String kid) {
        return this.keysById.get(kid);
    }

    @Override
    public synchronized void setState(StateAccepter other) {
        JwkState otherJwkState = (JwkState)other;
        this.keysById = otherJwkState.keysById;
        this.encryptionKeysById = otherJwkState.encryptionKeysById;
        this.sortedKeys = otherJwkState.sortedKeys;
        this.sortedEncryptionKeys = otherJwkState.sortedEncryptionKeys;
        this.lastUpdate = otherJwkState.lastUpdate;
    }

    public synchronized String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("JwkState");
        sb.append("{rollPeriodDays=").append(this.rollPeriodDays);
        sb.append(", # of keys=").append(this.keysById.size());
        sb.append(", # of encryption keys=").append(this.encryptionKeysById.size());
        sb.append(", lastUpdate=").append(this.lastUpdate);
        sb.append('}');
        return sb.toString();
    }
}

