batch normalization手写代码

批归一化(BN)层的代码实现,pytorch里面处理会更复杂
这里只是写了一个较为简单的模块

# -*- coding: utf-8 -*-
# @Time    : 2021/3/9 11:59
# @Author  : Li Gang
# @File    : test4.py
import numpy as np

def batchnorm(X,params,mode):
    mode = mode
    D,N = params.shape
    running_mean = params.get('running_mean', np.zeros(D,dtype=X.dtype))
    running_var = params.get('running_var', np.zeros(D,dtype=X.dtype))
    gamma = params.get('gamma')
    beta = params.get('beta')
    eps = params.get('eps', 1e-5)
    if mode=='train':
        samples_mean = np.mean(X, axis=0)
        samples_var = np.var(X, axis=0)
        out_ = (X - samples_mean) / (np.sqrt(samples_var) + eps)
        momentum = params.get('momentum')
        out = gamma * out_ + beta
        running_mean = momentum * running_mean + (1 - momentum) * samples_mean
        running_var = momentum * running_var + (1 - momentum) * samples_var
        params['running_mean'] = running_mean
        params['running_var'] = running_var
    elif mode=='test':
        out_ = (X-running_mean)/(running_var+eps)
        out = gamma*out_+beta
    return out



上一篇:多线程


下一篇:Zabbix自定义key值监控MySQL主从同步