[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model

  1 import numpy as np
  2 import matplotlib.pyplot as plt
  3 import os
  4 import time
  5 
  6 import tensorflow as tf
  7 tf.enable_eager_execution()
  8 
  9 # create data
 10 X = np.linspace(-1, 1, 5000)
 11 np.random.shuffle(X)
 12 y = 0.5 * X + 2 + np.random.normal(0, 0.05, (5000,))
 13 
 14 # plot data
 15 plt.scatter(X, y)
 16 plt.show()
 17 
 18 # split data
 19 X_train, y_train = X[:4000], y[:4000]
 20 X_test, y_test = X[4000:], y[4000:]
 21 
 22 # tf.data
 23 BATCH_SIZE = 32
 24 BUFFER_SIZE = 512
 25 dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE)
 26 
 27 
 28 # subclassed model
 29 UNITS = 1
 30 
 31 
 32 class Model(tf.keras.Model):
 33     def __init__(self):
 34         super(Model, self).__init__()
 35         self.fc = tf.keras.layers.Dense(units=UNITS)
 36 
 37     def call(self, inputs):
 38         return self.fc(inputs)
 39 
 40 
 41 model = Model()
 42 
 43 optimizer = tf.train.AdamOptimizer()
 44 
 45 
 46 # loss function
 47 def loss_function(real, pred):
 48     return tf.losses.mean_squared_error(labels=real, predictions=pred)
 49 
 50 
 51 EPOCHS = 30
 52 checkpoint_dir = './save_subclassed_keras_model_training_checkpoints'
 53 if not os.path.exists(checkpoint_dir):
 54     os.makedirs(checkpoint_dir)
 55 
 56 # training loop
 57 for epoch in range(EPOCHS):
 58     start = time.time()
 59     epoch_loss = 0
 60 
 61     for (batch, (x, y)) in enumerate(dataset):
 62         x = tf.cast(x, tf.float32)
 63         y = tf.cast(y, tf.float32)
 64         x = tf.expand_dims(x, axis=1)
 65         y = tf.expand_dims(y, axis=1)
 66         # print(x)    # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
 67         # print(y)    # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
 68         with tf.GradientTape() as tape:
 69             predictions = model(x)
 70             # print(predictions)  # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
 71             batch_loss = loss_function(real=y, pred=predictions)
 72 
 73         grads = tape.gradient(batch_loss, model.variables)
 74         optimizer.apply_gradients(zip(grads, model.variables),
 75                                   global_step=tf.train.get_or_create_global_step())
 76         epoch_loss += batch_loss
 77 
 78         if (batch + 1) % 10 == 0:
 79             print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
 80                                                          batch + 1,
 81                                                          batch_loss/int(x.shape[0])))
 82 
 83     print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss/len(X_train)))
 84     print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
 85 
 86     # save checkpoint
 87     checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
 88     if (epoch + 1) % 10 == 0:
 89         model.save_weights(checkpoint_prefix.format(epoch=epoch), overwrite=True)
 90 
 91 _model = Model()
 92 _model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
 93 _model.build(input_shape=tf.TensorShape([BATCH_SIZE, 1]))
 94 _model.summary()
 95 
 96 test_dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(1)
 97 
 98 for (batch, x) in enumerate(test_dataset):
 99     x = tf.cast(x, tf.float32)
100     x = tf.expand_dims(x, axis=1)
101     print(x)
102     predictions = _model(x)
103     print(predictions)
104     exit()

 

上一篇:flink on yarn、资源及状态管理


下一篇:[Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model