本文为您介绍如何使用OssCheckpoint直接从OSS中读写检查点(模型训练过程中保存的特定时间点的模型状态)。
前提条件
已安装并配置OSS Connector for AI/ML。具体操作,请参见安装OSS Connector for AI/ML和配置OSS Connector for AI/ML。
OssCheckpoint
OssCheckpoint适用于数据训练过程中对训练结果进行读写需求的场景。
以下示例展示了如何使用OssCheckpoint来进行Checkpoint的读取和写入。
import torch
from osstorchconnector import OssCheckpoint
ENDPOINT = "endpoint"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"
# 使用OssCheckpoint创建checkpoint
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)
# 读 checkpoint
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
state_dict = torch.load(reader)
# 写 checkpoint
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
torch.save(state_dict, writer)
数据类型
通过OssCheckpoint创建的checkpoint对象实现了常用的IO接口。更多信息,请参见OSS Connector for AI/ML中的数据类型。
参数配置
使用OssCheckpoint时需要进行相应配置,具体配置项说明请参见下表。
参数名 | 参数类型 | 是否必选 | 说明 |
endpoint | string | 是 | OSS对外服务的访问域名。更多信息,请参见访问域名和数据中心。 |
cred_path | string | 是 | 鉴权文件默认路径为 |
config_path | string | 是 | OSS Connector配置文件默认路径为 |