全部产品
Search
文档中心

人工智能平台 PAI:Golang SDK使用说明

更新时间:Oct 31, 2023

推荐使用EAS提供的官方SDK进行服务调用,从而有效减少编写调用逻辑的时间并提高调用稳定性。本文介绍官方Golang SDK接口详情,并以常见类型的输入输出为例,提供了使用Golang SDK进行服务调用的完整程序示例。

背景信息

使用Golang SDK进行服务调用时,由于在编译代码时,Golang的包管理工具会自动从Github上将Golang SDK的代码下载到本地,因此您无需提前安装Golang SDK。如果您需要自定义部分调用逻辑,可以先下载Golang SDK代码,再对其进行修改。

接口列表

接口

描述

PredictClient

NewPredictClient(endpoint string, serviceName string) *PredictClient

  • 功能:PredictClient类构造函数。

  • 参数:

    • endpoint:必填,表示服务端的Endpoint地址。对于普通服务,将其设置为默认网关Endpoint。

    • serviceName:必填,表示服务名称。

  • 返回值:创建的PredictClient对象。

SetEndpoint(endpointName string)

  • 功能:设置服务的Endpoint。

  • 参数:endpointName 表示服务端的Endpoint地址。对于普通服务,将其设置为默认网关Endpoint。

SetServiceName(serviceName string)

  • 功能:设置请求的服务名称。

  • 参数:serviceName表示请求的服务名称。

SetEndpointType(endpointType string)

  • 功能:设置服务端的网关类型。

  • 参数:endpointType表示网关类型。系统支持以下网关类型:

    • "DEFAULT":默认网关。如果不指定网关类型,默认为该类型。

    • "DIRECT":使用高速直连通道访问服务。

SetToken(token string)

  • 功能:设置服务访问的Token。

  • 参数:token表示访问服务时使用的鉴权Token。

SetHttpTransport(transport *http.Transport)

  • 功能:设置HTTP客户端的Transport属性。

  • 参数:transport表示发送HTTP请求时使用的Transport对象。

SetRetryCount(max_retry_count int)

  • 功能:设置请求失败重试次数。

  • 参数:max_retry_count表示请求失败后重连的次数,默认为5。

    重要

    对于服务端进程异常、服务器异常或网关长连接断开等情况导致的个别请求失败,均需要客户端重新发送请求。因此,请勿将该参数设置为0。

SetTimeout(timeout int)

  • 功能:设置请求的超时时间。

  • 参数:timeout表示请求的超时时间,单位为ms,默认值为5000。

Init()

对PredictClient对象进行初始化。在上述设置参数的接口执行完成后,需要调用Init()接口才能生效。

Predict(request Request) Response

  • 功能:向在线预测服务提交一个预测请求。

  • 参数:Request对象是interface(StringRequest, TFRequest,TorchRequest)

  • 返回值:Response对象是interface(StringResponse, TFResponse,TorchResponse)

StringPredict(request string) string

  • 功能:向在线预测服务提交一个预测请求。

  • 参数:request对象表示待发送的请求字符串。

  • 返回值:STRING类型的服务响应。

TorchPredict(request TorchRequest) TorchResponse

  • 功能:向在线预测服务提交一个PyTorch预测请求。

  • 参数:request表示TorchRequest类的对象。

  • 返回值:对应的TorchResponse。

TFPredict(request TFRequest) TFResponse

  • 功能:向在线预测服务提交一个预测请求。

  • 参数:request表示TFRequest类的对象。

  • 返回值:对应的TFResponse。

TFRequest

TFRequest(signatureName string)

  • 功能:TFRequest类的构建函数。

  • 参数:signatureName表示请求模型的Signature Name。

AddFeed(?)(inputName string, shape []int64{}, content []?)

  • 功能:请求TensorFlow的在线预测服务模型时,设置需要输入的Tensor。

  • 参数:

    • inputName:表示输入Tensor的别名。

    • shape:表示输入Tensor的TensorShape。

    • content:表示输入的Tensor的内容,通过一维数组展开表示。支持的类型包括INT32、INT64、FLOAT32、FLOAT64、STRING及BOOL,该接口名称与具体类型相关,例如AddFeedInt32()。如果需要其它数据类型,则可以参考代码自行通过PB格式构造。

AddFetch(outputName string)

  • 功能:请求TensorFlow的在线预测服务模型时,设置需要输出Tensor的别名。

  • 参数:outputName表示待获取的输出Tensor的别名。

    对于SavedModel模型,该参数可选。如果未设置,则输出所有的outputs。

    对于Frozen Model,该参数必选。

TFResponse

GetTensorShape(outputName string) []int64

  • 功能:获得指定别名。的输出Tensor的TensorShape。

  • 参数:outputName表示待获取输出Shape的Tensor别名。

  • 返回值:返回的Tensor Shape,各个维度以数组形式表示。

Get(?)Val(outputName string) [](?)

  • 功能:获取输出Tensor的数据向量,输出结果以一维数组的形式保存。您可以配套使用GetTensorShape()接口,获取对应Tensor的Shape,将其还原成所需的多维Tensor。支持的类型包括FLOAT、DOUBLE、INT、INT64、STRING及BOOL,接口名称与具体类型相关,例如GetFloatVal()

  • 参数:outputName表示待获取输出数据的Tensor别名。

  • 返回值:输出Tensor的数据展开成的一维数组。

TorchRequest

TorchRequest()

TFRequest类的构建函数。

AddFeed(?)(index int, shape []int64{}, content []?)

  • 功能:请求PyTorch的在线预测服务模型时,设置需要输入的Tensor。

  • 参数:

    • index:表示待输入的Tensor下标。

    • shape:表示输入Tensor的TensorShape。

    • content:表示输入Tensor的内容,通过一维数组展开表示。支持的类型包括INT32、INT64、FLOAT32及FLOAT64,该接口名称与具体类型相关,例如AddFeedInt32()。如果需要其它数据类型,则可以参考代码自行通过PB格式构造。

AddFetch(outputIndex int)

  • 功能:请求PyTorch的在线预测服务模型时,设置需要输出的Tensor的Index。该接口为可选,如果您没有调用该接口设置输出Tensor的Index,则输出所有的outputs。

  • 参数:outputIndex表示输出Tensor的Index。

TorchResponse

GetTensorShape(outputIndex int) []int64

  • 功能:获得指定下标的输出Tensor的TensorShape。

  • 参数:outputName表示待获取输出Shape的Tensor别名。

  • 返回值:返回的Tensor Shape,各个维度以数组形式表示。

Get(?)Val(outputIndex int) [](?)

  • 功能:获取输出Tensor的数据向量,输出结果以一维数组的形式保存。您可以配套使用GetTensorShape()接口获取对应Tensor的Shape,将其还原成所需的多维Tensor。支持的类型包括FLOAT、DOUBLE、INT及INT64,接口名称与具体类型相关,例如GetFloatVal()

  • 参数:outputIndex表示待获取输出数据Tensor的下标。

  • 返回值:输出Tensor的数据展开成的一维数组。

QueueClient

NewQueueClient(endpoint, queueName, token string) (*QueueClient, error)

  • 功能:QueueClient类构造函数。

  • 参数:

    • endpoint:表示服务端的Endpoint地址。

    • queueName:表示队列服务名称。

    • token:表示队列服务的token。

  • 返回值:创建的QueueClient对象。

Truncate(ctx context.Context, index uint64) error

  • 功能:从指定index向前截断队列中的数据,只保留指定index之后的数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • index:表示要截断的队列中数据的index。

Put(ctx context.Context, data []byte, tags types.Tags) (index uint64, requestId string, err error)

  • 功能:向队列中写入一条数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • data:表示要向队列中写入的数据内容。

  • 返回值:

    • index:当前写入的数据在队列中的index值,可用于从队列中查询数据。

    • requestId:当前写入数据在队列中自动生成的requestId。requestId是一个特殊的tag,也可用于在队列中查询数据。

GetByIndex(ctx context.Context, index uint64) (dfs []types.DataFrame, err error)

  • 功能:根据index值从队列中查询一条数据,查询完成后,在队列中会自动删除该数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • index:表示要从队列中查询的数据所在的index值。

  • 返回值:dfs:队列中查询出的以DataFrame封装的数据结果。

GetByRequestId(ctx context.Context, requestId string) (dfs []types.DataFrame, err error)

  • 功能:根据数据的requestId从队列中查询一条数据,查询完成后,在队列中会自动删除该数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • requestId:表示要从队列中查询的数据的requestId值。

  • 返回值:dfs:队列中查询出的以DataFrame封装的数据结果。

Get(ctx context.Context, index uint64, length int, timeout time.Duration, autoDelete bool, tags types.Tags) (dfs []types.DataFrame, err error)

  • 功能:根据指定条件从队列中查询数据,GetByIndex()GetByRequestId()是对Get()函数的简单封装。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • index:表示要查询的数据的起始index。

    • length:表示要查询的数据的条数。返回从index开始计算(包含index)的最大length条数据。

    • timeout:表示查询的等待时间。在等待时间内,如果队列中有length条数据则直接返回,否则等到最大timeout等待时间则停止。

    • auto_delete:表示是否从队列中自动删除已经查询的数据。如果配置为False,则数据可被重复查询,您可以通过调用Del()方法手动删除数据。

    • tags:表示查询包含指定tags的数据,类型为map[string]string。从指定index开始遍历length条数据,返回包含指定tags的数据。

  • 返回值:dfs:队列中查询出的以DataFrame封装的数据结果。

Del(ctx context.Context, indexes ...uint64)

  • 功能:从队列中删除指定index的数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • indexes:表示要从队列中删除的数据的index列表。

Attributes() (attrs types.Attributes, err error)

  • 功能:获取队列的属性信息,包含队列总长度、当前的数据长度等信息。

  • 返回值:attrs:队列的属性信息,类型为map[string]string。

Watch(ctx context.Context, index, window uint64, indexOnly bool, autocommit bool) (watcher types.Watcher, err error)

  • 功能:订阅队列中的数据,队列服务会根据条件向客户端推送数据。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • index:表示订阅的起始数据index。

    • window:表示订阅的窗口大小,队列服务一次最多向单个客户端实例推送的数据量。

      说明

      如果推送的数据没有被commit,则服务端不会再推送新数据;如果commit N条数据,则服务队列会向客户端推送N条数据,确保客户端在同一时刻处理的数据不会超过设置的窗口大小,来实现客户端限制并发的功能。

    • index_only:表示是否只推送index值。

    • auto_commit:表示是否在推送完一条数据后,自动commit数据。建议配置为False。在收到推送数据并计算完成后手动Commit,在未完成计算的情况下实例发生异常,则实例上未commit的数据会由队列服务分发给其他实例继续处理。

  • 返回值:返回一个watcher对象,可通过该对象读取推送的数据。

Commit(ctx context.Context, indexes ...uint64) error

  • 功能:commit指定index的数据。

    说明

    commit表示服务队列推送的数据已经处理完成,可以将该数据从队列中清除,且不需要再推送给其他实例。

  • 参数:

    • ctx:表示当前操作的Context信息。

    • indexes:表示要向队列中commit的数据的index值列表。

types.Watcher

FrameChan() <-chan types.DataFrame

  • 功能:返回一个管道对象,服务端推送过来的数据会被写入该管道中,可以从该管道中循环读取数据。

  • 返回值:可用于读取推送数据的管道对象。

Close()

功能:关闭一个Watcher对象,用于关闭后端的数连接。

说明

一个客户端只能启动一个Watcher对象,使用完成后需要将该对象关闭才能启动新的Watcher对象。

程序示例

  • 字符串输入输出示例

    对于使用自定义Processor部署服务的用户而言,通常采用字符串进行服务调用(例如,PMML模型服务的调用),具体的Demo程序如下。

    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "scorecard_pmml_example")
        client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
        client.Init()
        req := "[{\"fea1\": 1, \"fea2\": 2}]"
        for i := 0; i < 100; i++ {
            resp, err := client.StringPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v\n", err.Error())
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • TensorFlow输入输出示例

    使用TensorFlow的用户,需要将TFRequest和TFResponse分别作为输入和输出数据格式,具体Demo示例如下。

    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "mnist_saved_model_example")
        client.SetToken("YTg2ZjE0ZjM4ZmE3OTc0NzYxZDMyNmYzMTJjZTQ1YmU0N2FjMTAy****")
        client.Init()
    
        tfreq := eas.TFRequest{}
        tfreq.SetSignatureName("predict_images")
        tfreq.AddFeedFloat32("images", []int64{1, 784}, make([]float32, 784))
    
        for i := 0; i < 100; i++ {
            resp, err := client.TFPredict(tfreq)
            if err != nil {
                fmt.Printf("failed to predict: %v", err)
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • PyTorch输入输出示例

    使用PyTorch的用户,需要将TorchRequest和TorchResponse分别作为输入和输出数据格式,具体Demo示例如下。

    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "pytorch_resnet_example")
        client.SetTimeout(500)
        client.SetToken("ZjdjZDg1NWVlMWI2NTU5YzJiMmY5ZmE5OTBmYzZkMjI0YjlmYWVl****")
        client.Init()
        req := eas.TorchRequest{}
        req.AddFeedFloat32(0, []int64{1, 3, 224, 224}, make([]float32, 150528))
        req.AddFetch(0)
        for i := 0; i < 10; i++ {
            resp, err := client.TorchPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v", err)
            } else {
                fmt.Println(resp.GetTensorShape(0), resp.GetFloatVal(0))
            }
        }
    }
  • 通过VPC网络直连方式调用服务的示例

    通过网络直连方式,您只能访问部署在EAS专属资源组的服务,且需要为该资源组与用户指定的vSwitch连通网络后才能使用。关于如何购买EAS专属资源组和连通网络,请参见使用专属资源组配置网络连通。该调用方式与普通调用方式相比,仅需增加一行代码client.SetEndpointType(eas.EndpointTypeDirect)即可,特别适合大流量高并发的服务,具体示例如下。

    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "scorecard_pmml_example")
        client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
        client.SetEndpointType(eas.EndpointTypeDirect)
        client.Init()
        req := "[{\"fea1\": 1, \"fea2\": 2}]"
        for i := 0; i < 100; i++ {
            resp, err := client.StringPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v\n", err.Error())
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • 客户端连接参数设置的示例

    您可以通过http.Transport属性设置请求客户端的连接参数,示例代码如下。

    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "network_test")
        client.SetToken("MDAwZDQ3NjE3OThhOTI4ODFmMjJiYzE0MDk1NWRkOGI1MmVhMGI0****")
        client.SetEndpointType(eas.EndpointTypeDirect)
        client.SetHttpTransport(&http.Transport{
            MaxConnsPerHost:       300,
            TLSHandshakeTimeout:   100 * time.Millisecond,
            ResponseHeaderTimeout: 200 * time.Millisecond,
            ExpectContinueTimeout: 200 * time.Millisecond,
        })
    }
  • 队列服务发送、订阅数据示例

    通过QueueClient可向队列服务中发送数据、查询数据、查询队列服务的状态以及订阅队列服务中的数据推送。以下方Demo为例,介绍一个线程向队列服务中推送数据,另一个线程通过Watcher订阅队列服务中推送过来的数据。

        const (
            QueueEndpoint = "182848887922****.cn-shanghai.pai-eas.aliyuncs.com"
            QueueName     = "test_group.qservice"
            QueueToken    = "YmE3NDkyMzdiMzNmMGM3ZmE4ZmNjZDk0M2NiMDA3OTZmNzc1MTUx****"
        )
        queue, err := NewQueueClient(QueueEndpoint, QueueName, QueueToken)
    
        // truncate all messages in the queue
        attrs, err := queue.Attributes()
        if index, ok := attrs["stream.lastEntry"]; ok {
            idx, _ := strconv.ParseUint(index, 10, 64)
            queue.Truncate(context.Background(), idx+1)
        }
    
        ctx, cancel := context.WithCancel(context.Background())
    
        // create a goroutine to send messages to the queue
        go func() {
            i := 0
            for {
                select {
                case <-time.NewTicker(time.Microsecond * 1).C:
                    _, _, err := queue.Put(context.Background(), []byte(strconv.Itoa(i)), types.Tags{})
                    if err != nil {
                        fmt.Printf("Error occured, retry to handle it: %v\n", err)
                    }
                    i += 1
                case <-ctx.Done():
                    break
                }
            }
        }()
    
        // create a watcher to watch the messages from the queue
        watcher, err := queue.Watch(context.Background(), 0, 5, false, false)
        if err != nil {
            fmt.Printf("Failed to create a watcher to watch the queue: %v\n", err)
            return
        }
    
        // read messages from the queue and commit manually
        for i := 0; i < 100; i++ {
            df := <-watcher.FrameChan()
            err := queue.Commit(context.Background(), df.Index.Uint64())
            if err != nil {
                fmt.Printf("Failed to commit index: %v(%v)\n", df.Index, err)
            }
        }
    
        // everything is done, close the watcher
        watcher.Close()
        cancel()