.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "beginner/flava_finetuning_tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_beginner_flava_finetuning_tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_beginner_flava_finetuning_tutorial.py:


TorchMultimodal Tutorial: Finetuning FLAVA
============================================

.. GENERATED FROM PYTHON SOURCE LINES 8-22

Multimodal AI has recently become very popular owing to its ubiquitous
nature, from use cases like image captioning and visual search to more
recent applications like image generation from text. **TorchMultimodal
is a library powered by Pytorch consisting of building blocks and end to
end examples, aiming to enable and accelerate research in
multimodality**.

In this tutorial, we will demonstrate how to use a **pretrained SoTA
model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
TorchMultimodal library to finetune on a multimodal task i.e. visual
question answering** (VQA). The model consists of two unimodal transformer
based encoders for text and image and a multimodal encoder to combine
the two embeddings. It is pretrained using contrastive, image text matching and 
text, image and multimodal masking losses.

.. GENERATED FROM PYTHON SOURCE LINES 25-41

Installation
-----------------
We will use TextVQA dataset and bert tokenizer from HuggingFace for this
tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.

.. note::

   When running this tutorial in Google Colab, install the required packages by
   creating a new cell and running the following commands:

   .. code-block::

      !pip install torchmultimodal-nightly
      !pip install datasets
      !pip install transformers


.. GENERATED FROM PYTHON SOURCE LINES 43-70

Steps 
-----

1. Download the HuggingFace dataset to a directory on your computer by running the following command:

   .. code-block::

      wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz 
      tar xf vocab.tar.gz

   .. note:: 
      If you are running this tutorial in Google Colab, run these commands
      in a new cell and prepend these commands with an exclamation mark (!)


2. For this tutorial, we treat VQA as a classification task where
   the inputs are images and question (text) and the output is an answer class. 
   So we need to download the vocab file with answer classes and create the answer to
   label mapping.

   We also load the `textvqa
   dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
   (images,questions and answers) from HuggingFace

We see there are 3997 answer classes including a class representing
unknown answers.


.. GENERATED FROM PYTHON SOURCE LINES 70-83

.. code-block:: default


    with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
      vocab = f.readlines()

    answer_to_idx = {}
    for idx, entry in enumerate(vocab):
      answer_to_idx[entry.strip("\n")] = idx
    print(len(vocab))
    print(vocab[:5])

    from datasets import load_dataset
    dataset = load_dataset("textvqa")


.. GENERATED FROM PYTHON SOURCE LINES 84-86

Lets display a sample entry from the dataset:


.. GENERATED FROM PYTHON SOURCE LINES 86-97

.. code-block:: default


    import matplotlib.pyplot as plt
    import numpy as np 
    idx = 5 
    print("Question: ", dataset["train"][idx]["question"]) 
    print("Answers: " ,dataset["train"][idx]["answers"])
    im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
    plt.imshow(im)
    plt.show()



.. GENERATED FROM PYTHON SOURCE LINES 98-105

3. Next, we write the transform function to convert the image and text into
Tensors consumable by our model - For images, we use the transforms from
torchvision to convert to Tensor and resize to uniform sizes - For text,
we tokenize (and pad) them using the BertTokenizer from HuggingFace -
For answers (i.e. labels), we take the most frequently occuring answer
as the label to train with:


.. GENERATED FROM PYTHON SOURCE LINES 105-135

.. code-block:: default


    import torch
    from torchvision import transforms
    from collections import defaultdict
    from transformers import BertTokenizer
    from functools import partial

    def transform(tokenizer, input):
      batch = {}
      image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
      image = image_transform(input["image"][0].convert("RGB"))
      batch["image"] = [image]

      tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
      batch.update(tokenized)


      ans_to_count = defaultdict(int)
      for ans in input["answers"][0]:
        ans_to_count[ans] += 1
      max_value = max(ans_to_count, key=ans_to_count.get)
      ans_idx = answer_to_idx.get(max_value,0)
      batch["answers"] = torch.as_tensor([ans_idx])
      return batch

    tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
    transform=partial(transform,tokenizer)
    dataset.set_transform(transform)



.. GENERATED FROM PYTHON SOURCE LINES 136-147

4. Finally, we import the flava_model_for_classification from
torchmultimodal. It loads the pretrained flava checkpoint by default and
includes a classification head.

The model forward function passes the image through the visual encoder
and the question through the text encoder. The image and question
embeddings are then passed through the multimodal encoder. The final
embedding corresponding to the CLS token is passed through a MLP head
which finally gives the probability distribution over each possible
answers.


.. GENERATED FROM PYTHON SOURCE LINES 147-152

.. code-block:: default


    from torchmultimodal.models.flava.model import flava_model_for_classification
    model = flava_model_for_classification(num_classes=len(vocab))



.. GENERATED FROM PYTHON SOURCE LINES 153-156

5. We put together the dataset and model in a toy training loop to
demonstrate how to train the model for 3 iterations:


.. GENERATED FROM PYTHON SOURCE LINES 156-179

.. code-block:: default


    from torch import nn
    BATCH_SIZE = 2
    MAX_STEPS = 3
    from torch.utils.data import DataLoader

    train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
    optimizer = torch.optim.AdamW(model.parameters())


    epochs = 1
    for _ in range(epochs):
      for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
        loss = out.loss
        loss.backward()
        optimizer.step()
        print(f"Loss at step {idx} = {loss}")
        if idx > MAX_STEPS-1:
          break



.. GENERATED FROM PYTHON SOURCE LINES 180-191

Conclusion
-------------------

This tutorial introduced the basics around how to finetune on a
multimodal task using FLAVA from TorchMultimodal. Please also check out
other examples from the library like
`MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
which is a multimodal model for object detection and
`Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
which is multitask model spanning image, video and 3d classification.



.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)


.. _sphx_glr_download_beginner_flava_finetuning_tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: flava_finetuning_tutorial.py <flava_finetuning_tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: flava_finetuning_tutorial.ipynb <flava_finetuning_tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_