小语言模型输入法设计

Posted by wyj on January 8, 2025

背景

在去年的六月,我忽然产生了一个idea:有没有一种可能,可以制作小语言模型的输入法?即使用语言模型产生的预测词汇作为输入法的选项。LLM一方面成本太高,另一方面太慢、硬件要求太高,可能没法在终端设备上跑起来;但是输入法由于有着输入拼音字母的限制,完全不需要做到LLM那么智能,就能预测好用户想要的下一个词汇。相比于语音输入,目前使用键盘的输入效率是非常低的,而语音输入并非所有场合都适用,因此输入法应该还有不少提升的空间。

虽说有了这样的想法,我当时并没有自己实现它的打算,主要是由于我对于AI并不感兴趣,也没怎么了解过相关技术。然而到了10月,我发现形势出现了巨大的变化:由于我想要申请CS方面的PhD,我发现很多我感兴趣的老师最近做的项目都和AI有关系,如果再像找暑研时一样把AI当做黑名单关键词,会给我自己带来很大的障碍。因此,我开始认为了解神经网络的基本知识、学会自己训练AI是必要的生存技能。那么该训练什么AI呢?我回想起来这个六月的idea,觉得还挺感兴趣的,因此想要实践一下。

tlt对于AI比较了解,因此我首先询问了tlt的意见。由于我之前没有深度学习的基础知识,tlt认为至少需要投入300小时才能完成这个项目。这真出乎我的想象,让我有点望而却步。然而那段时间申请的事情已经基本结束了,我实在是太无聊了,最终还是下决心开始实现这个项目。

准备

通过询问chatGPT,我对于Transformer的工作原理有了一个基本的了解。我认为,我需要做到一个双输入的、预测下一个token的语言模型:一个输入为上文,另一个输入为用户输入的首字母序列:这也被分为一些token,不会很多,因为中文的音节很少,仅有400个左右;前缀的种类数自然也不多。其预测的下一个token必须要满足首字母序列的第一个音节,且在此基础上,尽量符合前文和后续的首字母序列。比如:输入可能是“今天打算去”和“gy”,则模型理应回答“公”;一般的输入法不考虑上下文,可能因为“关于”的词频比“公园”更高,输入gy就会产生“关于”这一预测词,而这显然不对;而一般的LLM可能会回答“散步”,因为没有拼音gy的限制。

在开始做之前,我先列了一个todo list:

  1. 获取中文的Tokenizer和训练数据,使用pypinyin标注上拼音。由于该应用的特殊性,很可能只能一个字一个token。
  2. 编写简易的分词算法,把连续的拼音输入分为若干个字对应的拼音前缀;列出每个拼音前缀的全部可行token。
  3. 参考他人代码编写训练过程,但是要改用自定义的loss函数:如果预测的字和拼音不一样就额外惩罚它。
  4. 训练,得到模型。
  5. 进行推理测试,判断其智能程度;将预测下一个token概率的生成环节封装成生成预测词列表(反复预测直到耗尽输入的拼音序列),可能还有对应的概率。
  6. 压缩该模型,使得在普通电脑/手机上可用;且生成下一个token达到实时速度。

(事后来看,我并没有完全按照todo list执行,偷懒跳过了不少步骤,也受到了相应的惩罚)

实现

为了训练一个自己的模型,chatGPT给出的帮助还远远不够,需要对于Transformer有更多的了解。我主要是使用Hugging Face学到了这些知识。我打算使用GPT2模型来训练,因此主要参考了chinese_gpt2来实现中文GPT2的训练。

首先是下载数据,我使用了chinese_gpt2中提到的中文数据集。我靠数据是百度网盘的啊,这不得下到天荒地老?还是自己找找吧。但是,我搜到一个奇怪的百度网盘青春版,一辈子可以下载3次。那就下一次吧。太大了,要我下app。但好歹app里能下。这下载10MB/s都比我从手机数据线传到电脑快了。

得到15G的中文数据集之后,需要tokenize,并且切割为长度为context_length的段落,作为可以批量处理的数据。我选择按照教程把上下文长度设置为512。这个过程运行了半个小时。然后是训练集的制作,就是把一些汉字随机地替换为其拼音。这与NLP中的mask类似,按照经验是遮掉15%;我打算每一段遮掉1~10个词,遮6段,这其实还不到遮住15%,只遮了10%不到(但考虑到是连续遮蔽,很可能遮不了太多);每个词都随机取前缀或首字母,遮掉词更多首字母越可能。运行一个半小时之后,训练集制作完成了。

然后我就准备训练模型了。我本来打算加入按照拼音限制额外惩罚的loss函数,我都让chatGPT把代码写好了,但最后还是懒得修改代码加入这个自定义的loss函数,要改的似乎太多了;我想就算没有这样的限制,模型应该也能学到拼音与汉字之间的对应关系。

按照几个教程的指导,训练很快就顺利开始了。我的模型有0.1B个参数,在两块3090上训练需要耗时数天。在训练的同时,我编写代码完成了拼音输入的分词,我选择简单地用dp实现;此外还需要将模型生成预测词的过程封装起来,并且编写一个简单的交互界面实现输入过程;这都没有什么难度。最后,在一周之内我的模型就做好了,比tlt的预估要容易不少。

测试

我失望地发现,我的模型并不能满足我的需求。在我的电脑上,如果使用CPU推理,把10个拼音转变为汉字就需要3秒多;就算使用GPU也要1秒钟。这不能满足输入法实时生成预测词汇的要求。并且这个模型没能学到较为罕见的词,如“色盲”等;就算只是常用词,我发现首字母预测的难度似乎远超我的想象,它可以生成首字母和我输入相同的通顺句子,但是并非我想要的输入。不仅如此,虽然说我的训练数据之中将汉字替换为拼音的区间都结束在完整的词汇处,但由于我采取的方法本身的限制,模型生成的预测与人类总是一次性输入完整词组的输入习惯并不相符。

以《三体》中的一段作为例子:正确的句子是“飞向未知的太空深处”,而我的模型根据首字母生成的首个预测是“发现完整的太空是从”,而我电脑上的搜狗输入法产生的预测是“防汛物资的听课手册”。由此可见,与输入法相比,我的模型预测出的句子确实更加通顺,且与上下文相符,但还远不足以令人满意。

我尝试优化这一模型在PC上的推理速度,llama.cpp这样的项目的存在让我相信实际上可以加快非常多。比如,它展示运行13B参数的llama能达到16token/s的生成速度,那我这0.1B的小模型应该一秒几百到几千token没问题吧?我问chatGPT该怎样才能加速,它提出可以用模型量化、裁剪等手段减小模型,也能用ONNX加速推理。llama.cpp看上去对于运行的模型过度限制,可能要大改代码;按照chatGPT的建议,我选择尝试用ONNX Runtime加速。量化之后,模型的尺寸确实大幅下降了;然而让我失望的是,执行速度几乎没有变快。优化的尝试失败了。

related work

snz一开始认为这个问题太过简单,然而看到我失败之后,他决定把这个问题作为NLP课程的project。由于是正式的课程作业,snz在开工之前调研了related work,找到了一些已有的论文;而我只是让chatGPT搜索了一下,但这个愚蠢的chatGPT什么都没能找到,因此我误以为这个问题没什么人研究过。阅读了现有的论文,特别是2022年的Exploring and Adapting Chinese GPT to Pinyin Input Method这篇文章之后,我意识到我的这个想法早就被人实现过了,而且实现得还比我好很多:

  • 这篇文章制作训练集的方法更加科学,不是将汉字替换为拼音,而是将原有的汉字token接在拼音token之后,使得模型可以学习到要生成完整的词组;
  • 通过修改position embedding,把汉字与拼音设置为同一个position,让模型更好地学习到汉字和拼音的对应关系;
  • 由于训练时使用了按照拼音限制额外惩罚的loss函数(我偷懒没加的那个),模型效果有很大提升。根据论文中的数据,我这样不设置特殊loss函数的模型效果还不如直接选择GPT的输出中符合拼音限制的最高概率的一项呢!我变成了人家的人形ablation study。

当然,我搞这个项目本身也不是什么正经的科研,只是为了了解深度学习与AI训练过程的一个练手项目而已,我本来也不指望能够有很好的效果。这个思考与尝试的过程才是关键,如果只是按照别人的论文照猫画虎一遍的话,学习效果肯定是要大打折扣的。

而snz的实现则更加关注在低性能设备上的可用性:为了加快推理速度,他选择使用尺寸小很多的模型,并且使用了过量的训练数据来在模型尺寸的限制下尽量提升效果。snz还实现了保存past_key_values(这和普通的无限制文本生成并不相同,需要单独设计一下,因此我没有尝试去做),不再需要每次重复计算attention,据他说这能让模型的推理提速许多倍。为了让预测结果更加实用,snz的程序不是像我一样仅仅链式地生成最高概率的预测,而是使用beam search去搜索出前k个概率最高的预测词汇;当然这也大大提升了计算量。

讨论

在做出尝试之前,我和snz都以为按照拼音限制的文本生成应该不是一个特别困难的问题,使用参数量较小的模型就足够了;然而实际尝试之后,我们都发现这个任务比想象中难不少。看来在客户端运行AI模型的推理实在是不太可行,我认为可能还是需要稍微大一些的模型,并且在服务器上执行推理,才能达成接近实时生成智能词汇预测的目的:就像在Chrome的地址栏里输入时下方产生的预测结果那样。

有限制的文本生成并不仅限于拼音输入法,许多文字游戏也有这个需求;比如经常被用来测试大模型智能的lipogram编写问题,还有限制更加严格的anagram。我这学期的计算思维课程展示就是做的anagram,我发现相关论文中使用的手段也和用LLM做拼音输入法较为类似。当然,拼音输入法比起这些问题更加实用一些。