/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.tensorgenerator;

import cc.redberry.core.context.CC;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.indices.StructureOfIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.number.Rational;
import cc.redberry.core.solver.frobenius.FrobeniusSolver;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.FastTensors;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorBuilder;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensorgenerator.GeneratedTensor;
import cc.redberry.core.tensorgenerator.TensorGeneratorUtils;
import cc.redberry.core.transformations.expand.ExpandTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeUpperLowerIndicesTransformation;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.math3.util.ArithmeticUtils;

public class TensorGenerator {
    private final Tensor[] samples;
    private final int[] lowerArray;
    private final int[] upperArray;
    private final List<SimpleTensor> coefficients = new ArrayList<SimpleTensor>();
    private final boolean symmetricForm;
    private final SimpleIndices indices;
    private Tensor result;
    private final boolean withCoefficients;

    private TensorGenerator(SimpleIndices indices, Tensor[] samples, boolean symmetricForm, boolean withCoefficients, boolean raiseLowerSamples) {
        this.samples = raiseLowerSamples ? TensorGenerator.expandSamples(samples) : samples;
        this.indices = indices;
        this.symmetricForm = symmetricForm;
        this.lowerArray = indices.getLower().toArray();
        this.upperArray = indices.getUpper().toArray();
        this.withCoefficients = withCoefficients;
        Arrays.sort(this.lowerArray);
        Arrays.sort(this.upperArray);
        this.generate();
    }

    private void generate() {
        int[] combination;
        int i;
        int totalLowCount = this.lowerArray.length;
        int[] lowCounts = new int[this.samples.length + 1];
        for (i = 0; i < this.samples.length; ++i) {
            lowCounts[i] = this.samples[i].getIndices().getFree().getLower().size();
        }
        lowCounts[i] = totalLowCount;
        int totalUpCount = this.upperArray.length;
        int[] upCounts = new int[this.samples.length + 1];
        for (i = 0; i < this.samples.length; ++i) {
            upCounts[i] = this.samples[i].getIndices().getFree().getUpper().size();
        }
        upCounts[i] = totalUpCount;
        FrobeniusSolver fbSolver = new FrobeniusSolver(lowCounts, upCounts);
        SumBuilder result = new SumBuilder();
        while ((combination = fbSolver.take()) != null) {
            ArrayList<Tensor> tCombination = new ArrayList<Tensor>();
            int u = 0;
            int l = 0;
            for (i = 0; i < combination.length; ++i) {
                for (int j = 0; j < combination[i]; ++j) {
                    int k;
                    Tensor temp = this.samples[i];
                    Indices termLow = temp.getIndices().getFree().getLower();
                    Indices termUp = temp.getIndices().getFree().getUpper();
                    int[] oldIndices = new int[termUp.size() + termLow.size()];
                    int[] newIndices = (int[])oldIndices.clone();
                    for (k = 0; k < termUp.size(); ++k) {
                        oldIndices[k] = termUp.get(k);
                        newIndices[k] = this.upperArray[u++];
                    }
                    for (k = 0; k < termLow.size(); ++k) {
                        oldIndices[k + termUp.size()] = termLow.get(k);
                        newIndices[k + termUp.size()] = this.lowerArray[l++];
                    }
                    temp = ApplyIndexMapping.applyIndexMapping(temp, new Mapping(oldIndices, newIndices), this.indices.getAllIndices().copy());
                    tCombination.add(temp);
                }
            }
            Tensor[] prodArray = tCombination.toArray(new Tensor[tCombination.size()]);
            Tensors.resolveAllDummies(prodArray);
            Tensor term = SymmetrizeUpperLowerIndicesTransformation.symmetrizeUpperLowerIndices(Tensors.multiplyAndRenameConflictingDummies(prodArray));
            if (this.symmetricForm || !(term instanceof Sum)) {
                Tensor coefficient;
                if (this.withCoefficients) {
                    coefficient = CC.generateNewSymbol();
                    this.coefficients.add((SimpleTensor)coefficient);
                } else {
                    coefficient = Complex.ONE;
                }
                term = Tensors.multiply(coefficient, term, term instanceof Sum ? new Complex(new Rational(1, term.size())) : Complex.ONE);
            } else if (this.withCoefficients) {
                term = FastTensors.multiplySumElementsOnFactors((Sum)term);
            }
            result.put(term);
        }
        this.result = this.indices.getSymmetries().isTrivial() ? result.build() : this.symmetrize(result.build());
    }

    private Tensor symmetrize(Tensor result) {
        result = new SymmetrizeTransformation(this.indices, false).transform(result);
        if (!((result = ExpandTransformation.expand(result)) instanceof Sum)) {
            return result;
        }
        TIntObjectHashMap coefficients = new TIntObjectHashMap();
        Tensor newCoefficient = null;
        TensorBuilder rebuild = result.getBuilder();
        for (Tensor t : result) {
            Tensor[] sc;
            assert (t instanceof Product);
            if (!(t instanceof Product) || (sc = ((Product)t).getAllScalarsWithoutFactor()).length == 0) continue;
            assert (sc.length == 1);
            Tensor oldCoefficient = sc[0];
            ArrayList<Tensor[]> list = (ArrayList<Tensor[]>)coefficients.get(oldCoefficient.hashCode());
            if (list == null) {
                list = new ArrayList<Tensor[]>();
                coefficients.put(oldCoefficient.hashCode(), list);
            }
            Mapping match = null;
            for (Tensor[] transformed : list) {
                match = IndexMappings.getFirst(transformed[0], oldCoefficient);
                if (match == null) continue;
                newCoefficient = match.getSign() ? Tensors.negate(transformed[1]) : transformed[1];
                break;
            }
            if (match == null) {
                if (oldCoefficient instanceof SimpleTensor) {
                    newCoefficient = oldCoefficient;
                } else if (this.withCoefficients) {
                    newCoefficient = CC.generateNewSymbol();
                    this.coefficients.add((SimpleTensor)newCoefficient);
                    this.coefficients.removeAll(TensorUtils.getAllSymbols(oldCoefficient));
                }
                list.add(new Tensor[]{oldCoefficient, newCoefficient});
            }
            rebuild.put(Tensors.multiply(((Product)t).getFactor(), newCoefficient, ((Product)t).getDataSubProduct()));
        }
        return rebuild.build();
    }

    private Tensor result() {
        return this.result;
    }

    public static Tensor generate(SimpleIndices indices, Tensor[] samples, boolean symmetricForm, boolean withCoefficients, boolean raiseLower) {
        return new TensorGenerator(indices, samples, symmetricForm, withCoefficients, raiseLower).result();
    }

    public static GeneratedTensor generateStructure(SimpleIndices indices, Tensor[] samples, boolean symmetricForm, boolean withCoefficients, boolean raiseLower) {
        TensorGenerator generator = new TensorGenerator(indices, samples, symmetricForm, withCoefficients, raiseLower);
        SimpleTensor[] generatedCoefficients = TensorUtils.getAllSymbols(generator.result()).toArray(new SimpleTensor[0]);
        return new GeneratedTensor(generatedCoefficients, generator.result());
    }

    private static Tensor[] expandSamples(Tensor[] samples) {
        HashSet<Wrapper> set = new HashSet<Wrapper>();
        for (Tensor sample : samples) {
            set.add(new Wrapper(sample));
        }
        ArrayList<Tensor> r = new ArrayList<Tensor>();
        for (Wrapper st : set) {
            r.ensureCapacity(ArithmeticUtils.pow((int)2, (int)st.tensor.getIndices().getFree().size()));
            r.addAll(Arrays.asList(TensorGeneratorUtils.allStatesCombinations(st.tensor)));
        }
        return r.toArray(new Tensor[r.size()]);
    }

    private static class Wrapper {
        private final Tensor tensor;
        private final StructureOfIndices freeIndices;

        private Wrapper(Tensor tensor) {
            this.tensor = tensor;
            this.freeIndices = StructureOfIndices.create(IndicesFactory.createSimple(null, tensor.getIndices().getFree()));
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Wrapper wrapper = (Wrapper)o;
            return this.freeIndices.equals(wrapper.freeIndices) && IndexMappings.anyMappingExists(this.tensor, wrapper.tensor);
        }

        public int hashCode() {
            return this.tensor.hashCode();
        }
    }
}

