使用python对bin文件进行操作

python对bin文件操作的步骤

背景

博主想对神经网络模型的参数写入 bin 文件,方便在后续创建IP的过程中读取数据进行验证,记录 python 读取 pytorch 的模块参数并进行bin文件写入和读取操作。本文以3x3卷积为例。

示例代码

本文涉及的模块

struct: 该模块可以执行 Python 值和以 Python bytes 对象表示的 C 结构之间的转换。

pytorch :神经网络框架

简单示例

import struct

SAVE_DIR = "./conv3x3_pool_relu_outputs"

import struct

val = -1
a = struct.pack('i', val)# 将 -1 进行二进制打包(4字节), 默认情况下,打包给定 C 结构的结果会包含填充字节以使得所涉及的 C 类型保持正确的对齐
print(a)  # b'\xff\xff\xff\xff',确实是-1的二进制补码表示
file = os.path.join(SAVE_DIR, "wt.bin")
with open(file, "ab+") as fw:  # 二进制追加形式
    fw.write(a)

with open(file, "rb") as fr:  # 二进制读形式
    b = struct.unpack('i', fr.read(4))
    print(b[0])  # (-1,),返回元组
    print(b[0] == val)  # true

完整保存参数代码

# coding:utf-8
"""
for generate conv3x3_pool_relu and data for test.
"""
import os

import torch
import torch.nn as nn

# hyper param
Hin = 6
Win = 12
CHin = 16
CHout = 16
step = 0.1
G_SIZE = 8

SAVE_DIR = "./conv3x3_pool_relu_outputs"

seed = 2021
torch.random.manual_seed(seed)

def format_num(x):
    """
    >0 -> 1, <0 -> -1. switch func.
    """
    return (torch.randn_like(x) > 0).to(torch.float32) * 2 - 1


def save_conv3x3_weight(weight, save_dir="./outputs", filename="conv3x3", size=8):
    """
    写入文件,
    """
    shape = weight.shape
    print("save {} weights(bin format) ".format(filename), shape, end="  ---------wait----------  ")
    assert len(shape) == 4 and shape[0] % size == 0 and shape[1] % size == 0, "input error"

    if not ".dat" in filename:
        filename = filename + "_weight.bin"

    if type(weight) in [torch.nn.Parameter, torch.Tensor]:
        weight = weight.cpu().detach().numpy()

    filepath = os.path.join(save_dir, filename)
    with open(filepath, "wb+") as fw:
        for i in range(0, shape[0], size):
            for j in range(0, shape[1], size):
                for co in range(i, i + size):
                    for ci in range(j, j + size):
                        for h in range(3):
                            for w in range(3):
                                fw.write(struct.pack('i', int(weight[co][ci][h][w])))  # 写入前进行二进制转换
    print("save conv3x3_weight done. save weights to {}".format(filepath))
    return filepath


class Conv3x3PoolRelu(nn.Module):
    def __init__(self, in_channels=16, out_channels=32, save=False, out_dir="./outputs", save_size=8):
        super().__init__()
        assert in_channels % G_SIZE == 0 and out_channels % G_SIZE == 0, "input error!!"

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2, 0)

        self.init_weights()
        self.save_dir = out_dir

    def forward(self, x):
        for name, module in self.named_children():
            print(name, module)
            if type(module) in [nn.Conv2d]:
                save_conv3x3_weight(module.weight, self.save_dir)
            x = module(x)
        return x

    def init_weights(self):
        for idx, m in self.named_modules():
            # print(idx, " ", type(m))
            # initialize
            if type(m) in [nn.Conv2d]:
                weight, bias = m.weight, m.bias
                m_weight, m_bias = format_num(weight), format_num(bias)
                # m_weight, m_bias = format_num(weight), torch.zeros_like(bias)
                m.weight, m.bias = nn.Parameter(m_weight, requires_grad=False), nn.Parameter(m_bias,
                                                                           requires_grad=False)
                
if __name__ == '__main__':
    model = Conv3x3PoolRelu(8, 8, out_dir=SAVE_DIR)
    x = format_num(torch.randn(1, 8, 4, 4))
    y = model(x)
    print(y.shape, y)

    # read to validate
    w = torch.empty(8, 8, 3, 3)
    con, cin, kh, kw = w.shape
    with open("./conv3x3_pool_relu_outputs/conv3x3_weight.bin", "rb") as fr:
        for co in range(con):
            for ci in range(cin):
                for i in range(kh):
                    for j in range(kw): 
                        data = struct.unpack("i", fr.read(4))  # 使用unpack进行转换为数据类型,注意read后面跟的是4个字节
                        w[co][ci][i][j] = data[0]

    print(w)

步骤

详细步骤如下:

  1. 以二进制读写方式打开文件;
  2. 使用struct库对相应的数据类型进行二进制转换(读使用unpack,写使用pack);
  3. 读取或者写入文件中。

重要

Note:写入文件的格式和数据类型之间的关系如下:

格式 C 类型 Python 类型 标准大小
x 填充字节
c char 长度为 1 的字节串 1
b signed char 整数 1
B unsigned char 整数 1
? _Bool bool 1
h short 整数 2
H unsigned short 整数 2
i int 整数 4
I unsigned int 整数 4
l long 整数 4
L unsigned long 整数 4
q long long 整数 8
Q unsigned long long 整数 8
n ssize_t 整数
N size_t 整数
e (6) float 2
f float float 4
d double float 8
s char[] 字节串
p char[] 字节串
P void * 整数

写入 bin 文件主要是将二进制数据写入,如果一开始就是二进制数据,那么就不需要进行 structpack 操作。另外,对于python的数据类型,写入文件的字节顺序、大小与对齐方式可以设置,详细见官方文档[2]。

参考

1、python bin 文件处理 - 云 + 社区 - 腾讯云 (tencent.com)

2、struct — 将字节串解读为打包的二进制数据 — Python 3.8.12 文档

上一篇:XML实现动物园动物添加功能


下一篇:背包问题总结