banner
NEWS LETTER

ECnet代码解读(一)

Scroll down

数据处理部分代码

#data.py#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import warnings
import numpy as np
import pandas as pd
from Bio import SeqIO
import torch.utils.data
from sklearn.model_selection import KFold, ShuffleSplit

from ecnet import vocab
from ecnet.local_feature import CCMPredEncoder
from ecnet.global_feature import TAPEEncoder


class SequenceData(torch.utils.data.Dataset):
def __init__(self, sequences, labels):
self.sequences = sequences
self.labels = labels

def __len__(self):
return len(self.labels)

def __getitem__(self, index):
return self.sequences[index], self.labels[index]


class MetagenesisData(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

def index_encoding(sequences):
'''
Modified from https://github.com/openvax/mhcflurry/blob/master/mhcflurry/amino_acid.py#L110-L130

Parameters
----------
sequences: list of equal-length sequences

Returns
-------
np.array with shape (#sequences, length of sequences)
'''
df = pd.DataFrame(iter(s) for s in sequences)
encoding = df.replace(vocab.AMINO_ACID_INDEX)
encoding = encoding.values.astype(np.int)
return encoding


class Dataset(object):
def __init__(self,
train_tsv=None, test_tsv=None,
fasta=None, ccmpred_output=None,
use_loc_feat=True, use_glob_feat=True,
split_ratio=[0.9, 0.1],
random_seed=42):
"""
split_ratio: [train, valid] or [train, valid, test]
"""

self.train_tsv = train_tsv
self.test_tsv = test_tsv
self.fasta = fasta
self.use_loc_feat = use_loc_feat
self.use_glob_feat = use_glob_feat
self.split_ratio = split_ratio
self.rng = np.random.RandomState(random_seed)

self.native_sequence = self._read_native_sequence()
if train_tsv is not None:
self.full_df = self._read_mutation_df(train_tsv)
else:
self.full_df = None

if test_tsv is None:
if len(split_ratio) != 3:
split_ratio = [0.7, 0.1, 0.2]
warnings.warn("\nsplit_ratio should have 3 elements if test_tsv is None." + \
f"Changing split_ratio to {split_ratio}. " + \
"Set to other values using --split_ratio.")
self.train_df, self.valid_df, self.test_df = \
self._split_dataset_df(self.full_df, split_ratio)
else:
if len(split_ratio) != 2:
split_ratio = [0.9, 0.1]
warnings.warn("\nsplit_ratio should have 2 elements if test_tsv is provided. " + \
f"Changing split_ratio to {split_ratio}. " + \
"Set to other values using --split_ratio.")
self.test_df = self._read_mutation_df(test_tsv)
if self.full_df is not None:
self.train_df, self.valid_df, _ = \
self._split_dataset_df(self.full_df, split_ratio)

if self.full_df is not None:
self.train_valid_df = pd.concat(
[self.train_df, self.valid_df]).reset_index(drop=True)

if self.use_loc_feat:
self.ccmpred_encoder = CCMPredEncoder(
ccmpred_output=ccmpred_output, seq_len=len(self.native_sequence))
if self.use_glob_feat:
self.tape_encoder = TAPEEncoder()

def _read_native_sequence(self):
fasta = SeqIO.read(self.fasta, 'fasta')
native_sequence = str(fasta.seq)
return native_sequence


def _check_split_ratio(self, split_ratio):
"""
Modified from: https://github.com/pytorch/text/blob/3d28b1b7c1fb2ddac4adc771207318b0a0f4e4f9/torchtext/data/dataset.py#L284-L311
"""
test_ratio = 0.
if isinstance(split_ratio, float):
assert 0. < split_ratio < 1., (
"Split ratio {} not between 0 and 1".format(split_ratio))
valid_ratio = 1. - split_ratio
return (split_ratio, valid_ratio, test_ratio)
elif isinstance(split_ratio, list):
length = len(split_ratio)
assert length == 2 or length == 3, (
"Length of split ratio list should be 2 or 3, got {}".format(split_ratio))
ratio_sum = sum(split_ratio)
if not ratio_sum == 1.:
split_ratio = [float(ratio) / ratio_sum for ratio in split_ratio]
if length == 2:
return tuple(split_ratio + [test_ratio])
return tuple(split_ratio)
else:
raise ValueError('Split ratio must be float or a list, got {}'
.format(type(split_ratio)))


def _split_dataset_df(self, input_df, split_ratio, resample_split=False):
"""
Modified from:
https://github.com/pytorch/text/blob/3d28b1b7c1fb2ddac4adc771207318b0a0f4e4f9/torchtext/data/dataset.py#L86-L136
"""
_rng = self.rng.randint(512) if resample_split else self.rng
df = input_df.copy()
df = df.sample(frac=1, random_state=_rng).reset_index(drop=True)
N = len(df)
train_ratio, valid_ratio, test_ratio = self._check_split_ratio(split_ratio)
train_len = int(round(train_ratio * N))
valid_len = N - train_len if not test_ratio else int(round(valid_ratio * N))

train_df = df.iloc[:train_len].reset_index(drop=True)
valid_df = df.iloc[train_len:train_len + valid_len].reset_index(drop=True)
test_df = df.iloc[train_len + valid_len:].reset_index(drop=True)

return train_df, valid_df, test_df


def _mutation_to_sequence(self, mutation):
'''
Parameters
----------
mutation: ';'.join(WiM) (wide-type W at position i mutated to M)
'''
sequence = self.native_sequence
for mut in mutation.split(';'):
wt_aa = mut[0]
mt_aa = mut[-1]
pos = int(mut[1:-1])
assert wt_aa == sequence[pos - 1],\
"%s: %s->%s (fasta WT: %s)"%(pos, wt_aa, mt_aa, sequence[pos - 1])
sequence = sequence[:(pos - 1)] + mt_aa + sequence[pos:]
return sequence


def _mutations_to_sequences(self, mutations):
return [self._mutation_to_sequence(m) for m in mutations]


def _drop_invalid_mutation(self, df):
'''
Drop mutations WiM where
- W is incosistent with the i-th AA in native_sequence
- M is ambiguous, e.g., 'X'
'''
flags = []
for mutation in df['mutation'].values:
for mut in mutation.split(';'):
wt_aa = mut[0]
mt_aa = mut[-1]
pos = int(mut[1:-1])
valid = True if wt_aa == self.native_sequence[pos - 1] else False
valid = valid and (mt_aa not in ['X'])
flags.append(valid)
df = df[flags].reset_index(drop=True)
return df

def _read_mutation_df(self, tsv):
df = pd.read_table(tsv)
df = self._drop_invalid_mutation(df)
df['sequence'] = self._mutations_to_sequences(df['mutation'].values)
return df


def encode_seq_enc(self, sequences):
seq_enc = index_encoding(sequences)
seq_enc = torch.from_numpy(seq_enc.astype(np.int))
return seq_enc

def encode_loc_feat(self, sequences):
feat = self.ccmpred_encoder.encode(sequences)
feat = torch.from_numpy(feat).float()
return feat

def encode_glob_feat(self, sequences):
feat = self.tape_encoder.encode(sequences)
feat = torch.from_numpy(feat).float()
return feat

def build_data(self, mode, return_df=False):
if mode == 'train':
df = self.train_df.copy()
elif mode == 'valid':
df = self.valid_df.copy()
elif mode == 'test':
df = self.test_df.copy()
else:
raise NotImplementedError

sequences = df['sequence'].values
seq_enc = self.encode_seq_enc(sequences)
if self.use_loc_feat:
loc_feat = self.encode_loc_feat(sequences)
if self.use_glob_feat:
glob_feat = self.encode_glob_feat(sequences)

labels = df['score'].values
labels = torch.from_numpy(labels.astype(np.float32))

samples = []
for i in range(len(df)):
sample = {
'sequence':sequences[i],
'label':labels[i],
'seq_enc': seq_enc[i],
}
if self.use_loc_feat:
sample['loc_feat'] = loc_feat[i]
if self.use_glob_feat:
sample['glob_feat'] = glob_feat[i]
samples.append(sample)
data = MetagenesisData(samples)
if return_df:
return data, df
else:
return data

def get_dataloader(self, mode, batch_size=128,
return_df=False, resample_train_valid=False):
if resample_train_valid:
self.train_df, self.valid_df, _ = \
self._split_dataset_df(
self.train_valid_df, self.split_ratio[:2], resample_split=True)

if mode == 'train_valid':
train_data, train_df = self.build_data('train', return_df=True)
valid_data, valid_df = self.build_data('valid', return_df=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
if return_df:
return (train_loader, train_df), (valid_loader, valid_df)
else:
return train_loader, valid_loader
elif mode == 'test':
test_data, test_df = self.build_data('test', return_df=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
if return_df:
return test_loader, test_df
else:
return test_loader
else:
raise NotImplementedError

if __name__ == '__main__':
protein_name = 'gb1'
dataset_name = 'Envision_Gray2018'
dataset = Dataset(
train_tsv=f'../../output/mutagenesis/{dataset_name}/{protein_name}/data.tsv',
fasta=f'../../output/mutagenesis/{dataset_name}/{protein_name}/native_sequence.fasta',
ccmpred_output=f'../../output/homologous/{dataset_name}/{protein_name}/hhblits/ccmpred/{protein_name}.braw',
split_ratio=[0.7, 0.1, 0.2],
use_loc_feat=False, use_glob_feat=False,
)
# dataset.build_data('train')
(loader, df), (_, _) = dataset.get_dataloader('train_valid',
batch_size=32, return_df=True)
print(df.head())
print(len(loader.__iter__()))
(loader, df), (_, _) = dataset.get_dataloader('train_valid',
batch_size=32, return_df=True, resample_train_valid=True)
print(df.head())
print(len(loader.__iter__()))
loader, df = dataset.get_dataloader('test',
batch_size=32, return_df=True, resample_train_valid=True)
print(next(loader.__iter__()))

数据加载和预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class SequenceData(torch.utils.data.Dataset):
def __init__(self, sequences, labels):
self.sequences = sequences
self.labels = labels

def __len__(self):
return len(self.labels)

def __getitem__(self, index):
return self.sequences[index], self.labels[index]


class MetagenesisData(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]
  1. SequenceData
    • **__init__(self, sequences, labels)**:构造函数接受序列和标签作为输入,并将它们存储为类的属性。序列可以是一系列的特征向量,标签通常是相应的目标值或类别标签。
    • **__len__(self)**:此方法返回数据集的大小,即标签的数量。这对于迭代和批量处理数据集非常重要。
    • **__getitem__(self, index)**:此方法允许通过索引访问数据集中的特定项。返回的是给定索引处的序列和标签。
  2. MetagenesisData
    • **__init__(self, data)**:构造函数接受数据作为输入,并将其存储为类的属性。这些数据可能是一个复杂的结构,包括特征和标签。
    • **__len__(self)**:和上面一样,此方法返回数据集的大小。
    • **__getitem__(self, index)**:此方法允许通过索引访问数据集中的特定项。返回的是给定索引处的数据项。

这两个类的主要作用是封装数据,并提供一种标准化的方式来访问和迭代数据。这对于训练和评估机器学习模型非常重要,因为它允许模型以一致的方式处理数据,而无需关心数据的底层表示和结构。

通过将数据封装在这些类中,可以更容易地与 PyTorch 的其他组件(例如 DataLoader)集成,从而实现数据的批量加载、随机抽样和多线程加载等功能。这有助于提高训练和评估过程的效率和灵活性。

编码和数据准备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def index_encoding(sequences):
'''
Modified from https://github.com/openvax/mhcflurry/blob/master/mhcflurry/amino_acid.py#L110-L130

Parameters
----------
sequences: list of equal-length sequences

Returns
-------
np.array with shape (#sequences, length of sequences)
'''
df = pd.DataFrame(iter(s) for s in sequences)
encoding = df.replace(vocab.AMINO_ACID_INDEX)
encoding = encoding.values.astype(np.int)
return encoding


class Dataset(object):
def __init__(self,
train_tsv=None, test_tsv=None,
fasta=None, ccmpred_output=None,
use_loc_feat=True, use_glob_feat=True,
split_ratio=[0.9, 0.1],
random_seed=42):
"""
split_ratio: [train, valid] or [train, valid, test]
"""

self.train_tsv = train_tsv
self.test_tsv = test_tsv
self.fasta = fasta
self.use_loc_feat = use_loc_feat
self.use_glob_feat = use_glob_feat
self.split_ratio = split_ratio
self.rng = np.random.RandomState(random_seed)

self.native_sequence = self._read_native_sequence()
if train_tsv is not None:
self.full_df = self._read_mutation_df(train_tsv)
else:
self.full_df = None

if test_tsv is None:
if len(split_ratio) != 3:
split_ratio = [0.7, 0.1, 0.2]
warnings.warn("\nsplit_ratio should have 3 elements if test_tsv is None." + \
f"Changing split_ratio to {split_ratio}. " + \
"Set to other values using --split_ratio.")
self.train_df, self.valid_df, self.test_df = \
self._split_dataset_df(self.full_df, split_ratio)
else:
if len(split_ratio) != 2:
split_ratio = [0.9, 0.1]
warnings.warn("\nsplit_ratio should have 2 elements if test_tsv is provided. " + \
f"Changing split_ratio to {split_ratio}. " + \
"Set to other values using --split_ratio.")
self.test_df = self._read_mutation_df(test_tsv)
if self.full_df is not None:
self.train_df, self.valid_df, _ = \
self._split_dataset_df(self.full_df, split_ratio)

if self.full_df is not None:
self.train_valid_df = pd.concat(
[self.train_df, self.valid_df]).reset_index(drop=True)

if self.use_loc_feat:
self.ccmpred_encoder = CCMPredEncoder(
ccmpred_output=ccmpred_output, seq_len=len(self.native_sequence))
if self.use_glob_feat:
self.tape_encoder = TAPEEncoder()

def _read_native_sequence(self):
fasta = SeqIO.read(self.fasta, 'fasta')
native_sequence = str(fasta.seq)
return native_sequence

index_encoding 函数

这个函数用于将氨基酸序列转换为整数索引编码。

  • 参数sequences - 等长序列的列表。
  • 返回值:整数编码的 NumPy 数组,形状为 (#序列, 序列长度)。

流程:

  1. 使用 Pandas 的 DataFrame 将序列转换为表格形式。
  2. 使用 vocab.AMINO_ACID_INDEX 替换 DataFrame 中的值,将氨基酸转换为对应的索引。
  3. 将结果转换为整数类型的 NumPy 数组并返回。

Dataset

这个类是一个数据集的容器,用于存储和管理数据。

  • 构造函数:接受许多参数,包括训练和测试的 TSV 文件路径、FASTA 文件路径、CCMPred 输出、本地特征和全局特征的使用、数据集的划分比例以及随机种子。

流程:

  1. 初始化参数和随机数生成器。
  2. 读取本地序列(通过 _read_native_sequence 方法)。
  3. 根据提供的 TSV 文件读取训练和测试数据。
  4. 根据 split_ratio 对数据进行划分。
  5. 如果使用本地特征,则初始化 CCMPred 编码器;如果使用全局特征,则初始化 TAPE 编码器。

_read_native_sequence 方法

这个私有方法用于从 FASTA 文件中读取本地序列,并将其转换为字符串格式。

数据集划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _split_dataset_df(self, input_df, split_ratio, resample_split=False):
"""
Modified from:
https://github.com/pytorch/text/blob/3d28b1b7c1fb2ddac4adc771207318b0a0f4e4f9/torchtext/data/dataset.py#L86-L136
"""
_rng = self.rng.randint(512) if resample_split else self.rng
df = input_df.copy()
df = df.sample(frac=1, random_state=_rng).reset_index(drop=True)
N = len(df)
train_ratio, valid_ratio, test_ratio = self._check_split_ratio(split_ratio)
train_len = int(round(train_ratio * N))
valid_len = N - train_len if not test_ratio else int(round(valid_ratio * N))

train_df = df.iloc[:train_len].reset_index(drop=True)
valid_df = df.iloc[train_len:train_len + valid_len].reset_index(drop=True)
test_df = df.iloc[train_len + valid_len:].reset_index(drop=True)

return train_df, valid_df, test_df

  • 参数
    • input_df:要划分的 Pandas DataFrame。
    • split_ratio:一个包含划分比例的列表,如 [0.7, 0.1, 0.2]
    • resample_split:一个布尔值,如果为 True,则在划分数据之前重新采样随机种子。
  • 流程
    1. 通过调用 self.rng.randint(512) 创建一个随机数生成器 _rng,或者使用现有的随机数生成器 self.rng
    2. 复制输入的 DataFrame。
    3. 使用 _rng 随机化 DataFrame 的行顺序。
    4. 计算训练、验证和测试集的大小。首先通过调用 _check_split_ratio 方法验证和解析 split_ratio。然后,基于这些比例计算每个子集的长度。
    5. 使用 Pandas 的 iloc 方法,根据计算出的长度从随机化的 DataFrame 中切分训练、验证和测试集。
    6. 重置每个子集的索引并返回。
  • 返回值:三个 DataFrame,分别代表训练集、验证集和测试集。

突变的写入和读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def _mutation_to_sequence(self, mutation):
'''
Parameters
----------
mutation: ';'.join(WiM) (wide-type W at position i mutated to M)
'''
sequence = self.native_sequence
for mut in mutation.split(';'):
wt_aa = mut[0]
mt_aa = mut[-1]
pos = int(mut[1:-1])
assert wt_aa == sequence[pos - 1],\
"%s: %s->%s (fasta WT: %s)"%(pos, wt_aa, mt_aa, sequence[pos - 1])
sequence = sequence[:(pos - 1)] + mt_aa + sequence[pos:]
return sequence


def _mutations_to_sequences(self, mutations):
return [self._mutation_to_sequence(m) for m in mutations]


def _drop_invalid_mutation(self, df):
'''
Drop mutations WiM where
- W is incosistent with the i-th AA in native_sequence
- M is ambiguous, e.g., 'X'
'''
flags = []
for mutation in df['mutation'].values:
for mut in mutation.split(';'):
wt_aa = mut[0]
mt_aa = mut[-1]
pos = int(mut[1:-1])
valid = True if wt_aa == self.native_sequence[pos - 1] else False
valid = valid and (mt_aa not in ['X'])
flags.append(valid)
df = df[flags].reset_index(drop=True)
return df

def _read_mutation_df(self, tsv):
df = pd.read_table(tsv)
df = self._drop_invalid_mutation(df)
df['sequence'] = self._mutations_to_sequences(df['mutation'].values)
return df

1. _mutation_to_sequence 方法:

  • 目的:将一个突变描述转换为一个完整的氨基酸序列。

  • 参数mutation - 描述单个或多个突变的字符串,例如 “A5T;F10Y” 表示位置5的A突变为T,位置10的F突变为Y。

  • 流程

    • 从原始的 native_sequence 开始。

    • 对于

      1
      mutation

      中的每个突变:

      • 检查突变是否与 native_sequence 中的氨基酸相匹配。
      • 如果相匹配,则在序列中应用突变。
    • 返回突变后的序列。

2. _mutations_to_sequences 方法:

  • 目的:将多个突变描述转换为多个氨基酸序列。

  • 参数mutations - 描述多个突变的字符串列表。

  • 流程

    • 对于每个突变描述,调用 _mutation_to_sequence 方法并收集结果。

3. _drop_invalid_mutation 方法:

  • 目的:从 DataFrame 中删除无效的突变。

  • 参数df - 包含突变描述的 DataFrame。

  • 流程

    • 对于每个突变描述:
      • 检查突变是否与 native_sequence 中的氨基酸相匹配。
      • 检查突变是否为不明确的氨基酸,例如 “X”。
    • 基于上述检查结果,保留有效的突变,并从 DataFrame 中删除无效的突变。

4. _read_mutation_df 方法:

  • 目的:从 TSV 文件中读取突变数据,并转换为相应的氨基酸序列。

  • 参数tsv - TSV 文件的路径。

  • 流程

    • 使用 Pandas 从文件中读取数据。
    • 调用 _drop_invalid_mutation 方法,去除无效的突变。
    • 调用 _mutations_to_sequences 方法,将突变转换为相应的氨基酸序列,并将其添加到 DataFrame 中。

特征编码和数据载入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def encode_seq_enc(self, sequences):
seq_enc = index_encoding(sequences)
seq_enc = torch.from_numpy(seq_enc.astype(np.int))
return seq_enc

def encode_loc_feat(self, sequences):
feat = self.ccmpred_encoder.encode(sequences)
feat = torch.from_numpy(feat).float()
return feat

def encode_glob_feat(self, sequences):
feat = self.tape_encoder.encode(sequences)
feat = torch.from_numpy(feat).float()
return feat

def build_data(self, mode, return_df=False):
if mode == 'train':
df = self.train_df.copy()
elif mode == 'valid':
df = self.valid_df.copy()
elif mode == 'test':
df = self.test_df.copy()
else:
raise NotImplementedError

sequences = df['sequence'].values
seq_enc = self.encode_seq_enc(sequences)
if self.use_loc_feat:
loc_feat = self.encode_loc_feat(sequences)
if self.use_glob_feat:
glob_feat = self.encode_glob_feat(sequences)

labels = df['score'].values
labels = torch.from_numpy(labels.astype(np.float32))

samples = []
for i in range(len(df)):
sample = {
'sequence':sequences[i],
'label':labels[i],
'seq_enc': seq_enc[i],
}
if self.use_loc_feat:
sample['loc_feat'] = loc_feat[i]
if self.use_glob_feat:
sample['glob_feat'] = glob_feat[i]
samples.append(sample)
data = MetagenesisData(samples)
if return_df:
return data, df
else:
return data

def get_dataloader(self, mode, batch_size=128,
return_df=False, resample_train_valid=False):
if resample_train_valid:
self.train_df, self.valid_df, _ = \
self._split_dataset_df(
self.train_valid_df, self.split_ratio[:2], resample_split=True)

if mode == 'train_valid':
train_data, train_df = self.build_data('train', return_df=True)
valid_data, valid_df = self.build_data('valid', return_df=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
if return_df:
return (train_loader, train_df), (valid_loader, valid_df)
else:
return train_loader, valid_loader
elif mode == 'test':
test_data, test_df = self.build_data('test', return_df=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
if return_df:
return test_loader, test_df
else:
return test_loader
else:
raise NotImplementedError

1. encode_seq_enc 方法:

  • 目的:将氨基酸序列转换为整数索引编码。

  • 参数sequences - 氨基酸序列的列表。

  • 流程

    • 调用之前定义的 index_encoding 函数来获取整数编码。
    • 将 NumPy 数组转换为 PyTorch 张量并返回。

2. encode_loc_featencode_glob_feat 方法:

  • 目的:使用 CCMPred 编码器和 TAPE 编码器对序列进行本地和全局特征编码。

  • 参数sequences - 氨基酸序列的列表。

  • 流程

    • 调用编码器的 encode 方法来获取特征。
    • 将特征转换为 PyTorch 张量并返回。

3. build_data 方法:

  • 目的:根据给定的模式(训练、验证或测试)构建数据集。

  • 参数mode - 指定要构建的数据集类型;return_df - 如果为 True,则返回原始 DataFrame。

  • 流程

    • 根据模式复制相应的 DataFrame。
    • 对序列进行整数索引编码。
    • 如果使用,对序列进行本地和全局特征编码。
    • 从 DataFrame 中提取标签,并转换为 PyTorch 张量。
    • 创建样本字典并收集。
    • 使用 MetagenesisData 类创建数据集。
    • 返回数据集(和可选的 DataFrame)。

4. get_dataloader 方法:

  • 目的:根据给定的模式和批量大小构建数据加载器。

  • 参数mode, batch_size, return_df, resample_train_valid

  • 流程

    • 如果重新采样,重新划分训练和验证集。
    • 根据模式构建数据集。
    • 使用 PyTorch 的 DataLoader 创建数据加载器。
    • 返回数据加载器(和可选的 DataFrame)。

I'm so cute. Please give me money.

其他文章
请输入关键词进行搜索