一、前言
早前的文章,我们都是通过输入命令的方式来使用Chatglm3-6b模型。现在,我们可以通过使用gradio,通过一个界面与模型进行交互。这样做可以减少重复加载模型和修改代码的麻烦,
让我们更方便地体验模型的效果。
二、术语
2.1、Gradio
是一个用于构建交互式界面的Python库。它使得在Python中创建快速原型、构建和共享机器学习模型变得更加容易。
Gradio的主要功能是为机器学习模型提供一个即时的Web界面,使用户能够与模型进行交互,输入数据并查看结果,而无需编写复杂的前端代码。它提供了一个简单的API,可以将输入和输出绑定到模型的函数或方法,并自动生成用户界面。
三、前置条件
3.1. windows or linux操作系统均可
3.2. 下载chatglm3-6b模型
从huggingface下载:https://huggingface.co/THUDM/chatglm3-6b/tree/main
从魔搭下载:魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/fileshttps://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files
3.3. 创建虚拟环境&安装依赖
conda create --name chatglm3 python=3.10conda activate chatglm3pip install protobuf transformers==4.39.3 cpm_kernels torch>=2.0 sentencepiece acceleratepip install gradio
四、技术实现
# -*- coding = utf-8 -*-import gradio as grimport torchfrom threading import Threadfrom transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer)modelPath = "/model/chatglm3-6b"def loadTokenizer(): tokenizer = AutoTokenizer.from_pretrained(modelPath, use_fast=False, trust_remote_code=True) return tokenizerdef loadModel(): model = AutoModelForCausalLM.from_pretrained(modelPath, device_map="auto", trust_remote_code=True).cuda() model = model.eval() return modelclass StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [0, 2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return Falsedef parse_text(text): lines = text.split("/n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'<pre><code class="language-{items[-1]}">' else: lines[i] = f'<br></code></pre>' else: if i > 0: if count % 2 == 1: line = line.replace("`", "/`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "<br>" + line text = "".join(lines) return textdef predict(history, max_length, top_p, temperature): stop = StopOnTokens() messages = [] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1.2, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token != '': history[-1][1] += new_token yield historywith gr.Blocks() as demo: gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True) def user(query, history): return "", history + [[parse_text(query), ""]] submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( predict, [chatbot, max_length, top_p, temperature], chatbot ) emptyBtn.click(lambda: None, None, chatbot, queue=False)if __name__ == '__main__': model = loadModel() tokenizer = loadTokenizer() demo.queue() demo.launch(server_name="0.0.0.0", server_port=8989, inbrowser=True, share=False)
调用结果:
启动成功:
GPU使用情况:
浏览器访问:
推理:
五、附带说明
5.1. 问题:AttributeError: 'ChatGLMTokenizer' object has no attribute 'apply_chat_template'
1. transformers的版本太低,需要升级
pip install --upgrade transformers==4.39.3
5.2. 界面无法打开
1. 服务监听地址不能是127.0.0.1
2. 检查服务器的安全策略或防火墙配置
服务端:lsof -i:8989 查看端口是否正常监听
客户端:telnet ip 8989 查看是否可以正常连接