python/pytorch功能示例

文章目录

  • 关于collate_fn
  • torch.multiprocessing.spawn()
  • python hook
  • 异步保存或者中止训练
  • 装饰器方式异步保存或者中止训练
  • 简单的数据加载器SimpleDataLoader
  • 带缓存队列的数据加载器SimpleDataLoader
  • 装饰器
  • 使用装饰器的方式记录log

关于collate_fn

from torch.utils.data import Dataset

import torch
# 创建 DataLoader
from torch.utils.data import DataLoader


def collate_fn(batch):
    # 找到 batch 中最长的序列
    max_len = max(len(x) for x in batch)

    # 使用 0 进行填充,使所有序列具有相同的长度
    padded_batch = [x + [0] * (max_len - len(x)) for x in batch]

    # 将列表转换为 PyTorch 的张量
    return torch.tensor(padded_batch)

class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]



# 假设我们有一些不同长度的句子
texts = [
    [1, 2, 3],
    [4, 5],
    [6, 7, 8, 9],
    [10]
]

dataset = TextDataset(texts)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

for batch in dataloader:
    print(batch)

输出:
Connected to pydev debugger (build 241.18034.82)
tensor([[1, 2, 3],
        [4, 5, 0]])
tensor([[ 6,  7,  8,  9],
        [10,  0,  0,  0]])

Process finished with exit code 0

torch.multiprocessing.spawn()

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms

# 设置基本超参数
batch_size = 32
learning_rate = 0.01
epochs = 3


# 简单的全连接神经网络
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return self.fc(x.view(-1, 28 * 28))


# 初始化分布式环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


# 清理分布式环境
def cleanup():
    dist.destroy_process_group()


# 训练步骤
def train(rank, world_size):
    # 初始化进程组
    setup(rank, world_size)

    # 定义模型并转移到当前 rank 对应的 GPU
    model = SimpleModel().to(rank)

    # 使用 DistributedDataParallel 包装模型
    model = DDP(model, device_ids=[rank])

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss().to(rank)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # 下载并加载 MNIST 数据集
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    # 使用 DistributedSampler 对数据进行分片
    train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

    # DataLoader 使用 Sampler 确保数据按 GPU 分片
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)

    # 训练循环
    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)  # 分布式训练时每个 epoch 调用一次,以确保数据的充分打乱
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)

            # 前向传播
            outputs = model(data)
            loss = criterion(outputs, target)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if batch_idx % 100 == 0:
                print(
                    f"Rank {rank}, Epoch [{epoch + 1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    # 清理进程组
    cleanup()


# 主函数,使用 torch.multiprocessing 启动多进程
def main():
    world_size = torch.cuda.device_count()  # 获取 GPU 数量
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()

输出:
Rank 0, Epoch [1/3], Step [0/938], Loss: 2.4825
Rank 1, Epoch [1/3], Step [0/938], Loss: 2.7278
Rank 1, Epoch [1/3], Step [100/938], Loss: 0.7033
Rank 0, Epoch [1/3], Step [100/938], Loss: 0.8019
Rank 1, Epoch [1/3], Step [200/938], Loss: 0.7423
Rank 0, Epoch [1/3], Step [200/938], Loss: 0.4223
Rank 1, Epoch [1/3], Step [300/938], Loss: 0.2833
Rank 0, Epoch [1/3], Step [300/938], Loss: 0.5531
Rank 1, Epoch [1/3], Step [400/938], Loss: 0.3288
......

python hook

import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = SimpleModel()

# 定义一个前向传播的 hook 函数
def forward_hook(module, input, output):
    print(f"Hook: 模块 {module} 的输入 {input},输出 {output}")

# 在某一层注册 hook(例如第一层全连接层)
hook_handle = model.fc1.register_forward_hook(forward_hook)

# 生成随机输入并执行前向传播
input_data = torch.randn(1, 10)
output = model(input_data)

# 打印结果
print(f"最终输出: {output}")

# 移除 hook
hook_handle.remove()

输出:
Hook: 模块 Linear(in_features=10, out_features=5, bias=True) 的输入 (tensor([[ 0.7551,  1.3895,  0.4566,  0.5799, -1.2487,  0.1824,  0.8438,  1.0473,
         -0.7047, -0.2592]]),),输出 tensor([[ 0.2386, -0.1723,  0.2275,  0.0239,  0.1021]],
       grad_fn=<AddmmBackward0>)
最终输出: tensor([[ 0.0837, -0.1774]], grad_fn=<AddmmBackward0>)

异步保存或者中止训练

import multiprocessing
import time
import torch


# 假设这是一个训练函数
def train_model(signal_queue):
    epoch = 0
    try:
        while True:
            print(f"Training... Epoch: {epoch}")
            time.sleep(1)  # 模拟训练时间

            # 定期检查是否有来自信号队列的消息
            if not signal_queue.empty():
                signal = signal_queue.get()  # 获取信号
                if signal == "SAVE":
                    print(f"Saving model at epoch {epoch}")
                    model = {"epoch": epoch}  # 模拟模型
                    torch.save(model, f"model_epoch_{epoch}.pt")
                    print(f"Model saved at epoch {epoch}")
                elif signal == "STOP":
                    print("Stopping training...")
                    break

            epoch += 1
            if epoch > 20:  # 模拟达到一定的epoch自动停止
                print("Reached maximum epochs.")
                break
    except KeyboardInterrupt:
        print("Training interrupted manually.")
    finally:
        print("Training finished.")


# 主进程中处理输入并发送信号
def monitor_and_send_signal(signal_queue):
    try:
        while True:
            command = input("Enter 'save' to save the model or 'stop' to stop training: ").strip()
            if command.lower() == 'save':
                signal_queue.put("SAVE")  # 向队列发送保存模型的信号
            elif command.lower() == 'stop':
                signal_queue.put("STOP")  # 向队列发送停止训练的信号
                break
    except KeyboardInterrupt:
        print("Monitoring interrupted manually.")
    finally:
        print("Monitoring finished.")


# 主函数,启动训练进程
def main():
    signal_queue = multiprocessing.Queue()  # 创建用于通信的队列

    # 创建训练进程
    training_process = multiprocessing.Process(target=train_model, args=(signal_queue,))

    # 启动训练进程
    training_process.start()

    # 在主进程中监听并发送信号
    monitor_and_send_signal(signal_queue)

    # 等待训练进程结束
    training_process.join()


if __name__ == "__main__":
    main()

装饰器方式异步保存或者中止训练

import time
import threading
import multiprocessing


# 装饰器用于监控程序运行
def monitor_decorator(func):
    def wrapper(*args, **kwargs):
        # 创建用于接收信号的队列
        signal_queue = kwargs.get('signal_queue')

        if signal_queue is None:
            raise ValueError("signal_queue is required as a keyword argument")

        # 创建一个线程不断检查信号队列,控制函数的运行状态
        stop_event = threading.Event()

        def check_signals():
            while not stop_event.is_set():
                if not signal_queue.empty():
                    signal = signal_queue.get()
                    if signal == 'STOP':
                        print("Stopping the monitored function.")
                        stop_event.set()
                    elif signal == 'SAVE':
                        print("Saving current state...")
                        # 可以在此处插入保存模型或保存状态的逻辑
                time.sleep(0.1)

        # 启动信号检查线程
        signal_thread = threading.Thread(target=check_signals)
        signal_thread.start()

        try:
            # 执行被装饰的函数
            result = func(*args, **kwargs, stop_event=stop_event)
        finally:
            # 结束线程
            stop_event.set()
            signal_thread.join()

        return result

    return wrapper


# 被监控的函数
@monitor_decorator
def long_running_function(*args, stop_event=None, **kwargs):
    epoch = 0
    while not stop_event.is_set():
        print(f"Running epoch {epoch}...")
        time.sleep(1)  # 模拟长时间运行任务
        epoch += 1
        if epoch > 50:  # 模拟条件终止
            print("Reached maximum epochs.")
            break
    print("Function completed.")


# 主进程处理用户输入并发送信号
def monitor_input(signal_queue):
    while True:
        command = input("Enter 'save' to save the state or 'stop' to stop execution: ").strip().lower()
        if command == 'save':
            signal_queue.put('SAVE')
        elif command == 'stop':
            signal_queue.put('STOP')
            break


# 主函数启动程序
if __name__ == "__main__":
    # 创建信号队列
    signal_queue = multiprocessing.Queue()

    # 启动长时间运行任务的进程
    process = multiprocessing.Process(target=long_running_function, kwargs={'signal_queue': signal_queue})
    process.start()

    # 在主进程中监听用户输入并发送信号
    monitor_input(signal_queue)

    # 等待子进程结束
    process.join()

简单的数据加载器SimpleDataLoader

import random
import time


class SimpleDataLoader:
    def __init__(self, data, batch_size=1, shuffle=False, curriculum_learning_enabled=False, post_process_func=None):
        """
        初始化数据加载器
        :param data: 数据集 (list, numpy array, etc.)
        :param batch_size: 每批数据的大小
        :param shuffle: 是否在每次迭代时随机打乱数据
        :param curriculum_learning_enabled: 是否启用课程学习
        :param post_process_func: 后处理函数(如果有)
        """
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.curriculum_learning_enabled = curriculum_learning_enabled
        self.post_process_func = post_process_func
        self.len = len(data) // batch_size
        self.current_index = 0  # 当前数据索引
        self.data_iterator = None

    def _create_dataloader(self):
        """创建数据迭代器,根据是否启用打乱和课程学习来决定如何构造"""
        if self.shuffle:
            random.shuffle(self.data)
        self.data_iterator = iter(self.data)  # 创建迭代器

    def __iter__(self):
        """初始化迭代器"""
        self.current_index = 0  # 每次迭代时重置索引
        self._create_dataloader()
        return self

    def __len__
上一篇:python 实现support vector machines支持向量机算法-support vector machines支持向量机算法介绍


下一篇:Python知识点:如何使用Python进行无人机数据处理