Introduction to torch.nn
Hey, world! Sup? I am super excited!
I just heard Edward’s podcast on torch.nn
.
I am eager to share things that I learned with you all
through this blog post.
Before that, let’s take a look at the things I will cover in the blog.
- Implementation of the torch.nn
- Importance of NN module
- Function of Parameter
- The detail about the modules tracking parameters provided by the user.
- Top Level of NN module class
- New developments in the torch.nn module
- Drawbacks of NN module
Let’s begin!
PyTorch is a public API that is designed for neural networks.
Its torch.nn
module is an introductory module that provides functionality
for implementing neural networks. I hope you know 😉.
It’s known for delivering abstraction to the nn
module in PyTorch.
What’s that? Abstraction in Python is used to hide the internal functionality
of the function from the users. The user can use all
the provided parts required for neural networks,
but they are unaware of the internal functioning of PyTorch’s torch.nn module.
Unlike JAX, which provides a single function for all the parts.
But that gets annoying when the model gets too big.
Henceforth, torch.nn
provides an object-oriented interface that
automatically collects all the parameters for us. Isn’t it interesting? 🤩
Another important thing about the torch.nn
module is
that the developers have not moved the torch.nn script
to C++ (knowing C++ is faster).
To provide hackability and an easier interface for the users,
the code is still in Python. But this complicates the module :(
Well, Well, you can quickly solve using a few common tricks (:
- How do we know if it is a parameter or not? Modules can collect parameters! It will keep track of parameters that are in the module. But ignore the ones that are not in the module but are tensors.
- An upper parameter keeps track when we modify the module.
setatr
andgetatr
are two methods that keep track. - Find all the parameters in modules that are not only
parameters but that also buffer.
Example
.cuda()
._apply()
function is used to check each parameter. - Modules implement
hooks
. Hooks are just ways for interposing the behaviors on modules without actually manually writing them. How to implement? Every class has a function named forward. But here, we don’t call the forward function directly, but call the operator call (e.g., _forward, _bbackward, etc.) along with hooks. These hooks figure out all the administrative services needed before actually calling the forward implementation.
Goop in NN Module (Three imp things):
- Keep track of parameters
- Keep tracj of
setatr
andgetatr
- Hooks: Tricking behavior of call of module
Other important things:
- Serialization: We have trained all the modules;
we will have to dump all the parameters and trained models
to use them in the future. We use serialization.
We add
state_dict()
to serialize the module. We should ensure that the extra state added should be pickleable to ensure it’s working.
Other problem while writing modules in PyTorch: All the modules are torch scriptable. What’s the torch script?
Torch Script is a compiler for the PyTorch module. Restrictions of torch script:
- Limitations in dtypes to use. (doesn’t support arbitrary types)
- Set of Python code used inside the forward function should be understandable by torch script.
Unconventional things in Torch Script:
- It’s a staged computation.
It instantiates a module like a normal Python. Then compile forward implementation. Benefits:
- Since it is initialized as normal Python, we can go wild and do anything we want without restriction.
New Developments in NN Module
- Concept of Lazy Modules
- We can allocate a big module in CUDA directly by adding device keywords to all the modules in PyTorch.
TODO
: Learn about Meta Tensor.
The drawback of the NN Module
- Requirement of functional versions of modules.
- Initializations in PyTorch are out of date.
Futher Reading
- Implementation of
nn.Module
: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
Acknowledgement
I want to acknowledge Edward Z. Yang for his fantastic podcast, which helped me get an overview of the torch.nn
module. Thank you, Edward!
Let me know what you think of this article on Twitter @khushi__411 or leave a comment below!