torchinfo是目前github上比较好用,而且项目还很活跃,用于计算分析基于Pytorch的网络计算量的工具,它可以对网络结构进行可视化显示,方便进行网络结构调试。
使用演示
from torchinfo import summary
model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
结果:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─Conv2d: 1-1 [16, 10, 24, 24] 260
├─Conv2d: 1-2 [16, 20, 8, 8] 5,020
├─Dropout2d: 1-3 [16, 20, 8, 8] --
├─Linear: 1-4 [16, 50] 16,050
├─Linear: 1-5 [16, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 7.69
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================
目前版本支持特性
- RNNs, LSTMs, and other recursive layers 支持RNNs,LSTMs,以及其他递归层。
- Sequentials & Module Lists 支持序列或者list模型
- Branching output used to explore model layers using specified depths。支持特定分支深度的网络结构输出
- Returns ModelStatistics object containing all summary data fields。返回所有模型统计数据
- Configurable columns。可自由配置需要返回的模型参数统计
- Jupyter Notebook / Google Colab。支持jupyter和Colab.
其他特性
- Verbose mode to show weights and bias layers. 显示权重和偏置层
- Accepts either input data or simply the input shape! 接受输入数据或仅接受输入形状
- Customizable line widths and batch dimension。可自定义的线宽和batch维度
- Comprehensive unit/output testing, linting, and code coverage testing。全面的单元输出测试。
使用例子
class MultipleInputNetDifferentDtypes(nn.Module):
def __init__(self):
super().__init__()
self.fc1a = nn.Linear(300, 50)
self.fc1b = nn.Linear(50, 10)
self.fc2a = nn.Linear(300, 50)
self.fc2b = nn.Linear(50, 10)
def forward(self, x1, x2):
x1 = F.relu(self.fc1a(x1))
x1 = self.fc1b(x1)
x2 = x2.type(torch.float)
x2 = F.relu(self.fc2a(x2))
x2 = self.fc2b(x2)
x = torch.cat((x1, x2), 0)
return F.log_softmax(x, dim=1)
summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])