/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.transformations.collect;

import cc.redberry.core.groups.permutations.Permutations;
import cc.redberry.core.indexgenerator.IndexGeneratorImpl;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesBuilder;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.Expression;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.expand.ExpandPort;
import cc.redberry.core.transformations.powerexpand.PowerUnfoldTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.OutputPort;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;

public class CollectTransformation
implements Transformation {
    private final TIntHashSet patternsNames = new TIntHashSet();
    private final Transformation powerExpand;
    private final Transformation[] transformations;

    public CollectTransformation(SimpleTensor[] patterns, Transformation[] transformations) {
        this.powerExpand = new PowerUnfoldTransformation(patterns);
        for (SimpleTensor t : patterns) {
            this.patternsNames.add(t.getName());
        }
        this.transformations = transformations;
    }

    public CollectTransformation(SimpleTensor ... patterns) {
        this(patterns, new Transformation[0]);
    }

    @Override
    public Tensor transform(Tensor t) {
        if (t instanceof Expression) {
            return Transformation.Util.applyToEachChild(t, this);
        }
        return this.transform1(t);
    }

    private Tensor transform1(Tensor t) {
        Tensor current;
        SumBuilder notMatched = new SumBuilder();
        TIntObjectHashMap map = new TIntObjectHashMap();
        OutputPort<Tensor> port = ExpandPort.createPort(t);
        block0: while ((current = port.take()) != null) {
            Split toAdd = this.split(current);
            if (toAdd.factors.length == 0) {
                notMatched.put(current);
                continue;
            }
            ArrayList<Split> nodes = (ArrayList<Split>)map.get(toAdd.hashCode);
            if (nodes == null) {
                nodes = new ArrayList<Split>();
                nodes.add(toAdd);
                map.put(toAdd.hashCode, nodes);
                continue;
            }
            for (Split base : nodes) {
                int[] match = CollectTransformation.matchFactors(base.factors, toAdd.factors);
                if (match == null) continue;
                Tensor[] toAddFactors = Permutations.permute(toAdd.factors, match);
                Mapping mapping = IndexMappings.createBijectiveProductPort(toAddFactors, base.factors).take();
                base.summands.add(ApplyIndexMapping.applyIndexMappingAutomatically(toAdd.summands.get(0), mapping, base.forbidden));
                continue block0;
            }
            nodes.add(toAdd);
        }
        Tensor r = Transformation.Util.applySequentially(notMatched.build(), this.transformations);
        notMatched = new SumBuilder();
        notMatched.put(r);
        for (ArrayList splits : map.valueCollection()) {
            for (Split split : splits) {
                notMatched.put(split.toTensor(this.transformations));
            }
        }
        return notMatched.build();
    }

    private boolean match(Tensor t) {
        if (t instanceof SimpleTensor) {
            return this.patternsNames.contains(t.hashCode());
        }
        if (TensorUtils.isPositiveIntegerPower(t)) {
            return this.patternsNames.contains(t.get(0).hashCode());
        }
        return false;
    }

    /*
     * WARNING - void declaration
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Split split(Tensor tensor) {
        void var3_8;
        Tensor[] factors;
        if (tensor instanceof SimpleTensor || TensorUtils.isPositiveIntegerPowerOfSimpleTensor(tensor)) {
            if (!this.match(tensor)) return new Split(new Tensor[0], tensor);
            factors = new Tensor[]{tensor};
            Complex complex = Complex.ONE;
        } else {
            if (!(tensor instanceof Product) && !TensorUtils.isPositiveIntegerPowerOfProduct(tensor)) return new Split(new Tensor[0], tensor);
            tensor = this.powerExpand.transform(tensor);
            boolean containsMatch = false;
            for (Object t : tensor instanceof Product ? tensor : tensor.get(0)) {
                if (!this.match((Tensor)t)) continue;
                containsMatch = true;
                break;
            }
            if (!containsMatch) {
                return new Split(new Tensor[0], tensor);
            }
            assert (tensor instanceof Product);
            ArrayList<Tensor> factorsList = new ArrayList<Tensor>();
            Tensor tensor2 = tensor;
            for (Tensor t : tensor) {
                void var3_5;
                if (!this.match(t)) continue;
                factorsList.add(t);
                assert (var3_5 != Complex.ONE);
                if (var3_5 instanceof Product) {
                    Tensor tensor3 = ((Product)var3_5).remove(t);
                    continue;
                }
                Complex complex = Complex.ONE;
            }
            factors = factorsList.toArray(new Tensor[factorsList.size()]);
        }
        TIntHashSet freeIndices = new TIntHashSet(IndicesUtils.getIndicesNames(tensor.getIndices().getFree()));
        Indices factorIndices = new IndicesBuilder().append(factors).getIndices();
        TIntHashSet dummies = new TIntHashSet(IndicesUtils.getIntersections(factorIndices.getUpper().toArray(), factorIndices.getLower().toArray()));
        IntArrayList from = new IntArrayList();
        IntArrayList to = new IntArrayList();
        ArrayList<Object> kroneckers = new ArrayList<Object>();
        IndexGeneratorImpl generator = new IndexGeneratorImpl(TensorUtils.getAllIndicesNamesT(tensor).toArray());
        for (int i = 0; i < factors.length; ++i) {
            from.clear();
            to.clear();
            SimpleIndices currentFactorIndices = IndicesFactory.createSimple(null, factors[i].getIndices());
            for (int j = currentFactorIndices.size() - 1; j >= 0; --j) {
                int newIndex;
                int index = currentFactorIndices.get(j);
                if (freeIndices.contains(IndicesUtils.getNameWithType(index))) {
                    newIndex = IndicesUtils.setRawState(IndicesUtils.getRawStateInt(index), generator.generate(IndicesUtils.getType(index)));
                    from.add(index);
                    to.add(newIndex);
                    kroneckers.add(Tensors.createKronecker(index, IndicesUtils.inverseIndexState(newIndex)));
                    continue;
                }
                if (!IndicesUtils.getState(index) || !dummies.contains(IndicesUtils.getNameWithType(index))) continue;
                newIndex = IndicesUtils.setRawState(IndicesUtils.getRawStateInt(index), generator.generate(IndicesUtils.getType(index)));
                from.add(index);
                to.add(newIndex);
                kroneckers.add(Tensors.createKronecker(index, IndicesUtils.inverseIndexState(newIndex)));
            }
            factors[i] = CollectTransformation.applyDirectMapping(factors[i], new StateSensitiveMapping(from.toArray(), to.toArray()));
        }
        kroneckers.add(var3_8);
        Tensor tensor5 = Tensors.multiply(kroneckers.toArray(new Tensor[kroneckers.size()]));
        tensor5 = EliminateMetricsTransformation.eliminate(tensor5);
        return new Split(factors, tensor5);
    }

    static int[] matchFactors(Tensor[] a, Tensor[] b) {
        if (a.length != b.length) {
            return null;
        }
        int begin = 0;
        int length = a.length;
        int[] permutation = new int[length];
        Arrays.fill(permutation, -1);
        for (int i = 1; i <= length; ++i) {
            if (i != length && a[i].hashCode() == b[i - 1].hashCode()) continue;
            if (i - 1 != begin) {
                int n = begin;
                while (n < i) {
                    block8: {
                        for (int j = begin; j < i; ++j) {
                            if (permutation[j] != -1 || !CollectTransformation.matchSimpleTensors(a[n], b[j])) {
                                continue;
                            }
                            break block8;
                        }
                        return null;
                    }
                    permutation[j] = n++;
                }
            } else {
                if (!CollectTransformation.matchSimpleTensors(a[i - 1], b[i - 1])) {
                    return null;
                }
                permutation[i - 1] = i - 1;
            }
            begin = i;
        }
        return Permutations.inverse(permutation);
    }

    private static boolean matchSimpleTensors(Tensor a, Tensor b) {
        if (a.getClass() != b.getClass()) {
            return false;
        }
        if (a.hashCode() != b.hashCode()) {
            return false;
        }
        if (TensorUtils.isPositiveIntegerPowerOfSimpleTensor(a)) {
            return TensorUtils.isPositiveIntegerPowerOfSimpleTensor(b) && a.get(1).equals(b.get(1)) && CollectTransformation.matchSimpleTensors(a.get(0), b.get(0));
        }
        if (a instanceof TensorField) {
            for (int i = a.size() - 1; i >= 0; --i) {
                if (IndexMappings.positiveMappingExists(a.get(i), b.get(i))) continue;
                return false;
            }
        }
        return true;
    }

    private static Tensor applyDirectMapping(Tensor t, DirectIndexMapping mapping) {
        if (t instanceof SimpleTensor) {
            SimpleTensor st = (SimpleTensor)t;
            SimpleIndices newIndices = st.getIndices().applyIndexMapping(mapping);
            if (t instanceof TensorField) {
                return Tensors.field(st.getName(), newIndices, ((TensorField)st).getArgIndices(), ((TensorField)st).getArguments());
            }
            return Tensors.simpleTensor(st.getName(), newIndices);
        }
        assert (t.getIndices().size() == 0);
        return t;
    }

    private static final class StateSensitiveMapping
    extends DirectIndexMapping {
        private StateSensitiveMapping(int[] from, int[] to) {
            super(from, to);
        }

        @Override
        public int map(int from) {
            int index = Arrays.binarySearch(this.from, from);
            if (index >= 0) {
                return this.to[index];
            }
            return from;
        }
    }

    private static abstract class DirectIndexMapping
    implements IndexMapping {
        final int[] from;
        final int[] to;

        private DirectIndexMapping(int[] from, int[] to) {
            ArraysUtils.quickSort(from, to);
            this.from = from;
            this.to = to;
        }
    }

    private static final class Split {
        final Tensor[] factors;
        final ArrayList<Tensor> summands = new ArrayList();
        final int hashCode;
        final int[] forbidden;

        private Split(Tensor[] factors, Tensor summand) {
            this.factors = factors;
            this.summands.add(summand);
            Arrays.sort(factors);
            this.hashCode = Arrays.hashCode(factors);
            this.forbidden = IndicesUtils.getIndicesNames(new IndicesBuilder().append(factors).getIndices());
        }

        public int hashCode() {
            return this.hashCode;
        }

        Tensor toTensor(Transformation[] transformations) {
            Tensor sum = Transformation.Util.applySequentially(Tensors.sum(this.summands.toArray(new Tensor[this.summands.size()])), transformations);
            Tensor[] ms = new Tensor[this.factors.length + 1];
            ms[ms.length - 1] = sum;
            System.arraycopy(this.factors, 0, ms, 0, this.factors.length);
            return Tensors.multiply(ms);
        }

        public String toString() {
            return Tensors.multiply(this.factors) + " : " + Tensors.sum(this.summands.toArray(new Tensor[this.summands.size()]));
        }
    }
}

