#!/usr/bin/env python

import glob
import os
import os.path as osp
import sys
import re
import copy
import time
import math

from optparse import OptionParser

import memory_profiler as mp


def print_usage():
    print("Usage: %s <command> <options> <arguments>"
          % osp.basename(sys.argv[0]))

def get_action():
    """Pop first argument, check it is a valid action."""
    all_actions = ("run", "rm", "clean", "list", "plot")
    if len(sys.argv) <= 1:
        print_usage()
        sys.exit(1)
    if not sys.argv[1] in all_actions:
        print("Valid actions are: " + " ".join(all_actions))
        sys.exit(1)

    return sys.argv.pop(1)


def get_profile_filenames(args):
    """Return list of profile filenames.

    Parameters
    ==========
    args (list)
        list of filename or integer. An integer is the index of the
        profile in the list of existing profiles. 0 is the oldest,
        -1 in the more recent.
        Non-existing files cause a ValueError exception to be thrown.

    Returns
    =======
    filenames (list)
        list of existing memory profile filenames. It is guaranteed
        that an given file name will not appear twice in this list.
    """
    profiles = glob.glob("mprofile_??????????????.dat")
    profiles.sort()

    if args is "all":
        filenames = copy.copy(profiles)
    else:
        filenames = []
        for arg in args:
            if arg == "--":  # workaround
                continue
            try:
                index = int(arg)
            except ValueError:
                index = None
            if index is not None:
                try:
                    filename = profiles[index]
                except IndexError:
                    raise ValueError("Invalid index (non-existing file): %s" % arg)

                if filename not in filenames:
                    filenames.append(filename)
            else:
                if osp.isfile(arg):
                    if arg not in filenames:
                        filenames.append(arg)
                elif osp.isdir(arg):
                    raise ValueError("Path %s is a directory" % arg)
                else:
                    raise ValueError("File %s not found" % arg)

    # Add timestamp files, if any
    for filename in reversed(filenames):
        parts = osp.splitext(filename)
        timestamp_file = parts[0] + "_ts" + parts[1]
        if osp.isfile(timestamp_file) and timestamp_file not in filenames:
            filenames.append(timestamp_file)

    return filenames


def list_action():
    """Display existing profiles, with indices."""
    parser = OptionParser(version=mp.__version__)
    parser.disable_interspersed_args()

    (options, args) = parser.parse_args()

    if len(args) > 0:
        print("This command takes no argument.")
        sys.exit(1)

    filenames = get_profile_filenames("all")
    for n, filename in enumerate(filenames):
        ts = osp.splitext(filename)[0].split('_')[-1]
        print("{index} {filename} {hour}:{min}:{sec} {day}/{month}/{year}"
              .format(index=n, filename=filename,
                      year=ts[:4], month=ts[4:6], day=ts[6:8],
                      hour=ts[8:10], min=ts[10:12], sec=ts[12:14]))


def rm_action():
    """TODO: merge with clean_action (@pgervais)"""
    parser = OptionParser(version=mp.__version__)
    parser.disable_interspersed_args()
    parser.add_option("--dry-run", dest="dry_run", default=False,
                      action="store_true",
                      help="""Show what will be done, without actually doing it.""")

    (options, args) = parser.parse_args()

    if len(args) == 0:
        print("A profile to remove must be provided (number or filename)")
        sys.exit(1)

    filenames = get_profile_filenames(args)
    if options.dry_run:
        print("Files to be removed: ")
        for filename in filenames:
            print(filename)
    else:
        for filename in filenames:
            os.remove(filename)


def clean_action():
    """Remove every profile file in current directory."""
    parser = OptionParser(version=mp.__version__)
    parser.disable_interspersed_args()
    parser.add_option("--dry-run", dest="dry_run", default=False,
                      action="store_true",
                      help="""Show what will be done, without actually doing it.""")

    (options, args) = parser.parse_args()

    if len(args) > 0:
        print("This command takes no argument.")
        sys.exit(1)

    filenames = get_profile_filenames("all")
    if options.dry_run:
        print("Files to be removed: ")
        for filename in filenames:
            print(filename)
    else:
        for filename in filenames:
            os.remove(filename)


def get_cmd_line(args):
    """Given a set or arguments, compute command-line."""
    blanks = set(' \t')
    args = [s if blanks.isdisjoint(s) else "'" + s + "'" for s in args]
    return ' '.join(args)


def run_action():
    import time, subprocess
    parser = OptionParser(version=mp.__version__)
    parser.disable_interspersed_args()
    parser.add_option("--python", dest="python", default=False,
                      action="store_true",
                      help="""Activates extra features when the profiled executable is
                      a Python program (currently: function timestamping.)""")
    parser.add_option("--nopython", dest="nopython", default=False,
                      action="store_true",
                      help="""Disables extra features when the profiled executable is
                      a Python program (currently: function timestamping.)""")
    parser.add_option("--interval", "-T", dest="interval", default="0.1",
                      type="float", action="store",
                      help="Sampling period (in seconds)")
    parser.add_option("--include-children", "-C", dest="include_children",
                      default=False, action="store_true",
                      help="""Monitors forked processes as well (sum up all process memory)""")

    (options, args) = parser.parse_args()
    print("{1}: Sampling memory every {0.interval}s".format(
        options, osp.basename(sys.argv[0])))

    if len(args) == 0:
        print("A program to run must be provided. Use -h for help")
        sys.exit(1)

    ## Output results in a file called "mprofile_<YYYYMMDDhhmmss>.dat" (where
    ## <YYYYMMDDhhmmss> is the date-time of the program start) in the current
    ## directory. This file contains the process memory consumption, in Mb (one
    ## value per line). Memory is sampled twice each second."""

    suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
    mprofile_output = "mprofile_%s.dat" % suffix

    # .. TODO: more than one script as argument ? ..
    if args[0].endswith('.py') and not options.nopython:
        options.python = True
    if options.python:
        print("running as a Python program...")
        if not args[0].startswith("python"):
            args.insert(0, "python")
        cmd_line = get_cmd_line(args)
        args[1:1] = ("-m", "memory_profiler", "--timestamp",
                     "-o", mprofile_output)
        p = subprocess.Popen(args)
    else:
        cmd_line = get_cmd_line(args)
        p = subprocess.Popen(args)

    with open(mprofile_output, "a") as f:
        f.write("CMDLINE {0}\n".format(cmd_line))
        mp.memory_usage(proc=p, interval=options.interval, timestamps=True,
                         include_children=options.include_children, stream=f)


def add_brackets(xloc, yloc, xshift=0, color="r", label=None):
    """Add two brackets on the memory line plot.

    This function uses the current figure.

    Parameters
    ==========
    xloc: tuple with 2 values
        brackets location (on horizontal axis).
    yloc: tuple with 2 values
        brackets location (on vertical axis)
    xshift: float
        value to subtract to xloc.
    """
    try:
        import pylab as pl
    except ImportError:
        print("matplotlib is needed for plotting.")
        sys.exit(1)
    height_ratio = 20.
    vsize = (pl.ylim()[1] - pl.ylim()[0]) / height_ratio
    hsize = (pl.xlim()[1] - pl.xlim()[0]) / (3.*height_ratio)

    bracket_x = pl.asarray([hsize, 0, 0, hsize])
    bracket_y = pl.asarray([vsize, vsize, -vsize, -vsize])

    # Matplotlib workaround: labels starting with _ aren't displayed
    if label[0] == '_':
        label = ' ' + label
    pl.plot(bracket_x + xloc[0] - xshift, bracket_y + yloc[0],
            "-" + color, linewidth=2, label=label)
    pl.plot(-bracket_x + xloc[1] - xshift, bracket_y + yloc[1],
            "-" + color, linewidth=2 )

    # TODO: use matplotlib.patches.Polygon to draw a colored background for
    # each function.

    # with maplotlib 1.2, use matplotlib.path.Path to create proper markers
    # see http://matplotlib.org/examples/pylab_examples/marker_path.html
    # This works with matplotlib 0.99.1
    ## pl.plot(xloc[0], yloc[0], "<"+color, markersize=7, label=label)
    ## pl.plot(xloc[1], yloc[1], ">"+color, markersize=7)


def read_mprofile_file(filename):
    """Read an mprofile file and return its content.

    Returns
    =======
    content: dict
        Keys:

        - "mem_usage": (list) memory usage values, in MiB
        - "timestamp": (list) time instant for each memory usage value, in
            second
        - "func_timestamp": (dict) for each function, timestamps and memory
            usage upon entering and exiting.
        - 'cmd_line': (str) command-line ran for this profile.
    """
    func_ts = {}
    mem_usage = []
    timestamp = []
    cmd_line = None
    f = open(filename, "r")
    for l in f:
        field, value = l.split(' ', 1)
        if field == "MEM":
            # mem, timestamp
            values = value.split(' ')
            mem_usage.append(float(values[0]))
            timestamp.append(float(values[1]))

        elif field == "FUNC":
            values = value.split(' ')
            f_name, mem_start, start, mem_end, end = values[:5]
            ts = func_ts.get(f_name, [])
            ts.append([float(start), float(end),
                       float(mem_start), float(mem_end)])
            func_ts[f_name] = ts

        elif field == "CMDLINE":
            cmd_line = value
        else:
            pass
    f.close()

    return {"mem_usage": mem_usage, "timestamp": timestamp,
            "func_timestamp": func_ts, 'filename': filename,
            'cmd_line': cmd_line}



def plot_file(filename, index=0, timestamps=True):
    try:
        import pylab as pl
    except ImportError:
        print("matplotlib is needed for plotting.")
        sys.exit(1)
    import numpy as np  # pylab requires numpy anyway
    mprofile = read_mprofile_file(filename)

    if len(mprofile['timestamp']) == 0:
        print('** No memory usage values have been found in the profile '
              'file.**\nFile path: {0}\n'
              'File may be empty or invalid.\n'
              'It can be deleted with "mprof rm {0}"'.format(
              mprofile['filename']))
        sys.exit(0)

    # Merge function timestamps and memory usage together
    ts = mprofile['func_timestamp']
    t = mprofile['timestamp']
    mem = mprofile['mem_usage']

    if len(ts) > 0:
        for values in ts.values():
            for v in values:
                t.extend(v[:2])
                mem.extend(v[2:4])

    mem = np.asarray(mem)
    t = np.asarray(t)
    ind = t.argsort()
    mem = mem[ind]
    t = t[ind]

    # Plot curves
    global_start = float(t[0])
    t = t - global_start

    max_mem = mem.max()
    max_mem_ind = mem.argmax()

    all_colors=("c", "y", "g", "r", "b")
    mem_line_colors=("k", "b", "r", "g", "c", "y", "m")
    mem_line_label = time.strftime("%d / %m / %Y - start at %H:%M:%S",
                                   time.localtime(global_start)) \
                                   + ".{0:03d}".format(int(round(math.modf(global_start)[0]*1000)))

    pl.plot(t, mem, "+-" + mem_line_colors[index % len(mem_line_colors)],
            label=mem_line_label)

    bottom, top = pl.ylim()
    bottom += 0.001
    top -= 0.001

    # plot timestamps, if any
    if len(ts) > 0 and timestamps:
        func_num = 0
        for f, exec_ts in ts.items():
            for execution in exec_ts:
                add_brackets(execution[:2], execution[2:], xshift=global_start,
                             color= all_colors[func_num % len(all_colors)],
                             label=f.split(".")[-1]
                             + " %.3fs" % (execution[1] - execution[0]))
            func_num += 1

    if timestamps:
        pl.hlines(max_mem,
                  pl.xlim()[0] + 0.001, pl.xlim()[1] - 0.001,
                  colors="r", linestyles="--")
        pl.vlines(t[max_mem_ind], bottom, top,
                  colors="r", linestyles="--")
    return mprofile


def plot_action():
    try:
        import pylab as pl
    except ImportError:
        print("matplotlib is needed for plotting.")
        sys.exit(1)

    parser = OptionParser(version=mp.__version__)
    parser.disable_interspersed_args()
    parser.add_option("--title", "-t", dest="title", default=None,
                      type="str", action="store",
                      help="String shown as plot title")
    parser.add_option("--no-function-ts", "-n", dest="no_timestamps",
                      default=False, action="store_true",
                      help="Do not display function timestamps on plot.")
    parser.add_option("--output", "-o",
                      help="Save plot to file instead of displaying it.")
    (options, args) = parser.parse_args()

    profiles = glob.glob("mprofile_??????????????.dat")
    profiles.sort()

    if len(args) == 0:
        if len(profiles) == 0:
            print("No input file found. \nThis program looks for "
                  "mprofile_*.dat files, generated by the "
                  "'mprof run' command.")
            sys.exit(-1)
        filenames = [profiles[-1]]
    else:
        filenames = []
        for arg in args:
            if osp.exists(arg):
                if not arg in filenames:
                    filenames.append(arg)
            else:
                try:
                    n = int(arg)
                except ValueError:
                    print("Input file not found: " + arg)
                if not profiles[n] in filenames:
                    filenames.append(profiles[n])

    pl.figure(figsize=(14, 6), dpi=90)
    if len(filenames) > 1 or options.no_timestamps:
        timestamps = False
    else:
        timestamps = True
    for n, filename in enumerate(filenames):
        mprofile = plot_file(filename, index=n, timestamps=timestamps)
    pl.xlabel("time (in seconds)")
    pl.ylabel("memory used (in MiB)")

    if options.title is None and len(filenames) == 1:
        pl.title(mprofile['cmd_line'])
    else:
        if options.title is not None:
            pl.title(options.title)

    ax = pl.gca()
    # place legend within the plot, make partially transparent in
    # case it obscures part of the lineplot
    leg = ax.legend(loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)
    pl.grid()
    if options.output:
        pl.savefig(options.output)
    else:
        pl.show()

if __name__ == "__main__":
    # Workaround for optparse limitation: insert -- before first negative
    # number found.
    negint = re.compile("-[0-9]+")
    for n, arg in enumerate(sys.argv):
        if negint.match(arg):
            sys.argv.insert(n, "--")
            break
    actions = {"rm": rm_action,
               "clean": clean_action,
               "list": list_action,
               "run": run_action,
               "plot": plot_action}
    actions[get_action()]()
