/*
 * Decompiled with CFR 0.152.
 */
package org.jquantlib.methods.finitedifferences;

import org.jquantlib.math.matrixutilities.Array;
import org.jquantlib.methods.finitedifferences.Operator;

public class TridiagonalOperator
implements Operator {
    protected TimeSetter timeSetter;
    protected Array lowerDiagonal;
    protected Array diagonal;
    protected Array upperDiagonal;

    public TridiagonalOperator(int size) {
        if (size >= 3) {
            this.lowerDiagonal = new Array(size - 1);
            this.diagonal = new Array(size);
            this.upperDiagonal = new Array(size - 1);
        } else if (size == 0) {
            this.lowerDiagonal = new Array(0);
            this.diagonal = new Array(0);
            this.upperDiagonal = new Array(0);
        } else {
            throw new IllegalStateException("Invalid size for Tridiagonal Operator");
        }
    }

    public TridiagonalOperator(Array ldiag, Array diag, Array udiag) {
        if (ldiag.size() != diag.size() - 1) {
            throw new IllegalStateException("wrong size for lower diagonal");
        }
        if (udiag.size() != diag.size() - 1) {
            throw new IllegalStateException("wrong size for upper diagonal");
        }
        this.lowerDiagonal = ldiag;
        this.diagonal = diag;
        this.upperDiagonal = udiag;
    }

    public TridiagonalOperator(TridiagonalOperator t) {
        this.diagonal = t.diagonal();
        this.upperDiagonal = t.upperDiagonal();
        this.lowerDiagonal = t.lowerDiagonal();
        this.timeSetter = t.getTimeSetter();
    }

    public void setFirstRow(double b, double c) {
        this.diagonal.set(0, b);
        this.upperDiagonal.set(0, c);
    }

    public void setMidRow(int size, double a, double b, double c) {
        if (size < 1 || size > this.size() - 2) {
            throw new IllegalStateException("out of range in setMidRow");
        }
        this.lowerDiagonal.set(size - 1, a);
        this.diagonal.set(size, b);
        this.upperDiagonal.set(size, c);
    }

    public void setMidRows(double a, double b, double c) {
        for (int i = 1; i <= this.size() - 2; ++i) {
            this.lowerDiagonal.set(i - 1, a);
            this.diagonal.set(i, b);
            this.upperDiagonal.set(i, c);
        }
    }

    public void setLastRow(double a, double b) {
        this.lowerDiagonal.set(this.size() - 2, a);
        this.diagonal.set(this.size() - 1, b);
    }

    public final Array lowerDiagonal() {
        return this.lowerDiagonal;
    }

    public final Array diagonal() {
        return this.diagonal;
    }

    public final Array upperDiagonal() {
        return this.upperDiagonal;
    }

    public TimeSetter getTimeSetter() {
        return this.timeSetter;
    }

    @Override
    public int size() {
        return this.diagonal.size();
    }

    @Override
    public boolean isTimeDependent() {
        return this.timeSetter != null;
    }

    @Override
    public void setTime(double t) {
        if (this.timeSetter != null) {
            this.timeSetter.setTime(t, this);
        }
    }

    public Operator add(Operator op) {
        TridiagonalOperator D = (TridiagonalOperator)op;
        Array low = this.lowerDiagonal.add(D.lowerDiagonal);
        Array mid = this.diagonal.add(D.diagonal);
        Array high = this.upperDiagonal.add(D.upperDiagonal);
        return new TridiagonalOperator(low, mid, high);
    }

    public Operator subtract(Operator op) {
        TridiagonalOperator D = (TridiagonalOperator)op;
        Array low = this.lowerDiagonal.sub(D.lowerDiagonal);
        Array mid = this.diagonal.sub(D.diagonal);
        Array high = this.upperDiagonal.sub(D.upperDiagonal);
        return new TridiagonalOperator(low, mid, high);
    }

    public Operator multiply(double a) {
        Array low = this.lowerDiagonal.mul(a);
        Array mid = this.diagonal.mul(a);
        Array high = this.upperDiagonal.mul(a);
        return new TridiagonalOperator(low, mid, high);
    }

    public TridiagonalOperator identity(int size) {
        TridiagonalOperator I = new TridiagonalOperator(new Array(size - 1), new Array(size).fill(1.0), new Array(size - 1));
        return I;
    }

    public void swap(Operator from) {
        TridiagonalOperator D = (TridiagonalOperator)from;
        this.diagonal.swap(D.diagonal);
        this.lowerDiagonal.swap(D.lowerDiagonal);
        this.upperDiagonal.swap(D.upperDiagonal);
        TimeSetter tmpTimeSetter = this.timeSetter;
        this.timeSetter = D.timeSetter;
        D.timeSetter = tmpTimeSetter;
    }

    @Override
    public Array applyTo(Array v) {
        if (v.size() != this.size()) {
            throw new IllegalStateException("vector of the wrong size (" + v.size() + "instead of " + this.size() + ")");
        }
        Array result = this.diagonal.mul(v);
        double d = result.get(0) + this.upperDiagonal.get(0) * v.get(1);
        result.set(0, d);
        for (int j = 1; j <= this.size() - 2; ++j) {
            d = result.get(j) + this.lowerDiagonal.get(j - 1) * v.get(j - 1) + this.upperDiagonal.get(j) * v.get(j + 1);
            result.set(j, d);
        }
        d = result.get(this.size() - 1) + this.lowerDiagonal.get(this.size() - 2) * v.get(this.size() - 2);
        result.set(this.size() - 1, d);
        return result;
    }

    @Override
    public final Array solveFor(Array rhs) {
        int j;
        Array result = new Array(this.size());
        Array tmp = new Array(this.size());
        double bet = this.diagonal.first();
        if (bet == 0.0) {
            throw new IllegalStateException("division by zero");
        }
        result.set(0, rhs.first() / bet);
        for (j = 1; j <= this.size() - 1; ++j) {
            tmp.set(j, this.upperDiagonal.get(j - 1) / bet);
            bet = this.diagonal.get(j) - this.lowerDiagonal.get(j - 1) * tmp.get(j);
            if (bet == 0.0) {
                throw new IllegalStateException("division by zero");
            }
            result.set(j, (rhs.get(j) - this.lowerDiagonal.get(j - 1) * result.get(j - 1)) / bet);
        }
        for (j = this.size() - 2; j > 0; --j) {
            result.set(j, result.get(j) - tmp.get(j + 1) * result.get(j + 1));
        }
        result.set(0, result.first() - tmp.get(1) * result.get(1));
        return result;
    }

    @Override
    public final double[] solveFor(double[] rhs) {
        int j;
        double[] result = new double[this.size()];
        double[] tmp = new double[this.size()];
        double bet = this.diagonal.first();
        if (bet == 0.0) {
            throw new IllegalStateException("division by zero");
        }
        result[0] = rhs[0] / bet;
        for (j = 1; j <= this.size() - 1; ++j) {
            tmp[j] = this.upperDiagonal.get(j - 1) / bet;
            bet = this.diagonal.get(j) - this.lowerDiagonal.get(j - 1) * tmp[j];
            if (bet == 0.0) {
                throw new IllegalStateException("division by zero");
            }
            result[j] = (rhs[j] - this.lowerDiagonal.get(j - 1) * result[j - 1]) / bet;
        }
        for (j = this.size() - 2; j > 0; --j) {
            int n = j;
            result[n] = result[n] - tmp[j + 1] * result[j + 1];
        }
        result[0] = result[0] - tmp[1] * result[1];
        return result;
    }

    public static interface TimeSetter {
        public void setTime(double var1, TridiagonalOperator var3);
    }
}

