numpy 定义一个rnn_cell_forward

numpy 定义一个rnn_cell_forward

numpy 定义一个rnn_cell_forward

   

import numpy as np
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def rnn_cell_forward(xt, a_prev, parameters):

    Wax = parameters["Wax"]
    Waa = parameters["Waa"]
    Wya = parameters["Wya"]
    ba = parameters["ba"]
    by = parameters["by"]
    
    a_next = np.tanh(np.dot(Wax,xt)+np.dot(Waa,a_prev)+ba)

    yt_pred = softmax(np.dot(Wya,a_next)+by) 

    

    cache = (a_next, a_prev, xt, parameters)
    
    return a_next, yt_pred, cache
#测试
np.random.seed(1)
xt = np.random.randn(3,10)
a_prev = np.random.randn(5,10)
Waa = np.random.randn(5,5)
Wax = np.random.randn(5,3)
Wya = np.random.randn(2,5)
ba = np.random.randn(5,1)
by = np.random.randn(2,1)
parameters = {"Waa": Waa, "Wax": Wax, "Wya": Wya, "ba": ba, "by": by}

a_next, yt_pred, cache = rnn_cell_forward(xt, a_prev, parameters)
print("a_next[4] = ", a_next[4])
print("a_next.shape = ", a_next.shape)
print("yt_pred[1] =", yt_pred[1])
print("yt_pred.shape = ", yt_pred.shape)

样本数量 :10            

xt:3*10   神经元个数:3        a_prev、a_next : 5*10    神经元个数:5          yt: 2*10    神经元个数:2

一个神经元一个偏置项 numpy 定义一个rnn_cell_forwardba: 5*1       by:  2*1 

Waa :  5*5    a_prev numpy 定义一个rnn_cell_forward   a_next    numpy 定义一个rnn_cell_forwardnumpy 定义一个rnn_cell_forward 分别表示 a_prev 的第1,2,3,4,5个神经元到 a_next的第1个神经元的权重.

Wax : 5*3     xt numpy 定义一个rnn_cell_forward a_next   numpy 定义一个rnn_cell_forward  分别表示 xt 的第 1,2,3 个神经元到a_next的第1个神经元的权重。

Wya:2*5     a_next numpy 定义一个rnn_cell_forward yt_pred  numpy 定义一个rnn_cell_forward 分别表示 a_next的第1,2,3,4,5个神经元到 yt_pred 的第1个神经元的权重.

运行结果如下:

a_next[4] =  [ 0.59584544  0.18141802  0.61311866  0.99808218  0.85016201  0.99980978
 -0.18887155  0.99815551  0.6531151   0.82872037]
a_next.shape =  (5, 10)
yt_pred[1] = [0.9888161  0.01682021 0.21140899 0.36817467 0.98988387 0.88945212
 0.36920224 0.9966312  0.9982559  0.17746526]
yt_pred.shape =  (2, 10)

 

 

 

 

 

 

 

上一篇:Pytorch实现Top1准确率和Top5准确率


下一篇:多线程(十、AQS原理-ReentrantLock实现)