Intro¶
A Principled Bayesian Workflow¶
Among his other work, Michael Betancourt has a fantastically clear exposition on how a Bayesian might approach a modeling problem -- “Towards a Principled Bayesian Workflow”. Even if you don’t use Bayesian inference, this article contains a useful and thoughtful framework to think about whether the model you’re using to tackle a problem is good enough. It discusses more than the typical "Data Science 101" recipe of creating a training/test/validation set, from sklearn import ...
, cross-validation, etc. While those are essential, the framework rests on instead four foundational questions to be answered in a modeling exercise:
Question One... Is our model consistent with our domain expertise?
Models built upon assumptions that conflict with our domain expertise will give rise to inferences and predictions that also conflict with our domain expertise. While we don't need our model to capture every last detail of our knowledge, at the very least our models should not be in outright conflict with that knowledge.
Question Two... Will our computational tools be sufficient to accurately fit our posteriors?
Question Three... Will our inferences provide enough information to answer our questions?
Question Four... Is our model rich enough to capture the relevant structure of the true data generating process?
Question two aside, the answers to these questions are relevant whether you are using a Bayesian approach in Stan, using the workhorse that is XGBoost, fitting a billion parameter NLP model with PyTorch, etc. Furthermore, these questions are more than academic exercises -- answering these questions can help us catch failure modes as well as quantify the business impact of our models.
What does this have to do with RNA-Seq?¶
RNA-Seq and other assays based on high-throughput sequencing (HTS) data are powerful tools for interrogating biological systems: from characterizing gene expression to the 3-dimensional structure of the genome. While powerful, studies that leverage HTS data often suffer the following fates:
HTS data will often reflect the experimental conditions (e.g. library preparation) or artifacts (e.g. GC-bias) as much as the biological condition (e.g. treatment versus control).
While HTS data gives us data on many loci, we often have far fewer samples -- reducing our ability to identify differences between biological conditions.
For both cases, it is paramount to develop an expressive model that captures aspects of the data you as a modeler think are important and understand how well it recapitulates those features. This is where the Bayesian workflow described above comes in. We can adapt it for the purposes of analyzing RNA-seq data (let's limit ourselves to differential gene expression analysis):
Is our model consistent with known features of RNA-seq data (e.g. variation in read depth across genes) as well as the design of the experiment?
(Verbatim) Will our computational tools be sufficient to accurately fit our posteriors?
Will our inferences provide enough information to identify genes (pathways?) differentially expressed between conditions?
import warnings
from itertools import chain
import nest_asyncio
import numpy as np
import pandas as pd
import seaborn as sns
import stan
from matplotlib import pyplot as plt
from scipy.stats import pearsonr
nest_asyncio.apply()
del nest_asyncio
warnings.filterwarnings('ignore')
warnings.simplefilter(action='ignore', category=FutureWarning)
RNASEQ_DATA_URL = "https://raw.githubusercontent.com/ucdavis-bioinformatics-training/2018-June-RNA-Seq-Workshop/master/thursday/all_counts.txt"
rnaseq_df = (
pd.read_csv(RNASEQ_DATA_URL, sep="\t").
transpose().
astype("int64").
reset_index().
rename(columns={"index": "sample_id"})
)
rnaseq_df["cultivar"] = [s[0:1] if s.startswith("C") else s[0:2] for s in rnaseq_df["sample_id"]]
rnaseq_df["replicate"] = rnaseq_df["sample_id"].str[-1]
rnaseq_df["time"] = [s[1:-1] if s.startswith("C") else s[2:-1] for s in rnaseq_df["sample_id"]]
metadata_cols = ["sample_id", "cultivar", "replicate", "time"]
gene_cols = [col for col in rnaseq_df if not col in metadata_cols]
print(rnaseq_df.shape)
rnaseq_df.iloc[0:5,][metadata_cols + gene_cols[0:5]]
Probing with our Domain Expertise¶
What are somethings we think should be true? Do they hold?
Reads depth will vary across samples and on a per gene basis¶
- Due to library prep differences some samples will just get higher coverage
- Due to differences in gene length, GC bias, mapping artifacts, and sheer sampling, a large chunk of genes will get 0 reads and a few genes will have very high read depth
plt.figure(figsize=(10, 12))
plt.subplot(2, 1, 1)
plt.hist(rnaseq_df[gene_cols].sum(axis=1), density=True)
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.title("Total read depth per sample")
plt.subplot(2, 1, 2)
plt.hist(np.log10(1 + rnaseq_df[gene_cols]).mean(axis=0), bins=35, density=True)
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.title("Average read depth per gne")
The shape of gene-level read depth variation should be the same across samples¶
- Even if the total magnitude is the same, we shoulnd't expect there to be a sample-specific impact on the shape (e.g. quantiles or modes) of the distribution of read depth across genes
fig, axes = plt.subplots(3, 2, figsize=(10, 12), sharey=True, sharex=True)
fig.suptitle("Read depth distribution across samples")
samples_to_plot = np.random.choice(rnaseq_df["sample_id"], size=6)
for idx, ax in enumerate(chain(*axes)):
sample = samples_to_plot[idx]
log_read_depth = np.log10(1 + rnaseq_df.query("sample_id == @sample")[gene_cols].values.flatten())
ax.hist(log_read_depth, density=False, bins=20)
ax.axvline(log_read_depth.mean(), linewidth=2, color="red")
for iqr in np.quantile(log_read_depth, [0.25, 0.75]):
ax.axvline(iqr, linewidth=2, color="red", linestyle="--")
ax.grid(alpha=0.5, linestyle="-.")
ax.set_title(sample)
ax.set_ylabel("Counts")
ax.set_xlabel("# of reads / gene")
plt.tight_layout()
Technical replicates should look more similar to each other than other samples¶
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.scatter(
x=np.log10(1 + rnaseq_df[rnaseq_df["sample_id"] == "C61"][gene_cols].values),
y=np.log10(1 + rnaseq_df[rnaseq_df["sample_id"] == "C62"][gene_cols].values),
alpha=0.2
)
plt.grid(linestyle="--", color="lightgray")
plt.plot(np.linspace(0, 5, 25), np.linspace(0, 5, 25), color="yellow", linewidth=2)
plt.xlabel("Read depth (cultivar C @ T=6, rep 1)")
plt.ylabel("Read depth (cultivar C @ T=6, rep 2)")
plt.subplot(1, 2, 2)
plt.scatter(
x=np.log10(1 + rnaseq_df[rnaseq_df["sample_id"] == "C61"][gene_cols].values),
y=np.log10(1 + rnaseq_df[rnaseq_df["sample_id"] == "C91"][gene_cols].values),
alpha=0.2
)
plt.grid(linestyle="--", color="lightgray")
plt.plot(np.linspace(0, 5, 25), np.linspace(0, 5, 25), color="yellow", linewidth=2)
plt.xlabel("Read depth (cultivar C @ T=6, rep 1)")
plt.ylabel("Read depth (cultivar C @ T=9, rep 1)")
plt.suptitle("Read depth comparison of samples");
Mean versus variance of read depth will be non-linear¶
plt.figure(figsize=(10, 6))
plt.scatter(
x=np.log10(1 + rnaseq_df[gene_cols]).mean(axis=0),
y=np.log10(1 + rnaseq_df[gene_cols]).std(axis=0),
alpha=0.2
)
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
A Basic Model¶
Here we introduc a basic model for the read counts observed at each gene for reach sample. We assume that count $c_{ij}$ for sample $i$ at gene $j$ is drawn from a Poisson:
$$c_{ij} \sim Pois(\lambda)$$$$\lambda = exp(\mu^{samp}_i + \mu^{gene}_j)$$$$ \mu_{samp} \sim Normal(0, 5) $$$$ \mu_{gene} \sim Normal(0, 5) $$This model explicitly accounts for some of the features of the data generating process that we interrogated above:
- Sample-specific variation in read depth
- Gene-specific variation in read depth
Also worth noting that the motivation for the Poisson is that this distribution corresponds to the number of successes with the number of shots on goal is large but probability of success is low for each shot. This is analogous to RNA-Seq: lots of reads, but small chance that each read will align to a given gene.
basic_poisson_model_code = """
data {
int<lower=0> num_genes; // number of genes
int<lower=0> num_samps; // number of samples
int<lower=0> counts[num_samps, num_genes]; // read counts
}
parameters {
vector[num_genes] mu_gene;
vector[num_samps] mu_samp;
}
model {
mu_gene ~ normal(0, 10);
mu_samp ~ normal(0, 10);
for(s in 1:num_samps) {
counts[s] ~ poisson(exp(mu_samp[s] + mu_gene));
}
}
"""
SUBSAMP = True
SUBSAMPSIZE = 100
cols = np.random.choice(gene_cols, SUBSAMPSIZE) if SUBSAMP else gene_cols
stan_data = {
"num_genes": len(cols),
"num_samps": rnaseq_df.shape[0],
"counts": rnaseq_df[cols].values
}
Prior Predictive Check¶
Instead of going right to the fitting the model, we are going to start with step 1 of seeing how well our model is with consistent with our domain expertise i.e. can it even generate data that captures aspects of the data that we think are important?
To do this we will start with a prior predictive check. The idea is simple:
- Define a summary statistic that describes some germane aspect of the data
- Draw parameters from the prior of your model
- Simulate data from the parameters
- Calculate the chosen summary statistic for each simulated realization of the data
- Compare the distribution of the simulated summary statistic to the observed data
The result of the final step should be that the observed data should fall within the range of the simulated. It does not need to match as the prior did not take the data into account. But it should be at least possible for the model to generate data that is similar.
def sample_basic_poisson_prior_pred(stat, num_genes, num_samps, num_draws=1000, **kw_args):
mu_samp_prior = np.random.normal(loc=0, scale=5, size=(num_draws, num_samps))
mu_gene_prior = np.random.normal(loc=0, scale=5, size=(num_draws, num_genes))
counts_prior_pred = np.zeros((num_samps, num_genes))
prior_pred_check = []
for mu_samps, mu_genes in zip(mu_samp_prior, mu_gene_prior):
for samp in range(num_samps):
for gene in range(num_genes):
mu = mu_samps[samp] + mu_genes[gene]
counts_prior_pred[samp, gene] = np.random.poisson(lam=np.exp(mu))
prior_pred_check.append(stat(counts_prior_pred, **kw_args))
return np.array(prior_pred_check)
A simple, not informative check: the median of the generated data¶
Looks good -- low bar though! The observed data falls within the range of th prior predictive simulations.
median_prior_pred_check = sample_basic_poisson_prior_pred(np.median, stan_data["num_genes"], stan_data["num_samps"], axis=1)
plt.figure(figsize=(10, 6))
plt.hist(np.log10(1.0 + median_prior_pred_check.flatten()), density=True, label="prior predictive check", bins=35)
plt.hist(np.log10(1.0 + np.median(rnaseq_df[gene_cols], axis=1)), density=True, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
Can we generate as many genes with 0 reads?¶
Again, looks good! The observed data falls within the range of th prior predictive simulations.
_calc_num_zeros = lambda x: (x == 0.0).sum()
calc_num_zero_genes = lambda gs: np.apply_along_axis(_calc_num_zeros, axis=0, arr=gs)
num_zero_genes_prior_pred_check = sample_basic_poisson_prior_pred(calc_num_zero_genes, stan_data["num_genes"], stan_data["num_samps"])
plt.figure(figsize=(10, 6))
plt.hist(num_zero_genes_prior_pred_check.flatten(), bins=30, density=True, label="prior predictive check")
plt.hist(calc_num_zero_genes(rnaseq_df[cols].values), density=True, alpha=0.5, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
A more complicated check: the difference in pairwise correlation between replicates of the same cultivar versus random pairs of cultivars¶
Failure! We see the model has a hard time generating data where the replicates within the same cultivar are more correlated with each other than random pairs -- as is the case with observed data. This should not be surprising as we did not encode this in the model!
def _var_ttest_denom(v1, n1, v2, n2):
vn1 = v1 / n1
vn2 = v2 / n2
# If df is undefined, variances are zero (assumes n1 > 0 & n2 > 0).
# Hence it doesn't matter what df is as long as it's not NaN.
denom = np.sqrt(vn1 + vn2)
return denom
def _calc_t_stat(a, b):
"""Calculate the t statistic along the given dimension."""
na = len(a)
nb = len(b)
avg_a = np.mean(a)
avg_b = np.mean(b)
var_a = np.var(a, ddof=1)
var_b = np.var(b, ddof=1)
denom = _var_ttest_denom(var_a, na, var_b, nb)
return (avg_a-avg_b)/denom
def calc_pw_corr_diff(counts):
pw_corr = np.corrcoef(counts)
indices_1, indices_2 = np.tril_indices(n=pw_corr.shape[0], k=-1)
pw_corr_flat = np.array([pw_corr[idx1, idx2] for idx1, idx2 in zip(indices_1, indices_2)])
pw_reps = np.array([
1 if rnaseq_df["cultivar"][idx1] == rnaseq_df["cultivar"][idx2] else 0
for idx1, idx2 in zip(indices_1, indices_2)
])
pw_corr_t = _calc_t_stat(pw_corr_flat[pw_reps == 0], pw_corr_flat[pw_reps == 1])
return pw_corr_t
pw_corr_diff_prior_pred_check = sample_basic_poisson_prior_pred(calc_pw_corr_diff, stan_data["num_genes"], stan_data["num_samps"])
plt.figure(figsize=(10, 6))
plt.hist(pw_corr_diff_prior_pred_check.flatten(), bins=30, density=True, label="prior predictive check")
plt.hist(calc_pw_corr_diff(rnaseq_df[cols].values), density=True, alpha=0.5, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
basic_poisson_model = stan.build(basic_poisson_model_code, data=stan_data)
basic_poisson_model_fit = basic_poisson_model.sample(num_chains=1)
def sample_basic_poisson_post_pred(stat, params, num_genes, num_samps, num_draws=1000, **kw_args):
counts_post_pred = np.zeros((num_samps, num_genes))
post_pred_check = []
for mu_samps, mu_genes in zip(params["mu_samp"], params["mu_gene"]):
for samp in range(num_samps):
for gene in range(num_genes):
mu = mu_samps[samp] + mu_genes[gene]
counts_post_pred[samp, gene] = np.random.poisson(lam=np.exp(mu))
post_pred_check.append(stat(counts_post_pred, **kw_args))
return np.array(post_pred_check)
basic_poisson_post_params = {param: basic_poisson_model_fit[param].T for param in ("mu_samp", "mu_gene")}
median_post_pred_check = sample_basic_poisson_post_pred(np.median, basic_poisson_post_params, stan_data["num_genes"], stan_data["num_samps"], axis=1)
plt.figure(figsize=(10, 6))
plt.hist(np.log10(1.0 + median_post_pred_check.flatten()), density=True, label="posterior predictive check", bins=35)
plt.hist(np.log10(1.0 + np.median(rnaseq_df[gene_cols], axis=1)), density=True, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
num_zero_genes_post_pred_check = sample_basic_poisson_post_pred(calc_num_zero_genes, basic_poisson_post_params, stan_data["num_genes"], stan_data["num_samps"])
plt.figure(figsize=(10, 6))
plt.hist(num_zero_genes_post_pred_check.flatten(), bins=30, density=True, label="posterior predictive check")
plt.hist(calc_num_zero_genes(rnaseq_df[cols].values), density=True, alpha=0.5, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
pw_corr_diff_post_pred_check = sample_basic_poisson_post_pred(calc_pw_corr_diff, basic_poisson_post_params, stan_data["num_genes"], stan_data["num_samps"])
plt.figure(figsize=(10, 6))
plt.hist(pw_corr_diff_post_pred_check.flatten(), bins=30, density=True, label="posterior predictive check")
plt.hist(calc_pw_corr_diff(rnaseq_df[cols].values), density=True, alpha=0.5, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();
cultivar_poisson_model_code = """
data {
int<lower=0> num_genes; // number of genes
int<lower=0> num_samps; // number of samples
int<lower=0> num_cultivars; // number of cultivars
int<lower=0> cultivars[num_samps];
int<lower=0> counts[num_samps, num_genes]; // read counts
}
parameters {
vector[num_genes] mu_gene;
vector[num_cultivars] mu_cultivar;
vector<lower=0>[num_cultivars] sigma_cultivar;
vector[num_samps] mu_samp;
}
model {
mu_cultivar ~ normal(0, 10);
mu_gene ~ normal(0, 10);
for(s in 1:num_samps) {
mu_samp[s] ~ normal(mu_cultivar[cultivars[s]], sigma_cultivar[cultivars[s]]);
counts[s] ~ poisson(exp(mu_samp[s] + mu_gene));
}
}
"""
unique_cultivars = rnaseq_df["cultivar"].unique()
stan_data["num_cultivars"] = unique_cultivars.shape[0]
stan_data["cultivars"] = np.array([np.where(unique_cultivars == cultivar)[0][0] + 1 for cultivar in rnaseq_df["cultivar"]])
cultivar_poisson_model = stan.build(cultivar_poisson_model_code, data=stan_data)
cultivar_poisson_model_fit = cultivar_poisson_model.sample(num_chains=1)
cultivar_poisson_post_params = {param: cultivar_poisson_model_fit[param].T for param in ("mu_samp", "mu_gene", "mu_cultivar", "sigma_cultivar")}
plt.hist(cultivar_poisson_post_params["sigma_cultivar"])
plt.figure(figsize=(10, 6))
plt.scatter(np.exp(cultivar_poisson_post_params["mu_samp"]).mean(axis=0), np.exp(cultivar_poisson_post_params["mu_samp"]).mean(axis=0))
plt.plot(np.linspace(0.5, 2.75, 20), np.linspace(0.5, 2.75, 20), color="red")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
pw_corr_diff_post_pred_check = sample_basic_poisson_post_pred(calc_pw_corr_diff, cultivar_poisson_post_params, stan_data["num_genes"], stan_data["num_samps"])
plt.figure(figsize=(10, 6))
plt.hist(pw_corr_diff_post_pred_check.flatten(), bins=30, density=True, label="posterior predictive check")
plt.hist(calc_pw_corr_diff(rnaseq_df[cols].values), density=True, alpha=0.5, label="observed data")
plt.grid(linestyle="--", color="lightgray", alpha=0.5)
plt.legend();