/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.statistics;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.LabelList;
import de.lmu.ifi.dbs.elki.data.type.AlternativeTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVarianceMinMax;
import de.lmu.ifi.dbs.elki.result.CollectionResult;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import de.lmu.ifi.dbs.elki.utilities.random.RandomFactory;
import java.util.ArrayList;

public class AveragePrecisionAtK<O>
extends AbstractDistanceBasedAlgorithm<O, CollectionResult<DoubleVector>> {
    private static final Logging LOG = Logging.getLogger(AveragePrecisionAtK.class);
    private int k;
    private double sampling = 1.0;
    private RandomFactory random = null;
    private boolean includeSelf;

    public AveragePrecisionAtK(DistanceFunction<? super O> distanceFunction, int k, double sampling, RandomFactory random, boolean includeSelf) {
        super(distanceFunction);
        this.k = k;
        this.sampling = sampling;
        this.random = random;
        this.includeSelf = includeSelf;
    }

    public CollectionResult<double[]> run(Database database, Relation<O> relation, Relation<?> lrelation) {
        DistanceQuery<O> distQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        int qk = this.k + (this.includeSelf ? 0 : 1);
        KNNQuery<O> knnQuery = database.getKNNQuery(distQuery, qk);
        MeanVarianceMinMax[] mvs = MeanVarianceMinMax.newArray(this.k);
        DBIDs ids = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        FiniteProgress objloop = LOG.isVerbose() ? new FiniteProgress("Computing nearest neighbors", ids.size(), LOG) : null;
        DBIDIter iter = ids.iter();
        while (iter.valid()) {
            KNNList knn = knnQuery.getKNNForDBID(iter, qk);
            Object label = lrelation.get(iter);
            int positive = 0;
            int i = 0;
            DoubleDBIDListIter ri = knn.iter();
            while (i < this.k && ri.valid()) {
                if (this.includeSelf || !DBIDUtil.equal(iter, ri)) {
                    double precision = (double)(positive += AveragePrecisionAtK.match(label, lrelation.get(ri)) ? 1 : 0) / (double)(i + 1);
                    mvs[i].put(precision);
                    ++i;
                }
                ri.advance();
            }
            LOG.incrementProcessed(objloop);
            iter.advance();
        }
        LOG.ensureCompleted(objloop);
        ArrayList<double[]> res = new ArrayList<double[]>(this.k);
        for (int i = 0; i < this.k; ++i) {
            MeanVarianceMinMax mv = mvs[i];
            double std = mv.getCount() > 1.0 ? mv.getSampleStddev() : 0.0;
            res.add(new double[]{i + 1, mv.getMean(), std, mv.getMin(), mv.getMax(), mv.getCount()});
        }
        return new CollectionResult<double[]>("Average Precision", "average-precision", res);
    }

    protected static boolean match(Object ref, Object test) {
        if (ref == null) {
            return false;
        }
        if (ref == test) {
            return true;
        }
        if (ref instanceof LabelList && test instanceof LabelList) {
            LabelList lref = (LabelList)ref;
            LabelList ltest = (LabelList)test;
            int s1 = lref.size();
            int s2 = ltest.size();
            if (s1 == 0 || s2 == 0) {
                return false;
            }
            for (int i = 0; i < s1; ++i) {
                String l1 = lref.get(i);
                if (l1 == null) continue;
                for (int j = 0; j < s2; ++j) {
                    if (!l1.equals(ltest.get(j))) continue;
                    return true;
                }
            }
        }
        return ref.equals(test);
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        AlternativeTypeInformation cls = new AlternativeTypeInformation(TypeUtil.CLASSLABEL, TypeUtil.LABELLIST);
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction(), cls);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        private static final OptionID K_ID = new OptionID("avep.k", "K to compute the average precision at.");
        public static final OptionID SAMPLING_ID = new OptionID("avep.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("avep.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("avep.includeself", "Include the query object in the evaluation.");
        protected int k = 20;
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;

        @Override
        protected void makeOptions(Parameterization config) {
            Flag includeP;
            RandomParameter rndP;
            DoubleParameter samplingP;
            super.makeOptions(config);
            IntParameter kP = (IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ONE_INT);
            if (config.grab(kP)) {
                this.k = (Integer)kP.getValue();
            }
            if (config.grab(samplingP = (DoubleParameter)((DoubleParameter)((DoubleParameter)new DoubleParameter(SAMPLING_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_EQUAL_ONE_DOUBLE)).setOptional(true))) {
                this.sampling = (Double)samplingP.getValue();
            }
            if (config.grab(rndP = new RandomParameter(SEED_ID))) {
                this.seed = (RandomFactory)rndP.getValue();
            }
            if (config.grab(includeP = new Flag(INCLUDESELF_ID))) {
                this.includeSelf = includeP.isTrue();
            }
        }

        @Override
        protected AveragePrecisionAtK<O> makeInstance() {
            return new AveragePrecisionAtK(this.distanceFunction, this.k, this.sampling, this.seed, this.includeSelf);
        }
    }
}

