📅  最后修改于: 2022-03-11 14:46:23.003000             🧑  作者: Mango
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained("Sentdex/GPyT")
model = AutoModelWithLMHead.from_pretrained("Sentdex/GPyT").to("cuda")
def generate(code, max_length=100):
'''Takes input code, replaces newline chars with ,
tokenizes, feeds thru model, decodes,
then reformats the newlines back in'''
newlinechar = ""
converted = code.replace("\n", newlinechar)
tokenized = tokenizer.encode(converted, return_tensors='pt').to("cuda")
resp = model.generate(tokenized, max_length=max_length).to("cuda")
decoded = tokenizer.decode(resp[0])
reformatted = decoded.replace("","\n")
return reformatted
print(generate("import"))