/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans;

import de.lmu.ifi.dbs.elki.algorithm.AbstractNumberVectorDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.DistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.RandomlyChosenInitialMeans;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.SparseNumberVector;
import de.lmu.ifi.dbs.elki.data.model.KMeansModel;
import de.lmu.ifi.dbs.elki.data.model.Model;
import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil;
import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.NumberVectorDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.IndefiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.logging.statistics.LongStatistic;
import de.lmu.ifi.dbs.elki.math.linearalgebra.VMath;
import de.lmu.ifi.dbs.elki.utilities.datastructures.arrays.DoubleIntegerArrayQuickSort;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.jafama.FastMath;

public abstract class AbstractKMeans<V extends NumberVector, M extends Model>
extends AbstractNumberVectorDistanceBasedAlgorithm<V, Clustering<M>>
implements KMeans<V, M> {
    protected int k;
    protected int maxiter;
    protected KMeansInitialization initializer;

    public AbstractKMeans(NumberVectorDistanceFunction<? super V> distanceFunction, int k, int maxiter, KMeansInitialization initializer) {
        super(distanceFunction);
        this.k = k;
        this.maxiter = maxiter > 0 ? maxiter : Integer.MAX_VALUE;
        this.initializer = initializer;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new CombinedTypeInformation(TypeUtil.NUMBER_VECTOR_FIELD, this.getDistanceFunction().getInputTypeRestriction()));
    }

    protected double[][] initialMeans(Database database, Relation<V> relation) {
        Duration inittime = this.getLogger().newDuration(this.initializer.getClass() + ".time").begin();
        double[][] means = this.initializer.chooseInitialMeans(database, (Relation<? extends NumberVector>)relation, this.k, (NumberVectorDistanceFunction<?>)this.getDistanceFunction());
        this.getLogger().statistics(inittime.end());
        return means;
    }

    protected static double[][] means(List<? extends DBIDs> clusters, double[][] means, Relation<? extends NumberVector> relation) {
        if (TypeUtil.SPARSE_VECTOR_FIELD.isAssignableFromType(relation.getDataTypeInformation())) {
            Relation<? extends NumberVector> sparse = relation;
            return AbstractKMeans.sparseMeans(clusters, means, sparse);
        }
        return AbstractKMeans.denseMeans(clusters, means, relation);
    }

    private static double[][] denseMeans(List<? extends DBIDs> clusters, double[][] means, Relation<? extends NumberVector> relation) {
        int k = means.length;
        double[][] newMeans = new double[k][];
        for (int i = 0; i < newMeans.length; ++i) {
            DBIDs list = clusters.get(i);
            if (list.isEmpty()) {
                newMeans[i] = means[i];
                continue;
            }
            DBIDIter iter = list.iter();
            double[] sum = relation.get(iter).toArray();
            iter.advance();
            while (iter.valid()) {
                AbstractKMeans.plusEquals(sum, relation.get(iter));
                iter.advance();
            }
            newMeans[i] = VMath.timesEquals(sum, 1.0 / (double)list.size());
        }
        return newMeans;
    }

    public static void plusEquals(double[] sum, NumberVector vec) {
        for (int d = 0; d < sum.length; ++d) {
            int n = d;
            sum[n] = sum[n] + vec.doubleValue(d);
        }
    }

    public static void minusEquals(double[] sum, NumberVector vec) {
        for (int d = 0; d < sum.length; ++d) {
            int n = d;
            sum[n] = sum[n] - vec.doubleValue(d);
        }
    }

    public static void plusMinusEquals(double[] add, double[] sub, NumberVector vec) {
        int d = 0;
        while (d < add.length) {
            double v = vec.doubleValue(d);
            int n = d;
            add[n] = add[n] + v;
            int n2 = d++;
            sub[n2] = sub[n2] - v;
        }
    }

    private static double[][] sparseMeans(List<? extends DBIDs> clusters, double[][] means, Relation<? extends SparseNumberVector> relation) {
        int k = means.length;
        double[][] newMeans = new double[k][];
        for (int i = 0; i < k; ++i) {
            DBIDs list = clusters.get(i);
            if (list.isEmpty()) {
                newMeans[i] = means[i];
                continue;
            }
            DBIDIter iter = list.iter();
            double[] mean = relation.get(iter).toArray();
            iter.advance();
            while (iter.valid()) {
                SparseNumberVector vec = relation.get(iter);
                int j = vec.iter();
                while (vec.iterValid(j)) {
                    int n = vec.iterDim(j);
                    mean[n] = mean[n] + vec.iterDoubleValue(j);
                    j = vec.iterAdvance(j);
                }
                iter.advance();
            }
            newMeans[i] = VMath.timesEquals(mean, 1.0 / (double)list.size());
        }
        return newMeans;
    }

    protected static void nearestMeans(double[][] cdist, int[][] cnum) {
        int k = cdist.length;
        double[] buf = new double[k - 1];
        for (int i = 0; i < k; ++i) {
            System.arraycopy(cdist[i], 0, buf, 0, i);
            System.arraycopy(cdist[i], i + 1, buf, i, k - i - 1);
            for (int j = 0; j < buf.length; ++j) {
                cnum[i][j] = j < i ? j : j + 1;
            }
            DoubleIntegerArrayQuickSort.sort(buf, cnum[i], k - 1);
        }
    }

    protected static void incrementalUpdateMean(double[] mean, NumberVector vec, int newsize, double op) {
        if (newsize == 0) {
            return;
        }
        VMath.plusTimesEquals(mean, VMath.minusEquals(vec.toArray(), mean), op / (double)newsize);
    }

    @Override
    public void setK(int k) {
        this.k = k;
    }

    @Override
    public void setDistanceFunction(NumberVectorDistanceFunction<? super V> distanceFunction) {
        this.distanceFunction = distanceFunction;
    }

    @Override
    public void setInitializer(KMeansInitialization init) {
        this.initializer = init;
    }

    public static abstract class Parameterizer<V extends NumberVector>
    extends AbstractNumberVectorDistanceBasedAlgorithm.Parameterizer<V> {
        protected int k;
        protected int maxiter;
        protected KMeansInitialization initializer;
        protected boolean varstat = false;

        @Override
        protected void makeOptions(Parameterization config) {
            this.getParameterK(config);
            this.getParameterInitialization(config);
            this.getParameterDistanceFunction(config);
            this.getParameterMaxIter(config);
        }

        protected void getParameterK(Parameterization config) {
            IntParameter kP = (IntParameter)new IntParameter(KMeans.K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (config.grab(kP)) {
                this.k = (Integer)kP.getValue();
            }
        }

        protected void getParameterDistanceFunction(Parameterization config) {
            ObjectParameter distanceFunctionP = new ObjectParameter(DistanceBasedAlgorithm.DISTANCE_FUNCTION_ID, (Class<?>)PrimitiveDistanceFunction.class, SquaredEuclideanDistanceFunction.class);
            if (config.grab(distanceFunctionP)) {
                this.distanceFunction = (NumberVectorDistanceFunction)distanceFunctionP.instantiateClass(config);
                if (this.distanceFunction == null || this.distanceFunction instanceof SquaredEuclideanDistanceFunction || this.distanceFunction instanceof EuclideanDistanceFunction) {
                    return;
                }
                if (this.needsMetric() && !this.distanceFunction.isMetric()) {
                    Logging.getLogger(this.getClass()).warning("This k-means variants requires the triangle inequality, and thus should only be used with squared Euclidean distance!");
                } else {
                    Logging.getLogger(this.getClass()).warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!");
                }
            }
        }

        protected boolean needsMetric() {
            return false;
        }

        protected void getParameterInitialization(Parameterization config) {
            ObjectParameter initialP = new ObjectParameter(KMeans.INIT_ID, (Class<?>)KMeansInitialization.class, RandomlyChosenInitialMeans.class);
            if (config.grab(initialP)) {
                this.initializer = (KMeansInitialization)initialP.instantiateClass(config);
            }
        }

        protected void getParameterMaxIter(Parameterization config) {
            IntParameter maxiterP = (IntParameter)new IntParameter(KMeans.MAXITER_ID, 0).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT);
            if (config.grab(maxiterP)) {
                this.maxiter = (Integer)maxiterP.getValue();
            }
        }

        protected void getParameterVarstat(Parameterization config) {
            Flag varF = new Flag(KMeans.VARSTAT_ID);
            this.varstat = config.grab(varF) && varF.isTrue();
        }

        @Override
        protected abstract AbstractKMeans<V, ?> makeInstance();
    }

    protected static abstract class Instance {
        double[][] means;
        protected List<ModifiableDBIDs> clusters;
        protected WritableIntegerDataStore assignment;
        protected double[] varsum;
        protected Relation<? extends NumberVector> relation;
        private long diststat = 0L;
        private final NumberVectorDistanceFunction<?> df;
        protected final int k;
        protected final boolean isSquared;
        protected String key;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistanceFunction<?> df, double[][] means) {
            this.relation = relation;
            this.df = df;
            this.isSquared = df.isSquared();
            this.means = means;
            this.k = means.length;
            int guessedsize = (int)((double)relation.size() * 2.0 / (double)this.k);
            this.clusters = new ArrayList<ModifiableDBIDs>(this.k);
            for (int i = 0; i < this.k; ++i) {
                this.clusters.add(DBIDUtil.newHashSet(guessedsize));
            }
            this.assignment = DataStoreUtil.makeIntegerStorage(relation.getDBIDs(), 3, -1);
            this.varsum = new double[this.k];
            this.key = this.getClass().getName().replace("$Instance", "");
        }

        protected double distance(NumberVector x, NumberVector y) {
            ++this.diststat;
            return this.df.distance(x, y);
        }

        protected void run(int maxiter) {
            Logging log = this.getLogger();
            IndefiniteProgress prog = log.isVerbose() ? new IndefiniteProgress("Iteration") : null;
            int iteration = 0;
            while (++iteration <= maxiter) {
                log.incrementProcessed(prog);
                int changed = this.iterate(iteration);
                if (changed == 0) break;
                if (!log.isStatistics()) continue;
                log.statistics(new LongStatistic(this.key + "." + iteration + ".reassignments", Math.abs(changed)));
                double s = VMath.sum(this.varsum);
                if (!(s > 0.0)) continue;
                log.statistics(new DoubleStatistic(this.key + "." + iteration + ".variance-sum", s));
            }
            log.setCompleted(prog);
            log.statistics(new LongStatistic(this.key + ".iterations", iteration));
            log.statistics(new LongStatistic(this.key + ".distance-computations", this.diststat));
        }

        protected abstract int iterate(int var1);

        protected void meansFromSums(double[][] dst, double[][] sums) {
            for (int i = 0; i < this.k; ++i) {
                VMath.overwriteTimes(dst[i], sums[i], 1.0 / (double)this.clusters.get(i).size());
            }
        }

        protected void copyMeans(double[][] src, double[][] dst) {
            for (int i = 0; i < this.k; ++i) {
                System.arraycopy(src[i], 0, dst[i], 0, src[i].length);
            }
        }

        protected int assignToNearestCluster() {
            assert (this.k == this.means.length);
            int changed = 0;
            Arrays.fill(this.varsum, 0.0);
            for (ModifiableDBIDs cluster : this.clusters) {
                cluster.clear();
            }
            DBIDIter iditer = this.relation.iterDBIDs();
            while (iditer.valid()) {
                double mindist = Double.POSITIVE_INFINITY;
                NumberVector fv = this.relation.get(iditer);
                int minIndex = 0;
                for (int i = 0; i < this.k; ++i) {
                    double dist = this.distance(fv, DoubleVector.wrap(this.means[i]));
                    if (!(dist < mindist)) continue;
                    minIndex = i;
                    mindist = dist;
                }
                int n = minIndex;
                this.varsum[n] = this.varsum[n] + mindist;
                this.clusters.get(minIndex).add(iditer);
                if (this.assignment.putInt(iditer, minIndex) != minIndex) {
                    ++changed;
                }
                iditer.advance();
            }
            return changed;
        }

        protected void recomputeSeperation(double[] sep, double[][] cdist) {
            int k = this.means.length;
            boolean issquared = this.df.isSquared();
            assert (sep.length == k);
            Arrays.fill(sep, Double.POSITIVE_INFINITY);
            for (int i = 1; i < k; ++i) {
                DoubleVector mi = DoubleVector.wrap(this.means[i]);
                for (int j = 0; j < i; ++j) {
                    double d = this.distance(mi, DoubleVector.wrap(this.means[j]));
                    double d2 = d = 0.5 * (issquared ? FastMath.sqrt(d) : d);
                    cdist[j][i] = d2;
                    cdist[i][j] = d2;
                    sep[i] = d < sep[i] ? d : sep[i];
                    sep[j] = d < sep[j] ? d : sep[j];
                }
            }
        }

        protected double movedDistance(double[][] means, double[][] newmeans, double[] dists) {
            assert (newmeans.length == means.length);
            assert (dists.length == means.length);
            boolean issquared = this.df.isSquared();
            double max = 0.0;
            for (int i = 0; i < means.length; ++i) {
                double d = this.distance(DoubleVector.wrap(means[i]), DoubleVector.wrap(newmeans[i]));
                d = issquared ? FastMath.sqrt(d) : d;
                dists[i] = d;
                max = d > max ? d : max;
            }
            return max;
        }

        protected Clustering<KMeansModel> buildResult() {
            Clustering<KMeansModel> result = new Clustering<KMeansModel>("k-Means Clustering", "kmeans-clustering");
            for (int i = 0; i < this.clusters.size(); ++i) {
                DBIDs ids = this.clusters.get(i);
                if (ids.isEmpty()) {
                    this.getLogger().warning("K-Means produced an empty cluster - bad initialization?");
                }
                result.addToplevelCluster(new Cluster<KMeansModel>(ids, new KMeansModel(this.means[i], this.varsum[i])));
            }
            return result;
        }

        protected Clustering<KMeansModel> buildResult(boolean varstat, Relation<? extends NumberVector> relation) {
            double totalvariance = 0.0;
            Clustering<KMeansModel> result = new Clustering<KMeansModel>("k-Means Clustering", "kmeans-clustering");
            for (int i = 0; i < this.clusters.size(); ++i) {
                DBIDs ids = this.clusters.get(i);
                if (ids.isEmpty()) continue;
                double varsum = 0.0;
                if (varstat) {
                    DoubleVector mvec = DoubleVector.wrap(this.means[i]);
                    DBIDIter it = ids.iter();
                    while (it.valid()) {
                        varsum += this.distance(mvec, relation.get(it));
                        it.advance();
                    }
                    totalvariance += varsum;
                }
                result.addToplevelCluster(new Cluster<KMeansModel>(ids, new KMeansModel(this.means[i], varsum)));
            }
            Logging log = this.getLogger();
            if (varstat && log.isStatistics()) {
                log.statistics(new DoubleStatistic(this.key + ".variance-sum", totalvariance));
                log.statistics(new LongStatistic(this.key + ".distance-computations", this.diststat));
            }
            return result;
        }

        protected boolean isSquared() {
            return this.df.isSquared();
        }

        abstract Logging getLogger();
    }
}

