用于计算分析Pytorch网络结构计算量的工具torchinfo

torchinfo是目前github上比较好用,而且项目还很活跃,用于计算分析基于Pytorch的网络计算量的工具,它可以对网络结构进行可视化显示,方便进行网络结构调试。

使用演示

from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))

结果:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [16, 10, 24, 24]          260
├─Conv2d: 1-2                            [16, 20, 8, 8]            5,020
├─Dropout2d: 1-3                         [16, 20, 8, 8]            --
├─Linear: 1-4                            [16, 50]                  16,050
├─Linear: 1-5                            [16, 10]                  510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 7.69
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================

目前版本支持特性

  • RNNs, LSTMs, and other recursive layers 支持RNNs,LSTMs,以及其他递归层。
  • Sequentials & Module Lists 支持序列或者list模型
  • Branching output used to explore model layers using specified depths。支持特定分支深度的网络结构输出
  • Returns ModelStatistics object containing all summary data fields。返回所有模型统计数据
  • Configurable columns。可自由配置需要返回的模型参数统计
  • Jupyter Notebook / Google Colab。支持jupyter和Colab.

其他特性

  • Verbose mode to show weights and bias layers. 显示权重和偏置层
  • Accepts either input data or simply the input shape! 接受输入数据或仅接受输入形状
  • Customizable line widths and batch dimension。可自定义的线宽和batch维度
  • Comprehensive unit/output testing, linting, and code coverage testing。全面的单元输出测试。

使用例子

class MultipleInputNetDifferentDtypes(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1a = nn.Linear(300, 50)
        self.fc1b = nn.Linear(50, 10)

        self.fc2a = nn.Linear(300, 50)
        self.fc2b = nn.Linear(50, 10)

    def forward(self, x1, x2):
        x1 = F.relu(self.fc1a(x1))
        x1 = self.fc1b(x1)
        x2 = x2.type(torch.float)
        x2 = F.relu(self.fc2a(x2))
        x2 = self.fc2b(x2)
        x = torch.cat((x1, x2), 0)
        return F.log_softmax(x, dim=1)

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])

详细请访问项目:https://github.com/TylerYep/torchinfo

给TA买糖
共{{data.count}}人
人已赞赏
AI框架文章

第一视角:深度学习框架这几年

2021-4-27 10:30:17

文章

PyTorch 中模型的可复现性

2021-4-29 22:36:51

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
购物车
优惠劵
今日签到
搜索