-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfine-tune.py
More file actions
240 lines (186 loc) · 8.81 KB
/
fine-tune.py
File metadata and controls
240 lines (186 loc) · 8.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch
import torch.nn as nn
import numpy as np
from data_preprocess import SpectrumRead
from dataloader import SpectrumLoader
from model import Discriminator, Generator, spectrum_refine
from evalue import tag_correct_discriminate, AverageMeter
# Data parameters
data_folder = '../data' # the path of tag generation results
mgf_file_path = "../data/mouse.mgf" # the path of spectrum mgf file
# the intersection labeled dataset can be obtained by data_label_union.py
# according to the identification results of PEAKS, MSFragger+ and Open-pFind.
spectrum_results_file_path = '../data/mouse.txt'
# train and evaluation datasets path, can be created automatically based on mgf file.
train_tag_feature_path = '../data/tag_feature_dataset.pkl'
train_tag_label_path = '../data/tag_label_dataset.pkl'
val_tag_feature_path = '../data/val_tag_feature_dataset.pkl'
val_tag_label_path = '../data/val_tag_label_dataset.pkl'
# Model parameters
check_point_D = 'discriminator_check_point.pth.tar'
check_point_flag = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
batch_size = 512
lamb = 0.3
xi = 0.2
# Learning parameters
num_epoch = 100
lr = 1e-4
works = 4 # number of workers for accelerating the training.
print_freq = 200 # print loss every 200 batches.
best_accuracy = 0.
max_game_step = 5 # number of adversial game rounds.
mass_error = 0.01 # the mass difference for each amino acid inference.
# when find the optimal weights,
# we gradually reduce the search range for faster adjustation.
mass_error_reduction = 0.5
gap = 0.2
def main():
global check_point_flag, check_point_D, num_epoch, best_accuracy, mass_error, gap
# Read the mgf file to obtain the spectra.
s = SpectrumRead(mgf_file_path)
spectra_list = s.read_spectrum_data()
print("Read the mgf file successfully!")
# Fileter the unlabel spectra.
spectra_list = spectrum_refine(spectra_list, spectrum_results_file_path)
# generate the tag candidats with defeated parameters.
print("The program is extracting the tag sequence conditioned on each spectrum: ")
t = Generator(spectra_list, mass_error)
t.process_spectrum_list()
print("The extracted tag sequences are saved!")
# save the tag candidates features.
t.save_tag_feature(data_folder)
# Label the generated tags.
tag_correct_discriminate(data_folder, spectrum_results_file_path)
# Initialize model or load checkpoint
if check_point_flag is False:
discriminator = Discriminator()
else:
check_point_D = torch.load(check_point_D)
discriminator = check_point_D['model']
# Move to GPU, if available.
discriminator = discriminator.to(device)
# model parameters optimization with classical SGD.
optimizer = torch.optim.SGD(discriminator.parameters(), lr=lr)
# Loss function.
criterion = nn.CrossEntropyLoss().to(device)
# criterion = ImprovedLoss().to(device)
# training and evaluation dataloaders.
train_loader = DataLoader(SpectrumLoader(train_tag_feature_path, train_tag_label_path),
batch_size=batch_size, shuffle=True, num_workers=works)
val_loader = DataLoader(SpectrumLoader(val_tag_feature_path, val_tag_label_path),
batch_size=batch_size, shuffle=True, num_workers=works)
print('The Tag Discriminator is ready to train.')
for epoch in range(num_epoch):
# One epoch's training
discriminator_train(train_loader=train_loader,
model=discriminator,
criterion=criterion,
optimizer=optimizer,
epoch=epoch)
# One epoch's validation
recent_accuracy = discriminator_val(val_loader=val_loader,
model=discriminator)
#
# print(recent_accuracy)
# check if was best and save the best checkpoint.
is_best = recent_accuracy > best_accuracy
if is_best:
filename = 'discriminator_check_point.pth.tar'
state = {'model': discriminator,
'accuracy': best_accuracy}
torch.save(state, filename)
# After initialize the generator and discriminator, we further incorporate the adversial learning to
# fine-tune the model for better performance.
# Usually 2 iterative steps is enough for performance metric convergence.
print('The Adversial Game is ready to play!')
for i in range(max_game_step):
# adjust the tag generator parameters conditioned on the trained discriminator.
mass_error = generator_adjust(mass_error, discriminator, spectra_list, val_loader)
# re-train the dicriminator and store the best model conditioned on current tag generator.
for epoch in range(num_epoch):
discriminator_train(train_loader=train_loader,
model=discriminator,
criterion=criterion,
optimizer=optimizer,
epoch=epoch)
recent_accuracy = discriminator_val(val_loader=val_loader,
model=discriminator)
is_best = recent_accuracy > best_accuracy
if is_best:
filename = 'discriminator_check_point.pth.tar'
state = {'model': discriminator,
'accuracy': best_accuracy}
torch.save(state, filename)
print('Congratulations! trained model has been saved!')
def generator_adjust(mass_error, discriminator, spectra_list, val_loader):
global gap, mass_error_reduction, best_accuracy, lamb
best_mass_error = mass_error
gap = mass_error_reduction * gap
mass_gap = gap * mass_error
best_reward = 0
# grid search under the dynamic parameter range.
for error in np.arange(mass_error-mass_gap, mass_error+mass_gap, 0.5*mass_gap):
# re-generate the candidate tag features.
t = Generator(spectra_list, error)
t.process_spectrum_list()
t.save_tag_feature(data_folder)
tag_correct_discriminate(data_folder, spectrum_results_file_path)
# calculate the accuracy and weighted reward for the tag generator.
current_accuracy = discriminator_val(val_loader, discriminator)
current_reward = lamb * current_accuracy + (1 - lamb) * (1 - current_accuracy)
# store the best model conditioned on current discriminator.
if current_reward > best_reward:
best_mass_error = error
# The pre-defined stopping criterion.
if abs(best_reward-current_reward) < 0.001:
break
return best_mass_error
# train the tag discriminator.
def discriminator_train(train_loader, model, criterion, optimizer, epoch):
model.train()
losses = AverageMeter()
for i, (tag_feature, tag_label) in enumerate(train_loader):
tag_feature = tag_feature.to(device)
# dimension adjusts for tag_label.
tag_label_1_d = [t[0] for t in tag_label]
tag_label_1_d = torch.LongTensor(tag_label_1_d)
tag_label = tag_label_1_d.to(device)
# Forward prop.
scores = model(tag_feature)
print(tag_label)
# calculate the loss.
loss = criterion(scores, tag_label)
# Back prop.
optimizer.zero_grad()
loss.backward()
# Update weights.
optimizer.step()
losses.update(loss.item())
# print the real-time loss.
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})'
.format(epoch, i, len(train_loader), loss=losses))
# Test the performance of discriminator.
def discriminator_val(val_loader, model):
model.eval()
accuracy = AverageMeter()
with torch.no_grad():
for i, (tag_feature, tag_label) in enumerate(val_loader):
tag_feature = tag_feature.to(device)
tag_label = tag_label.to(device)
scores = model(tag_feature)
# compare the prediction label and true label to compute the accuracy.
predict_idx = torch.max(scores, 1)[1]
idx = torch.max(tag_label, 1)[1]
assert idx.size(0) == predict_idx.size(0)
accuracy_tmp = sum(predict_idx == idx).item() / idx.size(0)
# We provide the real-time accuray and total accuracy to observe the training process.
accuracy.update(accuracy_tmp)
return accuracy.avg
if __name__ == '__main__':
main()