class DataLoader():
def __init__(self):
path = tf.keras.utils.get_file('nietzsche.txt',origin='http://s3.amazonaws.com/text-data')
with open(path,encoding='utf-8') as f:
self.raw_text = f.read().lower()
self.chars = sorted(list(set(self.raw_text)))
self.char_indices = dict((c,i) for i,c in enumerate(self.chars))
self.indices_char = dict((i,c) for i,c in enumerate(self.chars))
self.text = [self.char_indices[c] for c in self.raw_text]
def get_batch(self,seq_length,batch_size):
seq = []
next_char = []
for i in range(batch_size):
index = np.random.randint(0,len(self.text) - seq_length)
seq.append(self.text[index:index+seq_length])
next_char.append(self.text[index+seq_length])
return np.array(seq), np.array(next_char)
class RNN(tf.keras.Model):
def __init__(self,num_chars,batch_size,seq_length):
super().__init__()
self.num_chars = num_chars
self.seq_length = seq_length
self.batch_size = batch_size
self.cell = tf.keras.layers.LSTMCell(units = 256)
self.dense = tf.keras.layers.Dense(units = self.num_chars)
def call(self,inputs,from_logits = False):
inputs = tf.one_hot(inputs,depth = self.num_chars)
state = self.cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32)
for t in range(self.seq_length):
output,state = self.cell(inputs[:,t,:],state)
logits = self.dense(output)
if from_logits:
return logits
else:
return tf.nn.softmax(logits)
num_batches = 1000
seq_length = 40
batch_size = 50
learning_rate = le-3
data_loader = DataLoader()
model = RNN(num_chars = len(data_loader.chars),batch_size = batch_size,seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
for batch_index in range(num_batches):
x,y = data_loader.get_batch(seq_length,batch_size)
with tf.GradientTape as tape:
y_pred = model(x)
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y,y_pred=y_pred)
print()
grads = tape.gradient(loss,model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))
def predict(self,inputs,temperature=1):
batch_size,_=tf.shape(inputs)
logits = self(inputs,from_logits=True)
prod = tf.nn.softmax(logits/temperature).numpy()
return np.array([np.random.choice(self.num_chars,p=prod[i,:]) for i in range(
batch_size.numpy())])
x_,- = data_loader.get_batch(seq_length,1)
for diversity in [0.2,0.5,1.0,1.2]:
x = x_
print("diversity %f" % diversity)
for t in range(400):
y_pred = model.predict(x,diversity)
print(data_loader.indices_char[y_pred[0]],end = '',flush=True)
x = np.concatenate([x[:,1:],np.expand_dims(y_pred,axis=1)],axis=-1)
print("\n")