python-优化生成变量的拒绝方法

我对生成连续随机变量的拒绝方法的优化存在问题.我有一个密度:f(x)= 3/2(1-x ^ 2).这是我的代码:

import random
import matplotlib.pyplot as plt
import numpy  as np
import time
import scipy.stats as ss

a=0   # xmin
b=1   # xmax

m=3/2 # ymax
variables = [] #list for variables

def f(x):
    return 3/2 * (1 - x**2)  #probability density function

reject = 0   # number of rejections
start = time.time()
while len(variables) < 100000:  #I want to generate 100 000 variables
    u1 = random.uniform(a,b)
    u2 = random.uniform(0,m)

    if u2 <= f(u1):
        variables.append(u1)
    else:
        reject +=1
end = time.time()

print("Time: ", end-start)
print("Rejection: ", reject)
x = np.linspace(a,b,1000)
plt.hist(variables,50, density=1)
plt.plot(x, f(x))
plt.show()

ss.probplot(variables, plot=plt)
plt.show()

我的第一个问题:我的概率图设计正确吗?
第二,标题中的内容.如何优化该方法?我想获得一些建议以优化代码.现在,该代码大约需要0.5秒,并且大约有5万次拒绝.是否可以减少拒绝的时间和数量?如果需要,我可以使用另一种生成变量的方法进行优化.

解决方法:

My first question: Is my probability plot made properly?

不可以.它是针对默认正态分布的.您必须将函数f(x)打包到stats.rv_continuous派生的类中,使其成为_pdf方法,然后将其传递给probplot

And the second, what is in the title. How to optimise that method? Is it possible to reduce the time and number of rejections?

当然,您掌握了NumPy向量功能的强大功能.永远不要编写显式循环-vectoriz,vectorize和vectorize!

请看下面的修改代码,而不是单个循环,所有操作都是通过NumPy向量完成的.我的计算机上的时间从0.19减少到0.003,处理了100000个样本(至强,Win10 x64,Anaconda Python 3.7).

import numpy as np
import scipy.stats as ss
import matplotlib.pyplot as plt
import time

a = 0.  # xmin
b = 1.  # xmax

m = 3.0/2.0 # ymax

def f(x):
    return 1.5 * (1.0 - x*x)  # probability density function

start  = time.time()

N = 100000
u1 = np.random.uniform(a, b, N)
u2 = np.random.uniform(0.0, m, N)

negs = np.empty(N)
negs.fill(-1)
variables = np.where(u2 <= f(u1), u1, negs) # accepted samples are positive or 0, rejected are -1

end = time.time()

accept = np.extract(variables>=0.0, variables)
reject = N - len(accept)

print("Time: ", end-start)
print("Rejection: ", reject)

x = np.linspace(a, b, 1000)
plt.hist(accept, 50, density=True)
plt.plot(x, f(x))
plt.show()

ss.probplot(accept, plot=plt) # against normal distribution
plt.show()

关于减少拒绝的数量,您可以使用逆方法以0个拒绝进行采样,它是三次方程式,因此可以轻松地工作

更新

这是用于probplot的代码:

class my_pdf(ss.rv_continuous):
    def _pdf(self, x):
        return 1.5 * (1.0 - x*x)

ss.probplot(accept, dist=my_pdf(a=a, b=b, name='my_pdf'), plot=plt)

你应该得到像

python-优化生成变量的拒绝方法

上一篇:脚本速度与内存使用率


下一篇:php-优化Trie实现