/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.transformations.expand;

import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.FastTensors;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.expand.ExpandTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.OutputPort;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.TIntCollection;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;

public final class ExpandUtils {
    public static final Transformation expandIndexlessSubproduct = new Transformation(){

        @Override
        public Tensor transform(Tensor t) {
            if (!(t instanceof Product)) {
                return t;
            }
            Product p = (Product)t;
            Tensor indexless = p.getIndexlessSubProduct();
            return Tensors.multiply(ExpandTransformation.expand(indexless), p.getDataSubProduct());
        }
    };

    public static Tensor expandPairOfSums(Sum s1, Sum s2, Tensor[] factors, Transformation[] transformations) {
        Tensor t;
        ExpandPairPort epp = new ExpandPairPort(s1, s2, factors);
        SumBuilder sum = new SumBuilder(s1.size() * s2.size());
        while ((t = epp.take()) != null) {
            sum.put(ExpandUtils.apply(transformations, t));
        }
        return sum.build();
    }

    public static Tensor expandPairOfSums(Sum s1, Sum s2, Transformation[] transformations) {
        return ExpandUtils.expandPairOfSums(s1, s2, new Tensor[0], transformations);
    }

    public static Tensor expandProductOfSums(Product product, Transformation[] transformations) {
        Tensor indexless = product.getIndexlessSubProduct();
        Tensor data = product.getDataSubProduct();
        boolean expandIndexless = false;
        boolean expandData = false;
        boolean containsIndexlessSumNeededExpand = false;
        if (indexless instanceof Sum && ExpandUtils.sumContainsIndexed(indexless)) {
            containsIndexlessSumNeededExpand = true;
            expandIndexless = true;
            expandData = true;
        }
        if (indexless instanceof Product) {
            for (Tensor t : indexless) {
                if (!(t instanceof Sum)) continue;
                if (ExpandUtils.sumContainsIndexed(t)) {
                    containsIndexlessSumNeededExpand = true;
                    expandData = true;
                    expandIndexless = true;
                    break;
                }
                expandIndexless = true;
            }
        }
        if (!expandData) {
            if (data instanceof Sum) {
                expandData = true;
            }
            if (data instanceof Product) {
                for (Tensor t : data) {
                    if (!(t instanceof Sum)) continue;
                    expandData = true;
                    break;
                }
            }
        }
        if (!expandData && !expandIndexless) {
            return product;
        }
        if (!expandData) {
            return Tensors.multiply(ExpandUtils.expandProductOfSums1(indexless, transformations, false), data);
        }
        if (!expandIndexless) {
            Tensor newData = ExpandUtils.expandProductOfSums1(data, transformations, true);
            if (newData instanceof Sum) {
                return FastTensors.multiplySumElementsOnFactorAndExpand((Sum)newData, indexless);
            }
            return expandIndexlessSubproduct.transform(Tensors.multiply(indexless, newData));
        }
        if (!containsIndexlessSumNeededExpand) {
            indexless = ExpandUtils.expandProductOfSums1(indexless, transformations, false);
            data = ExpandUtils.expandProductOfSums1(data, transformations, true);
        } else {
            ArrayList<Object> dataList;
            if (data instanceof Product) {
                dataList = new ArrayList<Tensor>(Arrays.asList(data.toArray()));
            } else {
                dataList = new ArrayList<Tensor>();
                dataList.add(data);
            }
            if (indexless instanceof Sum) {
                dataList.add(indexless);
                indexless = Complex.ONE;
                data = ExpandUtils.expandProductOfSums1(dataList, transformations, true);
            } else {
                assert (indexless instanceof Product);
                ArrayList<Tensor> indexlessList = new ArrayList<Tensor>(indexless.size());
                expandIndexless = false;
                for (Tensor in : indexless) {
                    if (ExpandUtils.sumContainsIndexed(in)) {
                        dataList.add(in);
                        continue;
                    }
                    if (in instanceof Sum) {
                        expandIndexless = true;
                    }
                    indexlessList.add(in);
                }
                indexless = expandIndexless ? ExpandUtils.expandProductOfSums1(indexlessList, transformations, false) : Tensors.multiply(indexlessList.toArray(new Tensor[indexlessList.size()]));
                data = ExpandUtils.expandProductOfSums1(dataList, transformations, true);
            }
        }
        if (data instanceof Sum) {
            return FastTensors.multiplySumElementsOnFactorAndExpand((Sum)data, indexless);
        }
        return Tensors.multiply(indexless, data);
    }

    public static Tensor apply(Transformation[] transformations, Tensor tensor) {
        for (Transformation tr : transformations) {
            tensor = tr.transform(tensor);
        }
        return tensor;
    }

    public static Tensor expandProductOfSums1(Iterable<Tensor> tensor, Transformation[] transformations, boolean indexed) {
        Transformation[] transformations1 = indexed ? ArraysUtils.addAll(new Transformation[]{expandIndexlessSubproduct}, transformations) : transformations;
        int capacity = 10;
        boolean isTensor = tensor instanceof Tensor;
        if (isTensor) {
            if (!(tensor instanceof Product)) {
                return (Tensor)tensor;
            }
            capacity = ((Tensor)tensor).size();
        }
        ArrayList<Tensor> ns = new ArrayList<Tensor>(capacity);
        ArrayList<Sum> sums = new ArrayList<Sum>(capacity);
        for (Tensor t : tensor) {
            if (t instanceof Sum) {
                sums.add((Sum)t);
                continue;
            }
            ns.add(t);
        }
        if (sums.isEmpty()) {
            if (isTensor) {
                return (Tensor)tensor;
            }
            return Tensors.multiply(ns.toArray(new Tensor[ns.size()]));
        }
        if (sums.size() == 1) {
            if (indexed) {
                return ExpandUtils.multiplySumElementsOnFactorAndExpand((Sum)sums.get(0), Tensors.multiply(ns.toArray(new Tensor[ns.size()])), transformations);
            }
            return ExpandUtils.multiplySumElementsOnFactor((Sum)sums.get(0), Tensors.multiply(ns.toArray(new Tensor[ns.size()])), transformations);
        }
        Tensor base = (Tensor)sums.get(0);
        int i = 1;
        int size = sums.size();
        while (true) {
            if (i == size - 1) {
                if (base == null) {
                    if (indexed) {
                        return ExpandUtils.multiplySumElementsOnFactorAndExpand((Sum)sums.get(i), Tensors.multiply(ns.toArray(new Tensor[ns.size()])), transformations);
                    }
                    return ExpandUtils.multiplySumElementsOnFactor((Sum)sums.get(i), Tensors.multiply(ns.toArray(new Tensor[ns.size()])), transformations);
                }
                return ExpandUtils.expandPairOfSums((Sum)base, (Sum)sums.get(i), ns.toArray(new Tensor[ns.size()]), transformations1);
            }
            if (base == null) {
                base = (Tensor)sums.get(i);
            } else if (!((base = ExpandUtils.expandPairOfSums((Sum)base, (Sum)sums.get(i), transformations1)) instanceof Sum)) {
                ns.add(base);
                base = null;
            }
            ++i;
        }
    }

    public static Tensor multiplySumElementsOnFactorAndExpand(Sum sum, Tensor factor, Transformation[] transformations) {
        if (TensorUtils.isZero(factor)) {
            return Complex.ZERO;
        }
        if (TensorUtils.isOne(factor)) {
            return sum;
        }
        if (factor instanceof Sum && factor.getIndices().size() != 0) {
            throw new IllegalArgumentException();
        }
        if (TensorUtils.haveIndicesIntersections(sum, factor)) {
            SumBuilder sb = new SumBuilder(sum.size());
            for (Tensor t : sum) {
                sb.put(ExpandUtils.apply(transformations, expandIndexlessSubproduct.transform(Tensors.multiply(t, factor))));
            }
            return sb.build();
        }
        return ExpandUtils.apply(transformations, FastTensors.multiplySumElementsOnFactorAndExpand(sum, factor));
    }

    public static Tensor multiplySumElementsOnFactor(Sum sum, Tensor factor, Transformation[] transformations) {
        if (TensorUtils.isZero(factor)) {
            return Complex.ZERO;
        }
        if (TensorUtils.isOne(factor)) {
            return sum;
        }
        if (TensorUtils.haveIndicesIntersections(sum, factor)) {
            SumBuilder sb = new SumBuilder(sum.size());
            for (Tensor t : sum) {
                sb.put(ExpandUtils.apply(transformations, Tensors.multiply(t, factor)));
            }
            return sb.build();
        }
        return ExpandUtils.apply(transformations, FastTensors.multiplySumElementsOnFactor(sum, factor));
    }

    public static boolean isExpandablePower(Tensor t) {
        return t instanceof Power && t.get(0) instanceof Sum && TensorUtils.isInteger(t.get(1));
    }

    public static boolean sumContainsIndexed(Tensor t) {
        if (!(t instanceof Sum)) {
            return false;
        }
        for (Tensor s : t) {
            if (s.getIndices().size() == 0) continue;
            return true;
        }
        return false;
    }

    public static Tensor expandSymbolicPower(Sum argument, int power, Transformation[] transformations) {
        Tensor temp = argument;
        for (int i = power - 1; i >= 1; --i) {
            if ((temp = ExpandUtils.expandPairOfSums(temp, argument, transformations)) instanceof Sum) continue;
            temp = Tensors.multiply(temp, ExpandUtils.apply(transformations, Tensors.pow((Tensor)argument, i - 1)));
            break;
        }
        return temp;
    }

    public static Tensor expandPower(Sum argument, int power, int[] forbiddenIndices, Transformation[] transformations) {
        Tensor temp = argument;
        TIntHashSet forbidden = new TIntHashSet(forbiddenIndices);
        TIntHashSet argIndices = TensorUtils.getAllIndicesNamesT(argument);
        forbidden.ensureCapacity(argIndices.size() * power);
        forbidden.addAll((TIntCollection)argIndices);
        for (int i = power - 1; i >= 1; --i) {
            if ((temp = ExpandUtils.expandPairOfSums(temp, (Sum)ApplyIndexMapping.renameDummy((Tensor)argument, forbidden.toArray(), forbidden), transformations)) instanceof Sum) continue;
            temp = Tensors.multiply(temp, ExpandUtils.apply(transformations, Tensors.pow((Tensor)argument, i - 1)));
            break;
        }
        return temp;
    }

    public static final class ExpandPairPort
    implements OutputPort<Tensor> {
        private final Tensor sum1;
        private final Tensor sum2;
        private final Tensor[] factors;
        private long index = 0L;

        public ExpandPairPort(Sum s1, Sum s2) {
            this.sum1 = s1;
            this.sum2 = s2;
            this.factors = new Tensor[0];
        }

        public ExpandPairPort(Sum s1, Sum s2, Tensor[] factors) {
            this.sum1 = s1;
            this.sum2 = s2;
            this.factors = factors;
        }

        @Override
        public Tensor take() {
            if (this.index >= (long)(this.sum1.size() * this.sum2.size())) {
                return null;
            }
            int i1 = (int)(this.index / (long)this.sum2.size());
            int i2 = (int)(this.index % (long)this.sum2.size());
            ++this.index;
            if (this.factors.length == 0) {
                return Tensors.multiply(this.sum1.get(i1), this.sum2.get(i2));
            }
            return Tensors.multiply(ArraysUtils.addAll(this.factors, this.sum1.get(i1), this.sum2.get(i2)));
        }
    }
}

