DonHurry

step19. ๋ณ€์ˆ˜ ์‚ฌ์šฉ์„ฑ ๊ฐœ์„  ๋ณธ๋ฌธ

DeZero/๐Ÿ—ป์ œ2๊ณ ์ง€

step19. ๋ณ€์ˆ˜ ์‚ฌ์šฉ์„ฑ ๊ฐœ์„ 

_๋„๋… 2023. 1. 20. 00:40

๐Ÿ“ข ๋ณธ ํฌ์ŠคํŒ…์€ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹3์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ฐฐ์šด ๋‚ด์šฉ์„ ๊ธฐ๋กํ•˜๊ณ , ๊ฐœ์ธ์ ์ธ ๊ณต๋ถ€๋ฅผ ์œ„ํ•ด ์ž‘์„ฑํ•˜๋Š” ํฌ์ŠคํŒ…์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ต์žฌ ๊ตฌ๋งค๋ฅผ ๊ฐ•๋ ฅ ์ถ”์ฒœ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

 

 

์ด๋ฏธ DeZero์˜ ๊ธฐ์ดˆ๋Š” ๋งŒ๋“ค์–ด์กŒ๊ณ , ์•ž์œผ๋กœ๋Š” ๋” ์‰ฝ๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ฐœ์„ ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ๋‹จ๊ณ„๋Š” Variable ํด๋ž˜์Šค ๊ฐœ์„ ์ž…๋‹ˆ๋‹ค. ์šฐ์„  ๋ณ€์ˆ˜๋“ค์„ ๊ตฌ๋ถ„์ง“๊ธฐ ์œ„ํ•œ ์ด๋ฆ„ ์„ค์ •์ž…๋‹ˆ๋‹ค. Variable ํด๋ž˜์Šค์— name์ด๋ผ๋Š” ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

class Variable:
    def __init__(self, data, name=None):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))

        self.data = data
        self.name = name
        self.grad = None
        self.creator = None
        self.generation = 0
    ...

 

๋ฐ์ดํ„ฐ๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ์ƒ์ž์ธ Variable์„ ๋ฐ์ดํ„ฐ์ธ ๊ฒƒ์ฒ˜๋Ÿผ ๋ณด์ด๊ฒŒ ํ•˜๋Š” ์žฅ์น˜๋ฅผ ๋งŒ๋“ค๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ๊ฐ€์ง€ ๋ฉ”์„œ๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š”๋ฐ, ํŒŒ์ด์ฌ์˜ @property๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ตฌํ˜„ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋ฉ”์„œ๋“œ๋ฅผ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜์ฒ˜๋Ÿผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•˜๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ์‹œ๋Š” ๋ถ„๋“ค์ด ๋งŽ์œผ๋‹ˆ ์ฐพ์•„๋ณด์‹œ๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ ๊ตฌํ˜„ํ•œ ๊ฒƒ ์™ธ์—๋„ ndarray์˜ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜๋ฅผ ํ™œ์šฉํ•˜๋ฉด ๋”์šฑ ๋‹ค์–‘ํ•œ ์ถ”๊ฐ€ ๊ธฐ๋Šฅ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

class Variable:
    ...
    @property
    def shape(self):
        return self.data.shape
    
    @property
    def ndim(self):
        return self.data.ndim
    
    @property
    def size(self):
        return self.data.size
    
    @property
    def dtype(self):
        return self.data.dtype

 

ํŒŒ์ด์ฌ์˜ len ํ•จ์ˆ˜๋„ ํ™œ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ถ”๊ฐ€๋กœ Variable์˜ ๋‚ด์šฉ ํ™•์ธ์„ ์‰ฝ๊ฒŒ ํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Šฅ๋„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

class Variable:
    ...
    def __len__(self):
        return len(self.data)

    def __repr__(self):
        if self.data is None:
            return 'variable(None)'
        p = str(self.data).replace('\n', '\n' + ' ' * 9)
        return 'variable(' + p + ')'

 

ํ…Œ์ŠคํŠธํ•ด๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•ž์œผ๋กœ๋„ ๊ฐœ์„ ์„ ์ด์–ด๋‚˜๊ฐ€๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
x.name = 'x'

print(x.name)  # x
print(x.shape)  # (2, 3)
print(x)  # variable([[1 2 3]
          #           [4 5 6]])