CTC前缀分数计算

Feb 25, 2024· · 3 min read

CTC(连接主义的时序分类)是一种在长度不同的序列中计算损失的方式。 对于不定长的模型输出和标签,在没有给定对齐的情况下计算概率和梯度,从而进行模型的训练。

在基于Attention的模型中,使用hybrid ctc+attention的方式训练,在解码过程中,进一步利用训练时CTC头部的信息,可以计算CTC前缀得分,加入到Beam Search解码中。这里前缀得分的计算方式与CTC Loss的前向后向算法的前向部分比较类似。

首先需要明确两个概念:规整字符串和CTC字符串。

在使用CTC进行训练的时候,前者表示训练集的标签,后者表示序列模型的输出。 即CTC字符串经过合并重复和去除Blank之后得到的结果。

在训练过程中,CTC损失计算可以合并得到规整字符串$y = [y_1, y_2, ..., y_N]$的所有CTC字符串$z = [z_1, z_2, ..., z_T]$的概率。

在预测的过程中,在每一个时间步$t$,给定之前$t-1$步计算得到的前缀$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算CTC字符串$z_1, z_2, ..., z_t$预测得到给定规整字符串$h=gc$的概率,因此计算过程类似CTC损失的前向部分。

在时间步$t$增加的时候,需要维护两个中间变量 $p^{b}\_{t} \in {\mathbb R}^{T}$ 和 $p^{n}\_{t} \in {\mathbb R}^{T}$。

其中 $p^{b}\_{t}$ 表示 $z[:t]$ 即以空白符号结尾的长度为 $t$ 的CTC字符串规整到目标的概率,$p^{n}\_{t}$ 表示 $z[:t]$ 即非空白符号结尾的长度为 $t$ 的CTC字符串的概率。

这一步的中间变量是为了简化计算的过程,因为在$h=gc$的时候,可以根据 $p\_{t-1}^{b}(g)$计算得到 $p^b\_{t}(h)$。

在初始化时,目标前缀字符串为空串。

  1. $p^{b}\_{t}=\prod_{i=0}^{t} p(z_i=blank)$,即长度为$t$的CTC字符串解析到空串的概率为每一位解析为blank的概率乘积。
  2. $p^{n}_{t}=0$,即任何非空白符号结尾的CTC字符串都不可能解析到空串。

在遍历时,给定前缀为$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算并返回$h=gc$的概率。

  1. $y_c \neq y_p$时,返回值为 $[p^{b}\_{t-1}(g)+p^n\_{t-1}(g)]p(y_p)$,即$z[:t-1]$可以规整为$g$和$g -$的概率与$z_t = y_p$的概率乘积。两者相乘表示CTC 字符串$z[:t]$可以规整为$h=gc$的概率。
  2. $y_c=y_p$时,返回值为$[p^b_{t-1}(g)]p(y_p)$,由于前缀$g$最后一个字符与预测字符相同,只有以blank结尾的$z[:t-1]$序列可以规整到$h=gc$。

同时需要更新中间变量,即计算$p_{t}^{b}(h)$和$p_{t}^{n}(h)$,需要分情况讨论。

  1. 更新 $p^{b}\_{t}(h)$:$p^{b}\_{t}(h) = [p^{b}\_{t-1}(g) + p^{n}\_{t-1}(g)]p(-)$,即$z[:t]$规整为h的概率,可以使用$z[:t-1]$规整为g的概率与当前位置预测为空的概率计算。
  2. 更新$p^{n}_{t}(h)$:
    1. 如果$y_c = y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,第一项表示以blank结尾的前缀,后一项表示重复符号的情况。
    2. 如果$y_c \neq y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,分别表示$z[:t-1]$规整到$g$和$h$的情况。

代码

下面结合代码进行分析。

"""
一个简单的使用attention+ctc的混合解码的代码实现
"""
import torch
from torch import nn, Tensor
from datasets import const


def ctc_prefix_score(
    ctc_probs, seq_len, prefix, next_word, prev_state
):
    assert ctc_probs.shape[0] == 1
    assert ctc_probs[0, 0].exp().sum() > 0.99, \
        f"sum up as {ctc_probs[0, 0].exp().sum()}"

    # 获得前缀的长度
    prev_length = prefix.shape[-1]

    # 获得不同长度CTC字符串得到前缀的概率p^n(g)和p^a(g)
    gamma_nbk = prev_state[0].clone()
    gamma_blk = prev_state[1].clone()

    # 计算前缀的概率,包括结尾为blank和非blank两种情况的概率求和
    prev_sum = torch.logaddexp(gamma_nbk, gamma_blk)

    if prev_length > 0:
        if prefix[-1] != next_word:
            log_phi = prev_sum
        else:
            log_phi = gamma_blk
    else:
        log_phi = prev_sum

    # 达到EOS的时候直接输出
    if next_word == const.EOS:
        psi = prev_sum[-1]
        return psi, None, None
    
    if prev_length == 0:
        # 如果前缀为空串,可以直接计算    
        gamma_nbk[0] = ctc_probs[0, 0, next_word]
        gamma_blk[0] = -1e10
    else:
        # 对于长度小于n的CTC字符串,无法得到长度为n的规整字符串
        # 严格说要将[0 : prev_length-1]都设置为0
        # 但是之前的部分不参与迭代的计算
        gamma_nbk[prev_length-1] = -1e10
        gamma_blk[prev_length-1] = -1e10

    # 确认计算开始的位置
    start = max(1, prev_length)
    psi = gamma_nbk[start-1]

    for t in range(start, seq_len):
        # 更新p^n_t(h)为p^n_{t-1}(h)和phi的和
        # 即上一时刻得到h的概率乘上当前位置重复符号的概率
        # 加上上一时刻得到g的概率
        gamma_nbk[t] = torch.logaddexp(
            gamma_nbk[t-1], log_phi[t-1]
        ) + ctc_probs[0, t, next_word]
        # 更新p^b_t(h)为p^b_{t-1}(h)和p^n_{t-1}(h)的和
        # 即在上一时刻就得到h的概率
        # 再乘上当前位置预测得到空串的概率
        gamma_blk[t] = torch.logaddexp(
            gamma_blk[t-1], gamma_nbk[t-1]
        ) + ctc_probs[0, t, const.PAD]

        # 输出的序列概率值为上一时刻得到h的概率
        # 加上上一时刻得到g,当前位置预测c的概率。
        psi = torch.logaddexp(
            psi,
            log_phi[t-1] + ctc_probs[0, t, next_word]
        )

    return psi, gamma_nbk, gamma_blk


@torch.no_grad()
def hybrid_decoder(
    m_transformer: nn.Module,
    m_embed: nn.Module,
    m_mlp: nn.Module,
    ctc_head: nn.Module,
    embed_src_seq: Tensor,
    src_padding_mask: Tensor,
    mask_list,
    beam_size: int,
    max_len: int,
    device: torch.DeviceObjType,
    ctc_weight: int = 0.1
):
    
    assert embed_src_seq.shape[0] == 1
    assert embed_src_seq.shape[2] == m_embed.emb_size

    encode_memory = m_transformer.encode(
        embed_src=embed_src_seq, src_mask=None,
        src_key_padding_mask=src_padding_mask,
    )

    # 获得输入序列的长度
    seq_len = (~src_padding_mask).sum(dim=1)[0]

    # 计算得到CTC头部的预测结果
    ctc_probs = ctc_head.mlp(encode_memory).log_softmax(-1)

    # 初始化CTC中间变量
    gamma_nbk = torch.full((seq_len, ), -1e6, device=device)
    gamma_blk = torch.full((seq_len, ), -1e6, device=device)
    
    for i in range(seq_len):
        if i == 0:
            gamma_blk[i] = ctc_probs[0, i, const.PAD]
        else:
            gamma_blk[i] = gamma_blk[i-1] + ctc_probs[0, i, const.PAD]

    # 初始化所有假设
    total_hypos = [
        (
            torch.tensor([const.BOS], dtype=torch.long, device=device),    
            torch.tensor(0, device=device),    
            (gamma_nbk, gamma_blk),
        )
    ]

    eos_hypos = []

    # 限制解码长度
    decode_len = min(max_len, seq_len)
    
    for decode_idx in range(1, decode_len):
        
        # 记录每次迭代得到的K**2个hypos
        running_hypos = []
        
        for cur_hyp in total_hypos:
            # 获得前缀开始处理
            prefix = cur_hyp[0].unsqueeze(0) # [1, decode_idx]

            embed_prefix = m_embed(prefix)

            attn_mask = mask_list[prefix.shape[1] - 1]
            
            probs = m_transformer.decode(
                embed_tgt=embed_prefix, memory=encode_memory,
                tgt_mask=attn_mask,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=src_padding_mask,
            )
            
            probs = m_mlp(probs[0, -1, :]).log_softmax(dim=-1)

            best_score, best_idx = probs.topk
                k=beam_size, dim=-1
            )  

            for k_idx, next_word in enumerate(best_idx):
                
                # 对于每一种可能的结果打分
                
                attn_score = cur_hyp[1] + best_score[k_idx]

                ctc_score, gamma_nbk, gamma_blk = \
                    ctc_prefix_score(
                        ctc_probs, seq_len,
                        prefix=cur_hyp[0][1:],
                        next_word=next_word,
                        prev_state=cur_hyp[2]
                    )
                
                total_score = (1 - ctc_weight) * attn_score \

                            + ctc_weight * ctc_score

  
                
                total_seq = torch.cat([cur_hyp[0], next_word.unsqueeze(0)])
                if next_word.data != const.EOS:
                    running_hypos.append((total_seq, total_score))    
                else:
                    # 找到了EOS
                    eos_hypos.append((total_seq, total_score))

        # 从K**2个hypo更新总的hypo
        total_hypos = sorted(running_hypos, key=lambda x: x[1], reverse=True)[:beam_size]
        # 结束寻找
        if len(eos_hypos) > beam_size * 2: break
        
    if len(eos_hypos) == 0:
        # 如果没有EOS,就从total里面抓一部分出来
        eos_hypos.extend(
            sorted(total_hypos, key=lambda x: x[1], reverse=True)[:2*beam_size]
        )

    # 添加length penalty    
    best_hypo = sorted(
        eos_hypos, key=lambda x: x[1] / (x[0].shape[0] ** 0.6), reverse=True
    )[0]

    pred_seq = best_hypo[0].unsqueeze(0)
    pred_seq_len = torch.tensor(pred_seq.shape[1], device=device).unsqueeze(0
    
    return pred_seq, pred_seq_len, encode_memory

代码参考包括:

  1. prefix beam search
  2. SSL AVSR
  3. ESPNet