1、降采样
def down_sample(train_x,train_y):
train_0 = train_x[np.where(train_y == 0)]
train_0_y = train_y[np.where(train_y == 0)]
train_1 = train_x[np.where(train_y == 1)]
train_1_y = train_y[np.where(train_y == 1)]
if train_0.shape[0]>train_1.shape[0]:
return np.concatenate((train_0[:train_1.shape[0]],train_1),axis=0),np.concatenate((train_0_y[:train_1.shape[0]],train_1_y),axis=0)
else:
return np.concatenate((train_1[:train_0.shape[0]],train_0),axis=0),np.concatenate((train_1_y[:train_0.shape[0]],train_0_y),axis=0)