Introduction to torch.nn

on under PyTorch

Hey, world! Sup? I am super excited! I just heard Edward’s podcast on torch.nn. I am eager to share things that I learned with you all through this blog post. Before that, let’s take a look at the things I will cover in the blog.

  • Implementation of the torch.nn
  • Importance of NN module
  • Function of Parameter
  • The detail about the modules tracking parameters provided by the user.
  • Top Level of NN module class
  • New developments in the torch.nn module
  • Drawbacks of NN module

Let’s begin! PyTorch is a public API that is designed for neural networks. Its torch.nn module is an introductory module that provides functionality for implementing neural networks. I hope you know 😉. It’s known for delivering abstraction to the nn module in PyTorch. What’s that? Abstraction in Python is used to hide the internal functionality of the function from the users. The user can use all the provided parts required for neural networks, but they are unaware of the internal functioning of PyTorch’s torch.nn module.

Unlike JAX, which provides a single function for all the parts. But that gets annoying when the model gets too big. Henceforth, torch.nn provides an object-oriented interface that automatically collects all the parameters for us. Isn’t it interesting? 🤩

Another important thing about the torch.nn module is that the developers have not moved the torch.nn script to C++ (knowing C++ is faster). To provide hackability and an easier interface for the users, the code is still in Python. But this complicates the module :(

Well, Well, you can quickly solve using a few common tricks (:

  • How do we know if it is a parameter or not? Modules can collect parameters! It will keep track of parameters that are in the module. But ignore the ones that are not in the module but are tensors.
  • An upper parameter keeps track when we modify the module. setatr and getatr are two methods that keep track.
  • Find all the parameters in modules that are not only parameters but that also buffer. Example .cuda(). _apply() function is used to check each parameter.
  • Modules implement hooks. Hooks are just ways for interposing the behaviors on modules without actually manually writing them. How to implement? Every class has a function named forward. But here, we don’t call the forward function directly, but call the operator call (e.g., _forward, _bbackward, etc.) along with hooks. These hooks figure out all the administrative services needed before actually calling the forward implementation.

Goop in NN Module (Three imp things):

  • Keep track of parameters
  • Keep tracj of setatr and getatr
  • Hooks: Tricking behavior of call of module

Other important things:

  • Serialization: We have trained all the modules; we will have to dump all the parameters and trained models to use them in the future. We use serialization. We add state_dict() to serialize the module. We should ensure that the extra state added should be pickleable to ensure it’s working.

Other problem while writing modules in PyTorch: All the modules are torch scriptable. What’s the torch script?

Torch Script is a compiler for the PyTorch module. Restrictions of torch script:

  • Limitations in dtypes to use. (doesn’t support arbitrary types)
  • Set of Python code used inside the forward function should be understandable by torch script.

Unconventional things in Torch Script:

  • It’s a staged computation. It instantiates a module like a normal Python. Then compile forward implementation. Benefits:
    • Since it is initialized as normal Python, we can go wild and do anything we want without restriction.

New Developments in NN Module

  • Concept of Lazy Modules
  • We can allocate a big module in CUDA directly by adding device keywords to all the modules in PyTorch.

TODO: Learn about Meta Tensor.

The drawback of the NN Module

  • Requirement of functional versions of modules.
  • Initializations in PyTorch are out of date.

Futher Reading

Acknowledgement

I want to acknowledge Edward Z. Yang for his fantastic podcast, which helped me get an overview of the torch.nn module. Thank you, Edward!

pytorch
comments powered by Disqus