Advanced Computing Platform for Theoretical Physics

Commit 67a02f07 authored by Pengfei Zhou's avatar Pengfei Zhou
Browse files

main

parent c2e3ae76
gru_mpo @ 22c85b41
Subproject commit 22c85b412e7be2a152c2f7347d52e02cdbec0155
import torch as t
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from model import *
from torchnet import meter
import tqdm
from config import *
from test import *
import sys
import time
def train():
if Config.use_gpu:
Config.device = t.device("cuda")
else:
Config.device = t.device("cpu")
device = Config.device
datas = np.load("tang.npz",allow_pickle=True)
print(datas)
data = datas['data']
print(data)
#print(ix2word[data[1]])
print(data.shape[0])
print(np.max(data))
ix2word = datas['ix2word'].item()
#print(ix2word)
word2ix = datas['word2ix'].item()
#print(word2ix)
print(ix2word[data[1][0]])
data = t.from_numpy(data)
dataloader = DataLoader(data,
batch_size=Config.batch_size,
shuffle=True,
num_workers=2)
model = PoetryModel(8400,
embedding_dim=Config.embedding_dim,
hidden_dim = Config.hidden_dim,mpo=Config.mpo)
Configimizer = optim.Adam(model.parameters(),lr=Config.lr)
criterion = nn.CrossEntropyLoss()
#if Config.model_path:
#model.load_state_dict(t.load(Config.model_path,map_location='cpu'))
model.to(device)
loss_meter = meter.AverageValueMeter()
f = open('result.txt','w')
#sys.exit(0)
for epoch in range(Config.epoch):
time0=time.time()
loss_meter.reset()
#for li,data_ in tqdm.tqdm(enumerate(dataloader)):
for li, data_ in (enumerate(dataloader)):
#print(data_.shape)
data_ = data_.long().transpose(1,0).contiguous()
data_ = data_.to(device)
#print(data_.shape)
Configimizer.zero_grad()
input_,target = data_[:-1,:],data_[1:,:]
output,_ = model(input_)
#print(output)
#print(target)
#print("Here",output.shape)
#print(target.shape)
# 这里要view(-1)
loss = criterion(output,target.view(-1))
loss.backward()
Configimizer.step()
loss_meter.add(loss.item())
# 进行可视化
if (1+li)%Config.plot_every == 0:
print("训练损失为%s"%(str(loss_meter.mean)))
f.write("训练损失为%s"%(str(loss_meter.mean)))
for word in list(u"春江花朝秋月夜"):
gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
print(gen_poetry)
f.write(gen_poetry)
f.write("\n\n\n")
f.flush()
t.save(model.state_dict(),'%s_%s_%s.pth'%(Config.model_prefix,Config.mpo,epoch))
print('used time: ', time.time()-time0)
if __name__ == '__main__':
train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment