/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Exploration;

import Catalano.MachineLearning.Exploration.IExplorationPolicy;
import java.util.Random;

public class Sarsa {
    private int states;
    private int actions;
    private double[][] qvalues;
    private IExplorationPolicy explorationPolicy;
    private double discountFactor = 0.95;
    private double learningRate = 0.25;

    public int getStates() {
        return this.states;
    }

    public int getActions() {
        return this.actions;
    }

    public IExplorationPolicy getExplorationPolicy() {
        return this.explorationPolicy;
    }

    public void setExplorationPolicy(IExplorationPolicy explorationPolicy) {
        this.explorationPolicy = explorationPolicy;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = Math.max(0.0, Math.min(1.0, learningRate));
    }

    public double getDiscountFactor() {
        return this.discountFactor;
    }

    public void setDiscountFactor(double discountFactor) {
        this.discountFactor = Math.max(0.0, Math.min(1.0, discountFactor));
    }

    public Sarsa(int states, int actions, IExplorationPolicy explorationPolicy, boolean randomize) {
        this.states = states;
        this.actions = actions;
        this.explorationPolicy = explorationPolicy;
        this.qvalues = new double[states][];
        for (int i = 0; i < states; ++i) {
            this.qvalues[i] = new double[actions];
        }
        if (randomize) {
            Random r = new Random();
            for (int i = 0; i < states; ++i) {
                for (int j = 0; j < actions; ++j) {
                    this.qvalues[i][j] = r.nextDouble() / 10.0;
                }
            }
        }
    }

    public int GetAction(int state) {
        return this.explorationPolicy.ChooseAction(this.qvalues[state]);
    }

    public void UpdateState(int previousState, int previousAction, double reward, int nextState, int nextAction) {
        double[] previousActionEstimations = this.qvalues[previousState];
        int n = previousAction;
        previousActionEstimations[n] = previousActionEstimations[n] * (1.0 - this.learningRate);
        int n2 = previousAction;
        previousActionEstimations[n2] = previousActionEstimations[n2] + this.learningRate * (reward + this.discountFactor * this.qvalues[nextState][nextAction]);
    }

    public void UpdateState(int previousState, int previousAction, double reward) {
        double[] previousActionEstimations = this.qvalues[previousState];
        int n = previousAction;
        previousActionEstimations[n] = previousActionEstimations[n] * (1.0 - this.learningRate);
        int n2 = previousAction;
        previousActionEstimations[n2] = previousActionEstimations[n2] + this.learningRate * reward;
    }
}

