package cc.kave.repackaged.jayes.inference.junctionTree;

import cc.kave.repackaged.jayes.BayesNet;
import cc.kave.repackaged.jayes.BayesNode;
import cc.kave.repackaged.jayes.factor.AbstractFactor;
import cc.kave.repackaged.jayes.factor.arraywrapper.DoubleArrayWrapper;
import cc.kave.repackaged.jayes.factor.arraywrapper.IArrayWrapper;
import cc.kave.repackaged.jayes.inference.AbstractInferer;
import cc.kave.repackaged.jayes.internal.util.ArrayUtils;
import cc.kave.repackaged.jayes.util.Graph;
import cc.kave.repackaged.jayes.util.MathUtils;
import cc.kave.repackaged.jayes.util.NumericalInstabilityException;
import cc.kave.repackaged.jayes.util.Pair;
import cc.kave.repackaged.jayes.util.sharing.CanonicalArrayWrapperManager;
import cc.kave.repackaged.jayes.util.sharing.CanonicalIntArrayManager;
import cc.kave.repackaged.jayes.util.triangulation.MinFillIn;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:lib/jayes-0.0.2.jar:cc/kave/repackaged/jayes/inference/junctionTree/JunctionTreeAlgorithm.class */
public class JunctionTreeAlgorithm extends AbstractInferer {
    protected Map<Graph.Edge, AbstractFactor> sepSets;
    protected Graph junctionTree;
    protected AbstractFactor[] nodePotentials;
    protected IdentityHashMap<Graph.Edge, int[]> preparedMultiplications;
    protected int[][] concernedClusters;
    protected AbstractFactor[] queryFactors;
    protected int[][] preparedQueries;
    protected boolean[] isBeliefValid;
    protected List<Pair<AbstractFactor, IArrayWrapper>> initializations;
    protected int[][] queryFactorReverseMapping;
    protected Set<Integer> clustersHavingEvidence;
    protected boolean[] isObserved;
    protected double[] scratchpad;
    protected JunctionTreeBuilder junctionTreeBuilder = JunctionTreeBuilder.forHeuristic(new MinFillIn());

    public void setJunctionTreeBuilder(JunctionTreeBuilder junctionTreeBuilder) {
        this.junctionTreeBuilder = junctionTreeBuilder;
    }

    @Override // cc.kave.repackaged.jayes.inference.AbstractInferer, cc.kave.repackaged.jayes.inference.IBayesInferer
    public double[] getBeliefs(BayesNode bayesNode) {
        if (!this.beliefsValid) {
            this.beliefsValid = true;
            updateBeliefs();
        }
        int id = bayesNode.getId();
        if (!this.isBeliefValid[id]) {
            this.isBeliefValid[id] = true;
            if (this.evidence.containsKey(bayesNode)) {
                Arrays.fill(this.beliefs[id], 0.0d);
                this.beliefs[id][bayesNode.getOutcomeIndex(this.evidence.get(bayesNode))] = 1.0d;
            } else {
                validateBelief(id);
            }
        }
        return super.getBeliefs(bayesNode);
    }

    private void validateBelief(int i) {
        AbstractFactor abstractFactor = this.queryFactors[i];
        abstractFactor.sumPrepared(new DoubleArrayWrapper(this.beliefs[i]), this.preparedQueries[i]);
        if (abstractFactor.isLogScale()) {
            MathUtils.exp(this.beliefs[i]);
        }
        try {
            this.beliefs[i] = MathUtils.normalize(this.beliefs[i]);
        } catch (IllegalArgumentException e) {
            throw new NumericalInstabilityException("Numerical instability detected for evidence: " + this.evidence + " and node : " + i + ", consider using logarithmic scale computation (configurable in FactorFactory)", e);
        }
    }

    @Override // cc.kave.repackaged.jayes.inference.AbstractInferer
    protected void updateBeliefs() {
        Arrays.fill(this.isBeliefValid, false);
        doUpdateBeliefs();
    }

    private void doUpdateBeliefs() {
        incorporateAllEvidence();
        int findPropagationRoot = findPropagationRoot();
        replayFactorInitializations();
        collectEvidence(findPropagationRoot, skipCollection(findPropagationRoot));
        distributeEvidence(findPropagationRoot, skipDistribution(findPropagationRoot));
    }

    private void replayFactorInitializations() {
        for (Pair<AbstractFactor, IArrayWrapper> pair : this.initializations) {
            pair.getFirst().copyValues(pair.getSecond());
        }
    }

    private void incorporateAllEvidence() {
        Iterator<Pair<AbstractFactor, IArrayWrapper>> it = this.initializations.iterator();
        while (it.hasNext()) {
            it.next().getFirst().resetSelections();
        }
        this.clustersHavingEvidence.clear();
        Arrays.fill(this.isObserved, false);
        Iterator<BayesNode> it2 = this.evidence.keySet().iterator();
        while (it2.hasNext()) {
            incorporateEvidence(it2.next());
        }
    }

    private void incorporateEvidence(BayesNode bayesNode) {
        int id = bayesNode.getId();
        this.isObserved[id] = true;
        for (int i : this.concernedClusters[id]) {
            Integer valueOf = Integer.valueOf(i);
            this.nodePotentials[valueOf.intValue()].select(id, bayesNode.getOutcomeIndex(this.evidence.get(bayesNode)));
            this.clustersHavingEvidence.add(valueOf);
        }
    }

    private int findPropagationRoot() {
        int i = 0;
        Iterator<BayesNode> it = this.evidence.keySet().iterator();
        while (it.hasNext()) {
            i = this.concernedClusters[it.next().getId()][0];
        }
        return i;
    }

    private Set<Integer> skipCollection(int i) {
        HashSet hashSet = new HashSet(this.nodePotentials.length);
        recursiveSkipCollection(i, new HashSet(this.nodePotentials.length), hashSet);
        return hashSet;
    }

    private void recursiveSkipCollection(int i, Set<Integer> set, Set<Integer> set2) {
        set.add(Integer.valueOf(i));
        boolean z = true;
        for (Graph.Edge edge : this.junctionTree.getIncidentEdges(i)) {
            if (!set.contains(edge.getSecond())) {
                recursiveSkipCollection(edge.getSecond().intValue(), set, set2);
                if (!set2.contains(edge.getSecond())) {
                    z = false;
                }
            }
        }
        if (!z || this.clustersHavingEvidence.contains(Integer.valueOf(i))) {
            return;
        }
        set2.add(Integer.valueOf(i));
    }

    private Set<Integer> skipDistribution(int i) {
        HashSet hashSet = new HashSet(this.nodePotentials.length);
        recursiveSkipDistribution(i, new HashSet(this.nodePotentials.length), hashSet);
        return hashSet;
    }

    private void recursiveSkipDistribution(int i, Set<Integer> set, Set<Integer> set2) {
        set.add(Integer.valueOf(i));
        boolean z = true;
        for (Graph.Edge edge : this.junctionTree.getIncidentEdges(i)) {
            if (!set.contains(edge.getSecond())) {
                recursiveSkipDistribution(edge.getSecond().intValue(), set, set2);
                if (!set2.contains(edge.getSecond())) {
                    z = false;
                }
            }
        }
        if (!z || isQueryFactorOfUnobservedVariable(i)) {
            return;
        }
        set2.add(Integer.valueOf(i));
    }

    private boolean isQueryFactorOfUnobservedVariable(int i) {
        for (int i2 : this.queryFactorReverseMapping[i]) {
            if (!this.isObserved[i2]) {
                return true;
            }
        }
        return false;
    }

    private void collectEvidence(int i, Set<Integer> set) {
        set.add(Integer.valueOf(i));
        for (Graph.Edge edge : this.junctionTree.getIncidentEdges(i)) {
            if (!set.contains(edge.getSecond())) {
                collectEvidence(edge.getSecond().intValue(), set);
                messagePass(edge.getBackEdge());
            }
        }
    }

    private void distributeEvidence(int i, Set<Integer> set) {
        set.add(Integer.valueOf(i));
        for (Graph.Edge edge : this.junctionTree.getIncidentEdges(i)) {
            if (!set.contains(edge.getSecond())) {
                messagePass(edge);
                distributeEvidence(edge.getSecond().intValue(), set);
            }
        }
    }

    private void messagePass(Graph.Edge edge) {
        AbstractFactor abstractFactor = this.sepSets.get(edge);
        if (needMessagePass(abstractFactor)) {
            IArrayWrapper values = abstractFactor.getValues();
            System.arraycopy(values.toDoubleArray(), 0, this.scratchpad, 0, values.length());
            this.nodePotentials[edge.getFirst().intValue()].sumPrepared(values, this.preparedMultiplications.get(edge.getBackEdge()));
            if (isOnlyFirstLogScale(edge)) {
                MathUtils.exp(values);
            }
            if (areBothEndsLogScale(edge)) {
                MathUtils.secureSubtract(values.toDoubleArray(), this.scratchpad, this.scratchpad);
            } else {
                MathUtils.secureDivide(values.toDoubleArray(), this.scratchpad, this.scratchpad);
            }
            if (isOnlySecondLogScale(edge)) {
                MathUtils.log(this.scratchpad);
            }
            this.nodePotentials[edge.getSecond().intValue()].multiplyPrepared(new DoubleArrayWrapper(this.scratchpad), this.preparedMultiplications.get(edge));
        }
    }

    private boolean needMessagePass(AbstractFactor abstractFactor) {
        for (int i : abstractFactor.getDimensionIDs()) {
            if (!this.isObserved[i]) {
                return true;
            }
        }
        return false;
    }

    private boolean isOnlyFirstLogScale(Graph.Edge edge) {
        return this.nodePotentials[edge.getFirst().intValue()].isLogScale() && !this.nodePotentials[edge.getSecond().intValue()].isLogScale();
    }

    private boolean isOnlySecondLogScale(Graph.Edge edge) {
        return !this.nodePotentials[edge.getFirst().intValue()].isLogScale() && this.nodePotentials[edge.getSecond().intValue()].isLogScale();
    }

    @Override // cc.kave.repackaged.jayes.inference.AbstractInferer, cc.kave.repackaged.jayes.inference.IBayesInferer
    public void setNetwork(BayesNet bayesNet) {
        super.setNetwork(bayesNet);
        initializeFields(bayesNet.getNodes().size());
        JunctionTree buildJunctionTree = buildJunctionTree(bayesNet);
        int[] computeHomeClusters = computeHomeClusters(bayesNet, buildJunctionTree.getClusters());
        initializeClusterFactors(bayesNet, buildJunctionTree.getClusters(), computeHomeClusters);
        initializeSepsetFactors(buildJunctionTree.getSepSets());
        determineConcernedClusters();
        setQueryFactors();
        initializePotentialValues();
        multiplyCPTsIntoPotentials(bayesNet, computeHomeClusters);
        prepareMultiplications();
        prepareScratch();
        invokeInitialBeliefUpdate();
        storePotentialValues();
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [int[], int[][]] */
    private void determineConcernedClusters() {
        this.concernedClusters = new int[this.queryFactors.length];
        List[] listArr = new List[this.concernedClusters.length];
        for (int i = 0; i < listArr.length; i++) {
            listArr[i] = new ArrayList();
        }
        for (int i2 = 0; i2 < this.nodePotentials.length; i2++) {
            for (int i3 : this.nodePotentials[i2].getDimensionIDs()) {
                listArr[i3].add(Integer.valueOf(i2));
            }
        }
        for (int i4 = 0; i4 < listArr.length; i4++) {
            this.concernedClusters[i4] = ArrayUtils.toIntArray(listArr[i4]);
        }
    }

    /* JADX WARN: Type inference failed for: r1v8, types: [int[], int[][]] */
    private void initializeFields(int i) {
        this.isBeliefValid = new boolean[this.beliefs.length];
        Arrays.fill(this.isBeliefValid, false);
        this.queryFactors = new AbstractFactor[i];
        this.preparedQueries = new int[i];
        this.sepSets = new HashMap();
        this.preparedMultiplications = new IdentityHashMap<>();
        this.initializations = new ArrayList();
        this.clustersHavingEvidence = new HashSet();
        this.isObserved = new boolean[i];
    }

    private JunctionTree buildJunctionTree(BayesNet bayesNet) {
        JunctionTree buildJunctionTree = this.junctionTreeBuilder.buildJunctionTree(bayesNet);
        this.junctionTree = buildJunctionTree.getGraph();
        return buildJunctionTree;
    }

    private int[] computeHomeClusters(BayesNet bayesNet, List<List<Integer>> list) {
        int[] iArr = new int[bayesNet.getNodes().size()];
        for (BayesNode bayesNode : bayesNet.getNodes()) {
            List<Integer> nodeAndParentIds = getNodeAndParentIds(bayesNode);
            ListIterator<List<Integer>> listIterator = list.listIterator();
            while (true) {
                if (!listIterator.hasNext()) {
                    break;
                }
                if (listIterator.next().containsAll(nodeAndParentIds)) {
                    iArr[bayesNode.getId()] = listIterator.nextIndex() - 1;
                    break;
                }
            }
        }
        return iArr;
    }

    private List<Integer> getNodeAndParentIds(BayesNode bayesNode) {
        ArrayList arrayList = new ArrayList(bayesNode.getParents().size() + 1);
        arrayList.add(Integer.valueOf(bayesNode.getId()));
        Iterator<BayesNode> it = bayesNode.getParents().iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(it.next().getId()));
        }
        return arrayList;
    }

    private void initializeClusterFactors(BayesNet bayesNet, List<List<Integer>> list, int[] iArr) {
        this.nodePotentials = new AbstractFactor[list.size()];
        Map<Integer, List<AbstractFactor>> findMultiplicationPartners = findMultiplicationPartners(bayesNet, iArr);
        ListIterator<List<Integer>> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            List<Integer> next = listIterator.next();
            int nextIndex = listIterator.nextIndex() - 1;
            List<AbstractFactor> list2 = findMultiplicationPartners.get(Integer.valueOf(nextIndex));
            this.nodePotentials[nextIndex] = this.factory.create(next, list2 == null ? Collections.emptyList() : list2);
        }
    }

    private Map<Integer, List<AbstractFactor>> findMultiplicationPartners(BayesNet bayesNet, int[] iArr) {
        HashMap hashMap = new HashMap();
        for (BayesNode bayesNode : bayesNet.getNodes()) {
            Integer valueOf = Integer.valueOf(iArr[bayesNode.getId()]);
            if (!hashMap.containsKey(valueOf)) {
                hashMap.put(valueOf, new ArrayList());
            }
            ((List) hashMap.get(valueOf)).add(bayesNode.getFactor());
        }
        return hashMap;
    }

    private void initializeSepsetFactors(List<Pair<Graph.Edge, List<Integer>>> list) {
        for (Pair<Graph.Edge, List<Integer>> pair : list) {
            this.sepSets.put(pair.getFirst(), this.factory.create(pair.getSecond(), Collections.emptyList()));
        }
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [int[], int[][]] */
    private void setQueryFactors() {
        for (int i = 0; i < this.queryFactors.length; i++) {
            for (int i2 : this.concernedClusters[i]) {
                Integer valueOf = Integer.valueOf(i2);
                if (this.queryFactors[i] == null || this.queryFactors[i].getValues().length() > this.nodePotentials[valueOf.intValue()].getValues().length()) {
                    this.queryFactors[i] = this.nodePotentials[valueOf.intValue()];
                }
            }
        }
        this.queryFactorReverseMapping = new int[this.nodePotentials.length];
        for (int i3 = 0; i3 < this.nodePotentials.length; i3++) {
            ArrayList arrayList = new ArrayList();
            for (int i4 : this.nodePotentials[i3].getDimensionIDs()) {
                if (this.queryFactors[i4] == this.nodePotentials[i3]) {
                    arrayList.add(Integer.valueOf(i4));
                }
            }
            this.queryFactorReverseMapping[i3] = ArrayUtils.toIntArray(arrayList);
        }
    }

    private void prepareMultiplications() {
        CanonicalIntArrayManager canonicalIntArrayManager = new CanonicalIntArrayManager();
        prepareSepsetMultiplications(canonicalIntArrayManager);
        prepareQueries(canonicalIntArrayManager);
    }

    private void prepareSepsetMultiplications(CanonicalIntArrayManager canonicalIntArrayManager) {
        for (int i = 0; i < this.nodePotentials.length; i++) {
            for (Graph.Edge edge : this.junctionTree.getIncidentEdges(i)) {
                this.preparedMultiplications.put(edge, canonicalIntArrayManager.getInstance(this.nodePotentials[edge.getSecond().intValue()].prepareMultiplication(this.sepSets.get(edge))));
            }
        }
    }

    private void prepareQueries(CanonicalIntArrayManager canonicalIntArrayManager) {
        for (int i = 0; i < this.queryFactors.length; i++) {
            this.preparedQueries[i] = canonicalIntArrayManager.getInstance(this.queryFactors[i].prepareMultiplication(this.factory.create(Arrays.asList(Integer.valueOf(i)), Collections.emptyList())));
        }
    }

    private void prepareScratch() {
        int i = 0;
        Iterator<AbstractFactor> it = this.sepSets.values().iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().getValues().length());
        }
        this.scratchpad = new double[i];
    }

    private void invokeInitialBeliefUpdate() {
        collectEvidence(0, new HashSet());
        distributeEvidence(0, new HashSet());
    }

    private void initializePotentialValues() {
        for (AbstractFactor abstractFactor : this.nodePotentials) {
            abstractFactor.fill(abstractFactor.isLogScale() ? 0.0d : 1.0d);
        }
        for (Map.Entry<Graph.Edge, AbstractFactor> entry : this.sepSets.entrySet()) {
            if (areBothEndsLogScale(entry.getKey())) {
                entry.getValue().fill(0.0d);
            } else {
                entry.getValue().fill(1.0d);
            }
        }
    }

    private void multiplyCPTsIntoPotentials(BayesNet bayesNet, int[] iArr) {
        for (BayesNode bayesNode : bayesNet.getNodes()) {
            AbstractFactor abstractFactor = this.nodePotentials[iArr[bayesNode.getId()]];
            if (abstractFactor.isLogScale()) {
                abstractFactor.multiplyCompatibleToLog(bayesNode.getFactor());
            } else {
                abstractFactor.multiplyCompatible(bayesNode.getFactor());
            }
        }
    }

    private boolean areBothEndsLogScale(Graph.Edge edge) {
        return this.nodePotentials[edge.getFirst().intValue()].isLogScale() && this.nodePotentials[edge.getSecond().intValue()].isLogScale();
    }

    private void storePotentialValues() {
        CanonicalArrayWrapperManager canonicalArrayWrapperManager = new CanonicalArrayWrapperManager();
        for (AbstractFactor abstractFactor : this.nodePotentials) {
            this.initializations.add(Pair.newPair(abstractFactor, canonicalArrayWrapperManager.getInstance(abstractFactor.getValues().m3clone())));
        }
        for (AbstractFactor abstractFactor2 : this.sepSets.values()) {
            this.initializations.add(Pair.newPair(abstractFactor2, canonicalArrayWrapperManager.getInstance(abstractFactor2.getValues().m3clone())));
        }
    }
}
