AlphaFold2代码阅读(四)

2021SC@SDUSC


class EmbeddingsAndEvoformer

class EmbeddingsAndEvoformer(hk.Module):

  def __init__(self, config, global_config, name='evoformer'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, batch, is_training, safe_key=None):

    c = self.config
    gc = self.global_config

1.modules.py中class EvoformerIteration主要实现了一个Evoformer Block,而class EmbeddingsAndEvoformer则负责嵌入输入数据并运行 Evoformer,初始化函数__init__的参数就有name=‘evoformer’
2.class EmbeddingsAndEvoformer后面还产生了MSA, single and pair representations

if safe_key is None:
      safe_key = prng.SafeKey(hk.next_rng_key())

这里和之前的class EvoformerIteration类似
在这里hk.next_rng_key()返回一个唯一的rng键

    preprocess_1d = common_modules.Linear(
        c.msa_channel, name='preprocess_1d')(
            batch['target_feat'])

    preprocess_msa = common_modules.Linear(
        c.msa_channel, name='preprocess_msa')(
            batch['msa_feat'])

    msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa

    left_single = common_modules.Linear(
        c.pair_channel, name='left_single')(
            batch['target_feat'])
    right_single = common_modules.Linear(
        c.pair_channel, name='right_single')(
            batch['target_feat'])
    pair_activations = left_single[:, None] + right_single[None]
    mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :

上图中的主要作用是嵌入集群 MSA
经过查找论文,这里对应的伪代码如下
AlphaFold2代码阅读(四)
这段代码介绍了初始嵌入的细节。
将输入的 primary sequence和MSA features转变为MSA 表示
primary sequence特征被转换为form pair表示
这两种表示也由模板和未集群的MSA特性提供信息
AlphaFold2代码阅读(四)

    if c.recycle_pos and 'prev_pos' in batch:
      prev_pseudo_beta = pseudo_beta_fn(
          batch['aatype'], batch['prev_pos'], None)
      dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos)
      pair_activations += common_modules.Linear(
          c.pair_channel, name='prev_pos_linear')(
              dgram)

    if c.recycle_features:
      if 'prev_msa_first_row' in batch:
        prev_msa_first_row = hk.LayerNorm([-1],
                                          True,
                                          True,
                                          name='prev_msa_first_row_norm')(
                                              batch['prev_msa_first_row'])
        msa_activations = msa_activations.at[0].add(prev_msa_first_row)

      if 'prev_pair' in batch:
        pair_activations += hk.LayerNorm([-1],
                                         True,
                                         True,
                                         name='prev_pair_norm')(
                                             batch['prev_pair'])

上述代码是对注入以前的输出以进行回收
结合伪代码来看
AlphaFold2代码阅读(四)

在AlphaFold中,从结构模块回收预测的主原子坐标,从evote输出对和第一行MSA表示。而上述代码展示了嵌入细节。这两种类型的表示都通过LayerNorm为相应的输入表示准备更新。所预测的碳原子坐标(甘氨酸的阿尔法碳)被用来计算成对距离,然后将其离散成15个宽度为1.25A埃,跨度到大约20埃。所得到的单热二向图被线性投影并添加到对表示更新中。循环更新被注入到网络中.这是用于循环先前预测的唯一机制,而网络的其余部分在所有循环迭代中都是相同的。

    if c.max_relative_feature:
      pos = batch['residue_index']
      offset = pos[:, None] - pos[None, :]
      rel_pos = jax.nn.one_hot(
          jnp.clip(
              offset + c.max_relative_feature,
              a_min=0,
              a_max=2 * c.max_relative_feature),
          2 * c.max_relative_feature + 1)
      pair_activations += common_modules.Linear(
          c.pair_channel, name='pair_activiations')(
              rel_pos)

上述代码主要是为相对位置编码。
为了向网络提供关于链中残基位置的信息,将相对位置特征编码到初始对表示中。具体来说,对于残基,计算链内剪切的相对距离,将其编码为一个热向量,并将该向量的线性投影添加到zij中。由于被最大值32剪切,所以残基链内的任何较大的距离都不会被这个特征所区分。这种归纳偏差消除了对初级序列距离的强调。

    if c.template.enabled:
      template_batch = {k: batch[k] for k in batch if k.startswith('template_')}
      template_pair_representation = TemplateEmbedding(c.template, gc)(
          pair_activations,
          template_batch,
          mask_2d,
          is_training=is_training)

      pair_activations += template_pair_representation

上面这段代码是将模板嵌入到配对激活中。

    extra_msa_feat = create_extra_msa_feature(batch)
    extra_msa_activations = common_modules.Linear(
        c.extra_msa_channel,
        name='extra_msa_activations')(
            extra_msa_feat)

上述代码段功能是嵌入额外的 MSA 功能

   extra_msa_stack_input = {
        'msa': extra_msa_activations,
        'pair': pair_activations,
    }

    extra_msa_stack_iteration = EvoformerIteration(
        c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')

    def extra_msa_stack_fn(x):
      act, safe_key = x
      safe_key, safe_subkey = safe_key.split()
      extra_evoformer_output = extra_msa_stack_iteration(
          activations=act,
          masks={
              'msa': batch['extra_msa_mask'],
              'pair': mask_2d
          },
          is_training=is_training,
          safe_key=safe_subkey)
      return (extra_evoformer_output, safe_key)

    if gc.use_remat:
      extra_msa_stack_fn = hk.remat(extra_msa_stack_fn)

    extra_msa_stack = layer_stack.layer_stack(
        c.extra_msa_stack_num_block)(
            extra_msa_stack_fn)
    extra_msa_output, safe_key = extra_msa_stack(
        (extra_msa_stack_input, safe_key))

    pair_activations = extra_msa_output['pair']

    evoformer_input = {
        'msa': msa_activations,
        'pair': pair_activations,
    }

    evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d}

未聚类MSA序列特征线性投影形成初始表示,这些表示用包含4个块的额外MSA堆栈进行处理,而上面的代码段就是写额外MSA堆栈的。这个额外的MSA堆栈与主要的Evoformer块非常相似,具有使用全局列级自我注意和更小的表示大小来处理大量序列的显著差异。最终的对表示被用作主Evoformer堆栈的输入,而最终的MSA激活没有使用.

    if c.template.enabled and c.template.embed_torsion_angles:
      num_templ, num_res = batch['template_aatype'].shape
 

这if后的包括未列出的代码都是使用模板嵌入,将num_templ rows附加到 msa_activations

 aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)

这是在嵌入模板 aatypes

      ret = all_atom.atom37_to_torsion_angles(
          aatype=batch['template_aatype'],
          all_atom_pos=batch['template_all_atom_positions'],
          all_atom_mask=batch['template_all_atom_masks'],
          # Ensure consistent behaviour during testing:
          placeholder_for_undefined=not gc.zero_init)

      template_features = jnp.concatenate([
          aatype_one_hot,
          jnp.reshape(
              ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]),
          jnp.reshape(
              ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]),
          ret['torsion_angles_mask']], axis=-1)

      template_activations = common_modules.Linear(
          c.msa_channel,
          initializer='relu',
          name='template_single_embedding')(
              template_features)
      template_activations = jax.nn.relu(template_activations)
      template_activations = common_modules.Linear(
          c.msa_channel,
          initializer='relu',
          name='template_projection')(
              template_activations)

这是在嵌入模板 aatype、扭转角和掩码
并且形状Shape (templates, residues, msa_channels)

      evoformer_input['msa'] = jnp.concatenate(
          [evoformer_input['msa'], template_activations], axis=0)

这是要将模板连接到 msa

      torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2]
      torsion_angle_mask = torsion_angle_mask.astype(
          evoformer_masks['msa'].dtype)
      evoformer_masks['msa'] = jnp.concatenate(
          [evoformer_masks['msa'], torsion_angle_mask], axis=0)

将模板掩码连接到 msa 掩码
从 psi 角度使用掩码,因为它只取决于骨架原子

 evoformer_iteration = EvoformerIteration(
        c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')

这是在搭建网络主干,调用了前面的EvoformerIteration


总结

  这样结合源代码,论文和伪代码就把class EmbeddingsAndEvoformer基本上看完了。之前的class EvoformerIteration主要实现了一个Evoformer Block,class EmbeddingsAndEvoformer则包含了input embedding和48个Evoformer Block,EvoformerIteration在EmbeddingsAndEvoformer中也用到了。


上一篇:POJ-2676 Sudoku(简单数独-dfs深搜)


下一篇:Mysql Explain Extra