diff --git a/03_classification.ipynb b/03_classification.ipynb index ae2ee96..ea39159 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -460,7 +460,9 @@ { "cell_type": "code", "execution_count": 34, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n", @@ -2061,19 +2063,660 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Spam classifier\n", - "\n", - "Coming soon..." + "## 4. Spam classifier" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's fetch the data:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 126, "metadata": { "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "import os\n", + "import tarfile\n", + "from six.moves import urllib\n", + "\n", + "DOWNLOAD_ROOT = \"http://spamassassin.apache.org/old/publiccorpus/\"\n", + "HAM_URL = DOWNLOAD_ROOT + \"20030228_easy_ham.tar.bz2\"\n", + "SPAM_URL = DOWNLOAD_ROOT + \"20030228_spam.tar.bz2\"\n", + "SPAM_PATH = os.path.join(\"datasets\", \"spam\")\n", + "\n", + "def fetch_spam_data(spam_url=SPAM_URL, spam_path=SPAM_PATH):\n", + " if not os.path.isdir(spam_path):\n", + " os.makedirs(spam_path)\n", + " for filename, url in ((\"ham.tar.bz2\", HAM_URL), (\"spam.tar.bz2\", SPAM_URL)):\n", + " path = os.path.join(spam_path, filename)\n", + " if not os.path.isfile(path):\n", + " urllib.request.urlretrieve(url, path)\n", + " tar_bz2_file = tarfile.open(path)\n", + " tar_bz2_file.extractall(path=SPAM_PATH)\n", + " tar_bz2_file.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "fetch_spam_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's load all the emails:" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "HAM_DIR = os.path.join(SPAM_PATH, \"easy_ham\")\n", + "SPAM_DIR = os.path.join(SPAM_PATH, \"spam\")\n", + "ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n", + "spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "len(ham_filenames)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "metadata": {}, + "outputs": [], + "source": [ + "len(spam_filenames)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use Python's `email` module to parse these emails (this handles headers, encoding, and so on):" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import email\n", + "import email.policy\n", + "\n", + "def load_email(is_spam, filename, spam_path=SPAM_PATH):\n", + " directory = \"spam\" if is_spam else \"easy_ham\"\n", + " with open(os.path.join(spam_path, directory, filename), \"rb\") as f:\n", + " return email.parser.BytesParser(policy=email.policy.default).parse(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]\n", + "spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "metadata": {}, + "outputs": [], + "source": [ + "print(ham_emails[1].get_content().strip())" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [], + "source": [ + "print(spam_emails[6].get_content().strip())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of structures we have:" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def get_email_structure(email):\n", + " if isinstance(email, str):\n", + " return email\n", + " payload = email.get_payload()\n", + " if isinstance(payload, list):\n", + " return \"multipart({})\".format(\", \".join([\n", + " get_email_structure(sub_email)\n", + " for sub_email in payload\n", + " ]))\n", + " else:\n", + " return email.get_content_type()" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from collections import Counter\n", + "\n", + "def structures_counter(emails):\n", + " structures = Counter()\n", + " for email in emails:\n", + " structure = get_email_structure(email)\n", + " structures[structure] += 1\n", + " return structures" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [], + "source": [ + "structures_counter(ham_emails).most_common()" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [], + "source": [ + "structures_counter(spam_emails).most_common()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is a usual information to have." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's take a look at the email headers:" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [], + "source": [ + "for header, value in spam_emails[0].items():\n", + " print(header,\":\",value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There's probably a lot of useful information in there, such as the sender's email address (12a1mailbot1@web.de looks fishy), but we will just focus on the `Subject` header:" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [], + "source": [ + "spam_emails[0][\"Subject\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "X = np.array(ham_emails + spam_emails)\n", + "y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do this would be to use the great [BeautifulSoup](https://www.crummy.com/software/BeautifulSoup/) library, but I would like to avoid adding another dependency to this project, so let's hack a quick & dirty solution using regular expressions (at the risk of [un̨ho͞ly radiańcé destro҉ying all enli̍̈́̂̈́ghtenment](https://stackoverflow.com/a/1732454/38626)). The following function first drops the `
` section, then converts all `` tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text. For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as `>` or ` `):" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import re\n", + "from html import unescape\n", + "\n", + "def html_to_plain_text(html):\n", + " text = re.sub('