CTC前缀分数计算
CTC(连接主义的时序分类)是一种在长度不同的序列中计算损失的方式。 对于不定长的模型输出和标签,在没有给定对齐的情况下计算概率和梯度,从而进行模型的训练。
在基于Attention的模型中,使用hybrid ctc+attention的方式训练,在解码过程中,进一步利用训练时CTC头部的信息,可以计算CTC前缀得分,加入到Beam Search解码中。这里前缀得分的计算方式与CTC Loss的前向后向算法的前向部分比较类似。
首先需要明确两个概念:规整字符串和CTC字符串。
在使用CTC进行训练的时候,前者表示训练集的标签,后者表示序列模型的输出。 即CTC字符串经过合并重复和去除Blank之后得到的结果。
在训练过程中,CTC损失计算可以合并得到规整字符串$y = [y_1, y_2, ..., y_N]$的所有CTC字符串$z = [z_1, z_2, ..., z_T]$的概率。
在预测的过程中,在每一个时间步$t$,给定之前$t-1$步计算得到的前缀$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算CTC字符串$z_1, z_2, ..., z_t$预测得到给定规整字符串$h=gc$的概率,因此计算过程类似CTC损失的前向部分。
在时间步$t$增加的时候,需要维护两个中间变量 $p^{b}\_{t} \in {\mathbb R}^{T}$ 和 $p^{n}\_{t} \in {\mathbb R}^{T}$。
其中 $p^{b}\_{t}$ 表示 $z[:t]$ 即以空白符号结尾的长度为 $t$ 的CTC字符串规整到目标的概率,$p^{n}\_{t}$ 表示 $z[:t]$ 即非空白符号结尾的长度为 $t$ 的CTC字符串的概率。
这一步的中间变量是为了简化计算的过程,因为在$h=gc$的时候,可以根据 $p\_{t-1}^{b}(g)$计算得到 $p^b\_{t}(h)$。
在初始化时,目标前缀字符串为空串。
- $p^{b}\_{t}=\prod_{i=0}^{t} p(z_i=blank)$,即长度为$t$的CTC字符串解析到空串的概率为每一位解析为blank的概率乘积。
- $p^{n}_{t}=0$,即任何非空白符号结尾的CTC字符串都不可能解析到空串。
在遍历时,给定前缀为$g = [y_1, y_2, ..., y_c]$和预测结果$c = y_p$,计算并返回$h=gc$的概率。
- $y_c \neq y_p$时,返回值为 $[p^{b}\_{t-1}(g)+p^n\_{t-1}(g)]p(y_p)$,即$z[:t-1]$可以规整为$g$和$g -$的概率与$z_t = y_p$的概率乘积。两者相乘表示CTC 字符串$z[:t]$可以规整为$h=gc$的概率。
- $y_c=y_p$时,返回值为$[p^b_{t-1}(g)]p(y_p)$,由于前缀$g$最后一个字符与预测字符相同,只有以blank结尾的$z[:t-1]$序列可以规整到$h=gc$。
同时需要更新中间变量,即计算$p_{t}^{b}(h)$和$p_{t}^{n}(h)$,需要分情况讨论。
- 更新 $p^{b}\_{t}(h)$:$p^{b}\_{t}(h) = [p^{b}\_{t-1}(g) + p^{n}\_{t-1}(g)]p(-)$,即$z[:t]$规整为h的概率,可以使用$z[:t-1]$规整为g的概率与当前位置预测为空的概率计算。
- 更新$p^{n}_{t}(h)$:
- 如果$y_c = y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,第一项表示以blank结尾的前缀,后一项表示重复符号的情况。
- 如果$y_c \neq y_p$时:$p_{t}^{n}(h) = \left[p_{t-1}^{b}(g) + p_{t-1}^{n}(g) + p_{t-1}^{n}(h)\right]p(y_p)$,分别表示$z[:t-1]$规整到$g$和$h$的情况。
代码
下面结合代码进行分析。
"""
一个简单的使用attention+ctc的混合解码的代码实现
"""
import torch
from torch import nn, Tensor
from datasets import const
def ctc_prefix_score(
ctc_probs, seq_len, prefix, next_word, prev_state
):
assert ctc_probs.shape[0] == 1
assert ctc_probs[0, 0].exp().sum() > 0.99, \
f"sum up as {ctc_probs[0, 0].exp().sum()}"
# 获得前缀的长度
prev_length = prefix.shape[-1]
# 获得不同长度CTC字符串得到前缀的概率p^n(g)和p^a(g)
gamma_nbk = prev_state[0].clone()
gamma_blk = prev_state[1].clone()
# 计算前缀的概率,包括结尾为blank和非blank两种情况的概率求和
prev_sum = torch.logaddexp(gamma_nbk, gamma_blk)
if prev_length > 0:
if prefix[-1] != next_word:
log_phi = prev_sum
else:
log_phi = gamma_blk
else:
log_phi = prev_sum
# 达到EOS的时候直接输出
if next_word == const.EOS:
psi = prev_sum[-1]
return psi, None, None
if prev_length == 0:
# 如果前缀为空串,可以直接计算
gamma_nbk[0] = ctc_probs[0, 0, next_word]
gamma_blk[0] = -1e10
else:
# 对于长度小于n的CTC字符串,无法得到长度为n的规整字符串
# 严格说要将[0 : prev_length-1]都设置为0
# 但是之前的部分不参与迭代的计算
gamma_nbk[prev_length-1] = -1e10
gamma_blk[prev_length-1] = -1e10
# 确认计算开始的位置
start = max(1, prev_length)
psi = gamma_nbk[start-1]
for t in range(start, seq_len):
# 更新p^n_t(h)为p^n_{t-1}(h)和phi的和
# 即上一时刻得到h的概率乘上当前位置重复符号的概率
# 加上上一时刻得到g的概率
gamma_nbk[t] = torch.logaddexp(
gamma_nbk[t-1], log_phi[t-1]
) + ctc_probs[0, t, next_word]
# 更新p^b_t(h)为p^b_{t-1}(h)和p^n_{t-1}(h)的和
# 即在上一时刻就得到h的概率
# 再乘上当前位置预测得到空串的概率
gamma_blk[t] = torch.logaddexp(
gamma_blk[t-1], gamma_nbk[t-1]
) + ctc_probs[0, t, const.PAD]
# 输出的序列概率值为上一时刻得到h的概率
# 加上上一时刻得到g,当前位置预测c的概率。
psi = torch.logaddexp(
psi,
log_phi[t-1] + ctc_probs[0, t, next_word]
)
return psi, gamma_nbk, gamma_blk
@torch.no_grad()
def hybrid_decoder(
m_transformer: nn.Module,
m_embed: nn.Module,
m_mlp: nn.Module,
ctc_head: nn.Module,
embed_src_seq: Tensor,
src_padding_mask: Tensor,
mask_list,
beam_size: int,
max_len: int,
device: torch.DeviceObjType,
ctc_weight: int = 0.1
):
assert embed_src_seq.shape[0] == 1
assert embed_src_seq.shape[2] == m_embed.emb_size
encode_memory = m_transformer.encode(
embed_src=embed_src_seq, src_mask=None,
src_key_padding_mask=src_padding_mask,
)
# 获得输入序列的长度
seq_len = (~src_padding_mask).sum(dim=1)[0]
# 计算得到CTC头部的预测结果
ctc_probs = ctc_head.mlp(encode_memory).log_softmax(-1)
# 初始化CTC中间变量
gamma_nbk = torch.full((seq_len, ), -1e6, device=device)
gamma_blk = torch.full((seq_len, ), -1e6, device=device)
for i in range(seq_len):
if i == 0:
gamma_blk[i] = ctc_probs[0, i, const.PAD]
else:
gamma_blk[i] = gamma_blk[i-1] + ctc_probs[0, i, const.PAD]
# 初始化所有假设
total_hypos = [
(
torch.tensor([const.BOS], dtype=torch.long, device=device),
torch.tensor(0, device=device),
(gamma_nbk, gamma_blk),
)
]
eos_hypos = []
# 限制解码长度
decode_len = min(max_len, seq_len)
for decode_idx in range(1, decode_len):
# 记录每次迭代得到的K**2个hypos
running_hypos = []
for cur_hyp in total_hypos:
# 获得前缀开始处理
prefix = cur_hyp[0].unsqueeze(0) # [1, decode_idx]
embed_prefix = m_embed(prefix)
attn_mask = mask_list[prefix.shape[1] - 1]
probs = m_transformer.decode(
embed_tgt=embed_prefix, memory=encode_memory,
tgt_mask=attn_mask,
tgt_key_padding_mask=None,
memory_key_padding_mask=src_padding_mask,
)
probs = m_mlp(probs[0, -1, :]).log_softmax(dim=-1)
best_score, best_idx = probs.topk(
k=beam_size, dim=-1
)
for k_idx, next_word in enumerate(best_idx):
# 对于每一种可能的结果打分
attn_score = cur_hyp[1] + best_score[k_idx]
ctc_score, gamma_nbk, gamma_blk = \
ctc_prefix_score(
ctc_probs, seq_len,
prefix=cur_hyp[0][1:],
next_word=next_word,
prev_state=cur_hyp[2]
)
total_score = (1 - ctc_weight) * attn_score \
+ ctc_weight * ctc_score
total_seq = torch.cat([cur_hyp[0], next_word.unsqueeze(0)])
if next_word.data != const.EOS:
running_hypos.append((total_seq, total_score))
else:
# 找到了EOS
eos_hypos.append((total_seq, total_score))
# 从K**2个hypo更新总的hypo
total_hypos = sorted(running_hypos, key=lambda x: x[1], reverse=True)[:beam_size]
# 结束寻找
if len(eos_hypos) > beam_size * 2: break
if len(eos_hypos) == 0:
# 如果没有EOS,就从total里面抓一部分出来
eos_hypos.extend(
sorted(total_hypos, key=lambda x: x[1], reverse=True)[:2*beam_size]
)
# 添加length penalty
best_hypo = sorted(
eos_hypos, key=lambda x: x[1] / (x[0].shape[0] ** 0.6), reverse=True
)[0]
pred_seq = best_hypo[0].unsqueeze(0)
pred_seq_len = torch.tensor(pred_seq.shape[1], device=device).unsqueeze(0)
return pred_seq, pred_seq_len, encode_memory
代码参考包括: