/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.clustering;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.model.MeanModel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDList;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.range.RangeQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.database.relation.RelationUtil;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.linearalgebra.Centroid;
import de.lmu.ifi.dbs.elki.math.statistics.kernelfunctions.EpanechnikovKernelDensityFunction;
import de.lmu.ifi.dbs.elki.math.statistics.kernelfunctions.KernelDensityFunction;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.pairs.Pair;
import java.util.ArrayList;

@Reference(authors="Y. Cheng", title="Mean shift, mode seeking, and clustering", booktitle="IEEE Transactions on Pattern Analysis and Machine Intelligence 17-8", url="https://doi.org/10.1109/34.400568", bibkey="DBLP:journals/pami/Cheng95")
public class NaiveMeanShiftClustering<V extends NumberVector>
extends AbstractDistanceBasedAlgorithm<V, Clustering<MeanModel>>
implements ClusteringAlgorithm<Clustering<MeanModel>> {
    private static final Logging LOG = Logging.getLogger(NaiveMeanShiftClustering.class);
    KernelDensityFunction kernel = EpanechnikovKernelDensityFunction.KERNEL;
    double bandwidth;
    static final int MAXITER = 1000;

    public NaiveMeanShiftClustering(DistanceFunction<? super V> distanceFunction, KernelDensityFunction kernel, double range) {
        super(distanceFunction);
        this.kernel = kernel;
        this.bandwidth = range;
    }

    public Clustering<MeanModel> run(Database database, Relation<V> relation) {
        DistanceQuery<V> distq = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        RangeQuery<V> rangeq = database.getRangeQuery(distq, new Object[0]);
        NumberVector.Factory<V> factory = RelationUtil.getNumberVectorFactory(relation);
        int dim = RelationUtil.dimensionality(relation);
        double threshold = this.bandwidth * 1.0E-10;
        ArrayList<Pair<Object, ArrayModifiableDBIDs>> clusters = new ArrayList<Pair<Object, ArrayModifiableDBIDs>>();
        ArrayModifiableDBIDs noise = DBIDUtil.newArray();
        FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("Mean-shift clustering", relation.size(), LOG) : null;
        DBIDIter iter = relation.iterDBIDs();
        while (iter.valid()) {
            Object position = (NumberVector)relation.get(iter);
            int n = 1;
            while (true) {
                boolean okay;
                Object newvec = null;
                DoubleDBIDList neigh = rangeq.getRangeForObject(position, this.bandwidth);
                boolean bl = okay = neigh.size() > 1 || neigh.size() >= 1 && n > 1;
                if (okay) {
                    Centroid newpos = new Centroid(dim);
                    DoubleDBIDListIter niter = neigh.iter();
                    while (niter.valid()) {
                        double d = this.kernel.density(niter.doubleValue() / this.bandwidth);
                        newpos.put((NumberVector)relation.get(niter), d);
                        niter.advance();
                    }
                    newvec = factory.newNumberVector(newpos.getArrayRef());
                }
                if (!okay) {
                    noise.add(iter);
                    break;
                }
                double bestd = Double.POSITIVE_INFINITY;
                Pair bestp = null;
                for (Pair pair : clusters) {
                    double merged = distq.distance(newvec, pair.first);
                    if (!(merged < bestd)) continue;
                    bestd = merged;
                    bestp = pair;
                }
                double delta = distq.distance(position, newvec);
                if (bestd < 10.0 * threshold || bestd * 2.0 < delta) {
                    assert (bestp != null);
                    ((ModifiableDBIDs)bestp.second).add(iter);
                    break;
                }
                if (Double.isNaN(delta)) {
                    LOG.warning("Encountered NaN distance. Invalid center vector? " + newvec.toString());
                    break;
                }
                if (n == 1000 || delta < threshold) {
                    if (n == 1000) {
                        LOG.warning("No convergence after 1000 iterations. Distance: " + delta);
                    }
                    if (LOG.isDebuggingFine()) {
                        LOG.debugFine("New cluster:" + newvec + " delta: " + delta + " threshold: " + threshold + " bestd: " + bestd);
                    }
                    ArrayModifiableDBIDs cids = DBIDUtil.newArray(1);
                    cids.add(iter);
                    clusters.add(new Pair<Object, ArrayModifiableDBIDs>(newvec, cids));
                    break;
                }
                position = newvec;
                ++n;
            }
            LOG.incrementProcessed(prog);
            iter.advance();
        }
        LOG.ensureCompleted(prog);
        ArrayList cs = new ArrayList(clusters.size());
        for (Pair pair : clusters) {
            cs.add(new Cluster<MeanModel>((DBIDs)pair.second, new MeanModel(((NumberVector)pair.first).toArray())));
        }
        if (noise.size() > 0) {
            cs.add(new Cluster((DBIDs)noise, true));
        }
        return new Clustering<MeanModel>("Mean-shift Clustering", "mean-shift-clustering", cs);
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<V extends NumberVector>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<V> {
        public static final OptionID KERNEL_ID = new OptionID("meanshift.kernel", "Kernel function to use with mean-shift clustering.");
        public static final OptionID RANGE_ID = new OptionID("meanshift.kernel-bandwidth", "Range of the kernel to use (aka: radius, bandwidth).");
        KernelDensityFunction kernel = EpanechnikovKernelDensityFunction.KERNEL;
        double range;

        @Override
        protected void makeOptions(Parameterization config) {
            DoubleParameter rangeP;
            super.makeOptions(config);
            ObjectParameter kernelP = new ObjectParameter(KERNEL_ID, (Class<?>)KernelDensityFunction.class, EpanechnikovKernelDensityFunction.class);
            if (config.grab(kernelP)) {
                this.kernel = (KernelDensityFunction)kernelP.instantiateClass(config);
            }
            if (config.grab(rangeP = new DoubleParameter(RANGE_ID))) {
                this.range = (Double)rangeP.getValue();
            }
        }

        @Override
        protected NaiveMeanShiftClustering<V> makeInstance() {
            return new NaiveMeanShiftClustering(this.distanceFunction, this.kernel, this.range);
        }
    }
}

