本文為您介紹如何使用OssCheckpoint直接從OSS中讀寫檢查點(模型訓練過程中儲存的特定時間點的模型狀態)。
前提條件
已安裝並配置OSS Connector for AI/ML。具體操作,請參見安裝OSS Connector for AI/ML和配置OSS Connector for AI/ML。
OssCheckpoint
OssCheckpoint適用於資料訓練過程中對訓練結果進行讀寫需求的情境。
以下樣本展示了如何使用OssCheckpoint來進行Checkpoint的讀取和寫入。
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "endpoint"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# 使用OssCheckpoint建立checkpoint
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# 讀 checkpoint
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# 寫 checkpoint
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
資料類型
通過OssCheckpoint建立的checkpoint對象實現了常用的IO介面。更多資訊,請參見OSS Connector for AI/ML中的資料類型。
參數配置
使用OssCheckpoint時需要進行相應配置,具體配置項說明請參見下表。
參數名 | 參數類型 | 是否必選 | 說明 |
endpoint | string | 是 | OSS對外服務的訪問網域名稱。更多資訊,請參見訪問網域名稱和資料中心。 |
cred_path | string | 是 | 鑒權檔案預設路徑為 |
config_path | string | 是 | OSS Connector設定檔預設路徑為 |