반응형
dgl은 그래프 신경망을 쉽게 사용할 수 있도록 해주는 라이브러리이다.
그중에서도 GraphConv라는 함수를 알아보겠다.
https://docs.dgl.ai/en/0.4.x/api/python/nn.pytorch.html?highlight=graphconv#graphconv
GraphConv는 아래의 ' semi-supervised classification with graph convolutional networks' 라는 논문에서 제안한 방법이다.
아마 다른 블로그들에서 설명이 잘 되어있으니, 설명은 건너 뛰겠다! (나중에 시간이 나면 조금 더 자세히 설명할것이다)
간단히 설명하자면, 그래프에서 이웃 노드들의 특성들을 이용해서 자신의 노드를 업데이트하는 방법이다.
https://arxiv.org/abs/1609.02907
dgl의 graphconv documentation이다.
입력파라미터를 조금 자세히 설명해보자면,
in_feats : input feature의 dimension (정수)
out_feats : output feature의 dimension (정수)
가 필수적으로 들어가야한다.
그렇다면 graphconv를 compound에 대하여 돌려보도록 하자!
lipophilicity 값을 예측하는 코드이다!
In [1]:
#Predict lipophilicity with regression
In [2]:
import torch.nn as nn
import torch
import numpy as np
import warnings
if torch.cuda.is_available():
print('use GPU')
device = 'cuda'
else:
print('use CPU')
device = 'cpu'
#libraries for model
import dgl
from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.glob import MaxPooling
import dgl.backend as F
#libraries for data
from dgl.data.chem.utils import smiles_to_bigraph
from dgl.data.chem import CanonicalAtomFeaturizer #make atom feature
from torch.utils.data import DataLoader
import dgllife.data.lipophilicity #lipophilicity dataset of MoleculeNet, detail explanation : https://lifesci.dgl.ai/api/data.html?highlight=lipophilicity#dgllife.data.Lipophilicity
import dgl.data.chem.utils.splitters #data splitter
use GPU
Using backend: pytorch
In [3]:
def warn(*args, **kwargs): #ignore warning
pass
warnings.warn = warn
def mse(y,f): #measure : mean squared error
mse = ((y - f)**2).mean(axis=0)
return mse
def collate(sample): # make batch data for input of model
smi, graphs, labels = map(list,zip(*sample))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)
Make Graph Data¶
In [4]:
atom_featurizer = CanonicalAtomFeaturizer() #atom feature : one hot encoding of atom type, atom degree, number of implicit Hs on the atom, formal charge, number of radical electrons, ...
#detail explanation : https://docs.dgl.ai/en/0.4.x/generated/dgl.data.chem.CanonicalAtomFeaturizer.html?highlight=canonicalatomfeaturizer
n_feats = atom_featurizer.feat_size('h')
dataset = dgllife.data.lipophilicity.Lipophilicity(smiles_to_bigraph,atom_featurizer) #download lipophilicity data from dgllife with specified atom featurizer
Processing dgl graphs from scratch... Processing molecule 1000/4200 Processing molecule 2000/4200 Processing molecule 3000/4200 Processing molecule 4000/4200
In [5]:
print('the number of lipophilicity dataset : ',dataset.__len__())
print('first data : ', dataset[0])
the number of lipophilicity dataset : 4200 first data : ('Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14', DGLGraph(num_nodes=24, num_edges=54, ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)} edata_schemes={}), tensor([3.5400]))
In [6]:
splitter = dgl.data.chem.utils.splitters.RandomSplitter() #random data splitter initialize
split_data = splitter.train_val_test_split(dataset, frac_train = 0.8, frac_val=0.1, frac_test=0.1) #data split train:val:test = 0.8:0.1:0.1
train_loader = DataLoader(split_data[0], batch_size=32, shuffle=True, collate_fn=collate, drop_last=False) #train dataset
val_loader = DataLoader(split_data[1], batch_size=32, shuffle=False, collate_fn=collate, drop_last=False) #validation dataset
test_loader = DataLoader(split_data[2], batch_size=32, shuffle=False, collate_fn=collate, drop_last=False) #test dataset
print('the number of train data : ',len(split_data[0]))
print('the number of validation data : ',len(split_data[1]))
print('the number of test data : ',len(split_data[2]))
the number of train data : 3360 the number of validation data : 420 the number of test data : 420
GCN Model¶
In [7]:
#3 layer GCN + 1 fully connected layer model
class GCN(nn.Module):
def __init__(self,num_features_xd=74, embed_dim=128, dropout=0.1):
super(GCN, self).__init__()
self.conv1 = GraphConv(num_features_xd, embed_dim)
self.conv2 = GraphConv(embed_dim, embed_dim*2)
self.conv3 = GraphConv(embed_dim*2, embed_dim)
self.max_pooling_readout = MaxPooling() #readout function
self.relu = nn.ReLU()
self.out = nn.Linear(embed_dim, 1)
def forward(self, graph, atom_feats):
batch_num_objs = graph.batch_num_nodes #list of the number of nodes in compound
h = self.conv1(graph, atom_feats)
h = self.relu(h)
h = self.conv2(graph, h)
h = self.relu(h)
h = self.conv3(graph, h)
h_atoms = self.relu(h) #atoms representation
h_graph = self.max_pooling_readout(graph,h_atoms) #graph representation : applying max pooling to atoms representation
out = self.out(h_graph) #predict lipophilicity
return out, F.pad_packed_tensor(h_atoms,batch_num_objs,0), h_graph #prediction, atom representation, graph representation
Train and Validation¶
In [8]:
model = GCN() #model initialize
model = model.to(device) #load model to GPU
loss_fn = nn.MSELoss() #MSE loss initialize
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) #optimizer, learning rate initialize
best_mse = 1000
best_epoch = -1
avg_train_losses = [] # track the train loss per epoch
avg_val_losses = [] # track the test loss per epoch
for epoch in range(200):
#################train########################################
model.train()
epoch_loss = 0
train_losses = []
for i,data in enumerate(train_loader):
drugs,labels = data
labels = labels.to(device)
atom_feats = drugs.ndata['h'].to(device)
optimizer.zero_grad()
output, atoms, compound = model(drugs,atom_feats)
loss = loss_fn(output, labels.view(-1,1).float().to(device))
loss.backward()
optimizer.step()
train_losses.append(loss.item())
#################validation#####################################
model.eval()
total_preds = torch.Tensor()
total_labels = torch.Tensor()
with torch.no_grad():
for data in val_loader:
drugs,labels = data
labels = labels.to(device)
atom_feats = drugs.ndata['h'].to(device)
output,atoms, compound = model(drugs,atom_feats)
total_preds = torch.cat((total_preds, output.cpu()), 0)
total_labels = torch.cat((total_labels, labels.view(-1,1).cpu()), 0)
G,P = total_labels.numpy().flatten(),total_preds.numpy().flatten()
##################################################################
train_loss = np.average(train_losses)
val_loss = mse(G,P)
avg_train_losses.append(train_loss)
avg_val_losses.append(val_loss)
print_msg = (f'[{epoch}/{200}] ' +
f'train_loss: {train_loss:.5f} ' +
f'val_loss: {val_loss:.5f}')
print(print_msg)
if val_loss<best_mse:
counter = 0
print('best mse renew : '+str(best_mse) + ' --> ' + str(val_loss))
best_mse = val_loss
state = {'Epoch':epoch, 'State_dict':model.state_dict(),'optimizer':optimizer.state_dict(), 'loss':val_loss}
torch.save(state, 'model.pt') ###### model path
else :
counter = counter+1
print('Early Stopping counter : ', str(counter))
if counter > 30:
break
[0/200] train_loss: 1.95434 val_loss: 1.51559 best mse renew : 1000 --> 1.5155939 [1/200] train_loss: 1.36955 val_loss: 1.46722 best mse renew : 1.5155939 --> 1.4672236 [2/200] train_loss: 1.33241 val_loss: 1.42904 best mse renew : 1.4672236 --> 1.4290434 [3/200] train_loss: 1.29858 val_loss: 1.42048 best mse renew : 1.4290434 --> 1.4204757 [4/200] train_loss: 1.25950 val_loss: 1.33802 best mse renew : 1.4204757 --> 1.3380173 [5/200] train_loss: 1.12362 val_loss: 1.18202 best mse renew : 1.3380173 --> 1.1820222 [6/200] train_loss: 1.05155 val_loss: 1.10567 best mse renew : 1.1820222 --> 1.105675 [7/200] train_loss: 0.99904 val_loss: 1.08157 best mse renew : 1.105675 --> 1.0815665 [8/200] train_loss: 0.93803 val_loss: 0.93699 best mse renew : 1.0815665 --> 0.93698573 [9/200] train_loss: 0.85106 val_loss: 0.93227 best mse renew : 0.93698573 --> 0.9322726 [10/200] train_loss: 0.83140 val_loss: 0.82554 best mse renew : 0.9322726 --> 0.82553697 [11/200] train_loss: 0.81292 val_loss: 0.79809 best mse renew : 0.82553697 --> 0.7980915 [12/200] train_loss: 0.79233 val_loss: 0.95146 Early Stopping counter : 1 [13/200] train_loss: 0.74127 val_loss: 0.80331 Early Stopping counter : 2 [14/200] train_loss: 0.73618 val_loss: 0.84577 Early Stopping counter : 3 [15/200] train_loss: 0.72450 val_loss: 0.74624 best mse renew : 0.7980915 --> 0.7462437 [16/200] train_loss: 0.70383 val_loss: 0.77018 Early Stopping counter : 1 [17/200] train_loss: 0.68809 val_loss: 0.73231 best mse renew : 0.7462437 --> 0.7323067 [18/200] train_loss: 0.68257 val_loss: 0.73202 best mse renew : 0.7323067 --> 0.7320163 [19/200] train_loss: 0.66296 val_loss: 0.74296 Early Stopping counter : 1 [20/200] train_loss: 0.66372 val_loss: 0.74701 Early Stopping counter : 2 [21/200] train_loss: 0.65405 val_loss: 0.71536 best mse renew : 0.7320163 --> 0.71535736 [22/200] train_loss: 0.64247 val_loss: 0.72089 Early Stopping counter : 1 [23/200] train_loss: 0.60881 val_loss: 0.72555 Early Stopping counter : 2 [24/200] train_loss: 0.61465 val_loss: 0.70507 best mse renew : 0.71535736 --> 0.7050665 [25/200] train_loss: 0.58813 val_loss: 0.67148 best mse renew : 0.7050665 --> 0.6714805 [26/200] train_loss: 0.59718 val_loss: 0.77253 Early Stopping counter : 1 [27/200] train_loss: 0.57969 val_loss: 0.68160 Early Stopping counter : 2 [28/200] train_loss: 0.57552 val_loss: 0.64846 best mse renew : 0.6714805 --> 0.64845574 [29/200] train_loss: 0.56178 val_loss: 0.65439 Early Stopping counter : 1 [30/200] train_loss: 0.56844 val_loss: 0.69861 Early Stopping counter : 2 [31/200] train_loss: 0.56114 val_loss: 0.62967 best mse renew : 0.64845574 --> 0.62967014 [32/200] train_loss: 0.56019 val_loss: 0.66053 Early Stopping counter : 1 [33/200] train_loss: 0.55294 val_loss: 0.64441 Early Stopping counter : 2 [34/200] train_loss: 0.52632 val_loss: 0.64154 Early Stopping counter : 3 [35/200] train_loss: 0.53562 val_loss: 0.63579 Early Stopping counter : 4 [36/200] train_loss: 0.50573 val_loss: 0.65004 Early Stopping counter : 5 [37/200] train_loss: 0.50307 val_loss: 0.62107 best mse renew : 0.62967014 --> 0.6210739 [38/200] train_loss: 0.49255 val_loss: 0.65657 Early Stopping counter : 1 [39/200] train_loss: 0.49819 val_loss: 0.60692 best mse renew : 0.6210739 --> 0.6069179 [40/200] train_loss: 0.51618 val_loss: 0.60522 best mse renew : 0.6069179 --> 0.6052168 [41/200] train_loss: 0.50862 val_loss: 0.62639 Early Stopping counter : 1 [42/200] train_loss: 0.49511 val_loss: 0.62095 Early Stopping counter : 2 [43/200] train_loss: 0.47859 val_loss: 0.59625 best mse renew : 0.6052168 --> 0.5962458 [44/200] train_loss: 0.47655 val_loss: 0.60315 Early Stopping counter : 1 [45/200] train_loss: 0.46269 val_loss: 0.61914 Early Stopping counter : 2 [46/200] train_loss: 0.45547 val_loss: 0.63928 Early Stopping counter : 3 [47/200] train_loss: 0.49106 val_loss: 0.62706 Early Stopping counter : 4 [48/200] train_loss: 0.45177 val_loss: 0.60714 Early Stopping counter : 5 [49/200] train_loss: 0.46692 val_loss: 0.59806 Early Stopping counter : 6 [50/200] train_loss: 0.44425 val_loss: 0.59351 best mse renew : 0.5962458 --> 0.59351224 [51/200] train_loss: 0.44528 val_loss: 0.62585 Early Stopping counter : 1 [52/200] train_loss: 0.44690 val_loss: 0.63859 Early Stopping counter : 2 [53/200] train_loss: 0.44482 val_loss: 0.59164 best mse renew : 0.59351224 --> 0.59163505 [54/200] train_loss: 0.42760 val_loss: 0.58326 best mse renew : 0.59163505 --> 0.58325666 [55/200] train_loss: 0.43398 val_loss: 0.59338 Early Stopping counter : 1 [56/200] train_loss: 0.42194 val_loss: 0.61062 Early Stopping counter : 2 [57/200] train_loss: 0.41605 val_loss: 0.64674 Early Stopping counter : 3 [58/200] train_loss: 0.42756 val_loss: 0.61521 Early Stopping counter : 4 [59/200] train_loss: 0.42904 val_loss: 0.56738 best mse renew : 0.58325666 --> 0.56738055 [60/200] train_loss: 0.42589 val_loss: 0.65219 Early Stopping counter : 1 [61/200] train_loss: 0.42727 val_loss: 0.56402 best mse renew : 0.56738055 --> 0.56401503 [62/200] train_loss: 0.41769 val_loss: 0.58184 Early Stopping counter : 1 [63/200] train_loss: 0.39143 val_loss: 0.58395 Early Stopping counter : 2 [64/200] train_loss: 0.43507 val_loss: 0.67928 Early Stopping counter : 3 [65/200] train_loss: 0.40078 val_loss: 0.72767 Early Stopping counter : 4 [66/200] train_loss: 0.39728 val_loss: 0.55930 best mse renew : 0.56401503 --> 0.5593039 [67/200] train_loss: 0.39248 val_loss: 0.58432 Early Stopping counter : 1 [68/200] train_loss: 0.38073 val_loss: 0.56519 Early Stopping counter : 2 [69/200] train_loss: 0.38924 val_loss: 0.61464 Early Stopping counter : 3 [70/200] train_loss: 0.40178 val_loss: 0.58014 Early Stopping counter : 4 [71/200] train_loss: 0.38269 val_loss: 0.62222 Early Stopping counter : 5 [72/200] train_loss: 0.39362 val_loss: 0.56200 Early Stopping counter : 6 [73/200] train_loss: 0.36554 val_loss: 0.56383 Early Stopping counter : 7 [74/200] train_loss: 0.36350 val_loss: 0.55055 best mse renew : 0.5593039 --> 0.5505514 [75/200] train_loss: 0.36439 val_loss: 0.57905 Early Stopping counter : 1 [76/200] train_loss: 0.35732 val_loss: 0.56242 Early Stopping counter : 2 [77/200] train_loss: 0.36075 val_loss: 0.57157 Early Stopping counter : 3 [78/200] train_loss: 0.35178 val_loss: 0.54656 best mse renew : 0.5505514 --> 0.54656196 [79/200] train_loss: 0.35065 val_loss: 0.55964 Early Stopping counter : 1 [80/200] train_loss: 0.34390 val_loss: 0.55837 Early Stopping counter : 2 [81/200] train_loss: 0.35124 val_loss: 0.66200 Early Stopping counter : 3 [82/200] train_loss: 0.36465 val_loss: 0.54476 best mse renew : 0.54656196 --> 0.5447594 [83/200] train_loss: 0.35356 val_loss: 0.55710 Early Stopping counter : 1 [84/200] train_loss: 0.33278 val_loss: 0.55100 Early Stopping counter : 2 [85/200] train_loss: 0.34809 val_loss: 0.56183 Early Stopping counter : 3 [86/200] train_loss: 0.34225 val_loss: 0.58601 Early Stopping counter : 4 [87/200] train_loss: 0.34343 val_loss: 0.56375 Early Stopping counter : 5 [88/200] train_loss: 0.32546 val_loss: 0.55512 Early Stopping counter : 6 [89/200] train_loss: 0.34227 val_loss: 0.55642 Early Stopping counter : 7 [90/200] train_loss: 0.33023 val_loss: 0.54111 best mse renew : 0.5447594 --> 0.54110926 [91/200] train_loss: 0.33342 val_loss: 0.54837 Early Stopping counter : 1 [92/200] train_loss: 0.31676 val_loss: 0.63164 Early Stopping counter : 2 [93/200] train_loss: 0.32978 val_loss: 0.53880 best mse renew : 0.54110926 --> 0.5387994 [94/200] train_loss: 0.31332 val_loss: 0.57774 Early Stopping counter : 1 [95/200] train_loss: 0.32147 val_loss: 0.53801 best mse renew : 0.5387994 --> 0.53801316 [96/200] train_loss: 0.31640 val_loss: 0.53207 best mse renew : 0.53801316 --> 0.532071 [97/200] train_loss: 0.31953 val_loss: 0.54631 Early Stopping counter : 1 [98/200] train_loss: 0.30683 val_loss: 0.53966 Early Stopping counter : 2 [99/200] train_loss: 0.30068 val_loss: 0.54622 Early Stopping counter : 3 [100/200] train_loss: 0.31020 val_loss: 0.53605 Early Stopping counter : 4 [101/200] train_loss: 0.30062 val_loss: 0.60919 Early Stopping counter : 5 [102/200] train_loss: 0.29344 val_loss: 0.52990 best mse renew : 0.532071 --> 0.5299048 [103/200] train_loss: 0.29078 val_loss: 0.57145 Early Stopping counter : 1 [104/200] train_loss: 0.30019 val_loss: 0.59016 Early Stopping counter : 2 [105/200] train_loss: 0.29959 val_loss: 0.54217 Early Stopping counter : 3 [106/200] train_loss: 0.28553 val_loss: 0.54888 Early Stopping counter : 4 [107/200] train_loss: 0.28805 val_loss: 0.58825 Early Stopping counter : 5 [108/200] train_loss: 0.33146 val_loss: 0.53426 Early Stopping counter : 6 [109/200] train_loss: 0.29533 val_loss: 0.54330 Early Stopping counter : 7 [110/200] train_loss: 0.28567 val_loss: 0.56272 Early Stopping counter : 8 [111/200] train_loss: 0.29065 val_loss: 0.55840 Early Stopping counter : 9 [112/200] train_loss: 0.29010 val_loss: 0.54545 Early Stopping counter : 10 [113/200] train_loss: 0.27936 val_loss: 0.56230 Early Stopping counter : 11 [114/200] train_loss: 0.28225 val_loss: 0.58313 Early Stopping counter : 12 [115/200] train_loss: 0.27352 val_loss: 0.55847 Early Stopping counter : 13 [116/200] train_loss: 0.27312 val_loss: 0.55384 Early Stopping counter : 14 [117/200] train_loss: 0.26610 val_loss: 0.56275 Early Stopping counter : 15 [118/200] train_loss: 0.26496 val_loss: 0.54125 Early Stopping counter : 16 [119/200] train_loss: 0.25832 val_loss: 0.55301 Early Stopping counter : 17 [120/200] train_loss: 0.26143 val_loss: 0.54153 Early Stopping counter : 18 [121/200] train_loss: 0.26160 val_loss: 0.53584 Early Stopping counter : 19 [122/200] train_loss: 0.25601 val_loss: 0.54358 Early Stopping counter : 20 [123/200] train_loss: 0.26420 val_loss: 0.54525 Early Stopping counter : 21 [124/200] train_loss: 0.26093 val_loss: 0.56776 Early Stopping counter : 22 [125/200] train_loss: 0.26606 val_loss: 0.53230 Early Stopping counter : 23 [126/200] train_loss: 0.26330 val_loss: 0.54361 Early Stopping counter : 24 [127/200] train_loss: 0.25213 val_loss: 0.57000 Early Stopping counter : 25 [128/200] train_loss: 0.25876 val_loss: 0.53212 Early Stopping counter : 26 [129/200] train_loss: 0.24504 val_loss: 0.55074 Early Stopping counter : 27 [130/200] train_loss: 0.23742 val_loss: 0.58509 Early Stopping counter : 28 [131/200] train_loss: 0.25486 val_loss: 0.53119 Early Stopping counter : 29 [132/200] train_loss: 0.24540 val_loss: 0.55007 Early Stopping counter : 30 [133/200] train_loss: 0.24446 val_loss: 0.56702 Early Stopping counter : 31
In [11]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(avg_train_losses)+1),avg_train_losses, label='Training Loss')
plt.plot(range(1,len(avg_val_losses)+1),avg_val_losses,label='Validation Loss')
# validation loss의 최저값 지점을 찾기
minposs = avg_val_losses.index(min(avg_val_losses))+1
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 1) # 일정한 scale
plt.xlim(0, len(avg_train_losses)+1) # 일정한 scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
In [12]:
print('validation result')
print('MSE : ', best_mse)
print('RMSE : ', best_mse**0.5)
validation result MSE : 0.5299048 RMSE : 0.7279455905299957
Test¶
In [13]:
trained_model = torch.load('model.pt') #####model path
model = GCN()
model.load_state_dict(trained_model['State_dict']) #load trained weight
model = model.to(device)
model.eval()
total_preds = torch.Tensor()
total_labels = torch.Tensor()
atoms_list=[]
compounds_list = []
with torch.no_grad():
for data in test_loader:
drugs,labels = data
labels = labels.to(device)
atom_feats = drugs.ndata['h'].to(device)
output,atoms, compound = model(drugs,atom_feats)
total_preds = torch.cat((total_preds, output.cpu()), 0)
total_labels = torch.cat((total_labels, labels.view(-1,1).cpu()), 0)
atoms_list.append(atoms.cpu())
compounds_list.append(compound.cpu())
G,P = total_labels.numpy().flatten(),total_preds.numpy().flatten()
In [14]:
result = mse(G,P)
print('test result')
print('MSE : ', result)
print('RMSE : ', result**0.5)
test result MSE : 0.60317504 RMSE : 0.7766434471878028
Representation¶
In [15]:
import seaborn as sns
import matplotlib.pyplot as plt
#atom representation in frist batch, first compound
plt.rcParams["figure.figsize"] = (15,15)
sns.heatmap(atoms_list[0][0])
plt.xlabel('Dimension',fontsize=14)
plt.ylabel('Atom representation',fontsize=14)
plt.title('Compound Embedding Matrix',fontsize=18)
plt.show()
In [16]:
#compound representation (frist batch, first compound)
plt.rcParams["figure.figsize"] = (15,0.5)
sns.heatmap(np.expand_dims(compounds_list[0][0],0))
plt.xlabel('Dimension',fontsize=14)
plt.title('Compound representation',fontsize=18)
plt.show()
In [ ]:
반응형
'연구' 카테고리의 다른 글
rdkit으로 SMILES를 canonical smiles로 바꾸는 법 (0) | 2022.09.14 |
---|---|
multiple sequence alignment(MSA) (0) | 2022.09.10 |
empirical probability란 무엇일까 (0) | 2022.09.09 |
HMMER (0) | 2022.09.06 |
deep graph library - pad_packed_tensor (0) | 2022.08.31 |