flippy

FlipPy lets you specify probabilistic programs in Python syntax while seamlessly interacting with the rest of Python.

Quick start

pip install flippy-lang

Example: Sum of bernoullis

Here, we flip two coins (heads = 1, tails = 0), condition on them being equal, and return their sum.

from flippy import infer, flip, condition

@infer
def fn():
    x = flip(0.5)
    y = flip(0.5)
    condition(x == y)
    return x + y

result = fn()
dict(result) # {0: 0.5, 2: 0.5}

Documentation

Here is the documentation for writing models in FlipPy.

  • The core API for declaring a model (link)
  • Specifying distributions (link)
  • Selecting inference algorithms (link)

Tutorials

Click here to launch an interactive environment with tutorial notebooks, or clone the tutorial repo from here and run the notebooks locally. Statically rendered versions of the notebooks are also available:

API

  1'''
  2FlipPy lets you specify probabilistic programs in Python syntax
  3while seamlessly interacting with the rest of Python.
  4
  5# Quick start
  6
  7```bash
  8pip install flippy-lang
  9```
 10
 11# Example: Sum of bernoullis
 12
 13Here, we flip two coins (heads = 1, tails = 0), condition on them being equal,
 14and return their sum.
 15
 16```python
 17from flippy import infer, flip, condition
 18
 19@infer
 20def fn():
 21    x = flip(0.5)
 22    y = flip(0.5)
 23    condition(x == y)
 24    return x + y
 25
 26result = fn()
 27dict(result) # {0: 0.5, 2: 0.5}
 28```
 29
 30# Documentation
 31
 32Here is the documentation for writing models in FlipPy.
 33- The core API for declaring a model ([link](#api))
 34- Specifying distributions ([link](flippy/distributions))
 35- Selecting inference algorithms ([link](flippy/inference))
 36
 37# Tutorials
 38Click [here](https://mybinder.org/v2/gh/codec-lab/flippy-tutorials/main?urlpath=%2Fdoc%2Ftree%2Fnotebooks%2F00-intro.ipynb)
 39to launch an interactive environment with tutorial notebooks, or clone the
 40tutorial repo from [here](https://github.com/codec-lab/flippy-tutorials) and run the notebooks locally.
 41Statically rendered versions of the notebooks are also available:
 42
 43- [Introductory tutorial](https://codec-lab.github.io/flippy-tutorials/)
 44- [Rational Speech Acts (RSA)](https://codec-lab.github.io/flippy-tutorials/01-RSA)
 45- [Language of Thought (LoT)](https://codec-lab.github.io/flippy-tutorials/02-LoT)
 46- [Hidden Markov Models (HMMs)](https://codec-lab.github.io/flippy-tutorials/03-HMMs)
 47- [Bayesian Non-parametrics](https://codec-lab.github.io/flippy-tutorials/04-DP-MM)
 48- [Intuitive Physics](https://codec-lab.github.io/flippy-tutorials/05-Physics)
 49- [Sequential Decision-Making](https://codec-lab.github.io/flippy-tutorials/06-Sequential-DM)
 50
 51# API
 52
 53'''
 54
 55import math
 56from typing import Callable, Sequence, Union, TypeVar, overload, Generic
 57
 58from flippy.transforms import CPSTransform
 59from flippy.inference import \
 60    SimpleEnumeration, Enumeration, SamplePrior, MetropolisHastings, \
 61    LikelihoodWeighting, InferenceAlgorithm
 62from flippy.distributions import Categorical, Bernoulli, Distribution, Uniform,\
 63    Element, Normal
 64from flippy.distributions.random import default_rng
 65from flippy.distributions.base import _factor_dist
 66from flippy.core import global_store
 67from flippy.hashable import hashabledict
 68from flippy.map import recursive_map
 69from flippy.tools import LRUCache
 70
 71from flippy.interpreter import CPSInterpreter, keep_deterministic, \
 72    cps_transform_safe_decorator, DescriptorMixIn
 73
 74__all__ = [
 75    'infer',
 76
 77    'flip',
 78    'draw_from',
 79    'uniform',
 80    'normal',
 81
 82    'factor',
 83    'condition',
 84    # 'map_observe',
 85    'keep_deterministic',
 86    'mem',
 87
 88    # submodules
 89
 90    # Model specification
 91    'distributions',
 92    # Inference algorithms
 93    'inference',
 94
 95    # Execution model
 96    'core',
 97    'callentryexit',
 98    # 'map',
 99]
100
101class InferCallable(Generic[Element], DescriptorMixIn):
102    '''
103    @private
104    '''
105    def __init__(
106        self,
107        func: Callable[..., Element],
108        method : Union[type[InferenceAlgorithm], str] = "Enumeration",
109        cache_size=0,
110        **kwargs
111    ):
112        DescriptorMixIn.__init__(self, func)
113
114        if isinstance(method, str):
115            method : type[InferenceAlgorithm] = {
116                'Enumeration': Enumeration,
117                'SimpleEnumeration': SimpleEnumeration,
118                'SamplePrior': SamplePrior,
119                'MetropolisHastings': MetropolisHastings,
120                'LikelihoodWeighting' : LikelihoodWeighting
121            }[method]
122        self.cache_size = cache_size
123        self.cache = LRUCache(cache_size)
124        self.method = method
125        self.kwargs = kwargs
126        self.func = func
127        self.inference_alg = None
128        setattr(self, CPSTransform.is_transformed_property, True)
129
130    def _lazy_init(self):
131        if self.inference_alg is not None:
132            return
133        func = self.func
134        if isinstance(func, (classmethod, staticmethod)):
135            func = func.__func__
136        if not CPSTransform.is_transformed(func):
137            func = CPSInterpreter().non_cps_callable_to_cps_callable(func)
138        self.inference_alg = self.method(func, **self.kwargs)
139
140        if not self.inference_alg.is_cachable:
141            self.cache_size = 0
142
143    def __call__(self, *args, _cont=None, _cps=None, _stack=None, **kws) -> Distribution[Element]:
144        self._lazy_init()
145        if self.cache_size > 0:
146            kws_tuple = tuple(sorted(kws.items()))
147            if (args, kws_tuple) in self.cache:
148                dist = self.cache[args, kws_tuple]
149            else:
150                dist = self.inference_alg.run(*args, **kws)
151                self.cache[args, kws_tuple] = dist
152        else:
153            dist = self.inference_alg.run(*args, **kws)
154        if _cont is None:
155            return dist
156        else:
157            return lambda : _cont(dist)
158
159def infer(
160    func: Callable[..., Element]=None,
161    method=Enumeration,
162    cache_size=1024,
163    **kwargs
164) -> InferCallable[Element]:
165    '''
166    Turns a function into a stochastic function, that represents a posterior distribution.
167
168    This is the main interface for performing inference in FlipPy.
169
170    - `method` specifies the inference method and can either be an instance of
171    an `InferenceAlgorithm` or a string. Defaults to `Enumeration`.
172    - `**kwargs` are keyword arguments passed to the inference method.
173    '''
174    return InferCallable(func, method, cache_size, **kwargs)
175infer = cps_transform_safe_decorator(infer)
176
177# type hints for infer - if we can use ParamSpecs this will be cleaner
178InferenceType = Callable[[Callable[..., Element]], InferCallable[Element]]
179infer : Callable[..., Union[InferCallable, InferenceType]]
180
181def recursive_filter(fn, iter):
182    if not iter:
183        return []
184    if fn(iter[0]):
185        head = [iter[0]]
186    else:
187        head = []
188    return head + recursive_filter(fn, iter[1:])
189
190def recursive_reduce(fn, iter, initializer):
191    if len(iter) == 0:
192        return initializer
193    return recursive_reduce(fn, iter[1:], fn(initializer, iter[0]))
194
195def factor(score):
196    '''
197    Adds a real-valued `score` (i.e., log-probability) to the weight of the
198    current trace.
199    '''
200    _factor_dist.observe(score)
201
202def condition(cond: float):
203    '''
204    Used for conditioning statements. When `cond` is a boolean, this behaves like
205    typical conditioning.
206
207    - `cond` is a non-negative multiplicative weight for the conditioning. When zero,
208        the trace is assigned zero probability.
209    '''
210    if cond == 0:
211        _factor_dist.observe(-float("inf"))
212    else:
213        _factor_dist.observe(math.log(cond))
214
215def flip(p=.5, name=None):
216    '''
217    Samples from a Bernoulli distribution with probability `p`.
218    '''
219    return bool(Bernoulli(p).sample(name=name))
220
221@keep_deterministic
222def _draw_from_dist(n: Union[Sequence[Element], int]) -> Distribution[Element]:
223    if isinstance(n, int):
224        return Categorical(range(n))
225    if hasattr(n, '__getitem__'):
226        return Categorical(n)
227    else:
228        return Categorical(list(n))
229
230@overload
231def draw_from(n: int) -> int:
232    ...
233@overload
234def draw_from(n: Sequence[Element]) -> Element:
235    ...
236def draw_from(n: Union[Sequence[Element], int]) -> Element:
237    '''
238    Samples uniformly from `n` when it is a sequence.
239    When `n` is an integer, a sample is drawn from `range(n)`.
240    '''
241    return _draw_from_dist(n).sample()
242
243def mem(fn: Callable[..., Element]) -> Callable[..., Element]:
244    '''
245    Turns a function into a stochastically memoized function.
246    Stores information in trace-specific storage.
247    '''
248    def mem_wrapper(*args, **kws):
249        key = (fn, args, tuple(sorted(kws.items())))
250        kws = hashabledict(kws)
251        if key in global_store:
252            return global_store.get(key)
253        else:
254            value = fn(*args, **kws)
255            global_store.set(key, value)
256            return value
257    return mem_wrapper
258mem = cps_transform_safe_decorator(mem)
259
260_uniform = Uniform()
261def uniform():
262    '''
263    Samples from a uniform distribution over the interval $[0, 1]$.
264    '''
265    return _uniform.sample()
266
267_normal = Normal(0, 1)
268def normal(mean=0, std=1):
269    '''
270    Samples from a standard normal distribution.
271    '''
272    return mean + std * _normal.sample()
273
274@keep_deterministic
275def map_log_probability(distribution: Distribution[Element], values: Sequence[Element]) -> float:
276    return sum(distribution.log_probability(i) for i in values)
277
278def map_observe(distribution: Distribution[Element], values: Sequence[Element]) -> float:
279    """
280    Calculates the total log probability of a sequence of
281    independent values from a distribution.
282    """
283    log_prob = map_log_probability(distribution, values)
284    factor(log_prob)
285    return log_prob
def infer( func: Callable[..., ~Element] = None, method=<class 'flippy.inference.Enumeration'>, cache_size=1024, **kwargs) -> flippy.InferCallable[~Element]:
160def infer(
161    func: Callable[..., Element]=None,
162    method=Enumeration,
163    cache_size=1024,
164    **kwargs
165) -> InferCallable[Element]:
166    '''
167    Turns a function into a stochastic function, that represents a posterior distribution.
168
169    This is the main interface for performing inference in FlipPy.
170
171    - `method` specifies the inference method and can either be an instance of
172    an `InferenceAlgorithm` or a string. Defaults to `Enumeration`.
173    - `**kwargs` are keyword arguments passed to the inference method.
174    '''
175    return InferCallable(func, method, cache_size, **kwargs)

Turns a function into a stochastic function, that represents a posterior distribution.

This is the main interface for performing inference in FlipPy.

  • method specifies the inference method and can either be an instance of an InferenceAlgorithm or a string. Defaults to Enumeration.
  • **kwargs are keyword arguments passed to the inference method.
def flip(p=0.5, name=None):
216def flip(p=.5, name=None):
217    '''
218    Samples from a Bernoulli distribution with probability `p`.
219    '''
220    return bool(Bernoulli(p).sample(name=name))

Samples from a Bernoulli distribution with probability p.

def draw_from(n: Union[Sequence[~Element], int]) -> ~Element:
237def draw_from(n: Union[Sequence[Element], int]) -> Element:
238    '''
239    Samples uniformly from `n` when it is a sequence.
240    When `n` is an integer, a sample is drawn from `range(n)`.
241    '''
242    return _draw_from_dist(n).sample()

Samples uniformly from n when it is a sequence. When n is an integer, a sample is drawn from range(n).

def uniform():
262def uniform():
263    '''
264    Samples from a uniform distribution over the interval $[0, 1]$.
265    '''
266    return _uniform.sample()

Samples from a uniform distribution over the interval $[0, 1]$.

def normal(mean=0, std=1):
269def normal(mean=0, std=1):
270    '''
271    Samples from a standard normal distribution.
272    '''
273    return mean + std * _normal.sample()

Samples from a standard normal distribution.

def factor(score):
196def factor(score):
197    '''
198    Adds a real-valued `score` (i.e., log-probability) to the weight of the
199    current trace.
200    '''
201    _factor_dist.observe(score)

Adds a real-valued score (i.e., log-probability) to the weight of the current trace.

def condition(cond: float):
203def condition(cond: float):
204    '''
205    Used for conditioning statements. When `cond` is a boolean, this behaves like
206    typical conditioning.
207
208    - `cond` is a non-negative multiplicative weight for the conditioning. When zero,
209        the trace is assigned zero probability.
210    '''
211    if cond == 0:
212        _factor_dist.observe(-float("inf"))
213    else:
214        _factor_dist.observe(math.log(cond))

Used for conditioning statements. When cond is a boolean, this behaves like typical conditioning.

  • cond is a non-negative multiplicative weight for the conditioning. When zero, the trace is assigned zero probability.
def keep_deterministic(fn: Callable[..., ~R]) -> Callable[..., ~R]:
560def keep_deterministic(fn: Callable[..., R]) -> Callable[..., R]:
561    '''
562    Decorator to interpret a function as deterministic Python.
563    Any random sampling in the function will not be targeted for inference.
564    Any conditioning will not be incorporated into the inference process.
565
566    This is helpful if the transform slows a function down, if a
567    deterministic library is being called, or if a distribution is being
568    directly computed.
569    '''
570    return KeepDeterministicCallable(fn)

Decorator to interpret a function as deterministic Python. Any random sampling in the function will not be targeted for inference. Any conditioning will not be incorporated into the inference process.

This is helpful if the transform slows a function down, if a deterministic library is being called, or if a distribution is being directly computed.

def mem(fn: Callable[..., ~Element]) -> Callable[..., ~Element]:
244def mem(fn: Callable[..., Element]) -> Callable[..., Element]:
245    '''
246    Turns a function into a stochastically memoized function.
247    Stores information in trace-specific storage.
248    '''
249    def mem_wrapper(*args, **kws):
250        key = (fn, args, tuple(sorted(kws.items())))
251        kws = hashabledict(kws)
252        if key in global_store:
253            return global_store.get(key)
254        else:
255            value = fn(*args, **kws)
256            global_store.set(key, value)
257            return value
258    return mem_wrapper

Turns a function into a stochastically memoized function. Stores information in trace-specific storage.