/*
 * Decompiled with CFR 0.152.
 */
package org.sourceid.saml20.service.impl.grouprpc.adaptive;

import com.pingidentity.common.util.PropertyInfo;
import com.pingidentity.common.util.consistent.ConsistentHashRing;
import com.pingidentity.common.util.consistent.HashRingNode;
import com.pingidentity.common.util.consistent.Range;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jgroups.Address;
import org.jgroups.util.RspList;
import org.sourceid.config.ConfigStore;
import org.sourceid.config.ConfigStoreFarm;
import org.sourceid.saml20.domain.mgmt.MgmtFactory;
import org.sourceid.saml20.service.DistributedStateCoordinator;
import org.sourceid.saml20.service.RangeRecords;
import org.sourceid.saml20.service.SessionRegistryException;
import org.sourceid.saml20.service.StateService;
import org.sourceid.saml20.service.StateServiceId;
import org.sourceid.saml20.service.StateServiceRegistry;
import org.sourceid.saml20.service.impl.grouprpc.BaseGroupRpc;
import org.sourceid.saml20.service.impl.grouprpc.RspUtil;
import org.sourceid.saml20.service.impl.grouprpc.adaptive.AddressNode;
import org.sourceid.saml20.service.impl.grouprpc.adaptive.RebalanceTask;
import org.sourceid.saml20.service.impl.grouprpc.adaptive.RebalanceWorker;
import org.sourceid.saml20.service.impl.localmemory.IdpSessionRegistryMapImpl;
import org.sourceid.saml20.service.impl.localmemory.InterReqStateMgmtMapImpl;
import org.sourceid.saml20.service.impl.localmemory.SpSessionRegistryMapImpl;
import org.sourceid.saml20.service.util.Node;
import org.sourceid.saml20.service.util.NodeIndexRegistryListener;
import org.sourceid.saml20.state.AdaptiveClusteringConfig;
import org.sourceid.saml20.state.StateMgmtFactory;
import org.sourceid.websso.servlet.SessionIdUtil;

public class DistributedStateCoordinatorImpl
extends BaseGroupRpc
implements DistributedStateCoordinator,
NodeIndexRegistryListener {
    private static final Log log = LogFactory.getLog(DistributedStateCoordinatorImpl.class);
    private static final String RETRIEVE_RECORDS_DELAY_MILLIS = "RetrieveRecordsDelayMillis";
    private static final String RPC_TIMEOUT_MILLIS = "RpcTimeoutMillis";
    private static final String REBALANCE_CHECK_INTERVAL_MILLIS = "RebalanceCheckIntervalMillis";
    private static final String CENTRAL_REBALANCE_LOCK = "CentralRebalanceLock";
    private static final String REPLICATE_CACHE_ACCESS = "replicateCacheAccess";
    private static final Class<?>[] REPLICATE_CACHE_ACCESS_SIG = new Class[]{String.class, String.class};
    private static final String IS_REBALANCE_IN_PROGRESS = "isRebalanceInProgress";
    private static final Class<?>[] IS_REBALANCE_IN_PROGRESS_SIG = new Class[0];
    private volatile ConsistentHashRing localHashRing = null;
    private volatile Map<String, ConsistentHashRing> nodeGroupHashRings = new HashMap<String, ConsistentHashRing>();
    private AdaptiveClusteringConfig configProps = StateMgmtFactory.getAdaptiveClusteringConfig();
    private ConfigStore configStore = ConfigStoreFarm.getConfig("org.sourceid.saml20.service.impl.grouprpc.DistributedStateCoordinatorImpl");
    protected RebalanceWorker rebalanceWorker;
    private Object retrievingRecordsLock = new Object();

    public DistributedStateCoordinatorImpl() {
        this(true);
    }

    public DistributedStateCoordinatorImpl(boolean initializeClustering) {
        super(initializeClustering);
    }

    @Override
    public void initialize() {
        this.rebalanceWorker = new RebalanceWorker(this, StateMgmtFactory.getStateServiceRegistry(), this.configProps, this.configStore);
        if (PropertyInfo.isAdaptiveClusteringEnabled()) {
            MgmtFactory.getNodeIndexRegistry().addListener(this);
            if (this.isDoRebalance()) {
                try {
                    this.acquireRebalanceLock();
                    this.rebalanceWorker.start();
                    MgmtFactory.getNodeIndexRegistry().setLocalNodeStateTracking(this.configProps.isStateTracking());
                    this.rebalanceWorker.waitForInitialRebalanceToComplete();
                }
                finally {
                    this.releaseRebalanceLock();
                }
            } else {
                this.onNodeIndexRegistryChanged();
            }
        }
    }

    @Override
    public Vector<Address> getReplicas(String partitionKey) {
        Collection<Object> replicas = Collections.emptyList();
        if (this.localHashRing != null) {
            replicas = this.localHashRing.getReplicaSet(partitionKey);
        }
        return this.replicasToAddresses(replicas);
    }

    @Override
    public Map<String, Vector<Address>> getNodeGroupReplicas(String partitionKey) {
        return this.getNodeGroupReplicas(partitionKey, false);
    }

    @Override
    public Vector<Address> getNodeGroupReplicas(String groupId, String partitionKey) {
        ConsistentHashRing ring = this.nodeGroupHashRings.get(groupId);
        if (ring != null) {
            return this.replicasToAddresses(ring.getReplicaSet(partitionKey));
        }
        return new Vector<Address>();
    }

    @Override
    public Vector<Address> getOtherNodeGroupReplicas(String partitionKey) {
        if (this.nodeGroupHashRings.size() <= 1) {
            return new Vector<Address>();
        }
        Vector<Address> result = new Vector<Address>();
        Map<String, Vector<Address>> nodeGroupReplicas = this.getNodeGroupReplicas(partitionKey, true);
        for (Map.Entry<String, Vector<Address>> entry : nodeGroupReplicas.entrySet()) {
            result.addAll((Collection<Address>)entry.getValue());
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Collection<RangeRecords> retrieveStateRecords(StateServiceId serviceId, Collection<Range> ranges) {
        Object object = this.retrievingRecordsLock;
        synchronized (object) {
            StateService<?> service = this.getStateServiceRegistry().getService(serviceId);
            if (service == null) {
                throw new RuntimeException("Service " + serviceId + " not found");
            }
            log.debug((Object)("Retrieving records for service " + service.getServiceId() + (String)(log.isTraceEnabled() ? " and ranges " + ranges : "")));
            int totalRecordsRetrieved = 0;
            ArrayList<RangeRecords> result = new ArrayList<RangeRecords>();
            for (Range range : ranges) {
                if (!this.rebalanceWorker.hasRange(serviceId, range)) {
                    log.trace((Object)("Range " + range + " not found for service " + serviceId));
                    result.add(new RangeRecords(range, Collections.emptyList(), false));
                    continue;
                }
                Collection<?> records2 = service.getRecords(range.getStart(), range.getEnd());
                if (records2.size() > 0) {
                    log.trace((Object)("Returning " + records2.size() + " records for range " + range));
                }
                result.add(new RangeRecords(range, records2, true));
                if ((totalRecordsRetrieved += records2.size()) < this.getRebalanceBatchSize()) continue;
                break;
            }
            Collection retrievedRanges = result.stream().map(records -> records.getRange()).collect(Collectors.toList());
            this.doRetrieveDelay();
            log.debug((Object)("Retrieved " + totalRecordsRetrieved + " records for service " + service.getServiceId() + (String)(log.isTraceEnabled() ? " across ranges " + retrievedRanges : "")));
            return result;
        }
    }

    @Override
    public void onNodeIndexRegistryChanged() {
        List<Node> nodes = MgmtFactory.getNodeIndexRegistry().getNodes();
        this.updateNodeGroupHashRings(nodes);
        ConsistentHashRing newLocalRing = this.nodeGroupHashRings.get(this.getNodeGroupId());
        if (newLocalRing == null) {
            newLocalRing = this.ringFromNodes(Collections.emptyList());
        }
        Set newHashRingNodes = newLocalRing.getNodes();
        if (this.localHashRing == null || !newHashRingNodes.equals(this.localHashRing.getNodes())) {
            log.debug((Object)("State tracking nodes updated: " + newHashRingNodes));
            ConsistentHashRing oldRing = this.localHashRing;
            this.localHashRing = newLocalRing;
            if (this.isDoRebalance()) {
                Address localAddress = this.getLocalAddress();
                if (!newHashRingNodes.stream().anyMatch(node -> ((AddressNode)((Object)node)).getAddress().equals(localAddress))) {
                    log.debug((Object)"Not running rebalance task, as local node is not found or not yet enabled for state tracking");
                } else {
                    this.rebalanceWorker.submit(new RebalanceTask(oldRing, this.localHashRing, localAddress));
                }
            }
        }
    }

    @Override
    public void replicateCacheAccess(String primarySessionId, String extendedSessionId) {
        Vector<Address> nodes = this.getOtherNodeGroupReplicas(SessionIdUtil.getInstance().getSriFromPrimaryValue(primarySessionId));
        if (!nodes.isEmpty()) {
            this.callRemoteMethods(nodes, REPLICATE_CACHE_ACCESS, (Class[])REPLICATE_CACHE_ACCESS_SIG, false, this.getRpcTimeoutMillis(), primarySessionId, extendedSessionId);
        }
    }

    @Override
    public boolean isAnyRebalanceInProgress() {
        if (!PropertyInfo.isAdaptiveClusteringEnabled()) {
            return false;
        }
        String nodeGroupId = this.getNodeGroupId();
        Vector addresses = MgmtFactory.getNodeIndexRegistry().getNodes().stream().filter(node -> node.isStateTracking() && (StringUtils.isEmpty((CharSequence)nodeGroupId) || nodeGroupId.equals(node.getNodeGroupId()))).map(Node::getAddress).collect(Collectors.toCollection(Vector::new));
        if (addresses.isEmpty()) {
            return false;
        }
        RspList rspList = this.callRemoteMethods((Vector<Address>)addresses, IS_REBALANCE_IN_PROGRESS, (Class[])IS_REBALANCE_IN_PROGRESS_SIG, true, this.getRpcTimeoutMillis(), new Object[0]);
        return RspUtil.getAnyTrue(rspList);
    }

    @Override
    public boolean isRebalanceInProgress() {
        return this.rebalanceWorker.isRebalanceInProgress();
    }

    @Override
    public void prepareForShutdown() {
        if (!PropertyInfo.isAdaptiveClusteringEnabled() || !MgmtFactory.getNodeIndexRegistry().isLocalNodeStateTracking()) {
            log.debug((Object)"No runtime state rebalancing is required");
            return;
        }
        try {
            this.acquireRebalanceLock();
            log.info((Object)"Rebalancing runtime state to other nodes in the cluster");
            MgmtFactory.getNodeIndexRegistry().setLocalNodeStateTracking(false);
            boolean done = false;
            while (!done) {
                try {
                    Thread.sleep(this.configStore.getLongValue(REBALANCE_CHECK_INTERVAL_MILLIS, 3000L));
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
                if (done = !this.isAnyRebalanceInProgress()) continue;
                log.debug((Object)"One or more nodes are still rebalancing, waiting for all rebalancing to complete ...");
            }
            log.info((Object)"Runtime state rebalancing completed");
        }
        finally {
            this.releaseRebalanceLock();
        }
    }

    protected Map<String, Vector<Address>> getNodeGroupReplicas(String partitionKey, boolean excludeLocal) {
        HashMap<String, Vector<Address>> result = new HashMap<String, Vector<Address>>();
        for (Map.Entry<String, ConsistentHashRing> entry : this.nodeGroupHashRings.entrySet()) {
            if (excludeLocal && entry.getKey().equals(this.getNodeGroupId())) continue;
            result.put(entry.getKey(), this.replicasToAddresses(entry.getValue().getReplicaSet(partitionKey)));
        }
        return result;
    }

    protected Vector<Address> replicasToAddresses(Collection<HashRingNode> replicas) {
        Vector<Address> result = new Vector<Address>();
        for (HashRingNode node : replicas) {
            AddressNode addressNode = (AddressNode)node;
            result.add(addressNode.getAddress());
        }
        return result;
    }

    protected String getNodeGroupId() {
        return this.configProps.getNodeGroupId();
    }

    protected int getRebalanceBatchSize() {
        return this.configProps.getRebalanceBatchSize();
    }

    protected StateServiceRegistry getStateServiceRegistry() {
        return StateMgmtFactory.getStateServiceRegistry();
    }

    protected ConsistentHashRing ringFromNodes(Collection<? extends HashRingNode> nodes) {
        return new ConsistentHashRing(nodes, this.configProps.getVirtualNodeCount(), this.configProps.getReplicationFactor(), this.configProps.getHashFunction());
    }

    private void acquireRebalanceLock() {
        if (!this.configProps.isRebalanceLockRequired()) {
            log.debug((Object)"Requiring lock to perform rebalancing is disabled. Moving on without obtaining a rebalancing lock.");
            return;
        }
        try {
            log.debug((Object)"Waiting to acquire central rebalance lock");
            StateMgmtFactory.getClusterLockService().getLock(CENTRAL_REBALANCE_LOCK);
            log.debug((Object)"Central rebalance lock acquired");
        }
        catch (InterruptedException e) {
            log.error((Object)"Unexpected error while waiting to acquire initial rebalance lock", (Throwable)e);
        }
    }

    private void releaseRebalanceLock() {
        if (!this.configProps.isRebalanceLockRequired()) {
            log.debug((Object)"Requiring lock to perform rebalancing is disabled. Moving on without requiring to release the central rebalancing lock.");
            return;
        }
        log.debug((Object)"Releasing central rebalance lock");
        StateMgmtFactory.getClusterLockService().releaseLock(CENTRAL_REBALANCE_LOCK);
    }

    private int getRpcTimeoutMillis() {
        return this.configStore.getIntValue(RPC_TIMEOUT_MILLIS, 1000);
    }

    private void doRetrieveDelay() {
        long retrieveRecordsDelayMillis = this.configStore.getLongValue(RETRIEVE_RECORDS_DELAY_MILLIS, 0L);
        if (retrieveRecordsDelayMillis > 0L) {
            log.debug((Object)("Sleeping " + retrieveRecordsDelayMillis + " milliseconds before returning records"));
            try {
                Thread.sleep(retrieveRecordsDelayMillis);
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        }
    }

    private boolean isDoRebalance() {
        return this.configProps.isRebalanceEnabled() && this.configProps.isStateTracking();
    }

    private void updateNodeGroupHashRings(Collection<Node> newNodes) {
        HashMap<String, HashSet<AddressNode>> nodeGroups = new HashMap<String, HashSet<AddressNode>>();
        for (Node node : newNodes) {
            if (!node.isStateTracking()) continue;
            String nodeGroupId = node.getNodeGroupId();
            HashSet<AddressNode> ringNodes = (HashSet<AddressNode>)nodeGroups.get(nodeGroupId);
            if (ringNodes == null) {
                ringNodes = new HashSet<AddressNode>();
                nodeGroups.put(nodeGroupId, ringNodes);
            }
            ringNodes.add(new AddressNode(node.getAddress()));
        }
        HashMap<String, ConsistentHashRing> newRings = new HashMap<String, ConsistentHashRing>();
        for (Map.Entry nodeGroupEntry : nodeGroups.entrySet()) {
            newRings.put((String)nodeGroupEntry.getKey(), this.ringFromNodes((Collection)nodeGroupEntry.getValue()));
        }
        this.nodeGroupHashRings = newRings;
    }

    public static class DistributedStateCoordinatorRpcTarget
    implements RpcTarget {
        private InterReqStateMgmtMapImpl interReqStateMgmtMapImpl;
        private IdpSessionRegistryMapImpl idpSessionRegistryMapImpl;
        private SpSessionRegistryMapImpl spSessionRegistryMapImpl;

        public DistributedStateCoordinatorRpcTarget(InterReqStateMgmtMapImpl interReqStateMgmtMapImpl, IdpSessionRegistryMapImpl idpSessionRegistryMapImpl, SpSessionRegistryMapImpl spSessionRegistryMapImpl) {
            this.interReqStateMgmtMapImpl = interReqStateMgmtMapImpl;
            this.idpSessionRegistryMapImpl = idpSessionRegistryMapImpl;
            this.spSessionRegistryMapImpl = spSessionRegistryMapImpl;
        }

        @Override
        public Collection<RangeRecords> retrieveStateRecords(StateServiceId serviceId, Collection<Range> ranges) {
            return StateMgmtFactory.getDistributedStateCoordinator().retrieveStateRecords(serviceId, ranges);
        }

        @Override
        public void replicateCacheAccess(String primarySessionId, String extendedSessionId) {
            if (extendedSessionId != null) {
                this.interReqStateMgmtMapImpl.updateCacheAccessTime(extendedSessionId);
            }
            this.spSessionRegistryMapImpl.updateCacheAccessTime(primarySessionId);
            String sri = SessionIdUtil.getInstance().getSriFromPrimaryValue(primarySessionId);
            try {
                this.idpSessionRegistryMapImpl.updateCacheAccessTime(sri);
            }
            catch (SessionRegistryException e) {
                log.error((Object)("Error updating cache access time for SRI " + sri), (Throwable)e);
            }
        }

        @Override
        public boolean isRebalanceInProgress() {
            return StateMgmtFactory.getDistributedStateCoordinator().isRebalanceInProgress();
        }
    }

    public static interface RpcTarget {
        public Collection<RangeRecords> retrieveStateRecords(StateServiceId var1, Collection<Range> var2);

        public void replicateCacheAccess(String var1, String var2);

        public boolean isRebalanceInProgress();
    }
}

