DonHurry

step12. ๊ฐ€๋ณ€ ๊ธธ์ด ์ธ์ˆ˜(๊ฐœ์„  ํŽธ) ๋ณธ๋ฌธ

DeZero/๐Ÿ—ป์ œ2๊ณ ์ง€

step12. ๊ฐ€๋ณ€ ๊ธธ์ด ์ธ์ˆ˜(๊ฐœ์„  ํŽธ)

_๋„๋… 2023. 1. 13. 00:08

๐Ÿ“ข ๋ณธ ํฌ์ŠคํŒ…์€ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹3์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ฐฐ์šด ๋‚ด์šฉ์„ ๊ธฐ๋กํ•˜๊ณ , ๊ฐœ์ธ์ ์ธ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ์ž‘์„ฑํ•˜๋Š” ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ต์žฌ ๊ตฌ๋งค๋ฅผ ๊ฐ•๋ ฅ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

 

 

์ด๋ฒˆ ๋‹จ๊ณ„์—์„œ๋Š” ๋‘ ๊ฐ€์ง€๋ฅผ ๊ฐœ์„ ํ•ด DeZero์˜ ํŽธ์˜์„ฑ์„ ๊ฐœ์„ ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ๋Š” Addํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์‚ฌ๋žŒ, ๋‘ ๋ฒˆ์งธ๋Š” ๊ตฌํ˜„ํ•˜๋Š” ์‚ฌ๋žŒ์„ ์œ„ํ•œ ๊ฐœ์„ ์ž…๋‹ˆ๋‹ค. ์šฐ์„  ๋‹ค์Œ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์‰ฝ๊ฒŒ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

 

Add ํด๋ž˜์Šค๋Š” ์ธ์ˆ˜๋ฅผ ๋ฆฌ์ŠคํŠธ์— ๋ชจ์•„์„œ ๋ฐ›๊ณ  ๊ฒฐ๊ณผ๋Š” ํŠœํ”Œ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์˜ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๋ฆฌ์ŠคํŠธ๋‚˜ ํŠœํ”Œ์„ ๊ฑฐ์น˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ๋” ์ž์—ฐ์Šค๋Ÿฝ์Šต๋‹ˆ๋‹ค. ํ•จ์ˆ˜ ์ •์˜ ์‹œ ์ธ์ˆ˜ ์•ž์— ๋ณ„ํ‘œ(*)๋ฅผ ๋ถ™์ด๋ฉด, ์ž„์˜ ๊ฐœ์ˆ˜์˜ ์ธ์ˆ˜(๊ฐ€๋ณ€ ๊ธธ์ด ์ธ์ˆ˜)๋ฅผ ๋ฐ›์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ˜ํ™˜๊ฐ’์€ outputs ์›์†Œ๊ฐ€ ํ•˜๋‚˜์ธ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์›์†Œ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

class Function:
    def __call__(self, *inputs):  # ๊ฐ€๋ณ€ ๊ธธ์ด ์ธ์ˆ˜
        xs = [input.data for input in inputs]
        ys = self.forward(xs)
        outputs = [Variable(as_array(y)) for y in ys]
        
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]  # ๋ฆฌ์ŠคํŠธ ์›์†Œ๊ฐ€ ํ•˜๋‚˜๋ฉด ์ฒซ ๋ฒˆ์งธ ์›์†Œ ๋ฐ˜ํ™˜

 

๋‘ ๋ฒˆ์งธ ๊ฐœ์„ ์€ ๊ตฌํ˜„ํ•˜๋Š” ์‚ฌ๋žŒ์„ ์œ„ํ•œ ๊ฐœ์„ ์ž…๋‹ˆ๋‹ค. ์ด์ „๊นŒ์ง€๋Š” ์™ผ์ชฝ์ฒ˜๋Ÿผ ๊ตฌํ˜„ํ•ด์•ผ ํ–ˆ์ง€๋งŒ, ์˜ค๋ฅธ์ชฝ์ด ์ž์—ฐ์Šค๋Ÿฌ์›Œ๋ณด์ž…๋‹ˆ๋‹ค.

 

๋‘ ๋ฒˆ์งธ ๊ฐœ์„ ์„ ์œ„ํ•ด Function ํด๋ž˜์Šค๋ฅผ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฒˆ์—๋Š” forward ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•  ๋•Œ ๋ณ„ํ‘œ๋ฅผ ๋ถ™์˜€์Šต๋‹ˆ๋‹ค. ์ด๋Ÿด ๋•Œ๋Š” ๋ฆฌ์ŠคํŠธ ์–ธํŒฉ์˜ ๊ธฐ๋Šฅ์„ ํ•ฉ๋‹ˆ๋‹ค. ๋ฆฌ์ŠคํŠธ์˜ ์›์†Œ๋ฅผ ๋‚ฑ๊ฐœ๋กœ ํ’€์–ด์„œ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ys๊ฐ€ ํŠœํ”Œ์ด ์•„๋‹Œ ๊ฒฝ์šฐ ํŠœํ”Œ๋กœ ๋ณ€๊ฒฝํ•ด์ค๋‹ˆ๋‹ค.

class Function:
    def __call__(self, *inputs):
        xs = [input.data for input in inputs]
        ys = self.forward(*xs)  # ๋ณ„ํ‘œ๋ฅผ ๋ถ™์—ฌ ์–ธํŒฉ
        if not isinstance(ys, tuple):  # ํŠœํ”Œ์ด ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€๊ฒฝ
            ys = (ys, )
        outputs = [Variable(as_array(y)) for y in ys]
        
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]

 

์œ„์˜ ๊ฐœ์„  ์ž‘์—…์ด ๋๋‚˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ž์—ฐ์Šค๋Ÿฌ์šด ์ˆœ์ „ํŒŒ ๋ฉ”์„œ๋“œ๋ฅผ ์ •์˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y

 

๋‹ค์Œ์œผ๋กœ Add ํด๋ž˜์Šค๋ฅผ ํŒŒ์ด์ฌ ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค.

def add(x0, x1):
    return Add()(x0, x1)

 

๋งˆ์ง€๋ง‰์œผ๋กœ ํ…Œ์ŠคํŠธ๋ฅผ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ด์ œ ์ˆœ์ „ํŒŒ์—์„œ ๊ฐ€๋ณ€ ๊ธธ์ด ์ธ์ˆ˜๋ฅผ ๋‹ค๋ฃฐ ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„์—์„œ๋Š” ์—ญ์ „ํŒŒ๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

x0 = Variable(np.array(2))
x1 = Variable(np.array(3))
y = add(x0, x1)  # Add ํด๋ž˜์Šค ์ƒ์„ฑ ๊ณผ์ •์„ ๊ฐ์ถ”๊ธฐ
print(y.data)  # 5