DonHurry

step50. ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๋ฝ‘์•„์ฃผ๋Š” DataLoader ๋ณธ๋ฌธ

DeZero/๐Ÿ—ป์ œ4๊ณ ์ง€

step50. ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๋ฝ‘์•„์ฃผ๋Š” DataLoader

_๋„๋… 2023. 3. 2. 19:01

๐Ÿ“ข ๋ณธ ํฌ์ŠคํŒ…์€ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹3์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ฐฐ์šด ๋‚ด์šฉ์„ ๊ธฐ๋กํ•˜๊ณ , ๊ฐœ์ธ์ ์ธ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ์ž‘์„ฑํ•˜๋Š” ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ต์žฌ ๊ตฌ๋งค๋ฅผ ๊ฐ•๋ ฅ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

 

 

์ €๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” ๋ฐ์ดํ„ฐ์…‹ ์ค‘ ์ผ๋ถ€๋ฅผ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋กœ ๋ฝ‘์•„ ํ•™์Šต์‹œ์ผฐ์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” ์ด๋Ÿฌํ•œ ๊ณผ์ •์„ DataLoader ํด๋ž˜์Šค๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. DataLoader๋Š” ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์ƒ์„ฑ, ๋ฐ์ดํ„ฐ์…‹ ์„ž๊ธฐ ๋“ฑ์˜ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์šฐ์„  ํŒŒ์ด์ฌ์˜ ๋ฐ˜๋ณต์ž์— ๋Œ€ํ•ด ์•Œ๊ณ  ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

 

์ž˜ ์ •๋ฆฌ๋œ ๊ธ€์„ ์•„๋ž˜ ๋งํฌ์— ์ฒจ๋ถ€ํ•˜๋‹ˆ, ์ฐธ๊ณ  ๋ฐ”๋ž๋‹ˆ๋‹ค.

 

[Python] ๋ณผ ๋•Œ๋งˆ๋‹ค ํ—ท๊ฐˆ๋ฆฌ๋Š” Iterable, Iterator, Generator ์ •๋ฆฌํ•˜๊ธฐ

Iterable vs Iterator vs Generator ๋‹ค๋ฅธ ๋ถ„๋“ค์˜ ์ฝ”๋“œ๋ฅผ ์ฝ์„ ๋•Œ๋งˆ๋‹ค, ๋‚ด๊ฐ€ ์‚ฌ์šฉํ•  ๋•Œ๋งˆ๋‹ค, ํ—ท๊ฐˆ๋ฆฌ๋Š” Iterable, Iterator, Generator๋ฅผ ์ด๋ฒˆ ๊ธ€์„ ์ž‘์„ฑํ•ด๋ณด๋ฉด์„œ, ๋งˆ์ง€๋ง‰์œผ๋กœ! (๋ผ๋Š” ๋‹ค์ง์œผ๋กœ) ์ •๋ฆฌํ•ด๋ด…๋‹ˆ๋‹ค. ์ž˜ ์•Œ๊ณ 

emjayahn.github.io

 

ํŒŒ์ด์ฌ์—์„œ๋Š” ๋ฐ˜๋ณต์ž๋ฅผ ์ง์ ‘ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ๊ฐ€ ๊ณ ์œ ํ•œ ๋ฐ˜๋ณต์ž๋ฅผ ๋งŒ๋“œ๋Š” ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค. __iter__๋ผ๋Š” ํŠน์ˆ˜ ๋ฉ”์„œ๋“œ๋ฅผ ๊ตฌํ˜„ํ•˜์—ฌ ์ž๊ธฐ ์ž์‹ ์„ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์›์†Œ ๋ฐ˜ํ™˜์€ __next__๋ฅผ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค.

class MyIterator:
	def __init__(self, max_cnt):
    	self.max_cnt = max_cnt
        self.cnt = 0
    
    def __iter__(self):
    	return self
        
    def __next__(self):
    	if self.cnt == self.max_cnt:
        	raise StopIteration()
        
        self.cnt += 1
        return self.cnt

 

๊ธฐ๋ณธ ๊ฐœ๋…์„ ๊ณต๋ถ€ํ•˜์˜€์œผ๋‹ˆ, ์ด์ œ๋Š” DataLodaer๋ฅผ ๊ตฌํ˜„ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ดˆ๊ธฐํ™” ์ฝ”๋“œ์—์„œ ์ธ์ˆ˜๋ฅผ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜๋กœ ์ €์žฅํ•˜๊ณ , reset ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. reset ๋ฉ”์„œ๋“œ์—์„œ๋Š” ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜์˜ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋ฅผ 0์œผ๋กœ ์„ค์ •ํ•˜๊ณ , ๋ฐ์ดํ„ฐ์˜ ์ธ๋ฑ์Šค๋ฅผ ๋’ค์„ž์„์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค. __next__ ๋ฉ”์„œ๋“œ๊ฐ€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๊บผ๋‚ด ndarray ์ธ์Šคํ„ด์Šค๋กœ ๋ณ€ํ™”ํ•˜๋Š” ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค. ์ง€๊ธˆ๊นŒ์ง€ ์‚ฌ์šฉํ–ˆ๋˜ ์ฝ”๋“œ์™€ ๊ฐ™์œผ๋ฏ€๋กœ ์„ค๋ช…์€ ์ƒ๋žตํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

import math
import numpy as np


class DataLoader:
    def __init__(self, dataset, batch_size, shuffle=True):
        self.dataset = dataset  # Dataset ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ๋งŒ์กฑํ•˜๋Š” ์ธ์Šคํ„ด์Šค
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.data_size = len(dataset)
        self.max_iter = math.ceil(self.data_size / batch_size)

        self.reset()
    
    def reset(self):
        self.iteration = 0
        if self.shuffle:
            self.index = np.random.permutation(len(self.dataset))
        else:
            self.index = np.arange(len(self.dataset))
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.iteration >= self.max_iter:
            self.reset()
            raise StopIteration
        
        i, batch_size = self.iteration, self.batch_size
        batch_index = self.index[i * batch_size:(i + 1) * batch_size]
        batch = [self.dataset[i] for i in batch_index]
        x = np.array([example[0] for example in batch])
        t = np.array([example[1] for example in batch])

        self.iteration += 1
        return x, t
    
    def next(self):
        return self.__next__()

 

๋ณธ๊ฒฉ์ ์ธ ํ•™์Šต ์ด์ „์— ์ •ํ™•๋„๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” accuracy ํ•จ์ˆ˜๋ฅผ ๊ตฌํ˜„ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ธ์ˆ˜ y์™€ t๋ฅผ ๋ฐ›์•„ ์ •๋‹ต๋ฅ ์„ ๊ณ„์‚ฐํ•ด์ค๋‹ˆ๋‹ค. ์ฐธ๊ณ ๋กœ ์•„๋ž˜์™€ ๊ฐ™์ด np.ndarray ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ (pred == t.data) ์ฒ˜๋Ÿผ ๋น„๊ตํ•˜๋ฉด ๊ฐ ์›์†Œ๋งˆ๋‹ค ์ผ์น˜ ๋ถˆ์ผ์น˜๋ฅผ [True True False] ์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ๋ฐ˜ํ™˜ํ•ด์ค๋‹ˆ๋‹ค. (ํŒŒ์ด์ฌ์˜ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋น„๊ตํ•˜๋ฉด ๋ฆฌ์ŠคํŠธ ์ „์ฒด ์›์†Œ๊ฐ€ ๊ฐ™์€์ง€ ๋‹ค๋ฅธ์ง€ ํ•˜๋‚˜์˜ ๋ถˆ๋ฆฌ์–ธ ํƒ€์ž…์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.)

def accuracy(y, t):
    y, t = as_variable(y), as_variable(t)

    pred = y.data.argmax(axis=1).reshape(t.shape)
    result = (pred == t.data)
    acc = result.mean()
    return Variable(as_array(acc))

 

์ด์ œ ์ŠคํŒŒ์ด๋  ๋ฐ์ดํ„ฐ์…‹์„ ํ™œ์šฉํ•˜์—ฌ ํ•™์Šต์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฒˆ์—๋Š” train๊ณผ test์šฉ์„ ๋‚˜๋ˆ„์–ด์„œ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ์‹œ์—๋Š” ์—ญ์ „ํŒŒ๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ์ด์ „์— ๊ตฌํ˜„ํ–ˆ๋˜ with dezero.no_grad() ๋ฅผ ํ™œ์šฉํ•ด ์ž์› ์†Œ๋ชจ๋ฅผ ํ”ผํ•ฉ๋‹ˆ๋‹ค. 

import dezero
import dezero.functions as F
from dezero import optimizers
from dezero import DataLoader
from dezero.models import MLP


max_epoch = 300
batch_size = 30
hidden_size = 10
lr = 1.0

train_set = dezero.datasets.Spiral(train=True)
test_set = dezero.datasets.Spiral(train=False)
train_loader = DataLoader(train_set, batch_size)
test_loader = DataLoader(test_set, batch_size, shuffle=False)

model = MLP((hidden_size, 3))
optimizer = optimizers.SGD(lr).setup(model)

for epoch in range(max_epoch):
    sum_loss, sum_acc = 0, 0

    for x, t in train_loader:
        y = model(x)
        loss = F.softmax_cross_entropy(y, t)
        acc = F.accuracy(y, t)

        model.cleargrads()
        loss.backward()
        optimizer.update()

        sum_loss += float(loss.data) * len(t)
        sum_acc += float(acc.data) * len(t)
    
    print('epoch: {}'.format(epoch + 1))
    print('train loss: {:.4f}, accuracy: {:.4f}'.format(
    	sum_loss / len(train_set), sum_acc / len(train_set)))

    sum_loss, sum_acc = 0, 0
    with dezero.no_grad():
        for x, t in test_loader:
            y = model(x)
            loss = F.softmax_cross_entropy(y, t)
            acc = F.accuracy(y, t)

            sum_loss += float(loss.data) * len(t)
            sum_acc += float(acc.data) * len(t)

    print('test loss: {:.4f}, accuracy: {:.4f}'.format(
    	sum_loss / len(test_set), sum_acc / len(test_set)))

 

์œ„ ์ฝ”๋“œ์˜ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋ž˜ํ”„๋กœ ๋‚˜ํƒ€๋‚ด๋ฉด epoch์ด ์ง„ํ–‰๋ ์ˆ˜๋ก loss๊ฐ€ ๋‚ฎ์•„์ง€๊ณ  ์ •ํ™•๋„๊ฐ€ ์ƒ์Šนํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šต์ด ์ œ๋Œ€๋กœ ์ด๋ฃจ์–ด์ง€๊ณ  ์žˆ์œผ๋ฉฐ, train๊ณผ test์˜ ์ฐจ์ด๊ฐ€ ์ž‘์œผ๋ฏ€๋กœ ๋ชจ๋ธ์ด ๊ณผ๋Œ€์ ํ•ฉ ๋ฌธ์ œ๋„ ์ผ์œผํ‚ค์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„์—์„œ๋Š” ์ŠคํŒŒ์ด๋Ÿด ๋ฐ์ดํ„ฐ์…‹ ๋Œ€์‹  MNIST ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.