SW/영상인식

CIFAR 10 : Few Shot Learning : 새로운 클래스 인식

얇은생각 2019. 11. 2. 18:30
반응형

과제

"""
* CIFAR-10 testset 소개
32x32 three color = 3072 dimensions per sample
비행기, 자동차, 새, 고양이, 사슴, 개, 개구리, 말, 배, 트럭의 10가지 클래스로 이뤄져 있으며, 60000개(트레이닝셋 50000 / 테스트셋 10000)의 이미지(클래스당 6000개)를 포함한 데이터셋.

=== HW #3 내용 ===========================================

* 필수사항 (100점)
1. Model Saver 기능 포함 필요
2. Random sample or sample index 지정하여 visualization & 추정결과 plotting (MNIST 과제와 비슷)
3. Testset Accuracy 90% 이상 넘기기


* 도전과제 (20~100점)
1. 직접 찍은 실제 자동차, 고양이, 개 사진 등을 이용하여 테스트 결과 출력 (20점)
참고 자료: 업로드
2. Visualize Filters (20점)
관련 문서: https://hwiyong.tistory.com/35
3. 사람 얼굴 등 새로운 클래스를 인식하도록 만들기 (40점)
관련 문서: https://blog.naver.com/cenodim/220946688251
4. Few-shot learning 등으로 새로운 클래스를 인식하도록 만들기 (3의 hard case) (100점)
관련 문서: ① https://wewinserv.tistory.com/123 ② https://arxiv.org/pdf/1904.05046.pdf

① 해당 논문은 distractor라는 개념을 통해서 새로 들어온 샘플이 새로운 클래스로 분류되어야 하는지 아닌지를 먼저 판단하고, 이를 반영하여 새로운 unknown 클래스로 자동으로 분류하는 모델에 대한 논문입니다. 이 논문을 이해하기 위해서는 clustering(k-means 등) 에 대한 지식이 있어야 합니다.
② few-shot learning 에 대한 survey paper 입니다.

도전과제는 모두 할 필요 없습니다. 중복으로 해도 됩니다. (단, 3과 4는 중복이 안 됩니다.)


* 제공 코드
1. CIFAR-10 예제코드 (Simple CNN)
2. 도전과제 1에 대한 예시 code (MNIST case)


* 제출방법: 코드(.ipynb or .py)와 보고서(.hwp or .docx)를 zip으로 압축하여 제출해주세요. 파일명은 '학번_성명_hw3.zip' 와 같이 해 주시길 바랍니다.


* 보고서: 필수사항에서의 결과 및 코드에서의 해당 부분을 표시하여 주시고,
도전과제 결과의 경우 구현 방법에 대한 간략한 설명과 함께 결과를 캡쳐해주시기 바랍니다.
"""

 

해당 과제에서 제시한 내용, 도전 과제 4를 진행해 보았습니다. 이전 과제들도 추후 포스팅 예정입니다. 과연 퓨샷 러닝은 무엇이고, 어떠한 방식으로 연구가 되었는지, 많은 조사를 통해 알게 되었습니다. 저는 그중에서도, GNN과 FewShot을 활용한 방식에서 착안하여 CIFAR에 적용해보았습니다. 

샴 네트워크 등과 같은 FewShot 러닝 역시 매우 흥미로운 분야이고, 아직도 많은 이해가 필요로 하고, 부족한 점이 많다는 것을 알게 해준 과제 내용이었습니다. 

 

Few Shot Learning 개요

샷은 훈련에 사용할 수 있는 하나의 예일 뿐이므로 N-shot 학습에는 훈련에 대한 N 가지 예가 있습니다. “Few-shot learning”이라는 용어를 사용하면 “Few-shot learning”은 일반적으로 0에서 5 사이에 있습니다. 즉, 예가 없는 모델을 훈련하는 것은 zero-shot learning, 하나의 예는 one-shot learning 등입니다. 이러한 모든 변형은 다양한 수준의 교육 자료로 동일한 문제를 해결하려고합니다.

몇 번의 학습은 하나 이상의 학습 예제가 있는 유연한 원샷 학습의 유연한 버전입니다. 일반적으로 2 ~ 5 개의 이미지가 있지만 위에서 언급 한 모델의 대부분은 퓨샷 학습에도 사용될 수 있습니다.

2019 년 컴퓨터 비전 및 패턴 인식에 관한 컨퍼런스에서 몇 번의 학습을위한 메타 전송 학습이 발표되었습니다. 이 모델은 미래의 연구를위한 선례를 설정했습니다. 최신 결과를 제공하고보다 정교한 메타 전송 학습 방법을위한 길을 열었습니다.

이러한 메타 학습 및 강화 학습 알고리즘 중 다수는 일반적인 딥 러닝 알고리즘과 결합되어 놀라운 결과를 생성합니다. 프로토 타입 네트워크는 가장 널리 사용되는 딥 러닝 알고리즘 중 하나이며이 작업에 자주 사용됩니다.

 

GNN과 Few Shot Leaning 관계

GNN과 FewShot 구조

라벨을 관찰 할 수 있는 입력 이미지 모음으로 구성된 부분적으로 관찰 된 그래픽 모델에 대한 추론의 프리즘으로 소수 샷 학습 문제가 연구 되었습니다. 일반적인 네트워크 통과 알고리즘으로 일반 메시지 전달 추론 알고리즘을 동화함으로써 최근에 제안 된 퓨샷 학습 모델을 일반화하는 그래프 신경 네트워크 아키텍처를 정의합니다. 개선된 수치 성능을 제공하는 것 외에도 프레임 워크는 semi-supervised 또는 active learning과 같은 퓨샷 학습 형태로 쉽게 확장되어 '관계형' 작업에서 그래프 기반 모델이 잘 작동하는 능력을 보여줍니다.

최적화 된 기술의 개선, 더 큰 데이터 세트 및 심층 컨볼 루션 또는 반복 아키텍처의 간소화 된 설계로 인해 컴퓨터 비전, 음성 또는 기계 번역 작업에서 supervised 종단 간 학습이 크게 성공했습니다. 그러한 사례 중 하나는 소위 퓨샷 학습 작업에서 몇 가지 예를 통해 배울 수있는 능력입니다.

데이터 부족을 보완하기 위해 정규화에 의존하기보다는 연구자들은 인간 학습에서 영감을 얻은 유사한 작업 분포를 활용하는 방법을 모색했습니다. 이것은 새로운 지도 학습 학습 설정 ('메타-러닝'이라고도 함)을 정의하는데, 여기서 입력-출력 쌍은 더 이상 이미지의 샘플 및 관련 레이블이 아니라 이미지 컬렉션의 샘플 및 관련 레이블 유사성에 의해 제공됩니다.

최근에 성공적으로 수행 된 연구 프로그램은 몇 번의 이미지 분류 작업인 메타 학습 패러다임을 활용했습니다. 본질적으로, 이 작품들은 맥락에 따라 작업별 유사성 측정법을 배우고, 먼저 CNN을 사용하여 입력 이미지를 삽입한 다음 컬렉션에 포함된 이미지를 결합하여 레이블 정보를 대상 이미지로 전파하는 방법을 배웁니다.

특히 지원되는 이미지 세트를 원하는 레이블에 매핑하는 Supervised 분류 작업으로 몇 번의 학습 문제를 해결하고 이러한 지원 세트를 주의 메커니즘을 통해 입력으로 받아들이는 엔드 투 엔드 아키텍처를 개발했습니다. 이 작업에서 작업 라인을 바탕으로 이 작업이 그래프에서 Supervised된 보간 문제로 자연스럽게 표현되며 노드가 컬렉션의 이미지와 연관되고 가장자리는 훈련 가능한 유사성 커널에 의해 제공된다고 주장합니다. 그래프 구조화된 데이터의 표현 학습에 대한 최근 진행 상황을 활용합니다. 작업 중심 메시지 전달 알고리즘을 구현하는 간단한 그래프 기반의 퓨샷 학습 모델을 제안합니다. 결과 아키텍처는 엔드-투-엔드 교육을 받고 입력 컬렉션 내의 순열과 같은 작업의 불일치를 캡처하며 단순성, 일반성, 성능 및 샘플 복잡성간에 적절한 균형을 제공합니다.

몇 번의 학습 이외에도 관련 과제는 레이블이 붙은 예제와 레이블이 없는 예제 (반지도 학습 및 능동 학습)를 혼합하여 학습하는 기능입니다. 학습자는 학습자가 가장 많이 누락 된 레이블을 요청할 수 있습니다 예측 작업에 도움이됩니다. 우리의 그래프 기반 아키텍처는 훈련 설계에서 최소한의 변화만으로 이러한 설정으로 자연스럽게 확장됩니다.

소수의 이미지 분류에 대한 모델을 실험적으로 검증하고, 최첨단 성능을 훨씬 적은 수의 매개 변수와 일치시키고, 반 Supervised 및 능동 학습 설정에 대한 응용 프로그램을 시연합니다.

 

코드

Data 코드

import os
import time
import random
import skimage.io
import numpy as np

import torch
from torch.utils.data import Dataset
import torchvision as tv
from torchvision.datasets import CIFAR10

class self_Dataset(Dataset):
    def __init__(self, data, label=None):
        super(self_Dataset, self).__init__()

        self.data = data
        self.label = label
    def __getitem__(self, index):
        data = self.data[index]

        if self.label is not None:
            label = self.label[index]
            return data, label
        else:
            return data, 1
    def __len__(self):
        return len(self.data)

def count_data(data_dict):
    num = 0
    for key in data_dict.keys():
        num += len(data_dict[key])
    return num

class self_DataLoader(Dataset):
    def __init__(self, root, train=True, dataset='cifar10', seed=1, nway=5):
        super(self_DataLoader, self).__init__()

        self.seed = seed
        self.nway = nway
        self.num_labels = 10
        self.input_channels = 3
        self.size = 32

        self.transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize([0.5071, 0.4866, 0.4409],
                [0.2673, 0.2564, 0.2762])
            ])

        self.full_data_dict, self.few_data_dict = self.load_data(root, train, dataset)

        print('full_data_num: %d' % count_data(self.full_data_dict))
        print('few_data_num: %d' % count_data(self.few_data_dict))

    def load_data(self, root, train, dataset):
        if dataset == 'cifar10':
            # few_selected_label = random.Random(self.seed).sample(range(self.num_labels), self.nway)
            few_selected_label = random.Random(self.seed).sample([0,3,7,1,2], self.nway)
            
            print('selected labeled', few_selected_label)

            full_data_dict = {}
            few_data_dict = {}

            d = CIFAR10(root, train=train, download=True)

            for i, (data, label) in enumerate(d):

                data = self.transform(data)

                if label in few_selected_label:
                    data_dict = few_data_dict
                else:
                    data_dict = full_data_dict

                if label not in data_dict:
                    data_dict[label] = [data]
                else:
                    data_dict[label].append(data)
            print(i + 1)
        else:
            raise NotImplementedError

        return full_data_dict, few_data_dict

    def load_batch_data(self, train=True, batch_size=16, nway=5, num_shots=1):
        if train:
            data_dict = self.full_data_dict
        else:
            data_dict = self.few_data_dict

        x = []
        label_y = [] 
        one_hot_y = [] # one hot for fake label
        class_y = [] # real label

        xi = []
        label_yi = []
        one_hot_yi = []


        map_label2class = []

        for i in range(batch_size):

            # sample the class to train
            sampled_classes = random.sample(data_dict.keys(), nway)
            
            positive_class = random.randint(0, nway - 1)

            label2class = torch.LongTensor(nway)

            single_xi = []
            single_one_hot_yi = []
            single_label_yi = []
            single_class_yi = []


            for j, _class in enumerate(sampled_classes):
                if j == positive_class:
                    sampled_data = random.sample(data_dict[_class], num_shots+1)

                    x.append(sampled_data[0])
                    label_y.append(torch.LongTensor([j]))

                    one_hot = torch.zeros(nway)
                    one_hot[j] = 1.0
                    one_hot_y.append(one_hot)

                    class_y.append(torch.LongTensor([_class]))

                    shots_data = sampled_data[1:]
                else:
                    shots_data = random.sample(data_dict[_class], num_shots)

                single_xi += shots_data
                single_label_yi.append(torch.LongTensor([j]).repeat(num_shots))
                one_hot = torch.zeros(nway)
                one_hot[j] = 1.0
                single_one_hot_yi.append(one_hot.repeat(num_shots, 1))

                label2class[j] = _class

            shuffle_index = torch.randperm(num_shots*nway)
            xi.append(torch.stack(single_xi, dim=0)[shuffle_index])
            label_yi.append(torch.cat(single_label_yi, dim=0)[shuffle_index])
            one_hot_yi.append(torch.cat(single_one_hot_yi, dim=0)[shuffle_index])

            map_label2class.append(label2class)

        return [torch.stack(x, 0), torch.cat(label_y, 0), torch.stack(one_hot_y, 0), \
            torch.cat(class_y, 0), torch.stack(xi, 0), torch.stack(label_yi, 0), \
            torch.stack(one_hot_yi, 0), torch.stack(map_label2class, 0)]

    def load_tr_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(True, batch_size, nway, num_shots)

    def load_te_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(False, batch_size, nway, num_shots)

    def get_data_list(self, data_dict):
        data_list = []
        label_list = []
        for i in data_dict.keys():
            for data in data_dict[i]:
                data_list.append(data)
                label_list.append(i)

        now_time = time.time()

        random.Random(now_time).shuffle(data_list)
        random.Random(now_time).shuffle(label_list)

        return data_list, label_list

    def get_full_data_list(self):
        return self.get_data_list(self.full_data_dict)

    def get_few_data_list(self):
        return self.get_data_list(self.few_data_dict)

if __name__ == '__main__':
    D = self_DataLoader('/home/lab5300/Data', True)

    [x, label_y, one_hot_y, class_y, xi, label_yi, one_hot_yi, class_yi] = \
        D.load_tr_batch(batch_size=16, nway=5, num_shots=5)
    print(x.size(), label_y.size(), one_hot_y.size(), class_y.size())
    print(xi.size(), label_yi.size(), one_hot_yi.size(), class_yi.size())

    print(label_yi[0])
    print(one_hot_yi[0])

 

Layer 코드

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class Graph_conv_block(nn.Module):
    def __init__(self, input_dim, output_dim, use_bn=True):
        super(Graph_conv_block, self).__init__()

        self.weight = nn.Linear(input_dim, output_dim)
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
        else:
            self.bn = None

    def forward(self, x, A):
        x_next = torch.matmul(A, x) # (b, N, input_dim)
        x_next = self.weight(x_next) # (b, N, output_dim)

        if self.bn is not None:
            x_next = torch.transpose(x_next, 1, 2) # (b, output_dim, N)
            x_next = x_next.contiguous()
            x_next = self.bn(x_next)
            x_next = torch.transpose(x_next, 1, 2) # (b, N, output)

        return x_next

class Adjacency_layer(nn.Module):
    def __init__(self, input_dim, hidden_dim, ratio=[2,2,1,1]):

        super(Adjacency_layer, self).__init__()

        module_list = []

        for i in range(len(ratio)):
            if i == 0:
                module_list.append(nn.Conv2d(input_dim, hidden_dim*ratio[i], 1, 1))
            else:
                module_list.append(nn.Conv2d(hidden_dim*ratio[i-1], hidden_dim*ratio[i], 1, 1))

            module_list.append(nn.BatchNorm2d(hidden_dim*ratio[i]))
            module_list.append(nn.LeakyReLU())

        module_list.append(nn.Conv2d(hidden_dim*ratio[-1], 1, 1, 1))

        self.module_list = nn.ModuleList(module_list)

    def forward(self, x):
        X_i = x.unsqueeze(2) # (b, N , 1, input_dim)
        X_j = torch.transpose(X_i, 1, 2) # (b, 1, N, input_dim)

        phi = torch.abs(X_i - X_j) # (b, N, N, input_dim)

        phi = torch.transpose(phi, 1, 3) # (b, input_dim, N, N)

        A = phi

        for l in self.module_list:
            A = l(A)
        # (b, 1, N, N)

        A = torch.transpose(A, 1, 3) # (b, N, N, 1)

        A = F.softmax(A, 2) # normalize

        return A.squeeze(3) # (b, N, N)

class GNN_module(nn.Module):
    def __init__(self, nway, input_dim, hidden_dim, num_layers, feature_type='dense'):
        super(GNN_module, self).__init__()

        self.feature_type = feature_type

        adjacency_list = []
        graph_conv_list = []

        # ratio = [2, 2, 1, 1]
        ratio = [2, 1]

        if self.feature_type == 'dense':
            for i in range(num_layers):
                adjacency_list.append(Adjacency_layer(
                    input_dim=input_dim+hidden_dim//2*i, 
                    hidden_dim=hidden_dim, 
                    ratio=ratio))

                graph_conv_list.append(Graph_conv_block(
                    input_dim=input_dim+hidden_dim//2*i, 
                    output_dim=hidden_dim//2))

            # last layer
            last_adjacency = Adjacency_layer(
                        input_dim=input_dim+hidden_dim//2*num_layers, 
                        hidden_dim=hidden_dim, 
                        ratio=ratio)

            last_conv = Graph_conv_block(
                    input_dim=input_dim+hidden_dim//2*num_layers, 
                    output_dim=nway, 
                    use_bn=False)

        elif self.feature_type == 'forward':
            for i in range(num_layers):
                adjacency_list.append(Adjacency_layer(
                    input_dim=input_dim if i == 0 else hidden_dim, 
                    hidden_dim=hidden_dim, 
                    ratio=ratio))

                graph_conv_list.append(Graph_conv_block(
                    input_dim=hidden_dim, 
                    output_dim=hidden_dim))

            # last layer
            last_adjacency = Adjacency_layer(
                        input_dim=hidden_dim, 
                        hidden_dim=hidden_dim, 
                        ratio=ratio)

            last_conv = Graph_conv_block(
                    input_dim=hidden_dim, 
                    output_dim=nway,
                    use_bn=False)

        else:
            raise NotImplementedError

        self.adjacency_list = nn.ModuleList(adjacency_list)
        self.graph_conv_list = nn.ModuleList(graph_conv_list)
        self.last_adjacency = last_adjacency
        self.last_conv = last_conv


    def forward(self, x):
        for i, _ in enumerate(self.adjacency_list):
            adjacency_layer = self.adjacency_list[i]
            conv_block = self.graph_conv_list[i]

            A = adjacency_layer(x)

            x_next = conv_block(x, A)

            x_next = F.leaky_relu(x_next, 0.1)

            if self.feature_type == 'dense':
                x = torch.cat([x, x_next], dim=2)
            elif self.feature_type == 'forward':
                x = x_next
            else:
                raise NotImplementedError
        
        A = self.last_adjacency(x)
        out = self.last_conv(x, A)   

        return out[:, 0, :]

 

Train 코드

import os
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

from time import time

from gnn import GNN_module


def np2cuda(array):
    tensor = torch.from_numpy(array)
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor


def tensor2cuda(tensor):
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()

    def load(self, file_name):
        self.load_state_dict(torch.load(file_name, map_location=lambda storage, loc: storage))
    def save(self, file_name):
        torch.save(self.state_dict(), file_name)

###############################################################
## Vanilla CNN model, used to extract visual features

class EmbeddingCNN(myModel):

    def __init__(self, image_size, cnn_feature_size, cnn_hidden_dim, cnn_num_layers):
        super(EmbeddingCNN, self).__init__()

        module_list = []
        dim = cnn_hidden_dim
        for i in range(cnn_num_layers):
            if i == 0:
                module_list.append(nn.Conv2d(3, dim, 3, 1, 1, bias=False))
                module_list.append(nn.BatchNorm2d(dim))
            else:
                module_list.append(nn.Conv2d(dim, dim*2, 3, 1, 1, bias=False))
                module_list.append(nn.BatchNorm2d(dim*2))
                dim *= 2
            module_list.append(nn.MaxPool2d(2))
            module_list.append(nn.LeakyReLU(0.1, True))
            image_size //= 2
        module_list.append(nn.Conv2d(dim, cnn_feature_size, image_size, 1, bias=False))
        module_list.append(nn.BatchNorm2d(cnn_feature_size))
        module_list.append(nn.LeakyReLU(0.1, True))

        self.module_list = nn.ModuleList(module_list)

    def forward(self, inputs):
        for l in self.module_list:
            inputs = l(inputs)

        outputs = inputs.view(inputs.size(0), -1)
        return outputs

    def freeze_weight(self):
        for p in self.parameters():
            p.requires_grad = False

class GNN(myModel):
    def __init__(self, cnn_feature_size, gnn_feature_size, nway):
        super(GNN, self).__init__()

        num_inputs = cnn_feature_size + nway
        graph_conv_layer = 2
        self.gnn_obj = GNN_module(nway=nway, input_dim=num_inputs,
            hidden_dim=gnn_feature_size,
            num_layers=graph_conv_layer,
            feature_type='dense')

    def forward(self, inputs):
        logits = self.gnn_obj(inputs).squeeze(-1)

        return logits

class gnnModel(myModel):
    def __init__(self, nway):
        super(myModel, self).__init__()
        image_size = 32
        cnn_feature_size = 64
        cnn_hidden_dim = 32
        cnn_num_layers = 3

        gnn_feature_size = 32

        self.cnn_feature = EmbeddingCNN(image_size, cnn_feature_size, cnn_hidden_dim, cnn_num_layers)
        self.gnn = GNN(cnn_feature_size, gnn_feature_size, nway)

    def forward(self, data):
        [x, _, _, _, xi, _, one_hot_yi, _] = data

        z = self.cnn_feature(x)
        zi_s = [self.cnn_feature(xi[:, i, :, :, :]) for i in range(xi.size(1))]

        zi_s = torch.stack(zi_s, dim=1)


        # follow the paper, concatenate the information of labels to input features
        uniform_pad = torch.FloatTensor(one_hot_yi.size(0), 1, one_hot_yi.size(2)).fill_(
            1.0/one_hot_yi.size(2))
        uniform_pad = tensor2cuda(uniform_pad)

        labels = torch.cat([uniform_pad, one_hot_yi], dim=1)
        features = torch.cat([z.unsqueeze(1), zi_s], dim=1)

        nodes_features = torch.cat([features, labels], dim=2)

        out_logits = self.gnn(inputs=nodes_features)
        logsoft_prob = F.log_softmax(out_logits, dim=1)

        return logsoft_prob

class Trainer():
    def __init__(self, trainer_dict):

        self.num_labels = 10

        self.args = trainer_dict['args']
        self.logger = trainer_dict['logger']

        if self.args.todo == 'train':
            self.tr_dataloader = trainer_dict['tr_dataloader']

        if self.args.model_type == 'gnn':
            Model = gnnModel

        self.model = Model(nway=self.args.nway)

        self.logger.info(self.model)

        self.total_iter = 0
        self.sample_size = 32

    def load_model(self, model_dir):
        self.model.load(model_dir)

        print('load model sucessfully...')

    def load_pretrain(self, model_dir):
        self.model.cnn_feature.load(model_dir)

        print('load pretrain feature sucessfully...')

    def model_cuda(self):
        if torch.cuda.is_available():
            self.model.cuda()

    def eval(self, dataloader, test_sample):
        self.model.eval()
        args = self.args
        iteration = int(test_sample/self.args.batch_size)

        total_loss = 0.0
        total_sample = 0
        total_correct = 0
        with torch.no_grad():
            for i in range(iteration):
                data = dataloader.load_te_batch(batch_size=args.batch_size,
                    nway=args.nway, num_shots=args.shots)

                data_cuda = [tensor2cuda(_data) for _data in data]

                logsoft_prob = self.model(data_cuda)

                label = data_cuda[1]
                loss = F.nll_loss(logsoft_prob, label)

                total_loss += loss.item() * logsoft_prob.shape[0]

                pred = torch.argmax(logsoft_prob, dim=1)

                # print(pred)

                # print(torch.eq(pred, label).float().sum().item())
                # print(label)

                assert pred.shape == label.shape

                total_correct += torch.eq(pred, label).float().sum().item()
                total_sample += pred.shape[0]
        print('correct: %d / %d' % (total_correct, total_sample))
        print(total_correct)
        return total_loss / total_sample, 100.0 * total_correct / total_sample

    def train_batch(self):
        self.model.train()
        args = self.args

        data = self.tr_dataloader.load_tr_batch(batch_size=args.batch_size,
            nway=args.nway, num_shots=args.shots)

        data_cuda = [tensor2cuda(_data) for _data in data]

        self.opt.zero_grad()

        logsoft_prob = self.model(data_cuda)

        # print('pred', torch.argmax(logsoft_prob, dim=1))
        # print('label', data[2])
        label = data_cuda[1]

        loss = F.nll_loss(logsoft_prob, label)
        loss.backward()
        self.opt.step()

        return loss.item()

    def train(self):
        if self.args.freeze_cnn:
            self.model.cnn_feature.freeze_weight()
            print('freeze cnn weight...')

        best_loss = 1e8
        best_acc = 0.0
        stop = 0
        eval_sample = 5000
        self.model_cuda()
        self.model_dir = os.path.join(self.args.model_folder, 'model.pth')

        self.opt = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.args.lr,
            weight_decay=1e-6)
        # self.opt = torch.optim.Adam(self.model.parameters(), lr=self.args.lr,
        #     weight_decay=1e-6)

        start = time()
        tr_loss_list = []
        for i in range(self.args.max_iteration):

            tr_loss = self.train_batch()
            tr_loss_list.append(tr_loss)

            if i % self.args.log_interval == 0:
                self.logger.info('iter: %d, spent: %.4f s, tr loss: %.5f' % (i, time() - start,
                    np.mean(tr_loss_list)))
                del tr_loss_list[:]
                start = time()

            if i % self.args.eval_interval == 0:
                va_loss, va_acc = self.eval(self.tr_dataloader, eval_sample)

                self.logger.info('================== eval ==================')
                self.logger.info('iter: %d, va loss: %.5f, va acc: %.4f %%' % (i, va_loss, va_acc))
                self.logger.info('==========================================')

                if va_loss < best_loss:
                    stop = 0
                    best_loss = va_loss
                    best_acc = va_acc
                    if self.args.save:
                        self.model.save(self.model_dir)

                stop += 1
                start = time()

                if stop > self.args.early_stop:
                    break

            self.total_iter += 1

        self.logger.info('============= best result ===============')
        self.logger.info('best loss: %.5f, best acc: %.4f %%' % (best_loss, best_acc))

    def test(self, test_data_array, te_dataloader):
        self.model_cuda()
        self.model.eval()
        start = 0
        end = 0
        args = self.args
        batch_size = args.batch_size
        pred_list = []

        with torch.no_grad():
            while start < test_data_array.shape[0]:
                end = start + batch_size
                if end >= test_data_array.shape[0]:
                    batch_size = test_data_array.shape[0] - start

                data = te_dataloader.load_te_batch(batch_size=batch_size, nway=args.nway,
                    num_shots=args.shots)

                test_x = test_data_array[start:end]

                data[0] = np2cuda(test_x)

                data_cuda = [tensor2cuda(_data) for _data in data]

                map_label2class = data[-1].cpu().numpy()

                logsoft_prob = self.model(data_cuda)
                # print(logsoft_prob)
                pred = torch.argmax(logsoft_prob, dim=1).cpu().numpy()

                pred = map_label2class[range(len(pred)), pred]

                pred_list.append(pred)

                start = end

        return np.hstack(pred_list)

    def pretrain_eval(self, loader, cnn_feature, classifier):
        total_loss = 0
        total_sample = 0
        total_correct = 0

        with torch.no_grad():

            for j, (data, label) in enumerate(loader):
                data = tensor2cuda(data)
                label = tensor2cuda(label)
                output = classifier(cnn_feature(data))
                output = F.log_softmax(output, dim=1)
                loss = F.nll_loss(output, label)

                total_loss += loss.item() * output.shape[0]

                pred = torch.argmax(output, dim=1)

                assert pred.shape == label.shape

                total_correct += torch.eq(pred, label).float().sum().item()
                total_sample += pred.shape[0]

        return total_loss / total_sample, 100.0 * total_correct / total_sample

    def pretrain(self, pretrain_dataset, test_dataset):
        pretrain_loader = torch.utils.data.DataLoader(pretrain_dataset,
                batch_size=self.args.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                        batch_size=self.args.batch_size, shuffle=True)

        self.model_cuda()

        best_loss = 1e8
        self.model_dir = os.path.join(self.args.model_folder, 'pretrain_model.pth')

        cnn_feature = self.model.cnn_feature
        classifier = nn.Linear(list(cnn_feature.parameters())[-3].shape[0], self.num_labels)

        if torch.cuda.is_available():
            classifier.cuda()
        self.pretrain_opt =  torch.optim.Adam(
            list(cnn_feature.parameters()) + list(classifier.parameters()),
            lr=self.args.lr,
            weight_decay=1e-6)

        start = time()

        for i in range(10000):
            total_tr_loss = []
            for j, (data, label) in enumerate(pretrain_loader):
                data = tensor2cuda(data)
                label = tensor2cuda(label)
                output = classifier(cnn_feature(data))

                output = F.log_softmax(output, dim=1)
                loss = F.nll_loss(output, label)

                self.pretrain_opt.zero_grad()
                loss.backward()
                self.pretrain_opt.step()
                total_tr_loss.append(loss.item())

            te_loss, te_acc = self.pretrain_eval(test_loader, cnn_feature, classifier)
            self.logger.info('iter: %d, tr loss: %.5f, spent: %.4f s' % (i, np.mean(total_tr_loss),
                time() - start))
            self.logger.info('--> eval: te loss: %.5f, te acc: %.4f %%' % (te_loss, te_acc))

            if te_loss < best_loss:
                stop = 0
                best_loss = te_loss
                if self.args.save:
                    cnn_feature.save(self.model_dir)

            stop += 1
            start = time()

            if stop > self.args.early_stop_pretrain:
                break



if __name__ == '__main__':
    import os
    b_s = 10
    nway = 5
    shots = 5
    batch_x = torch.rand(b_s, 3, 32, 32).cuda()
    batches_xi = [torch.rand(b_s, 3, 32, 32).cuda() for i in range(nway*shots)]

    label_x = torch.rand(b_s, nway).cuda()

    labels_yi = [torch.rand(b_s, nway).cuda() for i in range(nway*shots)]

    print('create model...')
    model = gnnModel(128, nway).cuda()
    # print(list(model.cnn_feature.parameters())[-3].shape)
    # print(len(list(model.parameters())))
    print(model([batch_x, label_x, None, None, batches_xi, labels_yi, None]).shape)

 

해당 프로젝트의 코드를 부분적으로 공유를 하였습니다. 실제로 구현을 위해서는, 관련 논문과 코드를 참고하시는 것을 추천드립니다.

 

 

결과

N way K shot Accuracy
5 10 51 / 50 %
5 1 36 %
2 10 92 / 93 %
2 1 87 %

논문 내용과 제공해주는 코드를 기반으로 CIFAR10을 적용해여 구현하였습니다. 성능이 CIFAR100을 적용하였을 떄보다 약 10% 적은 정확성을 보여주었습니다. 제가 개인적으로 연구한 내용은 다음과 같습니다.

CIFAR10에 있는 항목을 5개로 나누었습니다. 비행기, 고양이, 말, 자동차, 새 : 배, 개, 사슴, 트럭, 개구리로 나누어서 퓨샷 러닝을 진행해보았습니다. 처음에는 랜덤하게 5개의 클래스를 선택하여 퓨샷 학습을 하였습니다. 그러다보니 성능이 약 50% 웃돌았습니다. 만약 2개의 클래스를 선택하여 퓨샷 학습을 하면 정확도가 93%를 나타낼 정도로 좋은 성능이 나타났습니다. 그 이유는 역시, 분류할 경우가 2가지 밖에 없기떄문에 이러한 문제에 대해서는 좋은 성능을 나타내는 것입니다.

5개의 클래스로 늘어난 순간 정확도는 50%로 매우 저조해졌습니다. 따라서 저는 위와 같이 비슷한 클래스라고 생각되는 것들을 직접 CIFAR10 항목에서 분리를 해보았습니다. 해당 클래스를 선택하여 학습을 한다면, 비슷한 특성들을 가지고 있는 클래스들을 잘 구별해내지 않을까라는 생각으로 출발하였습니다. 해당 결과에서 약간의 정확도 상승이 있다는 것을 확인하였습니다.

결국, 몇개의 데이터만을 가지고, 인공지능에게 분류를 할 수 있는 능력을 가지게 하는 것은 결국, 사람과 유사한 학습 능력을 기르게 하는 것이라고 생각합니다. 수많은 클래스를 끊임없이 분류해내고 학습할 수 있는 능력을 가진 인공지능의 유의미한 퍼포먼스는 과연 언제쯤 나올 수 있을 지 기대가 됩니다.

 

 

결론

GNN 구조

퓨샷, SemiSupervised 및 Active Learning을 위한 그래프 뉴럴 네트워크를 탐구했습니다. 메타 학습 관점에서 볼 때, 이러한 작업은 신경 메시지 전달 모델과 함께 관계형 구조를 활용할 수 있는 요소 모음 또는 요소 집합으로 입력을 받는 경우 학습 문제를 Supervise합니다. 특히, 스택 노드 및 엣지 기능은 이전 퓨샷 학습 모델을 뒷받침하는 맥락 유사성 학습을 일반화합니다.

그래프 공식은 동일한 체제 하에서 여러 교육 설정 (퓨샷, 활성, 반Supervised)을 통합하는 데 도움이 됩니다. 이는 여러 영역(라벨 스트림)에서 동시에 작동 할 수 있는 단일 학습자를 갖기 위한 필수 단계입니다. 또 다른 기대효과는 Active Learning의 범위를 일반화하여 예를 들어 질문을 하는 능력 또는 강화 학습 설정에서 고정 학습이 정지되지 않은 환경에 적응하는 데 중요합니다.

반응형