/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.statistics;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.LabelList;
import de.lmu.ifi.dbs.elki.data.type.AlternativeTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListMIter;
import de.lmu.ifi.dbs.elki.database.ids.HashSetModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDoubleDBIDList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.evaluation.scores.AveragePrecisionEvaluation;
import de.lmu.ifi.dbs.elki.evaluation.scores.ROCEvaluation;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.result.textwriter.TextWriteable;
import de.lmu.ifi.dbs.elki.result.textwriter.TextWriterStream;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import de.lmu.ifi.dbs.elki.utilities.random.RandomFactory;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;

public class EvaluateRetrievalPerformance<O>
extends AbstractDistanceBasedAlgorithm<O, RetrievalPerformanceResult> {
    private static final Logging LOG = Logging.getLogger(EvaluateRetrievalPerformance.class);
    protected double sampling = 1.0;
    protected RandomFactory random = null;
    protected boolean includeSelf;
    private final String PREFIX = this.getClass().getName();
    protected int maxk = 100;

    public EvaluateRetrievalPerformance(DistanceFunction<? super O> distanceFunction, double sampling, RandomFactory random, boolean includeSelf, int maxk) {
        super(distanceFunction);
        this.sampling = sampling;
        this.random = random;
        this.includeSelf = includeSelf;
        this.maxk = maxk;
    }

    public RetrievalPerformanceResult run(Database database, Relation<O> relation, Relation<?> lrelation) {
        DistanceQuery<O> distQuery = database.getDistanceQuery(relation, this.getDistanceFunction(), new Object[0]);
        DBIDs ids = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        HashSetModifiableDBIDs posn = DBIDUtil.newHashSet();
        ModifiableDoubleDBIDList nlist = DBIDUtil.newDistanceDBIDList(relation.size());
        Object2IntOpenHashMap<Object> counters = new Object2IntOpenHashMap<Object>();
        double map = 0.0;
        double mroc = 0.0;
        double[] knnperf = new double[this.maxk];
        int samples = 0;
        FiniteProgress objloop = LOG.isVerbose() ? new FiniteProgress("Processing query objects", ids.size(), LOG) : null;
        DBIDIter iter = ids.iter();
        while (iter.valid()) {
            Object label = lrelation.get(iter);
            this.findMatches(posn, lrelation, label);
            if (posn.size() > 0) {
                this.computeDistances(nlist, iter, distQuery, relation);
                if (nlist.size() != relation.size() - (this.includeSelf ? 0 : 1)) {
                    LOG.warning("Neighbor list does not have the desired size: " + nlist.size());
                }
                map += AveragePrecisionEvaluation.STATIC.evaluate(posn, nlist);
                mroc += ROCEvaluation.STATIC.evaluate(posn, nlist);
                KNNEvaluator.STATIC.evaluateKNN(knnperf, nlist, lrelation, counters, label);
                ++samples;
            }
            LOG.incrementProcessed(objloop);
            iter.advance();
        }
        LOG.ensureCompleted(objloop);
        if (samples < 1) {
            throw new AbortException("No object matched - are labels parsed correctly?");
        }
        if (!(map >= 0.0) || !(mroc >= 0.0)) {
            throw new AbortException("NaN in MAP/ROC.");
        }
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".map", map /= (double)samples));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".rocauc", mroc /= (double)samples));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".samples", samples));
        for (int k = 0; k < this.maxk; ++k) {
            knnperf[k] = knnperf[k] / (double)samples;
            LOG.statistics(new DoubleStatistic(this.PREFIX + ".knn-" + (k + 1), knnperf[k]));
        }
        return new RetrievalPerformanceResult(samples, map, mroc, knnperf);
    }

    protected static boolean match(Object ref, Object test) {
        if (ref == null) {
            return false;
        }
        if (ref == test) {
            return true;
        }
        if (ref instanceof LabelList && test instanceof LabelList) {
            LabelList lref = (LabelList)ref;
            LabelList ltest = (LabelList)test;
            int s1 = lref.size();
            int s2 = ltest.size();
            if (s1 == 0 || s2 == 0) {
                return false;
            }
            for (int i = 0; i < s1; ++i) {
                String l1 = lref.get(i);
                if (l1 == null) continue;
                for (int j = 0; j < s2; ++j) {
                    if (!l1.equals(ltest.get(j))) continue;
                    return true;
                }
            }
        }
        return ref.equals(test);
    }

    private void findMatches(ModifiableDBIDs posn, Relation<?> lrelation, Object label) {
        posn.clear();
        DBIDIter ri = lrelation.iterDBIDs();
        while (ri.valid()) {
            if (EvaluateRetrievalPerformance.match(label, lrelation.get(ri))) {
                posn.add(ri);
            }
            ri.advance();
        }
    }

    private void computeDistances(ModifiableDoubleDBIDList nlist, DBIDIter query, DistanceQuery<O> distQuery, Relation<O> relation) {
        nlist.clear();
        O qo = relation.get(query);
        DBIDIter ri = relation.iterDBIDs();
        while (ri.valid()) {
            if (this.includeSelf || !DBIDUtil.equal(ri, query)) {
                double dist = distQuery.distance((DBIDIter)qo, ri);
                if (dist != dist) {
                    dist = Double.POSITIVE_INFINITY;
                }
                nlist.add(dist, ri);
            }
            ri.advance();
        }
        nlist.sort();
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        AlternativeTypeInformation cls = new AlternativeTypeInformation(TypeUtil.CLASSLABEL, TypeUtil.LABELLIST);
        return TypeUtil.array(this.getDistanceFunction().getInputTypeRestriction(), cls);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Parameterizer<O>
    extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID SAMPLING_ID = new OptionID("map.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("map.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("map.includeself", "Include the query object in the evaluation.");
        public static final OptionID MAXK_ID = new OptionID("map.maxk", "Maximum value of k for kNN evaluation.");
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;
        protected int maxk = 0;

        @Override
        protected void makeOptions(Parameterization config) {
            IntParameter maxkP;
            Flag includeP;
            RandomParameter rndP;
            super.makeOptions(config);
            DoubleParameter samplingP = (DoubleParameter)((DoubleParameter)((DoubleParameter)new DoubleParameter(SAMPLING_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_EQUAL_ONE_DOUBLE)).setOptional(true);
            if (config.grab(samplingP)) {
                this.sampling = (Double)samplingP.getValue();
            }
            if (config.grab(rndP = new RandomParameter(SEED_ID))) {
                this.seed = (RandomFactory)rndP.getValue();
            }
            if (config.grab(includeP = new Flag(INCLUDESELF_ID))) {
                this.includeSelf = includeP.isTrue();
            }
            if (config.grab(maxkP = (IntParameter)((IntParameter)new IntParameter(MAXK_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).setOptional(true))) {
                this.maxk = maxkP.intValue();
            }
        }

        @Override
        protected EvaluateRetrievalPerformance<O> makeInstance() {
            return new EvaluateRetrievalPerformance(this.distanceFunction, this.sampling, this.seed, this.includeSelf, this.maxk);
        }
    }

    public static class RetrievalPerformanceResult
    implements Result,
    TextWriteable {
        private int samplesize;
        private double map;
        private double rocauc;
        private double[] knnperf;

        public RetrievalPerformanceResult(int samplesize, double map, double rocauc, double[] knnperf) {
            this.map = map;
            this.rocauc = rocauc;
            this.samplesize = samplesize;
            this.knnperf = knnperf;
        }

        public double getROCAUC() {
            return this.rocauc;
        }

        public double getMAP() {
            return this.map;
        }

        @Override
        public String getLongName() {
            return "Distance function retrieval evaluation.";
        }

        @Override
        public String getShortName() {
            return "distance-retrieval-evaluation";
        }

        @Override
        public void writeToText(TextWriterStream out, String label) {
            out.inlinePrintNoQuotes("MAP");
            out.inlinePrint(this.map);
            out.flush();
            out.inlinePrintNoQuotes("ROCAUC");
            out.inlinePrint(this.rocauc);
            out.flush();
            out.inlinePrintNoQuotes("Samplesize");
            out.inlinePrint(this.samplesize);
            out.flush();
            for (int i = 0; i < this.knnperf.length; ++i) {
                out.inlinePrintNoQuotes("knn-" + (i + 1));
                out.inlinePrint(this.knnperf[i]);
                out.flush();
            }
        }
    }

    public static class KNNEvaluator {
        public static final KNNEvaluator STATIC = new KNNEvaluator();

        public void evaluateKNN(double[] knnperf, ModifiableDoubleDBIDList nlist, Relation<?> lrelation, Object2IntOpenHashMap<Object> counters, Object label) {
            int maxk = knnperf.length;
            int k = 1;
            int prevk = 0;
            int max = 0;
            counters.clear();
            DoubleDBIDListMIter iter = nlist.iter();
            while (iter.valid() && prevk < maxk) {
                double prev = iter.doubleValue();
                Object l = lrelation.get(iter);
                max = Math.max(max, this.countkNN(counters, l));
                iter.advance();
                ++k;
                if (iter.valid() && !(iter.doubleValue() > prev)) continue;
                int pties = 0;
                int ties = 0;
                ObjectIterator cit = counters.object2IntEntrySet().fastIterator();
                block1: while (cit.hasNext()) {
                    Object2IntMap.Entry entry = (Object2IntMap.Entry)cit.next();
                    if (entry.getIntValue() < max) continue;
                    ++ties;
                    Object key = entry.getKey();
                    if (key == null) continue;
                    if (key.equals(label)) {
                        ++pties;
                        continue;
                    }
                    if (!(label instanceof LabelList)) continue;
                    LabelList ll = (LabelList)label;
                    int e = ll.size();
                    for (int i = 0; i < e; ++i) {
                        if (!key.equals(ll.get(i))) continue;
                        ++pties;
                        continue block1;
                    }
                }
                while (prevk < k && prevk < maxk) {
                    int n = prevk++;
                    knnperf[n] = knnperf[n] + (double)pties / (double)ties;
                }
            }
        }

        public int countkNN(Object2IntOpenHashMap<Object> counters, Object l) {
            if (l instanceof LabelList) {
                LabelList ll = (LabelList)l;
                int m = 0;
                int e = ll.size();
                for (int i = 0; i < e; ++i) {
                    m = Math.max(m, counters.addTo(ll.get(i), 1));
                }
                return m;
            }
            return counters.addTo(l, 1);
        }
    }
}

