In [1]:
!pip install -q timm torch torchvision

In [2]:
import random
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR10
import timm

In [3]:
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
model = timm.create_model("resnet50", pretrained=True, num_classes=0)
model.eval().to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [5]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


In [6]:
train_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)

class_names = train_dataset.classes
selected_classes = [0, 1, 2]
num_support_per_class = 5
num_query_per_class = 5

support_indices, query_indices = [], []

100%|██████████| 170M/170M [00:03<00:00, 48.4MB/s]


In [7]:
for cls in selected_classes:
    train_idxs = [i for i, y in enumerate(train_dataset.targets) if y == cls]
    test_idxs  = [i for i, y in enumerate(test_dataset.targets) if y == cls]

    support_indices.extend(random.sample(train_idxs, num_support_per_class))
    query_indices.extend(random.sample(test_idxs, num_query_per_class))


In [9]:
support_images, support_labels = zip(*[train_dataset[i] for i in support_indices])
query_images, query_labels     = zip(*[test_dataset[i] for i in query_indices])

support_images = torch.stack(support_images).to(device)
query_images   = torch.stack(query_images).to(device)
support_labels = list(support_labels)
query_labels   = list(query_labels)

print(f"Support set: {support_images.size(0)} images")
print(f"Query set:   {query_images.size(0)} images")

Support set: 15 images
Query set:   15 images


In [10]:
with torch.no_grad():
    support_embeddings = model(support_images)
    query_embeddings   = model(query_images)

support_embeddings = F.normalize(support_embeddings, p=2, dim=1)
query_embeddings   = F.normalize(query_embeddings, p=2, dim=1)

In [11]:
similarity = torch.mm(query_embeddings, support_embeddings.T)
_, nearest = similarity.max(dim=1)
predicted_labels = [support_labels[i] for i in nearest.cpu().tolist()]

In [12]:
print("\n Few-Shot Classification Results:")
for i, (t, p) in enumerate(zip(query_labels, predicted_labels), 1):
    print(f"Query {i}: True = {class_names[t]:12s} | Pred = {class_names[p]}")

correct = sum(int(t == p) for t, p in zip(query_labels, predicted_labels))
acc = 100.0 * correct / len(query_labels)
print(f"\nAccuracy: {acc:.2f}%")


 Few-Shot Classification Results:
Query 1: True = airplane     | Pred = airplane
Query 2: True = airplane     | Pred = airplane
Query 3: True = airplane     | Pred = airplane
Query 4: True = airplane     | Pred = automobile
Query 5: True = airplane     | Pred = automobile
Query 6: True = automobile   | Pred = automobile
Query 7: True = automobile   | Pred = automobile
Query 8: True = automobile   | Pred = automobile
Query 9: True = automobile   | Pred = automobile
Query 10: True = automobile   | Pred = automobile
Query 11: True = bird         | Pred = bird
Query 12: True = bird         | Pred = bird
Query 13: True = bird         | Pred = bird
Query 14: True = bird         | Pred = bird
Query 15: True = bird         | Pred = airplane

Accuracy: 80.00%
