package us.ihmc.utilities.math.functionApproximation;

import com.mathworks.jama.Matrix;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Random;

/* loaded from: input_file:us/ihmc/utilities/math/functionApproximation/LinearRegression.class */
public class LinearRegression {
    private Matrix inputMatrix;
    private Matrix outputVector;
    private boolean solved;
    private double error;
    private Matrix betaVector;
    public static boolean VERBOSE = false;
    private static PrintWriter printWriter = new PrintWriter(System.err);

    public LinearRegression(Matrix matrix, Matrix matrix2) {
        setMatrices(matrix, matrix2);
    }

    public LinearRegression(ArrayList<double[]> arrayList, ArrayList<Double> arrayList2) {
        double[][] dArr = new double[arrayList.size()][arrayList.get(0).length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = arrayList.get(i);
        }
        Matrix matrix = new Matrix(dArr);
        double[] dArr2 = new double[arrayList2.size()];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = arrayList2.get(i2).doubleValue();
        }
        setMatrices(matrix, new Matrix(dArr2, dArr2.length));
    }

    public LinearRegression(double[][] dArr, double[] dArr2) {
        setMatrices(new Matrix(dArr), new Matrix(dArr2, dArr2.length));
    }

    private void setMatrices(Matrix matrix, Matrix matrix2) {
        this.inputMatrix = matrix;
        this.outputVector = matrix2;
        this.solved = false;
    }

    public boolean solve() {
        if (this.solved) {
            return true;
        }
        Matrix matrix = this.inputMatrix;
        Matrix transpose = this.inputMatrix.transpose();
        Matrix times = transpose.times(matrix);
        if (VERBOSE) {
            System.out.println("LinearRegression::solve: X, Y : ");
            this.inputMatrix.print(5, 5);
            this.outputVector.print(5, 5);
            System.out.println("LinearRegression::solve: xTransX.det() : " + times.det());
        }
        if (Math.abs(times.det()) < 1.0E-86d) {
            return false;
        }
        printWriter.flush();
        this.betaVector = times.inverse().times(transpose).times(this.outputVector);
        Matrix minus = this.outputVector.minus(matrix.times(this.betaVector));
        this.error = 0.0d;
        for (int i = 0; i < minus.getColumnDimension(); i++) {
            this.error += minus.get(0, i) * minus.get(0, i);
        }
        this.solved = true;
        return true;
    }

    public double getSquaredError() {
        if (this.solved) {
            return this.error;
        }
        throw new IllegalStateException("cannot get error before the Regression has been solved");
    }

    public Matrix getCoefficientVectorAsMatrix() {
        return this.betaVector;
    }

    public void packCoefficientVector(double[] dArr) {
        if (dArr.length != this.betaVector.getRowDimension()) {
            throw new IllegalArgumentException("given array must have size " + this.betaVector.getRowDimension());
        }
        for (int i = 0; i < this.betaVector.getRowDimension(); i++) {
            dArr[i] = this.betaVector.get(i, 0);
        }
    }

    public static void main(String[] strArr) {
        Random random = new Random(1984L);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 500; i++) {
            double nextDouble = (random.nextDouble() * 2.0d) - 1.0d;
            double[] dArr = {1.0d, nextDouble, nextDouble * nextDouble};
            double nextDouble2 = (1.0d * nextDouble) + (5.0d * nextDouble * nextDouble) + (random.nextDouble() * 0.1d);
            arrayList.add(dArr);
            arrayList2.add(Double.valueOf(nextDouble2));
        }
        LinearRegression linearRegression = new LinearRegression((ArrayList<double[]>) arrayList, (ArrayList<Double>) arrayList2);
        long nanoTime = System.nanoTime();
        linearRegression.solve();
        System.out.println("LinearRegression::main: regression took " + ((System.nanoTime() - nanoTime) / 1000000.0d) + " ms");
        double[] dArr2 = new double[3];
        linearRegression.packCoefficientVector(dArr2);
        System.out.println("LinearRegression::main: coefficients are " + dArr2[0] + ", " + dArr2[1] + ", " + dArr2[2]);
        System.out.println("LinearRegression::main: linearRegression.getSquaredError() : " + linearRegression.getSquaredError());
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i2 = 0; i2 < 500; i2++) {
            double nextDouble3 = (random.nextDouble() * 2.0d) - 1.0d;
            double nextDouble4 = (random.nextDouble() * 2.0d) - 1.0d;
            double[] dArr3 = {1.0d, nextDouble3, nextDouble3 * nextDouble3, nextDouble4, nextDouble4 * nextDouble4, nextDouble3 * nextDouble4};
            double nextDouble5 = (((4.0d + (1.0d * nextDouble3)) + ((5.0d * nextDouble3) * nextDouble4)) - ((3.0d * nextDouble4) * nextDouble4)) + (random.nextDouble() * 0.1d);
            arrayList3.add(dArr3);
            arrayList4.add(Double.valueOf(nextDouble5));
        }
        LinearRegression linearRegression2 = new LinearRegression((ArrayList<double[]>) arrayList3, (ArrayList<Double>) arrayList4);
        long nanoTime2 = System.nanoTime();
        linearRegression2.solve();
        System.out.println("LinearRegression::main: regression took " + ((System.nanoTime() - nanoTime2) / 1000000.0d) + " ms");
        double[] dArr4 = new double[6];
        linearRegression2.packCoefficientVector(dArr4);
        System.out.println("LinearRegression::main: coefficients are " + dArr4[0] + ", " + dArr4[1] + ", " + dArr4[2] + ", " + dArr4[3] + ", " + dArr4[4] + ", " + dArr4[5]);
        System.out.println("LinearRegression::main: linearRegression.getSquaredError() : " + linearRegression2.getSquaredError());
    }
}
