from caffe.proto import caffe_pb2
s = caffe_pb2.SolverParameter()
s.train_net = "train.prototxt" # 定义网络名为trai.prototxt
s.test_net.append("test.prototxt") # 定义测试网络
s.test_interval = 100
s.test_iter.append(10)
# 定义最大迭代次数
s.max_iter = 1000
s.base_lr = 0.1
# 定义学习率衰减率
s.weight_decay = 5e-4
# 义学习率更新方式
s.lr_policy = "step"
# 定义网络打印间隔
s.display = 10
# 定义模型和存储间隔
s.snapshop = 10.
# 定义模型存放路径
s.snapshop_prefix_prefix = "model"
s.type="SGD"
s.solver_mode=caffe_pb2.SolverParameter.GPU #定义网络优化使用gpu
#生成solver文件s.prototxt
with open('net/s.prototxt','w') as f:
f.write(str(s))