/*
 * Decompiled with CFR 0.152.
 */
package cc.redberry.core.parser.preprocessor;

import cc.redberry.core.context.CC;
import cc.redberry.core.context.NameAndStructureOfIndices;
import cc.redberry.core.context.NameDescriptor;
import cc.redberry.core.indexgenerator.IndexGeneratorImpl;
import cc.redberry.core.indices.IndexType;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.indices.StructureOfIndices;
import cc.redberry.core.parser.ParseToken;
import cc.redberry.core.parser.ParseTokenSimpleTensor;
import cc.redberry.core.parser.ParseTokenTensorField;
import cc.redberry.core.parser.ParseTokenTransformer;
import cc.redberry.core.parser.ParseUtils;
import cc.redberry.core.parser.TokenType;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.BitArray;
import cc.redberry.core.utils.IntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class GeneralIndicesInsertion
implements ParseTokenTransformer {
    private final Map<NameAndStructureOfIndices, InsertionRule> initialRules = new HashMap<NameAndStructureOfIndices, InsertionRule>();
    private Map<NameAndStructureOfIndices, InsertionRule> mappedRules;

    public void addInsertionRule(SimpleTensor tensor, IndexType omittedIndexType) {
        this.addInsertionRule(CC.getNameDescriptor(tensor.getName()), omittedIndexType);
    }

    public void addInsertionRule(NameDescriptor nd, IndexType omittedIndexType) {
        NameAndStructureOfIndices originalStructureAndName = NameDescriptor.extractKey(nd);
        StructureOfIndices structure = nd.getStructureOfIndices();
        if (structure.getTypeData((byte)omittedIndexType.getType()).length == 0) {
            throw new IllegalArgumentException("No indices of specified type in tensor.");
        }
        if (CC.isMetric(omittedIndexType.getType())) {
            int omittedIndicesCount = structure.getTypeData((byte)omittedIndexType.getType()).length;
            if (omittedIndicesCount % 2 == 1) {
                throw new IllegalArgumentException("The number of omitted indices for metric types should be even.");
            }
            omittedIndicesCount /= 2;
            BitArray omittedIndices = structure.getTypeData((byte)omittedIndexType.getType()).states;
            int size = omittedIndices.size();
            for (int i = 0; i < size; ++i) {
                if (i < omittedIndicesCount && !omittedIndices.get(i)) {
                    throw new IllegalArgumentException("Inconsistent states signature for metric type.");
                }
                if (i < omittedIndicesCount || !omittedIndices.get(i)) continue;
                throw new IllegalArgumentException("Inconsistent states signature for metric type.");
            }
        }
        this.mappedRules = null;
        InsertionRule rule = this.initialRules.get(originalStructureAndName);
        if (rule == null) {
            rule = new InsertionRule(originalStructureAndName);
            this.initialRules.put(originalStructureAndName, rule);
        }
        rule.indicesAllowedToOmit.add(omittedIndexType);
    }

    private void ensureMappedRulesInitialized() {
        if (this.mappedRules != null) {
            return;
        }
        this.mappedRules = new HashMap<NameAndStructureOfIndices, InsertionRule>();
        for (InsertionRule rule : this.initialRules.values()) {
            for (NameAndStructureOfIndices key : rule.getKeys()) {
                if (this.mappedRules.put(key, rule) == null) continue;
                throw new RuntimeException("Conflicting insertion rules.");
            }
        }
    }

    @Override
    public ParseToken transform(ParseToken node) {
        this.ensureMappedRulesInitialized();
        int[] forbidden = ParseUtils.getAllIndicesT(node).toArray();
        IndexGeneratorImpl generator = new IndexGeneratorImpl(forbidden);
        this.transformInsideFieldsAndScalarFunctions(node);
        ParseToken wrapped = new ParseToken(TokenType.Dummy, node);
        IITransformer transformer = this.createTransformer(wrapped);
        node = wrapped.content[0];
        node.parent = null;
        if (transformer == null) {
            return node;
        }
        OuterIndices outerIndices = transformer.getOuterIndices();
        int[][] upper = new int[8][];
        int[][] lower = new int[8][];
        for (byte i = 0; i < 8; i = (byte)(i + 1)) {
            int j;
            upper[i] = new int[outerIndices.upper[i]];
            for (j = 0; j < upper[i].length; ++j) {
                upper[i][j] = Integer.MIN_VALUE | generator.generate(i);
            }
            lower[i] = new int[outerIndices.lower[i]];
            for (j = 0; j < lower[i].length; ++j) {
                lower[i][j] = generator.generate(i);
            }
        }
        transformer.apply(generator, upper, lower);
        return node;
    }

    private void transformInsideFieldsAndScalarFunctions(ParseToken pn) {
        if (pn.tokenType == TokenType.TensorField) {
            ParseTokenTensorField pntf = (ParseTokenTensorField)pn;
            if (!pntf.name.equalsIgnoreCase("tr")) {
                for (int i = 0; i < pn.content.length; ++i) {
                    ParseToken newArgNode;
                    pntf.content[i] = newArgNode = this.transform(pntf.content[i]);
                    newArgNode.parent = pntf;
                    SimpleIndices oldArgIndices = pntf.argumentsIndices[i];
                    if (oldArgIndices == null) continue;
                    IntArrayList newArgIndices = new IntArrayList(oldArgIndices.getAllIndices().copy());
                    Indices newIndices = newArgNode.getIndices();
                    for (byte j = 0; j < 8; j = (byte)(j + 1)) {
                        if (oldArgIndices.size(IndexType.getType(j)) >= newIndices.size(IndexType.getType(j))) continue;
                        if (oldArgIndices.size(IndexType.getType(j)) != 0) {
                            throw new IllegalArgumentException("Error in field arg indices.");
                        }
                        newArgIndices.addAll(newIndices.getOfType(IndexType.getType(j)).getAllIndices());
                    }
                    pntf.argumentsIndices[i] = IndicesFactory.createSimple(null, newArgIndices.toArray());
                }
            }
        }
        if (pn.tokenType == TokenType.Power || pn.tokenType == TokenType.ScalarFunction) {
            for (int i = 0; i < pn.content.length; ++i) {
                pn.content[i] = this.transform(pn.content[i]);
            }
        }
        for (int i = 0; i < pn.content.length; ++i) {
            this.transformInsideFieldsAndScalarFunctions(pn.content[i]);
        }
    }

    private IITransformer createTransformer(ParseToken node) {
        switch (node.tokenType) {
            case TensorField: {
                if (((ParseTokenTensorField)node).name.equalsIgnoreCase("tr")) {
                    int i;
                    EnumSet<IndexType> types;
                    if (node.content.length == 1) {
                        types = EnumSet.allOf(IndexType.class);
                    } else {
                        types = EnumSet.noneOf(IndexType.class);
                        for (i = 1; i < node.content.length; ++i) {
                            ParseToken pn = node.content[i];
                            if (pn.tokenType != TokenType.SimpleTensor) {
                                throw new IllegalArgumentException("Error in trace indices list.");
                            }
                            IndexType type = IndexType.fromShortString(((ParseTokenSimpleTensor)pn).name);
                            if (type == null && (type = IndexType.valueOf(((ParseTokenSimpleTensor)pn).name)) == null) {
                                throw new IllegalArgumentException("Error in trace indices list.");
                            }
                            types.add(type);
                        }
                    }
                    ParseToken nested = node.content[0];
                    ParseToken parent = node.parent;
                    for (i = 0; i < parent.content.length; ++i) {
                        if (parent.content[i] != node) continue;
                        parent.content[i] = nested;
                        nested.parent = parent;
                        break;
                    }
                    assert (i != parent.content.length);
                    IITransformer innerTransformer = this.createTransformer(nested);
                    if (innerTransformer == null) {
                        return null;
                    }
                    return new TraceTransformer(innerTransformer, types);
                }
            }
            case SimpleTensor: {
                InsertionRule rule = this.mappedRules.get(((ParseTokenSimpleTensor)node).getIndicesTypeStructureAndName());
                if (rule != null) {
                    return new SimpleTransformer((ParseTokenSimpleTensor)node, rule);
                }
                return null;
            }
            case Product: {
                ArrayList<IITransformer> transformersList = new ArrayList<IITransformer>();
                for (ParseToken _node : node.content) {
                    IITransformer t = this.createTransformer(_node);
                    if (t == null) continue;
                    transformersList.add(t);
                }
                if (transformersList.isEmpty()) {
                    return null;
                }
                if (transformersList.size() == 1) {
                    return (IITransformer)transformersList.get(0);
                }
                return new ProductTransformer(transformersList.toArray(new IITransformer[transformersList.size()]));
            }
            case Expression: {
                IITransformer lhsTransformer = this.createTransformer(node.content[0]);
                IITransformer rhsTransformer = this.createTransformer(node.content[1]);
                if (lhsTransformer == null && rhsTransformer == null) {
                    return null;
                }
                OuterIndices lhsOuterIndices = lhsTransformer == null ? OuterIndices.EMPTY : lhsTransformer.getOuterIndices();
                OuterIndices rhsOuterIndices = rhsTransformer == null ? OuterIndices.EMPTY : rhsTransformer.getOuterIndices();
                for (int i = 0; i < 8; ++i) {
                    if (rhsOuterIndices.upper[i] == 0 && rhsOuterIndices.lower[i] == 0 || lhsOuterIndices.initialized[i]) continue;
                    throw new IllegalArgumentException("Inconsistent matrix expression.");
                }
                return new SumTransformer(new IITransformer[]{lhsTransformer, rhsTransformer}, lhsOuterIndices, node);
            }
            case Sum: {
                IITransformer[] transformersArray = new IITransformer[node.content.length];
                OuterIndices outerIndices = null;
                for (int i = 0; i < transformersArray.length; ++i) {
                    transformersArray[i] = this.createTransformer(node.content[i]);
                    if (transformersArray[i] == null) continue;
                    OuterIndices currentOI = transformersArray[i].getOuterIndices();
                    if (outerIndices != null) {
                        outerIndices.cumulativeAggregate(currentOI);
                        continue;
                    }
                    outerIndices = currentOI.clone();
                }
                if (outerIndices == null) {
                    return null;
                }
                return new SumTransformer(transformersArray, outerIndices, node);
            }
            case Dummy: {
                return this.createTransformer(node.content[0]);
            }
        }
        return null;
    }

    private static class ProductTransformer
    extends MIITransformer {
        private final OuterIndices outerIndices;

        public ProductTransformer(IITransformer[] transformers) {
            super(transformers);
            OuterIndices oi = null;
            for (IITransformer transformer : transformers) {
                if (oi == null) {
                    oi = transformer.getOuterIndices().clone();
                    continue;
                }
                oi.cumulativeAdd(transformer.getOuterIndices());
            }
            this.outerIndices = oi;
        }

        @Override
        public OuterIndices getOuterIndices() {
            return this.outerIndices;
        }

        @Override
        public void apply(IndexGeneratorImpl generator, int[][] upper, int[][] lower) {
            byte j;
            int i;
            int[] totalCountUpper = new int[8];
            int[] totalCountLower = new int[8];
            TransformersIndicesRange[] upperRanges = new TransformersIndicesRange[this.transformers.length];
            TransformersIndicesRange[] lowerRanges = new TransformersIndicesRange[this.transformers.length];
            for (i = 0; i < this.transformers.length; ++i) {
                OuterIndices oi = this.transformers[i].getOuterIndices();
                upperRanges[i] = new TransformersIndicesRange((int[])totalCountUpper.clone(), (int[])oi.upper.clone());
                lowerRanges[i] = new TransformersIndicesRange((int[])totalCountLower.clone(), (int[])oi.lower.clone());
                for (j = 0; j < 8; j = (byte)((byte)(j + 1))) {
                    byte by = j;
                    totalCountUpper[by] = totalCountUpper[by] + oi.upper[j];
                    byte by2 = j;
                    totalCountLower[by2] = totalCountLower[by2] + oi.lower[j];
                }
            }
            int[][] totalUppers = new int[8][];
            int[][] totalLowers = new int[8][];
            for (j = 0; j < 8; j = (byte)((byte)(j + 1))) {
                totalUppers[j] = new int[totalCountUpper[j]];
                totalLowers[j] = new int[totalCountLower[j]];
                System.arraycopy(upper[j], 0, totalUppers[j], 0, upper[j].length);
                System.arraycopy(lower[j], 0, totalLowers[j], totalCountLower[j] - lower[j].length, lower[j].length);
                if (totalCountLower[j] - lower[j].length != totalCountUpper[j] - upper[j].length) {
                    throw new IllegalArgumentException();
                }
                for (i = 0; i < totalCountUpper[j] - upper[j].length; ++i) {
                    totalLowers[j][i] = generator.generate(j);
                    totalUppers[j][i + upper[j].length] = totalLowers[j][i] | Integer.MIN_VALUE;
                }
            }
            for (i = 0; i < this.transformers.length; ++i) {
                int[][] cUpper = new int[8][];
                int[][] cLower = new int[8][];
                for (j = 0; j < 8; j = (byte)((byte)(j + 1))) {
                    cUpper[j] = Arrays.copyOfRange(totalUppers[j], upperRanges[i].from[j], upperRanges[i].from[j] + upperRanges[i].count[j]);
                    cLower[j] = Arrays.copyOfRange(totalLowers[j], lowerRanges[i].from[j], lowerRanges[i].from[j] + lowerRanges[i].count[j]);
                }
                this.transformers[i].apply(generator, cUpper, cLower);
            }
        }
    }

    private static class TraceTransformer
    implements IITransformer {
        private final OuterIndices outerIndices;
        private final IITransformer innerTransformer;
        private final Set<IndexType> typesToContract;

        private TraceTransformer(IITransformer innerTransformer, Set<IndexType> typesToContract) {
            this.innerTransformer = innerTransformer;
            this.typesToContract = new HashSet<IndexType>(typesToContract);
            this.outerIndices = innerTransformer.getOuterIndices().clone();
            for (IndexType type : typesToContract) {
                if (this.outerIndices.upper[type.getType()] != this.outerIndices.lower[type.getType()]) {
                    throw new IllegalArgumentException("Illegal trace usage.");
                }
                if (this.outerIndices.upper[type.getType()] == 0) {
                    this.typesToContract.remove((Object)type);
                }
                this.outerIndices.lower[type.getType()] = 0;
                this.outerIndices.upper[type.getType()] = 0;
            }
        }

        @Override
        public OuterIndices getOuterIndices() {
            return this.outerIndices;
        }

        @Override
        public void apply(IndexGeneratorImpl generator, int[][] upper, int[][] lower) {
            OuterIndices innerIndices = this.innerTransformer.getOuterIndices();
            int[][] preparedUpper = (int[][])upper.clone();
            int[][] preparedLower = (int[][])lower.clone();
            for (IndexType type : this.typesToContract) {
                int[] l = new int[innerIndices.lower[type.getType()]];
                int[] u = new int[innerIndices.lower[type.getType()]];
                for (int i = 0; i < l.length; ++i) {
                    int generated;
                    l[i] = generated = generator.generate(type.getType());
                    u[i] = Integer.MIN_VALUE | generated;
                }
                preparedLower[type.getType()] = l;
                preparedUpper[type.getType()] = u;
            }
            this.innerTransformer.apply(generator, preparedUpper, preparedLower);
        }
    }

    private static final class TransformersIndicesRange {
        final int[] from;
        final int[] count;

        public TransformersIndicesRange(int[] from, int[] count) {
            this.from = from;
            this.count = count;
        }
    }

    private static class SumTransformer
    extends MIITransformer {
        private final OuterIndices outerIndices;
        private final ParseToken parseToken;

        private SumTransformer(IITransformer[] transformers, OuterIndices outerIndices, ParseToken parseToken) {
            super(transformers);
            this.outerIndices = outerIndices;
            this.parseToken = parseToken;
        }

        @Override
        public OuterIndices getOuterIndices() {
            return this.outerIndices;
        }

        @Override
        public void apply(IndexGeneratorImpl generator, int[][] upper, int[][] lower) {
            IndexGeneratorImpl generatorTemp = null;
            int[][] preparedUpper = new int[8][];
            int[][] preparedLower = new int[8][];
            for (int i = 0; i < this.transformers.length; ++i) {
                OuterIndices oi;
                if (this.transformers[i] == null) {
                    oi = OuterIndices.EMPTY;
                } else {
                    oi = this.transformers[i].getOuterIndices();
                    if (oi.equals(this.outerIndices)) {
                        System.arraycopy(upper, 0, preparedUpper, 0, 8);
                        System.arraycopy(lower, 0, preparedLower, 0, 8);
                    } else {
                        for (int j = 0; j < 8; j = (int)((byte)(j + 1))) {
                            if (oi.initialized[j]) {
                                preparedUpper[j] = upper[j];
                                preparedLower[j] = lower[j];
                                continue;
                            }
                            preparedUpper[j] = new int[0];
                            preparedLower[j] = new int[0];
                        }
                    }
                    if (i != this.transformers.length - 1) {
                        IndexGeneratorImpl generatorClone = generator.clone();
                        this.transformers[i].apply(generatorClone, preparedUpper, preparedLower);
                        if (generatorTemp == null) {
                            generatorTemp = generatorClone;
                        } else {
                            generatorTemp.mergeFrom(generatorClone);
                        }
                    } else if (generatorTemp == null) {
                        this.transformers[i].apply(generator, preparedUpper, preparedLower);
                    } else {
                        this.transformers[i].apply(generatorTemp, preparedUpper, preparedLower);
                        generator.mergeFrom(generatorTemp);
                    }
                }
                this.parseToken.content[i] = this.addDeltas(oi, this.parseToken.content[i], this.outerIndices, upper, lower);
            }
            if (generatorTemp != null) {
                generator.mergeFrom(generatorTemp);
            }
        }

        private ParseToken addDeltas(OuterIndices inserted, ParseToken node, OuterIndices expected, int[][] upper, int[][] lower) {
            ArrayList<ParseToken> multipliers = new ArrayList<ParseToken>();
            for (int i = 0; i < 8; i = (int)((byte)(i + 1))) {
                if (inserted.initialized[i] || !expected.initialized[i] || expected.lower[i] == 0 && expected.upper[i] == 0) continue;
                if (expected.lower[i] != 1 || expected.upper[i] != 1) {
                    throw new IllegalArgumentException("Deltas insertion is only supported for one upper and one lower omitted indices.");
                }
                multipliers.add(new ParseTokenSimpleTensor(IndicesFactory.createSimple(null, upper[i][0], lower[i][0]), CC.current().getKroneckerName()));
            }
            if (multipliers.isEmpty()) {
                return node;
            }
            multipliers.add(node);
            return new ParseToken(TokenType.Product, multipliers.toArray(new ParseToken[multipliers.size()]));
        }
    }

    private static abstract class MIITransformer
    implements IITransformer {
        protected final IITransformer[] transformers;

        public MIITransformer(IITransformer[] transformers) {
            this.transformers = transformers;
        }
    }

    private static class SimpleTransformer
    implements IITransformer {
        private final ParseTokenSimpleTensor node;
        private final OuterIndices outerIndices = new OuterIndices();

        public SimpleTransformer(ParseTokenSimpleTensor node, InsertionRule insertionRule) {
            this.node = node;
            StructureOfIndices originalStructure = insertionRule.originalStructureAndName.getStructure()[0];
            StructureOfIndices currentStructure = node.getIndicesTypeStructureAndName().getStructure()[0];
            for (IndexType type : insertionRule.indicesAllowedToOmit) {
                if (currentStructure.getStates(type).size() == 0) {
                    BitArray originalStates = originalStructure.getStates(type);
                    if (originalStates != null) {
                        this.outerIndices.upper[type.getType()] = originalStates.bitCount();
                        this.outerIndices.lower[type.getType()] = originalStates.size() - this.outerIndices.upper[type.getType()];
                        continue;
                    }
                    int n = originalStructure.typeCount(type.getType()) / 2;
                    this.outerIndices.lower[type.getType()] = n;
                    this.outerIndices.upper[type.getType()] = n;
                    continue;
                }
                if (currentStructure.typeCount(type.getType()) == originalStructure.typeCount(type.getType())) continue;
                throw new IllegalArgumentException();
            }
            this.outerIndices.init();
        }

        @Override
        public OuterIndices getOuterIndices() {
            return this.outerIndices;
        }

        @Override
        public void apply(IndexGeneratorImpl generator, int[][] upper, int[][] lower) {
            SimpleIndices oldIndices = this.node.indices;
            int[] result = ArraysUtils.addAll(new int[][]{oldIndices.getAllIndices().copy(), ArraysUtils.addAll(upper), ArraysUtils.addAll(lower)});
            this.node.indices = IndicesFactory.createSimple(null, result);
        }
    }

    private static interface IITransformer {
        public OuterIndices getOuterIndices();

        public void apply(IndexGeneratorImpl var1, int[][] var2, int[][] var3);
    }

    private static class OuterIndices {
        public static final OuterIndices EMPTY = new OuterIndices();
        final int[] upper;
        final int[] lower;
        final boolean[] initialized;

        OuterIndices() {
            this.upper = new int[8];
            this.lower = new int[8];
            this.initialized = new boolean[8];
        }

        private OuterIndices(int[] upper, int[] lower, boolean[] initialized) {
            this.upper = upper;
            this.lower = lower;
            this.initialized = initialized;
        }

        public void init() {
            for (int i = 0; i < 8; ++i) {
                this.initialized[i] = this.upper[i] != 0 || this.lower[i] != 0;
            }
        }

        public void cumulativeAggregate(OuterIndices other) {
            for (int i = 0; i < 8; ++i) {
                if (!other.initialized[i]) continue;
                if (this.initialized[i]) {
                    if (this.upper[i] == other.upper[i] && this.lower[i] == other.lower[i]) continue;
                    throw new IllegalArgumentException("Inconsistent omitted indices exception.");
                }
                this.upper[i] = other.upper[i];
                this.lower[i] = other.lower[i];
                this.initialized[i] = true;
            }
        }

        public void cumulativeAdd(OuterIndices other) {
            for (int i = 0; i < 8; ++i) {
                int n = i;
                this.initialized[n] = this.initialized[n] | other.initialized[i];
                int dif = other.upper[i] - this.lower[i];
                this.lower[i] = other.lower[i];
                if (dif < 0) {
                    int n2 = i;
                    this.lower[n2] = this.lower[n2] - dif;
                    continue;
                }
                int n3 = i;
                this.upper[n3] = this.upper[n3] + dif;
            }
        }

        public boolean equals(Object o) {
            OuterIndices that = (OuterIndices)o;
            if (!Arrays.equals(this.initialized, that.initialized)) {
                return false;
            }
            if (!Arrays.equals(this.lower, that.lower)) {
                return false;
            }
            return Arrays.equals(this.upper, that.upper);
        }

        public OuterIndices clone() {
            return new OuterIndices((int[])this.upper.clone(), (int[])this.lower.clone(), (boolean[])this.initialized.clone());
        }
    }

    private static class InsertionRule {
        final NameAndStructureOfIndices originalStructureAndName;
        final Set<IndexType> indicesAllowedToOmit = new HashSet<IndexType>();

        private InsertionRule(NameAndStructureOfIndices originalStructureAndName) {
            this.originalStructureAndName = originalStructureAndName;
        }

        public NameAndStructureOfIndices[] getKeys() {
            IndexType[] toOmit = this.indicesAllowedToOmit.toArray(new IndexType[this.indicesAllowedToOmit.size()]);
            NameAndStructureOfIndices[] keys = new NameAndStructureOfIndices[(1 << toOmit.length) - 1];
            for (int omitted = 1; omitted <= keys.length; ++omitted) {
                int[] allCounts = this.originalStructureAndName.getStructure()[0].getTypesCounts();
                BitArray[] states = this.originalStructureAndName.getStructure()[0].getStates();
                for (int i = 0; i < toOmit.length; ++i) {
                    if ((omitted & 1 << i) == 0) continue;
                    allCounts[toOmit[i].getType()] = 0;
                    states[toOmit[i].getType()] = states[toOmit[i].getType()] == null ? null : BitArray.EMPTY;
                }
                StructureOfIndices[] structures = (StructureOfIndices[])this.originalStructureAndName.getStructure().clone();
                structures[0] = StructureOfIndices.create(allCounts, states);
                keys[omitted - 1] = new NameAndStructureOfIndices(this.originalStructureAndName.getName(), structures);
            }
            return keys;
        }
    }
}

