#!/usr/bin/env python

import collections
import glob
import hashlib
import os
import subprocess


class NotebookAnalyser:

    def __init__(self, dry_run=False, verbose=False, colorful=False):
        self._dry_run = dry_run
        self._verbose = verbose
        self._colors = collections.defaultdict(lambda: "")
        if colorful:
            for color in [
                NotebookAnalyser.COLOR_WHITE,
                NotebookAnalyser.COLOR_RED,
                NotebookAnalyser.COLOR_GREEN,
                NotebookAnalyser.COLOR_YELLOW,
            ]:
                self._colors[color] =    "\033[{}m".format(color)

    NOTEBOOK_SUFFIX = ".ipynb"
    CHECKPOINT_DIR = NOTEBOOK_SUFFIX + "_checkpoints"
    CHECKPOINT_MASK = "*-checkpoint" + NOTEBOOK_SUFFIX
    CHECKPOINT_MASK_LEN = len(CHECKPOINT_MASK) - 1

    @staticmethod
    def get_hash(file_path):
        with open(file_path, "rb") as input:
            hash = hashlib.md5()
            for chunk in iter(lambda: input.read(4096), b""):
                hash.update(chunk)
        return hash.hexdigest()

    MESSAGE_ORPHANED = "missing "
    MESSAGE_MODIFIED = "modified"
    MESSAGE_DELETED =  "DELETING"

    COLOR_WHITE = "0"
    COLOR_RED = "31"
    COLOR_GREEN = "32"
    COLOR_YELLOW = "33"

    def log(self, message, file, color=COLOR_WHITE):
        color_on = self._colors[color]
        color_off = self._colors[NotebookAnalyser.COLOR_WHITE]
        print("{}{}{}: {}".format(color_on, message, color_off, file))

    def clean_checkpoints(self, directory):
        for checkpoint_path in sorted(glob.glob(os.path.join(directory, NotebookAnalyser.CHECKPOINT_MASK))):

            workfile_dir = os.path.dirname(os.path.dirname(checkpoint_path))
            workfile_name = os.path.basename(checkpoint_path)[:-NotebookAnalyser.CHECKPOINT_MASK_LEN] + NotebookAnalyser.NOTEBOOK_SUFFIX
            workfile_path = os.path.join(workfile_dir, workfile_name)

            status = ""
            if not os.path.isfile(workfile_path):
                if self._verbose:
                    self.log(NotebookAnalyser.MESSAGE_ORPHANED, workfile_path, NotebookAnalyser.COLOR_RED)
            else:
                checkpoint_stat = os.stat(checkpoint_path)
                workfile_stat = os.stat(workfile_path)

                modified = workfile_stat.st_size != checkpoint_stat.st_size

                if not modified:
                    checkpoint_hash = NotebookAnalyser.get_hash(checkpoint_path)
                    workfile_hash = NotebookAnalyser.get_hash(workfile_path)
                    modified = checkpoint_hash != workfile_hash

                if modified:
                    if self._verbose:
                        self.log(NotebookAnalyser.MESSAGE_MODIFIED, workfile_path, NotebookAnalyser.COLOR_YELLOW)
                else:
                    self.log(NotebookAnalyser.MESSAGE_DELETED, checkpoint_path, NotebookAnalyser.COLOR_GREEN)
                    if not self._dry_run:
                        os.remove(checkpoint_path)

        if not self._dry_run and not os.listdir(directory):
            self.log(NotebookAnalyser.MESSAGE_DELETED, directory, NotebookAnalyser.COLOR_GREEN)
            os.rmdir(directory)

    def clean_checkpoints_recursively(self, directory):
        for (root, subdirs, files) in os.walk(directory):
            subdirs.sort() # INFO: traverse alphabetically
            if NotebookAnalyser.CHECKPOINT_DIR in subdirs:
                subdirs.remove(NotebookAnalyser.CHECKPOINT_DIR) # INFO: don't recurse there
                self.clean_checkpoints(os.path.join(root, NotebookAnalyser.CHECKPOINT_DIR))


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Remove checkpointed versions of those jupyter notebooks that are identical to their working copies.",
        epilog="""Notebooks will be reported as either
        "DELETED" if the working copy and checkpointed version are identical
        (checkpoint will be deleted),
        "missing" if there is a checkpoint but no corresponding working file can be found
        or "modified" if notebook and the checkpoint are not byte-to-byte identical.
        If removal of checkpoints results in empty ".ipynb_checkpoints" directory
        that directory is also deleted.
        """) #, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("dirs", metavar="DIR", type=str, nargs="*", default=".", help="directories to search")
    parser.add_argument("-d", "--dry-run", action="store_true", help="only print messages, don't perform any removals")
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose mode")
    parser.add_argument("-c", "--color", action="store_true", help="colorful mode")
    args = parser.parse_args()

    analyser = NotebookAnalyser(args.dry_run, args.verbose, args.color)
    for directory in args.dirs:
        analyser.clean_checkpoints_recursively(directory)

if __name__ == "__main__":
    main()