写在前面
最近想体验一下AI方面的技术,但是在做一些TensorFlow官网上的简单实验时却发现由于网络、环境等问题,里面的示例代码并不能很顺利地跑在我的机器上,主要原因出在从远端自动load_data
的时候,总是会报https、ssl等网络相关的问题。之前的Fashion MNIST,通过网上的一些攻略,找到了本地导入的方法,于是在进行下一个imdb(电影评论文本分类)的小实验时,想自己写一个类似功能的函数来本地导入数据集,于是就有了本文,代码和思路供参考。(另外,我发现keras中的imdb这部分代码并不长,于是把另外一个也需要联网的imdb.get_word_index()
也顺便一起改成离线版了,反正网络问题要出现就都是一起出现的…)
我的环境
操作系统:CentOS 7
Conda版本:从TUNA下载的Anaconda3-5.3.1-Linux-x86_64.sh
TensorFlow版本:1.15.4
Keras:Tensorflow中自带的
使用方法
-
下载imdb相关的两个文件(imdb.npz, imdb_word_index.json)到本地目录(例如
/home/user/datasets/TF_imdb
)。度盘下载链接 / 提取码:ic54 -
复制下方小节的【离线导入代码块】到.py文件中,修改代码中标出的路径,定位到刚才下载的两个文件。(如果缺少相应的包,可以用pip install)
-
把官网示例代码中的
imdb.load_data(num_words=10000)
修改成我们自定义的manual_imdb_load_data(num_words=10000)
;同理,imdb.get_word_index()
修改成manual_imdb_get_word_index()
-
其他代码可以和官网示例中保持一致,运行测试即可。
【离线导入代码块】:由于是从官方代码中改的,有些注释我保留了。用之前记得先把下载的两个文件放在对应路径下。
import tensorflow as tf
from tensorflow import keras
import numpy as np
print(tf.__version__) # 简单检测一下tensorflow的版本
# 修改imdb的一些函数(因为https有问题)
from tensorflow.python.platform import tf_logging as logging
import json
def manual_imdb_load_data(path='imdb.npz',
num_words=None,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
**kwargs):
if 'nb_words' in kwargs:
logging.warning('The `nb_words` argument in `load_data` '
'has been renamed `num_words`.')
num_words = kwargs.pop('nb_words')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
path = '/home/user/datasets/TF_imdb/imdb.npz' # 【修改点1】
with np.load(path, allow_pickle=True) as f:
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']
rng = np.random.RandomState(seed)
indices = np.arange(len(x_train))
rng.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]
indices = np.arange(len(x_test))
rng.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]
if start_char is not None:
x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
elif index_from:
x_train = [[w + index_from for w in x] for x in x_train]
x_test = [[w + index_from for w in x] for x in x_test]
if maxlen:
x_train, labels_train = keras.preprocessing.sequence._remove_long_seq(maxlen, x_train, labels_train)
x_test, labels_test = keras.preprocessing.sequence._remove_long_seq(maxlen, x_test, labels_test)
if not x_train or not x_test:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
'Increase maxlen.')
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
if not num_words:
num_words = max(max(x) for x in xs)
# by convention, use 2 as OOV word
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
xs = [
[w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
]
else:
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
idx = len(x_train)
x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
def manual_imdb_get_word_index(path='imdb_word_index.json'):
"""Retrieves a dict mapping words to their index in the IMDB dataset.
Arguments:
path: where to cache the data (relative to `~/.keras/dataset`).
Returns:
The word index dictionary. Keys are word strings, values are their index.
"""
path = '/home/user/datasets/TF_imdb/imdb_word_index.json' # 【修改点2】
with open(path) as f:
return json.load(f)
#------------------------------offline loading preparation-------------------------------#