Hidden Markov Models¶
This notebook implements a simple example of a Hidden Markov Model (HMM) and demonstrates how different implementations have different runtime characteristics.
Prerequisites¶
This tutorial assumes you have completed the Introduction to FlipPy tutorial.
Learning Objectives¶
By the end of this tutorial, you will be able to:
- Understand the structure of Hidden Markov Models
- Implement HMM inference in FlipPy
- Recognize why naive inference is slow and how to fix it
What are Hidden Markov Models?¶
HMMs are probabilistic models for sequential data where:
- There's a sequence of hidden states we can't observe directly
- At each timestep, we receive an observation that depends on the hidden state
- States transition according to a Markov process (next state depends only on current state)
Real-world applications:
- Speech recognition: Hidden phonemes → observed audio
- Genomics: Hidden gene functions → observed DNA sequences
- Finance: Hidden market regimes → observed prices
- Weather prediction: Hidden atmospheric states → observed conditions
from flippy import infer, condition, flip, draw_from
import pandas as pd
import seaborn as sns
The Rainy/Sunny HMM Example¶
This classic example (from Russell & Norvig, 2010) models weather prediction:
- Hidden states: The actual weather (rainy or sunny)
- Observations: Whether someone is carrying an umbrella
Hidden: rainy ──→ sunny ──→ rainy ──→ sunny ──→ ...
↓ ↓ ↓ ↓
Observed: umbrella no umbrella umbrella umbrella
HMM Components:
| Component | Notation | Description |
|---|---|---|
| Initial distribution | $P(x_0)$ | Probability of starting in each state |
| Transition model | $P(x_{t+1} \mid x_t)$ | How states evolve over time |
| Sensor/observation model | $P(y_t \mid x_t)$ | How observations depend on states |
Our model's parameters:
Transition model:
| Transition | From Rainy | From Sunny |
|---|---|---|
| → Rainy | 70% | 30% |
| → Sunny | 30% | 70% |
Observation model:
| Observation | When Rainy | When Sunny |
|---|---|---|
| Umbrella | 90% | 20% |
| No umbrella | 10% | 80% |
The joint distribution over states and observations is: $$ P(x_{1:T}, y_{1:T}) = P(x_0)\prod_{t = 1}^T P(y_t \mid x_t) P(x_{t+1} \mid x_{t}) $$
def transition(state):
if state == "rainy":
next_state = "rainy" if flip(.7) else "sunny"
else:
next_state = "rainy" if flip(.3) else "sunny"
return next_state
def sensor_model(state):
if state == "rainy":
return "umbrella" if flip(.9) else "no umbrella"
else:
return "umbrella" if flip(.2) else "no umbrella"
def generate_sequence(length):
state_seq = ()
obs_seq = ()
state = draw_from(["rainy", "sunny"])
for t in range(length):
obs = sensor_model(state)
obs_seq += (obs,)
state_seq += (state,)
state = transition(state)
return state_seq, obs_seq
Inferring latent states from observations (version 1)¶
HMMs can be used to perform smoothing, in which we observe a sequence of observations $y_{1:T}$ and attempt to infer the hidden state at each time $t$, where $1 \leq t \leq T$. That is, for each $t$, calculate: $$ P(x_t \mid y_{1:T}) = \frac{\sum_{x_{1:t-1}, x_{t-1:T}}P(x_{1:T}, y_{1:T})}{\sum_{x_{1:T}}P(x_{1:T}, y_{1:T})} $$
In FlipPy, the simplest way to implement this inference is to infer a posterior over state sequences given observations and then compute marginals for each timestep like in the following cells.
@infer
def state_seq_model(obs_seq):
state = draw_from(["rainy", "sunny"])
state_seq = ()
for obs in obs_seq:
# Save the state at time t
state_seq += (state,)
# Generate the observation
obs_ = sensor_model(state)
# Condition on the observation
condition(obs == obs_)
# Transition to the next state
state = transition(state)
return state_seq
obs_seq = ("umbrella", "umbrella", "umbrella", "umbrella", "no umbrella")
state_seq_dist = state_seq_model(obs_seq)
state_seq_marginals_df = []
for t in range(len(obs_seq)):
marginal_dist = state_seq_dist.marginalize(lambda seq: seq[t])
state_seq_marginals_df.append({'t': t, **dict(marginal_dist)})
ax = sns.lineplot(
data=pd.DataFrame(state_seq_marginals_df),
x='t', y='rainy'
)
ax.set_ylim(0, 1)
(0.0, 1.0)
Why is This Slow?¶
The naive approach above computes the full joint distribution over all possible state sequences. The problem is that the number of possible sequences grows exponentially with the sequence length:
| Sequence Length | Number of Possible Sequences | Computation |
|---|---|---|
| 5 | $2^5 = 32$ | Fast |
| 10 | $2^{10} = 1,024$ | Okay |
| 15 | $2^{15} = 32,768$ | Slow |
| 20 | $2^{20} ≈ 1$ million | Very slow |
This defeats the purpose of HMMs! The Markov property tells us that the future is independent of the past given the present—we should be able to exploit this for efficient computation.
The fix: Instead of computing the joint over sequences, compute marginals at each timestep independently. This is the key insight behind the forward-backward algorithm.
Try running the cell below—it takes about 30 seconds because it's enumerating all $2^{15} = 32,768$ possible state sequences:
# Running this cell will take about 30 seconds!
long_state_seq_dist = state_seq_model(obs_seq*3)
long_state_seq_marginals_df = []
for t in range(len(obs_seq)*3):
marginal_dist = long_state_seq_dist.marginalize(lambda seq: seq[t])
long_state_seq_marginals_df.append({'t': t, **dict(marginal_dist)})
ax = sns.lineplot(
data=pd.DataFrame(long_state_seq_marginals_df),
x='t', y='rainy'
)
ax.set_ylim(0, 1)
(0.0, 1.0)
The Efficient Approach: Compute Marginals Directly¶
The key insight is that we don't need the full joint distribution—we only need the marginal probability of each state at each timestep. By restructuring our code to return the state at a specific time t (rather than the full sequence), FlipPy can exploit the Markov structure.
Why does this work?
When we call state_marginal_model(obs, t=2), FlipPy only needs to track:
- The probability distribution over states at time 2
- How observations before and after time 2 constrain this distribution
It doesn't need to enumerate all possible sequences—it can sum over "past" and "future" states efficiently.
Computational complexity:
- Naive approach: $O(|S|^T)$ where $|S|$ is the number of states and $T$ is sequence length
- Efficient approach: $O(|S|^2 \cdot T)$—linear in sequence length!
This is the same insight behind the classic forward-backward algorithm, but expressed naturally in a probabilistic programming framework.
@infer
def state_marginal_model(observations, t):
assert t < len(observations)
state = draw_from(["rainy", "sunny"])
# We will return this at the end
state_t = None
for t_, obs in list(enumerate(observations)):
# Save the state at time t
if t_ == t:
state_t = state
# generate observation
obs_ = sensor_model(state)
# Condition on the observation
condition(obs == obs_)
# Transition to the next state
state = transition(state)
return state_t
state_marginals_df = []
for t in range(len(obs_seq)*3):
marginal_dist = state_marginal_model(obs_seq*3, t)
state_marginals_df.append({'t': t, **dict(marginal_dist)})
ax = sns.lineplot(
data=pd.DataFrame(state_marginals_df),
x='t', y='rainy',
)
ax.set_ylim(0, 1)
(0.0, 1.0)
Interpreting the results: The plot shows the probability of rain at each timestep, given the full observation sequence. Notice:
- When we see umbrellas (timesteps 0-4, 5-9, 10-14), the probability of rain is high
- When we see "no umbrella" (timesteps 4, 9, 14), the probability drops
- The effect of observations propagates both forward and backward in time (this is "smoothing")
Summary¶
In this tutorial, you learned:
- HMMs model sequential data with hidden states and observations
- Naive inference over full sequences is exponentially slow
- Marginal inference exploits the Markov property for linear-time computation
- FlipPy can automatically exploit this structure when you return marginals instead of sequences
Key takeaway: How you structure your probabilistic program affects efficiency. Returning what you actually need (marginals) rather than intermediate representations (full sequences) can yield dramatic speedups.
Extensions¶
The HMM framework can be extended to:
- Continuous observations: Replace the discrete sensor model with Gaussian observations
- Continuous states: Leads to Kalman filters and particle filters
- Switching dynamics: Different transition models in different regimes
- Input-driven transitions: HMMs where actions influence state transitions (leading to POMDPs)