/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evolution.datatype.DataType;
import dr.evomodel.substmodel.ComplexSubstitutionModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModelUtil;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.LogAdditiveCtmcRateProvider;
import dr.evomodel.substmodel.ParameterReplaceableSubstitutionModel;
import dr.inference.distribution.GeneralizedLinearModel;
import dr.inference.distribution.LogLinearModel;
import dr.inference.loggers.LogColumn;
import dr.inference.model.BayesianStochasticSearchVariableSelection;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class GlmSubstitutionModel
extends ComplexSubstitutionModel
implements ParameterReplaceableSubstitutionModel,
DifferentiableSubstitutionModel {
    private final LogAdditiveCtmcRateProvider glm;
    private final double[] testProbabilities;

    public GlmSubstitutionModel(String string, DataType dataType, FrequencyModel frequencyModel, LogAdditiveCtmcRateProvider logAdditiveCtmcRateProvider) {
        super(string, dataType, frequencyModel, (Parameter)null);
        this.glm = logAdditiveCtmcRateProvider;
        this.addModel(logAdditiveCtmcRateProvider);
        this.testProbabilities = new double[this.stateCount * this.stateCount];
    }

    @Override
    public LogAdditiveCtmcRateProvider getRateProvider() {
        return this.glm;
    }

    public GeneralizedLinearModel getGeneralizedLinearModel() {
        if (this.glm instanceof GeneralizedLinearModel) {
            return (GeneralizedLinearModel)((Object)this.glm);
        }
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected void setupRelativeRates(double[] dArray) {
        System.arraycopy(this.glm.getRates(), 0, dArray, 0, dArray.length);
    }

    @Override
    public Set<Likelihood> getLikelihoodSet() {
        return new HashSet<Likelihood>(Collections.singletonList(this));
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.glm) {
            this.updateMatrix = true;
            this.fireModelChanged();
        } else {
            super.handleModelChangedEvent(model, object, n);
        }
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[this.glm.getColumns().length + 2];
        int n = 0;
        LogColumn[] logColumnArray2 = this.glm.getColumns();
        int n2 = logColumnArray2.length;
        for (int i = 0; i < n2; ++i) {
            LogColumn logColumn;
            logColumnArray[n] = logColumn = logColumnArray2[i];
            ++n;
        }
        logColumnArray[n++] = new ComplexSubstitutionModel.LikelihoodColumn(this.getId() + ".L");
        logColumnArray[n] = new ComplexSubstitutionModel.NormalizationColumn(this.getId() + ".Norm");
        return logColumnArray;
    }

    @Override
    public double getLogLikelihood() {
        double d = super.getLogLikelihood();
        if (d == 0.0 && BayesianStochasticSearchVariableSelection.Utils.connectedAndWellConditioned(this.testProbabilities, this)) {
            return 0.0;
        }
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public String getDescription() {
        return "Generalized linear (model, GLM) substitution model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CommonCitations.LEMEY_2014_UNIFYING);
    }

    @Override
    public ParameterReplaceableSubstitutionModel factory(List<Parameter> list, List<Parameter> list2) {
        GeneralizedLinearModel generalizedLinearModel = ((LogLinearModel)this.glm).factory((List)list, (List)list2);
        return new GlmSubstitutionModel(this.getModelName(), this.dataType, this.freqModel, (LogAdditiveCtmcRateProvider)((Object)generalizedLinearModel));
    }

    @Override
    public WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter) {
        return DifferentiableSubstitutionModelUtil.getInfinitesimalDifferentialMatrix(wrtParameter, this);
    }

    @Override
    public DifferentialMassProvider.DifferentialWrapper.WrtParameter factory(Parameter parameter, int n) {
        int n2 = ((LogLinearModel)this.glm).getEffectNumber(parameter);
        if (n2 == -1) {
            throw new RuntimeException("Only implemented for single dimensions, break up beta to one for each block for now please.");
        }
        return new WrtOldGLMSubstitutionModelParameter((LogLinearModel)this.glm, n2, n, this.stateCount);
    }

    @Override
    public void setupDifferentialRates(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[] dArray, double d) {
        double[] dArray2 = new double[this.stateCount * this.stateCount];
        this.getInfinitesimalMatrix(dArray2);
        wrtParameter.setupDifferentialRates(dArray, dArray2, d);
    }

    @Override
    public void setupDifferentialFrequency(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[] dArray) {
        wrtParameter.setupDifferentialFrequencies(dArray, this.getFrequencyModel().getFrequencies());
    }

    @Override
    public double getWeightedNormalizationGradient(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[][] dArray, double[] dArray2) {
        double d = 0.0;
        if (this.getNormalization()) {
            for (int i = 0; i < this.stateCount; ++i) {
                d -= dArray[i][i] * this.getFrequencyModel().getFrequency(i);
            }
        }
        return d;
    }

    static class WrtOldGLMSubstitutionModelParameter
    implements DifferentialMassProvider.DifferentialWrapper.WrtParameter {
        private final int dim;
        private final int fixedEffectIndex;
        private final int stateCount;
        private final LogLinearModel glm;

        public WrtOldGLMSubstitutionModelParameter(LogLinearModel logLinearModel, int n, int n2, int n3) {
            this.glm = logLinearModel;
            this.fixedEffectIndex = n;
            this.dim = n2;
            this.stateCount = n3;
        }

        @Override
        public double getRate(int n) {
            throw new RuntimeException("Should not be called.");
        }

        @Override
        public double getNormalizationDifferential() {
            return 0.0;
        }

        @Override
        public void setupDifferentialFrequencies(double[] dArray, double[] dArray2) {
            Arrays.fill(dArray, 1.0);
        }

        @Override
        public void setupDifferentialRates(double[] dArray, double[] dArray2, double d) {
            int n;
            int n2;
            double[] dArray3 = this.glm.getDesignMatrix(this.fixedEffectIndex).getColumnValues(this.dim);
            int n3 = 0;
            for (n2 = 0; n2 < this.stateCount; ++n2) {
                for (n = n2 + 1; n < this.stateCount; ++n) {
                    dArray[n3] = dArray3[n3] * dArray2[this.index(n2, n)];
                    ++n3;
                }
            }
            for (n2 = 0; n2 < this.stateCount; ++n2) {
                for (n = n2 + 1; n < this.stateCount; ++n) {
                    dArray[n3] = dArray3[n3] * dArray2[this.index(n, n2)];
                    ++n3;
                }
            }
        }

        double getChainRule() {
            return Math.exp(this.glm.getFixedEffect(this.fixedEffectIndex).getParameterValue(this.dim));
        }

        private int index(int n, int n2) {
            return n * this.stateCount + n2;
        }
    }
}

