/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.projection;

import de.lmu.ifi.dbs.elki.algorithm.projection.AbstractProjectionAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.projection.AffinityMatrix;
import de.lmu.ifi.dbs.elki.algorithm.projection.AffinityMatrixBuilder;
import de.lmu.ifi.dbs.elki.algorithm.projection.PerplexityAffinityMatrixBuilder;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.data.type.VectorFieldTypeInformation;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreFactory;
import de.lmu.ifi.dbs.elki.database.datastore.WritableDataStore;
import de.lmu.ifi.dbs.elki.database.ids.DBIDArrayIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.relation.MaterializedRelation;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.logging.statistics.LongStatistic;
import de.lmu.ifi.dbs.elki.math.MathUtil;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import de.lmu.ifi.dbs.elki.utilities.random.RandomFactory;
import java.util.Arrays;
import java.util.Random;

@Reference(authors="G. Hinton, S. Roweis", title="Stochastic Neighbor Embedding", booktitle="Advances in Neural Information Processing Systems 15", url="http://papers.nips.cc/paper/2276-stochastic-neighbor-embedding", bibkey="DBLP:conf/nips/HintonR02")
public class SNE<O>
extends AbstractProjectionAlgorithm<Relation<DoubleVector>> {
    private static final Logging LOG = Logging.getLogger(SNE.class);
    protected static final double MIN_QIJ = 1.0E-12;
    protected static final double INITIAL_SOLUTION_SCALE = 1.0E-4;
    protected static final double MIN_GAIN = 0.01;
    protected AffinityMatrixBuilder<? super O> affinity;
    protected long projectedDistances;
    protected int dim;
    protected double learningRate;
    protected double initialMomentum;
    protected double finalMomentum;
    protected int momentumSwitch = 250;
    protected int iterations;
    protected RandomFactory random;

    public SNE(AffinityMatrixBuilder<? super O> affinity, int dim, RandomFactory random) {
        this(affinity, dim, 0.8, 200.0, 1000, random, true);
    }

    public SNE(AffinityMatrixBuilder<? super O> affinity, int dim, double finalMomentum, double learningRate, int iterations, RandomFactory random, boolean keep) {
        super(keep);
        this.affinity = affinity;
        this.dim = dim;
        this.iterations = iterations;
        this.learningRate = learningRate;
        this.initialMomentum = finalMomentum >= 0.6 ? 0.5 : 0.5 * finalMomentum;
        this.finalMomentum = finalMomentum;
        this.momentumSwitch = iterations / 4;
        this.random = random;
    }

    public Relation<DoubleVector> run(Relation<O> relation) {
        AffinityMatrix pij = this.affinity.computeAffinityMatrix(relation, 1.0);
        int size = pij.size();
        double[][] sol = SNE.randomInitialSolution(size, this.dim, this.random.getSingleThreadedRandom());
        this.projectedDistances = 0L;
        this.optimizeSNE(pij, sol);
        LOG.statistics(new LongStatistic(this.getClass().getName() + ".projected-distances", this.projectedDistances));
        this.removePreviousRelation(relation);
        DBIDs ids = relation.getDBIDs();
        WritableDataStore<DoubleVector> proj = DataStoreFactory.FACTORY.makeStorage(ids, 30, DoubleVector.class);
        VectorFieldTypeInformation<DoubleVector> otype = new VectorFieldTypeInformation<DoubleVector>(DoubleVector.FACTORY, this.dim);
        DBIDArrayIter it = pij.iterDBIDs();
        while (it.valid()) {
            proj.put(it, DoubleVector.wrap(sol[it.getOffset()]));
            it.advance();
        }
        return new MaterializedRelation<DoubleVector>("SNE", "SNE", otype, proj, ids);
    }

    protected static double[][] randomInitialSolution(int size, int dim, Random random) {
        double[][] sol = new double[size][dim];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < dim; ++j) {
                sol[i][j] = random.nextGaussian() * 1.0E-4;
            }
        }
        return sol;
    }

    protected void optimizeSNE(AffinityMatrix pij, double[][] sol) {
        int size = pij.size();
        if ((long)size * 3L * (long)this.dim > 0x7FFFFFFAL) {
            throw new AbortException("Memory exceeds Java array size limit.");
        }
        double[] meta = new double[size * 3 * this.dim];
        int dim3 = this.dim * 3;
        for (int off = 2 * this.dim; off < meta.length; off += dim3) {
            Arrays.fill(meta, off, off + this.dim, 1.0);
        }
        double[][] qij = new double[size][size];
        FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("Iterative Optimization", this.iterations, LOG) : null;
        Duration timer = LOG.isStatistics() ? LOG.newDuration(this.getClass().getName() + ".runtime.optimization").begin() : null;
        for (int it = 0; it < this.iterations; ++it) {
            double qij_sum = this.computeQij(qij, sol);
            this.computeGradient(pij, qij, 1.0 / qij_sum, sol, meta);
            this.updateSolution(sol, meta, it);
            LOG.incrementProcessed(prog);
        }
        LOG.ensureCompleted(prog);
        if (timer != null) {
            LOG.statistics(timer.end());
        }
    }

    protected double computeQij(double[][] qij, double[][] solution) {
        double qij_sum = 0.0;
        for (int i = 1; i < qij.length; ++i) {
            double[] qij_i = qij[i];
            double[] vi = solution[i];
            for (int j = 0; j < i; ++j) {
                double d = MathUtil.exp(-this.sqDist(vi, solution[j]));
                qij[j][i] = d;
                qij_i[j] = d;
                qij_sum += d;
            }
        }
        return qij_sum * 2.0;
    }

    protected double sqDist(double[] v1, double[] v2) {
        assert (v1.length == v2.length) : "Lengths do not agree: " + v1.length + " " + v2.length;
        double sum = 0.0;
        for (int i = 0; i < v1.length; ++i) {
            double diff = v1[i] - v2[i];
            sum += diff * diff;
        }
        ++this.projectedDistances;
        return sum;
    }

    protected void computeGradient(AffinityMatrix pij, double[][] qij, double qij_isum, double[][] sol, double[] meta) {
        int dim3 = this.dim * 3;
        int size = pij.size();
        int i = 0;
        int off = 0;
        while (i < size) {
            double[] sol_i = sol[i];
            double[] qij_i = qij[i];
            Arrays.fill(meta, off, off + this.dim, 0.0);
            for (int j = 0; j < size; ++j) {
                if (i == j) continue;
                double[] sol_j = sol[j];
                double qij_ij = qij_i[j];
                double q = MathUtil.max(qij_ij * qij_isum, 1.0E-12);
                double a = 4.0 * (pij.get(i, j) - q);
                for (int k = 0; k < this.dim; ++k) {
                    int n = off + k;
                    meta[n] = meta[n] + a * (sol_i[k] - sol_j[k]);
                }
            }
            ++i;
            off += dim3;
        }
    }

    protected void updateSolution(double[][] sol, double[] meta, int it) {
        double mom = it < this.momentumSwitch && this.initialMomentum < this.finalMomentum ? this.initialMomentum : this.finalMomentum;
        int dim3 = this.dim * 3;
        int i = 0;
        int off = 0;
        while (i < sol.length) {
            double[] sol_i = sol[i];
            int k = 0;
            while (k < this.dim) {
                int gradk = off + k;
                int movk = gradk + this.dim;
                int gaink = movk + this.dim;
                meta[gaink] = MathUtil.max(meta[gradk] > 0.0 != meta[movk] > 0.0 ? meta[gaink] + 0.2 : meta[gaink] * 0.8, 0.01);
                int n = movk;
                meta[n] = meta[n] * mom;
                int n2 = movk;
                meta[n2] = meta[n2] - this.learningRate * meta[gradk] * meta[gaink];
                int n3 = k++;
                sol_i[n3] = sol_i[n3] + meta[movk];
            }
            ++i;
            off += dim3;
        }
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(this.affinity.getInputTypeRestriction());
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractParameterizer {
        public static final OptionID AFFINITY_ID = new OptionID("tsne.affinity", "Affinity matrix builder.");
        public static final OptionID DIM_ID = new OptionID("tsne.dim", "Output dimensionality.");
        public static final OptionID MOMENTUM_ID = new OptionID("tsne.momentum", "The final momentum to use.");
        public static final OptionID LEARNING_RATE_ID = new OptionID("tsne.learningrate", "Learning rate of the method.");
        public static final OptionID ITER_ID = new OptionID("tsne.iter", "Number of iterations to perform.");
        public static final OptionID RANDOM_ID = new OptionID("tsne.seed", "Random generator seed");
        protected AffinityMatrixBuilder<? super O> affinity;
        protected int dim;
        protected double learningRate;
        protected double finalMomentum;
        protected int iterations;
        protected RandomFactory random;
        protected boolean keep;

        @Override
        protected void makeOptions(Parameterization config) {
            Flag keepF;
            RandomParameter randP;
            IntParameter maxiterP;
            DoubleParameter learningRateP;
            DoubleParameter momentumP;
            IntParameter dimP;
            super.makeOptions(config);
            ObjectParameter affinityP = new ObjectParameter(AFFINITY_ID, (Class<?>)AffinityMatrixBuilder.class, this.getDefaultAffinity());
            if (config.grab(affinityP)) {
                this.affinity = (AffinityMatrixBuilder)affinityP.instantiateClass(config);
            }
            if (config.grab(dimP = (IntParameter)((IntParameter)new IntParameter(DIM_ID).setDefaultValue((Object)2)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT))) {
                this.dim = dimP.intValue();
            }
            if (config.grab(momentumP = (DoubleParameter)((DoubleParameter)((DoubleParameter)new DoubleParameter(MOMENTUM_ID).setDefaultValue((Object)0.8)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_EQUAL_ONE_DOUBLE))) {
                this.finalMomentum = momentumP.doubleValue();
            }
            if (config.grab(learningRateP = (DoubleParameter)((DoubleParameter)new DoubleParameter(LEARNING_RATE_ID).setDefaultValue((Object)200.0)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE))) {
                this.learningRate = learningRateP.doubleValue();
            }
            if (config.grab(maxiterP = (IntParameter)((IntParameter)new IntParameter(ITER_ID).setDefaultValue((Object)1000)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT))) {
                this.iterations = maxiterP.intValue();
            }
            if (config.grab(randP = new RandomParameter(RANDOM_ID))) {
                this.random = (RandomFactory)randP.getValue();
            }
            this.keep = config.grab(keepF = new Flag(AbstractProjectionAlgorithm.KEEP_ID)) && keepF.isTrue();
        }

        protected Class<?> getDefaultAffinity() {
            return PerplexityAffinityMatrixBuilder.class;
        }

        @Override
        protected SNE<O> makeInstance() {
            return new SNE<O>(this.affinity, this.dim, this.finalMomentum, this.learningRate, this.iterations, this.random, this.keep);
        }
    }
}

