服务端基于鸢尾花数据集训练 DecisionTreeClassification 分类模型,等待客户端传入鸢尾花 4 个特征(数组)数据,预测分类结果后,将结果传回客户端。

服务器端和客户端的 hello_grpc.proto 文件是一样的。

syntax = "proto3";

service GRPC {
    rpc HelloHarrytsz(HelloReq) returns (HelloReply) {}
    rpc pre(Req) returns (Reply) {}
}

message Req {
    repeated double arr = 1;
}

message Reply {
    repeated int32 res = 1;
}

message HelloReq {
    string name = 1;
    int32 age = 2;
}

message HelloReply {
    string res = 1;
}

服务器端和客户端分别在 hello_grpc.proto 路径下执行以下命令,生成 hello_grpc_pb2 和 hello_grpc_pb2_grpc 文件

python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. hello_grpc.proto

服务端 Mac:

# service.py
# coding:utf-8
import time
import pickle
import grpc
import hello_grpc_pb2 as pb2
import hello_grpc_pb2_grpc as pb2_grpc
from concurrent import futures


class GRPC(pb2_grpc.GRPCServicer):
    def HelloHarrytsz(self, request, context):
        name = request.name
        age = request.age
        res = f'My name is {name}, I\'m {age} year\'s old'
        return pb2.HelloReply(res=res)

    def pre(self, request, context):
        f = open('./tree/dtc.model', 'rb')  # 注意此处model是rb
        s = f.read()
        model = pickle.loads(s)
        # a, b, c, d = request[0], request[1], request[2], request[3], request[4]
        print("requ: ", request.arr)
        res = model.predict([request.arr])
        print("res: ", res)
        return pb2.Reply(res=res)


def run():
    grpc_server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=4)
    )
    pb2_grpc.add_GRPCServicer_to_server(GRPC(), grpc_server)
    grpc_server.add_insecure_port('192.168.31.125:5000')
    print('Server will start at 192.168.31.125:5000')
    grpc_server.start()
    try:
        while True:
            time.sleep(3600)
    except KeyboardInterrupt:
        grpc_server.stop()


if __name__ == "__main__":
    run()

模型训练文件

# DTC.py
import pickle
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

data = datasets.load_iris()
x, y = data.data, data.target
y = y.reshape((-1, 1))

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25)

clf = DecisionTreeClassifier()
clf.fit(x_train, y_train)

y_pred = clf.predict(x_test)
print("Accuracy: ", accuracy_score(y_test, y_pred))

s = pickle.dumps(clf)
with open('dtc.model', 'wb+') as f:
    f.write(s)

客户端 Ubuntu:

# main.py

# Client

import grpc
import hello_grpc_pb2 as pb2
import hello_grpc_pb2_grpc as pb2_grpc


def run():
    conn = grpc.insecure_channel('192.168.31.125:5000')
    client = pb2_grpc.GRPCStub(channel=conn)
    response = client.HelloHarrytsz(pb2.HelloReq(name="Tom", age=20))
    s_list = [5.0, 3.5, 1.3, 0.3]
    response2 = client.pre(pb2.Req(arr=s_list))
    print(response.res)
    print(response2.res)


if __name__ == "__main__":
    run()

Client 传入 s_list 数组数据,服务器端利用训练好的模型预测结果,该模型是基于鸢尾花训练的 DecisionTreeClassfication 模型。

Last modification:April 13, 2022
如果觉得我的文章对你有用,请随意赞赏