- Published on
Approximating non-Function Mappings with Mixture Density Networks
- Authors
- Name
- Luke Wood
Neural networks are universal function approximators. Key word: function! While powerful function approximators, neural networks are not able to approximate non-functions. One important restriction to remember about functions - they have one input, one output! Neural networks suffer greatly when the training set has multiple values of Y for a single X.
In this guide I'll show you how to approximate the class of non-functions consisting of mappings from x -> y
such that multiple y
may exist for a given x
. We'll use a class of neural networks called "Mixture Density Networks".
I'm going to use the new multibackend Keras Core project to build my Mixture Density networks. Great job to the Keras team on the project - it's awesome to be able to swap frameworks in one line of code.
Some bad news: I use TensorFlow probability in this guide... so it doesn't actually work with other backends.
Anyways, let's start by installing dependencies and sorting out imports:
pip install -q --upgrade tensorflow-probability keras-core
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import math
import random
from keras_core import callbacks
import keras_core
import tensorflow as tf
from keras_core import layers
from keras_core import optimizers
from tensorflow_probability import distributions as tfd
Next, lets generate a noisy spiral that we're going to attempt to approximate. I've defined a few functions below to do this:
def normalize(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def create_noisy_spiral(n, max_jitter=0.2, size_range=[5, 10]):
R_X = random.choice(range(size_range[0], size_range[1], 10))
R_Y = random.choice(range(size_range[0], size_range[1], 10))
X_0 = random.random() * 100 - 50
Y_0 = random.random() * 100 - 50
randomized_jitter = lambda: random.uniform(-max_jitter / 2, max_jitter / 2)
x = lambda _t: _t / 360.0 * math.cos(_t / 90.0 * math.pi) * R_X + X_0
y = lambda _t: _t / 360.0 * math.sin(_t / 90.0 * math.pi) * R_Y + Y_0
out = np.zeros([n, 2])
for i in range(n):
t = 360.0 / n * i
out[i, 0] = x(t) + randomized_jitter()
out[i, 1] = y(t) + randomized_jitter()
out = out.astype("float32")
return (normalize(out[:, 0]) * 10) - 5, (normalize(out[:, 1]) * 10) - 5
Next, lets invoke this function many times to construct a sample dataset:
xs = []
ys = []
for _ in range(10):
x, y = create_noisy_spiral(1000)
xs.append(x)
ys.append(y)
x = np.concatenate(xs, axis=0)
y = np.concatenate(ys, axis=0)
x = np.expand_dims(x, axis=1)
plt.scatter(x, y)
plt.show()
As you can see, there's multiple possible values for Y with respect to a given X. Normal neural networks will simply learn the mean of these points with respect to geometric space.
We can quickly show this with a simple linear model:
N_HIDDEN = 128
model = keras_core.Sequential(
[
layers.Dense(N_HIDDEN, activation="relu"),
layers.Dense(N_HIDDEN, activation="relu"),
layers.Dense(1, activation="relu"),
]
)
Let's use mean squared error as well as the adam optimizer. These tend to be reasonable prototyping choices:
model.compile(optimizer="adam", loss="mse")
We can fit this model quite easy
model.fit(
x,
y,
epochs=300,
batch_size=128,
validation_split=0.15,
callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=5)],
)
Epoch 14/300
67/67 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 6.3333 - val_loss: 8.3465
And let's check out the result:
y_pred = model.predict(x)
As expected, the model learns the geometric mean of all points in y
for a given x
.
plt.scatter(x, y)
plt.scatter(x, y_pred)
plt.show()
Mixture Density Networks
Mixture Density networks can alleviate this problem. A Mixture density is a class of complicated densities expressible in terms of simpler densities. They are effectively the sum of a ton of probability distributions. Mixture Density networks learn to parameterize a mixture density distribution based on a given training set.
As a practitioner, all you need to know, is that Mixture Density Networks solve the problem of multiple values of Y for a given X. I'm hoping to add a tool to your kit- but I'm not going to formally explain the derivation of Mixture Density networks in this guide. The most important thing to know is that a Mixture Density network learns to parameterize a mixture density distribution. This is done by computing a special loss with respect to both the provided y_i
label as well as the predicted distribution for the corresponding x_i
. This loss function operates by computing the probability that y_i
would be drawn from the predicted mixture distribution.
Let's implement a Mixture density network. Below, a ton of helper functions are defined based on an old Keras library Keras Mixture Density Network Layer
.
I've adapted the code for use with Keras core.
Lets start writing a Mixture Density Network! First, we need a special activation function: ELU plus a tiny epsilon. This helps prevent ELU from outputting 0 which causes NaNs in Mixture Density Network loss evaluation.
def elu_plus_one_plus_epsilon(x):
return keras_core.activations.elu(x) + 1 + keras_core.backend.epsilon()
Next, lets actually define a MixtureDensity layer that outputs all values needed to sample from the learned mixture distribution:
class MixtureDensityOutput(layers.Layer):
def __init__(self, output_dimension, num_mixtures, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dimension
self.num_mix = num_mixtures
self.mdn_mus = layers.Dense(
self.num_mix * self.output_dim, name="mdn_mus"
) # mix*output vals, no activation
self.mdn_sigmas = layers.Dense(
self.num_mix * self.output_dim,
activation=elu_plus_one_plus_epsilon,
name="mdn_sigmas",
) # mix*output vals exp activation
self.mdn_pi = layers.Dense(self.num_mix, name="mdn_pi") # mix vals, logits
def build(self, input_shape):
self.mdn_mus.build(input_shape)
self.mdn_sigmas.build(input_shape)
self.mdn_pi.build(input_shape)
super().build(input_shape)
@property
def trainable_weights(self):
return (
self.mdn_mus.trainable_weights
+ self.mdn_sigmas.trainable_weights
+ self.mdn_pi.trainable_weights
)
@property
def non_trainable_weights(self):
return (
self.mdn_mus.non_trainable_weights
+ self.mdn_sigmas.non_trainable_weights
+ self.mdn_pi.non_trainable_weights
)
def call(self, x, mask=None):
return layers.concatenate(
[self.mdn_mus(x), self.mdn_sigmas(x), self.mdn_pi(x)], name="mdn_outputs"
)
Lets construct an Mixture Density Network using our new layer:
OUTPUT_DIMS = 1
N_MIXES = 20
mdn_network = keras_core.Sequential(
[
layers.Dense(N_HIDDEN, activation="relu"),
layers.Dense(N_HIDDEN, activation="relu"),
MixtureDensityOutput(OUTPUT_DIMS, N_MIXES),
]
)
Next, let's implement a custom loss function to train the Mixture Density Network layer based on the true values and our expected outputs:
def get_mixture_loss_func(output_dim, num_mixes):
def mdn_loss_func(y_true, y_pred):
# Reshape inputs in case this is used in a TimeDistribued layer
y_pred = tf.reshape(
y_pred,
[-1, (2 * num_mixes * output_dim) + num_mixes],
name="reshape_ypreds",
)
y_true = tf.reshape(y_true, [-1, output_dim], name="reshape_ytrue")
# Split the inputs into paramaters
out_mu, out_sigma, out_pi = tf.split(
y_pred,
num_or_size_splits=[
num_mixes * output_dim,
num_mixes * output_dim,
num_mixes,
],
axis=-1,
name="mdn_coef_split",
)
# Construct the mixture models
cat = tfd.Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [
tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
for loc, scale in zip(mus, sigs)
]
mixture = tfd.Mixture(cat=cat, components=coll)
loss = mixture.log_prob(y_true)
loss = tf.negative(loss)
loss = tf.reduce_mean(loss)
return loss
return mdn_loss_func
mdn_network.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer="adam")
Finally, we can call model.fit()
like any other Keras model.
mdn_network.fit(
x,
y,
epochs=300,
batch_size=128,
validation_split=0.15,
callbacks=[
callbacks.EarlyStopping(monitor="loss", patience=10, restore_best_weights=True),
callbacks.ReduceLROnPlateau(monitor="loss", patience=5),
],
)
Let's make some predictions!
y_pred_mixture = mdn_network.predict(x)
print(y_pred_mixture.shape)
The MDN does not output a single value; instead it outputs values to parameterize a mixture distribution. To visualize these outputs, lets sample from the distribution.
Note that sampling is a lossy process. If you want to preserve all information as part of a greater latent representation (i.e. for downstream processing) I recommend you simply keep the distribution parameters in place.
def split_mixture_params(params, output_dim, num_mixes):
mus = params[: num_mixes * output_dim]
sigs = params[num_mixes * output_dim : 2 * num_mixes * output_dim]
pi_logits = params[-num_mixes:]
return mus, sigs, pi_logits
def softmax(w, t=1.0):
e = np.array(w) / t # adjust temperature
e -= e.max() # subtract max to protect from exploding exp values.
e = np.exp(e)
dist = e / np.sum(e)
return dist
def sample_from_categorical(dist):
r = np.random.rand(1) # uniform random number in [0,1]
accumulate = 0
for i in range(0, dist.size):
accumulate += dist[i]
if accumulate >= r:
return i
tf.logging.info("Error sampling categorical model.")
return -1
def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):
mus, sigs, pi_logits = split_mixture_params(params, output_dim, num_mixes)
pis = softmax(pi_logits, t=temp)
m = sample_from_categorical(pis)
# Alternative way to sample from categorical:
# m = np.random.choice(range(len(pis)), p=pis)
mus_vector = mus[m * output_dim : (m + 1) * output_dim]
sig_vector = sigs[m * output_dim : (m + 1) * output_dim]
scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag
cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.
cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
return sample
Next lets use our sampling function:
# Sample from the predicted distributions
y_samples = np.apply_along_axis(
sample_from_output, 1, y_pred_mixture, 1, N_MIXES, temp=1.0
)
Finally, we can visualize our network outputs
plt.scatter(x, y, alpha=0.05, color="blue", label="Ground Truth")
plt.scatter(
x,
y_samples[:, :, 0],
color="green",
alpha=0.05,
label="Mixture Density Network prediction",
)
plt.show()
Beautiful. Love to see it
Conclusions
Neural Networks are universal function approximators - but they can only approximate functions. Mixture Density networks can approximate arbitrary x->y mappings using some neat probability tricks.
One more pretty graphic for the road:
fig, axs = plt.subplots(1, 3)
fig.set_figheight(3)
fig.set_figwidth(12)
axs[0].set_title("Ground Truth")
axs[0].scatter(x, y, alpha=0.05, color="blue")
axs[1].set_title("Normal Model prediction")
axs[1].scatter(x, y_pred, alpha=0.05, color="red")
axs[2].scatter(
x,
y_samples[:, :, 0],
color="green",
alpha=0.05,
label="Mixture Density Network prediction",
)
axs[2].set_title("Mixture Density Network prediction")
plt.show()