/*
 * Decompiled with CFR 0.152.
 */
package org.jquantlib.model.shortrate.onefactormodels;

import org.jquantlib.QL;
import org.jquantlib.instruments.Option;
import org.jquantlib.lang.exceptions.LibraryException;
import org.jquantlib.math.Constants;
import org.jquantlib.math.distributions.NonCentralChiSquaredDistribution;
import org.jquantlib.math.matrixutilities.Array;
import org.jquantlib.methods.lattices.Lattice;
import org.jquantlib.methods.lattices.TrinomialTree;
import org.jquantlib.model.Parameter;
import org.jquantlib.model.TermStructureFittingParameter;
import org.jquantlib.model.shortrate.onefactormodels.CoxIngersollRoss;
import org.jquantlib.model.shortrate.onefactormodels.OneFactorModel;
import org.jquantlib.model.shortrate.onefactormodels.TermStructureConsistentModelClass;
import org.jquantlib.quotes.Handle;
import org.jquantlib.termstructures.Compounding;
import org.jquantlib.termstructures.YieldTermStructure;
import org.jquantlib.time.Frequency;
import org.jquantlib.time.TimeGrid;

public class ExtendedCoxIngersollRoss
extends CoxIngersollRoss {
    private static final String STRIKE_MUST_BE_POSITIVE = "strike must be positive";
    private final TermStructureConsistentModelClass termstructureConsistentModel;
    private Parameter phi_;

    public ExtendedCoxIngersollRoss(Handle<YieldTermStructure> termStructure, double theta, double k, double sigma, double x0) {
        super(x0, theta, k, sigma);
        this.termstructureConsistentModel = new TermStructureConsistentModelClass(termStructure);
        this.generateArguments();
    }

    @Override
    public OneFactorModel.ShortRateDynamics dynamics() {
        return new Dynamics(this.phi_, this.theta(), this.k(), this.sigma(), this.x0());
    }

    @Override
    public void generateArguments() {
        this.phi_ = new FittingParameter(this.termstructureConsistentModel.termStructure(), this.theta(), this.k(), this.sigma(), this.x0());
    }

    @Override
    public double A(double t, double s) {
        double pt = this.termstructureConsistentModel.termStructure().currentLink().discount(t);
        double ps = this.termstructureConsistentModel.termStructure().currentLink().discount(s);
        double value = super.A(t, s) * Math.exp(this.B(t, s) * this.phi_.get(t)) * (ps * super.A(0.0, t) * Math.exp(-this.B(0.0, t) * this.x0())) / (pt * super.A(0.0, s) * Math.exp(-this.B(0.0, s) * this.x0()));
        return value;
    }

    @Override
    public double discountBondOption(Option.Type type, double strike, double t, double s) {
        QL.require(strike > 0.0, STRIKE_MUST_BE_POSITIVE);
        double discountT = this.termstructureConsistentModel.termStructure().currentLink().discount(t);
        double discountS = this.termstructureConsistentModel.termStructure().currentLink().discount(s);
        if (t < Constants.QL_EPSILON) {
            switch (type) {
                case Call: {
                    return Math.max(discountS - strike, 0.0);
                }
                case Put: {
                    return Math.max(strike - discountS, 0.0);
                }
            }
            throw new LibraryException("unknown option type");
        }
        double sigma2 = this.sigma() * this.sigma();
        double h = Math.sqrt(this.k() * this.k() + 2.0 * sigma2);
        double r0 = this.termstructureConsistentModel.termStructure().currentLink().forwardRate(0.0, 0.0, Compounding.Continuous, Frequency.NoFrequency).rate();
        double b = this.B(t, s);
        double rho = 2.0 * h / (sigma2 * (Math.exp(h * t) - 1.0));
        double psi = (this.k() + h) / sigma2;
        double df = 4.0 * this.k() * this.theta() / sigma2;
        double ncps = 2.0 * rho * rho * (r0 - this.phi_.get(0.0)) * Math.exp(h * t) / (rho + psi + b);
        double ncpt = 2.0 * rho * rho * (r0 - this.phi_.get(0.0)) * Math.exp(h * t) / (rho + psi);
        NonCentralChiSquaredDistribution chis = new NonCentralChiSquaredDistribution(df, ncps);
        NonCentralChiSquaredDistribution chit = new NonCentralChiSquaredDistribution(df, ncpt);
        double z = Math.log(super.A(t, s) / strike) / b;
        double call = discountS * chis.op(2.0 * z * (rho + psi + b)) - strike * discountT * chit.op(2.0 * z * (rho + psi));
        if (type.equals((Object)Option.Type.Call)) {
            return call;
        }
        return call - discountS + strike * discountT;
    }

    @Override
    public Lattice tree(TimeGrid grid) {
        TermStructureFittingParameter phi = new TermStructureFittingParameter(this.termstructureConsistentModel.termStructure());
        Dynamics numericDynamics = new Dynamics(phi, this.theta(), this.k(), this.sigma(), this.x0());
        TrinomialTree trinominal = new TrinomialTree(numericDynamics.process(), grid, true);
        TermStructureFittingParameter.NumericalImpl impl = (TermStructureFittingParameter.NumericalImpl)phi.implementation();
        return new OneFactorModel.ShortRateTree(this, trinominal, numericDynamics, impl, grid);
    }

    private static class FittingParameter
    extends TermStructureFittingParameter {
        public FittingParameter(Handle<YieldTermStructure> termStructure, double theta, double k, double sigma, double x0) {
            super(new Impl(termStructure, theta, k, sigma, x0));
        }

        public FittingParameter(Handle<YieldTermStructure> term) {
            super(term);
        }

        private static class Impl
        implements Parameter.Impl {
            private final Handle<YieldTermStructure> termStructure;
            private final double theta;
            private final double k;
            private final double sigma;
            private final double x0;

            public Impl(Handle<YieldTermStructure> termStructure, double theta, double k, double sigma, double x0) {
                this.termStructure = termStructure;
                this.theta = theta;
                this.k = k;
                this.sigma = sigma;
                this.x0 = x0;
            }

            @Override
            public double value(Array params, double t) {
                double forwardRate = this.termStructure.currentLink().forwardRate(t, t, Compounding.Continuous, Frequency.NoFrequency).rate();
                double h = Math.sqrt(this.k * this.k + 2.0 * this.sigma * this.sigma);
                double expth = Math.exp(t * h);
                double temp = 2.0 * h + (this.k + h) * (expth - 1.0);
                double phi = forwardRate - 2.0 * this.k * this.theta * (expth - 1.0) / temp - this.x0 * 4.0 * h * h * expth / (temp * temp);
                return phi;
            }
        }
    }

    private class Dynamics
    extends CoxIngersollRoss.Dynamics {
        private final Parameter phi;

        public Dynamics(Parameter phi, double theta, double k, double sigma, double x0) {
            super(theta, k, sigma, x0);
            this.phi = phi;
        }

        @Override
        public double variable(double t, double r) {
            return Math.sqrt(r - this.phi.get(t));
        }

        @Override
        public double shortRate(double t, double y) {
            return y * y + this.phi.get(t);
        }
    }
}

