/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.transformations;

import cc.redberry.core.indexgenerator.IndexGeneratorImpl;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.ProductBuilder;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.substitutions.SubstitutionTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.TIntCollection;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;

public final class DifferentiateTransformation
implements Transformation {
    private final SimpleTensor[] vars;
    private final Transformation[] expandAndContract;

    public DifferentiateTransformation(SimpleTensor ... vars) {
        this.vars = vars;
        this.expandAndContract = new Transformation[0];
    }

    public DifferentiateTransformation(SimpleTensor[] vars, Transformation[] expandAndContract) {
        this.vars = vars;
        this.expandAndContract = expandAndContract;
    }

    @Override
    public Tensor transform(Tensor t) {
        return DifferentiateTransformation.differentiate(t, this.expandAndContract, this.vars);
    }

    public static Tensor differentiate(Tensor tensor, SimpleTensor var, int order) {
        if (var.getIndices().size() != 0 && order > 1) {
            throw new IllegalArgumentException();
        }
        while (order > 0) {
            tensor = DifferentiateTransformation.differentiate(tensor, new Transformation[0], var);
            --order;
        }
        return tensor;
    }

    public static Tensor differentiate(Tensor tensor, SimpleTensor ... vars) {
        if (vars.length == 0) {
            return tensor;
        }
        if (vars.length == 1) {
            return DifferentiateTransformation.differentiate(tensor, new Transformation[0], vars[0]);
        }
        return DifferentiateTransformation.differentiate(tensor, new Transformation[0], vars);
    }

    /*
     * WARNING - void declaration
     */
    public static Tensor differentiate(Tensor tensor, Transformation[] expandAndContract, SimpleTensor ... vars) {
        void var7_16;
        if (vars.length == 0) {
            return tensor;
        }
        if (vars.length == 1) {
            return DifferentiateTransformation.differentiate(tensor, expandAndContract, vars[0]);
        }
        boolean needRename = false;
        for (SimpleTensor simpleTensor : vars) {
            if (simpleTensor.getIndices().size() == 0) continue;
            needRename = true;
            break;
        }
        Tensor[] resolvedVars = vars;
        if (needRename) {
            TIntHashSet allTensorIndices = TensorUtils.getAllIndicesNamesT(tensor);
            TIntHashSet dummyTensorIndices = new TIntHashSet((TIntCollection)allTensorIndices);
            dummyTensorIndices.removeAll(tensor.getIndices().getFree().getAllIndices().copy());
            needRename = false;
            SimpleTensor[] simpleTensorArray = vars;
            int n = simpleTensorArray.length;
            for (int i = 0; i < n; ++i) {
                SimpleTensor var = simpleTensorArray[i];
                if (!DifferentiateTransformation.containsIndicesNames(allTensorIndices, var.getIndices().getNamesOfDummies()) && !DifferentiateTransformation.containsIndicesNames(dummyTensorIndices, var.getIndices())) continue;
                needRename = true;
                break;
            }
            for (SimpleTensor var : vars) {
                allTensorIndices.addAll(IndicesUtils.getIndicesNames(var.getIndices().getFree()));
            }
            if (needRename) {
                void var7_14;
                resolvedVars = (SimpleTensor[])vars.clone();
                boolean bl = false;
                while (var7_14 < vars.length) {
                    if (!allTensorIndices.isEmpty() && ((SimpleTensor)resolvedVars[var7_14]).getIndices().size() != 0) {
                        if (((SimpleTensor)resolvedVars[var7_14]).getIndices().size() != ((SimpleTensor)resolvedVars[var7_14]).getIndices().getFree().size()) {
                            resolvedVars[var7_14] = (SimpleTensor)ApplyIndexMapping.renameDummy(resolvedVars[var7_14], allTensorIndices.toArray());
                        }
                        allTensorIndices.addAll(IndicesUtils.getIndicesNames(((SimpleTensor)resolvedVars[var7_14]).getIndices()));
                    }
                    ++var7_14;
                }
                tensor = ApplyIndexMapping.renameDummy(tensor, TensorUtils.getAllIndicesNamesT(resolvedVars).toArray(), allTensorIndices);
            }
            tensor = ApplyIndexMapping.renameIndicesOfFieldsArguments(tensor, (TIntSet)allTensorIndices);
        }
        Tensor[] tensorArray = resolvedVars;
        int n = tensorArray.length;
        boolean bl = false;
        while (var7_16 < n) {
            SimpleTensor simpleTensor = tensorArray[var7_16];
            tensor = DifferentiateTransformation.differentiate1(tensor, DifferentiateTransformation.createRule(simpleTensor), expandAndContract);
            ++var7_16;
        }
        return tensor;
    }

    private static Tensor differentiate(Tensor tensor, Transformation[] expandAndContract, SimpleTensor var) {
        if (var.getIndices().size() != 0) {
            TIntHashSet allTensorIndices = TensorUtils.getAllIndicesNamesT(tensor);
            TIntHashSet dummyTensorIndices = new TIntHashSet((TIntCollection)allTensorIndices);
            dummyTensorIndices.removeAll(tensor.getIndices().getFree().getAllIndices().copy());
            if (DifferentiateTransformation.containsIndicesNames(allTensorIndices, var.getIndices().getNamesOfDummies()) || DifferentiateTransformation.containsIndicesNames(dummyTensorIndices, var.getIndices())) {
                allTensorIndices.addAll(IndicesUtils.getIndicesNames(var.getIndices()));
                var = (SimpleTensor)ApplyIndexMapping.renameDummy(var, TensorUtils.getAllIndicesNamesT(tensor).toArray());
                tensor = ApplyIndexMapping.renameDummy(tensor, TensorUtils.getAllIndicesNamesT(var).toArray(), allTensorIndices);
            } else {
                allTensorIndices.addAll(IndicesUtils.getIndicesNames(var.getIndices()));
            }
            tensor = ApplyIndexMapping.renameIndicesOfFieldsArguments(tensor, (TIntSet)allTensorIndices);
        }
        return DifferentiateTransformation.differentiate1(tensor, DifferentiateTransformation.createRule(var), expandAndContract);
    }

    private static boolean containsIndicesNames(TIntHashSet set, Indices indices) {
        for (int i = 0; i < indices.size(); ++i) {
            if (!set.contains(IndicesUtils.getNameWithType(indices.get(i)))) continue;
            return true;
        }
        return false;
    }

    private static boolean containsIndicesNames(TIntHashSet set, int[] indices) {
        for (int i : indices) {
            if (!set.contains(IndicesUtils.getNameWithType(i))) continue;
            return true;
        }
        return false;
    }

    private static Tensor differentiateWithRenaming(Tensor tensor, SimpleTensorDifferentiationRule rule, Transformation[] expandAndEliminate) {
        SimpleTensorDifferentiationRule newRule = rule.newRuleForTensor(tensor);
        tensor = ApplyIndexMapping.renameDummy(tensor, newRule.getForbidden());
        return DifferentiateTransformation.differentiate1(tensor, newRule, expandAndEliminate);
    }

    private static Tensor differentiate1(Tensor tensor, SimpleTensorDifferentiationRule rule, Transformation[] transformations) {
        if (tensor.getClass() == SimpleTensor.class) {
            Tensor temp = rule.differentiateSimpleTensor((SimpleTensor)tensor);
            return DifferentiateTransformation.applyTransformations(temp, transformations);
        }
        if (tensor.getClass() == TensorField.class) {
            TensorField field = (TensorField)tensor;
            SumBuilder result = new SumBuilder(tensor.size());
            for (int i = tensor.size() - 1; i >= 0; --i) {
                Tensor dArg = DifferentiateTransformation.differentiate1(field.get(i), rule, transformations);
                if (TensorUtils.isZero(dArg)) continue;
                result.put(Tensors.multiply(dArg, Tensors.fieldDerivative(field, field.getArgIndices(i).getInverted(), i)));
            }
            return DifferentiateTransformation.applyTransformations(EliminateMetricsTransformation.eliminate(result.build()), transformations);
        }
        if (tensor instanceof Sum) {
            SumBuilder builder = new SumBuilder();
            for (Tensor t : tensor) {
                Tensor temp = DifferentiateTransformation.differentiate1(t, rule, transformations);
                temp = DifferentiateTransformation.applyTransformations(temp, transformations);
                builder.put(temp);
            }
            return builder.build();
        }
        if (tensor instanceof ScalarFunction) {
            Tensor temp = Tensors.multiply(((ScalarFunction)tensor).derivative(), DifferentiateTransformation.differentiateWithRenaming(tensor.get(0), rule, transformations));
            temp = DifferentiateTransformation.applyTransformations(temp, transformations);
            return temp;
        }
        if (tensor instanceof Power) {
            Tensor temp = Tensors.sum(Tensors.multiplyAndRenameConflictingDummies(tensor.get(1), Tensors.pow(tensor.get(0), Tensors.sum(tensor.get(1), Complex.MINUS_ONE)), DifferentiateTransformation.differentiate1(tensor.get(0), rule, transformations)), Tensors.multiplyAndRenameConflictingDummies(tensor, Tensors.log(tensor.get(0)), DifferentiateTransformation.differentiateWithRenaming(tensor.get(1), rule, transformations)));
            temp = DifferentiateTransformation.applyTransformations(temp, transformations);
            return temp;
        }
        if (tensor instanceof Product) {
            SumBuilder result = new SumBuilder();
            for (int i = tensor.size() - 1; i >= 0; --i) {
                Tensor temp = tensor.set(i, DifferentiateTransformation.differentiate1(tensor.get(i), rule, transformations));
                if (rule.var.getIndices().size() != 0) {
                    temp = EliminateMetricsTransformation.eliminate(temp);
                }
                temp = DifferentiateTransformation.applyTransformations(temp, transformations);
                result.put(temp);
            }
            return result.build();
        }
        if (tensor instanceof Complex) {
            return Complex.ZERO;
        }
        throw new UnsupportedOperationException();
    }

    private static Tensor applyTransformations(Tensor tensor, Transformation[] transformations) {
        for (Transformation transformation : transformations) {
            tensor = transformation.transform(tensor);
        }
        return tensor;
    }

    private static SimpleTensorDifferentiationRule createRule(SimpleTensor var) {
        if (var.getIndices().size() == 0) {
            return new SymbolicDifferentiationRule(var);
        }
        return new SymmetricDifferentiationRule(var);
    }

    private static final class SymmetricDifferentiationRule
    extends SimpleTensorDifferentiationRule {
        private final Tensor derivative;
        private final int[] allFreeFrom;
        private final int[] freeVarIndices;

        private SymmetricDifferentiationRule(SimpleTensor var, Tensor derivative, int[] allFreeFrom, int[] freeVarIndices) {
            super(var);
            this.derivative = derivative;
            this.allFreeFrom = allFreeFrom;
            this.freeVarIndices = freeVarIndices;
        }

        SymmetricDifferentiationRule(SimpleTensor var) {
            super(var);
            int i;
            SimpleIndices varIndices = var.getIndices();
            int[] allFreeVarIndices = new int[varIndices.size()];
            int[] allFreeArgIndices = new int[varIndices.size()];
            int length = allFreeArgIndices.length;
            IndexGeneratorImpl indexGenerator = new IndexGeneratorImpl(varIndices);
            for (i = 0; i < length; ++i) {
                byte type = IndicesUtils.getType(varIndices.get(i));
                int state = IndicesUtils.getRawStateInt(varIndices.get(i));
                allFreeVarIndices[i] = IndicesUtils.setRawState(indexGenerator.generate(type), IndicesUtils.inverseIndexState(state));
                allFreeArgIndices[i] = IndicesUtils.setRawState(indexGenerator.generate(type), state);
            }
            int[] allIndices = ArraysUtils.addAll(allFreeVarIndices, allFreeArgIndices);
            SimpleIndices dIndices = IndicesFactory.createSimple(null, allIndices);
            SimpleTensor symmetric = Tensors.simpleTensor("@!@#@##_AS@23@@#", dIndices);
            SimpleIndices allFreeVarIndicesI = IndicesFactory.createSimple(varIndices.getSymmetries(), allFreeVarIndices);
            Tensor derivative = new SymmetrizeTransformation(allFreeVarIndicesI, true).transform(symmetric);
            derivative = ApplyIndexMapping.applyIndexMapping(derivative, new Mapping(allIndices, ArraysUtils.addAll(varIndices.getInverted().getAllIndices().copy(), allFreeArgIndices)), new int[0]);
            ProductBuilder builder = new ProductBuilder(0, length);
            for (i = 0; i < length; ++i) {
                builder.put(Tensors.createMetricOrKronecker(allFreeArgIndices[i], allFreeVarIndices[i]));
            }
            this.derivative = derivative = new SubstitutionTransformation(symmetric, builder.build()).transform(derivative);
            this.freeVarIndices = var.getIndices().getFree().getInverted().getAllIndices().copy();
            this.allFreeFrom = ArraysUtils.addAll(allFreeArgIndices, this.freeVarIndices);
        }

        @Override
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            int[] to = simpleTensor.getIndices().getAllIndices().copy();
            to = ArraysUtils.addAll(to, this.freeVarIndices);
            return ApplyIndexMapping.applyIndexMapping(this.derivative, new Mapping(this.allFreeFrom, to), new int[0]);
        }

        @Override
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return new SymmetricDifferentiationRule(this.var, ApplyIndexMapping.renameDummy(this.derivative, TensorUtils.getAllIndicesNamesT(tensor).toArray()), this.allFreeFrom, this.freeVarIndices);
        }

        @Override
        int[] getForbidden() {
            return TensorUtils.getAllIndicesNamesT(this.derivative).toArray();
        }
    }

    private static final class SymbolicDifferentiationRule
    extends SimpleTensorDifferentiationRule {
        private SymbolicDifferentiationRule(SimpleTensor var) {
            super(var);
        }

        @Override
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            return Complex.ONE;
        }

        @Override
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return this;
        }

        @Override
        int[] getForbidden() {
            return new int[0];
        }
    }

    private static abstract class SimpleTensorDifferentiationRule {
        protected final SimpleTensor var;

        protected SimpleTensorDifferentiationRule(SimpleTensor var) {
            this.var = var;
        }

        Tensor differentiateSimpleTensor(SimpleTensor simpleTensor) {
            if (simpleTensor.getName() != this.var.getName()) {
                return Complex.ZERO;
            }
            return this.differentiateSimpleTensorWithoutCheck(simpleTensor);
        }

        abstract SimpleTensorDifferentiationRule newRuleForTensor(Tensor var1);

        abstract Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor var1);

        abstract int[] getForbidden();
    }
}

