Bean Machine Retrospective, part 5

Let’s take another look at the “hello world” example and think more carefully about what is actually going on:

def fairness():
  return Beta(2,2)

def flip(n):
  return Bernoulli(fairness())

heads = tensor(1)
tails = tensor(0)

observations = { flip(0) : heads, ... flip(9): tails }
queries = [ fairness() ]
num_samples = 10000
results = BMGInference().infer(queries, observations, num_samples)
samples = results[fairness()]

There’s a lot going on here. Let’s start by clearing up what the returned values of the random variables are.

It sure looks like fairness() returns an instance of a pytorch object — Beta(2,2) — but the model behaves as though it returned a sample from that distribution in the call to Bernoulli. What’s going on?

The call doesn’t return either. It returns a random variable identifier, which I will abbreviate as RVID. An RVID is essentially a tuple of the original function and all the arguments. (If you think that sounds like a key to a memoization dictionary, you’re right!)

This is an oversimplification, but you can imagine for now that it works like this:

def random_variable(original_function):
  def new_function(*args):   
    return RVID(original_function, args)
  # The actual implementation stashes away useful information
  # and has logic needed for other Bean Machine inference algorithms
  # but for our purposes, pretend it just returns an RVID.
  return new_function

def fairness():
  return Beta(2,2)
fairness = random_variable(fairness)

The decorator lets us manipulate function calls which represent random variables as values! Now it should be clear how

queries = [ fairness() ]

works; what we’re really doing here is

queries = [ RVID(fairness, ()) ]

That clears up how it is that we treat calls to random variables as unique values. What about inference?

Leaving aside the behavior of the decorators to cause random variables to generate RVIDs, our “hello world” program acts just like any other Python program right up to here:

results = BMGInference().infer(queries, observations, num_samples)

Argument queries is a list of RVIDs, and observations is a dictionary mapping RVIDs onto their observed values. Plainly infer causes a miracle to happen: it returns a dictionary mapping each queried RVID onto a tensor of num_samples values that are plausible samples from the distribution of the posterior of the random variable.

Of course it is no miracle. We do the following steps:

  • Transform the source code of each queried or observed function (and transitively their callees) into an equivalent program which partially evaluates the model, accumulating a graph as it goes
  • Execute the transformed code to accumulate the graph
  • Transform the accumulated graph into a valid BMG graph
  • Perform inference on the graph in BMG
  • Marshal the samples returned into the dictionary data structure expected by the user

Coming up on FAIC: we will look at how we implemented each of those steps.

1 thought on “Bean Machine Retrospective, part 5

  1. Pingback: Dew Drop – December 21, 2022 (#3842) – Morning Dew by Alvin Ashcraft

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s