/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.classification;

import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.DistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.Priority;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.Collections;

@Title(value="kNN-classifier")
@Description(value="Lazy classifier classifies a given instance to the majority class of the k-nearest neighbors.")
@Priority(value=100)
public class KNNClassifier<O>
extends AbstractAlgorithm<Result>
implements DistanceBasedAlgorithm<O>,
Classifier<O> {
    private static final Logging LOG = Logging.getLogger(KNNClassifier.class);
    protected int k;
    protected KNNQuery<O> knnq;
    protected Relation<? extends ClassLabel> labelrep;
    protected DistanceFunction<? super O> distanceFunction;

    public KNNClassifier(DistanceFunction<? super O> distanceFunction, int k) {
        this.distanceFunction = distanceFunction;
        this.k = k;
    }

    @Override
    public void buildClassifier(Database database, Relation<? extends ClassLabel> labels) {
        Relation relation = database.getRelation(this.getDistanceFunction().getInputTypeRestriction(), new Object[0]);
        DistanceQuery distanceQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        this.knnq = database.getKNNQuery(distanceQuery, this.k);
        this.labelrep = labels;
    }

    @Override
    public ClassLabel classify(O instance) {
        Object2IntOpenHashMap<ClassLabel> count = new Object2IntOpenHashMap<ClassLabel>();
        KNNList query = this.knnq.getKNNForObject(instance, this.k);
        DoubleDBIDListIter neighbor = query.iter();
        while (neighbor.valid()) {
            count.addTo(this.labelrep.get(neighbor), 1);
            neighbor.advance();
        }
        int bestoccur = Integer.MIN_VALUE;
        ClassLabel bestl = null;
        ObjectIterator iter = count.object2IntEntrySet().fastIterator();
        while (iter.hasNext()) {
            Object2IntMap.Entry entry = (Object2IntMap.Entry)iter.next();
            if (entry.getIntValue() <= bestoccur) continue;
            bestoccur = entry.getIntValue();
            bestl = (ClassLabel)entry.getKey();
        }
        return bestl;
    }

    public double[] classProbabilities(O instance, ArrayList<ClassLabel> labels) {
        int[] occurences = new int[labels.size()];
        KNNList query = this.knnq.getKNNForObject(instance, this.k);
        DoubleDBIDListIter neighbor = query.iter();
        while (neighbor.valid()) {
            int index = Collections.binarySearch(labels, this.labelrep.get(neighbor));
            if (index >= 0) {
                int n = index;
                occurences[n] = occurences[n] + 1;
            }
            neighbor.advance();
        }
        double[] distribution = new double[labels.size()];
        for (int i = 0; i < distribution.length; ++i) {
            distribution[i] = (double)occurences[i] / (double)query.size();
        }
        return distribution;
    }

    @Override
    public String model() {
        return "lazy learner - provides no model";
    }

    @Override
    @Deprecated
    public Result run(Database database) throws IllegalStateException {
        throw new AbortException("Classifiers cannot auto-run on a database, but need to be trained and can then predict.");
    }

    @Override
    public DistanceFunction<? super O> getDistanceFunction() {
        return this.distanceFunction;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractParameterizer {
        public static final OptionID K_ID = new OptionID("knnclassifier.k", "The number of neighbors to take into account for classification.");
        protected DistanceFunction<? super O> distanceFunction;
        protected int k;

        @Override
        protected void makeOptions(Parameterization config) {
            IntParameter kP;
            super.makeOptions(config);
            ObjectParameter distP = new ObjectParameter(DistanceBasedAlgorithm.DISTANCE_FUNCTION_ID, (Class<?>)DistanceFunction.class, EuclideanDistanceFunction.class);
            if (config.grab(distP)) {
                this.distanceFunction = (DistanceFunction)distP.instantiateClass(config);
            }
            if (config.grab(kP = (IntParameter)new IntParameter(K_ID, 1).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT))) {
                this.k = kP.intValue();
            }
        }

        @Override
        protected KNNClassifier<O> makeInstance() {
            return new KNNClassifier<O>(this.distanceFunction, this.k);
        }
    }
}

