/*
 * Decompiled with CFR 0.152.
 */
package keel.Algorithms.Instance_Generation.LVQ;

import keel.Algorithms.Instance_Generation.Basic.Prototype;
import keel.Algorithms.Instance_Generation.Basic.PrototypeGenerationAlgorithm;
import keel.Algorithms.Instance_Generation.Basic.PrototypeSet;
import keel.Algorithms.Instance_Generation.LVQ.LVQ2;
import keel.Algorithms.Instance_Generation.utilities.KNN.KNN;
import keel.Algorithms.Instance_Generation.utilities.Parameters;

public class LVQ3
extends LVQ2 {
    public final double DEFAULT_EPSILON = 0.1;
    protected double epsilon = 0.1;
    protected double epsilonTimesAlpha_0 = 0.1;

    public LVQ3(PrototypeSet tDataSet, int iter, int nProt, double alpha_0, double windowWidth, double epsilon) {
        super(tDataSet, iter, nProt, alpha_0, windowWidth);
        this.algorithmName = "LVQ3";
        this.epsilon = epsilon;
        this.epsilonTimesAlpha_0 = epsilon * alpha_0;
    }

    public LVQ3(PrototypeSet InitialSet, PrototypeSet tDataSet, int iter, int nProt, double alpha_0, double windowWidth, double epsilon) {
        super(InitialSet, tDataSet, iter, nProt, alpha_0, windowWidth);
        this.algorithmName = "LVQ3";
        this.epsilon = epsilon;
        this.epsilonTimesAlpha_0 = epsilon * alpha_0;
    }

    public LVQ3(PrototypeSet tDataSet, int iter, double pcNprot, double alpha_0, double windowWidth, double epsilon) {
        super(tDataSet, iter, pcNprot, alpha_0, windowWidth);
        this.algorithmName = "LVQ3";
        this.epsilon = epsilon;
        this.epsilonTimesAlpha_0 = epsilon * alpha_0;
    }

    public LVQ3(PrototypeSet tDataSet, Parameters par) {
        super(tDataSet, par);
        this.algorithmName = "LVQ3";
        this.epsilon = par.getNextAsDouble();
        this.epsilonTimesAlpha_0 = this.epsilon * this.alpha_0;
    }

    @Override
    protected void reward(Prototype m, Prototype x) {
        Prototype term = x.sub(m);
        term = term.mul(this.alpha_0);
        m.set(m.add(term));
    }

    protected void reward2(Prototype m, Prototype x) {
        Prototype term = x.sub(m);
        term = term.mul(this.epsilonTimesAlpha_0);
        m.set(m.add(term));
    }

    @Override
    protected void penalize(Prototype m, Prototype x) {
        Prototype term = x.sub(m);
        term = term.mul(this.epsilonTimesAlpha_0);
        m.set(m.sub(term));
    }

    @Override
    protected void correct(Prototype x, PrototypeSet tData) {
        Prototype uno = tData.nearestTo(x);
        PrototypeSet dosTdata = tData.without(uno);
        Prototype dos = dosTdata.nearestTo(x);
        double clase_x = x.getOutput(0);
        double clase_0 = uno.getOutput(0);
        double clase_1 = dos.getOutput(0);
        if (clase_x == clase_0 && clase_x == clase_1) {
            this.reward2(uno, x);
            this.reward2(dos, x);
        }
        if (this.isInsideTheWindow(x, uno, dos)) {
            if (clase_x == clase_0 && clase_x != clase_1) {
                this.reward(uno, x);
                this.penalize(dos, x);
            } else if (clase_x == clase_1) {
                this.reward(dos, x);
                this.penalize(uno, x);
            }
        }
    }

    public static void main(String[] args) {
        Parameters.setUse("LVQ3", "<seed> <number of iterations> <% of prototypes> <alpha_0> <window width> <epsilon>");
        Parameters.assertBasicArgs(args);
        PrototypeSet training = PrototypeGenerationAlgorithm.readPrototypeSet(args[0]);
        PrototypeSet test = PrototypeGenerationAlgorithm.readPrototypeSet(args[1]);
        long seed = Parameters.assertExtendedArgAsInt(args, 2, "seed", 0.0, 9.223372036854776E18);
        int iter = Parameters.assertExtendedArgAsInt(args, 3, "number of iterations", 1.0, 2.147483647E9);
        int n_prot = Parameters.assertExtendedArgAsInt(args, 4, "number of prototypes", 1.0, training.size() - 1);
        double alpha_0 = Parameters.assertExtendedArgAsDouble(args, 5, "alpha_0", 0.0, 1.0);
        double wind = Parameters.assertExtendedArgAsDouble(args, 6, "window width", 0.0, 1.0);
        double epsilon = Parameters.assertExtendedArgAsDouble(args, 7, "epsilon", 0.0, 1.0);
        LVQ3.setSeed(seed);
        LVQ3 generator = new LVQ3(training, iter, n_prot, alpha_0, wind, epsilon);
        PrototypeSet resultingSet = generator.execute();
        int accuracy1NN = KNN.classficationAccuracy(resultingSet, test);
        generator.showResultsOfAccuracy(Parameters.getFileName(), accuracy1NN, test);
    }
}

