全部產品
Search
文件中心

Platform For AI:使用EasyTransfer進行文本分類

更新時間:Jul 13, 2024

EasyTransfer旨在協助自然語言處理(NLP)情境的遷移學習開發人員方便快捷地構建遷移學習模型。本文以文本分類為例,為您介紹如何在DSW中使用EasyTransfer,包括啟動訓練、評估模型、預測模型及匯出並部署模型。

前提條件

已建立DSW執行個體,且該執行個體滿足版本限制,詳情請參見建立及管理DSW執行個體使用限制

說明

建議建立DSW執行個體時選擇GPU規格。

背景資訊

遷移學習(Transfer Learning)的核心的思想是將一個環境中學到的知識應用到新環境的學習任務中。面向自然語言處理(NLP)情境的遷移學習在工業上擁有大量需求,且不斷湧現新的領域,而傳統的機器學習需要對每個領域都積累大量訓練資料,這將耗費大量的人力和物力。如果能夠利用現有的訓練資料協助學習新領域的學習任務,將會大幅度減少標註的人力和物力。為了方便使用者快速搭建面向NLP情境的遷移學習模型,PAI團隊推出了深度遷移學習架構EasyTransfer。

使用限制

EasyTransfer僅支援如下Python版本和鏡像版本:

  • Python版本:Python 2.7或Python 3.4及其以上版本。

  • 鏡像版本:選擇官方鏡像tensorflow:1.12PAI-gpu-py36-cu101-ubuntu18.04

步驟一:準備資料

  1. 進入DSW開發環境。

    1. 登入PAI控制台

    2. 在左側導覽列單擊工作空間列表,在工作空間列表頁面中單擊待操作的工作空間名稱,進入對應工作空間內。

    3. 在頁面左上方,選擇使用服務的地區。

    4. 在左側導覽列,選擇模型開發與訓練 > 互動式建模(DSW)

    5. 可選:互動式建模(DSW)頁面的搜尋方塊,輸入執行個體名稱或關鍵字,搜尋執行個體。

    6. 單擊需要開啟的執行個體操作列下的開啟

  2. DSW開發環境,單擊頂部功能表列中的Terminal,按照介面操作指引開啟Terminal。

  3. 在Terminal中,使用如下命令下載Demo資料集。

    wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/ez_text_classify/zqkd_sample/train.csv
    wget http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/ez_text_classify/zqkd_sample/dev.csv
    說明

    此處僅使用少量樣本進行示範,您訓練自己的新聞分類模型時,需要使用更多的樣本進行模型訓練。

步驟二:啟動訓練任務(在目前的目錄)

使用如下命令,啟動訓練任務。

easy_transfer_app \
  --mode=train \
  --modelName=text_classify_bert \
  --inputTable="./train.csv,./dev.csv" \
  --inputSchema=content:str:1,label:str:1 \
  --firstSequence=content \
  --labelName=label \
  --labelEnumerateValues="教育,三農,娛樂,健康,美文,搞笑,美食,財經,科技,旅遊,汽車,時尚,科學,文化,房產,熱點,母嬰,家居,體育,國際,育兒,寵物,遊戲,健身,職場,讀書,藝術,動漫" \
  --sequenceLength=128 \
  --checkpointDir=./classify_models \
  --batchSize=64 \
  --numEpochs=3 \
  --optimizerType=adam \
  --learningRate=3e-5 \
  --advancedParameters='\
    pretrain_model_name_or_path=pai-bert-base-zh \
    '

命令中的訓練參數介紹如下表所示。

參數

是否必選

描述

預設值

類型

mode

模式,取值包括:

  • train:訓練

  • evaluate:評估

  • predict:預測

  • export:匯出

STRING

modelName

App模型名稱,支援以下模型:

  • BERT分類,該參數取值為text_classify_bert

  • DGCNN分類,該參數取值為text_classify_dgcnn

  • BERT匹配,該參數取值為text_match_bert

  • BERT雙塔匹配,該參數取值為text_match_bert_two_tower

  • BiCNN模型(雙塔CNN模型),該參數取值為text_match_bicnn

  • HCNN模型,該參數取值為text_match_hcnn

  • DAM模型,該參數取值為text_match_dam

  • DAM+模型,該參數取值為text_match_damplus

  • TextCNN模型,該參數取值為text_classify_cnn

  • BERT閱讀理解,該參數取值為text_comprehension_bert

  • BERT-HAE模型,該參數取值為text_comprehension_bert_hae

  • BERT序列標註,該參數取值為sequence_labeling_bert

text_match_bert

STRING

inputTable

輸入的訓練表,使用英文逗號(,)分隔。例如./train.csv,./dev.csv

STRING

inputSchema

輸入檔案的列Schema,取值格式為列名:類型:長度。其中:

  • 類型的取值包括intstrfloat

  • 長度通常為1。如果某列為英文逗號(,)分隔的數組,則長度為數組的長度。

STRING

firstSequence

第一個文本序列在輸入格式中對應的列名。

STRING

labelName

標籤在輸入格式中對應的列名。

Null 字元串('')

STRING

labelEnumerateValues

標籤枚舉值,支援以下兩種格式:

  • 直接列出標籤枚舉值,且多個枚舉值之間以英文逗號(,)分隔。

  • 取值為一個TXT格式的檔案路徑。該檔案內,多個枚舉值以分行符號分隔。

Null 字元串('')

STRING

sequenceLength

序列整體最大長度,取值範圍1~512。

128

INT

checkpointDir

模型儲存路徑所在目錄。例如./classify_models

STRING

batchSize

訓練時的批處理大小。如果是多卡訓練,則為每個GPU上的批處理大小。

32

INT

numEpochs

訓練總Epoch的數量。

1

INT

optimizerType

最佳化器類型,取值包括:

  • adam

  • lamb

  • adagrad

  • adadeleta

adam

STRING

learningRate

學習率。

2e-5

FLOAT

advancedParameters

其他進階參數,詳情請參見下方的進階參數表格。

不涉及

STRING

關於進階參數的介紹如下表所示。

參數

是否必選

描述

預設值

類型

pretrain_model_name_or_path

預訓練模型。不僅支援EasyTransfer下的所有預訓練模型,也支援使用者自己的預訓練模型OSS地址。

pai-bert-base-zh

STRING

步驟三:評估模型

訓練完成後,您可以使用如下命令測試或評估訓練結果。

easy_transfer_app \
  --mode=evaluate \
  --inputTable=./dev.csv \
  --checkpointPath=./classify_models/model.ckpt-64 \
  --batchSize=10

命令中的參數介紹如下表所示。

參數

是否必選

描述

預設值

類型

mode

模式,取值包括:

  • train:訓練

  • evaluate:評估

  • predict:預測

  • export:匯出

STRING

inputTable

輸入的評估表,使用英文逗號(,)分隔。例如./dev.csv

重要

評估集的列Schema必須與訓練集的保持一致。

STRING

checkpointPath

模型CKPT儲存路徑所在的目錄。例如./classify_models/model.ckpt-32

STRING

batchSize

評估時的批處理大小。如果是多卡情境,則為每個GPU上的批處理大小。

32

INT

步驟四:預測模型

訓練完成後,您可以使用如下命令對檔案(可以沒有標籤)進行預測。

easy_transfer_app \
  --mode=predict \
  --inputSchema=content:str:1,label:str:1 \
  --inputTable=dev.csv \
  --outputTable=dev.pred.csv \
  --firstSequence=content \
  --appendCols=label \
  --outputSchema=predictions,probabilities,logits \
  --checkpointPath=./classify_models/ \
  --batchSize=100

命令中的參數介紹如下表所示。

參數

是否必選

描述

預設值

類型

mode

模式,取值包括:

  • train:訓練

  • evaluate:評估

  • predict:預測

  • export:匯出

STRING

inputTable

輸入的待預測表。例如./dev.csv

STRING

outputTable

預測結果的輸出表。例如./dev.pred.csv

STRING

inputSchema

輸入檔案的列Schema,取值格式為列名:類型:長度。其中:

  • 類型的取值包括intstrfloat

  • 長度通常為1。如果某列為英文逗號(,)分隔的數組,則長度為數組的長度。

STRING

firstSequence

第一個文本序列在輸入格式中對應的列名。

STRING

appendCols

輸入表中需要添加到輸出表的列。

Null 字元串('')

STRING

outputSchema

選擇輸出資料中需要的預測值,多個選擇項之間以英文逗號(,)分隔。支援以下三種格式:

  • predictions:對於單標籤模型,輸出相應類型的ID,其中ID與訓練時的labelEnumerateValue順序對應。對於多標籤模型,輸出multi-hot的向量,且使用英文逗號(,)分隔。

  • probabilities:輸出每一個類的機率,多個類之間使用英文逗號(,)分隔。

  • logits:輸出每一個類的Logit值,多個類之間使用英文逗號(,)分隔。

predictions

STRING

checkpointPath

模型儲存路徑所在目錄。例如./bert_classify_models

STRING

batchSize

訓練時的批處理大小。如果是多卡訓練,則為每個GPU上的批處理大小。

32

INT

步驟五:匯出模型並線上部署EAS服務

  1. 匯出模型。

    訓練結束後,預設會匯出最後一個Checkpoint產生的variables和saved_model.pb檔案。如果您需要匯出其他Checkpoint的訓練結果,則可以使用如下命令。

    easy_transfer_app \
      --mode=export \
      --exportType=app_model \
      --checkpointPath=./classify_models/model.ckpt-64 \
      --exportDirBase=./export_model \
      --batchSize=100

    命令中的參數介紹如下表所示。

    參數

    是否必選

    描述

    預設值

    類型

    mode

    模式,取值包括:

    • train:訓練

    • evaluate:評估

    • predict:預測

    • export:匯出

    STRING

    exportType

    匯出的類型,取值包括:

    • app_model: 匯出Finetune模型。

    • ez_bert_feat:匯出文本向量化組件所需模型。

    STRING

    checkpointPath

    模型CKPT儲存路徑所在的目錄。

    STRING

    exportDirBase

    匯出模型的目錄。

    STRING

    batchSize

    評估時的批處理大小。如果是多卡情境,則為每個GPU上的批處理大小。

    32

    INT

  2. 打包模型檔案。

    打包輸出目錄中的variables、saved_model.pb、vocab.txt及定義使用者輸入的label_mapping檔案。例如本文中新聞分類的label_mapping檔案為label_mapping.json,該檔案中的標籤ID必須為INT類型,且順序與訓練時的labelEnumerateValues參數的順序一致。label_mapping.json的內容樣本如下。

    {"教育": 0,
     "三農": 1,
     ...,
     "動漫": 27}

    您也可以從訓練指定的checkpointDir目錄下找到label_mapping.json檔案。

    打包得到的檔案如下所示。打包的模型檔案

  3. 上傳模型檔案至OSS,得到模型的OSS地址。例如oss://xxx/your_model.zip

  4. 部署模型,詳情請參見EasyTransfer Processor