flippy
FlipPy lets you specify probabilistic programs in Python syntax while seamlessly interacting with the rest of Python.
Quick start
pip install flippy-lang
Example: Sum of bernoullis
Here, we flip two coins (heads = 1, tails = 0), condition on them being equal, and return their sum.
from flippy import infer, flip, condition
@infer
def fn():
x = flip(0.5)
y = flip(0.5)
condition(x == y)
return x + y
result = fn()
dict(result) # {0: 0.5, 2: 0.5}
Documentation
Here is the documentation for writing models in FlipPy.
- The core API for declaring a model (link)
- Specifying distributions (link)
- Selecting inference algorithms (link)
Tutorials
Click here to launch an interactive environment with tutorial notebooks, or clone the tutorial repo from here and run the notebooks locally. Statically rendered versions of the notebooks are also available:
- Introductory tutorial
- Rational Speech Acts (RSA)
- Language of Thought (LoT)
- Hidden Markov Models (HMMs)
- Bayesian Non-parametrics
- Intuitive Physics
- Sequential Decision-Making
API
1''' 2FlipPy lets you specify probabilistic programs in Python syntax 3while seamlessly interacting with the rest of Python. 4 5# Quick start 6 7```bash 8pip install flippy-lang 9``` 10 11# Example: Sum of bernoullis 12 13Here, we flip two coins (heads = 1, tails = 0), condition on them being equal, 14and return their sum. 15 16```python 17from flippy import infer, flip, condition 18 19@infer 20def fn(): 21 x = flip(0.5) 22 y = flip(0.5) 23 condition(x == y) 24 return x + y 25 26result = fn() 27dict(result) # {0: 0.5, 2: 0.5} 28``` 29 30# Documentation 31 32Here is the documentation for writing models in FlipPy. 33- The core API for declaring a model ([link](#api)) 34- Specifying distributions ([link](flippy/distributions)) 35- Selecting inference algorithms ([link](flippy/inference)) 36 37# Tutorials 38Click [here](https://mybinder.org/v2/gh/codec-lab/flippy-tutorials/main?urlpath=%2Fdoc%2Ftree%2Fnotebooks%2F00-intro.ipynb) 39to launch an interactive environment with tutorial notebooks, or clone the 40tutorial repo from [here](https://github.com/codec-lab/flippy-tutorials) and run the notebooks locally. 41Statically rendered versions of the notebooks are also available: 42 43- [Introductory tutorial](https://codec-lab.github.io/flippy-tutorials/) 44- [Rational Speech Acts (RSA)](https://codec-lab.github.io/flippy-tutorials/01-RSA) 45- [Language of Thought (LoT)](https://codec-lab.github.io/flippy-tutorials/02-LoT) 46- [Hidden Markov Models (HMMs)](https://codec-lab.github.io/flippy-tutorials/03-HMMs) 47- [Bayesian Non-parametrics](https://codec-lab.github.io/flippy-tutorials/04-DP-MM) 48- [Intuitive Physics](https://codec-lab.github.io/flippy-tutorials/05-Physics) 49- [Sequential Decision-Making](https://codec-lab.github.io/flippy-tutorials/06-Sequential-DM) 50 51# API 52 53''' 54 55import math 56from typing import Callable, Sequence, Union, TypeVar, overload, Generic 57 58from flippy.transforms import CPSTransform 59from flippy.inference import \ 60 SimpleEnumeration, Enumeration, SamplePrior, MetropolisHastings, \ 61 LikelihoodWeighting, InferenceAlgorithm 62from flippy.distributions import Categorical, Bernoulli, Distribution, Uniform,\ 63 Element, Normal 64from flippy.distributions.random import default_rng 65from flippy.distributions.base import _factor_dist 66from flippy.core import global_store 67from flippy.hashable import hashabledict 68from flippy.map import recursive_map 69from flippy.tools import LRUCache 70 71from flippy.interpreter import CPSInterpreter, keep_deterministic, \ 72 cps_transform_safe_decorator, DescriptorMixIn 73 74__all__ = [ 75 'infer', 76 77 'flip', 78 'draw_from', 79 'uniform', 80 'normal', 81 82 'factor', 83 'condition', 84 # 'map_observe', 85 'keep_deterministic', 86 'mem', 87 88 # submodules 89 90 # Model specification 91 'distributions', 92 # Inference algorithms 93 'inference', 94 95 # Execution model 96 'core', 97 'callentryexit', 98 # 'map', 99] 100 101class InferCallable(Generic[Element], DescriptorMixIn): 102 ''' 103 @private 104 ''' 105 def __init__( 106 self, 107 func: Callable[..., Element], 108 method : Union[type[InferenceAlgorithm], str] = "Enumeration", 109 cache_size=0, 110 **kwargs 111 ): 112 DescriptorMixIn.__init__(self, func) 113 114 if isinstance(method, str): 115 method : type[InferenceAlgorithm] = { 116 'Enumeration': Enumeration, 117 'SimpleEnumeration': SimpleEnumeration, 118 'SamplePrior': SamplePrior, 119 'MetropolisHastings': MetropolisHastings, 120 'LikelihoodWeighting' : LikelihoodWeighting 121 }[method] 122 self.cache_size = cache_size 123 self.cache = LRUCache(cache_size) 124 self.method = method 125 self.kwargs = kwargs 126 self.func = func 127 self.inference_alg = None 128 setattr(self, CPSTransform.is_transformed_property, True) 129 130 def _lazy_init(self): 131 if self.inference_alg is not None: 132 return 133 func = self.func 134 if isinstance(func, (classmethod, staticmethod)): 135 func = func.__func__ 136 if not CPSTransform.is_transformed(func): 137 func = CPSInterpreter().non_cps_callable_to_cps_callable(func) 138 self.inference_alg = self.method(func, **self.kwargs) 139 140 if not self.inference_alg.is_cachable: 141 self.cache_size = 0 142 143 def __call__(self, *args, _cont=None, _cps=None, _stack=None, **kws) -> Distribution[Element]: 144 self._lazy_init() 145 if self.cache_size > 0: 146 kws_tuple = tuple(sorted(kws.items())) 147 if (args, kws_tuple) in self.cache: 148 dist = self.cache[args, kws_tuple] 149 else: 150 dist = self.inference_alg.run(*args, **kws) 151 self.cache[args, kws_tuple] = dist 152 else: 153 dist = self.inference_alg.run(*args, **kws) 154 if _cont is None: 155 return dist 156 else: 157 return lambda : _cont(dist) 158 159def infer( 160 func: Callable[..., Element]=None, 161 method=Enumeration, 162 cache_size=1024, 163 **kwargs 164) -> InferCallable[Element]: 165 ''' 166 Turns a function into a stochastic function, that represents a posterior distribution. 167 168 This is the main interface for performing inference in FlipPy. 169 170 - `method` specifies the inference method and can either be an instance of 171 an `InferenceAlgorithm` or a string. Defaults to `Enumeration`. 172 - `**kwargs` are keyword arguments passed to the inference method. 173 ''' 174 return InferCallable(func, method, cache_size, **kwargs) 175infer = cps_transform_safe_decorator(infer) 176 177# type hints for infer - if we can use ParamSpecs this will be cleaner 178InferenceType = Callable[[Callable[..., Element]], InferCallable[Element]] 179infer : Callable[..., Union[InferCallable, InferenceType]] 180 181def recursive_filter(fn, iter): 182 if not iter: 183 return [] 184 if fn(iter[0]): 185 head = [iter[0]] 186 else: 187 head = [] 188 return head + recursive_filter(fn, iter[1:]) 189 190def recursive_reduce(fn, iter, initializer): 191 if len(iter) == 0: 192 return initializer 193 return recursive_reduce(fn, iter[1:], fn(initializer, iter[0])) 194 195def factor(score): 196 ''' 197 Adds a real-valued `score` (i.e., log-probability) to the weight of the 198 current trace. 199 ''' 200 _factor_dist.observe(score) 201 202def condition(cond: float): 203 ''' 204 Used for conditioning statements. When `cond` is a boolean, this behaves like 205 typical conditioning. 206 207 - `cond` is a non-negative multiplicative weight for the conditioning. When zero, 208 the trace is assigned zero probability. 209 ''' 210 if cond == 0: 211 _factor_dist.observe(-float("inf")) 212 else: 213 _factor_dist.observe(math.log(cond)) 214 215def flip(p=.5, name=None): 216 ''' 217 Samples from a Bernoulli distribution with probability `p`. 218 ''' 219 return bool(Bernoulli(p).sample(name=name)) 220 221@keep_deterministic 222def _draw_from_dist(n: Union[Sequence[Element], int]) -> Distribution[Element]: 223 if isinstance(n, int): 224 return Categorical(range(n)) 225 if hasattr(n, '__getitem__'): 226 return Categorical(n) 227 else: 228 return Categorical(list(n)) 229 230@overload 231def draw_from(n: int) -> int: 232 ... 233@overload 234def draw_from(n: Sequence[Element]) -> Element: 235 ... 236def draw_from(n: Union[Sequence[Element], int]) -> Element: 237 ''' 238 Samples uniformly from `n` when it is a sequence. 239 When `n` is an integer, a sample is drawn from `range(n)`. 240 ''' 241 return _draw_from_dist(n).sample() 242 243def mem(fn: Callable[..., Element]) -> Callable[..., Element]: 244 ''' 245 Turns a function into a stochastically memoized function. 246 Stores information in trace-specific storage. 247 ''' 248 def mem_wrapper(*args, **kws): 249 key = (fn, args, tuple(sorted(kws.items()))) 250 kws = hashabledict(kws) 251 if key in global_store: 252 return global_store.get(key) 253 else: 254 value = fn(*args, **kws) 255 global_store.set(key, value) 256 return value 257 return mem_wrapper 258mem = cps_transform_safe_decorator(mem) 259 260_uniform = Uniform() 261def uniform(): 262 ''' 263 Samples from a uniform distribution over the interval $[0, 1]$. 264 ''' 265 return _uniform.sample() 266 267_normal = Normal(0, 1) 268def normal(mean=0, std=1): 269 ''' 270 Samples from a standard normal distribution. 271 ''' 272 return mean + std * _normal.sample() 273 274@keep_deterministic 275def map_log_probability(distribution: Distribution[Element], values: Sequence[Element]) -> float: 276 return sum(distribution.log_probability(i) for i in values) 277 278def map_observe(distribution: Distribution[Element], values: Sequence[Element]) -> float: 279 """ 280 Calculates the total log probability of a sequence of 281 independent values from a distribution. 282 """ 283 log_prob = map_log_probability(distribution, values) 284 factor(log_prob) 285 return log_prob
160def infer( 161 func: Callable[..., Element]=None, 162 method=Enumeration, 163 cache_size=1024, 164 **kwargs 165) -> InferCallable[Element]: 166 ''' 167 Turns a function into a stochastic function, that represents a posterior distribution. 168 169 This is the main interface for performing inference in FlipPy. 170 171 - `method` specifies the inference method and can either be an instance of 172 an `InferenceAlgorithm` or a string. Defaults to `Enumeration`. 173 - `**kwargs` are keyword arguments passed to the inference method. 174 ''' 175 return InferCallable(func, method, cache_size, **kwargs)
Turns a function into a stochastic function, that represents a posterior distribution.
This is the main interface for performing inference in FlipPy.
methodspecifies the inference method and can either be an instance of anInferenceAlgorithmor a string. Defaults toEnumeration.**kwargsare keyword arguments passed to the inference method.
216def flip(p=.5, name=None): 217 ''' 218 Samples from a Bernoulli distribution with probability `p`. 219 ''' 220 return bool(Bernoulli(p).sample(name=name))
Samples from a Bernoulli distribution with probability p.
237def draw_from(n: Union[Sequence[Element], int]) -> Element: 238 ''' 239 Samples uniformly from `n` when it is a sequence. 240 When `n` is an integer, a sample is drawn from `range(n)`. 241 ''' 242 return _draw_from_dist(n).sample()
Samples uniformly from n when it is a sequence.
When n is an integer, a sample is drawn from range(n).
262def uniform(): 263 ''' 264 Samples from a uniform distribution over the interval $[0, 1]$. 265 ''' 266 return _uniform.sample()
Samples from a uniform distribution over the interval $[0, 1]$.
269def normal(mean=0, std=1): 270 ''' 271 Samples from a standard normal distribution. 272 ''' 273 return mean + std * _normal.sample()
Samples from a standard normal distribution.
196def factor(score): 197 ''' 198 Adds a real-valued `score` (i.e., log-probability) to the weight of the 199 current trace. 200 ''' 201 _factor_dist.observe(score)
Adds a real-valued score (i.e., log-probability) to the weight of the
current trace.
203def condition(cond: float): 204 ''' 205 Used for conditioning statements. When `cond` is a boolean, this behaves like 206 typical conditioning. 207 208 - `cond` is a non-negative multiplicative weight for the conditioning. When zero, 209 the trace is assigned zero probability. 210 ''' 211 if cond == 0: 212 _factor_dist.observe(-float("inf")) 213 else: 214 _factor_dist.observe(math.log(cond))
Used for conditioning statements. When cond is a boolean, this behaves like
typical conditioning.
condis a non-negative multiplicative weight for the conditioning. When zero, the trace is assigned zero probability.
560def keep_deterministic(fn: Callable[..., R]) -> Callable[..., R]: 561 ''' 562 Decorator to interpret a function as deterministic Python. 563 Any random sampling in the function will not be targeted for inference. 564 Any conditioning will not be incorporated into the inference process. 565 566 This is helpful if the transform slows a function down, if a 567 deterministic library is being called, or if a distribution is being 568 directly computed. 569 ''' 570 return KeepDeterministicCallable(fn)
Decorator to interpret a function as deterministic Python. Any random sampling in the function will not be targeted for inference. Any conditioning will not be incorporated into the inference process.
This is helpful if the transform slows a function down, if a deterministic library is being called, or if a distribution is being directly computed.
244def mem(fn: Callable[..., Element]) -> Callable[..., Element]: 245 ''' 246 Turns a function into a stochastically memoized function. 247 Stores information in trace-specific storage. 248 ''' 249 def mem_wrapper(*args, **kws): 250 key = (fn, args, tuple(sorted(kws.items()))) 251 kws = hashabledict(kws) 252 if key in global_store: 253 return global_store.get(key) 254 else: 255 value = fn(*args, **kws) 256 global_store.set(key, value) 257 return value 258 return mem_wrapper
Turns a function into a stochastically memoized function. Stores information in trace-specific storage.