Python/Pytorch

[Pytorch] Pytorch 에서 모델 요약 보기 with torchsummary

언킴 2022. 4. 18. 18:35
반응형

tensorflow 같은 경우에는 model.summary()를 통해서 간편하게 모델의 정보를 확인할 수 있다. 그렇다면 PyTorch에서는 불가능할까? 아니다!! Torch도 torchsummary를 지원해주고 있기에, torchsummary를 설치하면 확인할 수 있다. 

 

 

확인하기에 앞서 우리는 이런 구조의 간단한 CNN을 만들 수 있을 것이다. 그런다음 아래의 코드를 실행하면 tensorflow의 model.summary()와 유사한 그림을 가지고 올 수 있다. 

 

from torchsummary import summary​

 

각 layer의 output shape와 각 parameter에 대한 정보도 다 나온다. 또한 용량까지 나오니 엄청나게 좋다..! 한가지 단점이 있다면 아래와 같이 model의 input과 batch size를 입력해주어야 된다는 것이다.. 

model = CNN(feature)
summary(model, (3, 32, 32), batch_size=64)

# torch는 tensorflow와는 달리 (channel, height, width)이다. 
# (batch, channel, height, width) -> PyTorch
# (batch, height, width, channel) -> Tensorflow