- 导入必要的模块。本方法中,需要调用 Numpy、Matplolib 和 TensorFlow 函数:
- 定义 VariationalAutoencoder 类。采用 __init__ 类方法来定义超参数,如学习率、批量大小、用于输入的占位符、编码器及解码器网络的权重和偏置变量。它还根据 VAE 的网络体系结构建立计算图。在本方法中使用 Xavier 初始化器初始化权重。与使用自己定义的方法进行 Xavier 初始化不同,本方法使用 tf.contrib.layers.xavier_initializer() 来进行初始化。最后,定义损失(生成和潜在)及优化器操作:
- 创建网络编码器和网络解码器。网络编码器的第一层接收输入并生成输入的递减式潜在表示;第二层将输入映射到高斯分布。网络学习这些转变:
- VariationalAutoencoder 类还包含一些帮助函数来生成和重建数据,并适应 VAE:
- 一旦 VAE 类完成,定义一个函数序列,它使用 VAE 类对象并通过给定的数据进行训练:
- 使用 VAE 类和序列函数。采用 MNIST 数据集:
- 定义网络架构,并在 MNIST 数据集上进行 VAE 的训练。在这种情况下,为了简单保留了潜在维度 2。
- 看一下 VAE 是否重构了输入。输出表明那些数字确实被重构了,而且由于使用了二维的潜在空间,所以图像显得模糊了:
- 以下是使用经过训练的 VAE 生成的手写数字样本: