import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
x = jnp.linspace(-15,25,10000)
with open('./results_data/logistic_full_rank_Ajax','rb') as f:
variational = pickle.load(f)
params = variational.get_params()
loc_m, scale = jax.tree_leaves(variational.transform_dist(params['beta']))
scale = jnp.dot(scale, scale.T)
all_pdf = []
for i in range(3):
y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
all_pdf.append(y)
with open('./results_data/logistic_regression_laplace','rb') as f:
laplace = pickle.load(f)
loc_m = laplace.loc
std = laplace.stddev()
for i in range(3):
y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
all_pdf.append(y)
with open('./results_data/MCMC_Blackjax','rb') as f:
black_samples = pickle.load(f)
for i in range(3):
kde_black = gaussian_kde(black_samples.position['theta'][:,i],bw_method=0.3)
pdf_black = kde_black(x)
all_pdf.append(pdf_black)
with open('./results_data/HMC_Stan','rb') as f:
stan_hmc = pickle.load(f)
for i in range(3):
stan_hmc_kde=gaussian_kde(stan_hmc[i,:])
stan_hmc_pdf=stan_hmc_kde(x)
all_pdf.append(stan_hmc_pdf)
all_label = ['Ajax VI theta0']*x.shape[0] + ['Ajax VI theta1']*x.shape[0] + ['Ajax VI theta2'] *x.shape[0]+ ['Laplace theta0']*x.shape[0] + ['Laplace theta1']*x.shape[0] + ['Laplace theta2'] *x.shape[0]+['MCMC theta0']*x.shape[0]+['MCMC theta1']*x.shape[0]+['MCMC theta2']*x.shape[0]+['Stan HMC theta0']*x.shape[0]+['Stan HMC theta1']*x.shape[0]+['Stan HMC theta2']*x.shape[0]
all_pdf = jnp.array(all_pdf).reshape((-1))
x_repeated = jnp.tile(x,12)
to_df = {
"theta":x_repeated,
"PDF":all_pdf,
"label": all_label
}
df = pd.DataFrame(to_df)
fig = px.line(to_df,"theta","PDF",color="label",title="logistic regression")
fig.show()
fig.write_html("logistic_reg_result_plotly.html")