当前位置 博文首页 > 文章内容

    pytorch将部分参数进行加载

    作者: 栏目:未分类 时间:2020-09-27 10:00:48

    本站于2023年9月4日。收到“大连君*****咨询有限公司”通知
    说我们IIS7站长博客,有一篇博文用了他们的图片。
    要求我们给他们一张图片6000元。要不然法院告我们

    为避免不必要的麻烦,IIS7站长博客,全站内容图片下架、并积极应诉
    博文内容全部不再显示,请需要相关资讯的站长朋友到必应搜索。谢谢!

    另祝:版权碰瓷诈骗团伙,早日弃暗投明。

    相关新闻:借版权之名、行诈骗之实,周某因犯诈骗罪被判处有期徒刑十一年六个月

    叹!百花齐放的时代,渐行渐远!



    参考:
    https://blog.csdn.net/LXX516/article/details/80124768

    示例代码:

    1. 加载相同名称的模块
    pretrained_dict=torch.load(model_weight)
    model_dict=myNet.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    myNet.load_state_dict(model_dict)
    
    
    1. 直接赋值
    pretrained_dict = torch.load('pre_model/best.pt')
    model_dict = self.get_model().state_dict()
    model_dict['inBlock.0.0.weight'][:,0:10,:,:] = pretrained_dict['inBlock.0.0.weight'][:,0:10,:,:]
    self.get_model().load_state_dict(model_dict)