最近讯飞发布了星火大模型2.0的API,我为此开发了一个简易的SDK,方便大家快速接入和开发。国内有这样的大模型API对于普通开发者来说是个不错的消息,讯飞还提供了大量免费的token,正好可以用来练习和实践。希望我们国内的大型模型能不断优化,创造出更多优秀的产品,为用户提供更好的服务。
二、申请星火认知大模型应用
首先访问官网申请API
xinghuo.xfyun.cn/sparkapi
个人一年内有免费试用200w 的token数
然后到服务管理面板获取申请到的
app_id、api_key、api_secret
应用信息
最后就是熟悉下Web对接方式的接口文档
www.xfyun.cn/doc/spark/W…
三、制作简易版SDK
确认请求方式
官方API提供的是websocket连接的方式来进行通信。
什么是websocket?
WebSocket是一种在单个TCP连接上进行全双工通信的协议
。它为客户端和服务器之间的双向通信提供了一种更简单的方法,可以使数据在一个持久连接上进行交换,而不需要客户端不断发起HTTP请求。它与传统的HTTP请求-响应模式不同,可以实现服务器向客户端推送数据的功能,而不需要客户端发出请求。
具体实现原理可以查看这篇博客,写的挺不错:
juejin.cn/post/702096…
python支持websocket通信常用的库有
讯飞官方也提供了基于 websocket-client 库的DEMO案例,大家感兴趣可以下载看看。
xfyun-doc.cn-bj.ufileos.com/static%2F16…
但官方的DEMO好像不支持 asyncio,因此我打算用websockets与aiohttp库简单的重新封装下。
websocket请求响应示例
这里先展示下这两个库该如何发送websocket请求与处理响应。
websockets
ws服务端demo
import asyncio
import websockets
async def hello(websocket):
name = await websocket.recv()
print(f"<<< {name}")
greeting = f"Hello {name}!"
await websocket.send(greeting)
print(f">>> {greeting}")
async def main():
print("ws server run on localhost:8765")
async with websockets.serve(hello, "localhost", 8765):
await asyncio.Future()
if __name__ == "__main__":
asyncio.run(main())
ws客户端demo
import asyncio
import websockets
async def hello():
uri = "ws://localhost:8765"
async with websockets.connect(uri) as websocket:
name = input("What's your name? ")
await websocket.send(name)
print(f">>> {name}")
greeting = await websocket.recv()
print(f"<<< {greeting}")
if __name__ == "__main__":
asyncio.run(hello())
demo运行效果
aiohttp
ws服务端
from aiohttp import web
app = web.Application()
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.WSMsgType.text:
if msg.data == 'close':
print('websocket connection closed')
await ws.close()
else:
print(f"recv data >>> {msg.data}")
await ws.send_str(f"Echo: {msg.data}")
elif msg.type == web.WSMsgType.error:
print(f'ws connection closed with exception {ws.exception()}')
return ws
app.router.add_get('/ws_demo', websocket_handler)
if __name__ == '__main__':
web.run_app(app, host="localhost", port=8080)
ws客户端
import aiohttp
import asyncio
async def ws_demo(session):
async with session.ws_connect('ws://localhost:8080/ws_demo') as ws:
test_data_list = ["hello ws", "close"]
for test_data in test_data_list:
print("send", test_data)
await ws.send_str(test_data)
msg = await ws.receive()
print("recv", msg)
async def main():
async with aiohttp.ClientSession() as session:
await ws_demo(session)
if __name__ == '__main__':
asyncio.run(main())
可以通过 aiohttp.ClientSession().ws_connect()
进行ws连接。
Demo运行效果
封装简易SDK
接受用户问题
组织api请求参数,
api鉴权,获取鉴权后的ws url
建立ws连接,发送ws请求
星火客户端初步封装
import base64
import hashlib
import hmac
import uuid
import json
from datetime import datetime
from time import mktime
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
import aiohttp
import websockets
class SparkChatConfig(BaseModel):
"""星火聊天配置"""
domain: str = Field(default="generalv2", description="api版本")
temperature: float = Field(
default=0.5,
ge=0, le=1,
description="取值为[0,1],默认为0.5, 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高"
max_tokens: int = Field(default=2048, le=8192, ge=1, description="模型回答的tokens的最大长度")
top_k: int = Field(default=4, le=6, ge=1, description="从k个候选中随机选择⼀个(⾮等概率)")
class SparkClient:
SERVER_URI_MAPPING = {
"general": "ws://spark-api.xf-yun.com/v1.1/chat",
"generalv2": "ws://spark-api.xf-yun.com/v2.1/chat",
def __init__(
self,
app_id: str,
api_secret: str,
api_key: str,
chat_conf: SparkChatConfig = None
self.app_id = app_id
self.api_secret = api_secret
self.api_key = api_key
self.chat_conf = chat_conf or SparkChatConfig()
self.server_uri = self.SERVER_URI_MAPPING[self.chat_conf.domain]
self.answer_full_content = ""
def build_chat_params(self, msg_context_list=None, uid: str = None):
"""构造请求参数"""
def _parse_chat_response(self, chat_resp: str) -> SparkMsgInfo:
"""解析chat响应"""
def get_sign_url(self, host=None, path=None):
"""获取鉴权后url"""
async def achat(self, msg_context_list: list, uid: str = None):
chat_params = self.build_chat_params(msg_context_list, uid)
sign_url = self.get_sign_url()
async with websockets.connect(sign_url) as ws:
await ws.send(chat_params)
async for chat_resp in ws:
spark_msg_info = self._parse_chat_response(chat_resp)
yield spark_msg_info
SparkClient 初始化的基本属性是API服务的申请应用信息与api密钥
app_id
api_secret
api_key
然后也可以初始化聊天对话的配置,默认None使用默认的对话配置
class SparkChatConfig(BaseModel):
"""星火聊天配置"""
domain: str = Field(default="generalv2", description="api版本")
temperature: float = Field(
default=0.5,
ge=0, le=1,
description="取值为[0,1],默认为0.5, 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高"
max_tokens: int = Field(default=2048, le=8192, ge=1, description="模型回答的tokens的最大长度")
top_k: int = Field(default=4, le=6, ge=1, description="从k个候选中随机选择⼀个(⾮等概率)")
然后根据 domain 参数获取对应api版本的ws的url。
对外的功能就是 achat 对话
async def achat(self, msg_context_list: list, uid: str = None):
chat_params = self.build_chat_params(msg_context_list, uid)
sign_url = self.get_sign_url()
async with websockets.connect(sign_url) as ws:
await ws.send(chat_params)
async for chat_resp in ws:
spark_msg_info = self._parse_chat_response(chat_resp)
yield spark_msg_info
msg_context_list 用户提问的上下文信息列表
msg_context_list = [
{"role": 'user', "content": content},
这里使用的是websockets来进行ws连接处理,由于星火返回的数据是一段一段,因此这里使用 async for 来处接受返回的数据,然后调用 _parse_chat_response
方法处理聊天的数据,最后使用yield返回(异步生成器)。OK,到这里初步结构已经好了,接下来就是具体实现了。
构造对话聊天参数
def build_chat_params(self, msg_context_list=None, uid: str = None):
"""构造请求参数"""
return json.dumps({
"header": self._build_header(uid=uid),
"parameter": self._build_parameter(),
"payload": self._build_payload(msg_context_list)
这里分别通过三个方法一起构造请求参数信息,分别是
_build_header 请求头部信息 (应用、用户信息)
_build_parameter 请求参数信息(对话的配置)
_build_payload 请求载体信息 (问题内容)
def _build_header(self, uid=None):
return {
"app_id": self.app_id,
"uid": uid or uuid.uuid4().hex
def _build_parameter(self):
return {
"chat": {
"domain": self.chat_conf.domain,
"temperature": self.chat_conf.temperature,
"max_tokens": self.chat_conf.max_tokens,
"top_k": self.chat_conf.top_k
def _build_payload(self, msg_context_list: list):
return {
"message": {
"text": msg_context_list
具体组织的信息就是用应用服务配置、聊天配置、以及用户传的问题信息。这样封装看起来就非常的清晰,也好在不同方法中扩展信息,不然一个大字典的组织不美观。
获取鉴权后的ws地址
应该是先获取鉴权后的url再构造请求参数,其实都可以,不影响。
def get_sign_url(self, host=None, path=None):
"""获取鉴权后url"""
host = host or urlparse(self.server_uri).hostname
path = path or urlparse(self.server_uri).path
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
sign_url = self.server_uri + '?' + urlencode(v)
return sign_url
这里的host、path其实是你服务器获取到请求的host与path,如果不给的话,默认使用的是讯飞api的host、path。这里鉴权流程如下
通过 host date request-line 信息进行hmac-sha256进行加密然后进行base64编码
然后把加密后得到的信息 与 api_key 再次进行base64编码
最后把鉴权的参数拼接到ws的uri上生成新的sign_url
发送请求处理响应
async def aiohttp_chat(self, msg_context_list: list, uid: str = None):
chat_params = self.build_chat_params(msg_context_list, uid)
sign_url = self.get_sign_url()
async with aiohttp.ClientSession() as session:
async with session.ws_connect(sign_url) as ws:
await ws.send_str(chat_params)
async for chat_resp in ws:
spark_msg_info = self._parse_chat_response(chat_resp.data)
yield spark_msg_info
async def achat(self, msg_context_list: list, uid: str = None):
chat_params = self.build_chat_params(msg_context_list, uid)
sign_url = self.get_sign_url()
async with websockets.connect(sign_url) as ws:
await ws.send(chat_params)
async for chat_resp in ws:
spark_msg_info = self._parse_chat_response(chat_resp)
yield spark_msg_info
def _parse_chat_response(self, chat_resp: str) -> SparkMsgInfo:
"""解析chat响应"""
chat_resp = json.loads(chat_resp)
code = chat_resp["header"]["code"]
if code != 0:
raise ValueError(f"对话错误,{chat_resp}")
text_list = chat_resp["payload"]["choices"]["text"]
answer_content = text_list[0]["content"]
self.answer_full_content += answer_content
spark_msg_info = SparkMsgInfo()
status = chat_resp["header"]["status"]
sid = chat_resp["header"]["sid"]
spark_msg_info.msg_sid = sid
spark_msg_info.msg_status = status
spark_msg_info.msg_content = answer_content
if status == SparkMessageStatus.END_RET.value:
usage_info = chat_resp["payload"]["usage"]["text"]
spark_msg_info.usage_info = usage_info
spark_msg_info.msg_content = self.answer_full_content
self.answer_full_content = ""
return spark_msg_info
解析响应其实就是获取星火回答的内容并组装成我们自己定义的格式 SparkMsgInfo
星火返回的格式内容如下
"header":{
"code":0,
"message":"Success",
"sid":"cht000cb087@dx18793cd421fb894542",
"status":2
"payload":{
"choices":{
"status":2,
"seq":0,
"text":[
"content":"我可以帮助你的吗?",
"role":"assistant",
"index":0
"usage":{
"text":{
"question_tokens":4,
"prompt_tokens":5,
"completion_tokens":9,
"total_tokens":14
封装的内容如下
class SparkMessageStatus(Enum):
星火消息响应状态
0-代表首个文本结果;1-代表中间文本结果;2-代表最后一个文本结果
FIRST_RET = 0
MID_RET = 1
END_RET = 2
class SparkMsgInfo(BaseModel):
"""星火消息信息"""
msg_sid: str = Field(default=uuid.uuid4().hex, description="消息id,用于唯一标识⼀条消息")
msg_type: str = Field(default="text", description="消息类型,目前仅支持text")
msg_content: str = Field(default="", description="消息内容")
msg_status: SparkMessageStatus = Field(default=SparkMessageStatus.FIRST_RET, description="消息状态")
usage_info: Optional[SparkChatUsageInfo] = Field(default=None, description="消息使用信息")
最后有一个判断就是对话消息状态为 2 代表最后一个文本结果的时候,我把之前的回复的内容拼接到了 answer_full_content 中去,然后就是获取消息token的使用信息后再返回。
这里顺便把aiohttp请求的方式也写了下,主要是当练习用的,SDK连接的方式最好是确认一种方式好,不要两种混用,我一开始不知道aiohttp也可以websocket通信,所以封装的时候使用websockets库,后面才发现aiohttp也支持异步的websocket通信,要说功能性的话感觉还是要使用aiohttp,因为后面可能还要封装http请求的api。
四、使用体验
import asyncio
import random
from spark_ai_sdk.client import SparkClient
from spark_ai_sdk.config import SparkMsgRole, SparkChatConfig, SparkMsgInfo
def build_user_msg_context_list(content):
msg_context_list = [
{"role": SparkMsgRole.USER.value, "content": content},
return msg_context_list
async def main():
chat_conf = SparkChatConfig(domain="generalv2", temperature=0.5, max_tokens=2048, top_k=3)
spark_client = SparkClient(
app_id="",
api_secret="",
api_key="",
chat_conf=chat_conf
questions = ["程序员如何技术提升?", "如何提升系统并发", "如何找女朋友"]
ques = random.choice(questions)
msg_context_list = build_user_msg_context_list(content=ques)
answer_full_content = ""
async for chat_resp in spark_client.achat(msg_context_list):
chat_resp: SparkMsgInfo = chat_resp
answer_full_content += chat_resp.msg_content
print(chat_resp)
print(answer_full_content)
if __name__ == '__main__':
asyncio.run(main())
五、封装总结
做事情不要太着急,写代码也是, 古话说的好,磨刀不误砍柴工 。
首先熟悉API文档,确认请求方式与鉴权、下载示例Demo观摩学习体验下。
调研你不熟悉的领域,例如python 如何进行 websocket 通信,利用Google搜索查询资料,学习一些Demo,获取关键信息,然后逐渐扩展知识面,了解相关的技术,再度扩张,例如 webscokets、aiohttp库具体使用,还是要看官方文档才是最新、最权威的,这时就可以去pypi、github去查找这些开源库学习官方文档和教程。
学习下别人写的一些开源库,可以获得一些灵感,最后就是让代码组织自己的想法去实现。
我一开始的初始想法,就是消息的上下文让调用方自行组织比较好,然后就是数据格式使用pydantic进行封装组织,这样比字典更好维护。后续可能会继续扩展,存储对话的上下文,简化组织消息格式。例如支持本地内存的形式或者Redis的形式进行存储对话上下文,由于token的限制,还可以指定一些存储对话上下文的策略,例如
一次会话只保留最近30对话
总结压缩会话等。
大家也可以去学习下其他优秀开源项目的实践
MetaGPT github.com/geekan/Meta…
langchain github.com/langchain-a…
六、源代码
欢迎大家一起贡献学习。
Github:github.com/HuiDBK/Spar…