CTC前缀分数计算

Feb 25, 2024·
Lei Yang
Lei Yang
· 3 min read
blog

CTC(连接主义的时序分类)是一种在长度不同的序列中计算损失的方式。 对于不定长的模型输出和标签,在没有给定对齐的情况下计算概率和梯度,从而进行模型的训练。

在基于Attention的模型中,使用hybrid ctc+attention的方式训练,在解码过程中,进一步利用训练时CTC头部的信息,可以计算CTC前缀得分,加入到Beam Search解码中。这里前缀得分的计算方式与CTC Loss的前向后向算法的前向部分比较类似。

首先需要明确两个概念:规整字符串和CTC字符串。

在使用CTC进行训练的时候,前者表示训练集的标签,后者表示序列模型的输出。 即CTC字符串经过合并重复和去除Blank之后得到的结果。

在训练过程中,CTC损失计算可以合并得到规整字符串 $y = [y_1, y_2, ..., y_N]$ 的所有CTC字符串$z = [z_1, z_2, ..., z_T]$的概率。

在预测的过程中,在每一个时间步$t$,给定之前$t-1$步计算得到的前缀$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算CTC字符串$z_1, z_2, ..., z_t$预测得到给定规整字符串$h=gc$的概率,因此计算过程类似CTC损失的前向部分。

在时间步$t$增加的时候,需要维护两个中间变量 $p^{b}\_{t} \in {\mathbb R}^{T}$ 和 $p^{n}\_{t} \in {\mathbb R}^{T}$。

其中 $p^{b}\_{t}$ 表示 $z[:t]$ 即以空白符号结尾的长度为 $t$ 的CTC字符串规整到目标的概率,$p^{n}\_{t}$ 表示 $z[:t]$ 即非空白符号结尾的长度为 $t$ 的CTC字符串的概率。

这一步的中间变量是为了简化计算的过程,因为在$h=gc$的时候,可以根据 $p\_{t-1}^{b}(g)$计算得到 $p^b\_{t}(h)$。

在初始化时,目标前缀字符串为空串。

  1. $p^{b}\_{t}=\prod_{i=0}^{t} p(z_i=blank)$,即长度为$t$的CTC字符串解析到空串的概率为每一位解析为blank的概率乘积。
  2. $p^{n}_{t}=0$,即任何非空白符号结尾的CTC字符串都不可能解析到空串。

在遍历时,给定前缀为$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算并返回$h=gc$的概率。

  1. $y_c \neq y_p$时,返回值为 $[p^{b}\_{t-1}(g)+p^n\_{t-1}(g)]p(y_p)$,即$z[:t-1]$可以规整为$g$和$g -$的概率与$z_t = y_p$的概率乘积。两者相乘表示CTC 字符串$z[:t]$可以规整为$h=gc$的概率。
  2. $y_c=y_p$时,返回值为$[p^b_{t-1}(g)]p(y_p)$,由于前缀$g$最后一个字符与预测字符相同,只有以blank结尾的$z[:t-1]$序列可以规整到$h=gc$。

同时需要更新中间变量,即计算$p_{t}^{b}(h)$和$p_{t}^{n}(h)$,需要分情况讨论。

  1. 更新 $p^{b}\_{t}(h)$:$p^{b}\_{t}(h) = [p^{b}\_{t-1}(g) + p^{n}\_{t-1}(g)]p(-)$,即$z[:t]$规整为h的概率,可以使用$z[:t-1]$规整为g的概率与当前位置预测为空的概率计算。
  2. 更新$p^{n}_{t}(h)$:
    1. 如果$y_c = y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,第一项表示以blank结尾的前缀,后一项表示重复符号的情况。
    2. 如果$y_c \neq y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,分别表示$z[:t-1]$规整到$g$和$h$的情况。

代码

下面结合代码进行分析。

"""
一个简单的使用attention+ctc的混合解码的代码实现
"""
import torch
from torch import nn, Tensor
from datasets import const


def ctc_prefix_score(
    ctc_probs, seq_len, prefix, next_word, prev_state
):
    assert ctc_probs.shape[0] == 1
    assert ctc_probs[0, 0].exp().sum() > 0.99, \
        f"sum up as {ctc_probs[0, 0].exp().sum()}"

    # 获得前缀的长度
    prev_length = prefix.shape[-1]

    # 获得不同长度CTC字符串得到前缀的概率p^n(g)和p^a(g)
    gamma_nbk = prev_state[0].clone()
    gamma_blk = prev_state[1].clone()

    # 计算前缀的概率,包括结尾为blank和非blank两种情况的概率求和
    prev_sum = torch.logaddexp(gamma_nbk, gamma_blk)

    if prev_length > 0:
        if prefix[-1] != next_word:
            log_phi = prev_sum
        else:
            log_phi = gamma_blk
    else:
        log_phi = prev_sum

    # 达到EOS的时候直接输出
    if next_word == const.EOS:
        psi = prev_sum[-1]
        return psi, None, None
    
    if prev_length == 0:
        # 如果前缀为空串,可以直接计算    
        gamma_nbk[0] = ctc_probs[0, 0, next_word]
        gamma_blk[0] = -1e10
    else:
        # 对于长度小于n的CTC字符串,无法得到长度为n的规整字符串
        # 严格说要将[0 : prev_length-1]都设置为0
        # 但是之前的部分不参与迭代的计算
        gamma_nbk[prev_length-1] = -1e10
        gamma_blk[prev_length-1] = -1e10

    # 确认计算开始的位置
    start = max(1, prev_length)
    psi = gamma_nbk[start-1]

    for t in range(start, seq_len):
        # 更新p^n_t(h)为p^n_{t-1}(h)和phi的和
        # 即上一时刻得到h的概率乘上当前位置重复符号的概率
        # 加上上一时刻得到g的概率
        gamma_nbk[t] = torch.logaddexp(
            gamma_nbk[t-1], log_phi[t-1]
        ) + ctc_probs[0, t, next_word]
        # 更新p^b_t(h)为p^b_{t-1}(h)和p^n_{t-1}(h)的和
        # 即在上一时刻就得到h的概率
        # 再乘上当前位置预测得到空串的概率
        gamma_blk[t] = torch.logaddexp(
            gamma_blk[t-1], gamma_nbk[t-1]
        ) + ctc_probs[0, t, const.PAD]

        # 输出的序列概率值为上一时刻得到h的概率
        # 加上上一时刻得到g,当前位置预测c的概率。
        psi = torch.logaddexp(
            psi,
            log_phi[t-1] + ctc_probs[0, t, next_word]
        )

    return psi, gamma_nbk, gamma_blk


@torch.no_grad()
def hybrid_decoder(
    m_transformer: nn.Module,
    m_embed: nn.Module,
    m_mlp: nn.Module,
    ctc_head: nn.Module,
    embed_src_seq: Tensor,
    src_padding_mask: Tensor,
    mask_list,
    beam_size: int,
    max_len: int,
    device: torch.DeviceObjType,
    ctc_weight: int = 0.1
):
    
    assert embed_src_seq.shape[0] == 1
    assert embed_src_seq.shape[2] == m_embed.emb_size

    encode_memory = m_transformer.encode(
        embed_src=embed_src_seq, src_mask=None,
        src_key_padding_mask=src_padding_mask,
    )

    # 获得输入序列的长度
    seq_len = (~src_padding_mask).sum(dim=1)[0]

    # 计算得到CTC头部的预测结果
    ctc_probs = ctc_head.mlp(encode_memory).log_softmax(-1)

    # 初始化CTC中间变量
    gamma_nbk = torch.full((seq_len, ), -1e6, device=device)
    gamma_blk = torch.full((seq_len, ), -1e6, device=device)
    
    for i in range(seq_len):
        if i == 0:
            gamma_blk[i] = ctc_probs[0, i, const.PAD]
        else:
            gamma_blk[i] = gamma_blk[i-1] + ctc_probs[0, i, const.PAD]

    # 初始化所有假设
    total_hypos = [
        (
            torch.tensor([const.BOS], dtype=torch.long, device=device),    
            torch.tensor(0, device=device),    
            (gamma_nbk, gamma_blk),
        )
    ]

    eos_hypos = []

    # 限制解码长度
    decode_len = min(max_len, seq_len)
    
    for decode_idx in range(1, decode_len):
        
        # 记录每次迭代得到的K**2个hypos
        running_hypos = []
        
        for cur_hyp in total_hypos:
            # 获得前缀开始处理
            prefix = cur_hyp[0].unsqueeze(0) # [1, decode_idx]

            embed_prefix = m_embed(prefix)

            attn_mask = mask_list[prefix.shape[1] - 1]
            
            probs = m_transformer.decode(
                embed_tgt=embed_prefix, memory=encode_memory,
                tgt_mask=attn_mask,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=src_padding_mask,
            )
            
            probs = m_mlp(probs[0, -1, :]).log_softmax(dim=-1)

            best_score, best_idx = probs.topk
                k=beam_size, dim=-1
            )  

            for k_idx, next_word in enumerate(best_idx):
                
                # 对于每一种可能的结果打分
                
                attn_score = cur_hyp[1] + best_score[k_idx]

                ctc_score, gamma_nbk, gamma_blk = \
                    ctc_prefix_score(
                        ctc_probs, seq_len,
                        prefix=cur_hyp[0][1:],
                        next_word=next_word,
                        prev_state=cur_hyp[2]
                    )
                
                total_score = (1 - ctc_weight) * attn_score \

                            + ctc_weight * ctc_score

  
                
                total_seq = torch.cat([cur_hyp[0], next_word.unsqueeze(0)])
                if next_word.data != const.EOS:
                    running_hypos.append((total_seq, total_score))    
                else:
                    # 找到了EOS
                    eos_hypos.append((total_seq, total_score))

        # 从K**2个hypo更新总的hypo
        total_hypos = sorted(running_hypos, key=lambda x: x[1], reverse=True)[:beam_size]
        # 结束寻找
        if len(eos_hypos) > beam_size * 2: break
        
    if len(eos_hypos) == 0:
        # 如果没有EOS,就从total里面抓一部分出来
        eos_hypos.extend(
            sorted(total_hypos, key=lambda x: x[1], reverse=True)[:2*beam_size]
        )

    # 添加length penalty    
    best_hypo = sorted(
        eos_hypos, key=lambda x: x[1] / (x[0].shape[0] ** 0.6), reverse=True
    )[0]

    pred_seq = best_hypo[0].unsqueeze(0)
    pred_seq_len = torch.tensor(pred_seq.shape[1], device=device).unsqueeze(0
    
    return pred_seq, pred_seq_len, encode_memory

代码参考包括:

  1. prefix beam search
  2. SSL AVSR
  3. ESPNet
Lei Yang
Authors
PhD Candidate in SJTU
Lei Yang received the B.Eng. degree from Shanghai Jiao Tong University (SJTU) in 2021. He is currently pursuing the Ph.D. degree with the SEIEE, SJTU, under the supervision of Professor Shilin Wang. His research interests include computer vision, visual speech recognition, and semantics segmentation.