服务端基于鸢尾花数据集训练 DecisionTreeClassification 分类模型,等待客户端传入鸢尾花 4 个特征(数组)数据,预测分类结果后,将结果传回客户端。
服务器端和客户端的 hello_grpc.proto 文件是一样的。
syntax = "proto3";
service GRPC {
rpc HelloHarrytsz(HelloReq) returns (HelloReply) {}
rpc pre(Req) returns (Reply) {}
}
message Req {
repeated double arr = 1;
}
message Reply {
repeated int32 res = 1;
}
message HelloReq {
string name = 1;
int32 age = 2;
}
message HelloReply {
string res = 1;
}
服务器端和客户端分别在 hello_grpc.proto 路径下执行以下命令,生成 hello_grpc_pb2 和 hello_grpc_pb2_grpc 文件
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. hello_grpc.proto
服务端 Mac:
# service.py
# coding:utf-8
import time
import pickle
import grpc
import hello_grpc_pb2 as pb2
import hello_grpc_pb2_grpc as pb2_grpc
from concurrent import futures
class GRPC(pb2_grpc.GRPCServicer):
def HelloHarrytsz(self, request, context):
name = request.name
age = request.age
res = f'My name is {name}, I\'m {age} year\'s old'
return pb2.HelloReply(res=res)
def pre(self, request, context):
f = open('./tree/dtc.model', 'rb') # 注意此处model是rb
s = f.read()
model = pickle.loads(s)
# a, b, c, d = request[0], request[1], request[2], request[3], request[4]
print("requ: ", request.arr)
res = model.predict([request.arr])
print("res: ", res)
return pb2.Reply(res=res)
def run():
grpc_server = grpc.server(
futures.ThreadPoolExecutor(max_workers=4)
)
pb2_grpc.add_GRPCServicer_to_server(GRPC(), grpc_server)
grpc_server.add_insecure_port('192.168.31.125:5000')
print('Server will start at 192.168.31.125:5000')
grpc_server.start()
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
grpc_server.stop()
if __name__ == "__main__":
run()
模型训练文件
# DTC.py
import pickle
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
data = datasets.load_iris()
x, y = data.data, data.target
y = y.reshape((-1, 1))
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25)
clf = DecisionTreeClassifier()
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
print("Accuracy: ", accuracy_score(y_test, y_pred))
s = pickle.dumps(clf)
with open('dtc.model', 'wb+') as f:
f.write(s)
客户端 Ubuntu:
# main.py
# Client
import grpc
import hello_grpc_pb2 as pb2
import hello_grpc_pb2_grpc as pb2_grpc
def run():
conn = grpc.insecure_channel('192.168.31.125:5000')
client = pb2_grpc.GRPCStub(channel=conn)
response = client.HelloHarrytsz(pb2.HelloReq(name="Tom", age=20))
s_list = [5.0, 3.5, 1.3, 0.3]
response2 = client.pre(pb2.Req(arr=s_list))
print(response.res)
print(response2.res)
if __name__ == "__main__":
run()
Client 传入 s_list 数组数据,服务器端利用训练好的模型预测结果,该模型是基于鸢尾花训练的 DecisionTreeClassfication 模型。