#ecnet.py#
1 | class ECNet(object): |
__init__
方法
这个构造函数负责初始化 ECNet
类的所有必要组件。
参数
output_dir
:保存输出的目录路径。train_tsv
,test_tsv
:训练和测试数据的 TSV 文件路径。fasta
,ccmpred_output
:FASTA 文件路径和 CCMPred 输出路径。use_loc_feat
,use_glob_feat
:是否使用本地和全局特征。split_ratio
:数据划分比例。random_seed
:随机种子,用于确保可重现性。nn_name
:神经网络类型,例如 LSTM 或双向 LSTM。n_ensembles
:模型集合的大小。d_embed
,d_model
,d_h
,nlayers
:模型的维度和层数参数。batch_size
:训练批次的大小。save_log
:是否保存日志。
主要组件和流程
- 数据集:使用
Dataset
类创建数据集对象,包括所有训练、验证和测试数据。 - 保存器:使用
Saver
类创建保存器对象,用于管理输出目录和文件保存。 - 日志记录器:使用
Logger
类创建日志记录器,可选择保存到文件。 - 模型参数:计算和设置模型的特定参数,如词汇表大小、序列长度和投影层配置。
- 设备:检测是否有可用的 GPU,并据此设置 PyTorch 设备。
- 模型创建:根据指定的网络类型和参数创建一个或多个 LSTM 预测器模型。
- 损失函数:设置均方误差损失函数。
- 优化器:为每个模型创建 Adam 优化器。
模型的加载点和重用接口设置
1 |
|
test_pack
属性
这个属性返回一个包含测试数据加载器和测试 DataFrame 的元组。
流程
:
- 如果
_test_pack
尚未设置,则使用get_dataloader
方法从数据集中获取测试数据加载器和测试 DataFrame。 - 将这些值存储在
_test_pack
中,并返回。
- 如果
test_loader
和 test_df
属性
这两个属性是 test_pack
属性的便捷访问器,分别返回测试数据加载器和测试 DataFrame。
load_checkpoint
方法
目的:从给定的目录加载预训练的模型检查点。
参数:
checkpoint_dir
- 包含模型检查点文件的目录路径。流程
:
- 验证给定的路径是否为目录。
- 对于每个模型:
- 构建检查点文件的路径。
- 从文件加载检查点。
- 更新模型的状态字典并加载到模型中。
- 加载优化器的状态字典。
load_single_pretrained_model
方法
目的:从给定的文件路径加载单个预训练模型。
参数:
checkpoint_path
- 检查点文件的路径;model
- 要加载的模型;optimizer
- 要加载的优化器;is_resume
- 是否返回日志信息。流程
:
- 从文件加载检查点。
- 更新模型的状态字典并加载到模型中。
- 加载优化器的状态字典。
- 返回模型、优化器和可选的日志信息。
save_checkpoint
方法
目的:保存模型的检查点。
参数:
ckp_name
- 检查点的名称;model_dict
- 模型的状态字典;opt_dict
- 优化器的状态字典;log_info
- 日志信息。流程
:
- 创建一个包含状态字典和日志信息的检查点字典。
- 使用
saver
的save_ckp
方法保存检查点。
模型训练和测试
1 | def train(self, epochs=1000, log_freq=100, eval_freq=50, |
train
方法
目的:训练模型的整个过程。
参数
:
epochs
:训练周期的数量。log_freq
:记录训练信息的频率。eval_freq
:进行验证和检查早期停止的频率。patience
:早期停止的耐心参数,即在验证得分没有改进时允许的连续周期数。save_checkpoint
:是否在找到更好的验证得分时保存检查点。resume_path
:从先前保存的检查点恢复训练的路径。
主要流程
模型循环:遍历所有模型和对应的优化器。
加载或恢复检查点:如果提供了
resume_path
,则从检查点加载或恢复模型。准备训练和验证数据加载器:从数据集中获取训练和验证的数据加载器。
早期停止:初始化一个早期停止对象来监视验证得分。
训练循环
:
- 将模型设置为训练模式。
- 对于每个周期:
- 初始化时间和损失计数器。
- 遍历训练加载器中的批次:
- 将批次数据移动到适当的设备上。
- 执行前向传递和损失计算。
- 执行反向传递和优化器步骤。
- 累积批次损失。
- 如果到达验证频率,则在验证集上评估模型,并使用早期停止对象更新得分。
- 如果到达日志频率,则在训练和验证集上评估模型,并记录结果。
- 检查早期停止条件,如果满足,则中断训练。
异常处理:捕获键盘中断,允许用户手动停止训练。
加载最佳模型:如果找到更好的模型,则加载最佳状态字典。
测试:在测试集上测试模型并记录结果。
1 | def test(self, test_model=None, test_loader=None, test_df=None, |
test
方法
目的:对给定的测试模型在测试加载器上的性能进行评估。
参数
:
test_model
:要测试的模型(可选)。test_loader
:包含测试数据的数据加载器(可选)。test_df
:包含测试数据的 DataFrame(可选)。checkpoint_dir
:从检查点加载预训练模型的目录(可选)。save_prediction
:是否保存预测结果到文件(可选)。calc_metric
:是否计算评估指标(例如相关性)(可选)。calc_loss
:是否计算损失(可选)。model_label
:模型标签(可选,未在代码中使用)。mode
:描述测试模式的字符串(例如 ‘test’ 或 ‘val’)。
主要流程
加载预训练模型:如果提供了
checkpoint_dir
,则从检查点加载预训练模型。设置测试加载器和 DataFrame:如果没有提供,使用默认的测试加载器和测试 DataFrame。
选择测试模型:如果没有提供
test_model
,则使用所有模型进行测试。模型循环
:遍历要测试的模型。
- 将模型设置为评估模式。
- 初始化真实值和预测值的变量。
- 遍历测试加载器中的批次:
- 将批次数据移动到适当的设备上。
- 执行前向传递。
- 如果
calc_loss
为 True,则计算损失。 - 收集预测值。
- 累积预测值和损失。
计算集成预测和损失:通过平均所有模型的预测和损失来计算集成结果。
计算评估指标:如果
calc_metric
为 True,则计算相关性等评估指标。保存预测:如果
save_prediction
为 True,则将预测结果保存到文件中。返回结果:返回包含预测、损失和评估指标的结果字典。