From 517a2f18be426d7f47c622bfcdf28c495697ad64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 21 Feb 2022 10:20:48 +1300 Subject: [PATCH] Fix titanic data download function --- 03_classification.ipynb | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index b45755d..6f4c61f 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -2528,22 +2528,21 @@ "metadata": {}, "outputs": [], "source": [ + "from pathlib import Path\n", "import pandas as pd\n", + "import tarfile\n", "import urllib.request\n", "\n", "def load_titanic_data():\n", - " titanic_path = Path() / \"datasets\" / \"titanic\"\n", - " titanic_path.mkdir(parents=True, exist_ok=True)\n", - " filenames = (\"train.csv\", \"test.csv\")\n", - " for filename in filenames:\n", - " filepath = titanic_path / filename\n", - " if filepath.is_file():\n", - " continue\n", - " data_root = \"https://github.com/ageron/data/raw/main/\"\n", - " url = data_root + \"titanic/\" + filename\n", - " print(\"Downloading\", filename)\n", - " urllib.request.urlretrieve(url, filepath)\n", - " return [pd.read_csv(titanic_path / filename) for filename in filenames]" + " tarball_path = Path(\"datasets/titanic.tgz\")\n", + " if not tarball_path.is_file():\n", + " Path(\"datasets\").mkdir(parents=True, exist_ok=True)\n", + " url = \"https://github.com/ageron/data/raw/main/titanic.tgz\"\n", + " urllib.request.urlretrieve(url, tarball_path)\n", + " with tarfile.open(tarball_path) as titanic_tarball:\n", + " titanic_tarball.extractall(path=\"datasets\")\n", + " return [pd.read_csv(Path(\"datasets/titanic\") / filename)\n", + " for filename in (\"train.csv\", \"test.csv\")]" ] }, {