Tensorflow V2.0 图像识别教程
代码:
https://github.com/dwSun/classification-tutorial.git
教程参考官方专家高级教程:
https://tensorflow.google.cn/tutorials/quickstart/advanced?hl=en
这里以 TinyMind 《汉字书法识别》比赛数据为例,展示使用 Tensorflow V2.0 进行图像数据分类模型训练的整个流程。
数据地址请参考:
https://www.tinymind.cn/competitions/41#property_23
或到这里下载:
*练习赛数据下载地址:
训练集:链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd
测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k
数据探索
请参考官方的数据说明
数据处理
竞赛中只有训练集 train 数据有准确的标签,因此这里只使用 train 数据即可,实际应用中,阶段 1、2 的榜单都需要使用。
数据下载
下载数据之后进行解压,得到 train 文件夹,里面有 100 个文件夹,每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量,看数据集是否符合竞赛数据描述:
for l in $(ls); do echo $l $(ls $l|wc -l); done
划分数据集
因为这里只使用了 train 集,因此我们需要对已有数据集进行划分,供模型训练的时候做验证使用,也就是 validation 集的构建。
一般认为,train 用来训练模型,validation 用来对模型进行验证以及超参数( hyper parameter)调整,test 用来做模型的最终验证,我们所谓模型的性能,一般也是指 test 集上模型的性能指标。但是实际项目中,一般只有 train 集,同时没有可靠的 test 集来验证模型,因此一般将 train 集划分出一部分作为 validation,同时将 validation 上的模型性能作为最终模型性能指标。一般情况下,我们不严格区分 validation 和 test。
这里将每个文件夹下面随机50个文件拿出来做 validation。
export train=train
export val=validation
for d in $(ls $train); do
mkdir -p $val/$d/
for f in $(ls train/$d | shuf | head -n 50 ); do
mv $train/$d/$f $val/$d/;
done;
done
需要注意,这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。
模型训练代码-数据部分
首先导入 TF 看一下版本
import tensorflow as tf
tf.__version__
'2.1.0'
训练模型的时候,模型内部全部都是数字,没有任何可读性,而且这些数字也需要人为给予一些实际的意义,这里将 100 个汉字作为模型输出数字的文字表述。
需要注意的是,因为模型训练往往是一个循环往复的过程,因此一个稳定的文字标签是很有必要的,这里利用相关 python 代码在首次运行的时候生成了一个标签文件,后续检测到这个标签文件,则直接调用即可。
import os
if os.path.exists("labels.txt"):
with open("labels.txt") as inf:
classes = [l.strip() for l in inf]
else:
classes = os.listdir("worddata/train/")
with open("labels.txt", "w") as of:
of.write("\r\n".join(classes))
相比于 TF V1.x,V2.x 中一个比较大的变化就是数据集读取和处理的工具更加简单,(虽然效率其实低了一些,但是考虑数据读取在一定的优化下,很少成为模型训练的瓶颈,这一点性能损失带来巨大的便利性,还是值得的)。
TF V2.x中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。
这里使用了两个数据集,分别代表 train、validation。
需要注意的是,由于 数据中,使用的图像数据集,其数值在(0, 255)之间,不适合直接输入模型进行训练,因此这里使用 rescale 对数据进行缩放。同时,train 数据集做了一定的数据预处理(旋转、明暗度),用于进行数据增广,而 validation则不需要做类似的变换。
img_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1.0 / 255.0, rotation_range=15, brightness_range=(0.5, 1.0)
)
img_gen_val = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0 / 255.0)
从 API 的名称和参数可以看出,这个 API 并不直接从目录中读取数据,我们实际使用的时候,要使用这个 API 的一个封装。
在这个封装中,我们指定了图像最终的大小(target_size),颜色模式(color_mode),批量大小(batch_size),同时还有一个非常重要的标签(classes)。
需要注意的是,这里的 color_mode 使用的是灰度模式,读取出来的数据只有一个颜色通道,因为书法的汉字全部是水墨,无所谓颜色。而 classes的使用,保证每次模型训练都使用统一的标签,如果不指定,那么这个 API 会按照一个内置的规则对标签进行编号,这个编号在不同的系统平台之间可能是不一致的。
这里还需要注意一点的是,train 集我们对数据进行了随机打乱 (shuffle), 而 validation 则没有。
batch_size = 32
img_train = img_gen_train.flow_from_directory(
"worddata/train/",
target_size=(128, 128),
color_mode="grayscale",
classes=classes,
batch_size=batch_size,
shuffle=True,
)
img_val = img_gen_train.flow_from_directory(
"worddata/validation/",
target_size=(128, 128),
color_mode="grayscale",
classes=classes,
batch_size=batch_size,
)
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
到这里,这两个数据集就可以使用了,正式模型训练之前,我们可以先来看看这个数据集是怎么读取数据的,读取出来的数据又是设么样子的。
imgs, labels = next(img_train)
# 因为是 generator 所以可以用next来读取
imgs.shape, labels.shape
((32, 128, 128, 1), (32, 100))
可以看到数据是(batch, width, height, channel), 因为这里是灰度图像,因此 channel 是 1。
需要注意,pyTorch、mxnet使用的数据 layout 与Tensorflow 不同,因此数据也有一些不同的处理方式。
把图片打印出来看看,看看数据和标签之间是否匹配
import numpy as np
from matplotlib import pyplot as plt
plt.imshow(imgs[0, :, :, 0], cmap="gray")
classes[np.argsort(labels[0, :])[-1]]
'益'
模型训练代码-模型构建
TF V2.x 中使用动态图来构建模型,同时由于使用了 keras的 API,因此模型构建比较简单了。这里演示的是使用 class的方式构建模型,对于简单模型,还可以直接使用 Sequential 进行构建。
这里的复杂模型也是用 Sequential 的简单模型进行的叠加。
这里构建的是VGG模型,关于VGG模型的更多细节请参考 1409.1556。
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
# 模型有两个主要部分,特征提取层和分类器
# 这里是特征提取层
self.feature = tf.keras.models.Sequential()
self.feature.add(self.conv(64))
self.feature.add(self.conv(64, add_pooling=True))
self.feature.add(self.conv(128))
self.feature.add(self.conv(128, add_pooling=True))
self.feature.add(self.conv(256))
self.feature.add(self.conv(256))
self.feature.add(self.conv(256, add_pooling=True))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512, add_pooling=True))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512, add_pooling=True))
self.feature.add(tf.keras.layers.GlobalAveragePooling2D())
self.feature.add(tf.keras.layers.Dense(4096, activation="relu"))
self.feature.add(tf.keras.layers.BatchNormalization())
self.feature.add(tf.keras.layers.Dense(4096, activation="relu"))
self.feature.add(tf.keras.layers.BatchNormalization())
self.feature.add(tf.keras.layers.Dropout(0.5))
# 这个简单的机构是分类器
self.pred = tf.keras.layers.Dense(100)
def conv(self, filters, add_pooling=False):
# 模型大量使用重复模块构建,
# 这里将重复模块提取出来,简化模型构建过程
model = tf.keras.models.Sequential(
[
tf.keras.layers.Conv2D(filters, 3, padding="same", activation="relu"),
tf.keras.layers.BatchNormalization(),
]
)
if add_pooling:
model.add(
tf.keras.layers.MaxPool2D(
pool_size=(2, 2), strides=None, padding="same"
)
)
return model
def call(self, x):
# call 用来定义模型各个结构之间的运算关系
x = self.feature(x)
return self.pred(x)
可以看到,上述模型定义中,仅仅关注当前模块的参数即可,模块的输入及输出的关系由框架自动推算得到,节省了很多精力。
这跟 TF 1.x 有很大不同,也跟 pyTorch有很大不同。mxnet 的 gluon api 跟这里的操作是类似的。
实例化一个模型看看:
model = MyModel()
model.build(input_shape=(None, 128, 128, 1))
# 这里的build,是因为模型构建的时候,并没有指定输入数据的尺寸,
# 因此要查看模型的一些数据,需要告知模型的输入数据尺寸,框架据
# 此推断模型内部各模块参数。
model.summary()
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 33645760
_________________________________________________________________
dense_2 (Dense) multiple 409700
=================================================================
Total params: 34,055,460
Trainable params: 34,030,628
Non-trainable params: 24,832
_________________________________________________________________
build 和 summary 仅仅用来查看模型数据,对于模型训练不是必须的
模型训练代码-训练相关部分
要训练模型,我们还需要定义损失,优化器等,同时为了方便训练过程对模型进行验证,我们还需要定义一些性能指标。
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam() # 优化器有些参数可以设置
train_loss = tf.keras.metrics.Mean(name="train_loss")
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name="train_accuracy")
val_loss = tf.keras.metrics.Mean(name="val_loss")
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name="val_accuracy")
TF V1.x 中,模型的训练是需要启动一个 Session 的,而在 TF V2.x中,这个 Session的操作被隐藏了起来,取而代之的是 tf.function。
@tf.function
def train_step(imgs, labels):
with tf.GradientTape() as tape:
# tape 用来追踪整个计算图,并记录梯度
preds = model(imgs, training=True)
loss = loss_object(labels, preds)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, preds)
@tf.function
def val_step(imgs, labels):
# 验证的时候,我们不需要进行梯度更新,
# 也就不需要使用tape
preds = model(imgs, training=True)
loss = loss_object(labels, preds)
val_loss(loss)
val_accuracy(labels, preds)
import time # 模型训练的过程中手动追踪一下模型的训练速度
因为模型整个训练过程一般是一个循环往复的过程,所以经常性的保存重启模型训练中间过程是有必要的,这里使用 checkpoint 来保存模型中间训练结果,TF 整个系列对 checkpoint 的处理都很方便,这个目前是其他框架有欠缺的部分。
ckpt = tf.train.get_checkpoint_state(".")
# 检查 checkpoint 是否存在
if ckpt:
# 如果存在,则加载 checkpoint
model.load_weights(ckpt.model_checkpoint_path)
# 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,
# 手动选择准确率最高的某次 checkpoint 进行加载。
print("model lodaded")
EPOCHS = 20
for epoch in range(EPOCHS):
# 验证数据都是针对整个 epoch 的,
# 所以每个 epoch 之间要对这些数据初始化一下。
train_loss.reset_states()
train_accuracy.reset_states()
val_loss.reset_states()
val_accuracy.reset_states()
total_trained = 0
total_valed = 0
start = time.time()
for imgs, labels in img_gen_train.flow_from_directory(
"worddata/train/",
target_size=(128, 128),
color_mode="grayscale",
classes=classes,
batch_size=batch_size,
shuffle=True,
):
# 之前生成的 img_train 是个 generator,而且它不会自动重启和结束,
# 一旦启动,只要不手动结束,它会一直无限循环输出数据,因此这里手动处
# 理一下数据的生成,并做一个计数。
train_step(imgs, labels)
total_trained += imgs.shape[0]
if total_trained > 35000:
break
period = time.time() - start
train_samples_per_second = total_trained / period
start = time.time()
for imgs, labels in img_gen_val.flow_from_directory(
"worddata/validation/",
target_size=(128, 128),
color_mode="grayscale",
classes=classes,
batch_size=batch_size,
):
val_step(imgs, labels)
total_valed += imgs.shape[0]
if total_valed > 5000:
break
period = time.time() - start
val_samples_per_second = total_valed / period
print(
"Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}".format(
epoch,
train_loss.result(),
train_accuracy.result() * 100,
val_loss.result(),
val_accuracy.result() * 100,
)
)
print(
"Speed train {}imgs/s val {}imgs/s".format(
train_samples_per_second, val_samples_per_second
)
)
model.save_weights("model-{:04d}.ckpt".format(epoch))
# 每个 epoch 保存一下模型,需要注意每次
# 保存要用一个不同的名字,不然会导致覆盖,
# 同时还要关注一下磁盘空间占用,防止太多
# chekcpoint 占满磁盘空间导致错误。
# 注意这个 API 调用每次都会生成数个文件,
# 其中 checkpoint 文件用来记录每次的文
# 件路径,其他文件则存储模型数据和索引信息
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 0 Loss 7.023247718811035, Acc 1.0854661464691162, Val Loss 7.03449010848999, Val Acc 1.0532591342926025
Speed train 333.99072580901236imgs/s val 855.2596907200536imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 1 Loss 6.777282238006592, Acc 1.0226234197616577, Val Loss 6.941394805908203, Val Acc 0.9141494631767273
Speed train 327.1325337516975imgs/s val 916.1226758458906imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 2 Loss 6.359256267547607, Acc 1.174017310142517, Val Loss 5.80662727355957, Val Acc 1.2321144342422485
Speed train 324.335567399578imgs/s val 962.8157509235044imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 3 Loss 5.362932205200195, Acc 1.388254165649414, Val Loss 5.001005172729492, Val Acc 1.5699522495269775
Speed train 323.86471413694295imgs/s val 930.7858136706466imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 4 Loss 4.838431358337402, Acc 2.4737203121185303, Val Loss 4.622710704803467, Val Acc 3.5572338104248047
Speed train 322.9755351536125imgs/s val 923.726279512034imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 5 Loss 4.364129066467285, Acc 5.950068473815918, Val Loss 3.7800049781799316, Val Acc 11.54610538482666
Speed train 323.1431222660079imgs/s val 943.7244951999392imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 6 Loss 3.1959731578826904, Acc 21.5322208404541, Val Loss 2.499628782272339, Val Acc 36.08903121948242
Speed train 322.8869991116342imgs/s val 930.8482939234336imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 7 Loss 2.12115216255188, Acc 46.052330017089844, Val Loss 1.6911603212356567, Val Acc 56.79650115966797
Speed train 321.866934720262imgs/s val 936.5390705295545imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 8 Loss 1.473148226737976, Acc 61.46023941040039, Val Loss 1.1917930841445923, Val Acc 69.67408752441406
Speed train 321.0786097004539imgs/s val 952.8346292198817imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 9 Loss 1.1069483757019043, Acc 70.93235778808594, Val Loss 1.013684868812561, Val Acc 73.92686462402344
Speed train 319.400476897427imgs/s val 923.3046830012753imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 10 Loss 0.8779520988464355, Acc 76.97097778320312, Val Loss 0.8483538031578064, Val Acc 78.67646789550781
Speed train 309.2933807436792imgs/s val 900.1829360283824imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 11 Loss 0.7251728773117065, Acc 80.77296447753906, Val Loss 0.7506675124168396, Val Acc 81.0214614868164
Speed train 301.2632010228335imgs/s val 898.4208917080757imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 12 Loss 0.6008656024932861, Acc 83.94367218017578, Val Loss 0.7291795015335083, Val Acc 81.25993347167969
Speed train 297.387641228917imgs/s val 860.7664859874725imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 13 Loss 0.5257374048233032, Acc 85.77182006835938, Val Loss 0.6249831318855286, Val Acc 83.94276428222656
Speed train 294.72389001523715imgs/s val 870.8659186105319imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 14 Loss 0.47265708446502686, Acc 87.06867218017578, Val Loss 0.628078043460846, Val Acc 84.91653442382812
Speed train 292.79691865600137imgs/s val 885.4427240527049imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 15 Loss 0.41973787546157837, Acc 88.53118896484375, Val Loss 0.5823920965194702, Val Acc 86.40699768066406
Speed train 290.9758017995805imgs/s val 861.7976330609051imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 16 Loss 0.35504868626594543, Acc 89.89087677001953, Val Loss 0.5734841823577881, Val Acc 86.44673919677734
Speed train 287.8656709680878imgs/s val 843.401053133528imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 17 Loss 0.32744738459587097, Acc 90.64213562011719, Val Loss 0.5623900294303894, Val Acc 87.400634765625
Speed train 286.3891225531947imgs/s val 850.7197864862696imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 18 Loss 0.26646262407302856, Acc 92.30461883544922, Val Loss 0.5681359767913818, Val Acc 87.10254669189453
Speed train 293.50211639303797imgs/s val 890.6136287690091imgs/s
Found 35008 images belonging to 100 classes.
Found 5000 images belonging to 100 classes.
Epoch 19 Loss 0.3041706383228302, Acc 91.7533187866211, Val Loss 0.9558600187301636, Val Acc 79.47138214111328
Speed train 314.4602859964906imgs/s val 918.107688418758imgs/s
一些技巧
因为这里定义的模型比较大,同时训练的数据也比较多,每个 epoch 用时较长,因此,如果代码有 bug 的话,经过一次 epoch 再去 debug 效率比较低。
这种情况下,我们使用的数据生成过程又是自己手动指定数据数量的,因此可以尝试缩减模型规模,定义小一些的数据集来快速验证代码。在这个例子里,我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸,通过修改训练循环里面的数据数量来缩减数据数量。
_
下面的代码属于另外一个文件,因此部分代码跟上面是重复的。
模型的使用代码
模型训练好了之后要实际应用。对于模型部署有很多成熟的方案,如 Nvidia 的 TensorRT, Intel 的 OpenVINO 等,都可以做模型的高效部署,这里限于篇幅不涉及相关内容。
在模型训练过程中,也可以使用使用框架提供的 API 做模型的简单部署以方便开发。
import os
import tensorflow as tf
首先要加载模型的标签用于展示,因为我们训练的时候就已经生成了标签文件,这里直接用写好的代码就可以。
if os.path.exists("labels.txt"):
with open("labels.txt") as inf:
classes = [l.strip() for l in inf]
else:
classes = os.listdir("worddata/train/")
with open("labels.txt", "w") as of:
of.write("\r\n".join(classes))
接着是模型的定义,这里直接将训练中使用的模型代码拿来即可。
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
# 模型有两个主要部分,特征提取层和分类器
# 这里是特征提取层
self.feature = tf.keras.models.Sequential()
self.feature.add(self.conv(64))
self.feature.add(self.conv(64, add_pooling=True))
self.feature.add(self.conv(128))
self.feature.add(self.conv(128, add_pooling=True))
self.feature.add(self.conv(256))
self.feature.add(self.conv(256))
self.feature.add(self.conv(256, add_pooling=True))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512, add_pooling=True))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512))
self.feature.add(self.conv(512, add_pooling=True))
self.feature.add(tf.keras.layers.GlobalAveragePooling2D())
self.feature.add(tf.keras.layers.Dense(4096, activation="relu"))
self.feature.add(tf.keras.layers.BatchNormalization())
self.feature.add(tf.keras.layers.Dense(4096, activation="relu"))
self.feature.add(tf.keras.layers.BatchNormalization())
self.feature.add(tf.keras.layers.Dropout(0.5))
# 这个简单的机构是分类器
self.pred = tf.keras.layers.Dense(100)
def conv(self, filters, add_pooling=False):
# 模型大量使用重复模块构建,
# 这里将重复模块提取出来,简化模型构建过程
model = tf.keras.models.Sequential(
[
tf.keras.layers.Conv2D(filters, 3, padding="same", activation="relu"),
tf.keras.layers.BatchNormalization(),
]
)
if add_pooling:
model.add(
tf.keras.layers.MaxPool2D(
pool_size=(2, 2), strides=None, padding="same"
)
)
return model
def call(self, x):
# call 用来定义模型各个结构之间的运算关系
x = self.feature(x)
return self.pred(x)
有了模型的定义之后,我们可以加载训练好的模型,跟模型训练的时候类似,我们可以直接加载模型训练中的 checkpoint。
model = MyModel()
ckpt = tf.train.get_checkpoint_state("./ckpts/")
if ckpt:
model.load_weights(ckpt.model_checkpoint_path)
print("model lodaded")
model lodaded
对于数据,我们需要直接处理图片,因此这里导入一些图片处理的库和数据处理的库
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
直接打开某个图片
img = Image.open("worddata/validation/臣/1075875876fc1994ab864525089f675e6f9da575.jpg")
plt.imshow(img, cmap="gray")
<matplotlib.image.AxesImage at 0x7f521c3cfd30>
需要注意,模型在训练的时候,我们对数据进行了一些处理,在模型使用的时候,我们要对数据做一样的处理,如果不做的话,模型最终的结果会出现不可预料的问题。
img = img.resize((128, 128))
img = np.array(img) / 255
img.shape
(128, 128)
模型对图片数据的运算其实很简单,一行代码就可以。
这里需要注意模型处理的数据是 4 维的,而上面的图片数据实际是 2 维的,因此要对数据进行维度的扩充。同时模型的输出是 2 维的,带 batch ,所以需要压缩一下维度。
pred = np.squeeze(
tf.nn.softmax(model(img[np.newaxis, :, :, np.newaxis], training=False))
)
pred.argsort()[-5:]
print([pred[idx] for idx in pred.argsort()[-5:]])
print([classes[idx] for idx in pred.argsort()[-5:]])
WARNING:tensorflow:Layer my_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.
If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.
To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
[8.4665515e-08, 5.306108e-07, 1.8298316e-06, 4.1379462e-05, 0.999956]
['西', '定', '良', '白', '臣']
这里只给出了 top5 的结果,可以看到,准确率还是不错的。