1 import numpy as np 2 import matplotlib.pyplot as plt 3 import os 4 import time 5 6 import tensorflow as tf 7 tf.enable_eager_execution() 8 9 # create data 10 X = np.linspace(-1, 1, 5000) 11 np.random.shuffle(X) 12 y = 0.5 * X + 2 + np.random.normal(0, 0.05, (5000,)) 13 14 # plot data 15 plt.scatter(X, y) 16 plt.show() 17 18 # split data 19 X_train, y_train = X[:4000], y[:4000] 20 X_test, y_test = X[4000:], y[4000:] 21 22 # tf.data 23 BATCH_SIZE = 32 24 BUFFER_SIZE = 512 25 dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE) 26 27 28 # subclassed model 29 UNITS = 1 30 31 32 class Model(tf.keras.Model): 33 def __init__(self): 34 super(Model, self).__init__() 35 self.fc = tf.keras.layers.Dense(units=UNITS) 36 37 def call(self, inputs): 38 return self.fc(inputs) 39 40 41 model = Model() 42 43 optimizer = tf.train.AdamOptimizer() 44 45 46 # loss function 47 def loss_function(real, pred): 48 return tf.losses.mean_squared_error(labels=real, predictions=pred) 49 50 51 EPOCHS = 30 52 checkpoint_dir = './save_subclassed_keras_model_training_checkpoints' 53 if not os.path.exists(checkpoint_dir): 54 os.makedirs(checkpoint_dir) 55 56 # training loop 57 for epoch in range(EPOCHS): 58 start = time.time() 59 epoch_loss = 0 60 61 for (batch, (x, y)) in enumerate(dataset): 62 x = tf.cast(x, tf.float32) 63 y = tf.cast(y, tf.float32) 64 x = tf.expand_dims(x, axis=1) 65 y = tf.expand_dims(y, axis=1) 66 # print(x) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) 67 # print(y) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) 68 with tf.GradientTape() as tape: 69 predictions = model(x) 70 # print(predictions) # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32) 71 batch_loss = loss_function(real=y, pred=predictions) 72 73 grads = tape.gradient(batch_loss, model.variables) 74 optimizer.apply_gradients(zip(grads, model.variables), 75 global_step=tf.train.get_or_create_global_step()) 76 epoch_loss += batch_loss 77 78 if (batch + 1) % 10 == 0: 79 print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, 80 batch + 1, 81 batch_loss/int(x.shape[0]))) 82 83 print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss/len(X_train))) 84 print('Time taken for 1 epoch {} sec\n'.format(time.time() - start)) 85 86 # save checkpoint 87 checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') 88 if (epoch + 1) % 10 == 0: 89 model.save_weights(checkpoint_prefix.format(epoch=epoch), overwrite=True) 90 91 _model = Model() 92 _model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) 93 _model.build(input_shape=tf.TensorShape([BATCH_SIZE, 1])) 94 _model.summary() 95 96 test_dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(1) 97 98 for (batch, x) in enumerate(test_dataset): 99 x = tf.cast(x, tf.float32) 100 x = tf.expand_dims(x, axis=1) 101 print(x) 102 predictions = _model(x) 103 print(predictions) 104 exit()