最近实习中遇到了C++解析TFRecord的需求,搜索一圈发现虽然tensorflow C++ API中提供了相应的接口,但是编译C++版本的Tensorflow并不容易&很不清真,把他当做自己的项目的依赖就更离谱了。内网外网找了很久都发现没有相关的教程,于是调研了一圈,写了个自定义的解析脚本,只需要安装了解protobuf即可使用。读懂本文以及使用对应代码需要对protobuf有一定了解。
TFRecord的官方文档中说明了TFRecord由若干tf.train.Example组成,每条tf.train.Example其实就是一个protobuf的message,而这个message的定义文件就是tensorflow的代码库中的example.proto
文件(目前的文件链接)
值得注意的是,为了保证数据的正确性,TFRecord并不是直接将一条条Example的序列化结果首尾相连的,而是给每一条Example添加了一些header和footer。这些额外的信息对应的定义在tensorflow的代码库中的record_writer.h
文件里(目前的文件链接),最重要的就是下面几行
class RecordWriter {
public:
// Format of a single record:
// uint64 length
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
static constexpr size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static constexpr size_t kFooterSize = sizeof(uint32);
可以看到,Example序列化完的有效数据就是中间的data字段,data的长度是length字段,此外这俩字段都还有一个crc校验。
这样解析tfrecord的思路就很清晰了:
- 拿到一个tfrecord文件,以二进制模式读
- 前8个字节读到一个uint64变量中,获取length信息
- 跳过4个字节的length的crc校验码(只要胆子大)
- 读取length个字节的data,交给protobuf接口来解析
- 跳过4个字节的data的crc校验码
代码还在整理中,后面会post一下github链接