一、继承nn.Module类并自定义层
我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类。
首先,简单实现一个Mylinear类:
from torch import nn # Mylinear继承Module class Mylinear(nn.Module): # 传入输入维度和输出维度 def __init__(self,in_d,out_d): # 调用父类构造函数 super(Mylinear,self).__init__() # 使用Parameter类将w和b封装,这样可以通过nn.Module直接管理,并提供给优化器优化 self.w = nn.Parameter(torch.randn(out_d,in_d)) self.b = nn.Parameter(torch.randn(out_d)) # 实现forward函数,该函数为默认执行的函数,即计算过程,并将输出返回 def forward(self, x): x = x@self.w.t() + self.b return x
转载自原文链接, 如需删除请联系管理员。
原文链接:pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强),转载请注明来源!
相关推荐