Python

Pytorch 딥러닝 model 저장하기

qlsenddl 2020. 12. 22. 17:04
728x90

 pytorch로 AI 딥러닝 모델을 학습 시키는 경우, 학습시킨 모델을 저장하고 다시 불러와서 사용할 필요가 생긴다. 해당 코드 내에서 다시 돌리는 경우 모델의 parameter들이 저장되기 때문에 따로 저장할 필요가 없다. 하지만 코드를 다시 열어서 실행하거나 학습한 모델을 다른 코드에서 사용하거나 다른 사람과 공유해야 하는 경우 모델을 저장한 후 다시 불러와야 한다.


 이에 pytorch에서 딥러닝 모델을 저장하는 방법이 다음 코드에 설명되어 있다.

https://tutorials.pytorch.kr/beginner/saving_loading_models.html

 여기에서는 해당 설명 중에서 다시 학습을 하지 않을 때 모델을 저장하는 방법에 대해서 알아보고자 한다. 여기에서는 간략히 개념만 이해하고 바로 코드를 쓸 수 있게 하기 위함이기 때문에 더 자세한 설명이 필요하면 위 링크를 참고하길 바란다.


일러두기1: model의 파일 확장자는 '.pth'이고, 아래에서 나오는 PATH에는 이 확장자까지 포함해서 입력해주어야 한다.

일러두기2: 'ex>'에서 model_dir = 'C:/Users/user/Desktop/'과 같이 이미 선언됐다고 생각하자.(마지막에 /로 끝난 것이 포인트)


1. 전체 model 저장 및 불러오기

저장: torch.save(model, PATH) -> 현재 학습 중인 model을 PATH에 해당하는 directory 및 파일 이름으로 저장

불러오기: torch.load(PATH) -> PATH가 가리키는 directory 및 파일 이름에 해당하는 model을 불어옴


ex1> 저장

torch.save(model, model_dir + 'model_name{}.pth'.format(epoch))


ex2> 불러오기

model = torch.load(model_dir + 'model_name1.pth')

model.eval()


2. model의 parameter만 저장 및 불러오기

 state_dict라는 객체를 통해 저장한다. state_dict에 대한 설명은 위 링크에 설명되어 있기는 한데, 매우 간단하게 딥러닝 model에 대한 parameter 정보를 담고 있는 객체라고 생각하면 된다.

저장: torch.save(model.state_dict(), PATH) -> model의 parameter 저장

불러오기: model.load_state_dict(torch.load(PATH)) -> model의 parameter 불러오기


ex1> 저장

torch.save(model.state_dict(), 'model_name{}.pth'.format(epoch))


ex2> 불러오기

model = DeepLearningClass()

model.load_state_dict(torch.load(model_dir + 'model_name1.pth'))

model.eval()


※ 링크에 있는 설명을 보면 1번과 같이 전체 model을 저장하고 불러오는 경우 다양한 이유로 불러오기가 안될 수 있다고 설명한다. 하지만 나의 경우, 오히려 2번과 같이 parameter만 저장하고 불러오는 경우 더 많은 에러가 발생했다. 때문에 보통 모델을 저장할 때 위의 2가지 방법으로 모두 model을 저장해놓고, 보통은 1번 방식을 통해 model을 불러오고 에러가 발생하면 그 때 2번 방식으로 불러오기를 시도한다.


3. DataParallel을 적용한 model 저장 및 불러오기

 학습하는 model의 용량이 많아서 병렬 GPU로 계산하는 경우 DataParallel로 학습을 시키게 된다. 이 경우 1번 방식과 같이 model을 저장하고 불러오는 경우 에러가 발생할 수 있다. 때문에 병렬 GPU로 학습한 경우 model을 저장할 때는 아래와 같이 한다.

저장: torch.save(model.module.state_dict(), PATH)

불러오기: 2번의 불러오기와 동일


※ 하지만 DataParallel로 학습했는데 1번과 같이 저장했어도 불러오는 것은 가능하다. 방법은 아래와 같다.


model = torch.load(PATH)

model.eval()

output = model.module(input)


 print(model)을 해보면 알겠지만 해당 모델이 module 안에 원하는 학습 model이 저장되어있기 때문에 model.module을 통해서 원하는 학습 model로 접근하는 것이 가능하다. 그러므로 DataParallel로 학습 시 1번과 같이 저장했어도, 1번과 같이 불러오기를 한 후 model을 사용할 때 model.module로 입력해주면 원하는 학습 model을 사용할 수 있게 된다.

728x90