DonHurry

[Python] KNN (K-Nearest Neighbors) 본문

Data Science

[Python] KNN (K-Nearest Neighbors)

_도녁 2022. 12. 13. 00:01

데이터 준비

우선 오늘 실습에 사용할 데이터입니다. 머신러닝 분야에서 매우 자주 사용되는 붓꽃(iris) 데이터셋입니다.

 

UCI Machine Learning Repository: Iris Data Set

Data Set Characteristics:   Multivariate Number of Instances: 150 Area: Life Attribute Characteristics: Real Number of Attributes: 4 Date Donated 1988-07-01 Associated Tasks: Classification Missing Values? No Number of Web Hits: 5048222 Source: Creator:

archive.ics.uci.edu

 

 

 

모듈 불러오기

사용할 모듈을 import 합니다.

import matplotlib.pyplot as plt
import random

 

데이터 불러오기

이전 실습에서는 pandas 등을 활용하여 데이터를 불러왔습니다. 본 실습에서는 다른 방식으로 진행해보도록 하겠습니다. 참고로 iris.data의 마지막 줄에는 공백이 있습니다. pandas를 활용하면 자동으로 무시되지만, 아래와 같은 방식으로 이용할 경우 데이터 내부에서 수기로 지워주거나 전처리를 해주어야 합니다.

# iris data에서 마지막 줄 공백은 수기로 삭제
data = []

for line in open("iris.data", "r"):
  raw = line.split(",")
  x = [float(r) for r in raw[:4]]
  y = raw[4].strip()
  data.append((x, y))

 

붓꽃 데이터에 대한 이해를 위해 아래 코드의 결과를 첨부했습니다.

꽃받침(sepal)과 꽃잎(petal)에 대한 길이, 넓이, 품종 정보가 포함되어 있습니다.

import pandas as pd

data = pd.read_csv("iris.data", names=['sepal length', 'sepal width', 'petal length', 'petal width', 'label'])
data

 

데이터 살펴보기

matplotlib을 활용하여 데이터를 시각화합니다. 우선 꽃받침의 길이와 넓이에 따라 구분해봅니다.

species = {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica': 2}
plt.scatter([d[0][0] for d in data], [d[0][1] for d in data], c=[species[d[1]] for d in data], alpha=0.3)
plt.show()

 

다음으로 꽃잎의 길이와 넓이로 구분해봅니다.

species = {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica': 2}
plt.scatter([d[0][2] for d in data], [d[0][3] for d in data], c=[species[d[1]] for d in data], alpha=0.3)
plt.show()

 

distance

거리를 구하는 함수를 구현합니다. 거리를 구하는 방법에는 유클리디안 거리, 맨하튼 거리 등 다양하게 존재하지만 가장 간단한 유클리디안으로 구현해보겠습니다. distance2는 distance 함수의 이해를 돕기 위해 그 과정을 풀어놓은 것입니다.

def distance(a, b):
  return sum((x - y) ** 2 for x, y in zip(a, b)) ** 0.5


def distance2(a, b):
  s = 0
  for x, y in zip(a, b):
    s += (x - y) ** 2
  return s ** 0.5

 

결과는 같습니다.

a = distance(data[0][0], data[1][0])
print(a)
b = distance2(data[0][0], data[1][0])
print(b)

0.5385164807134502

0.5385164807134502

 

knn_classify

본격적으로 knn을 구현합니다. 우선 query 벡터가 주어지면 기존 데이터셋 벡터와의 거리를 계산하여 거리가 가까운 순으로 정렬해줍니다. 이때 정렬 함수를 이용하면 시간 복잡도가 O(n log n) 이지만 max heap을 활용하면 O(n log k)로 구현 가능합니다. 본 실습에서는 간단하게 sort함수를 활용하겠습니다.

 

knn_classify 함수의 리턴값은 query와 가장 유사한 붓꽃 품종입니다. 이때 k는 계산한 거리들 중 k개를 붓꽃 품종을 결정하는데 활용하겠다는 뜻입니다. 그런데 만약 정렬하는 과정에서 1번째와 2번째 거리가 동일하다면 제대로 분류하지 못할 가능성이 있습니다. 따라서 거리가 같은 경우에는 정렬해준 결과에서 임의로 하나의 데이터를 제거하고 다시 계산해줍니다.

from collections import Counter

# sorting 사용 O(n log n)
# max heap 사용 O(n log k)
def knn_classify(k, query, train):
  # query = [5.9, 3.0, 5.1, 1.8]
  # train = [([6.5, 3.0, 5.2, 2.0], 'Iris-virginica'), ...]
  res = sorted([(distance(query, v), l) for v, l in train])[:k]
  cnts = Counter(r[1] for r in res)
  mc = cnts.most_common(2)

  while len(cnts) >= 2 and mc[0][1] == mc[1][1]:
    res = res[:-1]
    cnts = Counter(r[1] for r in res)
    mc = cnts.most_common(2)

  return cnts.most_common(1)[0][0]
  

print(knn_classify(10, [5.5, 3.1, 5.0, 1.7], data))

Iris-virginica

 

k를 3으로 설정하고 정확도를 계산해보겠습니다. 이때 query는 기존 데이터를 그대로 활용하되, 동일한 벡터는 제외합니다. 매 쿼리마다 knn_classify를 진행하고 실제 정답과 일치하는지 비교해 정확도를 구해줍니다.

k = 3

n_correct = 0
n_incorrect = 0

for i, q in enumerate(data):
  new_data = []
  for j, d in enumerate(data):
    if i == j: continue
    new_data.append(d)
  
  ans = knn_classify(k, q[0], new_data)
  true_ans = q[1]
  if ans == true_ans:
    n_correct += 1
  else:
    n_incorrect += 1

print("accuracy", n_correct / (n_correct + n_incorrect))

accuracy 0.96

 

 

최적의 k 찾기

k를 변환해가며 실행하면, 몇 개의 거리 데이터를 사용하느냐에 따라 품종 예측의 정확도가 달라지는 것을 확인할 수 있습니다. 그렇다면 최적의 k개는 무엇일까요? 가장 간단한 방법은 전부 시행해보는 것입니다. 다음은 k를 1부터 50까지 시행해보았을 때의 결과입니다. 시각화하면 k가 20일 때 정확도가 가장 높은 것을 확인할 수 있습니다.

Y = []
X = []

for k in range(1, 50):
  n_correct = 0
  n_incorrect = 0

  for i, q in enumerate(data):
    new_data = []
    for j, d in enumerate(data):
      if i == j: continue
      new_data.append(d)
    
    ans = knn_classify(k, q[0], new_data)
    true_ans = q[1]
    if ans == true_ans:
      n_correct += 1
    else:
      n_incorrect += 1

  X.append(k)
  Y.append(n_correct / (n_correct + n_incorrect))


plt.plot(X, Y)
plt.show()

 

'Data Science' 카테고리의 다른 글

[Python] PageRank  (1) 2022.12.15
[Python] PCA (Principal Component Analysis)  (0) 2022.12.14
[Python] Clustering  (0) 2022.12.12
[Python] Latent Factor Model  (0) 2022.12.12
[Python] Linear Regression  (2) 2022.11.07