技术背景
随机采样问题,不仅仅只是一个统计学/离散数学上的概念,其实在工业领域也都有非常重要的应用价值/潜在应用价值,具体应用场景我们这里就不做赘述。本文重点在于在不同平台上的采样速率,至于另外一个重要的参数检验速率,这里我们先不做评估。因为在Jax中直接支持vmap的操作,而numpy的原生函数大多也支持了向量化的运算,两者更像是同一种算法的不同实现。所以对于检验的场景,两者的速度区别更多的也是在硬件平台上。
随机采样示例
关于Jax的安装和基本使用方法,读者可以自行参考Jax的官方文档,需要注意的是,Jax有CPU、GPU和TPU三个版本,如果需要使用其GPU版本的功能,还需要依赖于jaxlib,另外最好是指定安装对应的CUDA版本,这都是安装过程中所踩过的一些坑。最后如果安装的不是GPU的版本,运行Jax脚本的时候会有相关的提示说明。
随机采样,可以是针对一个给定的连续函数,也可以针对一个离散化的列表,但是为了更好的扩展性,一般问题都会转化成先获取均匀的随机分布,再转化成其他函数形式的分布,如正态分布等。所以这里我们更加的是关注下均匀分布函数的效率:
import numpy as np
import time
import jax.random as random
key = random.PRNGKey(0)
print ('An small example of numpy sampler: \n{}'.format(np.random.uniform(low=0,high=1,size=5)))
print ('An small example of jax sampler: \n{}'.format(random.uniform(key,shape=(5,),minval=0, maxval=1)))
data_size = 400000000
time0 = time.time()
s = np.random.uniform(low=0,high=1,size=data_size)
print ('The numpy time cost is: {}s'.format(time.time()-time0))
time1 = time.time()
v = random.uniform(key,shape=(data_size,),minval=0, maxval=1)
print ('The jax time cost is: {}s'.format(time.time()-time1))
执行结果如下:
An small example of numpy sampler:
[0.33654613 0.20267496 0.86859762 0.14940831 0.30321738]
An small example of jax sampler:
[0.57450044 0.09968603 0.39316022 0.8941783 0.59656656]
The numpy time cost is: 3.6664984226226807s
The jax time cost is: 0.10985755920410156s
同样是在生成双精度浮点数的情况下,我们可预期的GPU的速率在数据长度足够大的情况下一定是会更快的,这个运算结果也佐证了这个说法。
总结概要
关于工业领域中可能使用到的随机采样,更多的是这样的一个场景:给定一个连续或者离散的分布,然后进行大规模的连续采样,采样的同时需要对每一个得到的样点进行分析打分,最终在这大规模的采样过程中,有可能被使用到的样品可能只有其中的几份。那么这样的一个抽象问题,就非常适合使用分布式的多GPU硬件架构来实现。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/sampler.html
作者ID:DechinPhy
更多原著文章请参考:https://www.cnblogs.com/dechinphy/
打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958