DonHurry

step20. μ—°μ‚°μž μ˜€λ²„λ‘œλ“œ(1) λ³Έλ¬Έ

DeZero/πŸ—»μ œ2κ³ μ§€

step20. μ—°μ‚°μž μ˜€λ²„λ‘œλ“œ(1)

_도녁 2023. 1. 21. 00:34

πŸ“’ λ³Έ ν¬μŠ€νŒ…μ€ λ°‘λ°”λ‹₯λΆ€ν„° μ‹œμž‘ν•˜λŠ” λ”₯λŸ¬λ‹3을 기반으둜 μž‘μ„±ν•˜μ˜€μŠ΅λ‹ˆλ‹€. 배운 λ‚΄μš©μ„ κΈ°λ‘ν•˜κ³ , 개인적인 곡뢀λ₯Ό μœ„ν•΄ μž‘μ„±ν•˜λŠ” ν¬μŠ€νŒ…μž…λ‹ˆλ‹€. μžμ„Έν•œ λ‚΄μš©μ€ ꡐ재 ꡬ맀λ₯Ό κ°•λ ₯ μΆ”μ²œλ“œλ¦½λ‹ˆλ‹€.

 

 

이번 λ‹¨κ³„μ—μ„œλŠ” μ—°μ‚°μž μ˜€λ²„λ‘œλ“œλ₯Ό μ§„ν–‰ν•˜κ² μŠ΅λ‹ˆλ‹€. μ˜€λ²„λ‘œλ“œ μžμ²΄μ— κ΄€ν•œ λ‚΄μš©μ€ 잘 μ •λ¦¬λœ μ—¬λŸ¬ λΈ”λ‘œκ·Έλ“€μ΄ μžˆμœΌλ‹ˆ μ°Έκ³ ν•˜μ‹œκΈΈ λ°”λžλ‹ˆλ‹€. μš°μ„  Mul, κ³±μ…ˆ 클래슀λ₯Ό κ΅¬ν˜„ν•΄λ³΄κ² μŠ΅λ‹ˆλ‹€. κ³±μ…ˆμ˜ μ—­μ „νŒŒλŠ” μ•„λž˜ κ·Έλ¦Όκ³Ό 같이 ν˜λŸ¬λ“€μ–΄μ˜¨ κΈ°μšΈκΈ°μ— μ„œλ‘œμ˜ μž…λ ₯값을 κ΅ν™˜ν•˜μ—¬ κ³±ν•΄μ£ΌλŠ” λ°©μ‹μœΌλ‘œ μ§„ν–‰λ©λ‹ˆλ‹€. λ°‘μ‹œλ”₯ 1μ—μ„œ μžμ„Έν•˜κ²Œ μ„€λͺ…ν•˜λŠ” λ‚΄μš©μž…λ‹ˆλ‹€.

 

μ•„λž˜λŠ” Mul 클래슀의 μ½”λ“œμž…λ‹ˆλ‹€. 이전과 λ§ˆμ°¬κ°€μ§€λ‘œ 파이썬 ν•¨μˆ˜λ‘œ μ‚¬μš©ν•˜κΈ° μœ„ν•œ μ½”λ“œλ„ ν•¨κ»˜ κ΅¬ν˜„ν•©λ‹ˆλ‹€.

class Mul(Function):
    def forward(self, x0, x1):
        y = x0 * x1
        return y
    
    def backward(self, gy):
        x0, x1 = self.inputs[0].data, self.inputs[1].data
        return gy * x1, gy * x0


def mul(x0, x1):
    return Mul()(x0, x1)

 

ν˜„μž¬λŠ” κ³±μ…ˆμ„ μˆ˜ν–‰ν•˜κΈ° μœ„ν•΄ mul(a, b)와 같은 μ‹μœΌλ‘œ 번거둭게 κ΅¬ν˜„ν•΄μ•Όν•©λ‹ˆλ‹€. a * b와 같이 κΉ”λ”ν•˜κ²Œ μ‚¬μš©ν•  수 μžˆλ„λ‘ μ—°μ‚°μž μ˜€λ²„λ‘œλ“œλ₯Ό μ§„ν–‰ν•˜κ² μŠ΅λ‹ˆλ‹€. κ³±μ…ˆμ˜ 특수 λ©”μ„œλ“œλŠ” __mul__(self, other) μž…λ‹ˆλ‹€.

 

λ‹€μŒκ³Ό 같이 κ΅¬ν˜„ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

Variable:
    ...
    def __mul__(self, other):
    	return mul(self, other)

 

ν•˜μ§€λ§Œ 더 κ°„λ‹¨ν•œ 방법도 μžˆμŠ΅λ‹ˆλ‹€. νŒŒμ΄μ¬μ€ ν•¨μˆ˜λ„ κ°μ²΄μ΄λ―€λ‘œ 클래슀λ₯Ό μ •μ˜ν•œ ν›„ ν•¨μˆ˜ 자체λ₯Ό ν• λ‹Ή κ°€λŠ₯ν•©λ‹ˆλ‹€.

class Variable:
    ...

# μ—°μ‚°μž μ˜€λ²„λ‘œλ“œ
Variable.__mul__ = mul
Variable.__add__ = add

 

λ§ˆμ§€λ§‰μœΌλ‘œ ν…ŒμŠ€νŠΈλ₯Ό μ§„ν–‰ν•΄λ³΄κ² μŠ΅λ‹ˆλ‹€. κ²°κ³Όκ°€ 잘 λ‚˜μ˜€λŠ” 것을 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

a = Variable(np.array(3.0))
b = Variable(np.array(2.0))
c = Variable(np.array(1.0))

# y = add(mul(a, b), c)
y = a * b + c
y.backward()

print(y)  # variable(7.0)
print(a.grad)  # 2.0
print(b.grad)  # 3.0