### Package Design
import pyrichlet
import numpy as np
rng = np.random.default_rng(0)
mixture = pyrichlet.mixture_models.DirichletProcessMixture(rng=rng)

y = np.concatenate([rng.normal(1, 1, size=100),
                    rng.normal(4, 2, size=200)])
y = y.reshape(300, -1)
mixture.fit_gibbs(y, init_groups=2)
mixture.var_fitted

mixture.gibbs_fitted

density = mixture.gibbs_eap_density([1, 2])
density

# Extensibility
from pyrichlet import BaseWeight
class BasicDirichletDistribution(BaseWeight):
  def __init__(self, n=1, alpha=1, rng=None):
    super().__init__(rng=rng)
    self.n = n
    self.alpha = np.array([alpha] * self.n, dtype=np.float64)

  def random(self, size=None):
    if len(self.d) > 0:
      a_c = np.bincount(self.d)
      a_c.resize(len(self.alpha), refcheck=False)
      self.w = self._rng.dirichlet(self.alpha + a_c)
    else:
      self.w = self._rng.dirichlet(self.alpha)
      return self.w

  def complete(self, size=None):
    super().complete(size)
    if len(self.w) == 0:
      self.random()
    return self.w


from pyrichlet.mixture_models._base import BaseGaussianMixture
class BasicDirichletDistributionMixture(BaseGaussianMixture):
  def __init__(self, *, n=1, alpha=1, rng=None, **kwargs):
    weight_model = BasicDirichletDistribution(n=n, alpha=alpha, rng=rng)
    self.n = n
    super().__init__(weight_model=weight_model, rng=rng, **kwargs)

mixture = BasicDirichletDistributionMixture(n=5)


### Package Usage
import pyrichlet
import numpy as np
y = pyrichlet.utils.load_penguins()[0]

mixture = pyrichlet.BetaInBetaMixture(x=0.5, rng=0)
mixture.fit_gibbs(y, init_groups=3)

# Density estimation
import matplotlib.pyplot as plt
density = mixture.gibbs_map_density()
plt.scatter(y.iloc[:, 0], y.iloc[:, 3], c=density)
cbar = plt.colorbar()
plt.clim(0, density.max() / 1.5)
cbar.set_label("density")
plt.xlabel(y.columns[0])
plt.ylabel(y.columns[3])
plt.savefig("../Figures/scatterdens.pdf")
plt.show()

from scipy.stats import multivariate_normal
XY = np.mgrid[30:60:0.375, 2500:6500:50]
y_space = XY.reshape(2, -1).T
w = mixture.map_sim_params["w"]
theta = mixture.map_sim_params["theta"]
density = []
for j in range(len(w)):
  density.append(
    multivariate_normal.pdf(y_space,
                            theta[j][0][[0, 3]],
                            theta[j][1][:, [0, 3]][[0, 3], :],
                            1))
density = w @ density
plt.scatter(y_space[:, 0], y_space[:, 1], c=density)
cbar = plt.colorbar()
plt.xlabel(y.columns[0])
plt.ylabel(y.columns[3])
plt.savefig("../Figures/projdens.pdf")
plt.show()

# Clustering
group, uncertainty = mixture.gibbs_map_cluster(y, full=True)
order = np.argsort(-uncertainty)
#plt.scatter(y.iloc[order, 0], y.iloc[order, 3], c=group[order],
#            s=5000 * (0.01 + uncertainty[order]))
plt.scatter(y.iloc[:, 0], y.iloc[:, 3], c=group,
             s=5000 * (0.01 + uncertainty))
plt.xlabel(y.columns[0])
plt.ylabel(y.columns[3])
plt.savefig("../Figures/grpuncert.pdf")
plt.show()

mixture.gibbs_map_pairplot()
plt.savefig("../Figures/pairplot.pdf")
plt.show()


### Comparison and Running Times
import time

def get_time_and_clustering(model):
  start_time = time.time()
  model.fit_gibbs(y, init_groups=3)
  group = model.gibbs_map_cluster(y)
  end_time = time.time()
  return end_time - start_time, group


# pyrichlet
times_and_clusters = {
  "DDM": get_time_and_clustering(
    pyrichlet.mixture_models.DirichletDistributionMixture(n=3, rng=rng)),
  "DPM": get_time_and_clustering(
    pyrichlet.mixture_models.DirichletProcessMixture(rng=rng)),
  "PYM": get_time_and_clustering(
    pyrichlet.mixture_models.PitmanYorMixture(pyd=0.1, rng=rng)),
  "GPM": get_time_and_clustering(
    pyrichlet.mixture_models.GeometricProcessMixture(rng=rng)),
  "BIDM": get_time_and_clustering(
    pyrichlet.mixture_models.BetaInDirichletMixture(a=0.1, rng=rng)),
  "BIBM": get_time_and_clustering(
    pyrichlet.mixture_models.BetaInBetaMixture(rng=rng)),
  "BBM": get_time_and_clustering(
    pyrichlet.mixture_models.BetaBinomialMixture(rng=rng)),
  "EWM": get_time_and_clustering(
    pyrichlet.mixture_models.EqualWeightedMixture(n=3, rng=rng)),
  "FWM": get_time_and_clustering(
    pyrichlet.mixture_models.FrequencyWeightedMixture(n=3, rng=rng))
}

# sklearn
from sklearn.mixture import BayesianGaussianMixture

start_time = time.time()
bgm = BayesianGaussianMixture(n_components=3, max_iter=1000)
sklearn_clustering = bgm.fit_predict(y)
end_time = time.time()
times_and_clusters["sklearn"] = (end_time - start_time, sklearn_clustering)

# mixes
from mixes import GMM

start_time = time.time()
gmm = GMM(3, num_iter=1000)
gmm.fit(y)
mixes_clustering = gmm.predict(y)
end_time = time.time()
times_and_clusters["mixes"] = (end_time - start_time, mixes_clustering)


# bayesmix                      
from bayesmixpy import run_mcmc
### for installation see https://github.com/bayesmix-dev/bayesmix/blob/master/INSTALL.md
### or refere to the colab notebook "replication_script.ipynb""

dp_params = """
fixed_value {
  totalmass: 1.0
}
"""
g0_params = f"""
fixed_values {{
  mean {{
    size: {y.shape[1]}"""
for x in y.mean():
  g0_params += f"""
    data: {x}"""
g0_params += """
  }
  var_scaling: 0.01
  deg_free: 5
  scale {
    rows: 4
    cols: 4
    data: 1.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 1.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 1.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 0.0
    data: 1.0
    rowmajor: false
  }
}
"""
neal2_algo = """
algo_id: "Neal2"
rng_seed: 1
iterations: 1000
burnin: 100
init_num_clusters: 1
"""
start_time = time.time()
out = run_mcmc(
  "NNW", "DP", (y - y.mean()).to_numpy(), g0_params, dp_params, neal2_algo,
  [], return_clusters=False, return_num_clusters=False,
  return_best_clus=True)
end_time = time.time()
bayesmix_clustering = out[3]
times_and_clusters["bayesmix"] = (end_time - start_time, bayesmix_clustering)

### R
import subprocess
import pandas as pd

subprocess.call("./aux.R")
r_clusters = pd.read_csv("temp/R_clusters.csv", index_col=0)
r_times = pd.read_csv("temp/R_times.csv", index_col=0).loc["elapsed"]
times_and_clusters |= {
  x: (y, z) for x, y, z in
  zip(r_times.index, r_times.values, r_clusters.T.values)
}

# Process data
from sklearn.metrics import mutual_info_score

info_scores = [
  mutual_info_score(pyrichlet.utils.load_penguins()[1], x[1]) for x in
  times_and_clusters.values()]
df_times_scores = pd.DataFrame(
  {"Mutual Information Score": info_scores,
   "Running Time": [x[0] for x in times_and_clusters.values()]},
  index=times_and_clusters.keys())

# Plot
fig = plt.figure()
ax = fig.add_subplot(111)
ax2 = ax.twinx()
df_times_scores["Mutual Information Score"].plot(
  kind="bar", color="#1f77b4", ax=ax, width=0.3, position=1, legend=True)
df_times_scores["Running Time"].plot(
  kind="bar", color="#ff7f0e", ax=ax2, width=0.3, position=0, legend=True,
  logy=True)
ax.set_ylabel("Score")
ax.legend(loc="upper left")
ax2.set_ylabel("Seconds")
ax.set_ylim(0, 1.19)
ax2.set_ylim(0.01, 150)
plt.xlim(-1, 17)
plt.tight_layout()
plt.savefig("../Figures/all_scores.pdf")
plt.show()


### Variational
def var_get_time_and_clustering(model):
  start_time = time.time()
  model.fit_variational(y, n_groups=3)
  group = model.var_map_cluster(y)
  end_time = time.time()
  return end_time - start_time, group


var_times_and_clusters = {
  "DDM": var_get_time_and_clustering(
    pyrichlet.mixture_models.DirichletDistributionMixture(n=3, rng=rng)),
  "DPM": var_get_time_and_clustering(
    pyrichlet.mixture_models.DirichletProcessMixture(rng=rng)),
  "PYM": var_get_time_and_clustering(
    pyrichlet.mixture_models.PitmanYorMixture(pyd=0.1, rng=rng)),
  "GPM": var_get_time_and_clustering(
    pyrichlet.mixture_models.GeometricProcessMixture(rng=rng)),
  "EWM": var_get_time_and_clustering(
    pyrichlet.mixture_models.EqualWeightedMixture(n=3, rng=rng)),
  "FWM": var_get_time_and_clustering(
    pyrichlet.mixture_models.FrequencyWeightedMixture(n=3, rng=rng))
}

for x in ["sklearn", "mixes", "mclust"]:
  var_times_and_clusters[x] = times_and_clusters[x]

var_info_scores = [
  mutual_info_score(pyrichlet.utils.load_penguins()[1], x[1]) for x in
  var_times_and_clusters.values()]
df_var_times_scores = pd.DataFrame(
  {"Mutual Information Score": var_info_scores,
   "Running Time": [x[0] for x in var_times_and_clusters.values()]},
  index=var_times_and_clusters.keys())
fig = plt.figure()
ax = fig.add_subplot(111)
ax2 = ax.twinx()
df_var_times_scores["Mutual Information Score"].plot(
  kind="bar", color="#1f77b4", ax=ax, width=0.3, position=1, legend=True)
df_var_times_scores["Running Time"].plot(
  kind="bar", color="#ff7f0e", ax=ax2, width=0.3, position=0, legend=True,
  logy=True)
ax.set_ylabel("Score")
ax.legend(loc="upper left")
ax2.set_ylabel("Seconds")
ax.set_ylim(0, 1.19)
ax2.set_ylim(0.01, 15)
plt.xlim(-1, 9)
plt.tight_layout()
plt.savefig("../Figures/all_var_scores.pdf")
plt.show()
