📅  最后修改于: 2022-03-11 14:45:18.363000             🧑  作者: Mango
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.weight = torch.nn.Parameter(torch.zeros(2, 1))
self.bias = torch.nn.Parameter(torch.zeros(1))
self.register_buffer('a_constant_tensor', torch.tensor([0.5]))
def forward(self, x):
# linear regression completely from scratch,
# using parameters created in __init__
x = torch.mm(x, self.weight) + self.bias + self.a_constant_tensor
return x
model = Model().cuda()