首页 » 技术分享 » pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)

pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)

 

一、继承nn.Module类并自定义层

我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类。

首先,简单实现一个Mylinear类:

from torch import nn

# Mylinear继承Module
class Mylinear(nn.Module):
    # 传入输入维度和输出维度
    def __init__(self,in_d,out_d):
        # 调用父类构造函数
        super(Mylinear,self).__init__()
        # 使用Parameter类将w和b封装,这样可以通过nn.Module直接管理,并提供给优化器优化
        self.w = nn.Parameter(torch.randn(out_d,in_d))
        self.b = nn.Parameter(torch.randn(out_d))

    # 实现forward函数,该函数为默认执行的函数,即计算过程,并将输出返回
    def forward(self, x):
        x = x@self.w.t() + self.b
        return x

转载自原文链接, 如需删除请联系管理员。

原文链接:pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强),转载请注明来源!

0