package de.jstacs.sequenceScores.statisticalModels.trainable.hmm;

import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.State;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MultiThreadedTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.HigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.Transition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.SafeOutputStream;
import de.jstacs.utils.ToolBox;
import java.io.OutputStream;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.batik.util.CSSConstants;
import org.apache.batik.util.SVGConstants;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/AbstractHMM.class */
public abstract class AbstractHMM extends AbstractTrainableStatisticalModel implements Cloneable, Storable {
    protected State[] states;
    protected String[] name;
    protected int[] emissionIdx;
    protected boolean[] forward;
    protected Emission[] emission;
    protected Transition transition;
    protected double[][] fwdMatrix;
    protected double[][] bwdMatrix;
    protected HMMTrainingParameterSet trainingParameter;
    protected SafeOutputStream sostream;
    protected boolean[] finalState;
    protected int threads;
    public static final String START_NODE = "START";

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractHMM(HMMTrainingParameterSet hMMTrainingParameterSet, String[] strArr, int[] iArr, boolean[] zArr, Emission[] emissionArr) throws CloneNotSupportedException, WrongAlphabetException {
        super(getAlphabetContainer(emissionArr), 0);
        if (!hMMTrainingParameterSet.hasDefaultOrIsSet()) {
            throw new IllegalArgumentException("Please check the training parameters.");
        }
        this.trainingParameter = (HMMTrainingParameterSet) hMMTrainingParameterSet.mo7clone();
        setThreads();
        setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        int length = strArr.length;
        this.name = new String[length];
        HashSet hashSet = new HashSet();
        for (int i = 0; i < length; i++) {
            if (hashSet.contains(strArr[i])) {
                throw new IllegalArgumentException("The state names should be unique. Please check: " + strArr[i]);
            }
            this.name[i] = strArr[i];
            hashSet.add(strArr[i]);
        }
        hashSet.clear();
        if (iArr == null) {
            this.emissionIdx = new int[length];
            for (int i2 = 0; i2 < length; i2++) {
                this.emissionIdx[i2] = i2;
            }
        } else {
            if (length != iArr.length) {
                throw new IllegalArgumentException();
            }
            this.emissionIdx = (int[]) iArr.clone();
        }
        if (zArr == null) {
            this.forward = new boolean[length];
            Arrays.fill(this.forward, true);
        } else {
            if (length != zArr.length) {
                throw new IllegalArgumentException();
            }
            this.forward = (boolean[]) zArr.clone();
        }
        if (emissionArr.length > length) {
            throw new IllegalArgumentException();
        }
        this.emission = (Emission[]) ArrayHandler.clone(emissionArr);
    }

    private void setThreads() {
        this.threads = this.trainingParameter instanceof MultiThreadedTrainingParameterSet ? ((MultiThreadedTrainingParameterSet) this.trainingParameter).getNumberOfThreads() : 1;
    }

    private static AlphabetContainer getAlphabetContainer(Emission... emissionArr) throws WrongAlphabetException {
        AlphabetContainer alphabetContainer = null;
        int i = 0;
        while (alphabetContainer == null) {
            int i2 = i;
            i++;
            alphabetContainer = emissionArr[i2].getAlphabetContainer();
        }
        while (i < emissionArr.length) {
            int i3 = i;
            i++;
            AlphabetContainer alphabetContainer2 = emissionArr[i3].getAlphabetContainer();
            if (alphabetContainer2 != null && !alphabetContainer2.checkConsistency(alphabetContainer)) {
                throw new WrongAlphabetException("All emission should use the same AlphabetContainer.");
            }
        }
        if (alphabetContainer.isSimple()) {
            return alphabetContainer;
        }
        throw new IllegalArgumentException("The AlphabetContainer has to be simple.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractHMM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initTransition(BasicHigherOrderTransition.AbstractTransitionElement... abstractTransitionElementArr) throws Exception {
        boolean[] zArr = new boolean[this.states.length];
        for (int i = 0; i < this.states.length; i++) {
            zArr[i] = this.states[i].isSilent();
        }
        if (abstractTransitionElementArr instanceof TransitionElement[]) {
            this.transition = new HigherOrderTransition(zArr, (TransitionElement[]) abstractTransitionElementArr);
            return;
        }
        TransitionElement[] transitionElementArr = new TransitionElement[abstractTransitionElementArr.length];
        for (int i2 = 0; i2 < transitionElementArr.length && (abstractTransitionElementArr[0] instanceof TransitionElement); i2++) {
            transitionElementArr[0] = (TransitionElement) abstractTransitionElementArr[0];
        }
        if (0 == abstractTransitionElementArr.length) {
            this.transition = new HigherOrderTransition(zArr, transitionElementArr);
        } else {
            this.transition = new BasicHigherOrderTransition(zArr, abstractTransitionElementArr);
        }
    }

    protected abstract String getXMLTag();

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.trainingParameter, "trainingParameter");
        XMLParser.appendObjectWithTags(stringBuffer, this.transition, "transition");
        XMLParser.appendObjectWithTags(stringBuffer, this.name, SVGConstants.SVG_NAME_ATTRIBUTE);
        XMLParser.appendObjectWithTags(stringBuffer, this.emissionIdx, "emissionIdx");
        XMLParser.appendObjectWithTags(stringBuffer, this.forward, "strand");
        XMLParser.appendObjectWithTags(stringBuffer, this.emission, "emission");
        appendFurtherInformation(stringBuffer);
        XMLParser.addTags(stringBuffer, getXMLTag());
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        this.length = 0;
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, getXMLTag());
        this.trainingParameter = (HMMTrainingParameterSet) XMLParser.extractObjectForTags(extractForTag, "trainingParameter");
        setThreads();
        this.transition = (Transition) XMLParser.extractObjectForTags(extractForTag, "transition");
        this.name = (String[]) XMLParser.extractObjectForTags(extractForTag, SVGConstants.SVG_NAME_ATTRIBUTE, String[].class);
        this.emissionIdx = (int[]) XMLParser.extractObjectForTags(extractForTag, "emissionIdx", int[].class);
        this.forward = (boolean[]) XMLParser.extractObjectForTags(extractForTag, "strand", boolean[].class);
        this.emission = (Emission[]) XMLParser.extractObjectForTags(extractForTag, "emission", Emission[].class);
        extractFurtherInformation(extractForTag);
        try {
            this.alphabets = getAlphabetContainer(this.emission);
            createStates();
            determineFinalStates();
        } catch (WrongAlphabetException e) {
            throw new NonParsableException(e.getMessage());
        }
    }

    protected abstract void appendFurtherInformation(StringBuffer stringBuffer);

    protected abstract void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException;

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v20, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v24, types: [java.lang.Cloneable[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel
    /* renamed from: clone */
    public AbstractHMM mo93clone() throws CloneNotSupportedException {
        AbstractHMM abstractHMM = (AbstractHMM) super.mo93clone();
        abstractHMM.name = (String[]) this.name.clone();
        abstractHMM.emissionIdx = (int[]) this.emissionIdx.clone();
        abstractHMM.forward = (boolean[]) this.forward.clone();
        abstractHMM.emission = (Emission[]) ArrayHandler.clone(this.emission);
        abstractHMM.transition = this.transition.mo135clone();
        abstractHMM.fwdMatrix = (double[][]) ArrayHandler.clone(this.fwdMatrix);
        abstractHMM.bwdMatrix = (double[][]) ArrayHandler.clone(this.bwdMatrix);
        abstractHMM.trainingParameter = (HMMTrainingParameterSet) this.trainingParameter.mo7clone();
        abstractHMM.finalState = (boolean[]) this.finalState.clone();
        abstractHMM.createStates();
        abstractHMM.setOutputStream(this.sostream.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return abstractHMM;
    }

    protected abstract void createStates();

    protected abstract void fillFwdMatrix(int i, int i2, Sequence sequence) throws Exception;

    protected abstract void fillBwdMatrix(int i, int i2, Sequence sequence) throws Exception;

    public int getNumberOfThreads() {
        return this.threads;
    }

    public String getGraphvizRepresentation(NumberFormat numberFormat) {
        return getGraphvizRepresentation(numberFormat, (DataSet) null, (double[]) null, false);
    }

    public String getGraphvizRepresentation(NumberFormat numberFormat, boolean z) {
        return getGraphvizRepresentation(numberFormat, (DataSet) null, (double[]) null, z);
    }

    public String getGraphvizRepresentation(NumberFormat numberFormat, DataSet dataSet, double[] dArr, boolean z) {
        HashMap<String, String> hashMap = null;
        if (z) {
            hashMap = new HashMap<>();
            for (int i = 0; i < this.name.length; i++) {
                hashMap.put(this.name[i].charAt(0) + ".*", "same");
            }
        }
        return getGraphvizRepresentation(numberFormat, dataSet, dArr, hashMap);
    }

    public String getGraphvizRepresentation(NumberFormat numberFormat, DataSet dataSet, double[] dArr, HashMap<String, String> hashMap) {
        double[] dArr2;
        double d;
        if (dataSet != null) {
            try {
                dArr2 = getStateFreq(dataSet, dArr);
                d = ToolBox.max(dArr2);
            } catch (Exception e) {
                e.printStackTrace();
                dArr2 = new double[this.states.length];
                d = 0.0d;
            }
        } else {
            dArr2 = new double[this.states.length];
            Arrays.fill(dArr2, -1.0d);
            d = -1.0d;
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("digraph G {\n\trankdir=" + (hashMap != null ? "TB" : "LR") + "\n\n");
        stringBuffer.append("\tSTART[shape=point];\n\n");
        for (int i = 0; i < this.states.length; i++) {
            stringBuffer.append("\t" + i + "[" + this.states[i].getGraphvizNodeOptions(dArr2[i], d, numberFormat) + ",color=" + (this.finalState[i] ? CSSConstants.CSS_RED_VALUE : CSSConstants.CSS_BLACK_VALUE) + "];\n");
        }
        if (hashMap != null) {
            StringBuffer stringBuffer2 = new StringBuffer();
            HashMap hashMap2 = new HashMap();
            for (int i2 = 0; i2 < this.name.length; i2++) {
                Iterator<String> it = hashMap.keySet().iterator();
                while (true) {
                    if (it.hasNext()) {
                        String next = it.next();
                        if (this.name[i2].matches(next)) {
                            if (((IntList) hashMap2.get(next)) == null) {
                                hashMap2.put(next, new IntList());
                            }
                            ((IntList) hashMap2.get(next)).add(i2);
                        }
                    }
                }
            }
            boolean z = false;
            for (String str : hashMap2.keySet()) {
                stringBuffer2.append("{rank=" + hashMap.get(str) + "; ");
                if (!z && START_NODE.matches(str)) {
                    stringBuffer2.append("START ");
                    z = true;
                }
                IntList intList = (IntList) hashMap2.get(str);
                for (int i3 = 0; i3 < intList.length(); i3++) {
                    stringBuffer2.append(intList.get(i3) + " ");
                }
                stringBuffer2.append(";}\n");
            }
            stringBuffer.append(stringBuffer2);
        }
        stringBuffer.append(IOUtils.LINE_SEPARATOR_UNIX);
        stringBuffer.append(this.transition.getGraphizNetworkRepresentation(numberFormat, null, dataSet != null));
        stringBuffer.append("}");
        return stringBuffer.toString();
    }

    private double[] getStateFreq(DataSet dataSet, double[] dArr) throws Exception {
        double[] dArr2 = new double[this.states.length];
        if (dataSet != null) {
            double d = 1.0d;
            double d2 = 0.0d;
            double[][] createMatrixForStatePosterior = createMatrixForStatePosterior(0, dataSet.getMaximalElementLength() - 1);
            for (int i = 0; i < dataSet.getNumberOfElements(); i++) {
                fillLogStatePosteriorMatrix(createMatrixForStatePosterior, 0, dataSet.getElementAt(i).getLength() - 1, dataSet.getElementAt(i), false);
                if (dArr != null) {
                    d = dArr[i];
                }
                d2 += d;
                for (int i2 = 0; i2 < this.states.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + (d * Math.exp(Normalisation.getLogSum(createMatrixForStatePosterior[i2])));
                }
            }
            for (int i4 = 0; i4 < this.states.length; i4++) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / d2;
            }
        }
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[][] createMatrixForStatePosterior(int i, int i2) {
        return new double[this.states.length][(i2 - i) + 1 + 1];
    }

    protected abstract void fillLogStatePosteriorMatrix(double[][] dArr, int i, int i2, Sequence sequence, boolean z) throws Exception;

    public double[][] getLogStatePosteriorMatrixFor(int i, int i2, Sequence sequence) throws Exception {
        double[][] createMatrixForStatePosterior = createMatrixForStatePosterior(i, i2);
        fillLogStatePosteriorMatrix(createMatrixForStatePosterior, i, i2, sequence, true);
        return getFinalStatePosterioriMatrix(createMatrixForStatePosterior);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    public double[][] getFinalStatePosterioriMatrix(double[][] dArr) {
        ?? r0 = new double[dArr.length];
        for (int i = 0; i < r0.length; i++) {
            r0[i] = new double[dArr[i].length - 1];
            System.arraycopy(dArr[i], 1, r0[i], 0, r0[i].length);
        }
        return r0;
    }

    public double[][] getStatePosteriorMatrixFor(Sequence sequence) throws Exception {
        double[][] logStatePosteriorMatrixFor = getLogStatePosteriorMatrixFor(0, sequence.getLength() - 1, sequence);
        for (int i = 0; i < logStatePosteriorMatrixFor.length; i++) {
            for (int i2 = 0; i2 < logStatePosteriorMatrixFor[i].length; i2++) {
                logStatePosteriorMatrixFor[i][i2] = Math.exp(logStatePosteriorMatrixFor[i][i2]);
            }
        }
        return logStatePosteriorMatrixFor;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[][], double[][][]] */
    public double[][][] getLogStatePosteriorMatrixFor(DataSet dataSet) throws Exception {
        ?? r0 = new double[dataSet.getNumberOfElements()];
        for (int i = 0; i < r0.length; i++) {
            Sequence elementAt = dataSet.getElementAt(i);
            r0[i] = getLogStatePosteriorMatrixFor(0, elementAt.getLength() - 1, elementAt);
        }
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[][], double[][][]] */
    public double[][][] getStatePosteriorMatrixFor(DataSet dataSet) throws Exception {
        ?? r0 = new double[dataSet.getNumberOfElements()];
        for (int i = 0; i < r0.length; i++) {
            r0[i] = getStatePosteriorMatrixFor(dataSet.getElementAt(i));
        }
        return r0;
    }

    public abstract Pair<IntList, Double> getViterbiPathFor(int i, int i2, Sequence sequence) throws Exception;

    public Pair<IntList, Double> getViterbiPathFor(Sequence sequence) throws Exception {
        return getViterbiPathFor(0, sequence.getLength() - 1, sequence);
    }

    public Pair<IntList, Double>[] getViterbiPathsFor(DataSet dataSet) throws Exception {
        Pair<IntList, Double>[] pairArr = new Pair[dataSet.getNumberOfElements()];
        for (int i = 0; i < pairArr.length; i++) {
            pairArr[i] = getViterbiPathFor(dataSet.getElementAt(i));
        }
        return pairArr;
    }

    public final String[] decodePath(IntList intList) {
        String[] strArr = new String[intList.length()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = this.name[intList.get(i)];
        }
        return strArr;
    }

    public abstract double getLogProbForPath(IntList intList, int i, Sequence sequence) throws Exception;

    protected abstract void createHelperVariables();

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[]] */
    public void provideMatrix(int i, int i2) {
        double[][] dArr;
        createHelperVariables();
        int i3 = i2 + 1;
        switch (i) {
            case 0:
                dArr = this.fwdMatrix;
                break;
            case 1:
                dArr = this.bwdMatrix;
                break;
            default:
                throw new IllegalArgumentException("unknown matrix type");
        }
        if (dArr == null || dArr.length < i3) {
            dArr = new double[i3];
            int maximalMarkovOrder = this.transition.getMaximalMarkovOrder();
            int i4 = -1;
            int i5 = 0;
            while (i5 <= maximalMarkovOrder && i5 < i3) {
                i4 = this.transition.getNumberOfIndexes(i5);
                dArr[i5] = new double[i4];
                i5++;
            }
            while (i5 < i3) {
                int i6 = i5;
                i5++;
                dArr[i6] = new double[i4];
            }
        }
        for (int i7 = 0; i7 < i3; i7++) {
            Arrays.fill(dArr[i7], Double.NEGATIVE_INFINITY);
        }
        switch (i) {
            case 0:
                this.fwdMatrix = dArr;
                return;
            case 1:
                this.bwdMatrix = dArr;
                return;
            default:
                return;
        }
    }

    public int getNumberOfStates() {
        return this.states.length;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogProbFor(Sequence sequence, int i, int i2) throws Exception {
        int i3 = (i2 - i) + 1;
        int length = getLength();
        if (!sequence.getAlphabetContainer().checkConsistency(getAlphabetContainer())) {
            throw new WrongAlphabetException("The AlphabetContainer of the sequence and the model do not match.");
        }
        if (length == 0 || i3 == length) {
            return logProb(i, i2, sequence);
        }
        throw new WrongLengthException("The given start position (" + i + ") and end position (" + i2 + ") yield an length of " + i3 + " which is not possible for the current model that models sequences of length " + length + ".");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static RuntimeException getRunTimeException(Exception exc) {
        RuntimeException runtimeException;
        if (exc instanceof RuntimeException) {
            runtimeException = (RuntimeException) exc;
        } else {
            runtimeException = new RuntimeException(exc.getMessage());
            runtimeException.setStackTrace(exc.getStackTrace());
        }
        return runtimeException;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double logProb(int i, int i2, Sequence sequence) throws Exception {
        try {
            fillBwdMatrix(i, i2, sequence);
            return this.bwdMatrix[0][0];
        } catch (Exception e) {
            throw getRunTimeException(e);
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel
    public void train(DataSet dataSet) throws Exception {
        train(dataSet, null);
    }

    public final void setOutputStream(OutputStream outputStream) {
        this.sostream = SafeOutputStream.getSafeOutputStream(outputStream);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void finalize() throws Throwable {
        this.transition = null;
        this.states = null;
        this.trainingParameter = null;
        this.bwdMatrix = (double[][]) null;
        this.fwdMatrix = (double[][]) null;
        super.finalize();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void determineFinalStates() {
        this.finalState = this.transition.isAbsoring();
        int i = 0;
        while (i < this.finalState.length && !this.finalState[i]) {
            i++;
        }
        if (i == this.finalState.length) {
            for (int i2 = 0; i2 < this.finalState.length; i2++) {
                this.finalState[i2] = !this.states[i2].isSilent();
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] decodeStatePosterior(double[][]... dArr) {
        ?? r0 = new int[dArr.length];
        for (int i = 0; i < r0.length; i++) {
            r0[i] = new int[dArr[i][0].length];
            for (int i2 = 0; i2 < r0[i].length; i2++) {
                r0[i][i2] = 0;
                for (int i3 = 1; i3 < dArr[i].length; i3++) {
                    if (dArr[i][r0[i][i2]][i2] < dArr[i][i3][i2]) {
                        r0[i][i2] = i3;
                    }
                }
            }
        }
        return r0;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        String str = ("Transition:\n-----------\n" + this.transition.toString(this.name, numberFormat)) + "\nStates:\n-------\n";
        for (int i = 0; i < this.states.length; i++) {
            str = str + this.states[i].toString(numberFormat) + IOUtils.LINE_SEPARATOR_UNIX;
        }
        return str;
    }
}
