In the 3rd assignment of XCS330, we will implement prototypical neworks (protonets) for few-shots image classification on the Omniglot dataset.

Protonets Algorithm

This is the protonet in a nutshell, and the example comes from the assignment:

protonet algo

In this example, we compute three class prototypes c1, c2, c3 from the support features. The decision boundaries are computed using Euclidean distance. When there’s a new query, we can determine which class it belongs to.

Implement ProtoNet._step Method

The first coding task is to implement the ProtoNet._step method, which computest the accuracy metrics.

The input of this function is a task batch, and the output is the mean ProtoNet loss over the batch shape.

Here’s the instruction:

  1. compute the prototypes
  2. compute the protonet loss
  3. use cross entropy to compute the classification loss
  4. use util.score to compute accuracies

Let’s divide and conquer these tasks.

Input & Output Shape

I ran the basic test with >python grader.py 1b-0-basic and got the following shapes:

# print(images_support.shape)
torch.Size([5, 1, 28, 28])

# print(labels_support.shape)
torch.Size([5])

# print(images_query.shape)
torch.Size([75, 1, 28, 28])

# print(labels_query.shape)
torch.Size([75])

These shapes are determined by the setUp() function:

self.dataloader_train = omniglot.get_omniglot_dataloader(
    split='train',
    batch_size=16,
    num_way=5,
    num_support=1,
    num_query=15,
    num_tasks_per_epoch=240000
)

Here we got 5 classes, with 1 support set and 15 query sets.

Understand ProtoNet

The network definition is in the __init__() function:

class ProtoNetNetwork(nn.Module):
    """Container for ProtoNet weights and image-to-latent computation."""

    def __init__(self, device):
        """Inits ProtoNetNetwork.

        The network consists of four convolutional blocks, each comprising a
        convolution layer, a batch normalization layer, ReLU activation, and 2x2
        max pooling for downsampling. There is an additional flattening
        operation at the end.

        Note that unlike conventional use, batch normalization is always done
        with batch statistics, regardless of whether we are training or
        evaluating. This technically makes meta-learning transductive, as
        opposed to inductive.

        Args:
            device (str): device to be used
        """
        super().__init__()
        layers = []
        in_channels = NUM_INPUT_CHANNELS
        for _ in range(NUM_CONV_LAYERS):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    NUM_HIDDEN_CHANNELS,
                    (KERNEL_SIZE, KERNEL_SIZE),
                    padding='same'
                )
            )
            layers.append(nn.BatchNorm2d(NUM_HIDDEN_CHANNELS))
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(2))
            in_channels = NUM_HIDDEN_CHANNELS
        layers.append(nn.Flatten())
        self._layers = nn.Sequential(*layers)
        self.to(device)

The docs string is self explanatory.

The neural network takes an image as the input, and outputs a feature representation of that image, and that is: shape (num_images, latents)

Compute Protoypes

This is the math representation of a prototype:

$$ c_n = \frac{1}{K} \sum_{(x,y) \in \mathcal{D}^\mathrm{tr}i: y=n} f\theta(x) $$

We use a mapping function $f_\theta(.)$ to map the images to features. We then iterate over the K examples in the training dataset to calculate the prototype.

Here’s the implementation of it:

prototypes = []
for i in range(labels_support.max() + 1):
    # get the latent of the i-th class
    latents = self._network(images_support[labels_support == i])
    # take the mean of the latents
    prototypes.append(latents.mean(dim=0))

Basically, we want to find all the images for the i-th class, use the network to get the latents, and then calculate the mean of it. Note that (dim=0) is needed, otherwise, the mean() function returns the average over all numbers.

Compute Loss

Once we have the prototype, we then use squared Euclidean distance to measure the difference.

$$ d(f_\theta(x), c_n) = | f_\theta(x) - c_n |_2^2 $$

We define the logits as the negative distance -d, which makes sense, because the closer the latent vectors are, the similar the images are.

Finally, we use the cross entropy function to calculate the loss.

Note that, we use support set to compute the prototypes, but we use the query set to calculate the loss.

Also note that, the input tensor for the cross_entropy function is the logits nor the prbabilities! During the implementation, I first applied the softmax function to get the prbabilities and then use it as the input, which is wrong!

torch.nn.functional.cross_entropy

Parameters:
- input: predicted unnormalized logits
- target: ground truth class indicies or class probabilities

Shape:
- input: (N, C)
- target: (N)

where:
C = number of classes
N = batch size

Example

Again, let’s use the unit tests as the example. Assume we have 5 prototypes and each prototype c is a latent vector:

$$ \begin{bmatrix} \begin{bmatrix} c1 \end{bmatrix}\ \begin{bmatrix} c2 \end{bmatrix}\ …\ \begin{bmatrix} c5 \end{bmatrix} \end{bmatrix} $$

And let’s assume we have 75 query images:

$$ \begin{bmatrix} \begin{bmatrix} image 1 \end{bmatrix}\ \begin{bmatrix} image 2 \end{bmatrix}\ …\ \begin{bmatrix} image 75 \end{bmatrix} \end{bmatrix} $$

The first step is to use the model $f_\theta(.)$ to map the images to features. After the mapping, we have:

$$ \begin{bmatrix} \begin{bmatrix} latent 1 \end{bmatrix}\ \begin{bmatrix} latent 2 \end{bmatrix}\ …\ \begin{bmatrix} latent 75 \end{bmatrix} \end{bmatrix} $$

Let’s focus on the operation with all latents, but one prototype. All latents here means query_latents and it has a shape of [75, d]. One prototype here is c1 and it has a shape of [d]. When we do (query_latents - prototype1)**2, we got:

$$ \begin{bmatrix} \begin{bmatrix} || latent 1 - c1 ||^2 \end{bmatrix}\ \begin{bmatrix} || latent 2 - c1 ||^2 \end{bmatrix}\ …\ \begin{bmatrix} || latent 75 - c1 ||^2 \end{bmatrix} \end{bmatrix} $$

Once we have the diff square, we can then use torch.sum(a, dim=1) to calculate the squared Euclidean distance:

$$ \begin{bmatrix} \begin{bmatrix} d_1 \end{bmatrix}\ \begin{bmatrix} d_2 \end{bmatrix}\ …\ \begin{bmatrix} d_{75} \end{bmatrix} \end{bmatrix} $$

Finally we use the torch.stack([], dim=1) function to extend the distance along the column axis.

Torch.no_grad()

We also need to use torch.no_grad() to compute prototypes.