/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.transformations.factor;

import cc.redberry.core.number.Complex;
import cc.redberry.core.number.Rational;
import cc.redberry.core.tensor.Expression;
import cc.redberry.core.tensor.FastTensors;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.ProductBuilder;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorBuilder;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.tensor.iterator.FromChildToParentIterator;
import cc.redberry.core.tensor.iterator.FromParentToChildIterator;
import cc.redberry.core.tensor.iterator.TreeIteratorAbstract;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.factor.FactorizationEngine;
import cc.redberry.core.transformations.factor.JasFactor;
import cc.redberry.core.transformations.fractions.TogetherTransformation;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.LocalSymbolsProvider;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.math.BigInteger;
import java.util.ArrayList;

public class FactorTransformation
implements Transformation {
    public static final FactorTransformation FACTOR = new FactorTransformation(true, JasFactor.ENGINE);
    private final boolean factorScalars;
    private final FactorizationEngine factorizationEngine;

    public FactorTransformation(boolean factorScalars, FactorizationEngine factorizationEngine) {
        this.factorScalars = factorScalars;
        this.factorizationEngine = factorizationEngine;
    }

    public FactorizationEngine getFactorizationEngine() {
        return this.factorizationEngine;
    }

    @Override
    public Tensor transform(Tensor tensor) {
        if (this.factorScalars) {
            Expression[] replacementsOfScalars;
            for (Expression e : replacementsOfScalars = TensorUtils.generateReplacementsOfScalars(tensor, new LocalSymbolsProvider(tensor, "sclr"))) {
                tensor = e.transform(tensor);
            }
            tensor = this.factorSymbolicTerms(tensor);
            for (Expression e : replacementsOfScalars) {
                tensor = e.transpose().transform(tensor);
            }
            return tensor;
        }
        return this.factorSymbolicTerms(tensor);
    }

    public static Tensor factor(Tensor tensor, boolean factorScalars, FactorizationEngine factorizationEngine) {
        return new FactorTransformation(factorScalars, factorizationEngine).transform(tensor);
    }

    public static Tensor factor(Tensor tensor, boolean factorScalars) {
        return FactorTransformation.factor(tensor, factorScalars, JasFactor.ENGINE);
    }

    public static Tensor factor(Tensor tensor) {
        return FactorTransformation.factor(tensor, true, JasFactor.ENGINE);
    }

    private Tensor factorSymbolicTerms(Tensor tensor) {
        Tensor c;
        FromParentToChildIterator iterator = new FromParentToChildIterator(tensor);
        while ((c = iterator.next()) != null) {
            if (!(c instanceof Sum)) continue;
            Tensor remainder = c;
            IntArrayList symbolicPositions = new IntArrayList();
            for (int i = c.size() - 1; i >= 0; --i) {
                Tensor temp = c.get(i);
                if (!FactorTransformation.isSymbolic(temp)) continue;
                symbolicPositions.add(i);
                remainder = remainder instanceof Sum ? ((Sum)remainder).remove(i) : Complex.ZERO;
            }
            Tensor symbolicPart = ((Sum)c).select(symbolicPositions.toArray());
            symbolicPart = this.factorSymbolicTerm(symbolicPart);
            if (remainder instanceof Sum) {
                SumBuilder sb = new SumBuilder(remainder.size());
                for (Tensor tt : remainder) {
                    sb.put(this.factorSymbolicTerms(tt));
                }
                remainder = sb.build();
            } else {
                remainder = this.factorSymbolicTerms(remainder);
            }
            iterator.set(Tensors.sum(symbolicPart, remainder));
        }
        return iterator.result();
    }

    private Tensor factorSymbolicTerm(Tensor sum) {
        Tensor c;
        TreeIteratorAbstract iterator;
        if (this.factorizationEngine instanceof JasFactor) {
            iterator = new FromChildToParentIterator(sum);
            while ((c = iterator.next()) != null) {
                if (!(c instanceof Sum)) continue;
                iterator.set(this.factorOut(c));
            }
            sum = iterator.result();
        }
        iterator = new FromParentToChildIterator(sum);
        while ((c = iterator.next()) != null) {
            if (!(c instanceof Sum)) continue;
            if (FactorTransformation.needTogether(c)) {
                if ((c = TogetherTransformation.together(c, this)) instanceof Product) {
                    TensorBuilder pb = null;
                    for (int i = c.size() - 1; i >= 0; --i) {
                        if (c.get(i) instanceof Sum) {
                            if (pb == null) {
                                pb = c.getBuilder();
                                for (int j = c.size() - 1; j > i; --j) {
                                    pb.put(c.get(j));
                                }
                            }
                            pb.put(this.factorSum1(c.get(i)));
                            continue;
                        }
                        if (pb == null) continue;
                        pb.put(c.get(i));
                    }
                    iterator.set(pb == null ? c : pb.build());
                    continue;
                }
                iterator.set(c);
                continue;
            }
            iterator.set(this.factorSum1(c));
        }
        return iterator.result();
    }

    private Tensor factorSum1(Tensor sum) {
        Tensor[] parts = FactorTransformation.reIm(sum);
        if (!TensorUtils.isZero(parts[0])) {
            Tensor im = parts[0];
            im = im instanceof Sum ? FastTensors.multiplySumElementsOnFactor((Sum)im, Complex.IMAGINARY_UNIT) : Tensors.multiply(im, Complex.IMAGINARY_UNIT);
            im = this.factorizationEngine.factor(im);
            parts[0] = im = Tensors.multiply(im, Complex.NEGATIVE_IMAGINARY_UNIT);
        }
        if (!TensorUtils.isZero(parts[1])) {
            parts[1] = this.factorizationEngine.factor(parts[1]);
        }
        return Tensors.sum(parts[0], parts[1]);
    }

    private static Tensor[] reIm(Tensor sum) {
        IntArrayList im = new IntArrayList(sum.size());
        for (int i = sum.size() - 1; i >= 0; --i) {
            if (sum.get(i) instanceof Complex && !((Complex)sum.get(i)).getImaginary().isZero()) {
                im.add(i);
                continue;
            }
            if (!(sum.get(i) instanceof Product) || ((Product)sum.get(i)).getFactor().getImaginary().isZero()) continue;
            im.add(i);
        }
        Tensor[] parts = new Tensor[2];
        int[] positions = im.toArray();
        parts[0] = ((Sum)sum).select(positions);
        parts[1] = ((Sum)sum).remove(positions);
        return parts;
    }

    private static boolean needTogether(Tensor t) {
        if (t instanceof Power) {
            if (FactorTransformation.needTogether(t.get(0))) {
                return true;
            }
            return ((Complex)t.get(1)).getReal().signum() < 0;
        }
        if (t instanceof SimpleTensor) {
            return false;
        }
        for (Tensor tt : t) {
            if (!FactorTransformation.needTogether(tt)) continue;
            return true;
        }
        return false;
    }

    private static boolean isSymbolic(Tensor t) {
        if (t.getIndices().size() != 0 || t instanceof ScalarFunction) {
            return false;
        }
        if (t instanceof SimpleTensor) {
            return t.size() == 0;
        }
        if (t instanceof Power) {
            if (!FactorTransformation.isSymbolic(t.get(0))) {
                return false;
            }
            if (!TensorUtils.isInteger(t.get(1))) {
                return false;
            }
            Complex e = (Complex)t.get(1);
            return e.isReal() && !e.isNumeric();
        }
        for (Tensor tt : t) {
            if (FactorTransformation.isSymbolic(tt)) continue;
            return false;
        }
        return true;
    }

    Tensor factorOut(Tensor tensor) {
        Tensor c;
        FromChildToParentIterator iterator = new FromChildToParentIterator(tensor);
        while ((c = iterator.next()) != null) {
            if (!(c instanceof Sum)) continue;
            iterator.set(this.factorOut1(c));
        }
        return iterator.result();
    }

    private static boolean isProductOfSums(Tensor tensor) {
        if (tensor instanceof Sum) {
            return false;
        }
        if (tensor instanceof Product && tensor instanceof Product) {
            for (Tensor t : tensor) {
                if (!FactorTransformation.isIntegerPowerOfSum(t)) continue;
                return true;
            }
        }
        return FactorTransformation.isIntegerPowerOfSum(tensor);
    }

    private static boolean isIntegerPowerOfSum(Tensor tensor) {
        if (tensor instanceof Sum) {
            return true;
        }
        return tensor instanceof Power && tensor.get(0) instanceof Sum && TensorUtils.isInteger(tensor.get(1));
    }

    Tensor factorOut1(Tensor tensor) {
        Term[] terms;
        int i;
        Boolean factorOutImageOne = null;
        for (Tensor t : tensor) {
            boolean containsImageOne = t instanceof Product ? ((Product)t).getFactor().isImaginary() : (t instanceof Complex ? ((Complex)t).isImaginary() : false);
            if (factorOutImageOne == null) {
                factorOutImageOne = containsImageOne;
                continue;
            }
            if (factorOutImageOne == containsImageOne) continue;
            factorOutImageOne = false;
        }
        if (factorOutImageOne.booleanValue()) {
            tensor = FastTensors.multiplySumElementsOnFactor((Sum)tensor, Complex.NEGATIVE_IMAGINARY_UNIT);
        }
        if (!(tensor instanceof Sum)) {
            if (factorOutImageOne.booleanValue()) {
                tensor = Tensors.multiply(Complex.IMAGINARY_UNIT, tensor);
            }
            return this.factorOut(tensor);
        }
        Tensor temp = tensor;
        int j = temp.size();
        IntArrayList nonProductOfSumsPositions = new IntArrayList();
        for (i = 0; i < j; ++i) {
            if (FactorTransformation.isProductOfSums(temp.get(i))) continue;
            nonProductOfSumsPositions.add(i);
        }
        Int pivotPosition = new Int();
        if (nonProductOfSumsPositions.isEmpty() || nonProductOfSumsPositions.size() == temp.size()) {
            terms = FactorTransformation.sum2SplitArray((Sum)temp, pivotPosition);
        } else {
            SumBuilder sb = new SumBuilder();
            for (i = nonProductOfSumsPositions.size() - 1; i >= 0; --i) {
                assert (temp instanceof Sum);
                sb.put(temp.get(nonProductOfSumsPositions.get(i)));
                temp = ((Sum)temp).remove(nonProductOfSumsPositions.get(i));
            }
            Tensor withoutSumsTerm = this.factorSymbolicTerms(sb.build());
            if (FactorTransformation.isProductOfSums(withoutSumsTerm)) {
                if (!((temp = Tensors.sum(temp, withoutSumsTerm)) instanceof Sum)) {
                    return temp;
                }
                terms = FactorTransformation.sum2SplitArray((Sum)temp, pivotPosition);
            } else {
                if (!(temp instanceof Sum)) {
                    terms = new Term[]{FactorTransformation.tensor2term(temp), FactorTransformation.tensor2term(withoutSumsTerm)};
                } else {
                    terms = new Term[temp.size() + 1];
                    System.arraycopy(FactorTransformation.sum2SplitArray((Sum)temp, pivotPosition), 0, terms, 0, temp.size());
                    terms[temp.size()] = FactorTransformation.tensor2term(withoutSumsTerm);
                }
                pivotPosition.value = terms[pivotPosition.value].factors.length > terms[terms.length - 1].factors.length ? terms.length - 1 : pivotPosition.value;
            }
        }
        temp = FactorTransformation.mergeTerms(terms, pivotPosition.value, tensor);
        if (factorOutImageOne.booleanValue()) {
            temp = Tensors.multiply(Complex.IMAGINARY_UNIT, temp);
        }
        return temp;
    }

    private static Tensor mergeTerms(Term[] terms, int pivotPosition, Tensor tensor) {
        FactorNode baseNode;
        Term pivot = terms[pivotPosition];
        ArrayList<FactorNode> baseFactors = new ArrayList<FactorNode>(pivot.factors.length);
        for (FactorNode node : pivot.factors) {
            baseNode = node.clone();
            baseFactors.add(baseNode);
            baseNode.minExponent = new BigInt();
        }
        Boolean sign = null;
        for (int i = terms.length - 1; i >= 0; --i) {
            if (baseFactors.isEmpty()) {
                return tensor;
            }
            for (int j = baseFactors.size() - 1; j >= 0; --j) {
                baseNode = (FactorNode)baseFactors.get(j);
                ArrayList tempList = (ArrayList)terms[i].map.get(baseNode.tensor.hashCode());
                if (tempList == null) {
                    baseFactors.remove(j);
                    continue;
                }
                for (Object nn : tempList) {
                    sign = TensorUtils.compare1(baseNode.tensor, ((FactorNode)nn).tensor);
                    if (sign == null) continue;
                    BigInteger baseExponent = baseNode.exponent;
                    BigInteger tempExponent = ((FactorNode)nn).exponent;
                    if (baseExponent.signum() != tempExponent.signum()) {
                        baseFactors.remove(j);
                        continue;
                    }
                    if (sign.booleanValue()) {
                        ((FactorNode)nn).diffSigns = true;
                    }
                    ((FactorNode)nn).minExponent = baseNode.minExponent;
                    if (baseExponent.signum() > 0 && baseExponent.compareTo(tempExponent) > 0) {
                        baseNode.exponent = tempExponent;
                        break;
                    }
                    if (baseExponent.signum() >= 0 || baseExponent.compareTo(tempExponent) >= 0) break;
                    baseNode.exponent = tempExponent;
                    break;
                }
                if (sign != null) continue;
                baseFactors.remove(j);
            }
        }
        if (baseFactors.isEmpty()) {
            return tensor;
        }
        ProductBuilder pb = new ProductBuilder(baseFactors.size(), baseFactors.size());
        for (FactorNode node : baseFactors) {
            pb.put(node.toTensor());
            node.minExponent.value = node.exponent;
        }
        SumBuilder sb = new SumBuilder(tensor.size());
        for (Term term : terms) {
            sb.put(FactorTransformation.nodesToProduct(term.factors));
        }
        return Tensors.multiply(pb.build(), sb.build());
    }

    private static Term[] sum2SplitArray(Sum sum, Int pivotPosition) {
        Term[] terms = new Term[sum.size()];
        int pivotSumsCount = Integer.MAX_VALUE;
        int pivotPosition1 = -1;
        for (int i = sum.size() - 1; i >= 0; --i) {
            terms[i] = FactorTransformation.tensor2term(sum.get(i));
            if (terms[i].factors.length >= pivotSumsCount) continue;
            pivotSumsCount = terms[i].factors.length;
            pivotPosition1 = i;
        }
        pivotPosition.value = pivotPosition1;
        return terms;
    }

    private static Term tensor2term(Tensor tensor) {
        if (tensor instanceof Product) {
            FactorNode[] factors = new FactorNode[tensor.size()];
            int i = -1;
            for (Tensor t : tensor) {
                factors[++i] = FactorTransformation.createNode(t);
            }
            return new Term(factors);
        }
        return new Term(new FactorNode[]{FactorTransformation.createNode(tensor)});
    }

    private static FactorNode createNode(Tensor tensor) {
        if (tensor instanceof Power && TensorUtils.isInteger(tensor.get(1))) {
            return new FactorNode(tensor.get(0), ((Rational)((Complex)tensor.get(1)).getReal()).getNumerator());
        }
        return new FactorNode(tensor, BigInteger.ONE);
    }

    private static Tensor nodesToProduct(FactorNode[] nodes) {
        Tensor[] tensors = new Tensor[nodes.length];
        for (int i = nodes.length - 1; i >= 0; --i) {
            tensors[i] = nodes[i].toTensor();
        }
        return Tensors.multiply(tensors);
    }

    private static final class Int {
        int value;

        private Int() {
        }
    }

    private static final class BigInt {
        BigInteger value;

        private BigInt() {
        }
    }

    private static class FactorNode {
        final Tensor tensor;
        BigInteger exponent;
        BigInt minExponent;
        boolean diffSigns = false;

        private FactorNode(Tensor tensor, BigInteger exponent) {
            this.tensor = tensor;
            this.exponent = exponent;
        }

        Tensor toTensor() {
            BigInteger exponent = this.exponent;
            if (this.minExponent != null && this.minExponent.value != null) {
                exponent = exponent.subtract(this.minExponent.value);
                if (this.diffSigns && this.minExponent.value.testBit(0)) {
                    return Tensors.negate(Tensors.pow(this.tensor, new Complex(exponent)));
                }
            }
            return Tensors.pow(this.tensor, new Complex(exponent));
        }

        public boolean equals(Object o) {
            return TensorUtils.equals(((FactorNode)o).tensor, this.tensor);
        }

        public int hashCode() {
            return this.tensor.hashCode();
        }

        public String toString() {
            return this.tensor + " -> " + this.exponent;
        }

        public FactorNode clone() {
            return new FactorNode(this.tensor, this.exponent);
        }
    }

    private static class Term {
        final FactorNode[] factors;
        final TIntObjectHashMap<ArrayList<FactorNode>> map;

        private Term(FactorNode[] factors) {
            this.factors = factors;
            this.map = new TIntObjectHashMap(factors.length);
            for (FactorNode t : factors) {
                ArrayList<FactorNode> list = (ArrayList<FactorNode>)this.map.get(t.tensor.hashCode());
                if (list != null) {
                    list.add(t);
                    continue;
                }
                list = new ArrayList<FactorNode>();
                list.add(t);
                this.map.put(t.tensor.hashCode(), list);
            }
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            int i = 0;
            while (true) {
                sb.append('(').append(this.factors[i]).append(')');
                if (i == this.factors.length - 1) {
                    return sb.toString();
                }
                sb.append("*");
                ++i;
            }
        }
    }
}

