1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| import numpy as np import pandas as pd import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader from torch import nn from os.path import exists batch_size=100 word_cnt=29
src=pd.read_csv("./data/data.csv")
class mydata(Dataset): def __init__(self,typ): self.data=src[src.part==typ] self.typ=typ def __getitem__(self, idx): sen=[int(x) for x in self.data.iloc[idx]['x'].split(',')] oht=np.zeros((15,word_cnt)) for i in range(min(len(sen),15)): oht[i,sen[i]-1]=1 return torch.FloatTensor(oht),int(self.data.iloc[idx]['y']) def __len__(self): return len(self.data) train_set=mydata("train") train_load=DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_set=mydata("test") test_load=DataLoader(dataset=test_set,batch_size=batch_size,shuffle=True)
val_set=mydata("val") val_load=DataLoader(dataset=val_set,batch_size=batch_size,shuffle=True)
class Mol(nn.Module): def __init__(self): super().__init__() self.h=50 self.mol=nn.Sequential( nn.Conv1d(15,self.h,5,2),nn.ELU(), nn.Conv1d(self.h,self.h,5,2),nn.ELU(), nn.Conv1d(self.h,self.h,5,1),nn.ELU(), ) self.lin=nn.Linear(self.h,18) def forward(self,x): y1=self.mol(x).squeeze(dim=2) return self.lin(y1) mynn=Mol()
def test_accuracy(data_load): with torch.no_grad(): siz=0 ac=0 for data in data_load: sen,tag=data out=mynn(sen) for x,y in zip(out,tag): x=x.argmax(dim=0) siz+=1 if x==y: ac+=1 print("准确率为{:f}".format(ac/siz)) def train_model(): epoch=0 train_step=0 loss_fn=nn.CrossEntropyLoss() optim=torch.optim.Adam(mynn.parameters(), lr=1e-3)
for epoch in range(30): print("批次:{}".format(epoch)) for data in train_load: optim.zero_grad() sen,tag=data output=mynn(sen) res_loss=loss_fn(output,tag) res_loss.backward() optim.step() train_step+=1 if train_step%10==0: print("训练次数:{},loss:{}".format(train_step,res_loss)) test_accuracy(test_load) torch.save(mynn.state_dict(),"./model/epoch_{}.pth".format(epoch)) torch.save(mynn.state_dict(),"./model/final.pth") if not exists("./model/final.pth"): train_model() else: mynn.load_state_dict(torch.load("./model/final.pth")) test_accuracy(val_load)
|