在这篇文章中,我们介绍并分解了Distil Whisper:一个新版本,它为音频转录提供了高达6倍的Whisper模型运行速度。
深度学习技术一直在快速发展,并已成为我们日常生活中的关键参与者,尤其是在这个语音到文本应用的时代。无论是为自动人工智能呼叫系统、SIRI或Alexa等语音助手供电,还是与搜索引擎无缝集成:这一功能都显著增强了用户体验。它的广泛采用使它成为我们生活中不可或缺的一部分。
作为开源人工智能领域的有力竞争者,音频语音识别(ASR)模型Whisper by OpenAI获得了巨大的人气。它的有效性水平与其他生产级模型相当,同时用户可以零成本访问。此外,它还为用户提供了一系列预先训练的模型,以利用人工智能的力量转录和翻译任何音频片段。
在这篇文章中,我们将看看最近发布的Distil Whisper项目。Whisper型号的最新迭代提供了高达6倍的运行速度。在本文中,我们将更深入地研究这个模型版本,是什么使它成为可能,然后以代码演示结束。
花点时间浏览Paperspace提供的关于Whisper的综合文章。此外,请点击演示链接,利用Paperspace的免费GPU服务亲身体验该模型。
什么是知识蒸馏(KD)?
在我们深入研究模型本身之前,让我们讨论一下是什么使Distil Whisper的加速成为可能。知识提炼(KD)是指训练一个较小且计算高效的模型的过程,也称为试图模仿较大且更复杂模型或教师行为的学生。从本质上讲,它是一种模型压缩形式,有助于将知识从一个较大的模型转移到训练一个较小的模型,而不会有任何显著的性能损失。在这里,知识指的是学习到的权重和偏差,它们代表了训练模型中的模式理解。
大模型又称教师,接受感兴趣任务的训练,如NLP任务、图像识别等。这种深度学习模型在计算上非常昂贵。接下来,创建一个学生模型,并对其进行相同任务的训练,该模型保留了教师模型的知识。在这里,关键思想是使用教师的模型预测,即软化的概率或logits,作为训练学生模型的目标。
在训练过程中,学生模型不仅旨在模仿教师模型的最终预测,还旨在模仿嵌入中间步骤的知识。这种知识转移有助于学生模型更好地概括和执行任务,同时降低复杂性。
这种模型蒸馏已被证明在模型大小和计算要求方面显著减少,性能退化最小甚至没有。
在Distil Whisper的情况下,教师模型为Whisper,学生模型为Distil Whipper。两个模型共享相同的Seq2Seq架构,但维度不同。
Distil模型
现在,让我们来看看Distil Whisper模型本身。首先,重要的是要了解新模型发布与原始模型发布的区别。以下简要讨论了研究论文中提出的压缩模型的主要变化:
收缩和微调:对于Distilled模型,研究人员实现了基于层的压缩。这是通过从教师模型中最大间隔的层复制权重来初始化学生模型来完成的。例如,当基于32层教师模型建立2层学生模型时,从教师到学生的第一层和第32层的权重被复制。
伪标记:这种形式的蒸馏也可以被视为“序列级”KD,在这个过程中,知识以序列的形式转移到学生模型中。该序列在伪标签中生成。
Kullback-Leibler散度(Kullback-Leibler Divergence:):在KL散度中,学生模型的完全概率分布被训练为与教师模型的分布一致。这种对齐是通过最小化在第i个位置的整个潜在下一个令牌集合上的Kullback-Leibler(KL)分歧来实现的。这可以被解释为“单词级”知识提炼,其中知识通过与潜在标记相关联的logits从教师模型传递到学生模型。
Distil Whisper
自然语言处理(NLP)的最新发展表明,在基于转换器的模型压缩方面取得了重大进展。已经观察到知识蒸馏(KD)在减少BERT等模型的大小方面的成功应用,而没有任何显著的性能损失。Distil Whisper是Whisper的精简版,拥有显著的增强功能-速度快6倍,体积小49%,在分发外的评估集上实现了1%的单词错误率(WER)以内的性能水平。
为了实现这一点,特别值得注意的是,训练目标被优化为最小化提取模型和Whisper模型之间的KL差异,以及在伪标记音频数据上计算的交叉熵损失。
Distil Whisper在22k小时的伪标记音频数据上进行训练,该数据由10个域组成,扬声器数量超过18k。
Distil Whisper有什么新功能?
为了确保训练只包含可靠的伪标签,引入了一种直接的启发式方法来细化伪标签训练数据集。对于每个训练样本,使用Whisper英语归一化器对Whisper生成的基本事实标签和伪标签进行归一化。一旦完成,就计算归一化的基本事实和伪标签之间的字错误率(WER)。超过给定WER阈值的样本将被丢弃。这种过滤方法提高了转录的质量和模型性能。
Whisper的原始论文介绍了一种长格式转录算法,该算法系统地转录30秒的音频片段,并根据模型预测的时间戳调整滑动窗口。在Distil Whisper中,使用了一种替代策略,将长文件音频分块成更小的片段,中间有小的重叠相邻片段。该模型处理每个块,并且通过识别重叠部分之间最长的公共序列来每隔一段时间连接推断出的文本。这种跨步有助于跨块精确转录,而不需要顺序转录。
推测解码(SD)是一种通过引入更快的辅助模型来加快自回归变换器模型推理过程的方法。通过利用更快的辅助模型进行生成,并将验证仅向前传递到主模型,解码过程经历了显著的加速。SD有助于生成与主模型生成的令牌序列相匹配的输出。使用Distil Whisper作为Whisper模型的助手也采用了相同的方法。
推测解码在确保数学输出相同的同时,提供了显著的延迟改进。这使其成为现有Whisper管道的无缝且合乎逻辑的替代品。
架构
下图是Distil Whisper模型的结构示意图。编码器以绿色表示,完全从老师复制到学生,并在培训期间保持固定。学生的解码器仅包括两个解码器层,从教师的初始和最终解码器层初始化(用红色表示)。省略了教师的所有其他解码器层。
该模型基于KL散度和PL损失项的加权组合进行训练。在推理过程中,它能够使用它来顺序地识别关于文本的潜在编码和音频的下一个最可能的令牌。首先,波形音频片段被输入到编码器模块。音频是相对于其中的时间位置进行编码的。解码器块然后能够顺序地处理编码的输入令牌。然后,解码器块将该编码与输入序列中的前一个令牌一起进行,在开始时使用序列开始(BOS)令牌,将输出解码为字符串。
能力
Distil Whisper旨在取代Whisper进行英语语音识别。Distil Whisper的功能基本上可以归结为5个主要的关键功能:
更快的推理:实现六倍的推理速度,同时将性能保持在分发外音频的Whisper的1%字错误率(WER)以内。
对噪音和幻觉的鲁棒性:情节显示,随着噪音变得更加强烈,与在LibriSpeech语料库上训练的其他模型相比,Distil Whisper的WER’S退化得不那么严重。
与Whisper相比,重复5克单词重复的次数减少了1.3倍,插入错误率(IER)降低了2.1%。这表明,与最初的Whisper模型相比,Distil Whisper的幻觉程度有所降低。大v2和distil-largev2的平均删除错误率(DER)保持可比性,性能差异约为0.3%。
专为推测解码而设计:Distil Whisper作为Whisper的辅助模型,在数学上保证Whisper模型的输出相同的同时,推理速度提高了两倍。
商业许可:Distil Whisper已获得许可,可用于商业应用。
代码演示
根据本指南,我们可以运行Distil Whisper模型,并在很短的时间内转录语音的音频样本。此外,通过使用各种Paperspace GPU,可以预期性能会得到增强。
要运行该模型,请首先安装最新版本的变形金刚库。该型号支持4.35及以上版本的变形金刚。
#Install the dependencies
!pip install --upgrade pip
!pip install --upgrade transformers accelerate datasets[audio]
简式转录
短格式转录包括转录持续时间不到30秒的音频样本,这与Whisper模型的最大感受野一致。
使用AutoModelForSpeechSeq2Seq和AutoProcessor类加载Distil Whisper。
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
接下来,将模型和处理器传递到管道
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
从LibriSpeech语料库加载数据集,
从数据集导入load_dataset
from datasets import load_dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
调用管道来转录样本音频,
result = pipe(sample)
print(result["text"])
要转录本地存储的示例音频,请确保传递文件的路径。
result = pipe("path_to_the_audio.mp3")
print(result["text"])
长格式转录
要转录长音频(超过30秒),Distil Whisper使用分块算法。在这里,我们将使用从目录中保存的长格式音频。
再次加载模型和处理器:
import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "distil-whisper/distil-large-v2" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id)
为了启用分块,我们将在管道中使用chunk_length_s参数。对于Distil Whisper,最小区块长度为15秒。为了激活批处理,请包含batch_size参数。
pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, chunk_length_s=15, batch_size=16, torch_dtype=torch_dtype, device=device, )
现在,我们将加载一个很长的音频样本,该样本已存储在目录中,以方便您使用。将路径传递到要转录的已保存音频文件。也可以随意将您选择的任何mp3样本上传到目录中,并使用此代码演示进行转录。
result = pipe('/content/I_used_LLaMA_2_70B_to_rebuild_GPT_Banker...and_its_AMAZING_(LLM_RAG).mp3') print(result["text"])
导入textwrap库,我们可以使用该库以格式化段落的形式查看结果。
import textwrap wrapper = textwrap.TextWrapper(width=80, initial_indent=" " * 8, subsequent_indent=" " * 8, break_long_words=False, break_on_hyphens=False) print(wrapper.fill(result["text"]))
推测性解码(Speculative Decoding)
推测性解码保证了与Whisper模型类似的输出,但实现速度是Whisper的两倍。这一特性使Distil Whisper成为当前Whisper管道的理想无缝替代品,确保了一致的结果,同时提高了效率。
对于推测解码,我们需要教师和学生模型。下面的代码演示了使用Paperspace平台的推测性解码。
加载教师模型“openai/whisper-large-v2”和处理器。
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor import torch device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "openai/whisper-large-v2" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id)
接下来,加载学生模型。Distil Whisper与教师模型共享完全相同的编码器,只需加载2层解码器,有效地将其视为独立的“仅解码器”模型。
from transformers import AutoModelForCausalLM assistant_model_id = "distil-whisper/distil-large-v2" assistant_model = AutoModelForCausalLM.from_pretrained( assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) assistant_model.to(device)
将学生模型传递到管道,
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
一旦完成,通过待转录的样本,
从数据集导入load_dataset
from datasets import load_dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
如需进一步优化,请使用Flash Attention 2
!pip install flash-attn --no-build-isolation
要激活Flash Attention 2,只需在初始化期间将参数use_Flash_address_2=True传递给from_pretrained函数即可。
如果不支持GPU,请使用BetterTransformers。要执行此操作,请进行优化安装。
!pip install --upgrade optimum
下面的代码将模型转换为“BetterTransformer”模型,
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
model = model.to_bettertransformer()
结束思考
在这篇文章中,我们介绍了Distil Whisper,它是Whisper的精炼和加速版本。Distil Whisper是一款令人印象深刻的车型,也是测试应用程序的优秀候选者。在分发外的长格式音频中,DistilWhisper超过了Whisper,出现幻觉和重复的次数更少。这突出了大规模伪标记在提取ASR模型中的有效性,尤其是当与我们的字错误率(WER)阈值滤波器相结合时。我们使用Paperspace平台进一步演示了Distil Whisper,并无缝地使用该模型来转录英语的长格式和短格式音频。
请务必浏览原始论文和Github项目页面,了解有关创建这个令人敬畏的模型所涉及的研究的更多信息。
References
- Original Research Paper : DISTIL-WHISPER: ROBUST KNOWLEDGE DISTILLATION VIA LARGE-SCALE PSEUDO LABELLING
- Code reference Hugging Face github repo : distil-whisper
- Whisper blog post on Paperspace: Create your own speech to text application with Whisper from OpenAI and Flask
- 登录 发表评论