DonHurry

step18. ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ๋ชจ๋“œ ๋ณธ๋ฌธ

DeZero/๐Ÿ—ป์ œ2๊ณ ์ง€

step18. ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ๋ชจ๋“œ

_๋„๋… 2023. 1. 19. 00:01

๐Ÿ“ข ๋ณธ ํฌ์ŠคํŒ…์€ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹3์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ฐฐ์šด ๋‚ด์šฉ์„ ๊ธฐ๋กํ•˜๊ณ , ๊ฐœ์ธ์ ์ธ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ์ž‘์„ฑํ•˜๋Š” ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ต์žฌ ๊ตฌ๋งค๋ฅผ ๊ฐ•๋ ฅ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

 

 

์ด๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ๋ชจ๋“œ๋ฅผ ๊ตฌํ˜„ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ์„  ํ•„์š” ์—†๋Š” ๋ฏธ๋ถ„๊ฐ’๋“ค์„ ์‚ญ์ œํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ์˜ DeZero๋Š” ๋ฏธ๋ถ„์„ ์ง„ํ–‰ํ•˜๋ฉด ๋ชจ๋“  ๋ณ€์ˆ˜๊ฐ€ ๋ฏธ๋ถ„ ๊ฒฐ๊ณผ๋ฅผ ๋ฉ”๋ชจ๋ฆฌ์— ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. ๋ณดํ†ต ๋จธ์‹ ๋Ÿฌ๋‹์—์„œ๋Š” ๋ง๋‹จ ๋ณ€์ˆ˜์˜ ๋ฏธ๋ถ„๊ฐ’๋งŒ ํ•„์š”ํ•˜๋ฏ€๋กœ, ์ค‘๊ฐ„ ๋ณ€์ˆ˜์˜ ๋ฏธ๋ถ„๊ฐ’์„ ์ œ๊ฑฐํ•˜๋Š” ๊ธฐ๋Šฅ(retain_grad)์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

class Variable:
    ...
    def backward(self, retain_grad=False):
        ...
        while funcs:
            f = funcs.pop()
            gys = [output().grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs, )
            
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx

                if x.creator is not None:
                    add_func(x.creator)
            
            if not retain_grad:
                for y in f.outputs:
                    y().grad = None  # y๋Š” ์•ฝํ•œ ์ฐธ์กฐ(weakref)

 

ํ…Œ์ŠคํŠธ๋กœ ๋‹ค์Œ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ค‘๊ฐ„ ๋ณ€์ˆ˜์ธ y์™€ t์˜ ๋ฏธ๋ถ„๊ฐ’์€ ์‚ญ์ œ๋˜๊ณ , ๋ง๋‹จ ๋ณ€์ˆ˜์ธ x0๊ณผ x1์˜ ๋ฏธ๋ถ„๊ฐ’๋งŒ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค. ๋•๋ถ„์— ์ ˆ์•ฝ๋œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋‹ค๋ฅธ ๊ณณ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)  # None None
print(x0.grad, x1.grad)  # 2.0 1.0

 

์‹ ๊ฒฝ๋ง์€ ํฌ๊ฒŒ ํ•™์Šต๊ณผ ์ถ”๋ก ์ด๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ๋‹จ๊ณ„๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค. ํ•™์Šต ์‹œ์—๋Š” ๋ฏธ๋ถ„๊ฐ’์„ ๊ตฌํ•ด์•ผํ•˜์ง€๋งŒ, ์ถ”๋ก  ์‹œ์—๋Š” ์ˆœ์ „ํŒŒ๋งŒ์„ ์ง„ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ค‘๊ฐ„ ๊ณ„์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋กํ•˜์ง€ ์•Š์œผ๋ฉด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ํฌ๊ฒŒ ์ค„์ด๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์ฃผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ์ธ PyTorch์—์„œ๋Š” ์ด ๊ธฐ๋Šฅ์„ torch.no_grad()๋ผ๋Š” ํ•จ์ˆ˜ ํ˜•ํƒœ๋กœ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. (๋งค์šฐ ์ž์ฃผ ์“ฐ์ž…๋‹ˆ๋‹ค.)

 

์ˆœ์ „ํŒŒ๋งŒ ํ™œ์šฉํ•  ๋•Œ๋ฅผ ์œ„ํ•ด, ์—ญ์ „ํŒŒ ํ™œ์„ฑ ๋ชจ๋“œ์™€ ์—ญ์ „ํŒŒ ๋น„ํ™œ์„ฑ ๋ชจ๋“œ๋ฅผ ์ „ํ™˜ํ•˜๋Š” ๊ตฌ์กฐ๋ฅผ ๊ตฌ์ถ•ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € Config ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. enable_backprop์ด True์ด๋ฉด ์—ญ์ „ํŒŒ ํ™œ์„ฑ ๋ชจ๋“œ์ž…๋‹ˆ๋‹ค. Config ๊ฐ™์€ ์„ค์ • ๋ฐ์ดํ„ฐ๋Š” ๋‹จ ํ•œ ๊ตฐ๋ฐ์—๋งŒ ์กด์žฌํ•˜๋Š” ๊ฒƒ์ด ์ข‹๊ธฐ ๋•Œ๋ฌธ์—, ์ธ์Šคํ„ด์Šคํ™”ํ•˜์ง€ ์•Š๊ณ  ํด๋ž˜์Šค ์ƒํƒœ๋กœ ๋‘๊ฒ ์Šต๋‹ˆ๋‹ค.

class Config:
    enable_backprop = True

 

์ด์ œ Function์—์„œ Config ํด๋ž˜์Šค๋ฅผ ์ฐธ์กฐํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ญ์ „ํŒŒ ์‹œ์— ํ•„์š”ํ•œ ์„ธ๋Œ€์™€ ๊ณ„์‚ฐ๋“ค์˜ ์—ฐ๊ฒฐ์„ ๋งŒ๋“ค์–ด๋‚ด๋Š” output.set_creator(self)๋“ฑ์˜ ๊ธฐ๋Šฅ์„ if๋ฌธ ์•ˆ์— ๋„ฃ์—ˆ์Šต๋‹ˆ๋‹ค.

class Function:
    def __call__(self, *inputs):
        ...
        if Config.enable_backprop:
            self.generation = max([x.generation for x in inputs])  # ์„ธ๋Œ€ ์„ค์ •
            for output in outputs:
                output.set_creator(self)  # ์—ฐ๊ฒฐ ์„ค์ •
            self.inputs = inputs
            self.outputs = [weakref.ref(output) for output in outputs]
        
        return outputs if len(outputs) > 1 else outputs[0]

 

ํŒŒ์ด์ฌ์—๋Š” with๋ผ๋Š” ํŽธ๋ฆฌํ•œ ๊ตฌ๋ฌธ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ์™€ ๊ฐ™์ด with ๋ธ”๋ก์— ๋“ค์–ด๊ฐˆ ๋•Œ ์–ด๋–ค ์ฒ˜๋ฆฌ(์ „์ฒ˜๋ฆฌ)๋ฅผ ํ•ด์ฃผ๊ณ , with ๋ธ”๋ก์„ ๋น ์ ธ๋‚˜์˜ฌ ๋•Œ ์ฒ˜๋ฆฌ(ํ›„์ฒ˜๋ฆฌ)๋ฅผ ์ž๋™์œผ๋กœ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ๋Š” using_config ์•ˆ์—์„œ๋งŒ ์—ญ์ „ํŒŒ ๋น„ํ™œ์„ฑ ๋ชจ๋“œ์ธ ๊ฒƒ์ด๊ณ , ๋น ์ ธ๋‚˜์˜ค๋ฉด ์ผ๋ฐ˜ ๋ชจ๋“œ์ธ ์—ญ์ „ํŒŒ ํ™œ์„ฑ ๋ชจ๋“œ๋กœ ๋Œ์•„๊ฐ€๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

with using_config('enable_backprop', False):
    x = Variable(np.array(2.0))
    y = square(x)

 

์•ž์„œ ํ™œ์šฉํ–ˆ๋˜ using_config ํ•จ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์†Œ ์ƒ์†Œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ์„  with ๋ธ”๋ก ์•ˆ์—์„œ ์˜ˆ์™ธ๊ฐ€ ๋ฐœ์ƒํ•  ๊ฒƒ์„ ๊ณ ๋ คํ•˜์—ฌ try/finally ๊ตฌ๋ฌธ์„ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋ฏธ๋ฆฌ ์ด์ „ ๊ฐ’์ธ old_value๋ฅผ ๋ฐ›์•„๋†“๊ณ , setattr๋ฅผ ํ†ตํ•ด ์ƒˆ๋กœ์šด value๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ตœ์ข…์ ์œผ๋กœ with ๋ธ”๋ก์„ ๋น ์ ธ๋‚˜์˜ฌ ๋•Œ์—๋Š” ์›๋ž˜ ๊ฐ’์ธ old_value๋กœ ๋ณต์›๋ฉ๋‹ˆ๋‹ค.

import contextlib

@contextlib.contextmanager
def using_config(name, value):
    old_value = getattr(Config, name)
    setattr(Config, name, value)
    try:
        yield
    finally:
        setattr(Config, name, old_value)

 

ํŽธ์˜์„ฑ์„ ์œ„ํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ•จ์ˆ˜๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค. ์ด๋กœ์จ PyTorch์˜ torch.no_grad์™€ ๋™์ผํ•˜๊ฒŒ ์ž‘๋™์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•ž์œผ๋กœ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ์ด ํ•„์š” ์—†์„ ๋•Œ, no_grad ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

def no_grad():
    return using_config('enable_backprop', False)
    
with no_grad():
    x = Variable(np.array(2.0))
    y = square(x)