From afff4010fcba4f42dfd71a3a7b5a6ca5be4736bf Mon Sep 17 00:00:00 2001 From: Aurelien Geron Date: Wed, 4 Oct 2017 13:43:43 +0200 Subject: [PATCH] Add spam classifier exercise solution in chapter 3 --- 03_classification.ipynb | 988 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 982 insertions(+), 6 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 40dd82b..e8ef581 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -749,7 +749,9 @@ { "cell_type": "code", "execution_count": 34, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n", @@ -3688,19 +3690,993 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Spam classifier\n", - "\n", - "Coming soon..." + "## 4. Spam classifier" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's fetch the data:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 126, "metadata": { "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "import os\n", + "import tarfile\n", + "from six.moves import urllib\n", + "\n", + "DOWNLOAD_ROOT = \"http://spamassassin.apache.org/old/publiccorpus/\"\n", + "HAM_URL = DOWNLOAD_ROOT + \"20030228_easy_ham.tar.bz2\"\n", + "SPAM_URL = DOWNLOAD_ROOT + \"20030228_spam.tar.bz2\"\n", + "SPAM_PATH = os.path.join(\"datasets\", \"spam\")\n", + "\n", + "def fetch_spam_data(spam_url=SPAM_URL, spam_path=SPAM_PATH):\n", + " if not os.path.isdir(spam_path):\n", + " os.makedirs(spam_path)\n", + " for filename, url in ((\"ham.tar.bz2\", HAM_URL), (\"spam.tar.bz2\", SPAM_URL)):\n", + " path = os.path.join(spam_path, filename)\n", + " if not os.path.isfile(path):\n", + " urllib.request.urlretrieve(url, path)\n", + " tar_bz2_file = tarfile.open(path)\n", + " tar_bz2_file.extractall(path=SPAM_PATH)\n", + " tar_bz2_file.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "fetch_spam_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's load all the emails:" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "HAM_DIR = os.path.join(SPAM_PATH, \"easy_ham\")\n", + "SPAM_DIR = os.path.join(SPAM_PATH, \"spam\")\n", + "ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n", + "spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2500" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(ham_filenames)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "500" + ] + }, + "execution_count": 130, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(spam_filenames)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use Python's `email` module to parse these emails (this handles headers, encoding, and so on):" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import email\n", + "import email.policy\n", + "\n", + "def load_email(is_spam, filename, spam_path=SPAM_PATH):\n", + " directory = \"spam\" if is_spam else \"easy_ham\"\n", + " with open(os.path.join(spam_path, directory, filename), \"rb\") as f:\n", + " return email.parser.BytesParser(policy=email.policy.default).parse(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]\n", + "spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Martin A posted:\n", + "Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the\n", + " limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the\n", + " Mount Athos monastic community, was ideal for the patriotic sculpture. \n", + " \n", + " As well as Alexander's granite features, 240 ft high and 170 ft wide, a\n", + " museum, a restored amphitheatre and car park for admiring crowds are\n", + "planned\n", + "---------------------\n", + "So is this mountain limestone or granite?\n", + "If it's limestone, it'll weather pretty fast.\n", + "\n", + "------------------------ Yahoo! Groups Sponsor ---------------------~-->\n", + "4 DVDs Free +s&p Join Now\n", + "http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM\n", + "---------------------------------------------------------------------~->\n", + "\n", + "To unsubscribe from this group, send an email to:\n", + "forteana-unsubscribe@egroups.com\n", + "\n", + " \n", + "\n", + "Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/\n" + ] + } + ], + "source": [ + "print(ham_emails[1].get_content().strip())" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help wanted. We are a 14 year old fortune 500 company, that is\n", + "growing at a tremendous rate. We are looking for individuals who\n", + "want to work from home.\n", + "\n", + "This is an opportunity to make an excellent income. No experience\n", + "is required. We will train you.\n", + "\n", + "So if you are looking to be employed from home with a career that has\n", + "vast opportunities, then go:\n", + "\n", + "http://www.basetel.com/wealthnow\n", + "\n", + "We are looking for energetic and self motivated people. If that is you\n", + "than click on the link and fill out the form, and one of our\n", + "employement specialist will contact you.\n", + "\n", + "To be removed from our link simple go to:\n", + "\n", + "http://www.basetel.com/remove.html\n", + "\n", + "\n", + "4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40\n" + ] + } + ], + "source": [ + "print(spam_emails[6].get_content().strip())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of structures we have:" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def get_email_structure(email):\n", + " if isinstance(email, str):\n", + " return email\n", + " payload = email.get_payload()\n", + " if isinstance(payload, list):\n", + " return \"multipart({})\".format(\", \".join([\n", + " get_email_structure(sub_email)\n", + " for sub_email in payload\n", + " ]))\n", + " else:\n", + " return email.get_content_type()" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from collections import Counter\n", + "\n", + "def structures_counter(emails):\n", + " structures = Counter()\n", + " for email in emails:\n", + " structure = get_email_structure(email)\n", + " structures[structure] += 1\n", + " return structures" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('text/plain', 2408),\n", + " ('multipart(text/plain, application/pgp-signature)', 66),\n", + " ('multipart(text/plain, text/html)', 8),\n", + " ('multipart(text/plain, text/plain)', 4),\n", + " ('multipart(text/plain)', 3),\n", + " ('multipart(text/plain, application/octet-stream)', 2),\n", + " ('multipart(text/plain, text/enriched)', 1),\n", + " ('multipart(text/plain, video/mng)', 1),\n", + " ('multipart(text/plain, application/x-pkcs7-signature)', 1),\n", + " ('multipart(text/plain, multipart(text/plain))', 1),\n", + " ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',\n", + " 1),\n", + " ('multipart(text/plain, application/ms-tnef, text/plain)', 1),\n", + " ('multipart(text/plain, application/x-java-applet)', 1),\n", + " ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',\n", + " 1),\n", + " ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',\n", + " 1)]" + ] + }, + "execution_count": 137, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "structures_counter(ham_emails).most_common()" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('text/plain', 218),\n", + " ('text/html', 183),\n", + " ('multipart(text/plain, text/html)', 45),\n", + " ('multipart(text/html)', 20),\n", + " ('multipart(text/plain)', 19),\n", + " ('multipart(multipart(text/html))', 5),\n", + " ('multipart(text/plain, image/jpeg)', 3),\n", + " ('multipart(text/html, application/octet-stream)', 2),\n", + " ('multipart(multipart(text/plain, text/html), image/gif)', 1),\n", + " ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),\n", + " ('multipart/alternative', 1),\n", + " ('multipart(text/html, text/plain)', 1),\n", + " ('multipart(text/plain, application/octet-stream)', 1)]" + ] + }, + "execution_count": 138, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "structures_counter(spam_emails).most_common()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is a usual information to have." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's take a look at the email headers:" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Return-Path : <12a1mailbot1@web.de>\n", + "Delivered-To : zzzz@localhost.spamassassin.taint.org\n", + "Received : from localhost (localhost [127.0.0.1])\tby phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32\tfor ; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)\n", + "Received : from mail.webnote.net [193.120.211.219]\tby localhost with POP3 (fetchmail-5.9.0)\tfor zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)\n", + "Received : from dd_it7 ([210.97.77.167])\tby webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623\tfor ; Thu, 22 Aug 2002 13:09:41 +0100\n", + "From : 12a1mailbot1@web.de\n", + "Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7 with Microsoft SMTPSVC(5.5.1775.675.6);\t Sat, 24 Aug 2002 09:42:10 +0900\n", + "To : dcek1a1@netsgo.com\n", + "Subject : Life Insurance - Why Pay More?\n", + "Date : Wed, 21 Aug 2002 20:31:57 -1600\n", + "MIME-Version : 1.0\n", + "Message-ID : <0103c1042001882DD_IT7@dd_it7>\n", + "Content-Type : text/html; charset=\"iso-8859-1\"\n", + "Content-Transfer-Encoding : quoted-printable\n" + ] + } + ], + "source": [ + "for header, value in spam_emails[0].items():\n", + " print(header,\":\",value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There's probably a lot of useful information in there, such as the sender's email address (12a1mailbot1@web.de looks fishy), but we will just focus on the `Subject` header:" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Life Insurance - Why Pay More?'" + ] + }, + "execution_count": 140, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spam_emails[0][\"Subject\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "X = np.array(ham_emails + spam_emails)\n", + "y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do this would be to use the great [BeautifulSoup](https://www.crummy.com/software/BeautifulSoup/) library, but I would like to avoid adding another dependency to this project, so let's hack a quick & dirty solution using regular expressions (at the risk of [un̨ho͞ly radiańcé destro҉ying all enli̍̈́̂̈́ghtenment](https://stackoverflow.com/a/1732454/38626)). The following function first drops the `` section, then converts all `` tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text. For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as `>` or ` `):" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import re\n", + "from html import unescape\n", + "\n", + "def html_to_plain_text(html):\n", + " text = re.sub('.*?', '', html, flags=re.M | re.S | re.I)\n", + " text = re.sub('', ' HYPERLINK ', text, flags=re.M | re.S | re.I)\n", + " text = re.sub('<.*?>', '', text, flags=re.M | re.S)\n", + " text = re.sub(r'(\\s*\\n)+', '\\n', text, flags=re.M | re.S)\n", + " return unescape(text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's see if it works. This is HTML spam:" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "
\n", + "\n", + "OTC
\n", + "\n", + " Newsletter
\n", + "Discover Tomorrow's Winners 
comput\n", + "Computation => comput\n", + "Computing => comput\n", + "Computed => comput\n", + "Compute => comput\n", + "Compulsive => compuls\n" + ] + } + ], + "source": [ + "try:\n", + " import nltk\n", + "\n", + " stemmer = nltk.PorterStemmer()\n", + " for word in (\"Computations\", \"Computation\", \"Computing\", \"Computed\", \"Compute\", \"Compulsive\"):\n", + " print(word, \"=>\", stemmer.stem(word))\n", + "except ImportError:\n", + " print(\"Error: stemming requires the NLTK module.\")\n", + " stemmer = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also need a way to replace URLs with the word \"URL\". For this, we could use hard core [regular expressions](https://mathiasbynens.be/demo/url-regex) but we will just use the [urlextract](https://github.com/lipoja/URLExtract) library. You can install it with the following command (don't forget to activate your virtualenv first; if you don't have one, you will likely need administrator rights, or use the `--user` option):\n", + "\n", + "`$ pip install urlextract`" + ] + }, + { + "cell_type": "code", + "execution_count": 148, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']\n" + ] + } + ], + "source": [ + "try:\n", + " import urlextract # may require an Internet connection to download root domain names\n", + " \n", + " url_extractor = urlextract.URLExtract()\n", + " print(url_extractor.find_urls(\"Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s\"))\n", + "except ImportError:\n", + " print(\"Error: replacing URLs requires the urlextract module.\")\n", + " url_extractor = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are ready to put all this together into a transformer that we will use to convert emails to word counters. Note that we split sentences into words using Python's `split()` method, which uses whitespaces for word boundaries. This works for many written languages, but not all. For example, Chinese and Japanese scripts generally don't use spaces between words, and Vietnamese often uses spaces even between syllables. It's okay in this exercise, because the dataset is (mostly) in English." + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.base import BaseEstimator, TransformerMixin\n", + "\n", + "class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):\n", + " def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,\n", + " replace_urls=True, replace_numbers=True, stemming=True):\n", + " self.strip_headers = strip_headers\n", + " self.lower_case = lower_case\n", + " self.remove_punctuation = remove_punctuation\n", + " self.replace_urls = replace_urls\n", + " self.replace_numbers = replace_numbers\n", + " self.stemming = stemming\n", + " def fit(self, X, y=None):\n", + " return self\n", + " def transform(self, X, y=None):\n", + " X_transformed = []\n", + " for email in X:\n", + " text = email_to_text(email) or \"\"\n", + " if self.lower_case:\n", + " text = text.lower()\n", + " if self.replace_urls and url_extractor is not None:\n", + " urls = list(set(url_extractor.find_urls(text)))\n", + " urls.sort(key=lambda url: len(url), reverse=True)\n", + " for url in urls:\n", + " text = text.replace(url, \" URL \")\n", + " if self.replace_numbers:\n", + " text = re.sub(r'\\d+(?:\\.\\d*(?:[eE]\\d+))?', 'NUMBER', text)\n", + " if self.remove_punctuation:\n", + " text = re.sub(r'\\W+', ' ', text, flags=re.M)\n", + " word_counts = Counter(text.split())\n", + " if self.stemming and stemmer is not None:\n", + " stemmed_word_counts = Counter()\n", + " for word, count in word_counts.items():\n", + " stemmed_word = stemmer.stem(word)\n", + " stemmed_word_counts[stemmed_word] += count\n", + " word_counts = stemmed_word_counts\n", + " X_transformed.append(word_counts)\n", + " return np.array(X_transformed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try this transformer on a few emails:" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ Counter({'wrote': 1, 'r': 1, 'chuck': 1, 'stuff': 1, 'murcko': 1, 'yawn': 1}),\n", + " Counter({'the': 11, 'of': 9, 'and': 8, 'christian': 3, 'all': 3, 'by': 3, 'to': 3, 'superstit': 2, 'been': 2, 'teach': 2, 'jesu': 2, 'on': 2, 'have': 2, 'one': 2, 'rogueri': 2, 'i': 2, 'jefferson': 2, 'half': 2, 'ha': 2, 'e': 1, 'effect': 1, 'mytholog': 1, 'again': 1, 'what': 1, 'most': 1, 'william': 1, 'perpetr': 1, 'do': 1, 'world': 1, 'women': 1, 'first': 1, 'thi': 1, 'redeem': 1, 'url': 1, 'letter': 1, 'ever': 1, 'particular': 1, 'pervert': 1, 'interest': 1, 'histor': 1, 'sinc': 1, 'found': 1, 'find': 1, 'becom': 1, 'hypocrit': 1, 'larg': 1, 'featur': 1, 'some': 1, 'great': 1, 'known': 1, 'fabl': 1, 'fine': 1, 'examin': 1, 'children': 1, 'imprison': 1, 'other': 1, 'introduct': 1, 'john': 1, 'make': 1, 'led': 1, 'absurd': 1, 'quot': 1, 'coercion': 1, 'american': 1, 'million': 1, 'they': 1, 'earth': 1, 'burnt': 1, 'not': 1, 'paul': 1, 'short': 1, 'remsburg': 1, 'man': 1, 'support': 1, 'alik': 1, 'were': 1, 'word': 1, 'shone': 1, 'our': 1, 'thoma': 1, 'corrupt': 1, 'dupe': 1, 'that': 1, 'system': 1, 'a': 1, 'band': 1, 'upon': 1, 'innoc': 1, 'error': 1, 'men': 1, 'are': 1, 'fool': 1, 'tortur': 1, 'six': 1, 'import': 1, 'over': 1, 'untruth': 1, 'in': 1}),\n", + " Counter({'url': 5, 's': 3, 'group': 3, 'to': 3, 'martin': 2, 'we': 2, 'forteana': 2, 'is': 2, 'and': 2, 'an': 2, 'unsubscrib': 2, 'yahoo': 2, 'in': 2, 'join': 1, 'career': 1, 'hamza': 1, 'yemen': 1, 'email': 1, 'muslim': 1, 'should': 1, 'rather': 1, 'wrote': 1, 'more': 1, 'thi': 1, 'includ': 1, 'belief': 1, 'factual': 1, 'altern': 1, 'know': 1, 'html': 1, 'rundown': 1, 'hi': 1, 'use': 1, 'send': 1, 'dvd': 1, 'free': 1, 'all': 1, 'for': 1, 'non': 1, 'that': 1, 'your': 1, 'on': 1, 'sponsor': 1, 'y': 1, 'now': 1, 'adamson': 1, 'p': 1, 'from': 1, 'of': 1, 'number': 1, 'rob': 1, 'murder': 1, 'memri': 1, 'outright': 1, 'how': 1, 'subject': 1, 'be': 1, 'unbias': 1, 'base': 1, 't': 1, 'don': 1})], dtype=object)" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_few = X_train[:3]\n", + "X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)\n", + "X_few_wordcounts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This looks about right!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have the word counts, and we need to convert them to vectors. For this, we will build another transformer whose `fit()` method will build the vocabulary (an ordered list of the most common words) and whose `transform()` method will use the vocabulary to convert word counts to vectors. The output is a sparse matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 151, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from scipy.sparse import csr_matrix\n", + "\n", + "class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):\n", + " def __init__(self, vocabulary_size=1000):\n", + " self.vocabulary_size = vocabulary_size\n", + " def fit(self, X, y=None):\n", + " total_count = Counter()\n", + " for word_count in X:\n", + " for word, count in word_count.items():\n", + " total_count[word] += min(count, 10)\n", + " most_common = total_count.most_common()[:self.vocabulary_size]\n", + " self.most_common_ = most_common\n", + " self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}\n", + " return self\n", + " def transform(self, X, y=None):\n", + " rows = []\n", + " cols = []\n", + " data = []\n", + " for row, word_count in enumerate(X):\n", + " for word, count in word_count.items():\n", + " rows.append(row)\n", + " cols.append(self.vocabulary_.get(word, 0))\n", + " data.append(count)\n", + " return csr_matrix((data, (rows, cols)), shape=(len(X), self.vocabulary_size + 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<3x11 sparse matrix of type ''\n", + "\twith 19 stored elements in Compressed Sparse Row format>" + ] + }, + "execution_count": 152, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)\n", + "X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)\n", + "X_few_vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [100, 9, 8, 11, 1, 3, 3, 3, 2, 0, 3],\n", + " [ 64, 1, 2, 0, 5, 3, 1, 0, 1, 3, 0]], dtype=int64)" + ] + }, + "execution_count": 153, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_few_vectors.toarray()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What does this matrix mean? Well, the 65 in the third row, first column, means that the third email contains 65 words that are not part of the vocabulary. The 0 next to it means that the first word in the vocabulary is not present in this email. The 1 next to it means that the second word is present once, and so on. You can look at the vocabulary to know which words we are talking about. The first word is \"the\", the second word is \"of\", etc." + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'all': 6,\n", + " 'and': 2,\n", + " 'by': 10,\n", + " 'christian': 7,\n", + " 'of': 1,\n", + " 'on': 8,\n", + " 's': 9,\n", + " 'the': 3,\n", + " 'to': 5,\n", + " 'url': 4}" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab_transformer.vocabulary_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are now ready to train our first spam classifier! Let's transform the whole dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.pipeline import Pipeline\n", + "\n", + "preprocess_pipeline = Pipeline([\n", + " (\"email_to_wordcount\", EmailToWordCounterTransformer()),\n", + " (\"wordcount_to_vector\", WordCounterToVectorTransformer()),\n", + "])\n", + "\n", + "X_train_transformed = preprocess_pipeline.fit_transform(X_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CV] ................................................................\n", + "[CV] .................................. , score=0.98375, total= 0.0s\n", + "[CV] ................................................................\n", + "[CV] .................................... , score=0.985, total= 0.1s\n", + "[CV] ................................................................\n", + "[CV] ................................... , score=0.9925, total= 0.1s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" + ] + }, + { + "data": { + "text/plain": [ + "0.98708333333333342" + ] + }, + "execution_count": 156, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "log_clf = LogisticRegression()\n", + "score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)\n", + "score.mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Over 98.7%, not bad for a first try! :) However, remember that we are using the \"easy\" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.\n", + "\n", + "But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.95%\n", + "Recall: 0.98%\n" + ] + } + ], + "source": [ + "from sklearn.metrics import precision_score, recall_score\n", + "\n", + "X_test_transformed = preprocess_pipeline.transform(X_test)\n", + "\n", + "log_clf = LogisticRegression()\n", + "log_clf.fit(X_train_transformed, y_train)\n", + "\n", + "y_pred = log_clf.predict(X_test_transformed)\n", + "\n", + "print(\"Precision: {:.2f}%\".format(precision_score(y_test, y_pred)))\n", + "print(\"Recall: {:.2f}%\".format(recall_score(y_test, y_pred)))" + ] } ], "metadata": {