开源模型应用落地-chatglm3-6b-gradio-入门篇(七)

开源 0

一、前言

    早前的文章,我们都是通过输入命令的方式来使用Chatglm3-6b模型。现在,我们可以通过使用gradio,通过一个界面与模型进行交互。这样做可以减少重复加载模型和修改代码的麻烦,
让我们更方便地体验模型的效果。


二、术语

2.1、Gradio

    是一个用于构建交互式界面的Python库。它使得在Python中创建快速原型、构建和共享机器学习模型变得更加容易。

    Gradio的主要功能是为机器学习模型提供一个即时的Web界面,使用户能够与模型进行交互,输入数据并查看结果,而无需编写复杂的前端代码。它提供了一个简单的API,可以将输入和输出绑定到模型的函数或方法,并自动生成用户界面。


三、前置条件

3.1. windows or linux操作系统均可

3.2. 下载chatglm3-6b模型

从huggingface下载:https://huggingface.co/THUDM/chatglm3-6b/tree/main

从魔搭下载:魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/fileshttps://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files

 3.3. 创建虚拟环境&安装依赖

conda create --name chatglm3 python=3.10conda activate chatglm3pip install protobuf transformers==4.39.3 cpm_kernels torch>=2.0 sentencepiece acceleratepip install gradio

四、技术实现

# -*-  coding = utf-8 -*-import gradio as grimport torchfrom threading import Threadfrom transformers import (    AutoModelForCausalLM,    AutoTokenizer,    StoppingCriteria,    StoppingCriteriaList,    TextIteratorStreamer)modelPath = "/model/chatglm3-6b"def loadTokenizer():    tokenizer = AutoTokenizer.from_pretrained(modelPath, use_fast=False, trust_remote_code=True)    return tokenizerdef loadModel():    model = AutoModelForCausalLM.from_pretrained(modelPath, device_map="auto",  trust_remote_code=True).cuda()    model = model.eval()    return modelclass StopOnTokens(StoppingCriteria):    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:        stop_ids = [0, 2]        for stop_id in stop_ids:            if input_ids[0][-1] == stop_id:                return True        return Falsedef parse_text(text):    lines = text.split("/n")    lines = [line for line in lines if line != ""]    count = 0    for i, line in enumerate(lines):        if "```" in line:            count += 1            items = line.split('`')            if count % 2 == 1:                lines[i] = f'<pre><code class="language-{items[-1]}">'            else:                lines[i] = f'<br></code></pre>'        else:            if i > 0:                if count % 2 == 1:                    line = line.replace("`", "/`")                    line = line.replace("<", "&lt;")                    line = line.replace(">", "&gt;")                    line = line.replace(" ", "&nbsp;")                    line = line.replace("*", "&ast;")                    line = line.replace("_", "&lowbar;")                    line = line.replace("-", "&#45;")                    line = line.replace(".", "&#46;")                    line = line.replace("!", "&#33;")                    line = line.replace("(", "&#40;")                    line = line.replace(")", "&#41;")                    line = line.replace("$", "&#36;")                lines[i] = "<br>" + line    text = "".join(lines)    return textdef predict(history, max_length, top_p, temperature):    stop = StopOnTokens()    messages = []    for idx, (user_msg, model_msg) in enumerate(history):        if idx == len(history) - 1 and not model_msg:            messages.append({"role": "user", "content": user_msg})            break        if user_msg:            messages.append({"role": "user", "content": user_msg})        if model_msg:            messages.append({"role": "assistant", "content": model_msg})    model_inputs = tokenizer.apply_chat_template(messages,                                                 add_generation_prompt=True,                                                 tokenize=True,                                                 return_tensors="pt").to(next(model.parameters()).device)    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)    generate_kwargs = {        "input_ids": model_inputs,        "streamer": streamer,        "max_new_tokens": max_length,        "do_sample": True,        "top_p": top_p,        "temperature": temperature,        "stopping_criteria": StoppingCriteriaList([stop]),        "repetition_penalty": 1.2,    }    t = Thread(target=model.generate, kwargs=generate_kwargs)    t.start()    for new_token in streamer:        if new_token != '':            history[-1][1] += new_token            yield historywith gr.Blocks() as demo:    gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")    chatbot = gr.Chatbot()    with gr.Row():        with gr.Column(scale=4):            with gr.Column(scale=12):                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)            with gr.Column(min_width=32, scale=1):                submitBtn = gr.Button("Submit")        with gr.Column(scale=1):            emptyBtn = gr.Button("Clear History")            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)    def user(query, history):        return "", history + [[parse_text(query), ""]]    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(        predict, [chatbot, max_length, top_p, temperature], chatbot    )    emptyBtn.click(lambda: None, None, chatbot, queue=False)if __name__ == '__main__':    model = loadModel()    tokenizer = loadTokenizer()    demo.queue()    demo.launch(server_name="0.0.0.0", server_port=8989, inbrowser=True, share=False)

调用结果:

启动成功:

GPU使用情况:

浏览器访问:

推理:


五、附带说明

5.1. 问题:AttributeError: 'ChatGLMTokenizer' object has no attribute 'apply_chat_template'

1. transformers的版本太低,需要升级

pip install --upgrade transformers==4.39.3

5.2. 界面无法打开

1. 服务监听地址不能是127.0.0.1

2. 检查服务器的安全策略或防火墙配置

 服务端:lsof -i:8989 查看端口是否正常监听

 客户端:telnet ip 8989 查看是否可以正常连接

也许您对下面的内容还感兴趣: