def plot_coin_toss_results(varient=''):
varient = str(varient)
samples, alpha_prior, beta_prior = get_data("../../data/coin_toss/coin_toss"+varient)
plt.hist(samples)
plt.ylabel("frequency")
plt.title("Given Data")
plt.show()
all_labels = [] # add the labels as you go..
all_pdfs = [] # add the pdfs as you go..
x = jnp.linspace(0.01,0.99,100)
one= jnp.sum(samples==1).astype('float32')
zero= jnp.sum(samples==0).astype('float32')
print(alpha_prior,beta_prior,one,zero)
true_post_dist = tfd.Beta(alpha_prior+one,beta_prior+zero)
true_post_pdf = true_post_dist.prob(x)
all_labels.append("True Posterior")
all_pdfs.append(true_post_pdf)
with open('results_data/coin_toss_VI_Ajax_result'+varient,'rb') as f:
posterior = pickle.load(f)
samples_ajax = posterior.sample(seed = jax.random.PRNGKey(3),sample_shape = (10000,))
ajax_kde = gaussian_kde(samples_ajax["theta"])
ajax_vi_pdf = ajax_kde(x)
all_labels.append("AJAX VI")
all_pdfs.append(ajax_vi_pdf)
with open('results_data/HMC_Stan'+varient,'rb') as f:
stan_hmc = pickle.load(f)
stan_hmc_kde=gaussian_kde(stan_hmc)
stan_hmc_pdf=stan_hmc_kde(x)
all_labels.append("Stan HMC")
all_pdfs.append(stan_hmc_pdf)
with open('results_data/MCMC_BlackJAX'+varient,'rb') as black_f:
black_samples = pickle.load(black_f)
kde_black = gaussian_kde(black_samples.position['x'][300:,0])
pdf_black = kde_black(x)
all_labels.append("Blackjax rmh estimate")
all_pdfs.append(pdf_black)
with open('results_data/coin_toss_laplace_result'+varient,'rb') as f:
laplace_dist = pickle.load(f)
laplace_pdf = laplace_dist.prob(x)
all_labels.append("Laplace")
all_pdfs.append(laplace_pdf)
# laplace_dict = pd.read_pickle('results_data/laplace_coin_toss'+varient)
# laplace_posterior = laplace_dict['model'].apply(laplace_dict['params'], laplace_dict['data'])
# laplace_pdf = jnp.exp(laplace_posterior.log_prob({'p_of_h': x}, sample_shape=(len(x), )))
# all_labels.append("Laplace")
# all_pdfs.append(laplace_pdf)
# laplax_dict = pd.read_pickle('results_data/laplax_coin_toss'+varient)
# laplax_posterior = laplax_dict['model'].apply(laplax_dict['params'], laplax_dict['data'])
# laplax_pdf = jnp.exp(laplax_posterior.log_prob({'p_of_h': x}, sample_shape=(len(x), )))
# all_labels.append("Laplax")
# all_pdfs.append(laplax_pdf)
all_pdfs = jnp.array(all_pdfs).reshape((-1))
no_estimates = len(all_labels)
all_labels_repeated = [item for item in all_labels for i in range(x.shape[0])]
x_repeated = jnp.tile(x,no_estimates)
to_df = {
"theta":x_repeated,
"PDF":all_pdfs,
"label": all_labels_repeated
}
df = pd.DataFrame(to_df)
fig = px.line(to_df,"theta","PDF",color="label",title=f"Coin toss posterior prior=({alpha_prior},{beta_prior})")
fig.show()