/*
 * Decompiled with CFR 0.152.
 */
package jsat.utils.concurrent;

import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.BinaryOperator;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import jsat.utils.FakeExecutor;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.IndexReducer;
import jsat.utils.concurrent.IndexRunnable;
import jsat.utils.concurrent.LoopChunkReducer;
import jsat.utils.concurrent.LoopChunkRunner;

public class ParallelUtils {
    public static final ExecutorService CACHED_THREAD_POOL = Executors.newCachedThreadPool(r -> {
        Thread t = Executors.defaultThreadFactory().newThread(r);
        t.setDaemon(true);
        return t;
    });

    public static void run(boolean parallel, int N, LoopChunkRunner lcr) {
        ExecutorService threadPool = Executors.newFixedThreadPool(SystemInfo.LogicalCores);
        ParallelUtils.run(parallel, N, lcr, threadPool);
        threadPool.shutdownNow();
    }

    public static void run(boolean parallel, int N, LoopChunkRunner lcr, ExecutorService threadPool) {
        if (!parallel) {
            lcr.run(0, N);
            return;
        }
        int cores_to_use = Math.min(SystemInfo.LogicalCores, N);
        CountDownLatch latch = new CountDownLatch(cores_to_use);
        IntStream.range(0, cores_to_use).forEach(threadID -> threadPool.submit(() -> {
            int start = ParallelUtils.getStartBlock(N, threadID, cores_to_use);
            int end = ParallelUtils.getEndBlock(N, threadID, cores_to_use);
            lcr.run(start, end);
            latch.countDown();
        }));
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ParallelUtils.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public static <T> T run(boolean parallel, int N, LoopChunkReducer<T> lcr, BinaryOperator<T> reducer, ExecutorService threadPool) {
        if (!parallel) {
            return lcr.run(0, N);
        }
        int cores_to_use = Math.min(SystemInfo.LogicalCores, N);
        ArrayList futures = new ArrayList(cores_to_use);
        IntStream.range(0, cores_to_use).forEach(threadID -> futures.add(threadPool.submit(() -> {
            int start = ParallelUtils.getStartBlock(N, threadID, cores_to_use);
            int end = ParallelUtils.getEndBlock(N, threadID, cores_to_use);
            return lcr.run(start, end);
        })));
        Object cur = null;
        for (Future ft : futures) {
            try {
                Object chunk = ft.get();
                if (cur == null) {
                    cur = chunk;
                    continue;
                }
                cur = reducer.apply(cur, chunk);
            }
            catch (InterruptedException | ExecutionException ex) {
                Logger.getLogger(ParallelUtils.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        return cur;
    }

    public static <T> T run(boolean parallel, int N, LoopChunkReducer<T> lcr, BinaryOperator<T> reducer) {
        ExecutorService threadPool = Executors.newWorkStealingPool(SystemInfo.LogicalCores);
        T toRet = ParallelUtils.run(parallel, N, lcr, reducer, threadPool);
        threadPool.shutdownNow();
        return toRet;
    }

    public static <T> T run(boolean parallel, int N, IndexReducer<T> ir, BinaryOperator<T> reducer) {
        if (!parallel) {
            Object runner = ir.run(0);
            for (int i = 1; i < N; ++i) {
                runner = reducer.apply(runner, ir.run(i));
            }
            return runner;
        }
        return ParallelUtils.range(N, parallel).mapToObj(j -> ir.run(j)).reduce(reducer).orElse(null);
    }

    public static void run(boolean parallel, int N, IndexRunnable ir) {
        ExecutorService threadPool = Executors.newWorkStealingPool(SystemInfo.LogicalCores);
        ParallelUtils.run(parallel, N, ir, threadPool);
        threadPool.shutdownNow();
    }

    public static void run(boolean parallel, int N, IndexRunnable ir, ExecutorService threadPool) {
        if (!parallel) {
            for (int i = 0; i < N; ++i) {
                ir.run(i);
            }
            return;
        }
        CountDownLatch latch = new CountDownLatch(N);
        IntStream.range(0, N).forEach(threadID -> threadPool.submit(() -> {
            ir.run(threadID);
            latch.countDown();
        }));
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ParallelUtils.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public static ExecutorService getNewExecutor(boolean parallel) {
        if (parallel) {
            return Executors.newFixedThreadPool(SystemInfo.LogicalCores);
        }
        return new FakeExecutor();
    }

    public static <T> Stream<T> streamP(Stream<T> source, boolean parallel) {
        if (parallel) {
            return (Stream)source.parallel();
        }
        return source;
    }

    public static IntStream streamP(IntStream source, boolean parallel) {
        if (parallel) {
            return source.parallel();
        }
        return source;
    }

    public static DoubleStream streamP(DoubleStream source, boolean parallel) {
        if (parallel) {
            return source.parallel();
        }
        return source;
    }

    public static IntStream range(int end, boolean parallel) {
        return ParallelUtils.range(0, end, parallel);
    }

    public static IntStream range(int start, int end, boolean parallel) {
        if (parallel) {
            return ((Stream)ListUtils.range(start, end).stream().parallel()).mapToInt(i -> i);
        }
        return IntStream.range(start, end);
    }

    public static int getStartBlock(int N, int ID, int P) {
        int rem = N % P;
        int start = N / P * ID + Math.min(rem, ID);
        return start;
    }

    public static int getStartBlock(int N, int ID) {
        return ParallelUtils.getStartBlock(N, ID, SystemInfo.LogicalCores);
    }

    public static int getEndBlock(int N, int ID, int P) {
        int rem = N % P;
        int start = N / P * (ID + 1) + Math.min(rem, ID + 1);
        return start;
    }

    public static int getEndBlock(int N, int ID) {
        return ParallelUtils.getEndBlock(N, ID, SystemInfo.LogicalCores);
    }
}

