package de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/homogeneous/HomogeneousMM0DiffSM.class */
public class HomogeneousMM0DiffSM extends HomogeneousDiffSM {
    private double ess;
    private double norm;
    private double sumOfHyperParams;
    private double logGammaSum;
    private int[] counter;
    private boolean freeParams;
    private boolean plugIn;
    private boolean optimize;
    private MEMConstraint params;
    private int anz;

    public HomogeneousMM0DiffSM(AlphabetContainer alphabetContainer, int i, double d, boolean z, boolean z2) {
        super(alphabetContainer, i);
        if (d < 0.0d) {
            throw new IllegalArgumentException("The ess has to be non-negative.");
        }
        this.ess = d;
        this.sumOfHyperParams = d * i;
        this.params = new MEMConstraint(new int[]{0}, new int[]{(int) alphabetContainer.getAlphabetLengthAt(0)});
        this.plugIn = z;
        this.optimize = z2;
        setFreeParams(false);
        this.norm = 1.0d;
        double d2 = -Math.log(alphabetContainer.getAlphabetLengthAt(0));
        for (int i2 = 0; i2 < this.counter.length; i2++) {
            this.params.setLambda(i2, d2);
        }
        computeConstantsOfLogPrior();
    }

    public HomogeneousMM0DiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public HomogeneousMM0DiffSM mo71clone() throws CloneNotSupportedException {
        HomogeneousMM0DiffSM homogeneousMM0DiffSM = (HomogeneousMM0DiffSM) super.mo71clone();
        homogeneousMM0DiffSM.params = this.params.mo89clone();
        homogeneousMM0DiffSM.counter = (int[]) this.counter.clone();
        return homogeneousMM0DiffSM;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "hMM(0)";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractVariableLengthDiffSM, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i, int i2) {
        double d = 0.0d;
        int i3 = (i2 - i) + 1;
        for (int i4 = 0; i4 < i3; i4++) {
            d += this.params.getLambda(this.params.satisfiesSpecificConstraint(sequence, i + i4));
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractVariableLengthDiffSM, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, int i2, IntList intList, DoubleList doubleList) {
        Arrays.fill(this.counter, 0);
        int i3 = (i2 - i) + 1;
        for (int i4 = 0; i4 < i3; i4++) {
            int[] iArr = this.counter;
            int satisfiesSpecificConstraint = this.params.satisfiesSpecificConstraint(sequence, i + i4);
            iArr[satisfiesSpecificConstraint] = iArr[satisfiesSpecificConstraint] + 1;
        }
        double d = 0.0d;
        for (int i5 = 0; i5 < this.counter.length; i5++) {
            if (this.counter[i5] > 0) {
                d += this.counter[i5] * this.params.getLambda(i5);
                if (i5 < this.anz) {
                    intList.add(i5);
                    doubleList.add(this.counter[i5]);
                }
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.anz;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        if (this.optimize) {
            this.norm = 0.0d;
            for (int i2 = 0; i2 < this.anz; i2++) {
                this.params.setLambda(i2, dArr[i + i2]);
                this.norm += this.params.getExpLambda(i2);
            }
            if (this.anz < this.counter.length) {
                this.norm += this.params.getExpLambda(this.anz);
            }
        }
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer(1000);
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), "length");
        XMLParser.appendObjectWithTags(stringBuffer, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.sumOfHyperParams), "sumOfHyperParams");
        XMLParser.appendObjectWithTags(stringBuffer, this.params, "params");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.freeParams), "freeParams");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.plugIn), "plugIn");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.optimize), "optimize");
        XMLParser.addTags(stringBuffer, getClass().getSimpleName());
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() {
        double[] dArr = new double[this.anz];
        for (int i = 0; i < this.anz; i++) {
            dArr[i] = this.params.getLambda(i);
        }
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) {
        this.params.reset();
        if (this.plugIn) {
            if (dataSetArr != null && dataSetArr[i] != null) {
                for (int i2 = 0; i2 < dataSetArr[i].getNumberOfElements(); i2++) {
                    Sequence elementAt = dataSetArr[i].getElementAt(i2);
                    int length = elementAt.getLength();
                    for (int i3 = 0; i3 < length; i3++) {
                        this.params.add(elementAt.discreteVal(i3), dArr[i][i2]);
                    }
                }
            }
            this.params.estimate(this.sumOfHyperParams);
            for (int i4 = 0; i4 < this.counter.length; i4++) {
                this.params.setExpLambda(i4, this.params.getFreq(i4));
            }
        } else {
            double d = -Math.log(this.alphabets.getAlphabetLengthAt(0));
            for (int i5 = 0; i5 < this.counter.length; i5++) {
                this.params.setLambda(i5, d);
            }
        }
        this.norm = 1.0d;
        setFreeParams(z);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) {
        int length = this.counter.length;
        double[] generate = DirichletMRG.DEFAULT_INSTANCE.generate(length, new FastDirichletMRGParams(this.sumOfHyperParams == 0.0d ? 1.0d : this.sumOfHyperParams / length));
        for (int i = 0; i < length; i++) {
            this.params.setExpLambda(i, generate[i]);
        }
        this.norm = 1.0d;
        setFreeParams(z);
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, getClass().getSimpleName());
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, "length", Integer.TYPE)).intValue();
        this.alphabets = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabets");
        this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, "ess", Double.TYPE)).doubleValue();
        this.sumOfHyperParams = ((Double) XMLParser.extractObjectForTags(extractForTag, "sumOfHyperParams", Double.TYPE)).doubleValue();
        this.params = (MEMConstraint) XMLParser.extractObjectForTags(extractForTag, "params", MEMConstraint.class);
        this.plugIn = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "plugIn", Boolean.TYPE)).booleanValue();
        this.optimize = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "optimize", Boolean.TYPE)).booleanValue();
        setFreeParams(((Boolean) XMLParser.extractObjectForTags(extractForTag, "freeParams", Boolean.TYPE)).booleanValue());
        for (int i = 0; i < this.params.getNumberOfSpecificConstraints(); i++) {
            this.norm += this.params.getExpLambda(i);
        }
        computeConstantsOfLogPrior();
    }

    private void setFreeParams(boolean z) {
        this.freeParams = z;
        this.counter = new int[this.params.getNumberOfSpecificConstraints()];
        if (this.optimize) {
            this.anz = this.counter.length - (z ? 1 : 0);
        } else {
            this.anz = 0;
        }
        if (z) {
            double lambda = this.params.getLambda(this.params.getNumberOfSpecificConstraints() - 1);
            for (int i = 0; i < this.params.getNumberOfSpecificConstraints(); i++) {
                this.params.setLambda(i, this.params.getLambda(i) - lambda);
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        if (i < this.anz) {
            return this.params.getNumberOfSpecificConstraints();
        }
        throw new IndexOutOfBoundsException();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM
    public double getLogNormalizationConstant(int i) {
        if (i == 0) {
            throw new RuntimeException("The normalization constant can not be computed for length 0.");
        }
        return this.norm * i;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM
    public double getLogPartialNormalizationConstant(int i, int i2) throws Exception {
        if (i < this.anz) {
            return i2 + (this.norm * (i2 - 1)) + this.params.getLambda(i);
        }
        throw new IndexOutOfBoundsException();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer(100);
        stringBuffer.append(this.alphabets.getSymbol(0, 0.0d) + ": " + numberFormat.format(this.params.getExpLambda(0) / this.norm));
        for (int i = 1; i < this.params.getNumberOfSpecificConstraints(); i++) {
            stringBuffer.append("\t" + this.alphabets.getSymbol(0, i) + ": " + numberFormat.format(this.params.getExpLambda(i) / this.norm));
        }
        return stringBuffer.toString();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        if (!this.optimize) {
            return 0.0d;
        }
        double d = 0.0d;
        int numberOfSpecificConstraints = this.params.getNumberOfSpecificConstraints();
        int i = 0;
        while (i < numberOfSpecificConstraints) {
            int i2 = i;
            i++;
            d += this.params.getLambda(i2);
        }
        return ((d * this.sumOfHyperParams) / numberOfSpecificConstraints) + this.logGammaSum;
    }

    private void computeConstantsOfLogPrior() {
        int numberOfSpecificConstraints = this.params.getNumberOfSpecificConstraints();
        this.logGammaSum = Gamma.logOfGamma(this.sumOfHyperParams) - (numberOfSpecificConstraints * Gamma.logOfGamma(this.sumOfHyperParams / numberOfSpecificConstraints));
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) {
        double numberOfSpecificConstraints = this.sumOfHyperParams / this.params.getNumberOfSpecificConstraints();
        for (int i2 = 0; i2 < this.anz; i2++) {
            int i3 = i + i2;
            dArr[i3] = dArr[i3] + numberOfSpecificConstraints;
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousDiffSM, de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public byte getMaximalMarkovOrder() {
        return (byte) 0;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM
    public void setStatisticForHyperparameters(int[] iArr, double[] dArr) throws Exception {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException("The length of both arrays (length, weight) have to be identical.");
        }
        this.sumOfHyperParams = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            if (dArr[i] < 0.0d || iArr[i] < 0) {
                throw new IllegalArgumentException("check length and weight for entry " + i);
            }
            this.sumOfHyperParams += iArr[i] * dArr[i];
        }
        computeConstantsOfLogPrior();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousDiffSM
    public void initializeUniformly(boolean z) {
        double length = 1.0d / this.counter.length;
        for (int i = 0; i < this.counter.length; i++) {
            this.params.setExpLambda(i, length);
        }
        this.norm = 1.0d;
        setFreeParams(z);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel
    public int[][] getSamplingGroups(int i) {
        int[][] iArr = new int[1][this.params.getNumberOfSpecificConstraints()];
        for (int i2 = 0; i2 < iArr[0].length; i2++) {
            iArr[0][i2] = i + i2;
        }
        return iArr;
    }
}
