/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.benchmark;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRange;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.Util;
import de.lmu.ifi.dbs.elki.utilities.exceptions.IncompatibleDataException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import de.lmu.ifi.dbs.elki.utilities.random.RandomFactory;

public class KNNBenchmarkAlgorithm<O>
extends AbstractDistanceBasedAlgorithm<O, Result> {
    private static final Logging LOG = Logging.getLogger(KNNBenchmarkAlgorithm.class);
    protected int k = 10;
    protected DatabaseConnection queries = null;
    protected double sampling = -1.0;
    protected RandomFactory random;

    public KNNBenchmarkAlgorithm(DistanceFunction<? super O> distanceFunction, int k, DatabaseConnection queries, double sampling, RandomFactory random) {
        super(distanceFunction);
        this.k = k;
        this.queries = queries;
        this.sampling = sampling;
        this.random = random;
    }

    public Result run(Database database, Relation<O> relation) {
        DistanceQuery<O> distQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        KNNQuery<Object> knnQuery = database.getKNNQuery(distQuery, this.k);
        if (this.queries == null) {
            DBIDs sample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
            FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null;
            int hash = 0;
            MeanVariance mv = new MeanVariance();
            MeanVariance mvdist = new MeanVariance();
            DBIDIter iditer = sample.iter();
            while (iditer.valid()) {
                KNNList knns = knnQuery.getKNNForDBID(iditer, this.k);
                int ichecksum = 0;
                DoubleDBIDListIter it = knns.iter();
                while (it.valid()) {
                    ichecksum += DBIDUtil.asInteger(it);
                    it.advance();
                }
                hash = Util.mixHashCodes(hash, ichecksum);
                mv.put(knns.size());
                mvdist.put(knns.getKNNDistance());
                LOG.incrementProcessed(prog);
                iditer.advance();
            }
            LOG.ensureCompleted(prog);
            if (LOG.isStatistics()) {
                LOG.statistics("Result hashcode: " + hash);
                LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev());
                if (mvdist.getCount() > 0.0) {
                    LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev());
                }
            }
        } else {
            TypeInformation res = this.getDistanceFunction().getInputTypeRestriction();
            MultipleObjectsBundle bundle = this.queries.loadData();
            int col = -1;
            for (int i = 0; i < bundle.metaLength(); ++i) {
                if (!res.isAssignableFromType(bundle.meta(i))) continue;
                col = i;
                break;
            }
            if (col < 0) {
                throw new IncompatibleDataException("No compatible data type in query input was found. Expected: " + res.toString());
            }
            DBIDRange sids = DBIDUtil.generateStaticDBIDRange(bundle.dataLength());
            DBIDs sample = DBIDUtil.randomSample((DBIDs)sids, this.sampling, this.random);
            FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null;
            int hash = 0;
            MeanVariance mv = new MeanVariance();
            MeanVariance mvdist = new MeanVariance();
            DBIDIter iditer = sample.iter();
            while (iditer.valid()) {
                int off = sids.binarySearch(iditer);
                assert (off >= 0);
                Object o = bundle.data(off, col);
                KNNList knns = knnQuery.getKNNForObject(o, this.k);
                int ichecksum = 0;
                DoubleDBIDListIter it = knns.iter();
                while (it.valid()) {
                    ichecksum += DBIDUtil.asInteger(it);
                    it.advance();
                }
                hash = Util.mixHashCodes(hash, ichecksum);
                mv.put(knns.size());
                mvdist.put(knns.getKNNDistance());
                LOG.incrementProcessed(prog);
                iditer.advance();
            }
            LOG.ensureCompleted(prog);
            if (LOG.isStatistics()) {
                LOG.statistics("Result hashcode: " + hash);
                LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev());
                if (mvdist.getCount() > 0.0) {
                    LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev());
                }
            }
        }
        return null;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction());
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID K_ID = new OptionID("knnbench.k", "Number of neighbors to retreive for kNN benchmarking.");
        public static final OptionID QUERY_ID = new OptionID("knnbench.query", "Data source for the queries. If not set, the queries are taken from the database.");
        public static final OptionID SAMPLING_ID = new OptionID("knnbench.sampling", "Sampling size parameter. If the value is less or equal 1, it is assumed to be the relative share. Larger values will be interpreted as integer sizes. By default, all data will be used.");
        public static final OptionID RANDOM_ID = new OptionID("knnbench.random", "Random generator for sampling.");
        protected int k = 10;
        protected DatabaseConnection queries = null;
        protected double sampling = -1.0;
        protected RandomFactory random;

        @Override
        protected void makeOptions(Parameterization config) {
            RandomParameter randomP;
            DoubleParameter samplingP;
            super.makeOptions(config);
            IntParameter kP = (IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (config.grab(kP)) {
                this.k = kP.intValue();
            }
            ObjectParameter queryP = new ObjectParameter(QUERY_ID, DatabaseConnection.class);
            queryP.setOptional(true);
            if (config.grab(queryP)) {
                this.queries = (DatabaseConnection)queryP.instantiateClass(config);
            }
            if (config.grab(samplingP = (DoubleParameter)((DoubleParameter)new DoubleParameter(SAMPLING_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).setOptional(true))) {
                this.sampling = samplingP.doubleValue();
            }
            if (config.grab(randomP = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT))) {
                this.random = (RandomFactory)randomP.getValue();
            }
        }

        @Override
        protected KNNBenchmarkAlgorithm<O> makeInstance() {
            return new KNNBenchmarkAlgorithm(this.distanceFunction, this.k, this.queries, this.sampling, this.random);
        }
    }
}

