tensorflow tfrecord文件存储

import tensorflow as tf
import numpy as np
import skimage
from skimage import data, io, color
from PIL import Image path = "1.tfrecords"
img_path = '/data/test/img/1.png' with tf.python_io.TFRecordWriter(path) as writer:
# list: int or float
a = 1024
b = 10.24 c = [0.1, 0.2, 0.3]
c = np.array(c).astype(np.float32).tobytes() d = [[1, 2], [3, 4]]
d = np.array(d).astype(np.int8).tobytes() e = "Python"
e = bytes(e, encoding='utf-8') img = io.imread(img_path)
img = img.astype(np.uint8).tobytes() img2 = Image.open(img_path)
img2 = img2.resize((256, 256))
img2 = img2.tobytes() example = tf.train.Example(features=tf.train.Features(feature={
'a': tf.train.Feature(int64_list=tf.train.Int64List(value=[a])),
'b': tf.train.Feature(float_list=tf.train.FloatList(value=[b])),
'c': tf.train.Feature(bytes_list=tf.train.BytesList(value=[c])),
'd': tf.train.Feature(bytes_list=tf.train.BytesList(value=[d])),
'e': tf.train.Feature(bytes_list=tf.train.BytesList(value=[e])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
'image2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img2])), }))
writer.write(example.SerializeToString()) # 读取
filename_queue = tf.train.string_input_producer([path])
_, serialized_example = tf.TFRecordReader().read(filename_queue) features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.int64),
'b': tf.FixedLenFeature([], tf.float32),
'c': tf.FixedLenFeature([], tf.string),
'd': tf.FixedLenFeature([], tf.string),
'e': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string),
'image2': tf.FixedLenFeature([], tf.string), }) a = features['a'] # 返回是张量
b = features['b'] c = features['c']
c = tf.decode_raw(c, tf.float32) d = features['d']
d = tf.decode_raw(d, tf.int8)
d = tf.reshape(d, [2, 2]) e = features['e'] img = tf.decode_raw(features['image'], tf.uint8)
img = tf.reshape(img, shape=[256, 256, 3]) img2 = tf.decode_raw(features['image2'], tf.uint8)
img2 = tf.reshape(img2, [256, 256,3]) with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tf.train.start_queue_runners(sess=sess) print(sess.run([a, b, c, d, e])) e = sess.run(e)
print(type(e), bytes.decode(e)) img = sess.run(img)
io.imshow(img) img2 = sess.run(img2)
io.imshow(img2)
上一篇:boos直聘扫码直接登陆js代码


下一篇:【ZZ】 DShader之位移贴图(Displacement Mapping)