在tensorflow2.x的keras中内置了7种类型的数据集:
数据集名称 | 数据集描述 |
---|---|
boston_housing | 波士顿房价数据 |
cifar10 | 10种类别图片集 |
cifar100 | 100种类别图片集 |
fashion_mnist | 10种时尚类别图片集 |
imdb | 电影评论情感分类数据集 |
mnist | 手写数字图片集 |
reuters | 路透社新闻主题分类数据集 |
这些数据的读取都可以使用load_data()方法。不过2种关于文本的数据集imdb和reuters比较特殊,他们的load_data中包含了过滤参数。本文将介绍imdb的load_data参数以及用法。
imdb.load_data的定义如下:
tf.keras.datasets.imdb.load_data(
path='imdb.npz', num_words=None, skip_top=0, maxlen=None, seed=113,
start_char=1, oov_char=2, index_from=3, **kwargs
)
-
path
此参数定义的文件的名称。一般使用默认值 -
num_words
整数。定义的是大于该词频的单词会被读取。如果单词的词频小于该整数,会用oov_char定义的数字代替。默认是用2代替。
需要注意的是,词频越高的单词,其在单词表中的位置越靠前,也就是索引越小,所以实际上保留下来的是索引小于此数值的单词。 -
skip_top
整数。词频低于此整数的单词会被读入。高于此整数的会被oov_char定义的数字代替。 -
maxlen
整数。评论单词数小于此数值的会被读入。比如一条评论包含的单词数是120,如果maxlen=100,则该条评论不会被读入。 -
seed
整数。定义了随机打乱数据时候的初始化种子。跟生成随机数的种子是一样的。 -
start_char
整数。定义了每条评论的起始索引。默认值是1。 -
oov_char
整数。定义了不满足条件单词的替代值。凡是不满足过滤条件的单词的索引都用此数值代替。 -
index_from
整数。单词索引大于此数值的会被读入。 -
**kwargs
兼容用途。 -
num_words使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(num_words=100)
print("前100词频:",len(x100),' 第一个评论【100】:',len(x100[0]))
print('第一个评论内容【100】:',x100[0][0:10])
结果如下:
全部数据: 25000 第一个评论: 218
第一个评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
前100词频: 25000 第一个评论【100】: 218
第一个评论内容【100】: [1, 14, 22, 16, 43, 2, 2, 2, 2, 65]
对比可以发现,索引大于100的都被2替代了。
- skip_top使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(skip_top=100)
print("跳过前100词频:",len(x100),' 第一个评论【100】:',len(x100[0]))
print('第一个评论内容【100】:',x100[0][0:10])
结果如下:
全部数据: 25000 第一个评论: 218
第一个评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
跳过前100词频: 25000 第一个评论【100】: 218
第一个评论内容【100】: [2, 2, 2, 2, 2, 530, 973, 1622, 1385, 2]
对比可以发现,索引小于100的都被2替代了。
- maxlen使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(maxlen=100)
print("长度小于100:",len(x100),' 第一个评论【100】:',len(x100[0]))
结果:
全部数据: 25000 第一个评论: 218
长度小于100词频: 5736 第一个评论【100】: 43
可以看出定义了maxlen之后,读入的数据少了。
- start_char使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(start_char=100)
print("起始索引:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])
结果如下:
全部数据: 25000 第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
起始索引: 25000 第一个评论【100】: 218
第一条评论内容【100】: [100, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
对比可以发现,评论开始的数值被换为100了。
- oov_char使用
示例如下;
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(oov_char=100,skip_top=20)
print("替换索引=100:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])
结果如下:
全部数据: 25000 第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
替换索引=100: 25000 第一个评论【100】: 218
第一条评论内容【100】: [100, 100, 22, 100, 43, 530, 973, 1622, 1385, 65]
可以发现替换索引是100。起始索引也变为100了。即使定义了start_char也没有作用,这一点一定要注意。
- index_from使用
示例代码:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(index_from=100)
print("index_from=100:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])
结果:
全部数据: 25000 第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
index_from=100: 25000 第一个评论【100】: 218
第一条评论内容【100】: [1, 111, 119, 113, 140, 627, 1070, 1719, 1482, 162]
对比可以发现,每个单词都被增加了100-3=97当index_from =100的时候。之所以要减去3是因为默认参数index_from=3,因此不带任何参数的load_data()实际上是在原始的索引上增加了3。
。
- **kwargs使用
兼容使用。
源代码分析
从Github上可以看到此函数的代码:
if 'nb_words' in kwargs:
logging.warning('The `nb_words` argument in `load_data` '
'has been renamed `num_words`.')
num_words = kwargs.pop('nb_words')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'imdb.npz',
file_hash=
'69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f')
with np.load(path, allow_pickle=True) as f:
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']
rng = np.random.RandomState(seed)
indices = np.arange(len(x_train))
rng.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]
indices = np.arange(len(x_test))
rng.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]
if start_char is not None:
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
elif index_from:
x_train = [[w + index_from for w in x] for x in x_train]
x_test = [[w + index_from for w in x] for x in x_test]
if maxlen:
x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
if not x_train or not x_test:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
'Increase maxlen.')
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if not num_words:
num_words = max(max(x) for x in xs)
# by convention, use 2 as OOV word
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
xs = [
[w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
]
else:
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
idx = len(x_train)
x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
seed是用来初始随机数的:
rng = np.random.RandomState(seed)
start_char是额外添加的:
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
下面这段代码是取得评论的最大词频:
if not num_words:
num_words = max(max(x) for x in xs)
这段代码实现了oov_char替换:
if oov_char is not None:
xs = [
[w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
]
else:
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
需要注意的是,由于oov_char是全替换索引,也包括start_char。因此在更改oov_char的时候,还要注意start_char也被修改了。这应该是个小bug。