package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.NormalizedDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;
import java.util.LinkedList;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/mixture/AbstractMixtureDiffSM.class */
public abstract class AbstractMixtureDiffSM extends AbstractDifferentiableStatisticalModel implements SamplingDifferentiableStatisticalModel {
    private int starts;
    protected int[] paramRef;
    protected boolean optimizeHidden;
    protected boolean freeParams;
    private boolean plugIn;
    protected DifferentiableStatisticalModel[] function;
    protected double[] hiddenParameter;
    protected double[] logHiddenPotential;
    protected double[] hiddenPotential;
    protected double[] componentScore;
    protected double[] partNorm;
    protected double norm;
    protected double logHiddenNorm;
    protected double logGammaSum;
    protected DoubleList[] dList;
    protected IntList[] iList;
    private boolean isNormalized;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractMixtureDiffSM(int i, int i2, int i3, boolean z, boolean z2, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        super(differentiableStatisticalModelArr[0].getAlphabetContainer(), i);
        this.function = (DifferentiableStatisticalModel[]) ArrayHandler.clone(differentiableStatisticalModelArr);
        if (i2 < 1) {
            throw new IllegalArgumentException("The number of recommended starts has to be positive.");
        }
        this.starts = i2;
        if (i3 == 0) {
            throw new IllegalArgumentException("The number of components has to be positive.");
        }
        this.isNormalized = determineIsNormalized();
        this.hiddenParameter = new double[i3];
        this.logHiddenPotential = new double[i3];
        this.hiddenPotential = new double[i3];
        this.partNorm = new double[i3];
        setHiddenParameters(this.hiddenParameter, 0);
        this.componentScore = new double[i3];
        this.optimizeHidden = z && i3 > 1;
        this.plugIn = z2;
        this.paramRef = null;
        init(this.freeParams);
        this.norm = Double.NaN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeLogGammaSum() {
        this.logGammaSum = 0.0d;
        int numberOfComponents = getNumberOfComponents();
        if (numberOfComponents <= 1 || getESS() <= 0.0d) {
            return;
        }
        double d = 0.0d;
        for (int i = 0; i < numberOfComponents; i++) {
            double hyperparameterForHiddenParameter = getHyperparameterForHiddenParameter(i);
            d += hyperparameterForHiddenParameter;
            this.logGammaSum -= Gamma.logOfGamma(hyperparameterForHiddenParameter);
        }
        this.logGammaSum += Gamma.logOfGamma(d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractMixtureDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public AbstractMixtureDiffSM mo110clone() throws CloneNotSupportedException {
        AbstractMixtureDiffSM abstractMixtureDiffSM = (AbstractMixtureDiffSM) super.mo110clone();
        abstractMixtureDiffSM.cloneFunctions(this.function);
        abstractMixtureDiffSM.hiddenParameter = (double[]) this.hiddenParameter.clone();
        abstractMixtureDiffSM.logHiddenPotential = (double[]) this.logHiddenPotential.clone();
        abstractMixtureDiffSM.hiddenPotential = (double[]) this.hiddenPotential.clone();
        abstractMixtureDiffSM.componentScore = (double[]) this.componentScore.clone();
        abstractMixtureDiffSM.partNorm = (double[]) this.partNorm.clone();
        abstractMixtureDiffSM.iList = null;
        abstractMixtureDiffSM.paramRef = null;
        abstractMixtureDiffSM.init(this.freeParams);
        return abstractMixtureDiffSM;
    }

    protected void cloneFunctions(DifferentiableStatisticalModel[] differentiableStatisticalModelArr) throws CloneNotSupportedException {
        this.function = (DifferentiableStatisticalModel[]) ArrayHandler.clone(differentiableStatisticalModelArr);
    }

    public abstract double getHyperparameterForHiddenParameter(int i);

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.hiddenParameter.length; i++) {
            double hyperparameterForHiddenParameter = getHyperparameterForHiddenParameter(i);
            d2 += hyperparameterForHiddenParameter;
            d += this.hiddenParameter[i] * hyperparameterForHiddenParameter;
        }
        if (isNormalized()) {
            d -= d2 * this.logHiddenNorm;
        }
        for (int i2 = 0; i2 < this.function.length; i2++) {
            if (this.function[i2] != null) {
                d += this.function[i2].getLogPriorTerm();
            }
        }
        return d + this.logGammaSum;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        for (int i2 = 0; i2 < this.function.length; i2++) {
            if (this.function[i2] != null) {
                this.function[i2].addGradientOfLogPriorTerm(dArr, i + this.paramRef[i2]);
            }
        }
        int i3 = i + this.paramRef[this.function.length + 1];
        double ess = getESS();
        int i4 = 0;
        for (int i5 = i + this.paramRef[this.function.length]; i5 < i3; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] + (getHyperparameterForHiddenParameter(i4) - (isNormalized() ? ess * this.hiddenPotential[i4] : 0.0d));
            i4++;
        }
    }

    public int getIndexOfMaximalComponentFor(Sequence sequence, int i) {
        fillComponentScores(sequence, i);
        return ToolBox.getMaxIndex(this.componentScore);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        int numberOfParameters = getNumberOfParameters();
        if (numberOfParameters == -1) {
            throw new Exception("No parameters exists, yet.");
        }
        double[] dArr = new double[numberOfParameters];
        int length = this.function.length;
        for (int i = 0; i < length; i++) {
            if (this.function[i] != null) {
                double[] currentParameterValues = this.function[i].getCurrentParameterValues();
                System.arraycopy(currentParameterValues, 0, dArr, this.paramRef[i], currentParameterValues.length);
            }
        }
        System.arraycopy(this.hiddenParameter, 0, dArr, this.paramRef[length], this.paramRef[length + 1] - this.paramRef[length]);
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        fillComponentScores(sequence, i);
        return Normalisation.getLogSum(this.componentScore);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public final double getLogNormalizationConstant() {
        if (isNormalized()) {
            return 0.0d;
        }
        if (Double.isNaN(this.norm)) {
            precomputeNorm();
        }
        return this.norm;
    }

    public final int getNumberOfComponents() {
        return this.componentScore.length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public final int getNumberOfParameters() {
        if (this.paramRef == null) {
            return -1;
        }
        return this.paramRef[this.paramRef.length - 1];
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public final int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    public double[] getProbsForComponent(Sequence sequence) {
        fillComponentScores(sequence, 0);
        double[] dArr = new double[this.componentScore.length];
        Normalisation.logSumNormalisation(this.componentScore, 0, dArr.length, dArr, 0);
        return dArr;
    }

    public DifferentiableStatisticalModel[] getDifferentiableStatisticalModels() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel[]) ArrayHandler.clone(this.function);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        int[] indices = getIndices(i);
        return indices[0] == this.function.length ? this.hiddenParameter.length : this.function[indices[0]].getSizeOfEventSpaceForRandomVariablesOfParameter(indices[1]);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (!this.plugIn) {
            initializeFunctionRandomly(z);
            return;
        }
        if (!isInitialized()) {
            initializeFunctionRandomly(z);
        }
        initializeUsingPlugIn(i, z, dataSetArr, dArr);
        init(z);
    }

    protected abstract void initializeUsingPlugIn(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception;

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        for (int i = 0; i < this.function.length; i++) {
            if (this.function[i] != null) {
                this.function[i].initializeFunctionRandomly(z);
            }
        }
        if (this.optimizeHidden) {
            initializeHiddenPotentialRandomly();
        }
        init(z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initializeHiddenPotentialRandomly() {
        double[] dArr = new double[getNumberOfComponents()];
        if (getESS() == 0.0d) {
            Arrays.fill(dArr, 1.0d);
        } else {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = getHyperparameterForHiddenParameter(i);
            }
        }
        DirichletMRG.DEFAULT_INSTANCE.generate(this.hiddenPotential, 0, this.hiddenPotential.length, new DirichletMRGParams(dArr));
        computeHiddenParameter(this.hiddenPotential, false);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        int i = 0;
        while (i < this.function.length && (this.function[i] == null || this.function[i].isInitialized())) {
            i++;
        }
        return this.paramRef != null && i == this.function.length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        int i2 = 0;
        while (i2 < this.function.length) {
            if (this.function[i2] != null) {
                setParametersForFunction(i2, dArr, i + this.paramRef[i2]);
            }
            i2++;
        }
        this.isNormalized = determineIsNormalized();
        if (this.optimizeHidden) {
            setHiddenParameters(dArr, i + this.paramRef[i2]);
        } else if (this.isNormalized) {
            this.norm = 0.0d;
        } else {
            this.norm = Double.NaN;
        }
    }

    protected boolean determineIsNormalized() {
        return isNormalized(this.function);
    }

    public void initializeHiddenUniformly() {
        int numberOfComponents = getNumberOfComponents();
        for (int i = 0; i < this.function.length; i++) {
            if (this.function[i] != null) {
                if (this.function[i] instanceof AbstractMixtureDiffSM) {
                    ((AbstractMixtureDiffSM) this.function[i]).initializeHiddenUniformly();
                } else if (this.function[i] instanceof NormalizedDiffSM) {
                    ((NormalizedDiffSM) this.function[i]).initializeHiddenUniformly();
                } else if (this.function[i] instanceof DurationDiffSM) {
                    ((DurationDiffSM) this.function[i]).initializeUniformly();
                }
            }
        }
        if (this.optimizeHidden) {
            double[] dArr = new double[numberOfComponents];
            double logNormalizationConstantForComponent = this.freeParams ? getLogNormalizationConstantForComponent(numberOfComponents) : 0.0d;
            for (int i2 = 0; i2 < numberOfComponents; i2++) {
                dArr[i2] = logNormalizationConstantForComponent - getLogNormalizationConstantForComponent(i2);
            }
            setHiddenParameters(dArr, 0);
        }
        init(this.freeParams);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setHiddenParameters(double[] dArr, int i) {
        int length = this.hiddenParameter.length - (this.freeParams ? 1 : 0);
        int i2 = 0;
        while (i2 < length) {
            this.hiddenParameter[i2] = dArr[i];
            i2++;
            i++;
        }
        if (this.freeParams) {
            this.hiddenParameter[i2] = 0.0d;
        }
        if (isNormalized()) {
            this.logHiddenNorm = Normalisation.getLogSum(this.hiddenParameter);
            this.norm = 0.0d;
        } else {
            this.logHiddenNorm = 0.0d;
            this.norm = Double.NaN;
        }
        for (int i3 = 0; i3 < this.logHiddenPotential.length; i3++) {
            this.logHiddenPotential[i3] = this.hiddenParameter[i3] - this.logHiddenNorm;
            this.hiddenPotential[i3] = Math.exp(this.logHiddenPotential[i3]);
            this.partNorm[i3] = this.logHiddenPotential[i3];
        }
    }

    public void setParametersForFunction(int i, double[] dArr, int i2) {
        this.function[i].setParameters(dArr, i2);
    }

    @Override // de.jstacs.Storable
    public final StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer(10000);
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), "length");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.starts), "starts");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.freeParams), "freeParams");
        XMLParser.appendObjectWithTags(stringBuffer, this.function, "function");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.optimizeHidden), "optimizeHidden");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.plugIn), "plugIn");
        XMLParser.appendObjectWithTags(stringBuffer, this.hiddenParameter, "hiddenParameter");
        stringBuffer.append(getFurtherInformation());
        XMLParser.addTags(stringBuffer, getXMLTag());
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected final void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, getXMLTag());
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, "length", Integer.TYPE)).intValue();
        this.starts = ((Integer) XMLParser.extractObjectForTags(extractForTag, "starts", Integer.TYPE)).intValue();
        this.freeParams = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "freeParams", Boolean.TYPE)).booleanValue();
        this.function = (DifferentiableStatisticalModel[]) XMLParser.extractObjectForTags(extractForTag, "function", DifferentiableStatisticalModel[].class);
        this.alphabets = this.function[0].getAlphabetContainer();
        this.optimizeHidden = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "optimizeHidden", Boolean.TYPE)).booleanValue();
        this.plugIn = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "plugIn", Boolean.TYPE)).booleanValue();
        this.hiddenParameter = (double[]) XMLParser.extractObjectForTags(extractForTag, "hiddenParameter", double[].class);
        this.hiddenPotential = new double[this.hiddenParameter.length];
        this.logHiddenPotential = new double[this.hiddenParameter.length];
        this.partNorm = new double[this.hiddenParameter.length];
        this.componentScore = new double[this.hiddenParameter.length];
        extractFurtherInformation(extractForTag);
        this.isNormalized = determineIsNormalized();
        setHiddenParameters(this.hiddenParameter, 0);
        this.norm = Double.NaN;
        init(this.freeParams);
        computeLogGammaSum();
    }

    protected StringBuffer getFurtherInformation() {
        return new StringBuffer(1);
    }

    protected void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int[] getIndices(int i) {
        int[] iArr = new int[2];
        iArr[1] = -1;
        while (i >= this.paramRef[iArr[0]]) {
            iArr[0] = iArr[0] + 1;
        }
        iArr[0] = iArr[0] - 1;
        iArr[1] = i - this.paramRef[iArr[0]];
        return iArr;
    }

    protected String getXMLTag() {
        return getClass().getSimpleName();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void init(boolean z) {
        initWithLength(z, this.function.length + 2);
    }

    protected final void initWithLength(boolean z, int i) {
        if (this.paramRef == null || this.paramRef.length != i) {
            this.paramRef = new int[i];
        }
        if (this.iList == null) {
            this.iList = new IntList[Math.max(this.function.length, this.hiddenParameter.length)];
            this.dList = new DoubleList[this.iList.length];
            for (int i2 = 0; i2 < this.iList.length; i2++) {
                this.iList[i2] = new IntList();
                this.dList[i2] = new DoubleList();
            }
        }
        int i3 = 0;
        while (i3 < this.function.length) {
            int numberOfParameters = this.function[i3] == null ? 0 : this.function[i3].getNumberOfParameters();
            if (numberOfParameters == -1) {
                this.paramRef = null;
                return;
            } else {
                this.paramRef[i3 + 1] = this.paramRef[i3] + numberOfParameters;
                i3++;
            }
        }
        if (this.optimizeHidden) {
            this.paramRef[i3 + 1] = (this.paramRef[i3] + this.hiddenParameter.length) - (z ? 1 : 0);
        } else {
            this.paramRef[i3 + 1] = this.paramRef[i3];
        }
        this.freeParams = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeHiddenParameter(double[] dArr, boolean z) {
        int length = this.hiddenParameter.length;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < length; i++) {
            if (z) {
                int i2 = i;
                dArr[i2] = dArr[i2] + getHyperparameterForHiddenParameter(i);
            }
            d += dArr[i];
        }
        double log = Math.log(d);
        if (this.freeParams) {
            length--;
            d2 = (Math.log(dArr[length]) - log) - getLogNormalizationConstantForComponent(length);
            this.hiddenParameter[length] = 0.0d;
        }
        for (int i3 = 0; i3 < length; i3++) {
            this.hiddenParameter[i3] = ((Math.log(dArr[i3]) - log) - getLogNormalizationConstantForComponent(i3)) - d2;
        }
        setHiddenParameters(this.hiddenParameter, 0);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void precomputeNorm() {
        for (int i = 0; i < this.hiddenPotential.length; i++) {
            this.partNorm[i] = this.logHiddenPotential[i] + getLogNormalizationConstantForComponent(i);
        }
        this.norm = Normalisation.getLogSum(this.partNorm);
    }

    protected abstract double getLogNormalizationConstantForComponent(int i);

    public double[] getAPrioriMixtureProbabilities() {
        double[] dArr = new double[getNumberOfComponents()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.logHiddenPotential[i] + getLogNormalizationConstantForComponent(i);
        }
        Normalisation.logSumNormalisation(dArr);
        return dArr;
    }

    protected abstract void fillComponentScores(Sequence sequence, int i);

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public final boolean isNormalized() {
        return this.isNormalized;
    }

    public DifferentiableStatisticalModel getFunction(int i) throws CloneNotSupportedException {
        if (this.function[i] != null) {
            return (DifferentiableStatisticalModel) this.function[i].mo110clone();
        }
        return null;
    }

    public DifferentiableStatisticalModel[] getFunctions() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel[]) ArrayHandler.clone(this.function);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel
    public int[][] getSamplingGroups(int i) {
        LinkedList linkedList = new LinkedList();
        int i2 = 0;
        int i3 = 0;
        while (i3 < this.function.length) {
            if (this.function[i3] instanceof SamplingDifferentiableStatisticalModel) {
                for (int[] iArr : ((SamplingDifferentiableStatisticalModel) this.function[i3]).getSamplingGroups(i + this.paramRef[i2])) {
                    linkedList.add(iArr);
                }
            } else {
                int[] iArr2 = new int[this.function[i3].getNumberOfParameters()];
                for (int i4 = 0; i4 < iArr2.length; i4++) {
                    iArr2[i4] = i + this.paramRef[i3] + i4;
                }
                linkedList.add(iArr2);
            }
            i3++;
            i2++;
        }
        int[] iArr3 = new int[this.hiddenParameter.length];
        for (int i5 = 0; i5 < iArr3.length; i5++) {
            iArr3[i5] = i + this.paramRef[this.function.length] + i5;
        }
        linkedList.add(iArr3);
        return (int[][]) linkedList.toArray((Object[]) new int[0]);
    }
}
