Outline
-
Auto-Encoder
-
Variational Auto-Encoders
Auto-Encoder
创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
assert tf.version.startswith(‘2.‘)
def save_images(imgs, name):
new_im = Image.new(‘L‘, (280, 280))
index = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-number">280</span>, <span class="hljs-number">28</span>):
<span class="hljs-keyword">for</span> j <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-number">280</span>, <span class="hljs-number">28</span>):
im = imgs[index]
im = Image.fromarray(im, mode=<span class="hljs-string">‘L‘</span>)
new_im.paste(im, (i, j))
index += <span class="hljs-number">1</span>
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def init(self):
super(AE, self).init()
<span class="hljs-comment"># Encoders</span>
self.encoder = Sequential([
layers.Dense(<span class="hljs-number">256</span>, activation=tf.nn.relu),
layers.Dense(<span class="hljs-number">128</span>, activation=tf.nn.relu),
layers.Dense(h_dim)
])
<span class="hljs-comment"># Decoders</span>
self.decoder = Sequential([
layers.Dense(<span class="hljs-number">128</span>, activation=tf.nn.relu),
layers.Dense(<span class="hljs-number">256</span>, activation=tf.nn.relu),
layers.Dense(<span class="hljs-number">784</span>)
])
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs, training=<span class="hljs-literal">None</span></span>):</span>
<span class="hljs-comment"># [b,784] ==> [b,19]</span>
h = self.encoder(inputs)
<span class="hljs-comment"># [b,10] ==> [b,784]</span>
x_hat = self.decoder(h)
<span class="hljs-keyword">return</span> x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_
Layer (type) Output Shape Param #
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_
Layer (type) Output Shape Param #
sequential (Sequential) multiple 236436
_
sequential_1 (Sequential) multiple 237200
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
<span class="hljs-keyword">for</span> step, x <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_db):
<span class="hljs-comment"># [b,28,28]==>[b,784]</span>
x = tf.reshape(x, [<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>])
<span class="hljs-keyword">with</span> tf.GradientTape() <span class="hljs-keyword">as</span> tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=<span class="hljs-literal">True</span>)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(<span class="hljs-built_in">zip</span>(grads, model.trainable_variables))
<span class="hljs-keyword">if</span> step % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:
print(epoch, step, <span class="hljs-built_in">float</span>(rec_loss))
<span class="hljs-comment"># evaluation</span>
x = <span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(test_db))
logits = model(tf.reshape(x, [<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>]))
x_hat = tf.sigmoid(logits)
<span class="hljs-comment"># [b,784]==>[b,28,28]</span>
x_hat = tf.reshape(x_hat, [<span class="hljs-number">-1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>])
<span class="hljs-comment"># [b,28,28] ==> [2b,28,28]</span>
x_concat = tf.concat([x, x_hat], axis=<span class="hljs-number">0</span>)
<span class="hljs-comment"># x_concat = x # 原始图片</span>
x_concat = x_hat
x_concat = x_concat.numpy() * <span class="hljs-number">255.</span>
x_concat = x_concat.astype(np.uint8) <span class="hljs-comment"># 保存为整型</span>
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.exists(<span class="hljs-string">‘ae_images‘</span>):
os.mkdir(<span class="hljs-string">‘ae_images‘</span>)
save_images(x_concat, <span class="hljs-string">‘ae_images/rec_epoch_%d.png‘</span> % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703