Naive Bayes Classifier


  • 본문의 일부는 Data Science from scratch(밑바닥부터 시작하는 데이터 과학)을 참고했습니다.
  • 본문의 소스코드는 Joel Grus의 Github에서 Unlicensed 라이선스로 배포되고 있습니다.

스팸 필터

임의로 메시지를 선택하는 ‘공간’을 생각해보자. 메시지가 스팸인 경우를 S, 메시지에 비아그라라는 단어가 포함된 경우를 V라고 정의하자. 베이즈 정리를 사용하면 메시지에 비아그라라는 단어가 포함됐을 때 해당 메시지가 스팸일 확률은 아래와 같다.

V일 때 S의 확률 =
S일 때 V일 확률 x S일 확률 /
S일 때 V일 확률 x S일 확률 + V일 때 S가 아닐 확률 x S가 아닐 확률

분자는 메시지가 스팸이면서 비아그라라는 단어를 포함할 확률,
분모는 메시지에 비아그라라는 포함됐을 확률을 나타낸다.

만약 수 많은 스팸 메시지와 스팸이 아닌 메시지를 갖고 있으면 $P(V|S)$와 $P(V|ㄱS)$를 추정할 수 있다. 메시지가 스팸일 확률과 그렇지 않을 확률이 동일$(P(S)=P(ㄱS)=0.5)$하다면 위의 식을 아래와 같이 정리할 수 있다.

스팸 메시지의 50%, 스팸이 아닌 메시지의 1%만이 비아그라라는 단어를 포함한다면, 비아그라라는 단어를 포함한 메시지가 스팸일 확률은 아래와 같다.

조금은 더 똑똑한 스팸 필터

스팸을 처리하기 위한 말뭉치(corpus, w_1, …. w_n)가 주어졌다고 해보자. 확률 이론 적용을 위해 단어 w_i가 메시지에 포함되는 경우를 X_i로 나타내고 스팸 메시지에 i번째 단어가 포함되어 있는 확률인 $P(X_i|S)$와 스팸이 아닌 메시지에 i번째 단어가 포함되어 있는 확률 $P(X_I|ㄱS)$가 주어졌다고 해보자. 나이브 베이즈의 핵심은 ‘메시지가 스팸인가 아니냐가 주어졌다는 전제하에 각 단어의 존재는 서로 조건부 독립적이다’는(말도 안되는) 가정에 기반을 둔다. 좀 더 직관적으로 이해하자면, 어떤 스팸 메시지가 ‘비아그라’란 단어를 포함하고 있는 점이 같은 메시지가 ‘맥북프로’라는 단어를 포함하고 있는지 판단하는 데 도움을 주지 않음을 뜻한다.

나이브(Naive, 단순, 순진한) 베이즈라는 이름처럼 나이브 베이즈는 지나치게 극단적인 가정을 한다. 예컨데 사전에 수록된 단어가 ‘비아그라’와 ‘맥북프로’뿐이며, 모든 스팸 메시지의 반은 ‘값싼 비아그라’에 대한 메시지이고, 나머지 스팸 메시지는 ‘정품 맥북프로’에 대한 메시라고 해보자. 이런 경우 나이브 베이즈는 스팸 메시지에 ‘비아그라’와 ‘맥북프로’라는 단어가 포함될 확률을 다음처럼 추정한다.

현실에서는 ‘비아그라’와 ‘맥북프로’가 동시에 등장하지 않는다는 가정이 없었기 때문에 말도 안되는 가정으로 만들어진 모델이지만 성능은 꽤 뛰어나며, 이는 실제 스팸 필터에 적용할 수 있다.

‘비아그라’라는 단어만으로 스팸을 걸러 내는 필터에서도 사용된 베이즈 정리를 통해 메시지가 스팸일 확률을 다음처럼 계산할 수 있다.

베이즈의 가정을 따르면 각 단어가 메시지에 포함될 확률 값을 모두 곱해 위 식의 우변에 해당하는 확률 값을 모두 구할 수 있다.
하지만 이를 실제로 구현할 때 끊임없이 확률 값을 곱하는 일은 피해야 한다. 컴퓨터가 0에 가까운 부동소수점(floating point number)을 제대로 처리하지 못하므로 언더플로 문제가 발생할 수 있기 때문이다. 이 때 log(ab)=log a + log b이고 exp(log x) =x라는 것을 이용해서 부동소수점 문제$(p_1 *…*p_n)$를 다음과 같이 계산할 수 있다.

  • exp = 지수함수, log의 역함수.

이제 스팸이나 스팸이 아닌 메시지에 단어 w_i가 들어갈 확률 $P(X_i|S)$와 $P(X_i|ㄱS)$를 추정하는 일이 남았다. 만약 이미 스팸과 스팸이 아닌 메시지로 분류된 학습 메시지가 충분하다면 $P(X_i|S)$를 추정할 수 있는 간단한 방법은 스팸 메시지 중 w_i가 포함된 메시지의 비율을 사용하는 것이다. 그런데 이 방법에는 큰 단점이 있다. 학습 데이터에 ‘고양이’라는 단어는 스팸이 아닌 메시지에만 포함된다고 상상해보자. 즉, P(“고양이”|S)=0이다. 그런 경우 나이브 베이즈 분류기는 고양이라는 단어가 들어간 메시지를 항상 스팸 메시지가 아니라고 예측한다. ‘심지어 고양이도 인정하는 값싼 비아그라’라는 메시지도 스팸이 아니라 예측할 것이다. 이런 문제 해결을 위해서는 smoothing(평활화)가 필요하다.
Smoothing을 위해선 가짜 빈도수(pseudocount) k를 결정하고 스팸 메시지에서 i번째 단어가 나올 확률을 다음처럼 추정할 수 있다.

$P(X_i|ㄱS)$도 비슷하게 계산할 수 있다. 확률을 계산할 때 해당 단어가 포함된 스팸과 포함되지 않은 스팸이 이미 k번씩 나왔다고 가정을 한다. 예컨데 ‘고양이’라는 단어가 98개의 스팸 문서에서 단 한 번도 등장하진 않았지만, k가 1이라면 $P(“고양이”|S)$는 1/100=0.01로 계산할 수 있다. k를 통해 ‘고양이’란 단어가 포함된 메시지가 스팸 메시지일 확률을 0이 아닌 수로 설정할 수 있다.

구현하기

분류기(Classifier)에 필요한 이론적 부분을 살펴봤다면 이제 직접 구현할 차례다. 먼저 분류기를 적용하고자 하는 언어에 맞게 메시지를 전처리하는 과정이 필요한데 이를 토큰화(Tokenize)라고 한다. 영어를 기준으로 모든 메시지를 소문자로 바꾸고 숫자, 문자, 아포스트로피(apostrophe, ‘)가 포함된 단어를 추출해보자. 그리고 중복되는 단어를 제거해보자

from collections import Counter, defaultdict
import glob, re, random, math

def tokenize(message):
    message = message.lower()
    allWords = re.findall("[a-z0-9']+", message)
    return set(allWords)

# 학습 데이:q터에서 빈도수를 세는 함수를 만들자.
# 모양은 이렇다. [스팸 메시지에 나온 빈도수, 스팸이 아닌 메시지에 나온 빈도수]
def countWords(trainingSet):
    """데이터는 메시지 내용, 스팸 여부 꼴로 구성되어 있음"""
    counts = defaultdict(lambda: [0, 0])
    for message, isSpam in trainingSet:
        for word in tokenize(message):
            counts[word][0 if isSpam else 1] += 1
    return counts

# smoothing을 통해 단어의 빈도수로 확률 값을 추정하자. > k는 가짜빈도.
# 스팸에서 단어가 나올 확률, 스팸이 아닌 메시지에서 단어가 나올 확률의 list를 반환한다.
def wordProbabilities(counts, totalSpams, totalNonSpams, k=0.5):
    """빈도수를 [단어, p(w|스팸), p(w|~스팸)] 형태로 변환"""
    return [(w,
            (spam+k)/(totalSpams+2*k),             # spam+0.5 / totalSpams+1
            (nonSpam+k)/(totalNonSpams+2*k))       # nonSpam+0.5 / totalNonSpams+1
            for w, (spam, nonSpam) in counts.items()]

# 단어가 나올 경우 / 나오지 않을 경우에 대응해 log 덧셈한 값을 return한다.
def spamProbabilities(wordProbs, message):
    messageWords = tokenize(message)
    logProbIfSpam = logProbIfNotSpam = 0.0

    # 모든 단어 반복
    for word, probIfSpam, probIfNotSpam in wordProbs:

        if word in messageWords:
            logProbIfSpam += math.log(probIfSpam)
            logProbIfNotSpam += math.log(probIfNotSpam)

        else:
            logProbIfSpam += math.log(1.0 - probIfSpam)
            logProbIfNotSpam += math.log(1.0 - probIfNotSpam)

    probIfSpam = math.exp(logProbIfSpam)
    probIfNotSpam = math.exp(logProbIfNotSpam)
    return probIfSpam / (probIfSpam+probIfNotSpam)

class naiveBayesClassifier:

    def __init__(self, k=0.5):
        self.k = k
        self.wordProbs = []

    def train(self, trainingSet):
        numSpams = len([isSpam for message, isSpam in trainingSet if isSpam])
        numNonSpams = len(trainingSet) - numSpams

        wordCounts = countWords(trainingSet)
        self.wordProbs = wordProbabilities(wordCounts, numSpams, numNonSpams, self.k)

    def classify(self, message):
        return spamProbabilities(self.wordProbs, message)

path = r"C:\path\to\spam\*\*"

def pSpamGivenWord(wordProb):
    word, probIfSpam, probIfNotSpam = wordProb
    return probIfSpam / (probIfSpam + probIfNotSpam)

def getSubjectData(path):
    data = []
    subjectRegex = re.compile(r"Subject:\s+")
    for fn in glob.glob(path):
        isSpam = "ham" not in fn

        with open(fn,'r',encoding='ISO-8859-1') as file:
                for line in file:
                    if line.startswith("Subject:"):
                        subject = subjectRegex.sub("", line).strip()
                        data.append((subject, isSpam))       
    return data              

def splitData(data, prob):
    """split data into fractions [prob, 1 - prob]"""
    results = [], []
    for row in data:
        results[0 if random.random() < prob else 1].append(row)
    return results

def trainAndTest(path):
    data = getSubjectData(path)
    random.seed(0)
    trainData, testData = splitData(data, 0.75)

    classifier = naiveBayesClassifier()
    classifier.train(trainData)

    classified = [(subject, isSpam, classifier.classify(subject))
                   for subject, isSpam in testData]

    counts = Counter((isSpam, spamProbability > 0.5)
                      for _, isSpam, spamProbability in classified)

    print(counts)
    classified.sort(key=lambda row: row[2])
    spammiestHams = list(filter(lambda row: not row[1], classified))[-5:]
    hammiestSpams = list(filter(lambda row: row[1], classified))[:5]

    print("\n")
    print("Spammiest Hams :",spammiestHams)
    print("\n")
    print("Hammiest Spams :",hammiestSpams)

    words = sorted(classifier.wordProbs, key=pSpamGivenWord)
    spammiestWords = words[-5:]
    hammiestWords = words[:5]

    print("\nSpammiest Words :",spammiestWords)
    print("\nHammiest Words :",hammiestWords)

trainAndTest(path) # 가음성 704, 진양성 101, 가양성 38, 진음성 33
>>>
Counter({(False, False): 704, (True, True): 101, (True, False): 38, (False, True): 33})


Spammiest Hams : [('Attn programmers: support offered [FLOSS-Sarai Initiative]', False, 0.9756020492343607), ('2000+ year old Greek computer reinterpreted', False, 0.9835280743958158), ('What to look for in your next smart phone (Tech Update)', False, 0.9898673229124078), ('[ILUG-Social] Re: Important - reenactor insurance needed', False, 0.9995346925825354), ('[ILUG-Social] Re: Important - reenactor insurance needed', False, 0.9995346925825354)]


Hammiest Spams : [('Re: girls', True, 0.0009520823968456288), ('Introducing Chase Platinum for Students with a 0% Introductory APR', True, 0.001256093787333424), ('.Message report from your contact page....//ytu855 rkq', True, 0.0015102442616159145), ('Testing a system, please delete', True, 0.0026908231662706177), ('Never pay for the goodz again (8SimUgQ)', True, 0.005908929345307938)]

Spammiest Words : [('year', 0.028767123287671233, 0.00022893772893772894), ('sale', 0.031506849315068496, 0.00022893772893772894), ('rates', 0.031506849315068496, 0.00022893772893772894), ('systemworks', 0.036986301369863014, 0.00022893772893772894), ('money', 0.03972602739726028, 0.00022893772893772894)]

Hammiest Words : [('spambayes', 0.0013698630136986301, 0.04601648351648352), ('users', 0.0013698630136986301, 0.036401098901098904), ('razor', 0.0013698630136986301, 0.030906593406593408), ('zzzzteana', 0.0013698630136986301, 0.029075091575091576), ('sadev', 0.0013698630136986301, 0.026785714285714284)]

# 정밀도
print(101/(101+33))
# 재현율
print(101/(101+38))
>>
0.753731343283582
0.7266187050359713

한 발짝 더

나이브 분류기의 성능을 끌어 올리는 데에는 몇 가지 방법이 있다.

  • 학습 데이터를 대량으로 늘린다.
  • 메시지의 제목(Subject :) 부분 이외의 내용 또한 학습시킬 수 있도록 분류기를 개선한다.
  • 학습 데이터가 최소 빈도수만큼 등장하지 않으면 무시할 수 있도록 설정한다.
  • 이어진 한 단어를 필요에 맞게 잘라서 저장하는 걸 고려할 수 있다. (영문의 경우 cheapest를 cheap으로) 이를 stemmer라고 한다. 또 나이브 분류기에 한국어를 학습시키고 분류하는 방법으로 KoNLPy 라이브러리를 이용한 형태소 분석이 있다.