このセクションでは、OssCheckpointを使用して、OSSバケットから直接チェックポイント (特定の時点でのモデルトレーニング中に保存されたモデルステータス) を読み書きする方法について説明します。
前提条件
OSS Connector for AI/MLがインストールおよび設定されています。 詳細については、「AI/ML用OSSコネクタのインストール」および「AI/ML用OSSコネクタの設定」をご参照ください。
OssCheckpoint
OssCheckpointは、データトレーニングプロセス中に結果を読み書きするシナリオに適しています。
次の例は、OssCheckpointを使用してチェックポイントを読み書きする方法を示しています。
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "endpoint"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# Create a checkpoint by using OssCheckpoint
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# Read the checkpoint
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# Write the checkpoint
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
データ型
OssCheckpoint providecommon I/O操作を使用して作成されたチェックポイントオブジェクト。 詳細については、「OSS Connector For AI/MLのデータ型」をご参照ください。
パラメーター
次の表に、OssCheckpointを使用するときに設定する必要があるパラメーターを示します。
パラメーター | タイプ | 必須 | 説明 |
endpoint | String | 必須 | OSSへのアクセスに使用されるエンドポイント。 詳細については、「エンドポイントとデータセンター」をご参照ください。 |
cred_path | String | 必須 | 認証ファイルのパス。デフォルト値は |
config_path | String | 必須 | OSS Connector設定ファイルのパス。デフォルト値は |