package jss.glms;

import bayonet.math.SpecialFunctions;
import blang.core.BlangExtensions;
import blang.core.ConstantSupplier;
import blang.core.DeboxedName;
import blang.core.IntVar;
import blang.core.Model;
import blang.core.ModelBuilder;
import blang.core.ModelComponent;
import blang.core.Param;
import blang.core.RealVar;
import blang.inits.Arg;
import blang.inits.DesignatedConstructor;
import blang.io.GlobalDataSource;
import blang.types.Index;
import blang.types.Plate;
import blang.types.Plated;
import blang.types.StaticUtils;
import blang.types.internals.RealScalar;
import ca.ubc.stat.blang.StaticJavaUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Optional;
import java.util.function.Supplier;

@SuppressWarnings("all")
public class SpikeSlabClassification implements Model {
  public static class Builder implements ModelBuilder {
    private boolean fromCommandLine = false;
    
    @Arg
    public GlobalDataSource data;
    
    private boolean data_initialized = false;
    
    public SpikeSlabClassification.Builder setData(final GlobalDataSource data) {
      data_initialized = true;
      this.data = data;
      return this;
    }
    
    @Arg
    public Plate<String> instances;
    
    private boolean instances_initialized = false;
    
    public SpikeSlabClassification.Builder setInstances(final Plate<String> instances) {
      instances_initialized = true;
      this.instances = instances;
      return this;
    }
    
    @Arg
    public Plate<String> features;
    
    private boolean features_initialized = false;
    
    public SpikeSlabClassification.Builder setFeatures(final Plate<String> features) {
      features_initialized = true;
      this.features = features;
      return this;
    }
    
    @Arg
    public Plated<Double> covariates;
    
    private boolean covariates_initialized = false;
    
    public SpikeSlabClassification.Builder setCovariates(final Plated<Double> covariates) {
      covariates_initialized = true;
      this.covariates = covariates;
      return this;
    }
    
    @Arg
    public Plated<SpikedRealVar> parameters;
    
    private boolean parameters_initialized = false;
    
    public SpikeSlabClassification.Builder setParameters(final Plated<SpikedRealVar> parameters) {
      parameters_initialized = true;
      this.parameters = parameters;
      return this;
    }
    
    @Arg
    public Plated<IntVar> labels;
    
    private boolean labels_initialized = false;
    
    public SpikeSlabClassification.Builder setLabels(final Plated<IntVar> labels) {
      labels_initialized = true;
      this.labels = labels;
      return this;
    }
    
    @Arg
    public Optional<RealVar> activeProbability;
    
    public SpikeSlabClassification.Builder setActiveProbability(final RealVar activeProbability) {
      // work around typeRef(..) limitation
      Optional<RealVar> $generated__dummy = null;
      this.activeProbability = Optional.of(activeProbability);
      return this;
    }
    
    @Arg
    public Optional<RealVar> sigma;
    
    public SpikeSlabClassification.Builder setSigma(final RealVar sigma) {
      // work around typeRef(..) limitation
      Optional<RealVar> $generated__dummy = null;
      this.sigma = Optional.of(sigma);
      return this;
    }
    
    @Arg
    public Optional<RealVar> intercept;
    
    public SpikeSlabClassification.Builder setIntercept(final RealVar intercept) {
      // work around typeRef(..) limitation
      Optional<RealVar> $generated__dummy = null;
      this.intercept = Optional.of(intercept);
      return this;
    }
    
    public SpikeSlabClassification build() {
      // For each optional type, either get the value, or evaluate the ?: expression
      if (!fromCommandLine && !data_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing data");
      final GlobalDataSource __data = data;
      if (!fromCommandLine && !instances_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing instances");
      final Plate<String> __instances = instances;
      if (!fromCommandLine && !features_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing features");
      final Plate<String> __features = features;
      if (!fromCommandLine && !covariates_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing covariates");
      final Plated<Double> __covariates = covariates;
      if (!fromCommandLine && !parameters_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing parameters");
      final Plated<SpikedRealVar> __parameters = parameters;
      if (!fromCommandLine && !labels_initialized)
        throw new RuntimeException("Not all fields were set in the builder, e.g. missing labels");
      final Plated<IntVar> __labels = labels;
      RealVar activeProbability;
      if (this.activeProbability != null && this.activeProbability.isPresent()) {
        activeProbability = this.activeProbability.get();
      } else {
        activeProbability = $generated__20(data, instances, features, covariates, parameters, labels);
      }
      final RealVar __activeProbability = activeProbability;
      RealVar sigma;
      if (this.sigma != null && this.sigma.isPresent()) {
        sigma = this.sigma.get();
      } else {
        sigma = $generated__21(data, instances, features, covariates, parameters, labels, activeProbability);
      }
      final RealVar __sigma = sigma;
      RealVar intercept;
      if (this.intercept != null && this.intercept.isPresent()) {
        intercept = this.intercept.get();
      } else {
        intercept = $generated__22(data, instances, features, covariates, parameters, labels, activeProbability, sigma);
      }
      final RealVar __intercept = intercept;
      // Build the instance after boxing params
      return new SpikeSlabClassification(
        __parameters, 
        __labels, 
        __activeProbability, 
        __sigma, 
        __intercept, 
        new ConstantSupplier(__data), 
        new ConstantSupplier(__instances), 
        new ConstantSupplier(__features), 
        new ConstantSupplier(__covariates)
      );
    }
  }
  
  @DesignatedConstructor
  public static SpikeSlabClassification.Builder builderFromCommandLine() {
    Builder result = new Builder();
    result.fromCommandLine = true;
    return result;
  }
  
  @Param
  private final Supplier<GlobalDataSource> $generated__data;
  
  public GlobalDataSource getData() {
    return $generated__data.get();
  }
  
  @Param
  private final Supplier<Plate<String>> $generated__instances;
  
  public Plate<String> getInstances() {
    return $generated__instances.get();
  }
  
  @Param
  private final Supplier<Plate<String>> $generated__features;
  
  public Plate<String> getFeatures() {
    return $generated__features.get();
  }
  
  @Param
  private final Supplier<Plated<Double>> $generated__covariates;
  
  public Plated<Double> getCovariates() {
    return $generated__covariates.get();
  }
  
  private final Plated<SpikedRealVar> parameters;
  
  public Plated<SpikedRealVar> getParameters() {
    return parameters;
  }
  
  private final Plated<IntVar> labels;
  
  public Plated<IntVar> getLabels() {
    return labels;
  }
  
  private final RealVar activeProbability;
  
  public RealVar getActiveProbability() {
    return activeProbability;
  }
  
  private final RealVar sigma;
  
  public RealVar getSigma() {
    return sigma;
  }
  
  private final RealVar intercept;
  
  public RealVar getIntercept() {
    return intercept;
  }
  
  /**
   * Utility main method for posterior inference on this model
   */
  public static void main(final String[] arguments) {
    StaticJavaUtils.callRunner(Builder.class, arguments);
  }
  
  /**
   * Auxiliary method generated to translate:
   * instances.indices
   */
  private static Iterable<Index<String>> $generated__0(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    Collection<Index<String>> _indices = instances.indices();
    return _indices;
  }
  
  /**
   * Auxiliary method generated to translate:
   * DotProduct.of(features, parameters, covariates.slice(instance))
   */
  private static DotProduct $generated__1(final Index<String> instance, final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    DotProduct _of = DotProduct.<String>of(features, parameters, covariates.slice(instance));
    return _of;
  }
  
  /**
   * Auxiliary method generated to translate:
   * labels.get(instance)
   */
  private static IntVar $generated__2(final Index<String> instance, final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    IntVar _get = labels.get(instance);
    return _get;
  }
  
  /**
   * Auxiliary method generated to translate:
   * logistic(intercept + dotProduct.compute)
   */
  private static RealVar $generated__3(final RealVar intercept, final DotProduct dotProduct) {
    double _compute = dotProduct.compute();
    double _plus = BlangExtensions.operator_plus(intercept, Double.valueOf(_compute));
    double _logistic = SpecialFunctions.logistic(_plus);
    return new blang.core.RealConstant(_logistic);
  }
  
  public static class $generated__3_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__3(intercept, dotProduct);
    }
    
    public String toString() {
      return "logistic(intercept + dotProduct.compute)";
    }
    
    private final RealVar intercept;
    
    private final DotProduct dotProduct;
    
    public $generated__3_class(final RealVar intercept, final DotProduct dotProduct) {
      this.intercept = intercept;
      this.dotProduct = dotProduct;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * features.indices
   */
  private static Iterable<Index<String>> $generated__4(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    Collection<Index<String>> _indices = features.indices();
    return _indices;
  }
  
  /**
   * Auxiliary method generated to translate:
   * parameters.get(feature).selected
   */
  private static IntVar $generated__5(final Index<String> feature, final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    return parameters.get(feature).selected;
  }
  
  /**
   * Auxiliary method generated to translate:
   * activeProbability
   */
  private static RealVar $generated__6(final RealVar activeProbability) {
    return activeProbability;
  }
  
  public static class $generated__6_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__6(activeProbability);
    }
    
    public String toString() {
      return "activeProbability";
    }
    
    private final RealVar activeProbability;
    
    public $generated__6_class(final RealVar activeProbability) {
      this.activeProbability = activeProbability;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * parameters.get(feature).continuousPart
   */
  private static RealVar $generated__7(final Index<String> feature, final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    return parameters.get(feature).continuousPart;
  }
  
  /**
   * Auxiliary method generated to translate:
   * 1.0
   */
  private static RealVar $generated__8(final RealVar sigma) {
    return new blang.core.RealConstant(1.0);
  }
  
  public static class $generated__8_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__8(sigma);
    }
    
    public String toString() {
      return "1.0";
    }
    
    private final RealVar sigma;
    
    public $generated__8_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * 0.0
   */
  private static RealVar $generated__9(final RealVar sigma) {
    return new blang.core.RealConstant(0.0);
  }
  
  public static class $generated__9_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__9(sigma);
    }
    
    public String toString() {
      return "0.0";
    }
    
    private final RealVar sigma;
    
    public $generated__9_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * sigma
   */
  private static RealVar $generated__10(final RealVar sigma) {
    return sigma;
  }
  
  public static class $generated__10_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__10(sigma);
    }
    
    public String toString() {
      return "sigma";
    }
    
    private final RealVar sigma;
    
    public $generated__10_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * intercept
   */
  private static RealVar $generated__11(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    return intercept;
  }
  
  /**
   * Auxiliary method generated to translate:
   * 1.0
   */
  private static RealVar $generated__12(final RealVar sigma) {
    return new blang.core.RealConstant(1.0);
  }
  
  public static class $generated__12_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__12(sigma);
    }
    
    public String toString() {
      return "1.0";
    }
    
    private final RealVar sigma;
    
    public $generated__12_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * 0.0
   */
  private static RealVar $generated__13(final RealVar sigma) {
    return new blang.core.RealConstant(0.0);
  }
  
  public static class $generated__13_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__13(sigma);
    }
    
    public String toString() {
      return "0.0";
    }
    
    private final RealVar sigma;
    
    public $generated__13_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * sigma
   */
  private static RealVar $generated__14(final RealVar sigma) {
    return sigma;
  }
  
  public static class $generated__14_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__14(sigma);
    }
    
    public String toString() {
      return "sigma";
    }
    
    private final RealVar sigma;
    
    public $generated__14_class(final RealVar sigma) {
      this.sigma = sigma;
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * activeProbability
   */
  private static RealVar $generated__15(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    return activeProbability;
  }
  
  /**
   * Auxiliary method generated to translate:
   * 0
   */
  private static RealVar $generated__16() {
    return new blang.core.RealConstant(0);
  }
  
  public static class $generated__16_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__16();
    }
    
    public String toString() {
      return "0";
    }
    
    public $generated__16_class() {
      
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * 1
   */
  private static RealVar $generated__17() {
    return new blang.core.RealConstant(1);
  }
  
  public static class $generated__17_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__17();
    }
    
    public String toString() {
      return "1";
    }
    
    public $generated__17_class() {
      
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * sigma
   */
  private static RealVar $generated__18(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma, final RealVar intercept) {
    return sigma;
  }
  
  /**
   * Auxiliary method generated to translate:
   * 1.0
   */
  private static RealVar $generated__19() {
    return new blang.core.RealConstant(1.0);
  }
  
  public static class $generated__19_class implements Supplier<RealVar> {
    public RealVar get() {
      return $generated__19();
    }
    
    public String toString() {
      return "1.0";
    }
    
    public $generated__19_class() {
      
    }
  }
  
  /**
   * Auxiliary method generated to translate:
   * latentReal
   */
  private static RealVar $generated__20(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels) {
    RealScalar _latentReal = StaticUtils.latentReal();
    return _latentReal;
  }
  
  /**
   * Auxiliary method generated to translate:
   * latentReal
   */
  private static RealVar $generated__21(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability) {
    RealScalar _latentReal = StaticUtils.latentReal();
    return _latentReal;
  }
  
  /**
   * Auxiliary method generated to translate:
   * latentReal
   */
  private static RealVar $generated__22(final GlobalDataSource data, final Plate<String> instances, final Plate<String> features, final Plated<Double> covariates, final Plated<SpikedRealVar> parameters, final Plated<IntVar> labels, final RealVar activeProbability, final RealVar sigma) {
    RealScalar _latentReal = StaticUtils.latentReal();
    return _latentReal;
  }
  
  /**
   * Note: the generated code has the following properties used at runtime:
   *   - all arguments are annotated with a BlangVariable annotation
   *   - params additionally have a Param annotation
   *   - the order of the arguments is as follows:
   *     - first, all the random variables in the order they occur in the blang file
   *     - second, all the params in the order they occur in the blang file
   * 
   */
  public SpikeSlabClassification(@DeboxedName("parameters") final Plated<SpikedRealVar> parameters, @DeboxedName("labels") final Plated<IntVar> labels, @DeboxedName("activeProbability") final RealVar activeProbability, @DeboxedName("sigma") final RealVar sigma, @DeboxedName("intercept") final RealVar intercept, @Param @DeboxedName("data") final Supplier<GlobalDataSource> $generated__data, @Param @DeboxedName("instances") final Supplier<Plate<String>> $generated__instances, @Param @DeboxedName("features") final Supplier<Plate<String>> $generated__features, @Param @DeboxedName("covariates") final Supplier<Plated<Double>> $generated__covariates) {
    this.$generated__data = $generated__data;
    this.$generated__instances = $generated__instances;
    this.$generated__features = $generated__features;
    this.$generated__covariates = $generated__covariates;
    this.parameters = parameters;
    this.labels = labels;
    this.activeProbability = activeProbability;
    this.sigma = sigma;
    this.intercept = intercept;
  }
  
  /**
   * A component can be either a distribution, support constraint, or another model  
   * which recursively defines additional components.
   */
  public Collection<ModelComponent> components() {
    ArrayList<ModelComponent> components = new ArrayList();
    
    for (Index<String> instance : $generated__0($generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept)) {
      { // Code generated by: labels.get(instance) | intercept, DotProduct dotProduct = DotProduct.of(features, parameters, covariates.slice(instance)) ~ Bernoulli(logistic(intercept + dotProduct.compute))
        // Required initialization:
        DotProduct dotProduct = $generated__1(instance, $generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept);
        // Construction and addition of the factor/model:
        components.add(
          new blang.distributions.Bernoulli(
            $generated__2(instance, $generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
            new $generated__3_class(intercept, dotProduct)
          )
          );
      }
    }
    for (Index<String> feature : $generated__4($generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept)) {
      { // Code generated by: parameters.get(feature).selected | activeProbability ~ Bernoulli(activeProbability)
        // Construction and addition of the factor/model:
        components.add(
          new blang.distributions.Bernoulli(
            $generated__5(feature, $generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
            new $generated__6_class(activeProbability)
          )
          );
      }
      { // Code generated by: parameters.get(feature).continuousPart | sigma ~ StudentT(1.0, 0.0, sigma)
        // Construction and addition of the factor/model:
        components.add(
          new blang.distributions.StudentT(
            $generated__7(feature, $generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
            new $generated__8_class(sigma), 
            new $generated__9_class(sigma), 
            new $generated__10_class(sigma)
          )
          );
      }
    }
    { // Code generated by: intercept | sigma ~ StudentT(1.0, 0.0, sigma)
      // Construction and addition of the factor/model:
      components.add(
        new blang.distributions.StudentT(
          $generated__11($generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
          new $generated__12_class(sigma), 
          new $generated__13_class(sigma), 
          new $generated__14_class(sigma)
        )
        );
    }
    { // Code generated by: activeProbability ~ ContinuousUniform(0, 1)
      // Construction and addition of the factor/model:
      components.add(
        new blang.distributions.ContinuousUniform(
          $generated__15($generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
          new $generated__16_class(), 
          new $generated__17_class()
        )
        );
    }
    { // Code generated by: sigma ~ Exponential(1.0)
      // Construction and addition of the factor/model:
      components.add(
        new blang.distributions.Exponential(
          $generated__18($generated__data.get(), $generated__instances.get(), $generated__features.get(), $generated__covariates.get(), parameters, labels, activeProbability, sigma, intercept), 
          new $generated__19_class()
        )
        );
    }
    
    return components;
  }
}
