文章目录
- 关于collate_fn
- torch.multiprocessing.spawn()
- python hook
- 异步保存或者中止训练
- 装饰器方式异步保存或者中止训练
- 简单的数据加载器SimpleDataLoader
- 带缓存队列的数据加载器SimpleDataLoader
- 装饰器
- 使用装饰器的方式记录log
关于collate_fn
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
def collate_fn(batch):
max_len = max(len(x) for x in batch)
padded_batch = [x + [0] * (max_len - len(x)) for x in batch]
return torch.tensor(padded_batch)
class TextDataset(Dataset):
def __init__(self, texts):
self.texts = texts
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
return self.texts[idx]
texts = [
[1, 2, 3],
[4, 5],
[6, 7, 8, 9],
[10]
]
dataset = TextDataset(texts)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch in dataloader:
print(batch)
输出:
Connected to pydev debugger (build 241.18034.82)
tensor([[1, 2, 3],
[4, 5, 0]])
tensor([[ 6, 7, 8, 9],
[10, 0, 0, 0]])
Process finished with exit code 0
torch.multiprocessing.spawn()
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
batch_size = 32
learning_rate = 0.01
epochs = 3
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(28 * 28, 10)
def forward(self, x):
return self.fc(x.view(-1, 28 * 28))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
model = SimpleModel().to(rank)
model = DDP(model, device_ids=[rank])
criterion = nn.CrossEntropyLoss().to(rank)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
for epoch in range(epochs):
train_sampler.set_epoch(epoch)
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
outputs = model(data)
loss = criterion(outputs, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print(
f"Rank {rank}, Epoch [{epoch + 1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")
cleanup()
def main():
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
输出:
Rank 0, Epoch [1/3], Step [0/938], Loss: 2.4825
Rank 1, Epoch [1/3], Step [0/938], Loss: 2.7278
Rank 1, Epoch [1/3], Step [100/938], Loss: 0.7033
Rank 0, Epoch [1/3], Step [100/938], Loss: 0.8019
Rank 1, Epoch [1/3], Step [200/938], Loss: 0.7423
Rank 0, Epoch [1/3], Step [200/938], Loss: 0.4223
Rank 1, Epoch [1/3], Step [300/938], Loss: 0.2833
Rank 0, Epoch [1/3], Step [300/938], Loss: 0.5531
Rank 1, Epoch [1/3], Step [400/938], Loss: 0.3288
......
python hook
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = SimpleModel()
def forward_hook(module, input, output):
print(f"Hook: 模块 {module} 的输入 {input},输出 {output}")
hook_handle = model.fc1.register_forward_hook(forward_hook)
input_data = torch.randn(1, 10)
output = model(input_data)
print(f"最终输出: {output}")
hook_handle.remove()
输出:
Hook: 模块 Linear(in_features=10, out_features=5, bias=True) 的输入 (tensor([[ 0.7551, 1.3895, 0.4566, 0.5799, -1.2487, 0.1824, 0.8438, 1.0473,
-0.7047, -0.2592]]),),输出 tensor([[ 0.2386, -0.1723, 0.2275, 0.0239, 0.1021]],
grad_fn=<AddmmBackward0>)
最终输出: tensor([[ 0.0837, -0.1774]], grad_fn=<AddmmBackward0>)
异步保存或者中止训练
import multiprocessing
import time
import torch
def train_model(signal_queue):
epoch = 0
try:
while True:
print(f"Training... Epoch: {epoch}")
time.sleep(1)
if not signal_queue.empty():
signal = signal_queue.get()
if signal == "SAVE":
print(f"Saving model at epoch {epoch}")
model = {"epoch": epoch}
torch.save(model, f"model_epoch_{epoch}.pt")
print(f"Model saved at epoch {epoch}")
elif signal == "STOP":
print("Stopping training...")
break
epoch += 1
if epoch > 20:
print("Reached maximum epochs.")
break
except KeyboardInterrupt:
print("Training interrupted manually.")
finally:
print("Training finished.")
def monitor_and_send_signal(signal_queue):
try:
while True:
command = input("Enter 'save' to save the model or 'stop' to stop training: ").strip()
if command.lower() == 'save':
signal_queue.put("SAVE")
elif command.lower() == 'stop':
signal_queue.put("STOP")
break
except KeyboardInterrupt:
print("Monitoring interrupted manually.")
finally:
print("Monitoring finished.")
def main():
signal_queue = multiprocessing.Queue()
training_process = multiprocessing.Process(target=train_model, args=(signal_queue,))
training_process.start()
monitor_and_send_signal(signal_queue)
training_process.join()
if __name__ == "__main__":
main()
装饰器方式异步保存或者中止训练
import time
import threading
import multiprocessing
def monitor_decorator(func):
def wrapper(*args, **kwargs):
signal_queue = kwargs.get('signal_queue')
if signal_queue is None:
raise ValueError("signal_queue is required as a keyword argument")
stop_event = threading.Event()
def check_signals():
while not stop_event.is_set():
if not signal_queue.empty():
signal = signal_queue.get()
if signal == 'STOP':
print("Stopping the monitored function.")
stop_event.set()
elif signal == 'SAVE':
print("Saving current state...")
time.sleep(0.1)
signal_thread = threading.Thread(target=check_signals)
signal_thread.start()
try:
result = func(*args, **kwargs, stop_event=stop_event)
finally:
stop_event.set()
signal_thread.join()
return result
return wrapper
@monitor_decorator
def long_running_function(*args, stop_event=None, **kwargs):
epoch = 0
while not stop_event.is_set():
print(f"Running epoch {epoch}...")
time.sleep(1)
epoch += 1
if epoch > 50:
print("Reached maximum epochs.")
break
print("Function completed.")
def monitor_input(signal_queue):
while True:
command = input("Enter 'save' to save the state or 'stop' to stop execution: ").strip().lower()
if command == 'save':
signal_queue.put('SAVE')
elif command == 'stop':
signal_queue.put('STOP')
break
if __name__ == "__main__":
signal_queue = multiprocessing.Queue()
process = multiprocessing.Process(target=long_running_function, kwargs={'signal_queue': signal_queue})
process.start()
monitor_input(signal_queue)
process.join()
简单的数据加载器SimpleDataLoader
import random
import time
class SimpleDataLoader:
def __init__(self, data, batch_size=1, shuffle=False, curriculum_learning_enabled=False, post_process_func=None):
"""
初始化数据加载器
:param data: 数据集 (list, numpy array, etc.)
:param batch_size: 每批数据的大小
:param shuffle: 是否在每次迭代时随机打乱数据
:param curriculum_learning_enabled: 是否启用课程学习
:param post_process_func: 后处理函数(如果有)
"""
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.curriculum_learning_enabled = curriculum_learning_enabled
self.post_process_func = post_process_func
self.len = len(data) // batch_size
self.current_index = 0
self.data_iterator = None
def _create_dataloader(self):
"""创建数据迭代器,根据是否启用打乱和课程学习来决定如何构造"""
if self.shuffle:
random.shuffle(self.data)
self.data_iterator = iter(self.data)
def __iter__(self):
"""初始化迭代器"""
self.current_index = 0
self._create_dataloader()
return self
def __len__