通过动态类选择 (HF-Softmax) 进行大规模分类的加速训练

内容来源:https://github.com/yl-1993/hfsoftmax

一、文章

Accelerated Training for Massive Classification via Dynamic Class Selection (HF-Softmax)

Accelerated Training for Massive Classification via Dynamic Class Selection, AAAI 2018 (Oral)

二、训练

  1. Install PyTorch. (Better to install the latest master from source)
  2. Follow the instruction of InsightFace and download training data.
  3. Decode the data(.rec) to images and generate training/validation list.

1、安装 PyTorch。 (最好从源码安装最新的master)
2、按照 InsightFace 的说明下载训练数据。
3、将数据(.rec)解码为图像并生成训练/验证列表。

python tools/rec2img.py --in-folder xxx --out-folder yyy

4、尝试正常训练。 使用 torch.nn.DataParallel(多线程)进行并行训练。

sh scripts/train.sh dataset_path

5、尝试抽样训练。 使用 1 个 GPU 进行训练,默认采样数为 1000。

python paramserver/paramserver.py
sh scripts/train_hf.sh dataset_path

三、分布式训练

对于分布式训练,每个 GPU 上有一个进程。

为 PyTroch 分布式训练提供了一些后端。 如果您想使用 nccl 作为后端进行分布式训练,请按照说明安装 NCCL2。

要测试您的分布式设置,您可以执行

sh scripts/test_distributed.sh

安装 NCCL2 后,您应该从源代码重新编译 PyTorch。

python setup.py clean install

在我们的例子中,我们使用 libnccl2=2.2.13-1+cuda9.0 libnccl-dev=2.2.13-1+cuda9.0 和 PyTorch 的 master 0.5.0a0+e31ab99

四、Hashing Forest

We use Annoy to approximate the hashing forest. You can adjust sample_numntrees and interval to balance performance and cost.

我们使用 Annoy 来近似哈希森林。您可以调整sample_num、ntrees 和interval 以平衡性能和成本。

五、参数服务器Parameter Sever

Parameter server is decoupled with PyTorch. A client is developed to communicate with the server. Other platforms can integrate the parameter server via the communication API. Currently, it only supports syncronized SGD updater.

参数服务器与 PyTorch 解耦。 开发了一个客户端来与服务器进行通信。 其他平台可以通过通信API集成参数服务器。 目前,它仅支持同步 SGD 更新程序。

六、评估Evaluation

./scripts/eval.sh arch model_path dataset_path outputs

它使用 torch.nn.DataParallel 提取特征并将其保存为 .npy。 这些特征随后将用于执行验证测试。

如果使用分布式训练,请在特征提取期间设置strict=False。

请注意,来自 InsightFace 的 bin 文件,例如 lfw.bin,是由 Python2 获取的。 Python 3.0+ 无法处理它。 您可以使用 Python2 进行评估,也可以先通过 Python3 重新获取 bin 文件。

上一篇:pyinstaller 打包exe相关


下一篇:CnetOS双网卡绑定