diff --git a/03_classification.ipynb b/03_classification.ipynb index b8f7d75..486fd0c 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -1953,15 +1953,15 @@ "SPAM_URL = DOWNLOAD_ROOT + \"20030228_spam.tar.bz2\"\n", "SPAM_PATH = os.path.join(\"datasets\", \"spam\")\n", "\n", - "def fetch_spam_data(spam_url=SPAM_URL, spam_path=SPAM_PATH):\n", + "def fetch_spam_data(ham_url=HAM_URL, spam_url=SPAM_URL, spam_path=SPAM_PATH):\n", " if not os.path.isdir(spam_path):\n", " os.makedirs(spam_path)\n", - " for filename, url in ((\"ham.tar.bz2\", HAM_URL), (\"spam.tar.bz2\", SPAM_URL)):\n", + " for filename, url in ((\"ham.tar.bz2\", ham_url), (\"spam.tar.bz2\", spam_url)):\n", " path = os.path.join(spam_path, filename)\n", " if not os.path.isfile(path):\n", " urllib.request.urlretrieve(url, path)\n", " tar_bz2_file = tarfile.open(path)\n", - " tar_bz2_file.extractall(path=SPAM_PATH)\n", + " tar_bz2_file.extractall(path=spam_path)\n", " tar_bz2_file.close()" ] }, @@ -2392,7 +2392,7 @@ " for url in urls:\n", " text = text.replace(url, \" URL \")\n", " if self.replace_numbers:\n", - " text = re.sub(r'\\d+(?:\\.\\d*(?:[eE]\\d+))?', 'NUMBER', text)\n", + " text = re.sub(r'\\d+(?:\\.\\d*)?(?:[eE][+-]?\\d+)?', 'NUMBER', text)\n", " if self.remove_punctuation:\n", " text = re.sub(r'\\W+', ' ', text, flags=re.M)\n", " word_counts = Counter(text.split())\n", @@ -2455,7 +2455,6 @@ " for word, count in word_count.items():\n", " total_count[word] += min(count, 10)\n", " most_common = total_count.most_common()[:self.vocabulary_size]\n", - " self.most_common_ = most_common\n", " self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}\n", " return self\n", " def transform(self, X, y=None):\n",