/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.preorder;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.ConditionalVarianceAndTransform2;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.inference.model.CompoundParameter;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.WrappedVector;
import org.ejml.data.DenseMatrix64F;

public class ContinuousExtensionDelegate {
    protected final TreeTrait treeTrait;
    protected final Tree tree;
    private final ContinuousDataLikelihoodDelegate likelihoodDelegate;

    public ContinuousExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, Tree tree) {
        this.treeTrait = treeTrait;
        this.tree = tree;
        this.likelihoodDelegate = continuousDataLikelihoodDelegate;
    }

    public double[] getExtendedValues() {
        this.likelihoodDelegate.fireModelChanged();
        return (double[])this.treeTrait.getTrait(this.tree, null);
    }

    public double[] getExtendedValues(double[] dArray) {
        return dArray;
    }

    public TreeTrait getTreeTrait() {
        return this.treeTrait;
    }

    public Tree getTree() {
        return this.tree;
    }

    public static class MultivariateNormalExtensionDelegate
    extends ContinuousExtensionDelegate {
        private final double[] sample;
        private final ModelExtensionProvider.NormalExtensionProvider dataModel;
        private final int dimTrait;
        private final int nTaxa;

        public MultivariateNormalExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, ModelExtensionProvider.NormalExtensionProvider normalExtensionProvider, Tree tree) {
            super(continuousDataLikelihoodDelegate, treeTrait, tree);
            this.dataModel = normalExtensionProvider;
            this.dimTrait = normalExtensionProvider.getDataDimension();
            this.nTaxa = tree.getExternalNodeCount();
            this.sample = new double[this.nTaxa * this.dimTrait];
        }

        @Override
        public double[] getExtendedValues() {
            double[] dArray = super.getExtendedValues();
            double[] dArray2 = this.dataModel.transformTreeTraits(dArray);
            return this.getExtendedValues(dArray2);
        }

        @Override
        public double[] getExtendedValues(double[] dArray) {
            CompoundParameter compoundParameter = this.dataModel.getParameter();
            DenseMatrix64F denseMatrix64F = this.dataModel.getExtensionVariance();
            boolean[] blArray = this.dataModel.getDataMissingIndicators();
            boolean bl = false;
            double[][] dArray2 = null;
            int n = 0;
            for (int i = 0; i < this.nTaxa; ++i) {
                Object object;
                int n2;
                IndexPartition indexPartition = new IndexPartition(blArray, i);
                for (int n3 : indexPartition.obsInds) {
                    n2 = n3 + n;
                    this.sample[n2] = compoundParameter.getParameterValue(n2);
                }
                if (this.dataModel.diagonalVariance()) {
                    for (int n3 : indexPartition.misInds) {
                        n2 = n3 + n;
                        this.sample[n2] = MathUtils.nextGaussian() * Math.sqrt(denseMatrix64F.get(n3, n3)) + dArray[n2];
                    }
                } else if (indexPartition.nMissing == this.dimTrait) {
                    object = new double[this.dimTrait];
                    System.arraycopy(dArray, n, object, 0, this.dimTrait);
                    if (!bl) {
                        dArray2 = CholeskyDecomposition.execute(denseMatrix64F.getData(), 0, this.dimTrait);
                        bl = true;
                    }
                    double[] dArray3 = MultivariateNormalDistribution.nextMultivariateNormalCholesky((double[])object, dArray2);
                    for (int j = n; j < n + this.dimTrait; ++j) {
                        this.sample[j] = dArray3[j - n];
                    }
                } else {
                    object = new ConditionalVarianceAndTransform2(denseMatrix64F, indexPartition.misInds, indexPartition.obsInds);
                    WrappedVector wrappedVector = ((ConditionalVarianceAndTransform2)object).getConditionalMean(compoundParameter.getParameter(i).getParameterValues(), 0, dArray, n);
                    double[][] dArray4 = ((ConditionalVarianceAndTransform2)object).getConditionalCholesky();
                    double[] dArray5 = MultivariateNormalDistribution.nextMultivariateNormalCholesky(wrappedVector.getBuffer(), dArray4);
                    for (int n4 : indexPartition.obsInds) {
                        this.sample[n4 + n] = compoundParameter.getParameterValue(n4 + n);
                    }
                    for (n2 = 0; n2 < indexPartition.nMissing; ++n2) {
                        this.sample[((IndexPartition)indexPartition).misInds[n2] + n] = dArray5[n2];
                    }
                }
                n += this.dimTrait;
            }
            return this.sample;
        }

        private class IndexPartition {
            private final int[] obsInds;
            private final int[] misInds;
            private int nMissing;
            private int nObserved;

            private IndexPartition(boolean[] blArray, int n) {
                int n2;
                int n3 = n * MultivariateNormalExtensionDelegate.this.dimTrait;
                this.nMissing = 0;
                for (n2 = n3; n2 < n3 + MultivariateNormalExtensionDelegate.this.dimTrait; ++n2) {
                    if (!blArray[n2]) continue;
                    ++this.nMissing;
                }
                this.nObserved = MultivariateNormalExtensionDelegate.this.dimTrait - this.nMissing;
                this.misInds = new int[this.nMissing];
                this.obsInds = new int[MultivariateNormalExtensionDelegate.this.dimTrait - this.nMissing];
                n2 = 0;
                int n4 = 0;
                for (int i = n3; i < n3 + MultivariateNormalExtensionDelegate.this.dimTrait; ++i) {
                    if (blArray[i]) {
                        this.misInds[n2] = i - n3;
                        ++n2;
                        continue;
                    }
                    this.obsInds[n4] = i - n3;
                    ++n4;
                }
            }
        }
    }
}

