简单的神经网络回归模型
定义一个简单的回归模型,优化器选择了Adam。
·
一个完整的回归模型训练流程(使用均方误差(MSE)作为损失函数)。
从数据准备到模型训练,再参数更新。
在实际应用中,可能需要根据具体任务调整网络结构、损失函数、优化器等。
代码如下:
- 导入PyTorch库的必要模块。
torch
是PyTorch的核心库,torch.nn
用于构建神经网络,torch.optim
用于优化网络参数,DataLoader
和TensorDataset
用于数据加载和批处理
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
- 定义一个简单的回归模型。 这里定义了一个名为
Regressor
的类,它继承自nn.Module
。这个类构建了一个简单的前馈神经网络,用于回归任务。网络由三个全连接层组成,第一层将输入从input_size
维映射到128维,第二层将128维映射到64维,最后一层将64维映射到output_size
维(对于回归任务,通常是1维)。每一层后面都跟着一个ReLU激活函数,除了最后一层。
# 定义一个简单的回归模型
class Regressor(nn.Module):
def __init__(self, input_size, output_size):
super(Regressor, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, output_size)
)
#forward 函数定义了数据通过网络的前向传播路径
def forward(self, x):
return self.model(x)
# 参数设置
input_size = 86 # 特征数量
output_size = 1 # 回归任务的标签数量
# 创建回归模型
regressor = Regressor(input_size, output_size)
- 定义了损失函数和优化器。对于回归任务,通常使用均方误差损失(
MSELoss
)。优化器选择了Adam,这是一种自适应学习率的优化算法,适用于大多数情况。lr=0.001
设置了学习率。
# 损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.Adam(regressor.parameters(), lr=0.001)
- 创建了模拟数据,使用
TensorDataset
将数据和标签封装成一个数据集。
# 假设 data 是一个包含特征的张量,labels 是一个包含连续标签值的张量
data = torch.randn(1000, input_size) # 特征数据
labels = torch.randn(1000, output_size) # 连续标签
# 创建数据加载器
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
训练循环。它遍历指定的epoch次数,每个epoch中,数据被分批次加载,然后执行以下步骤:
optimizer.zero_grad()
清除之前的梯度。outputs = regressor(features)
计算模型的输出。loss = criterion(outputs, targets)
计算损失。loss.backward()
执行反向传播,计算梯度。optimizer.step()
更新模型参数。
每10个批次打印一次损失值,以监控训练进度。
# 训练模型
epochs = 100
for epoch in range(epochs):
for i, (features, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = regressor(features)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')
print("Training Complete.")
更多推荐
所有评论(0)