# Intro to Probabilistic Programming with Stan

In this tutorial we’re going to talk about what probabilistic programming is and how we can use it for statistical modeling. If you aren’t familiar at all with Bayesian stats, check out my previous post on the topic. If you’re used to probabilistic programming but just want to learn the Stan language, you can go straight to the fantastic Stan User’s Guide, which explains how to program a wide variety of models.

- What is probabilistic programming?
- Probabilistic programming with Stan
- Assessing model convergence
- Assessing model fit
- Linear Regression in Stan
- Summary

# What is probabilistic programming?

Probabilistic programming is a relatively new and exciting approach to statistical modeling that lets you create models in a standardized language without having to implement any of the nitty-gritty details or work out too much math. Although not all probabilistic programs are Bayesian, probabalistic programming makes Bayesian modeling easy, and so it’s a great way to learn what Bayesian models are, how they’re fit to data, and what you can do with them. To explain what probabilistic programming is, I’m going to use just a little bit of math. Bear with me, because this is important!

In Bayesian statistics, we start with a model and some data. As a simple
example, we might model some ratings on a scale using a normal
distribution with a particular mean
*μ*
and variance
*σ*^{2}
. Our goal is to identify the most likely parameter values given our
data (that is, the values of
*μ*
and
*σ*
that best explain our data). To determine which which parameter values
are best, we make use of Bayes’ formula:

*P*(*θ*|𝒟) ∝ *P*(*θ*)*P*(𝒟|*θ*)

This formula says that the probability of a parameter value
*θ*
given our data
𝒟
is proportional to our prior probability of that parameter value
multiplied by the likelihood that the data could have been generated
from that parameter value. How do we determine the likelihood? Well,
sometimes we can derive the likelihood (and hence the posterior) by
hand. But in most cases, this approach is too difficult or
time-consuming. In probabilistic programming, we write a program that
simulates our model given some parameter values. This is actually useful
in its own right: we can use this program to see how the model behaves
under different settings of the parameters. But in statistical
inference, the important part is that we run that program to
(approximately) calculate the likelihood, which in turn lets us
calculate the posterior probability of the parameter values given our
data.

## Why Stan?

There are a good number of probabilistic programming languages out
there. Today we’re going to focus on Stan, which
is one of the fastest, most reliable, and most widely used probabilistic
programming languages out there. One of the cool things about Stan is
that there are a number of different interfactes to Stan: you can use
Stan through R, through Python, through Matlab, through Julia, and even
directly through the command-line! If you’ve read my tutorial on
Bayesian regression with
brms, then you’ve
actually already used one of the easiest interfaces to Stan, which
writes Stan programs for you based on `lmer`

-like formulas. Lastly, Stan
has one of the largest communities
that makes getting coding help and statistical advice easy.

## The components of a Stan program

Unsurprisingly, Stan programs are written in Stan files, which use the
extension `.stan`

. The Stan language has similar syntax to `C++`

, in
that it uses curly braces (`{`

and `}`

) to define blocks of code,
semicolons (`;`

) after each statement, and has a type declaration for
every variable in the program. There are two primitive data types: `int`

for integers, and `real`

for floating-point/decimal numbers. There are
also a few different container types: `array`

, `vector`

, and
`row_vector`

for one-dimensional containers, and `matrix`

for
N-dimensional containers. For now, the differences between `array`

,
`vector`

, and `row_vector`

aren’t that important. Just know that when
possible, we will try to use type `vector`

, which will generally be most
efficient.

Stan programs consist of up to seven different blocks of code, in the following order (*required):

`functions`

`data`

*`transformed data`

`parameters`

*`transformed parameters`

`model`

*`generated quantities`

In the remainder of the workshop, we’re going to focus on the `data`

,
`parameters`

, `model`

, and `generated_quantities`

blocks, but we’ll also
use the `transformed parameters`

block.

# Probabilistic programming with Stan

To demonstrate the power of Stan, let’s first get a nice dataset to work with. Here I’m going to load some packages, and then run some code to gather data from the Spotify top 200 songs per week in 2021. Don’t worry about how this code actually works (we can save that for a future meeting…), but know that it will take some time (~5mins) if you run this on your computer.

## Getting some data

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

library(cmdstanr) # for stan
library(tidyverse) # for data wrangling
library(lubridate) # for dates
library(rvest) # for scraping spotify charts
library(spotifyr) # for spotify audio features
library(tidybayes) # for accessing model posteriors
library(viridis) # for pretty plots
options(mc.cores=parallel::detectCores())
## gather spotify chart data (modified from https://rpubs.com/argdata/web_scraping)
scrape_spotify <- function(url) {
page <- url %>% read_html() # read the HTML page
rank <- page %>%
html_elements('.chart-table-position') %>%
html_text() %>%
as.integer
track <- page %>%
html_elements('strong') %>%
html_text()
artist <- page %>%
html_elements('.chart-table-track span') %>%
html_text() %>%
str_remove('by ')
streams <- page %>%
html_elements('td.chart-table-streams') %>%
html_text() %>%
str_remove_all(',') %>%
as.integer
URI <- page %>%
html_elements('a') %>%
html_attr('href') %>%
str_subset('https://open.spotify.com/track/') %>%
str_remove('https://open.spotify.com/track/')
## combine, name, and make it a tibble
chart <- tibble(rank=rank, track=track, artist=artist, streams=streams, URI=URI)
return(chart)
}
## setup access to Spotify API
access_token <- get_spotify_access_token()
## load the top 200 songs in the US per week in 2021
spotify2021 <- tibble(week=seq(ymd('2021-01-01'), ymd('2021-11-19'), by = 'weeks')) %>%
mutate(url=paste0('https://spotifycharts.com/regional/us/weekly/', week, '--', week+days(7)),
data=map(url, scrape_spotify)) %>%
unnest(data) %>%
mutate(streams=streams/1000000)
## extract spotify's audio features for each song
features <- tibble(URI=unique(spotify2021$URI)) %>%
mutate(features=map(URI, ~ get_track_audio_features(.x, authorization=access_token))) %>%
unnest(features)
## make one tidy data frame
spotify2021 <- spotify2021 %>% left_join(features, by='URI') %>%
select(-URI, -analysis_url, -track_href, -id, -type) %>%
relocate(week, rank, track, artist, streams, duration_ms, tempo,
time_signature, key, mode, valence, loudness, danceability,
energy, speechiness, acousticness, instrumentalness, liveness, uri, url)
write_csv(spotify2021, '2021-12-10-spotify-data.csv')
spotify2021

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

## # A tibble: 9,400 × 20
## week rank track artist streams duration_ms tempo time_signature key
## <date> <dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 2021-01-01 1 Good … SZA 6.32 279204 121. 4 1
## 2 2021-01-01 2 Anyone Justi… 6.15 190779 116. 4 2
## 3 2021-01-01 3 34+35 Arian… 5.61 173711 110. 4 0
## 4 2021-01-01 4 Mood … 24kGo… 5.58 140526 91.0 4 7
## 5 2021-01-01 5 Lemon… Inter… 5.37 195429 140. 4 1
## 6 2021-01-01 6 DÁKITI Bad B… 5.16 205090 110. 4 4
## 7 2021-01-01 7 posit… Arian… 5.10 172325 144. 4 0
## 8 2021-01-01 8 Whoop… CJ 4.88 123263 140. 4 3
## 9 2021-01-01 9 WITHO… The K… 4.78 161385 93.0 4 0
## 10 2021-01-01 10 Blind… The W… 4.44 200040 171. 4 1
## # … with 9,390 more rows, and 11 more variables: mode <dbl>, valence <dbl>,
## # loudness <dbl>, danceability <dbl>, energy <dbl>, speechiness <dbl>,
## # acousticness <dbl>, instrumentalness <dbl>, liveness <dbl>, uri <chr>,
## # url <chr>

As we can see, we now have a dataframe of Spotify’s weekly top 200 tracks, along with the following information:

`week`

: the week in 2021`rank`

: the song’s rank (`1`

to`200`

) in this week, with`1`

being the top song`track`

: the name of the song`artist`

: the name of the artist who released the song`streams`

: the number of streams in that week (in millions)`duration_ms`

: the duration of the track in ms`tempo`

: the tempo of the track in beats per minute`time_signature`

: an estimated time signature ranging from`3`

to`7`

(for 3/4 to 7/4)`key`

: the key of the song from`0`

(for C) to`11`

(for B), or`-1`

if no key was found`mode`

: whether the track is in a major (`1`

) or minor (`0`

) key`valence`

: the emotional valence of the track from`0`

(negative valence/sad) to`1`

(positive valence/happy)`loudness`

: the average loudness of the track in decibels`danceability`

: an estimate of how danceable the track is, from`0`

(least danceable) to`1`

(most danceable)`energy`

: an estimate of the intensity or activity of the track, from`0`

(low energy) to`1`

(high energy)`speechiness`

: an estimate of the proportion of speech in the track, from`0`

(no speech) to`1`

(only speech)`acousticness`

: an estimate of the degree to which a track is (`1`

) or is not (`0`

) acoustic`instrumentalness`

: an estimate of the degree to which a track contains (`1`

) or does not contain (`0`

) vocals`liveness`

: an estimate of whether the track was performed live (`1`

) or not (`0`

)`uri`

: the Spotify unique identifier for the track`url`

: a link to the track

## Simulating fake data: number of streams

Let’s say we want to know how many times, on average, the top 200 tracks
are streamed every week. Of course, we could just use
`mean(spotify2021$streams)`

to get this number, but to get more
information we will need to specify a model. As a start, we can assume a
normal distribution with mean
*μ*
and standard deviation
*σ*
. Before fitting this model, we might just want to know what data
simulated from this model looks like under different parameter values.
This is the main goal of simulation: we assume that we know what the
values of
*μ*
and
*σ*
are to check what the distribution of streams would look like if those
values were true. To do that, let’s write a Stan program, which I’ll
save in the file `2021-12-10-streams-sim.stan`

:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

data {
real<lower=0> mu; // the mean
real<lower=0> sigma; // the standard deviation
}
parameters {
}
model {
}
generated quantities {
// simulate data using a normal distribution
real y_hat = normal_rng(mu, sigma);
}

Since we’re simulating from a prior, we will take our parameters `mu`

and `sigma`

as inputs to Stan by declaring them in the `data`

block. The
code `real<lower=0> mu;`

defines a variable called `mu`

that will refer
to the mean of the number of streams, and similarly
`real<lower=0> sigma;`

defines the standard deviation. Both of these
variables are lower-bounded at 0 with the expression `<lower=0>`

,
because it wouldn’t make sense to simulate a negative number of streams
or a negative standard deviation (we would also put an upper bound here
if it made sense). Since our model has no remaining parameters, and we
are not yet modeling any data, both the `parameters`

and `model`

blocks
are empty. Finally, in the `generated quantities`

block, we are telling
our model to simulate the number of streams by drawing a random number
from a normal distribution.

To run our Stan program, we will make use of the library `cmdstanr`

. The
`rstan`

library also works for this, but I’ve found `cmdstanr`

to be
faster and more reliable. Let’s say we know that there are roughly one
million streams per week, but this varies with a standard deviation of
one hundred thousand streams. We can make a list of these values, and
pass them to Stan as data:

1
2
3

streams_sim_data <- list(mu=1, sigma=.1)
streams_sim_model <- cmdstan_model('2021-12-10-streams-sim.stan') ## compile the model
streams_sim <- streams_sim_model$sample(data=streams_sim_data, fixed_param=TRUE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Sampling)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Sampling)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Sampling)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Sampling)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Sampling)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Sampling)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Sampling)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Sampling)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Sampling)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Sampling)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Sampling)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Sampling)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Sampling)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Sampling)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Sampling)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Sampling)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Sampling)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 0.0 seconds.
## Chain 2 finished in 0.0 seconds.
## Chain 3 finished in 0.0 seconds.
## Chain 4 finished in 0.0 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.0 seconds.
## Total execution time: 0.3 seconds.

As we can see, the model has simulated 1000 stream counts in four
different chains. Note that above, we used the argument
`fixed_param=TRUE`

to tell Stan that our model has no parameters, which
makes the sampling faster. Let’s look at a summary of our model:

1

streams_sim

1
2

## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## y_hat 1.00 1.00 0.10 0.10 0.83 1.17 1.00 3848 3898

This summary tells us that our simulated streams counts have an average
of about one million and a standard deviation of about one hundred
thousand. To access the simulated data, we have a few different options.
Within `cmdstanr`

, the default is to use `streams_sim$draws()`

. However,
I find that the `spread_draws`

function from `tidybayes`

is usually
easier to work with, as it gives us a nice tidy dataframe of whatever
variables we want. The other reason is that we’re going to use
`tidybayes`

(technically `ggdist`

) to make pretty plots of our draws.
Let’s get our draws and plot them:

1
2
3
4
5
6
7

draws <- streams_sim %>% gather_draws(y_hat)
ggplot(draws, aes(x=.value)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Streams (millions/week)') + ylab('Density') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

Again, this tells us what we already expected: our simulated top 200 songs have somewhere around one million streams per week, and the number of streams are normally distributed around that.

## Sampling from a prior distribution

It’s nice to simulate data, but of course our main goal is to infer what
the *actual* mean and standard deviation of stream counts for the top
200 tracks. To do so, we first need to define a prior distribution.
Thankfully, this is pretty easy in Stan: we just move the parameters
`mu`

and `sigma`

from the `data`

block to the `parameters`

block:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

data {
}
parameters {
real<lower=0> mu; // the mean
real<lower=0> sigma; // the standard deviation
}
model {
// define priors for mu and sigma
mu ~ normal(1, .1);
sigma ~ normal(0, .1);
}
generated quantities {
// simulate data using a normal distribution
real y_hat = normal_rng(mu, sigma);
}

Besides the declarations of `mu`

and `sigma`

being moved to the
`parameters`

block, we can see that we’ve also added to the `model`

block. Specifically, the `model`

block now specifies prior distributions
over our two parameters. The symbol `~`

can be read as “is distributed
as”, so we’re saying that `mu`

is distributed according to a normal
distribution with a mean of one million and a standard deviation of one
hundred thousand. Likewise, we’re assuming that `sigma`

is distributed
normally around 0 with a standard deviation of one hundred thousand. You
might think that this would give us negative numbers, but Stan truncates
these normal distributions at 0 because of the `<lower=0>`

in the
paramters’ declarations. Now let’s sample from our prior distribution to
simulate some fake data:

1
2
3
4
5
6
7
8
9
10
11
12

streams_prior_model <- cmdstan_model('2021-12-10-streams-prior.stan') ## compile the model
streams_prior <- streams_prior_model$sample()
streams_prior
streams_prior %>%
gather_draws(mu, sigma, y_hat) %>%
mutate(.variable=factor(.variable, levels=c('y_hat', 'sigma', 'mu'))) %>%
ggplot(aes(x=.value, y=.variable)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Streams (millions/week)') + ylab('Density') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 0.0 seconds.
## Chain 2 finished in 0.0 seconds.
## Chain 3 finished in 0.0 seconds.
## Chain 4 finished in 0.0 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.0 seconds.
## Total execution time: 0.2 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## lp__ -3.94 -3.61 1.08 0.83 -6.10 -2.85 1.00 1430 1617
## mu 1.00 1.00 0.10 0.10 0.84 1.17 1.00 1960 1705
## sigma 0.08 0.07 0.06 0.06 0.01 0.20 1.00 1391 1210
## y_hat 1.00 1.00 0.14 0.13 0.78 1.24 1.00 2721 2915

Just like before, we now have simulated values of `y_hat`

centered
around one million streams per week. However, the distribution of
`y_hat`

is wider than before. When we simulated stream counts with a
fixed `mu`

and `sigma`

, the only source of noise in our simulated data
was the noise in the sampling process. But now that we have included
`mu`

and `sigma`

as parameters in the model, we also have uncertainty in
`mu`

and `sigma`

that creates some more noise in `y_hat`

.

## Fitting a model to data

You might have noticed that that was a whole lot of work to go through
to sample from some normal distributions. Up until now, we could have
just as well used `rnorm`

a few times to do the trick. So what’s the
point? Well, using (almost) the same Stan code, we can now fit this
simple model to our data to find the most likely values of
*μ*
and
*σ*
:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

data {
int<lower=0> N; // the number of data points
vector<lower=0>[N] y; // the data to model
}
parameters {
real<lower=0> mu; // the mean
real<lower=0> sigma; // the standard deviation
}
model {
// define priors for mu and sigma
mu ~ normal(1, .1);
sigma ~ normal(0, .1);
// define the likelihood of y
y ~ normal(mu, sigma);
}
generated quantities {
// simulate data using a normal distribution
real y_hat = normal_rng(mu, sigma);
}

Compared to the previous code, we have added two things. First, in the
`data`

block, we added declarations for two variables. `y`

is a vector
containing the stream counts for each track in each week. The syntax
`[N]`

tells Stan that this vector is `N`

numbers long, which is why we
also declared a data variable `N`

. Finally, in the `model`

block, we
added a line that defines the likelihood of `y`

given our model: we are
modeling `y`

as normally-distributed with mean `mu`

and standard
deviation `sigma`

. Rather than just evaluating the likelihood of the
data according to our prior distributions, Stan will sample the values
of `mu`

and `sigma`

according to their posterior probability using
Markov Chain Monte Carlo (MCMC), giving us an approximate posterior
distribution. Let’s run it and see what happens:

1
2
3
4
5
6
7
8
9
10
11

streams_data <- list(N=nrow(spotify2021), y=spotify2021$streams)
streams_model <- cmdstan_model('2021-12-10-streams.stan') ## compile the model
streams <- streams_model$sample(data=streams_data, save_warmup=TRUE)
streams
draws <- streams %>% gather_draws(mu, sigma, y_hat)
ggplot(draws, aes(x=.value)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Streams (millions/week)') + ylab('Density') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 0.2 seconds.
## Chain 2 finished in 0.2 seconds.
## Chain 3 finished in 0.2 seconds.
## Chain 4 finished in 0.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.2 seconds.
## Total execution time: 0.8 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk
## lp__ -10568.72 -10568.40 0.97 0.74 -10570.60 -10567.80 1.00 1861
## mu 2.52 2.52 0.02 0.02 2.49 2.55 1.00 3421
## sigma 1.78 1.78 0.01 0.01 1.76 1.80 1.00 3618
## y_hat 2.54 2.51 1.76 1.77 -0.32 5.41 1.00 4011
## ess_tail
## 2564
## 2786
## 2641
## 3889

Even though our prior for `mu`

was around one million streams per week,
it looks like our posterior is now around 2.5 million streams per week.
Likewise, the posterior for `sigma`

is about 1.8 million, even though
our prior was centered around 0. Finally, looking at `y_hat`

, it appears
that our model estimates the number of streams per week to be anywhere
from -500,000 to 5.5 million. Before we talk about these results any
further, though, let’s make sure that we can trust them.

# Assessing model convergence

Since we don’t have direct access to the posterior distribution, Stan
uses Markov Chain Monte Carlo (MCMC) to sample values of `mu`

and
`sigma`

. We won’t go into the details here, but the gist is that MCMC
approximates the posterior distributions over `mu`

and `sigma`

by trying
to sample their values in proportion to their posterior probability. If
the samples look like they have come from the posterior distribution, we
say the model has *converged*. If not, we cannot use the sampled values
for inference, because they don’t reflect our posterior.

## The fuzzy caterpillar check

There are few different metrics for assessing convergence of MCMC chains. Honestly, the best one is visual: the “fuzzy caterpillar” check. The idea is you plot the MCMC chains for each parameter as a function of iteration number, like so:

1
2
3
4
5

ggplot(draws, aes(x=.iteration, y=.value, color=factor(.chain))) +
geom_line() + xlab('Iteration') + ylab('Value') +
scale_color_discrete(name='Chain') +
facet_grid(.variable ~ ., scales='free_y') +
theme_tidybayes()

Since all of these chains look like nice fuzzy caterpillars, we can be
pretty confident that they converged. To demonstrate what the chains
would look like if they *hadn’t* converged, let’s look at the chains
before the warmup period. The warmup period is the first stage of the
model while it is assumed to still be converging: typically we say that
something like the first half of the samples are in the warmup period,
and we throw them away to be left with just the good stuff.

1
2
3
4
5
6
7
8
9

streams$draws(variables=c('mu', 'sigma', 'y_hat'),
inc_warmup=TRUE, format='draws_df') %>%
pivot_longer(mu:y_hat, names_to='.variable', values_to='.value') %>%
filter(.iteration <= 250) %>%
ggplot(aes(x=.iteration, y=.value, color=factor(.chain))) +
geom_line() + xlab('Iteration') + ylab('Value') +
scale_color_discrete(name='Chain') +
facet_grid(.variable ~ ., scales='free_y') +
theme_tidybayes()

As we can see, the first 25 or so iterations do not look like nice fuzzy caterpillars. Instead, we can tell all of the four chains apart from each other, since they are close to their random initializaiton values. But by iteration 50, it appears that our model has converged: the parameters have all ended up around the values of our posterior distribution.

## R-hat

If the qualitative visual check isn’t working for you, you might want something a bit more quantitative. One option is R-hat, which is the ratio of the between-chain variance and the within-chain variance of the parameter values. This gives us a good quantification of the fuzzy caterpillar check: if the between-chain variance is high (relative to the within-chain variance), the chains are all exploring different regions of the parameter space and don’t overlap much. On the other hand, if the two variances are about equal, then the chains should look like fuzzy caterpillars. Typically we look for R-hat values to be as close to 1 as possible and we start to be suspicious of poor convergence if R-hat > 1.05.

1

streams$summary() %>% select(variable, rhat)

1
2
3
4
5
6
7

## # A tibble: 4 × 2
## variable rhat
## <chr> <dbl>
## 1 lp__ 1.00
## 2 mu 1.00
## 3 sigma 1.00
## 4 y_hat 1.00

Since our R-hat values are all 1.00, our model looks pretty good.

## Effective Sample Size (ESS)

Related to R-hat, we can also look at the effective sample size (ESS) of
the model. Recall that we sampled 1000 draws from four MCMC chains,
resulting in 4000 total samples from the posterior. In an ideal scenario
where every iteration of the model is totally independent of the
previous iteration, this would mean that we have a sample size of 4000
samples. But most of the time, there is some amount of auto-correlation
of the parameter values between iterations. To account for this, ESS is
the sample size adjusted for within-chain auto-correlation. In other
words, even though we have 4000 samples from the posterior, because of
auto-correlation inherent in the model fitting process, we *effectively*
have fewer independent samples. `cmdstanr`

actually gives us two
different ESSs: a bulk ESS and a tail ESS. The bulk ESS tells us the
effective sample size for our estimates of central tendency (i.e.,
mean/median), and the tail ESS tells us the effective sample size for
our estimates of the tail quantiles and credible intervals. Since there
are fewer samples at the tails, we will typically have a lower tail ESS
than bulk ESS. In any case, you want all of these ESSs to be as large as
possible. Minimally, it is good to have an ESS of 1000 for practical
applications.

1

streams$summary() %>% select(variable, ess_bulk, ess_tail)

1
2
3
4
5
6
7

## # A tibble: 4 × 3
## variable ess_bulk ess_tail
## <chr> <dbl> <dbl>
## 1 lp__ 1862. 2564.
## 2 mu 3422. 2787.
## 3 sigma 3619. 2641.
## 4 y_hat 4012. 3890.

Our bulk ESS looks very good- all of the values are close to 4000. Though the tail ESS is lower, it is still acceptable.

# Assessing model fit

Now that we know that our model converged, let’s try to figure out how well it fit. In other words, how well does our model describe the data? Just as the fuzzy-caterpillar check provides a quick & easy way of assessing convergence, posterior predictive checks do the same for model fit. To perform a posterior predictive check, all we have to do is plot the distribution of simulated data alongside the distribution of actual data:

1
2
3
4
5
6
7
8
9
10

draws %>%
filter(.variable=='y_hat') %>%
ggplot(aes(x=.value, fill=.variable)) +
stat_slab(slab_alpha=.75) +
stat_slab(slab_alpha=.75, data=tibble(.variable='y', .value=spotify2021$streams)) +
geom_vline(xintercept=mean(spotify2021$streams)) +
scale_fill_discrete(name='') +
xlab('Streams (millions/week)') + ylab('Density') +
coord_cartesian(xlim=c(-5, 10)) +
theme_tidybayes()

We can see that even though our model captures the mean of the stream
counts (the black vertical line) very well, there are a few problems.
First and foremost, it predicts some negative stream counts. For the top
200 songs on Spotify, not only is a negative number of streams very
unlikely, it is also impossible. Second, it predicts that most stream
counts will be at the mean, but the data have a positive skew. Let’s try
to fix these two issues at once by using a log-normal distribution
instead of a Normal distribution. The log-normal distribution is simply
what you get when you exponentiate samples from the normal distribution:
*lognormal(μ,σ)* = *exp(Normal(μ,σ))*
. So let’s try this distribution out, adjusting our priors over `mu`

and
`sigma`

:

1
2
3

streams_model_lognormal <- cmdstan_model('2021-12-10-streams-lognormal.stan') ## compile the model
streams_lognormal <- streams_model_lognormal$sample(data=streams_data)
streams_lognormal

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 0.5 seconds.
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 0.5 seconds.
## Chain 3 finished in 0.6 seconds.
## Chain 4 finished in 0.6 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.5 seconds.
## Total execution time: 1.1 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## lp__ -5789.20 -5788.89 0.97 0.67 -5791.16 -5788.29 1.00 1990 2763
## mu 0.81 0.81 0.00 0.00 0.81 0.82 1.00 4259 3110
## sigma 0.45 0.45 0.00 0.00 0.44 0.45 1.00 3122 2651
## y_hat 2.51 2.28 1.20 0.97 1.08 4.77 1.00 3558 3888

1
2
3
4
5
6
7
8
9
10
11

streams_lognormal %>%
gather_draws(y_hat) %>%
ggplot(aes(x=.value, fill=.variable)) +
stat_slab(slab_alpha=.75, fill=NA, color='black', data=filter(draws, .variable=='y_hat') %>% mutate(.variable='y_hat (normal)')) +
stat_slab(slab_alpha=.75) +
stat_slab(slab_alpha=.75, data=tibble(.variable='y', .value=spotify2021$streams)) +
geom_vline(xintercept=mean(spotify2021$streams)) +
scale_fill_discrete(name='') +
xlab('Streams (millions/week)') + ylab('Density') +
coord_cartesian(xlim=c(-5, 10)) +
theme_tidybayes()

Clearly this model (blue) does a lot better at describing stream counts than the previous one (black line), but it’s not perfect either. Importantly, there is no single gold standard for model fit: a model that fits perfectly fine for some purposes may not be good for other purposes. So it is up to you, the modeler, to determine when your model is good enough to inspect.

# Linear Regression in Stan

Modeling the mean and standard deviation of just one datapoint is good
and well, but as scientists we’re usually more interested in *effects*.
To do that, we’re going to need to add some predictor variables to our
model. To switch things up, let’s say we want to predict the energy
level of a song given the song’s loudness. First, let’s take a look at
the data:

1
2
3

ggplot(spotify2021, aes(x=loudness, y=energy)) +
geom_point(alpha=.2) +
theme_tidybayes()

This certainly looks promising! Let’s write a Stan program to see if this is the case:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

data {
int<lower=0> N; // the number of data points
vector[N] x; // the loudness of each song
vector<lower=0, upper=1>[N] y; // the energy level of each song
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
transformed parameters {
vector[N] mu = alpha + beta*x;
}
model {
alpha ~ normal(.5, .5);
beta ~ normal(0, .1);
sigma ~ normal(0, 1);
y ~ normal(mu, sigma);
}
generated quantities {
real y_hat[N] = normal_rng(mu, sigma);
}

Hopefully by now most of this new model looks familiar: we’re modeling
energy as normally distributed with mean `mu`

and standard deviation
`sigma`

. However, now instead of estimating a single `mu`

, we’re
calculating `mu`

as a `transformed parameter`

based on three things.
Unsurprisingly, `x`

is the vector of loudness values for each track.
`alpha`

is the intercept, which represents the mean energy level when
`loudness == 0`

. And finally, `beta`

is the slope, which represents the
average change in energy for every decible increase in loudness. The
reason we declare `mu`

as a `transformed parameter`

instead of a regular
old `parameter`

is that it makes sampling more efficient: by doing so,
we’re telling Stan that mu is just some combination of the other
parameters, so we don’t need to sample it directly (we can just sample
alpha and beta). I’ve assigned normal priors for each parameter based on
sheer intuition: hopefully none of the results should vary if these are
set slightly differently.

The last thing to note is that now we’re estimating a unique `y_hat`

for
each individual data point. The reasoning behind this is that each data
point now has a unique prediction of energy (before, the estimates did
not depend on predictors).

1
2
3
4

energy_data <- list(N=nrow(spotify2021), x=spotify2021$loudness, y=spotify2021$energy)
energy_model <- cmdstan_model('2021-12-10-energy.stan') ## compile the model
energy <- energy_model$sample(data=energy_data)
energy

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 16.5 seconds.
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 17.5 seconds.
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 17.7 seconds.
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 19.0 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 17.7 seconds.
## Total execution time: 19.1 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## lp__ 15649.15 15649.50 1.20 1.04 15646.70 15650.50 1.00 1272 1456
## alpha 0.94 0.94 0.00 0.00 0.93 0.95 1.00 1541 1279
## beta 0.05 0.05 0.00 0.00 0.05 0.05 1.00 1615 1370
## sigma 0.11 0.11 0.00 0.00 0.11 0.12 1.00 1672 1556
## mu[1] 0.53 0.53 0.00 0.00 0.53 0.53 1.00 3470 2550
## mu[2] 0.55 0.55 0.00 0.00 0.54 0.55 1.00 3798 2745
## mu[3] 0.62 0.62 0.00 0.00 0.62 0.62 1.00 3817 2854
## mu[4] 0.77 0.77 0.00 0.00 0.76 0.77 1.00 1671 1721
## mu[5] 0.64 0.64 0.00 0.00 0.64 0.64 1.00 3293 2835
## mu[6] 0.45 0.45 0.00 0.00 0.44 0.45 1.00 2417 2182
##
## # showing 10 of 18804 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

The results show that there does appear to be a sizable increase in energy with loudness! Let’s plot the coefficients to make this more clear:

1
2
3
4
5
6
7

energy %>%
gather_draws(alpha, beta, sigma) %>%
ggplot(aes(x=.value)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Value') + ylab('Parameter') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

Since the coefficient `beta`

is clearly greater than zero, we can say we
found an effect! If you’re skeptical, it might help to know that the
frequentist parameter values are extremely similar:

1

summary(lm(energy ~ loudness, spotify2021))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

##
## Call:
## lm(formula = energy ~ loudness, data = spotify2021)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.39105 -0.06769 -0.00484 0.07276 0.48819
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.9410411 0.0038950 241.60 <2e-16 ***
## loudness 0.0491423 0.0005569 88.23 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.1147 on 9398 degrees of freedom
## Multiple R-squared: 0.4531, Adjusted R-squared: 0.453
## F-statistic: 7785 on 1 and 9398 DF, p-value: < 2.2e-16

Let’s plot our regression line on top of the data:

1
2
3
4
5
6
7
8
9
10

energy_draws <- energy %>%
spread_draws(mu[.row]) %>%
mutate(x=spotify2021$loudness[.row])
ggplot(spotify2021, aes(x=loudness, y=energy)) +
geom_point(alpha=.2) +
stat_lineribbon(aes(x=x, y=mu), data=energy_draws,
.width=.95, show.legend=FALSE) +
scale_fill_brewer() +
theme_tidybayes()

If needed, we can also use the posterior distribution to compute other
quantities of interest. For example, let’s calculate the coefficient of
variation, `R^2`

:

1
2
3
4
5
6
7
8
9
10
11
12

energy_R2 <- energy %>%
spread_draws(mu[.row]) %>%
left_join(tibble(.row=1:nrow(spotify2021), y=spotify2021$energy)) %>%
group_by(.draw) %>%
summarize(ss_total=sum((y-mean(y))^2),
ss_residual=sum((y-mu)^2),
R2=1 - ss_residual/ss_total)
ggplot(energy_R2, aes(x=R2)) +
stat_halfeye(point_interval=median_hdi) +
xlab('R^2') + ylab('Density') +
theme_tidybayes()

This plot shows us that our model is pretty darn good: it explains about 45% of the variance in energy!

## Entering the matrix

To round off this tutorial, let’s try to make our regression model a little more general. Right now, we only have one predictor variable coded into our model. What if we wanted to allow for more than one variable and interactions between variables? As we talked about in Pranjal’s fantastic tutorial on linear algebra, the way to achieve this is to use matrices. This might sound scary, but really the core idea is still the same.

Let’s go back to stream counts: presume that we want to know what makes the top 200 songs so successful. Is it their energy, their danceability, duration, or some combination of variables? To find out, let’s code a Stan model. To keep things simple, I’m going to ignore the skew in the data and fit a model with a normal likelihood.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

data {
int<lower=0> N; // the number of data points
int<lower=0> K; // the number of regression coefficients
matrix[N, K] X; // the predictor variables
vector[N] y; // the outcome variable
}
parameters {
real alpha;
vector[K] beta;
real<lower=0> sigma;
}
transformed parameters {
vector[N] mu = alpha + X*beta;
}
model {
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 1);
y ~ normal(mu, sigma);
}
generated quantities {
real y_hat[N] = normal_rng(mu, sigma);
}

As promised, there are only a few differences between this model and the
last. The most obvious difference is that while `x`

used to be an
`N`

-vector with one value for each data point, `X`

is now an `N`

by `K`

matrix with one row of `K`

predictors for each data point. To make this
clear, I changed from little `x`

to big `X`

in the code. The other main
difference is that before we used to have a single scalar `beta`

, which
represented the effect of `x`

on `y`

. But now that we have `K`

different
`x`

’s, `beta`

is now a `K`

-vector, where each element represents the
effect of the `K`

th predictor variable on `y`

. What’s nice about Stan is
that everything else works just as before! Stan recognizes that `beta*X`

is now a vector-matrix multiplication, and it can perform the whole
multiplication out with the same code. If you think it’s more clear, you
always have the option of writing out some loops for this
multiplication:

1
2
3
4
5
6
7
8
9
10
11
12
13

transformed parameters {
vector[N] mu;
// loop over data points
for (i in 1:N) {
mu[i] = alpha;
// loop over predictor variables
for (k in 1:K) {
mu[i] = mu[i] + beta[k]*X[n,k];
}
}
}

While the other version of the code was a single line, this version is 7
lines of code! And the worse part is that even though this code is
longer, it’s actually *slower* to execute, since Stan can internally
optimize matrix multiplication but it can’t internally optimize these
sorts of loops. So unless you need to expand out the multiplication,
it’s generally better to use the shorter version.

Now that we have our model, let’s try to predict stream counts using duration and tempo:

1
2
3
4
5
6

streams_data <- list(N=nrow(spotify2021), K=2,
X=select(spotify2021, duration_ms, tempo),
y=spotify2021$streams)
streams_model <- cmdstan_model('2021-12-10-streams-glm.stan') ## compile the model
streams <- streams_model$sample(data=streams_data)
streams

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 10.3 seconds.
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 24.6 seconds.
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 119.5 seconds.
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 247.5 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 100.5 seconds.
## Total execution time: 247.6 seconds.
## variable mean median sd mad q5 q95
## lp__ -5158838.53 -152326.00 8680417.44 210440.84 -20199705.00 -10386.00
## alpha 0.36 0.48 1.21 1.19 -1.45 1.93
## beta[1] 0.00 0.00 0.00 0.00 0.00 0.00
## beta[2] 0.20 0.01 0.51 0.17 -0.28 1.06
## sigma 1.79 1.50 1.09 0.91 0.65 3.32
## mu[1] -5.03 2.84 20.73 7.29 -40.00 14.76
## mu[2] 3.39 2.46 2.45 0.98 1.01 7.51
## mu[3] 4.02 2.35 4.42 1.72 -0.27 11.45
## mu[4] 3.73 2.13 4.42 1.91 -0.66 11.14
## mu[5] 7.77 2.66 13.40 4.58 -4.93 30.41
## rhat ess_bulk ess_tail
## 3.57 4 11
## 3.55 4 8
## 4.08 4 10
## 4.23 4 11
## 4.29 4 11
## 4.05 4 10
## 3.59 4 10
## 3.87 4 10
## 3.96 4 10
## 3.36 4 10
##
## # showing 10 of 18805 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Oh, that doesn’t look good… what went wrong? In addition to taking really long to fit, we get some scary warnings, and the R-hat values are huge! Let’s look at the chains:

1
2
3
4
5
6
7

streams$draws(c('alpha', 'beta', 'sigma'), format='draws_df') %>%
pivot_longer(alpha:sigma, names_to='.variable', values_to='.value') %>%
ggplot(aes(x=.iteration, y=.value, color=factor(.chain))) +
geom_line() + xlab('Iteration') + ylab('Value') +
scale_color_discrete(name='Chain') +
facet_grid(.variable ~ ., scales='free_y') +
theme_tidybayes()

These are some bad looking chains! To get an idea of what went wrong, let’s take another quick look at our data:

1

summary(streams_data$X)

1
2
3
4
5
6
7

## duration_ms tempo
## Min. : 52062 Min. : 40.32
## 1st Qu.:167916 1st Qu.: 97.69
## Median :195873 Median :121.97
## Mean :198911 Mean :121.97
## 3rd Qu.:221980 3rd Qu.:142.59
## Max. :690732 Max. :208.92

One thing stands out to me: the scale of `duration_ms`

is much much
larger than that of `tempo`

and `loudness`

. This can actually cause
problems in prior specification and model fitting, because Stan doesn’t
know that since `duration_ms`

is much larger, its `beta`

weights should
be very small, whereas the other `beta`

weights will probably be larger.
Although we could z-score all of our variables to have the same scale
for all of our betas, it would be nice to retain the scales of the
variables, since they are meaningful (e.g., beats per minute is more
inherently meaningful than standard deviations of beats per minute). So,
let’s try our model again, but using duration in seconds rather than
milliseconds:

1
2
3
4
5
6
7

spotify2021$duration_s <- spotify2021$duration_ms / 1000
streams_data <- list(N=nrow(spotify2021), K=2,
X=select(spotify2021, duration_s, tempo),
y=spotify2021$streams)
streams <- streams_model$sample(data=streams_data)
streams

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 26.5 seconds.
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 34.5 seconds.
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 38.0 seconds.
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 41.1 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 35.0 seconds.
## Total execution time: 41.2 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk
## lp__ -10279.87 -10279.50 1.37 1.19 -10282.50 -10278.30 1.01 1298
## alpha 2.00 1.99 0.12 0.12 1.81 2.19 1.00 1202
## beta[1] 0.00 0.00 0.00 0.00 0.00 0.00 1.00 1656
## beta[2] 0.00 0.00 0.00 0.00 0.00 0.00 1.00 1814
## sigma 1.81 1.81 0.01 0.01 1.79 1.83 1.00 1291
## mu[1] 2.65 2.65 0.04 0.04 2.59 2.71 1.00 2171
## mu[2] 2.54 2.54 0.02 0.02 2.51 2.57 1.00 2477
## mu[3] 2.51 2.51 0.02 0.02 2.47 2.55 1.00 1731
## mu[4] 2.42 2.42 0.04 0.04 2.36 2.48 1.01 1300
## mu[5] 2.62 2.62 0.02 0.02 2.58 2.66 1.00 3444
## ess_tail
## 1598
## 1268
## 1867
## 1965
## 1627
## 2466
## 2482
## 2296
## 1559
## 2866
##
## # showing 10 of 18805 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Much better! Let’s plot the model parameters:

1
2
3
4
5
6
7

streams$draws(c('alpha', 'beta', 'sigma'), format='draws_df') %>%
pivot_longer(alpha:sigma, names_to='.variable', values_to='.value') %>%
ggplot(aes(x=.value)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Value') + ylab('Parameter') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

These coefficients might look small, but both `beta[1]`

and `beta[2]`

appear to exclude 0. Since `beta[1]`

is the coefficient for
`duration_s`

, it says that every one-second increase in the duration of
a track is associated with an increase in `.001`

* one million = `1000`

streams on average! Likewise, `beta[2]`

says that an increase of one
beat per minute of the tempo is associated with an increase of `3000`

streams.

## Design matrices and you

It looks for sure like there are some effects of duration and tempo on stream counts in the top 200. Could it be that these effects interact with each other? In other words, could the effect of duration depend on tempo? It’s easy to imagine that longer songs could get boring if they are too slow, or perhaps they are overstimulating if they are too fast.

To find out, let’s run one more model. It turns out that we don’t even
need to code a new Stan program, we can use the last one! All we need to
do is set up our input to Stan in a way that allows for interaction
terms. We could do this manually, but there’s a nice trick to do it much
more easily: use the function `model.matrix`

. You can think of
`model.matrix`

as like the `lm`

function, except that it produces a
design matrix `X`

instead of fitting a linear model. It takes a formula
representing the variables to include in your model and a dataframe
including all of the relevant variables, and it produces an `N`

by `K`

matrix that we can give to Stan. Let’s take a look at our old design
matrix:

1

head(model.matrix(streams ~ duration_s + tempo, data=spotify2021))

1
2
3
4
5
6
7

## (Intercept) duration_s tempo
## 1 1 279.204 121.002
## 2 1 190.779 115.884
## 3 1 173.711 109.978
## 4 1 140.526 90.989
## 5 1 195.429 140.042
## 6 1 205.090 109.928

This matrix has a column of one’s for the intercept, a column for
`duration_s`

, a column for `tempo`

, and a column for `loudness`

. But
since we have the intercept `alpha`

defined separately in our model, we
can get rid of the intercept column using the notation `0 + ...`

.

1

head(model.matrix(streams ~ 0 + duration_s + tempo, data=spotify2021))

1
2
3
4
5
6
7

## duration_s tempo
## 1 279.204 121.002
## 2 190.779 115.884
## 3 173.711 109.978
## 4 140.526 90.989
## 5 195.429 140.042
## 6 205.090 109.928

Finally, let’s add in an interaction term:

1
2

X <- model.matrix(streams ~ 0 + duration_s * tempo, data=spotify2021)
head(X)

1
2
3
4
5
6
7

## duration_s tempo duration_s:tempo
## 1 279.204 121.002 33784.24
## 2 190.779 115.884 22108.23
## 3 173.711 109.978 19104.39
## 4 140.526 90.989 12786.32
## 5 195.429 140.042 27368.27
## 6 205.090 109.928 22545.13

We now have three columns for our three predictor variables, three columns for the two-way interactions, and one column for the three-way interaction. Let’s see what happens when we run the model:

1
2
3
4
5

streams_data <- list(N=nrow(spotify2021), K=ncol(X),
X=X, y=spotify2021$streams)
streams_model <- cmdstan_model('2021-12-10-streams-glm.stan') ## compile the model
streams <- streams_model$sample(data=streams_data)
streams

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 13.0 seconds.
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 15.9 seconds.
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 84.1 seconds.
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 193.7 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 76.7 seconds.
## Total execution time: 193.8 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk
## lp__ -10420.09 -10276.90 490.69 10.23 -10679.73 -10269.10 2.15 5
## alpha 0.72 1.24 1.25 0.45 -1.40 1.82 3.37 4
## beta[1] 0.01 0.01 0.01 0.00 0.00 0.02 2.43 4
## beta[2] 0.01 0.01 0.01 0.00 0.01 0.03 2.43 4
## beta[3] 0.00 0.00 0.00 0.00 0.00 0.00 2.41 4
## sigma 1.95 1.83 0.32 0.04 1.79 2.25 2.14 5
## mu[1] 2.63 2.63 0.06 0.04 2.56 2.71 1.13 20
## mu[2] 2.54 2.55 0.03 0.02 2.50 2.58 1.26 11
## mu[3] 2.50 2.51 0.05 0.04 2.42 2.56 1.66 6
## mu[4] 2.34 2.38 0.14 0.07 2.10 2.48 2.19 5
## ess_tail
## 18
## 8
## 38
## 29
## 28
## 18
## 44
## 31
## 26
## 29
##
## # showing 10 of 18806 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Ah, the model has *yet again* failed to converge. Looking at the above
design matrix, it looks like a similar thing could be going on: the
interaction variable seems to be much larger than the two other
variables. This makes sense, because interactions are just
multiplications of other variables (the data for `duration_s:tempo`

is
just `duration_s * tempo`

). So what can we do? Well, we have a couple
options. As we mentioned before, we could z-score our variables. But as
we mentioned before, that removes all of the information about the
scales, which makes it hard to tell the size of the effects. So we’re
going to scale down both `duration`

(to minutes) and `tempo`

(to beats
per second) individually to make sure that the interaction doesn’t blow
out of proportion.

1
2
3
4
5
6
7
8
9
10

spotify2021$duration_m <- spotify2021$duration_s / 60
spotify2021$tempo_bps <- spotify2021$tempo / 60
X <- model.matrix(streams ~ 0 + duration_m * tempo_bps, data=spotify2021)
head(X)
streams_data <- list(N=nrow(spotify2021), K=ncol(X),
X=X, y=spotify2021$streams)
streams <- streams_model$sample(data=streams_data)
streams

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

## duration_m tempo_bps duration_m:tempo_bps
## 1 4.653400 2.016700 9.384512
## 2 3.179650 1.931400 6.141176
## 3 2.895183 1.832967 5.306775
## 4 2.342100 1.516483 3.551756
## 5 3.257150 2.334033 7.602297
## 6 3.418167 1.832133 6.262537
## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 32.4 seconds.
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 34.5 seconds.
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 35.3 seconds.
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 36.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 34.6 seconds.
## Total execution time: 36.3 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk
## lp__ -10269.61 -10269.30 1.58 1.48 -10272.60 -10267.70 1.00 1112
## alpha 0.66 0.67 0.31 0.30 0.12 1.16 1.00 541
## beta[1] 0.47 0.46 0.09 0.09 0.32 0.62 1.00 533
## beta[2] 0.86 0.86 0.15 0.15 0.62 1.13 1.00 528
## beta[3] -0.21 -0.20 0.05 0.04 -0.29 -0.14 1.00 523
## sigma 1.81 1.81 0.01 0.01 1.79 1.83 1.00 1599
## mu[1] 2.63 2.63 0.04 0.04 2.57 2.69 1.00 4441
## mu[2] 2.54 2.54 0.02 0.02 2.51 2.57 1.00 4013
## mu[3] 2.50 2.50 0.02 0.02 2.46 2.54 1.00 3525
## mu[4] 2.33 2.33 0.04 0.04 2.26 2.39 1.00 1421
## ess_tail
## 1327
## 630
## 671
## 700
## 733
## 1615
## 3322
## 3168
## 2544
## 1970
##
## # showing 10 of 18806 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

It’s not perfect, but for our purposes it’ll do just fine. Let’s plot our parameters:

1
2
3
4
5
6
7

streams$draws(c('alpha', 'beta', 'sigma'), format='draws_df') %>%
pivot_longer(alpha:sigma, names_to='.variable', values_to='.value') %>%
ggplot(aes(x=.value)) +
stat_halfeye(point_interval=median_hdi, normalize='panels') +
xlab('Value') + ylab('Parameter') +
facet_wrap(~ .variable, scales='free') +
theme_tidybayes()

From our model’s results, it looks like we do indeed have a negative
interaction: even though stream counts go up by about 450,000 with each
additional minute (`beta[1]`

), and even though they go up by about
850,000 for each beat-per-second increase in tempo (`beta[2]`

), these
relationships reverse if a song is either too long or too fast. To
visualize how this relationships change, we can add to our Stan program
to calculate model predictions for any arbitrary `X`

:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

data {
int<lower=0> N; // the number of data points
int<lower=0> K; // the number of regression coefficients
matrix[N, K] X; // the predictor variables
vector[N] y; // the outcome variable
int<lower=0> N_pred; // the number of prediction points
matrix[N_pred, K] X_pred; // the prediction points
}
parameters {
real alpha;
vector[K] beta;
real<lower=0> sigma;
}
transformed parameters {
vector[N] mu = alpha + X*beta;
}
model {
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 1);
y ~ normal(mu, sigma);
}
generated quantities {
real y_hat[N] = normal_rng(mu, sigma);
vector[N_pred] mu_pred = alpha + X_pred*beta;
real y_pred_hat[N_pred] = normal_rng(mu_pred, sigma);
}

Now in addition to `X`

, we’re also giving Stan a different dataset
called `X_pred`

, which contains some values of the predictors we want
model predictions for. Since this data isn’t involved in training, we
don’t need a corresponding `y_pred`

. In the generated quantities, I’ve
added two new variables. `mu_pred`

is the model prediction for the mean
of `y`

at `X_pred`

, and `y_pred_hat`

are the simulated data at `X_pred`

.
Let’s try it out:

1
2
3
4
5
6
7
8
9
10
11

X_pred <- expand_grid(duration_m=seq(0, 10, 1),
tempo_bps=seq(0.5, 3, 0.5)) %>%
model.matrix( ~ 0 + duration_m * tempo_bps, .)
head(X_pred)
streams_data <- list(N=nrow(spotify2021), K=ncol(X),
X=X, y=spotify2021$streams,
N_pred=nrow(X_pred), X_pred=X_pred)
streams_model <- cmdstan_model('2021-12-10-streams-glm-pred.stan')
streams <- streams_model$sample(data=streams_data)
streams

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

## duration_m tempo_bps duration_m:tempo_bps
## 1 0 0.5 0
## 2 0 1.0 0
## 3 0 1.5 0
## 4 0 2.0 0
## 5 0 2.5 0
## 6 0 3.0 0
## Running MCMC with 4 chains, at most 20 in parallel...
##
## Chain 1 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 2000 [ 0%] (Warmup)
## Chain 1 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 1 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 1 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 2 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 2 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 3 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 3 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 3 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 4 Iteration: 100 / 2000 [ 5%] (Warmup)
## Chain 4 Iteration: 200 / 2000 [ 10%] (Warmup)
## Chain 4 Iteration: 300 / 2000 [ 15%] (Warmup)
## Chain 2 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 3 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 4 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 1 Iteration: 400 / 2000 [ 20%] (Warmup)
## Chain 2 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 1 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 3 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 4 Iteration: 500 / 2000 [ 25%] (Warmup)
## Chain 2 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 3 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 2 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 4 Iteration: 600 / 2000 [ 30%] (Warmup)
## Chain 1 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 3 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 2 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 700 / 2000 [ 35%] (Warmup)
## Chain 1 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 2 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 4 Iteration: 800 / 2000 [ 40%] (Warmup)
## Chain 3 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 1 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 2 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 2 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 900 / 2000 [ 45%] (Warmup)
## Chain 3 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 3 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 1 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 1 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 4 Iteration: 1000 / 2000 [ 50%] (Warmup)
## Chain 4 Iteration: 1001 / 2000 [ 50%] (Sampling)
## Chain 2 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 1 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 3 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 4 Iteration: 1100 / 2000 [ 55%] (Sampling)
## Chain 2 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 1 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 3 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 4 Iteration: 1200 / 2000 [ 60%] (Sampling)
## Chain 2 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 1 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 3 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 4 Iteration: 1300 / 2000 [ 65%] (Sampling)
## Chain 2 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 3 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 4 Iteration: 1400 / 2000 [ 70%] (Sampling)
## Chain 1 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 2 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 3 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 4 Iteration: 1500 / 2000 [ 75%] (Sampling)
## Chain 1 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 2 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 3 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 4 Iteration: 1600 / 2000 [ 80%] (Sampling)
## Chain 1 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 3 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 4 Iteration: 1700 / 2000 [ 85%] (Sampling)
## Chain 2 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 3 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 4 Iteration: 1800 / 2000 [ 90%] (Sampling)
## Chain 1 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 2 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 3 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 4 Iteration: 1900 / 2000 [ 95%] (Sampling)
## Chain 1 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 1 finished in 34.0 seconds.
## Chain 2 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 2 finished in 34.3 seconds.
## Chain 3 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 3 finished in 35.8 seconds.
## Chain 4 Iteration: 2000 / 2000 [100%] (Sampling)
## Chain 4 finished in 35.8 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 35.0 seconds.
## Total execution time: 35.9 seconds.
## variable mean median sd mad q5 q95 rhat ess_bulk
## lp__ -10269.62 -10269.30 1.57 1.48 -10272.60 -10267.70 1.00 1285
## alpha 0.67 0.67 0.31 0.30 0.15 1.19 1.01 683
## beta[1] 0.46 0.46 0.09 0.09 0.31 0.62 1.01 688
## beta[2] 0.86 0.86 0.15 0.15 0.60 1.11 1.01 680
## beta[3] -0.21 -0.20 0.04 0.04 -0.28 -0.13 1.01 690
## sigma 1.81 1.81 0.01 0.01 1.79 1.83 1.00 1578
## mu[1] 2.63 2.63 0.04 0.04 2.57 2.70 1.00 4807
## mu[2] 2.54 2.54 0.02 0.02 2.51 2.58 1.00 4563
## mu[3] 2.50 2.50 0.02 0.02 2.46 2.54 1.00 3801
## mu[4] 2.33 2.33 0.04 0.04 2.26 2.40 1.00 1556
## ess_tail
## 1905
## 908
## 923
## 835
## 794
## 1568
## 3453
## 3361
## 3323
## 2485
##
## # showing 10 of 18938 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Thankfully, the results are pretty much the same. Let’s see what those predictions look like!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

streams %>%
spread_draws(mu_pred[.row]) %>%
mutate(duration_m=X_pred[.row, 'duration_m'],
tempo_bps=X_pred[.row, 'tempo_bps']) %>%
group_by(duration_m, tempo_bps) %>%
median_hdi %>%
ggplot(aes(x=duration_m, y=mu_pred, group=tempo_bps, fill=tempo_bps)) +
geom_point(aes(y=streams, color=tempo_bps), alpha=.25, size=.25, data=spotify2021) +
geom_ribbon(aes(ymin=.lower, ymax=.upper), alpha=.25) +
geom_line(aes(color=tempo_bps)) +
xlab('Duration (minutes)') + ylab('Streams (millions)') +
scale_color_viridis(name='Tempo\n(beats/second)') +
scale_fill_viridis(name='Tempo\n(beats/second)') +
coord_cartesian(xlim=c(1, 6), ylim=c(0, 5)) +
theme_tidybayes()

We can see the overall trends we found earlier: for slower songs, there is a positive relationship between duration and stream count, but for faster songs this trend is reversed (with the switch from positive to negative happening around 2 beats/second). Most importantly, though, our model doesn’t seem to be describing the data super well. If our goal is just to describe some overall patterns in the data, our current model is probably fine. But if we are interested in capturing the data-generating process of stream counts, then we would probably want to refine this model some more until the model-generated data looks reasonably like the actual data.

# Summary

In this tutorial, we learned what probabilistic programming is, and we walked through some of the many ways in which you can use the probabilistic programming language Stan. We simulated fake data from a model, simulated fake data from a prior distribution, fit a few different linear models to some data, and visualized linear trends in the data! Although this is a lot, there’s so much more that you can do with Stan, so expect to see it in more tutorials to come. Until next time!