#!/usr/bin/env python
# coding=utf-8 from keras.models import Sequential
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
import numpy as np
import string
import random class CharacterTable(object): def __init__(self, maxlen):
self.chars = string.digits + '+ '
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen def encode(self, strs, maxlen=None):
maxlen = maxlen if maxlen else self.maxlen
vec = np.zeros((maxlen, len(self.chars)))
for i, c in enumerate(strs):
vec[i, self.char_indices[c]] = 1
return vec def decode(self, vec, calc_argmax=True):
if calc_argmax:
vec = vec.argmax(axis=-1)
return ''.join(self.indice_chars[x] for x in vec) def gen_num():
nums = random.sample('', random.randint(1, 3))
return int(''.join(nums)) MAXLEN = 7 # 3+3+1
ctable = CharacterTable(MAXLEN) questions, expected = [], []
seen = set()
i = 0
while i < 50000:
a, b = gen_num(), gen_num()
key = tuple(sorted((a, b)))
if key in seen:
continue
seen.add(key)
q = '{}+{}'.format(a, b)
query = q + ' '*(7-len(q))
ans = str(a+b)
ans += ' ' * (4-len(ans)) questions.append(query)
expected.append(ans)
i += 1
print('total questions', len(questions)) X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool) for i, sent in enumerate(questions):
X[i] = ctable.encode(sent) for i, sent in enumerate(expected):
y[i] = ctable.encode(sent, 4) model = Sequential()
model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
model.add(RepeatVector(4))
model.add(recurrent.LSTM(128, return_sequences=True))
model.add(recurrent.LSTM(128, return_sequences=True)) model.add(TimeDistributed(Dense(len(ctable.chars))))
model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']) model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2) # 测试看看
for i in range(10):
ind = np.random.randint(0, len(questions)-5)
x_test, y_test = X[ind:ind+5], y[ind:ind+5]
y_preds = model.predict_classes(x_test, verbose=0)
print('Q', ctable.decode(x_test[0]))
print('T', ctable.decode(y_test[0]))
print('Pred', ctable.decode(y_preds[0], calc_argmax=False)) json_string = model.to_json()
with open('rnn_add_model.json', 'wb') as fw:
fw.write(json_string)
model.save_weights('rnn_add_model.h5')
基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%