package gospl.algo.ipf;

import core.metamodel.IPopulation;
import core.metamodel.attribute.Attribute;
import core.metamodel.entity.ADemoEntity;
import core.metamodel.value.IValue;
import core.util.GSKeywords;
import core.util.GSPerformanceUtil;
import core.util.random.GenstarRandom;
import gama.dev.DEBUG;
import gospl.algo.ipf.margin.Margin;
import gospl.algo.ipf.margin.MarginDescriptor;
import gospl.algo.ipf.margin.MarginalsIPFBuilder;
import gospl.distribution.matrix.AFullNDimensionalMatrix;
import gospl.distribution.matrix.INDimensionalMatrix;
import gospl.distribution.matrix.control.AControl;
import gospl.distribution.matrix.control.ControlFrequency;
import gospl.distribution.matrix.coordinate.ACoordinate;
import java.lang.Number;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:gospl/algo/ipf/AGosplIPF.class */
public abstract class AGosplIPF<T extends Number> {
    private int step;
    private double delta;
    protected IPopulation<ADemoEntity, Attribute<? extends IValue>> sampleSeed;
    protected INDimensionalMatrix<Attribute<? extends IValue>, IValue, T> marginals;
    protected MarginalsIPFBuilder<T> marginalProcessor;

    protected AGosplIPF(IPopulation<ADemoEntity, Attribute<? extends IValue>> iPopulation, MarginalsIPFBuilder<T> marginalsIPFBuilder, int i, double d) {
        this.step = 100;
        this.delta = Math.pow(10.0d, -4.0d);
        this.sampleSeed = iPopulation;
        this.marginalProcessor = marginalsIPFBuilder;
        this.step = i;
        this.delta = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AGosplIPF(IPopulation<ADemoEntity, Attribute<? extends IValue>> iPopulation, int i, double d) {
        this(iPopulation, new MarginalsIPFBuilder(), i, d);
    }

    protected AGosplIPF(IPopulation<ADemoEntity, Attribute<? extends IValue>> iPopulation, MarginalsIPFBuilder<T> marginalsIPFBuilder) {
        this.step = 100;
        this.delta = Math.pow(10.0d, -4.0d);
        this.sampleSeed = iPopulation;
        this.marginalProcessor = marginalsIPFBuilder;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AGosplIPF(IPopulation<ADemoEntity, Attribute<? extends IValue>> iPopulation) {
        this(iPopulation, new MarginalsIPFBuilder());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setMarginalMatrix(INDimensionalMatrix<Attribute<? extends IValue>, IValue, T> iNDimensionalMatrix) {
        this.marginals = iNDimensionalMatrix;
    }

    protected void setMaxStep(int i) {
        this.step = i;
    }

    protected void setMaxDelta(double d) {
        this.delta = d;
    }

    public abstract AFullNDimensionalMatrix<T> process();

    public AFullNDimensionalMatrix<T> process(double d, int i) {
        this.delta = d;
        this.step = i;
        return process();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AFullNDimensionalMatrix<T> process(AFullNDimensionalMatrix<T> aFullNDimensionalMatrix) {
        if (aFullNDimensionalMatrix.getDimensions().stream().noneMatch(attribute -> {
            return this.marginals.getDimensions().contains(attribute) || this.marginals.getDimensions().contains(attribute.getReferentAttribute());
        })) {
            throw new IllegalArgumentException("Output distribution and sample seed does not have any matching dimensions\nDistribution: " + Arrays.toString(this.marginals.getDimensions().toArray()) + "\nSample seed: :" + Arrays.toString(aFullNDimensionalMatrix.getDimensions().toArray()));
        }
        List<Attribute<? extends IValue>> list = aFullNDimensionalMatrix.getDimensions().stream().filter(attribute2 -> {
            return this.marginals.getDimensions().contains(attribute2) || this.marginals.getDimensions().contains(attribute2.getReferentAttribute());
        }).toList();
        GSPerformanceUtil gSPerformanceUtil = new GSPerformanceUtil("*** IPF PROCEDURE ***", GSPerformanceUtil.Level.INFO);
        gSPerformanceUtil.sysoStempPerformance(0, (Object) this);
        gSPerformanceUtil.sysoStempMessage(((list.size() / aFullNDimensionalMatrix.getDimensions().size()) * 100.0d) + "% of samples dimensions will be estimate with output controls");
        gSPerformanceUtil.sysoStempMessage("Sample seed controls' dimension: " + ((String) aFullNDimensionalMatrix.getDimensions().stream().map(attribute3 -> {
            return attribute3.getAttributeName() + " = " + attribute3.getValueSpace2().getValues().size();
        }).collect(Collectors.joining(GSKeywords.SERIALIZE_ELEMENT_SEPARATOR))));
        Collection<Margin<T>> buildCompliantMarginals = this.marginalProcessor.buildCompliantMarginals(this.marginals, aFullNDimensionalMatrix);
        int i = this.step;
        int sum = buildCompliantMarginals.stream().mapToInt((v0) -> {
            return v0.size();
        }).sum();
        gSPerformanceUtil.sysoStempMessage("Convergence criterias are: step = " + this.step + " | delta = " + this.delta);
        double doubleValue = this.marginals.getVal().getValue().doubleValue();
        double sum2 = buildCompliantMarginals.stream().mapToDouble(margin -> {
            return margin.getMarginDescriptors().stream().mapToDouble(marginDescriptor -> {
                return Math.abs(aFullNDimensionalMatrix.getVal((Collection<IValue>) marginDescriptor.getSeed()).getDiff(margin.getControl(marginDescriptor)).doubleValue()) / doubleValue;
            }).sum();
        }).sum() / sum;
        gSPerformanceUtil.sysoStempMessage("Start fitting iterations with AAPD = " + sum2);
        double d = Double.MAX_VALUE;
        while (true) {
            int i2 = i;
            i--;
            if ((i2 <= 0 || sum2 <= this.delta) && d >= this.delta) {
                aFullNDimensionalMatrix.normalize();
                int i3 = this.step - i;
                gSPerformanceUtil.sysoStempMessage("IPF fitting ends with final " + sum2 + " AAPD value and " + gSPerformanceUtil + " iteration(s)");
                return aFullNDimensionalMatrix;
            }
            if (i % ((int) (this.step * 0.1d)) == 0.0d) {
                gSPerformanceUtil.sysoStempMessage("Step = " + (this.step - i) + " | average error = " + sum2, GSPerformanceUtil.Level.DEBUG);
            }
            for (Margin<T> margin2 : buildCompliantMarginals) {
                for (MarginDescriptor marginDescriptor : margin2.getMarginDescriptors()) {
                    double doubleValue2 = margin2.getControl(marginDescriptor).getValue().doubleValue();
                    double doubleValue3 = aFullNDimensionalMatrix.getVal((Collection<IValue>) marginDescriptor.getSeed()).getValue().doubleValue();
                    ControlFrequency controlFrequency = new ControlFrequency(Double.valueOf(doubleValue2 / (doubleValue3 == 0.0d ? doubleValue2 : doubleValue3)));
                    for (ACoordinate<Attribute<? extends IValue>, IValue> aCoordinate : aFullNDimensionalMatrix.getCoordinates(marginDescriptor.getSeed())) {
                        if (doubleValue3 == 0.0d && doubleValue2 > 0.0d) {
                            aFullNDimensionalMatrix.setValue(aCoordinate, aFullNDimensionalMatrix.getAtomicVal());
                        }
                        AControl<T> val = aFullNDimensionalMatrix.getVal(aCoordinate);
                        double doubleValue4 = val.getValue().doubleValue();
                        val.multiply(controlFrequency);
                        if (DEBUG.IS_ON() && GenstarRandom.getInstance().nextDouble() < 0.01d) {
                            String valueOf = String.valueOf(aCoordinate);
                            double doubleValue5 = controlFrequency.getValue().doubleValue();
                            val.getValue().doubleValue();
                            gSPerformanceUtil.sysoStempMessage("Coord " + valueOf + ":\n AV = " + doubleValue4 + " | Factor = " + gSPerformanceUtil + " | UV = " + doubleValue5, GSPerformanceUtil.Level.TRACE);
                        }
                    }
                }
            }
            double sum3 = buildCompliantMarginals.stream().mapToDouble(margin3 -> {
                return margin3.getMarginDescriptors().stream().mapToDouble(marginDescriptor2 -> {
                    return Math.abs(aFullNDimensionalMatrix.getVal((Collection<IValue>) marginDescriptor2.getSeed()).getDiff(margin3.getControl(marginDescriptor2)).doubleValue()) / doubleValue;
                }).sum();
            }).sum() / sum;
            d = Math.abs(sum2 - sum3);
            sum2 = sum3;
        }
    }
}
