/*
 * Decompiled with CFR 0.152.
 */
package de.lmu.ifi.dbs.elki.algorithm.clustering.em;

import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.em.EMClusterModel;
import de.lmu.ifi.dbs.elki.algorithm.clustering.em.EMClusterModelFactory;
import de.lmu.ifi.dbs.elki.algorithm.clustering.em.MultivariateGaussianModelFactory;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.model.MeanModel;
import de.lmu.ifi.dbs.elki.data.type.SimpleTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil;
import de.lmu.ifi.dbs.elki.database.datastore.WritableDataStore;
import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.MaterializedRelation;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic;
import de.lmu.ifi.dbs.elki.logging.statistics.LongStatistic;
import de.lmu.ifi.dbs.elki.math.linearalgebra.VMath;
import de.lmu.ifi.dbs.elki.utilities.Alias;
import de.lmu.ifi.dbs.elki.utilities.Priority;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.documentation.References;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.List;
import net.jafama.FastMath;

@Title(value="EM-Clustering: Clustering by Expectation Maximization")
@Description(value="Cluster data via Gaussian mixture modeling and the EM algorithm")
@References(value={@Reference(authors="A. P. Dempster, N. M. Laird, D. B. Rubin", title="Maximum Likelihood from Incomplete Data via the EM algorithm", booktitle="Journal of the Royal Statistical Society, Series B, 39(1)", url="http://www.jstor.org/stable/2984875", bibkey="journals/jroyastatsocise2/DempsterLR77"), @Reference(title="Bayesian Regularization for Normal Mixture Estimation and Model-Based Clustering", authors="C. Fraley, A. E. Raftery", booktitle="J. Classification 24(2)", url="https://doi.org/10.1007/s00357-007-0004-5", bibkey="DBLP:journals/classification/FraleyR07")})
@Alias(value={"de.lmu.ifi.dbs.elki.algorithm.clustering.EM"})
@Priority(value=200)
public class EM<V extends NumberVector, M extends MeanModel>
extends AbstractAlgorithm<Clustering<M>>
implements ClusteringAlgorithm<Clustering<M>> {
    private static final Logging LOG = Logging.getLogger(EM.class);
    private static final String KEY = EM.class.getName();
    private int k;
    private double delta;
    private EMClusterModelFactory<V, M> mfactory;
    private int maxiter;
    private double prior = 0.0;
    private boolean soft;
    private static final double MIN_LOGLIKELIHOOD = -100000.0;
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation<double[]>(double[].class);

    public EM(int k, double delta, EMClusterModelFactory<V, M> mfactory) {
        this(k, delta, mfactory, -1, 0.0, false);
    }

    public EM(int k, double delta, EMClusterModelFactory<V, M> mfactory, int maxiter, boolean soft) {
        this(k, delta, mfactory, maxiter, 0.0, soft);
    }

    public EM(int k, double delta, EMClusterModelFactory<V, M> mfactory, int maxiter, double prior, boolean soft) {
        this.k = k;
        this.delta = delta;
        this.mfactory = mfactory;
        this.maxiter = maxiter;
        this.prior = prior;
        this.soft = soft;
    }

    public Clustering<M> run(Database database, Relation<V> relation) {
        DoubleStatistic likestat;
        if (relation.size() == 0) {
            throw new IllegalArgumentException("database empty: must contain elements");
        }
        List<EMClusterModel<M>> models = this.mfactory.buildInitialModels(database, relation, this.k, SquaredEuclideanDistanceFunction.STATIC);
        WritableDataStore<double[]> probClusterIGivenX = DataStoreUtil.makeStorage(relation.getDBIDs(), 10, double[].class);
        double loglikelihood = EM.assignProbabilitiesToInstances(relation, models, probClusterIGivenX);
        DoubleStatistic doubleStatistic = likestat = LOG.isStatistics() ? new DoubleStatistic(this.getClass().getName() + ".loglikelihood") : null;
        if (LOG.isStatistics()) {
            LOG.statistics(likestat.setDouble(loglikelihood));
        }
        int it = 0;
        int lastimprovement = 0;
        double bestloglikelihood = loglikelihood;
        ++it;
        while (it < this.maxiter || this.maxiter < 0) {
            double oldloglikelihood = loglikelihood;
            EM.recomputeCovarianceMatrices(relation, probClusterIGivenX, models, this.prior);
            loglikelihood = EM.assignProbabilitiesToInstances(relation, models, probClusterIGivenX);
            if (LOG.isStatistics()) {
                LOG.statistics(likestat.setDouble(loglikelihood));
            }
            if (loglikelihood - bestloglikelihood > this.delta) {
                lastimprovement = it;
                bestloglikelihood = loglikelihood;
            }
            if (Math.abs(loglikelihood - oldloglikelihood) <= this.delta || lastimprovement < it >> 1) break;
            ++it;
        }
        if (LOG.isStatistics()) {
            LOG.statistics(new LongStatistic(KEY + ".iterations", it));
        }
        ArrayList<ArrayModifiableDBIDs> hardClusters = new ArrayList<ArrayModifiableDBIDs>(this.k);
        for (int i = 0; i < this.k; ++i) {
            hardClusters.add(DBIDUtil.newArray());
        }
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            ((ModifiableDBIDs)hardClusters.get(VMath.argmax((double[])probClusterIGivenX.get(iditer)))).add(iditer);
            iditer.advance();
        }
        Clustering<M> result = new Clustering<M>("EM Clustering", "em-clustering");
        for (int i = 0; i < this.k; ++i) {
            result.addToplevelCluster(new Cluster<M>((DBIDs)hardClusters.get(i), models.get(i).finalizeCluster()));
        }
        if (this.isSoft()) {
            result.addChildResult(new MaterializedRelation<double[]>("cluster assignments", "em-soft-score", SOFT_TYPE, probClusterIGivenX, relation.getDBIDs()));
        } else {
            probClusterIGivenX.destroy();
        }
        return result;
    }

    /*
     * WARNING - void declaration
     */
    public static void recomputeCovarianceMatrices(Relation<? extends NumberVector> relation, WritableDataStore<double[]> probClusterIGivenX, List<? extends EMClusterModel<?>> models, double prior) {
        void var8_14;
        int k = models.size();
        boolean needsTwoPass = false;
        for (EMClusterModel<?> eMClusterModel : models) {
            eMClusterModel.beginEStep();
            needsTwoPass |= eMClusterModel.needsTwoPass();
        }
        if (needsTwoPass) {
            DBIDIter iditer = relation.iterDBIDs();
            while (iditer.valid()) {
                double[] dArray = (double[])probClusterIGivenX.get(iditer);
                NumberVector instance = relation.get(iditer);
                for (int i = 0; i < dArray.length; ++i) {
                    double prob = dArray[i];
                    if (!(prob > 1.0E-10)) continue;
                    models.get(i).firstPassE(instance, prob);
                }
                iditer.advance();
            }
            for (EMClusterModel eMClusterModel : models) {
                eMClusterModel.finalizeFirstPassE();
            }
        }
        double[] wsum = new double[k];
        DBIDIter dBIDIter = relation.iterDBIDs();
        while (dBIDIter.valid()) {
            double[] clusterProbabilities = (double[])probClusterIGivenX.get(dBIDIter);
            NumberVector instance = relation.get(dBIDIter);
            int i = 0;
            while (i < clusterProbabilities.length) {
                double prob = clusterProbabilities[i];
                if (prob > 1.0E-10) {
                    models.get(i).updateE(instance, prob);
                }
                int n = i++;
                wsum[n] = wsum[n] + prob;
            }
            dBIDIter.advance();
        }
        boolean bl = false;
        while (var8_14 < models.size()) {
            double weight = prior <= 0.0 ? wsum[var8_14] / (double)relation.size() : (wsum[var8_14] + prior - 1.0) / ((double)relation.size() + prior * (double)k - (double)k);
            models.get((int)var8_14).finalizeEStep(weight, prior);
            ++var8_14;
        }
    }

    public static double assignProbabilitiesToInstances(Relation<? extends NumberVector> relation, List<? extends EMClusterModel<?>> models, WritableDataStore<double[]> probClusterIGivenX) {
        int k = models.size();
        double emSum = 0.0;
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            NumberVector vec = relation.get(iditer);
            double[] probs = new double[k];
            for (int i = 0; i < k; ++i) {
                double v = models.get(i).estimateLogDensity(vec);
                probs[i] = v > -100000.0 ? v : -100000.0;
            }
            double logP = EM.logSumExp(probs);
            for (int i = 0; i < k; ++i) {
                probs[i] = FastMath.exp(probs[i] - logP);
            }
            probClusterIGivenX.put(iditer, probs);
            emSum += logP;
            iditer.advance();
        }
        return emSum / (double)relation.size();
    }

    private static double logSumExp(double[] x) {
        double max = x[0];
        for (int i = 1; i < x.length; ++i) {
            double v = x[i];
            max = v > max ? v : max;
        }
        double cutoff = max - 35.350506209;
        double acc = 0.0;
        for (int i = 0; i < x.length; ++i) {
            double v = x[i];
            if (!(v > cutoff)) continue;
            acc += v < max ? FastMath.exp(v - max) : 1.0;
        }
        return acc > 1.0 ? max + FastMath.log(acc) : max;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public boolean isSoft() {
        return this.soft;
    }

    public void setSoft(boolean soft) {
        this.soft = soft;
    }

    public static class Parameterizer<V extends NumberVector, M extends MeanModel>
    extends AbstractParameterizer {
        public static final OptionID K_ID = new OptionID("em.k", "The number of clusters to find.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID INIT_ID = new OptionID("em.model", "Model factory.");
        public static final OptionID PRIOR_ID = new OptionID("em.map.prior", "Regularization factor for MAP estimation.");
        protected int k;
        protected double delta;
        protected EMClusterModelFactory<V, M> initializer;
        protected int maxiter = -1;
        double prior = 0.0;

        @Override
        protected void makeOptions(Parameterization config) {
            DoubleParameter priorP;
            IntParameter maxiterP;
            DoubleParameter deltaP;
            ObjectParameter initialP;
            super.makeOptions(config);
            IntParameter kP = (IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (config.grab(kP)) {
                this.k = (Integer)kP.getValue();
            }
            if (config.grab(initialP = new ObjectParameter(INIT_ID, (Class<?>)EMClusterModelFactory.class, MultivariateGaussianModelFactory.class))) {
                this.initializer = (EMClusterModelFactory)initialP.instantiateClass(config);
            }
            if (config.grab(deltaP = (DoubleParameter)new DoubleParameter(DELTA_ID, 1.0E-7).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE))) {
                this.delta = (Double)deltaP.getValue();
            }
            if (config.grab(maxiterP = (IntParameter)((IntParameter)new IntParameter(KMeans.MAXITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true))) {
                this.maxiter = (Integer)maxiterP.getValue();
            }
            if (config.grab(priorP = (DoubleParameter)((DoubleParameter)new DoubleParameter(PRIOR_ID).setOptional(true)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE))) {
                this.prior = priorP.doubleValue();
            }
        }

        @Override
        protected EM<V, M> makeInstance() {
            return new EM<V, M>(this.k, this.delta, this.initializer, this.maxiter, this.prior, false);
        }
    }
}

