#!/usr/bin/env python3

import argparse
import asyncio
import os
import sys

import pywatchman
from pywatchman.aioclient import AIOClient as WatchmanClient


def _with_timeout(task, timeout):
    if timeout > 0:
        return asyncio.wait_for(task, timeout)
    return task


class Subscription(object):
    """A single subscription."""

    def __init__(self, abspath, name):
        self.name = name  # Our name for this subscription
        self.abspath = abspath  # Full path to the dir being watched
        self.root = None  # Watched root
        self.relpath = None  # Path relative to root of thr dir being watched

    def __str__(self):
        return "Subscription(name=%s, root=%s, abspath=%s, relpath=%s)" % (
            self.name,
            self.root,
            self.abspath,
            self.relpath,
        )

    async def activate(self, client, patterns=None, fields=["name"]):
        """Activate the subscription. This function is a coroutine."""
        dir_to_watch = self.abspath

        if patterns:
            expr = ["anyof"]
            for p in patterns:
                expr.append(["match", p, "wholename", {"includedotfiles": True}])
        else:
            expr = ["true"]

        if not os.path.isdir(self.abspath):
            # Need to watch its parent
            dir_to_watch = os.path.dirname(self.abspath)
            expr = ["name", os.path.basename(self.abspath)]

        query = {"expression": expr, "fields": fields}

        watch = await client.query("watch-project", dir_to_watch)

        if "warning" in watch:
            print("WARNING: ", watch["warning"], file=sys.stderr)

        self.root = watch["watch"]
        if "relative_path" in watch:
            self.relpath = watch["relative_path"]
            expr = ["allof", ["dirname", self.relpath], expr]

        # get the initial clock value so that we only get updates
        query["since"] = (await client.query("clock", self.root))["clock"]

        await client.query("subscribe", self.root, self.name, query)

    async def get(self, client):
        data = await client.get_subscription(self.name, self.root)
        return data["files"]

    @classmethod
    def from_path(cls, path, name=None):
        """Create a subscription for the given path, using that path as the
        subscription's name if not specified."""
        name = name or path
        abspath = os.path.abspath(path)
        return cls(abspath, name)


class SubscriptionManager(object):
    def __init__(
        self, client, paths, patterns, fields, separator, relative, max_events
    ):
        self.subscriptions = {}
        self.listeners = []
        self.total_events = 0
        self.client = client
        self.paths = paths
        self.patterns = patterns
        self.fields = fields
        self.separator = separator
        self.relative = relative
        self.max_events = max_events

    async def activate(self):
        for path in self.paths:
            await self._add_subscription(path)
        self._init_listeners()

    def _init_listeners(self):
        async def process_loop(sub):
            while True:
                for rec in await sub.get(self.client):
                    print(self._format_record(rec))
                    sys.stdout.flush()
                self.total_events += 1
                if self.max_events > 0 and self.total_events >= self.max_events:
                    self._cancel_all()
                    break

        self.listeners = [
            asyncio.ensure_future(process_loop(sub))
            for _, sub in self.subscriptions.items()
        ]

    def _check_path(self, path):
        if path in self.subscriptions:
            raise ValueError("path %s already specified" % path)
        if not os.path.exists(path):
            raise ValueError("path %s does not exist" % path)

    async def _add_subscription(self, path):
        self._check_path(path)
        sub = Subscription.from_path(path)
        await sub.activate(self.client, self.patterns, self.fields)
        self.subscriptions[path] = sub

    def _cancel_all(self):
        for l in self.listeners:
            if not l.done():
                l.cancel()
        self.client.close()

    def _format_field(self, fname, fval):
        # TODO: The crazy string conversion won't be needed with BSER v2.
        if fname == "name":
            fval = os.path.relpath(fval, self.relative.encode("utf-8")).decode("utf-8")
        return str(fval)

    def _format_record(self, rec):
        if len(self.fields) == 1:
            return self._format_field(self.fields[0], rec)
        else:
            output = [self._format_field(f, rec[f]) for f in self.fields]
            return output.join(self.separator)


def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""
watchman-wait-aio waits for changes to files. It uses the watchman service to
efficiently and recursively watch your specified list of paths.

It is suitable for waiting for changes to files from shell scripts.

It can stop after a configurable number of events are observed. The default
is a single event.  You may also remove the limit and allow it to execute
continuously.

watchman-wait will print one event per line. The event information includes
your specified list of fields, with each field separated by a space (or your
choice of --separator).

Events are consolidated and settled by the watchman server before they are
dispatched to watchman-wait.

Exit Status:

The following exit status codes can be used to determine what caused
watchman-wait-aio to exit:

0  After successfully waiting for event(s)
1  In case of a runtime error of some kind
2  The -t/--timeout option was used and that amount of time passed
   before an event was received
3  Execution was interrupted (Ctrl-C)
    """,
    )
    parser.add_argument("path", type=str, nargs="+", help="path(s) to watch")
    parser.add_argument(
        "--relative",
        type=str,
        default=".",
        help="print paths relative to this dir (default=PWD)",
    )
    parser.add_argument(
        "--fields",
        type=lambda s: s.split(","),
        default=["name"],
        help="""
Comma separated list of file information fields to return.
The default is just the name.  For a list of possible fields, see:
https://facebook.github.io/watchman/docs/cmd/query.html#available-fields.""",
    )
    parser.add_argument(
        "-s",
        "--separator",
        type=str,
        default=" ",
        help="String to use as field separator for event output.",
    )
    parser.add_argument(
        "-0",
        "--null",
        action="store_true",
        help="""
Use a NUL byte as a field separator, takes precedence over --separator.""",
    )
    parser.add_argument(
        "-m",
        "--max-events",
        type=int,
        default=1,
        help="""
Set the maximum number of events that will be processed.  When the limit
is reached, watchman-wait will exit.  The default is 1.  Setting the
limit to 0 removes the limit, causing watchman-wait to execute indefinitely.""",
    )
    parser.add_argument(
        "-p",
        "--pattern",
        type=str,
        nargs="+",
        help="""
Only emit paths that match this list of patterns.  Patterns are
applied by the watchman server and are matched against the root-relative
paths.

You will almost certainly want to use quotes around your pattern list
so that your shell doesn't interpret the pattern.

The pattern syntax is wildmatch style; globbing with recursive matching
via '**'.""",
    )
    parser.add_argument(
        "-t",
        "--timeout",
        type=float,
        default=0,
        help="""
Exit if no events trigger within the specified timeout.  If timeout is
zero (the default) then keep running indefinitely.""",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    loop = asyncio.get_event_loop()
    mgr = None
    try:
        client = loop.run_until_complete(
            _with_timeout(WatchmanClient.from_socket(), args.timeout)
        )
        mgr = SubscriptionManager(
            client,
            args.path,
            args.pattern,
            args.fields,
            "\0" if args.null else args.separator,
            args.relative,
            args.max_events,
        )
        loop.run_until_complete(_with_timeout(mgr.activate(), args.timeout))
        # Start the subscription processing tasks
        loop.run_until_complete(
            _with_timeout(asyncio.gather(*mgr.listeners), args.timeout)
        )

    except pywatchman.CommandError as ex:
        print("watchman:", ex.msg, file=sys.stderr)
        if mgr is not None:
            mgr._cancel_all()
        sys.exit(1)

    except asyncio.TimeoutError:
        if mgr is not None:
            mgr._cancel_all()
        sys.exit(2)

    except KeyboardInterrupt:
        # suppress ugly stack trace when they Ctrl-C
        if mgr is not None:
            mgr._cancel_all()
        sys.exit(3)


if __name__ == "__main__":
    main()
