package jss.perm;

import bayonet.distributions.Random;
import blang.core.LogScaleFactor;
import blang.distributions.Generators;
import blang.mcmc.ConnectedFactor;
import blang.mcmc.SampledVariable;
import blang.mcmc.Sampler;
import java.util.Collections;
import java.util.List;

/**
 * Each time a Permutation is encountered in a Blang model,
 * this sampler will be instantiated.
 */
@SuppressWarnings("all")
public class PermutationSampler implements Sampler {
  /**
   * This field will be populated automatically with the
   * permutation being sampled.
   */
  @SampledVariable
  private Permutation permutation;
  
  /**
   * This will contain all the elements of the prior or likelihood
   * (collectively, factors), that depend on the permutation being
   * resampled.
   */
  @ConnectedFactor
  private List<LogScaleFactor> numericFactors;
  
  @Override
  public void execute(final Random rand) {
    final int n = this.permutation.componentSize();
    final int i = Generators.discreteUniform(rand, 0, n);
    final int j = Generators.discreteUniform(rand, 0, n);
    final double currentLogDensity = this.logDensity();
    Collections.swap(this.permutation.getConnections(), i, j);
    final double newLogDensity = this.logDensity();
    final double acceptProb = Math.min(1.0, Math.exp((newLogDensity - currentLogDensity)));
    final boolean accept = Generators.bernoulli(rand, acceptProb);
    if ((!accept)) {
      Collections.swap(this.permutation.getConnections(), i, j);
    }
  }
  
  public double logDensity() {
    double sum = 0.0;
    for (final LogScaleFactor f : this.numericFactors) {
      double _sum = sum;
      double _logDensity = f.logDensity();
      sum = (_sum + _logDensity);
    }
    return sum;
  }
}
