/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.epidemiology.casetocase;

import dr.app.tools.NexusExporter;
import dr.evolution.coalescent.Coalescent;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.LinearGrowth;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.coalescent.demographicmodel.DemographicModel;
import dr.evomodel.epidemiology.casetocase.AbstractCase;
import dr.evomodel.epidemiology.casetocase.AbstractOutbreak;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.CategoryOutbreak;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.BigDecimalUtils;
import dr.math.Binomial;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.io.IOException;
import java.io.PrintStream;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

public class WithinCaseCoalescent
extends CaseToCaseTreeLikelihood {
    public static final String WITHIN_CASE_COALESCENT = "withinCaseCoalescent";
    private double[] partitionTreeLogLikelihoods;
    private double[] storedPartitionTreeLogLikelihoods;
    private boolean[] recalculateCoalescentFlags;
    private DemographicModel demoModel;
    private Mode mode;
    private double coalescencesLogLikelihood;
    private double storedCoalescencesLogLikelihood;
    private boolean pleaseReExplode = true;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        public static final String STARTING_NETWORK = "startingNetwork";
        public static final String MAX_FIRST_INF_TO_ROOT = "maxFirstInfToRoot";
        public static final String DEMOGRAPHIC_MODEL = "demographicModel";
        public static final String TRUNCATE = "truncate";
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(PartitionedTreeModel.class, "The tree"), new ElementRule(CategoryOutbreak.class, "The set of cases", 0, 1), new ElementRule(CategoryOutbreak.class, "The set of cases", 0, 1), new ElementRule("startingNetwork", String.class, "A CSV file containing a specified starting network", true), new ElementRule("maxFirstInfToRoot", Parameter.class, "The maximum time from the first infection tothe root node"), new ElementRule("demographicModel", DemographicModel.class, "The demographic model for within-caseevolution"), AttributeRule.newBooleanRule("truncate")};

        @Override
        public String getParserName() {
            return WithinCaseCoalescent.WITHIN_CASE_COALESCENT;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            WithinCaseCoalescent withinCaseCoalescent;
            PartitionedTreeModel partitionedTreeModel = (PartitionedTreeModel)xMLObject.getChild(TreeModel.class);
            AbstractOutbreak abstractOutbreak = (AbstractOutbreak)xMLObject.getChild(AbstractOutbreak.class);
            Parameter parameter = (Parameter)xMLObject.getElementFirstChild(MAX_FIRST_INF_TO_ROOT);
            DemographicModel demographicModel = (DemographicModel)xMLObject.getElementFirstChild(DEMOGRAPHIC_MODEL);
            Mode mode = xMLObject.hasAttribute(TRUNCATE) & xMLObject.getBooleanAttribute(TRUNCATE) ? Mode.TRUNCATE : Mode.NORMAL;
            try {
                withinCaseCoalescent = new WithinCaseCoalescent(partitionedTreeModel, abstractOutbreak, parameter, demographicModel, mode);
            }
            catch (TaxonList.MissingTaxonException missingTaxonException) {
                throw new XMLParseException(missingTaxonException.toString());
            }
            return withinCaseCoalescent;
        }

        @Override
        public String getParserDescription() {
            return "This element provides a tree prior for a partitioned tree, with each partitioned tree generatedby a coalescent process";
        }

        @Override
        public Class getReturnType() {
            return WithinCaseCoalescent.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public WithinCaseCoalescent(PartitionedTreeModel partitionedTreeModel, AbstractOutbreak abstractOutbreak, Parameter parameter, DemographicModel demographicModel, Mode mode) throws TaxonList.MissingTaxonException {
        super(WITHIN_CASE_COALESCENT, partitionedTreeModel, abstractOutbreak, parameter);
        this.mode = mode;
        this.demoModel = demographicModel;
        this.addModel(demographicModel);
        this.addModel(this.outbreak);
        this.partitionTreeLogLikelihoods = new double[this.outbreak.getCases().size()];
        this.storedPartitionTreeLogLikelihoods = new double[this.outbreak.getCases().size()];
        this.recalculateCoalescentFlags = new boolean[this.outbreak.getCases().size()];
        Arrays.fill(this.recalculateCoalescentFlags, true);
        this.elementsAsTrees = new HashMap();
        for (AbstractCase abstractCase : this.outbreak.getCases()) {
            if (!abstractCase.wasEverInfected()) continue;
            this.elementsAsTrees.put(abstractCase, null);
        }
        this.storedElementsAsTrees = new HashMap();
    }

    @Override
    protected double calculateLogLikelihood() {
        if (this.pleaseReExplode) {
            this.explodeTree();
        }
        double d = 0.0;
        this.coalescencesLogLikelihood = 0.0;
        for (AbstractCase abstractCase : this.outbreak.getCases()) {
            int n = this.outbreak.getCaseIndex(abstractCase);
            if (abstractCase.wasEverInfected()) {
                if (this.recalculateCoalescentFlags[n]) {
                    CaseToCaseTreeLikelihood.Treelet treelet = (CaseToCaseTreeLikelihood.Treelet)this.elementsAsTrees.get(abstractCase);
                    if (treelet.getExternalNodeCount() > 1) {
                        SpecifiedZeroCoalescent specifiedZeroCoalescent = new SpecifiedZeroCoalescent(treelet, this.demoModel, treelet.getZeroHeight(), this.mode == Mode.TRUNCATE);
                        this.partitionTreeLogLikelihoods[n] = specifiedZeroCoalescent.calculateLogLikelihood();
                        this.coalescencesLogLikelihood += this.partitionTreeLogLikelihoods[n];
                    } else {
                        this.partitionTreeLogLikelihoods[n] = 0.0;
                    }
                    this.recalculateCoalescentFlags[n] = false;
                    continue;
                }
                this.coalescencesLogLikelihood += this.partitionTreeLogLikelihoods[n];
                continue;
            }
            this.recalculateCoalescentFlags[n] = false;
        }
        this.likelihoodKnown = true;
        return d += this.coalescencesLogLikelihood;
    }

    @Override
    public void storeState() {
        super.storeState();
        this.storedElementsAsTrees = new HashMap(this.elementsAsTrees);
        this.storedPartitionTreeLogLikelihoods = Arrays.copyOf(this.partitionTreeLogLikelihoods, this.partitionTreeLogLikelihoods.length);
        this.storedCoalescencesLogLikelihood = this.coalescencesLogLikelihood;
    }

    @Override
    public void restoreState() {
        super.restoreState();
        this.elementsAsTrees = this.storedElementsAsTrees;
        this.partitionTreeLogLikelihoods = this.storedPartitionTreeLogLikelihoods;
        this.coalescencesLogLikelihood = this.storedCoalescencesLogLikelihood;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        super.handleModelChangedEvent(model, object, n);
        if (model == this.treeModel) {
            if (!(object instanceof PartitionedTreeModel.PartitionsChangedEvent)) return;
            HashSet<AbstractCase> hashSet = ((PartitionedTreeModel.PartitionsChangedEvent)object).getCasesToRecalculate();
            for (AbstractCase abstractCase : hashSet) {
                this.recalculateCaseWCC(abstractCase);
            }
            return;
        } else if (model == this.getBranchMap()) {
            if (!(object instanceof ArrayList)) throw new RuntimeException("Unanticipated model changed event from BranchMapModel");
            for (int i = 0; i < ((ArrayList)object).size(); ++i) {
                BranchMapModel.BranchMapChangedEvent branchMapChangedEvent = (BranchMapModel.BranchMapChangedEvent)((ArrayList)object).get(i);
                this.recalculateCaseWCC(branchMapChangedEvent.getOldCase());
                this.recalculateCaseWCC(branchMapChangedEvent.getNewCase());
                NodeRef nodeRef = this.treeModel.getNode(branchMapChangedEvent.getNodeToRecalculate());
                NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
                if (nodeRef2 == null) continue;
                this.recalculateCaseWCC(this.getBranchMap().get(nodeRef2.getNumber()));
            }
            return;
        } else if (model == this.demoModel) {
            Arrays.fill(this.recalculateCoalescentFlags, true);
            return;
        } else {
            if (model != this.outbreak || !(object instanceof AbstractCase)) return;
            AbstractCase abstractCase = (AbstractCase)object;
            this.recalculateCaseWCC(abstractCase);
            AbstractCase abstractCase2 = ((PartitionedTreeModel)this.treeModel).getInfector(abstractCase);
            if (abstractCase2 == null) return;
            this.recalculateCaseWCC(abstractCase2);
        }
    }

    protected void recalculateCaseWCC(int n) {
        this.elementsAsTrees.put(this.outbreak.getCase(n), null);
        this.pleaseReExplode = true;
        this.recalculateCoalescentFlags[n] = true;
    }

    protected void recalculateCaseWCC(AbstractCase abstractCase) {
        if (abstractCase.wasEverInfected()) {
            this.recalculateCaseWCC(this.outbreak.getCaseIndex(abstractCase));
        }
    }

    @Override
    public void makeDirty() {
        super.makeDirty();
        Arrays.fill(this.recalculateCoalescentFlags, true);
        for (AbstractCase abstractCase : this.outbreak.getCases()) {
            if (!abstractCase.wasEverInfected()) continue;
            this.elementsAsTrees.put(abstractCase, null);
        }
        this.pleaseReExplode = true;
    }

    public ArrayList<AbstractCase> postOrderTransmissionTreeTraversal() {
        return this.traverseTransmissionTree(this.getBranchMap().get(this.treeModel.getRoot().getNumber()));
    }

    private ArrayList<AbstractCase> traverseTransmissionTree(AbstractCase abstractCase) {
        ArrayList<AbstractCase> arrayList = new ArrayList<AbstractCase>();
        HashSet<AbstractCase> hashSet = ((PartitionedTreeModel)this.treeModel).getInfectees(abstractCase);
        for (int i = 0; i < this.getOutbreak().size(); ++i) {
            AbstractCase abstractCase2 = this.getOutbreak().getCase(i);
            if (!hashSet.contains(abstractCase2)) continue;
            arrayList.addAll(this.traverseTransmissionTree(abstractCase2));
        }
        arrayList.add(abstractCase);
        return arrayList;
    }

    private CaseToCaseTreeLikelihood.Treelet transformTreelet(CaseToCaseTreeLikelihood.Treelet treelet) {
        double[] dArray = new double[treelet.getNodeCount()];
        double d = treelet.getZeroHeight();
        double d2 = d - 1.0;
        for (int i = 0; i < treelet.getNodeCount(); ++i) {
            NodeRef nodeRef = treelet.getNode(i);
            double d3 = treelet.getNodeHeight(nodeRef) - d;
            dArray[i] = -Math.log(-d3);
        }
        double d4 = Double.POSITIVE_INFINITY;
        for (double d5 : dArray) {
            if (!(d5 < d4)) continue;
            d4 = d5;
        }
        double d6 = -d4;
        CaseToCaseTreeLikelihood.Treelet treelet2 = new CaseToCaseTreeLikelihood.Treelet(treelet, d6);
        for (int i = 0; i < treelet2.getNodeCount(); ++i) {
            NodeRef nodeRef = treelet2.getNode(i);
            treelet2.setNodeHeight(nodeRef, dArray[i] - d4);
        }
        treelet2.resolveTree();
        return treelet2;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static double calculatePartitionTreeLogLikelihood(IntervalList intervalList, DemographicFunction demographicFunction, double d, double d2, boolean bl) {
        double d3 = 0.0;
        double d4 = -d2;
        int n = intervalList.getIntervalCount();
        for (int i = 0; i < n; ++i) {
            double d5;
            double d6;
            double d7;
            double d8;
            if (bl) {
                d8 = intervalList.getInterval(i);
                d7 = d4 + d8;
                if (d7 == 0.0) {
                    return Double.NEGATIVE_INFINITY;
                }
                d6 = demographicFunction.getIntegral(d4, d7);
                double d9 = demographicFunction.getIntegral(d4, 0.0);
                if (d6 == 0.0 && d8 > tolerance) {
                    return Double.NEGATIVE_INFINITY;
                }
                int n2 = intervalList.getLineageCount(i);
                if (n2 >= 2) {
                    double d10;
                    d5 = Binomial.choose2(n2);
                    if (intervalList.getIntervalType(i) == IntervalType.COALESCENT) {
                        d3 += -d5 * d6;
                        d10 = demographicFunction.getDemographic(d7);
                        if (d8 != 0.0 && !(d10 * (d6 / d8) >= d)) return Double.NEGATIVE_INFINITY;
                        d3 -= Math.log(d10);
                    } else {
                        d10 = Math.exp(-d5 * d6) - Math.exp(-d5 * d9);
                        d3 += Math.log(d10);
                    }
                    double d11 = (d10 = Math.exp(-d5 * d9)) != 1.0 ? Math.log1p(-d10) : WithinCaseCoalescent.handleDenominatorUnderflow(-d5 * d9);
                    d3 -= d11;
                }
                d4 = d7;
                continue;
            }
            if (!(demographicFunction instanceof LinearGrowth)) {
                throw new RuntimeException("Function must have zero population at t=0 if truncate=false");
            }
            d8 = intervalList.getInterval(i);
            d7 = d4 + d8;
            d6 = demographicFunction.getIntegral(d4, d7);
            if (d6 == 0.0 && d8 != 0.0) {
                return Double.NEGATIVE_INFINITY;
            }
            int n3 = intervalList.getLineageCount(i);
            double d12 = Binomial.choose2(n3);
            d3 += -d12 * d6;
            if (intervalList.getIntervalType(i) == IntervalType.COALESCENT) {
                d5 = demographicFunction.getDemographic(d7);
                if (d8 != 0.0 && !(d5 * (d6 / d8) >= d)) return Double.NEGATIVE_INFINITY;
                d3 -= Math.log(d5);
            }
            d4 = d7;
        }
        return d3;
    }

    private static double handleDenominatorUnderflow(double d) {
        BigDecimal bigDecimal = new BigDecimal(d);
        BigDecimal bigDecimal2 = BigDecimalUtils.exp(bigDecimal, bigDecimal.scale());
        BigDecimal bigDecimal3 = new BigDecimal(1.0);
        BigDecimal bigDecimal4 = bigDecimal3.subtract(bigDecimal2);
        BigDecimal bigDecimal5 = BigDecimalUtils.ln(bigDecimal4, bigDecimal4.scale());
        return bigDecimal5.doubleValue();
    }

    public void debugTreelet(Tree tree, String string) {
        try {
            FlexibleTree flexibleTree = new FlexibleTree(tree);
            for (int i = 0; i < flexibleTree.getNodeCount(); ++i) {
                FlexibleNode flexibleNode = (FlexibleNode)flexibleTree.getNode(i);
                flexibleNode.setAttribute("Number", flexibleNode.getNumber());
            }
            NexusExporter nexusExporter = new NexusExporter(new PrintStream(string));
            nexusExporter.exportTree((Tree)flexibleTree);
        }
        catch (IOException iOException) {
            System.out.println("IOException");
        }
    }

    @Override
    public LogColumn[] passColumns() {
        ArrayList<LogColumn> arrayList = new ArrayList<LogColumn>(Arrays.asList(super.passColumns()));
        if (this.outbreak instanceof CategoryOutbreak) {
            for (int i = 0; i < this.outbreak.size(); ++i) {
                if (!this.outbreak.getCase(i).wasEverInfected()) continue;
                final int n = i;
                arrayList.add(new LogColumn.Abstract("coal_LL_" + i){

                    @Override
                    protected String getFormattedValue() {
                        return String.valueOf(WithinCaseCoalescent.this.partitionTreeLogLikelihoods[n]);
                    }
                });
            }
            arrayList.add(new LogColumn.Abstract("total_coal_LL"){

                @Override
                protected String getFormattedValue() {
                    return String.valueOf(WithinCaseCoalescent.this.coalescencesLogLikelihood);
                }
            });
            return arrayList.toArray(new LogColumn[arrayList.size()]);
        }
        return null;
    }

    private class SpecifiedZeroCoalescent
    extends Coalescent {
        private double zeroHeight;
        boolean truncate;

        private SpecifiedZeroCoalescent(Tree tree, DemographicModel demographicModel, double d, boolean bl) {
            super(tree, demographicModel.getDemographicFunction());
            this.zeroHeight = d;
            this.truncate = bl;
        }

        @Override
        public double calculateLogLikelihood() {
            return WithinCaseCoalescent.calculatePartitionTreeLogLikelihood(this.getIntervals(), this.getDemographicFunction(), 0.0, this.zeroHeight, this.truncate);
        }
    }

    private static enum Mode {
        TRUNCATE,
        NORMAL;

    }
}

