📅  最后修改于: 2020-11-11 00:39:18             🧑  作者: Mango
还有另一种查找预测的方法。在上一节中,我们使用forward()和实现线性模型来找到预测。此方法非常有效且可靠。很容易理解和实施。
在自定义模块中,我们使用类创建一个自定义模块,它是init()和forward()方法和模型。 init()方法用于初始化该类的新实例。在此init()方法中,第一个参数是self,它指示该类的实例(该对象尚未初始化),而在其自身之后,我们可以添加其他参数。
下一个参数是初始化线性模型的实例。在上一节中,初始化线性模型需要输入大小以及输出大小等于1,但是在自定义模块中,我们传递输入大小和输出大小变量而不传递其默认值。
在这种情况下,需要导入割炬的nn包。在此,我们使用继承,以便此子类将利用我们的基类和模块中的代码。
模块本身通常将充当所有神经网络模块的基类。之后,我们创建一个模型以进行预测。
让我们看一个示例,该示例如何通过创建自定义模块来完成预测:
对于单个数据
import torch
import torch.nn as nn
class LinearRegression(nn.Module):
def __init__(self,input_size, output_size):
super().__init__()
self.linear=nn.Linear(input_size,output_size)
def forward(self,x):
pred=self.linear(x)
return pred
torch.manual_seed(1)
model=LinearRegression(1,1)
x=torch.tensor([1.0])
print(model.forward(x))
输出:
tensor([0.0739], grad_fn=)
对于多个数据
import torch
import torch.nn as nn
class LinearRegression(nn.Module):
def __init__(self,input_size, output_size):
super().__init__()
self.linear=nn.Linear(input_size,output_size)
def forward(self,x):
pred=self.linear(x)
return pred
torch.manual_seed(1)
model=LinearRegression(1,1)
x=torch.tensor([[1.0],[2.0],[3.0]])
print(model.forward(x))
输出:
tensor([[0.0739],
[0.5891],
[1.1044]], grad_fn=)