banner
NEWS LETTER

ECnet代码解读(二)

Scroll down

#ecnet.py#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ECNet(object):
def __init__(self,
output_dir=None,
train_tsv=None, test_tsv=None,
fasta=None, ccmpred_output=None,
use_loc_feat=True, use_glob_feat=True,
split_ratio=[0.9, 0.1],
random_seed=42,
nn_name='lstm', n_ensembles=1,
d_embed=20, d_model=128, d_h=128, nlayers=1,
batch_size=128, save_log=False):

self.dataset = Dataset(
train_tsv=train_tsv, test_tsv=test_tsv,
fasta=fasta, ccmpred_output=ccmpred_output,
use_loc_feat=use_loc_feat, use_glob_feat=use_glob_feat,
split_ratio=split_ratio,
random_seed=random_seed)
self.saver = Saver(output_dir=output_dir)
self.logger = Logger(logfile=self.saver.save_dir/'exp.log' if save_log else None)
self.use_loc_feat = use_loc_feat
self.use_glob_feat = use_glob_feat
vocab_size = len(vocab.AMINO_ACIDS)
seq_len = len(self.dataset.native_sequence)
proj_loc_config = {
'layer': nn.Linear,
'd_in': seq_len + 1,
'd_out': min(128, seq_len)
}
proj_glob_config = {
'layer': nn.Identity,
'd_in': 768,
'd_out': 768,
}

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if nn_name in ['lstm', 'blstm']:
self.models = [LSTMPredictor(
d_embed=d_embed, d_model=d_model, d_h=d_h, nlayers=nlayers,
vocab_size=vocab_size, seq_len=seq_len,
bidirectional=True if nn_name == 'blstm' else False,
use_loc_feat=use_loc_feat, use_glob_feat=use_glob_feat,
proj_loc_config=proj_loc_config, proj_glob_config=proj_glob_config
).to(self.device) for _ in range(n_ensembles)]
else:
raise NotImplementedError

self.criterion = F.mse_loss
self.batch_size = batch_size
self.optimizers = [optim.Adam(model.parameters()) for model in self.models]
self._test_pack = None

__init__ 方法

这个构造函数负责初始化 ECNet 类的所有必要组件。

参数

  • output_dir:保存输出的目录路径。
  • train_tsv, test_tsv:训练和测试数据的 TSV 文件路径。
  • fasta, ccmpred_output:FASTA 文件路径和 CCMPred 输出路径。
  • use_loc_feat, use_glob_feat:是否使用本地和全局特征。
  • split_ratio:数据划分比例。
  • random_seed:随机种子,用于确保可重现性。
  • nn_name:神经网络类型,例如 LSTM 或双向 LSTM。
  • n_ensembles:模型集合的大小。
  • d_embed, d_model, d_h, nlayers:模型的维度和层数参数。
  • batch_size:训练批次的大小。
  • save_log:是否保存日志。

主要组件和流程

  1. 数据集:使用 Dataset 类创建数据集对象,包括所有训练、验证和测试数据。
  2. 保存器:使用 Saver 类创建保存器对象,用于管理输出目录和文件保存。
  3. 日志记录器:使用 Logger 类创建日志记录器,可选择保存到文件。
  4. 模型参数:计算和设置模型的特定参数,如词汇表大小、序列长度和投影层配置。
  5. 设备:检测是否有可用的 GPU,并据此设置 PyTorch 设备。
  6. 模型创建:根据指定的网络类型和参数创建一个或多个 LSTM 预测器模型。
  7. 损失函数:设置均方误差损失函数。
  8. 优化器:为每个模型创建 Adam 优化器。

模型的加载点和重用接口设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@property
def test_pack(self):
if self._test_pack is None:
test_loader, test_df = self.dataset.get_dataloader(
'test', batch_size=self.batch_size, return_df=True)
self._test_pack = (test_loader, test_df)
return self._test_pack

@property
def test_loader(self):
return self.test_pack[0]

@property
def test_df(self):
return self.test_pack[1]

def load_checkpoint(self, checkpoint_dir):
checkpoint_dir = pathlib.Path(checkpoint_dir)
if not checkpoint_dir.is_dir():
raise ValueError(f'{checkpoint_dir} is not a directory')
for i in range(len(self.models)):
checkpoint_path = f'{checkpoint_dir}/model_{i + 1}.pt'
self.logger.info('Load pretrained model from {}'.format(checkpoint_path))
pt = torch.load(checkpoint_path)
model_dict = self.models[i].state_dict()
model_pretrained_dict = {k: v for k, v in pt['model_state_dict'].items() if k in model_dict}
model_dict.update(model_pretrained_dict)
self.models[i].load_state_dict(model_dict)
self.optimizers[i].load_state_dict(pt['optimizer_state_dict'])


def load_single_pretrained_model(self, checkpoint_path, model=None, optimizer=None, is_resume=False):
self.logger.info('Load pretrained model from {}'.format(checkpoint_path))
pt = torch.load(checkpoint_path)
model_dict = model.state_dict()
model_pretrained_dict = {k: v for k, v in pt['model_state_dict'].items() if k in model_dict}
model_dict.update(model_pretrained_dict)
model.load_state_dict(model_dict)
optimizer.load_state_dict(pt['optimizer_state_dict'])
return (model, optimizer, pt['log_info']) if is_resume else (model, optimizer)


def save_checkpoint(self, ckp_name=None, model_dict=None, opt_dict=None, log_info=None):
ckp = {'model_state_dict': model_dict,
'optimizer_state_dict': opt_dict}
ckp['log_info'] = log_info
self.saver.save_ckp(ckp, ckp_name)


test_pack 属性

这个属性返回一个包含测试数据加载器和测试 DataFrame 的元组。

  • 流程

    • 如果 _test_pack 尚未设置,则使用 get_dataloader 方法从数据集中获取测试数据加载器和测试 DataFrame。
    • 将这些值存储在 _test_pack 中,并返回。

test_loadertest_df 属性

这两个属性是 test_pack 属性的便捷访问器,分别返回测试数据加载器和测试 DataFrame。

load_checkpoint 方法

  • 目的:从给定的目录加载预训练的模型检查点。

  • 参数checkpoint_dir - 包含模型检查点文件的目录路径。

  • 流程

    • 验证给定的路径是否为目录。
    • 对于每个模型:
      • 构建检查点文件的路径。
      • 从文件加载检查点。
      • 更新模型的状态字典并加载到模型中。
      • 加载优化器的状态字典。

load_single_pretrained_model 方法

  • 目的:从给定的文件路径加载单个预训练模型。

  • 参数checkpoint_path - 检查点文件的路径;model - 要加载的模型;optimizer - 要加载的优化器;is_resume - 是否返回日志信息。

  • 流程

    • 从文件加载检查点。
    • 更新模型的状态字典并加载到模型中。
    • 加载优化器的状态字典。
    • 返回模型、优化器和可选的日志信息。

save_checkpoint 方法

  • 目的:保存模型的检查点。

  • 参数ckp_name - 检查点的名称;model_dict - 模型的状态字典;opt_dict - 优化器的状态字典;log_info - 日志信息。

  • 流程

    • 创建一个包含状态字典和日志信息的检查点字典。
    • 使用 saversave_ckp 方法保存检查点。

模型训练和测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def train(self, epochs=1000, log_freq=100, eval_freq=50,
patience=500, save_checkpoint=False, resume_path=None):
assert eval_freq <= log_freq
monitoring_score = 'corr'
for midx, (model, optimizer) in enumerate(zip(self.models, self.optimizers), start=1):
(train_loader, train_df), (valid_loader, valid_df) = \
self.dataset.get_dataloader(
'train_valid', self.batch_size,
return_df=True, resample_train_valid=True)
if resume_path is not None:
model, optimizer, log_info = self.load_single_pretrained_model(
'{}/model_{}.pt'.format(resume_path, midx),
model=model, optimizer=optimizer, is_resume=True)
start_epoch = log_info['epoch'] + 1
best_score = log_info['best_{}'.format(monitoring_score)]
else:
start_epoch = 1
best_score = None

best_model_state_dict = None
stopper = EarlyStopping(patience=patience, eval_freq=eval_freq, best_score=best_score)
model.train()
try:
for epoch in range(start_epoch, epochs + 1):
time_start = time.time()
tot_loss = 0
for step, batch in tqdm(enumerate(train_loader, 1),
leave=False, desc=f'M-{midx} E-{epoch}', total=len(train_loader)):
y = batch['label'].to(self.device)
X = batch['seq_enc'].to(self.device)
if self.use_loc_feat:
loc_feat = batch['loc_feat'].to(self.device)
else:
loc_feat = None
if self.use_glob_feat:
glob_feat = batch['glob_feat'].to(self.device)
else:
glob_feat = None

optimizer.zero_grad()
output = model(X, glob_feat=glob_feat, loc_feat=loc_feat)
output = output.view(-1)
loss = self.criterion(output, y)

loss.backward()
optimizer.step()
tot_loss += loss.item()

if epoch % eval_freq == 0:
val_results = self.test(test_model=model, test_loader=valid_loader,
test_df=valid_df, mode='val')
model.train()
is_best = stopper.update(val_results['metric'][monitoring_score])
if is_best:
best_model_state_dict = copy.deepcopy(model.state_dict())
if save_checkpoint:
self.save_checkpoint(ckp_name='model_{}.pt'.format(midx),
model_dict=model.state_dict(),
opt_dict=optimizer.state_dict(),
log_info={
'epoch': epoch,
'best_{}'.format(monitoring_score): stopper.best_score,
'val_loss':val_results['loss'],
'val_results':val_results['metric']
})

if epoch % log_freq == 0:
train_results = self.test(test_model=model, test_loader=train_loader,
test_df=train_df, mode='val')
if (log_freq <= eval_freq) or (log_freq % eval_freq != 0):
val_results = self.test(test_model=model, test_loader=valid_loader,
test_df=valid_df, mode='val')
model.train()
self.logger.info(
'Model: {}/{}'.format(midx, len(self.models))
+ '\tEpoch: {}/{}'.format(epoch, epochs)
+ '\tTrain loss: {:.4f}'.format(tot_loss / step)
+ '\tVal loss: {:.4f}'.format(val_results['loss'])
+ '\t' + '\t'.join(['Val {}: {:.4f}'.format(k, v) \
for (k, v) in val_results['metric'].items()])
+ '\tBest {n}: {b:.4f}\t'.format(n=monitoring_score, b=stopper.best_score)
+ '\t{:.1f} s/epoch'.format(time.time() - time_start)
)
time_start = time.time()

if stopper.early_stop:
self.logger.info('Eearly stop at epoch {}'.format(epoch))
break
except KeyboardInterrupt:
self.logger.info('Exiting model training from keyboard interrupt')
if best_model_state_dict is not None:
model.load_state_dict(best_model_state_dict)

test_results = self.test(test_model=model, model_label='model_{}'.format(midx))
test_res_msg = 'Testing Model {}: Loss: {:.4f}\t'.format(midx, test_results['loss'])
test_res_msg += '\t'.join(['Test {}: {:.6f}'.format(k, v) \
for (k, v) in test_results['metric'].items()])
self.logger.info(test_res_msg + '\n')

train 方法

  • 目的:训练模型的整个过程。

  • 参数

    • epochs:训练周期的数量。
    • log_freq:记录训练信息的频率。
    • eval_freq:进行验证和检查早期停止的频率。
    • patience:早期停止的耐心参数,即在验证得分没有改进时允许的连续周期数。
    • save_checkpoint:是否在找到更好的验证得分时保存检查点。
    • resume_path:从先前保存的检查点恢复训练的路径。

主要流程

  1. 模型循环:遍历所有模型和对应的优化器。

  2. 加载或恢复检查点:如果提供了 resume_path,则从检查点加载或恢复模型。

  3. 准备训练和验证数据加载器:从数据集中获取训练和验证的数据加载器。

  4. 早期停止:初始化一个早期停止对象来监视验证得分。

  5. 训练循环

    • 将模型设置为训练模式。
    • 对于每个周期:
      • 初始化时间和损失计数器。
      • 遍历训练加载器中的批次:
        • 将批次数据移动到适当的设备上。
        • 执行前向传递和损失计算。
        • 执行反向传递和优化器步骤。
        • 累积批次损失。
      • 如果到达验证频率,则在验证集上评估模型,并使用早期停止对象更新得分。
      • 如果到达日志频率,则在训练和验证集上评估模型,并记录结果。
      • 检查早期停止条件,如果满足,则中断训练。
  6. 异常处理:捕获键盘中断,允许用户手动停止训练。

  7. 加载最佳模型:如果找到更好的模型,则加载最佳状态字典。

  8. 测试:在测试集上测试模型并记录结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def test(self, test_model=None, test_loader=None, test_df=None,
checkpoint_dir=None, save_prediction=False,
calc_metric=True, calc_loss=True, model_label=None, mode='test'):
if checkpoint_dir is not None:
self.load_pretrained_model(checkpoint_dir)
if test_loader is None and test_df is None:
test_loader = self.test_loader
test_df = self.test_df
test_models = self.models if test_model is None else [test_model]
esb_ypred, esb_yprob = None, None
esb_loss = 0
for model in test_models:
model.eval()
y_true, y_pred, y_prob = None, None, None
tot_loss = 0
with torch.no_grad():
for step, batch in tqdm(enumerate(test_loader, 1),
desc=mode, leave=False, total=len(test_loader)):
X = batch['seq_enc'].to(self.device)
if self.use_loc_feat:
loc_feat = batch['loc_feat'].to(self.device)
else:
loc_feat = None
if self.use_glob_feat:
glob_feat = batch['glob_feat'].to(self.device)
else:
glob_feat = None

output = model(X, glob_feat=glob_feat, loc_feat=loc_feat)
output = output.view(-1)
if calc_loss:
y = batch['label'].to(self.device)
loss = self.criterion(output, y)
tot_loss += loss.item()
y_pred = output if y_pred is None else torch.cat((y_pred, output), dim=0)

y_pred = y_pred.detach().cpu() if self.device == torch.device('cuda') else y_pred.detach()
esb_ypred = y_pred.view(-1, 1) if esb_ypred is None else torch.cat((esb_ypred, y_pred.view(-1, 1)), dim=1)
esb_loss += tot_loss / step

esb_ypred = esb_ypred.mean(axis=1).numpy()
esb_loss /= len(test_models)

if calc_metric:
y_fitness = test_df['score'].values
eval_results = scipy.stats.spearmanr(y_fitness, esb_ypred)[0]

test_results = {}
results_df = test_df.copy()
results_df = results_df.drop(columns=['sequence'])
results_df['prediction'] = esb_ypred
test_results['df'] = results_df
if save_prediction:
self.saver.save_df(results_df, 'prediction.tsv')
test_results['loss'] = esb_loss
if calc_metric:
test_results['metric'] = {'corr': eval_results}
return test_results

test 方法

  • 目的:对给定的测试模型在测试加载器上的性能进行评估。

  • 参数

    • test_model:要测试的模型(可选)。
    • test_loader:包含测试数据的数据加载器(可选)。
    • test_df:包含测试数据的 DataFrame(可选)。
    • checkpoint_dir:从检查点加载预训练模型的目录(可选)。
    • save_prediction:是否保存预测结果到文件(可选)。
    • calc_metric:是否计算评估指标(例如相关性)(可选)。
    • calc_loss:是否计算损失(可选)。
    • model_label:模型标签(可选,未在代码中使用)。
    • mode:描述测试模式的字符串(例如 ‘test’ 或 ‘val’)。

主要流程

  1. 加载预训练模型:如果提供了 checkpoint_dir,则从检查点加载预训练模型。

  2. 设置测试加载器和 DataFrame:如果没有提供,使用默认的测试加载器和测试 DataFrame。

  3. 选择测试模型:如果没有提供 test_model,则使用所有模型进行测试。

  4. 模型循环

    :遍历要测试的模型。

    • 将模型设置为评估模式。
    • 初始化真实值和预测值的变量。
    • 遍历测试加载器中的批次:
      • 将批次数据移动到适当的设备上。
      • 执行前向传递。
      • 如果 calc_loss 为 True,则计算损失。
      • 收集预测值。
    • 累积预测值和损失。
  5. 计算集成预测和损失:通过平均所有模型的预测和损失来计算集成结果。

  6. 计算评估指标:如果 calc_metric 为 True,则计算相关性等评估指标。

  7. 保存预测:如果 save_prediction 为 True,则将预测结果保存到文件中。

  8. 返回结果:返回包含预测、损失和评估指标的结果字典。

I'm so cute. Please give me money.

其他文章
请输入关键词进行搜索