/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.hmc;

import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public interface MultivariateChainRule {
    public double[] chainGradient(double[] var1);

    public void chainGradient(DenseMatrix64F var1);

    public static class InverseGeneral
    implements MultivariateChainRule {
        private final DenseMatrix64F Mat;
        private final DenseMatrix64F temp;
        private final int dim;

        public InverseGeneral(double[] dArray) {
            this.dim = (int)Math.sqrt(dArray.length);
            this.Mat = DenseMatrix64F.wrap(this.dim, this.dim, dArray);
            this.temp = new DenseMatrix64F(this.dim, this.dim);
        }

        public InverseGeneral(DenseMatrix64F denseMatrix64F) {
            this.dim = denseMatrix64F.getNumCols();
            assert (this.dim == denseMatrix64F.getNumRows()) : "Inverse rule is only valid for square matrices.";
            this.Mat = denseMatrix64F;
            this.temp = new DenseMatrix64F(this.dim, this.dim);
        }

        @Override
        public double[] chainGradient(double[] dArray) {
            assert (dArray.length == this.dim * this.dim);
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dim, this.dim);
            DenseMatrix64F denseMatrix64F2 = DenseMatrix64F.wrap(this.dim, this.dim, dArray);
            CommonOps.mult(this.Mat, denseMatrix64F2, this.temp);
            CommonOps.mult(-1.0, this.temp, this.Mat, denseMatrix64F);
            return denseMatrix64F.getData();
        }

        @Override
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            CommonOps.mult(this.Mat, denseMatrix64F, this.temp);
            CommonOps.mult(-1.0, this.temp, this.Mat, denseMatrix64F);
        }
    }

    public static class Inverse
    implements MultivariateChainRule {
        private final double[] vecP;
        private final double[] vecV;
        private final int dim;

        Inverse(double[] dArray, double[] dArray2) {
            this.vecP = dArray;
            this.vecV = dArray2;
            this.dim = (int)Math.sqrt(dArray.length);
        }

        @Override
        public double[] chainGradient(double[] dArray) {
            assert (dArray.length == this.dim * this.dim);
            double[] dArray2 = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim * this.dim; ++i) {
                if (this.vecV[i] == 0.0 || Double.isNaN(this.vecV[i])) {
                    throw new RuntimeException("0 or NaN value in variance. check start value or use smaller step size for hmc");
                }
                dArray2[i] = -dArray[i] * this.vecP[i] / this.vecV[i];
            }
            return dArray2;
        }

        @Override
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            throw new RuntimeException("not yet implemented");
        }
    }

    public static class Chain
    implements MultivariateChainRule {
        private final MultivariateChainRule[] rules;

        Chain(MultivariateChainRule[] multivariateChainRuleArray) {
            this.rules = multivariateChainRuleArray;
        }

        @Override
        public double[] chainGradient(double[] dArray) {
            for (MultivariateChainRule multivariateChainRule : this.rules) {
                dArray = multivariateChainRule.chainGradient(dArray);
            }
            return dArray;
        }

        @Override
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            for (MultivariateChainRule multivariateChainRule : this.rules) {
                multivariateChainRule.chainGradient(denseMatrix64F);
            }
        }
    }
}

