栏目分类:
子分类:
返回
终身学习网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
终身学习网 > IT > 软件开发 > 后端开发 > Python

关于存储和加载模型权重

Python 更新时间:发布时间: 百科书网 趣学号
        checkpoint = load_checkpoint(args.resume)
        model_dict = model.state_dict()
        checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items() 
                            if k in model_dict}
        model_dict.update(checkpoint_load)
        model.load_state_dict(model_dict)
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(start_epoch, best_top1))

1.首先先读取arg.sume(已存储的权重)到checkpoint,相当于字典

2.再读取模型中的参数权重到model_dict

3.将checkpoint中key值对应model_dict的数据加载到checkpoint_load中

4.将已经训练好的模型参数更新并加载到已有模型参数中(单卡)

5.再读取checkpoint中的其他参数,以此类推

model.module.load_state_dict(checkpoint['state_dict'])

 加载模型参数(多卡)

torch.save(model.state_dict(), model_out_path)

存储模型参数(单卡)

torch.save(state, fpath)
save_checkpoint({
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1
                })

存储模型参数(多卡)以及其他信息,

转载请注明:文章转载自 www.051e.com
本文地址:http://www.051e.com/it/268532.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 ©2023-2025 051e.com

ICP备案号:京ICP备12030808号