DonHurry

step16. ๋ณต์žกํ•œ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„(๊ตฌํ˜„ ํŽธ) ๋ณธ๋ฌธ

DeZero/๐Ÿ—ป์ œ2๊ณ ์ง€

step16. ๋ณต์žกํ•œ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„(๊ตฌํ˜„ ํŽธ)

_๋„๋… 2023. 1. 17. 00:01

๐Ÿ“ข ๋ณธ ํฌ์ŠคํŒ…์€ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹3์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ฐฐ์šด ๋‚ด์šฉ์„ ๊ธฐ๋กํ•˜๊ณ , ๊ฐœ์ธ์ ์ธ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ์ž‘์„ฑํ•˜๋Š” ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ต์žฌ ๊ตฌ๋งค๋ฅผ ๊ฐ•๋ ฅ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

 

 

์ด๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” 15๋‹จ๊ณ„์˜ ์ด๋ก ์„ ์ง์ ‘ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ๋จผ์ € ์ˆœ์ „ํŒŒ ์‹œ ์„ธ๋Œ€๋ฅผ ์„ค์ •ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. Variableํด๋ž˜์Šค์™€ Function ํด๋ž˜์Šค์— ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜ generation์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ๋ช‡ ๋ฒˆ์งธ ์„ธ๋Œ€์˜ ํ•จ์ˆ˜์ธ์ง€ ๋‚˜ํƒ€๋‚ด๋Š” ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค. set_creator ๋ฉ”์„œ๋“œ๊ฐ€ ํ˜ธ์ถœ๋  ๋•Œ ์„ธ๋Œ€๋ฅผ 1๋งŒํผ ๋Š˜๋ ค์ค๋‹ˆ๋‹ค. ์ฆ‰, ์–ด๋–ค ํ•จ์ˆ˜๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ ๋ณ€์ˆ˜๋Š” ํ•ด๋‹น ํ•จ์ˆ˜๋ณด๋‹ค 1๋งŒํผ ํฐ ์„ธ๋Œ€ ๊ฐ’์„ ๊ฐ–์Šต๋‹ˆ๋‹ค.

class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))

        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0  # ์„ธ๋Œ€ ์ˆ˜ ๊ธฐ๋ก

    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1  # ์„ธ๋Œ€ ๊ธฐ๋ก(๋ถ€๋ชจ ์„ธ๋Œ€ + 1)
    
    ...

 

๋‹ค์Œ์€ Function ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ž…๋ ฅ ๋ณ€์ˆ˜๊ฐ€ ๋‘˜ ์ด์ƒ์ผ ๋•Œ, ๊ฐ€์žฅ ํฐ ์„ธ๋Œ€๋ฅผ ์„ ํƒํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys, )
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]
    
    ...

 

์ด์ „ ๋‹จ๊ณ„์—์„œ, ์„ธ๋Œ€๊ฐ€ ํฐ ํ•จ์ˆ˜๋ถ€ํ„ฐ ์ฒ˜๋ฆฌํ•˜๋ฉด ์˜ฌ๋ฐ”๋ฅธ ์ˆœ์„œ๋กœ ์—ญ์ „ํŒŒ๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ฐ€์žฅ ํฐ ์›์†Œ๋ฅผ ๊บผ๋‚ด๊ธฐ ์œ„ํ•œ ๋‹จ์ˆœํ•œ ๋ฐฉ๋ฒ• ์ค‘ ํ•˜๋‚˜๋Š” ์ •๋ ฌ์„ ์ด์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํšจ์œจ์„ ๊ณ ๋ คํ•˜๋ฉด ์šฐ์„ ์ˆœ์œ„ ํ๋ฅผ ์ด์šฉํ•ด ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์ด ๋งž์ง€๋งŒ, ์šฐ์„ ์€ ์ •๋ ฌ์„ ์ด์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋™์•ˆ์—๋Š” DeZero์˜ ํ•จ์ˆ˜๋“ค์„ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐฉ์‹์„ ์ด์šฉํ–ˆ์ง€๋งŒ, add_func ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋„๋ก ๋ณ€๊ฒฝํ•˜์˜€์Šต๋‹ˆ๋‹ค. add_func ํ•จ์ˆ˜๋Š” ํ•จ์ˆ˜๋“ค์„ ์„ธ๋Œ€ ์ˆœ์œผ๋กœ ์ •๋ ฌํ•˜๋Š” ์—ญํ• ์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

class Variable:
    ...
    
    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs = []
        seen_set = set()

        # ์ถ”ํ›„ ์šฐ์„ ์ˆœ์œ„ ํ๋กœ ๊ตฌํ˜„
        def add_func(f):
            if f not in seen_set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
        
        add_func(self.creator)

        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs, )
            
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    add_func(x.creator)  # funcs.append(x.creator)์—์„œ ์ˆ˜์ •

 

 

์ด์ œ ์ด์ „์—๋Š” ๊ตฌํ˜„ํ•˜์ง€ ๋ชปํ–ˆ๋˜ ์œ„์™€ ๊ฐ™์€ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ํ…Œ์ŠคํŠธํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋Š” $y = (x^2)^2 + (x^2)^2$์ด๋ฏ€๋กœ $y=2x^4$์„ ๋ฏธ๋ถ„ํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ๊ฐ€ ์ œ๋Œ€๋กœ ์ถœ๋ ฅ๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์œผ๋กœ DeZero๋Š” ์•„๋ฌด๋ฆฌ ๋ณต์žกํ•œ ์—ฐ๊ฒฐ๋„ ์ œ๋Œ€๋กœ ๋ฏธ๋ถ„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

x = Variable(np.array(2.0))
a = square(x)
y = add(square(a), square(a))
y.backward()

print(y.data)  # 32.0
print(x.grad)  # 64.0