#!/usr/bin/env python
# Copyright (C) 2008 Aaron Bentley
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import cmd
from inspect import getdoc, getargspec
import os.path
import readline
import shlex
import stat
import string
import sys
import tempfile
import time
import shutil

from bzrlib import trace
trace.enable_default_logging()

from bzrlib import errors as bzr_errors, msgeditor, osutils, urlutils
from bzrlib.config import config_dir, ensure_config_dir_exists
from bzrlib.plugin import load_plugins
from bzrlib.transport import get_transport


def note(string):
    sys.stderr.write(string + '\n')


class UserError(Exception):
    pass

def command(**kwargs):
    options = kwargs
    def dec_instance(func):
        def decorator(self, *args, **kwargs):
            if options.get('needs_transport', True):
                if self.transport is None:
                    raise UserError('No transport open.')
            try:
                return func(self, *args, **kwargs)
            except bzr_errors.ConnectionReset:
                self.transport = None
                raise UserError("Connection reset/could not connect")
        args, varargs, varkw, defaults = getargspec(func)
        decorator._args = []
        if defaults is None:
            default_start = len(args)
        else:
            default_start = len(args) - len(defaults)
        for pos, arg in enumerate(args):
            if pos == 0:
                continue
            if pos < default_start:
                decorator._args.append(arg)
            else:
                decorator._args.append(arg + '+')
        decorator.__doc__ = func.__doc__
        return decorator
    return dec_instance


class PromptCmd(cmd.Cmd):

    def __init__(self, location=None, first_command=None):
        cmd.Cmd.__init__(self)
        self.possible_transports = []
        self.transport = None
        self._looping = False
        ensure_config_dir_exists()
        self.history_file = osutils.pathjoin(config_dir(),
            'hitchhiker-history')
        readline.set_completer_delims(string.whitespace)
        if os.access(self.history_file, os.R_OK) and \
            os.path.isfile(self.history_file):
            readline.read_history_file(self.history_file)
        if location is not None:
            self.cmd_open(location)
        if first_command is not None:
            self._run_from_args(first_command)

    def cmdloop(self):
        self._looping = True
        try:
            cmd.Cmd.cmdloop(self)
        finally:
            self._looping = False

    def _get_cmd(self, name):
        name = name.replace('-', '_')
        handler = getattr(self, 'cmd_' + name, None)
        if handler is None:
            raise UserError('No such comand: %s' % name)
        return handler

    def default(self, line):
        args = shlex.split(line)
        self._run_from_args(args)

    def _run_from_args(self, args):
        try:
            handler = self._get_cmd(args[0])
            kwargs = {}
            for pos, arg in enumerate(handler._args):
                try:
                    kwargs[arg.rstrip('+')] = args[pos + 1]
                except IndexError:
                    if arg.endswith('+'):
                        continue
                    raise UserError('Required parameter missing: %s' % arg)
            handler(**kwargs)
        except UserError, e:
            print e.args[0]
        except KeyboardInterrupt:
            print "Command interrupted.",
            if self._looping:
                print " (^C again to quit.)"
            else:
                print

    def do_help(self, line):
        self.default('help ' + line)

    @command(needs_transport=False)
    def cmd_help(self, topic=None):
        """Print help.

        If no topic is supplied, available commands are listed.
        If a topic is supplied, help for that command is provided.
        """
        if topic is not None:
            cmd = self._get_cmd(topic)
            usage = [topic]
            for arg in cmd._args:
                if arg.endswith('+'):
                    usage.append('[%s]' % arg.upper().rstrip('+'))
                else:
                    usage.append(arg.upper())
            print "Usage: %s" % ' '.join(usage)
            docstring = getdoc(cmd)
            if docstring is None:
                print 'No documentation is available.'
            else:
                print docstring
            return
        print "The following commands are supported:"
        for name, command in self._list_commands():
            print name

    def _list_commands(self):
        for key in dir(self):
            member = getattr(self, key)
            if getattr(member, '_args', None) is not None:
                yield key[4:].replace('_', '-'), member

    @command(needs_transport=False)
    def cmd_open(self, location):
        """Open a new location.

        The location may be any URL supported by Bazaar.
        """
        try:
            self.transport = get_transport(location,
                possible_transports=self.possible_transports)
        except (bzr_errors.InvalidURL, bzr_errors.InvalidLocationAlias), e:
            raise UserError(e)
        note('Opened %s' % self.transport.base)

    @command()
    def cmd_ls(self, relpath='.'):
        """List files and directories at current location.

        No distiction is made between files or directories.  (Sorry.)
        """
        try:
            for name in sorted(self.transport.list_dir(relpath)):
                print name
        except bzr_errors.NoSuchFile, e:
            raise UserError(e)
        except bzr_errors.TransportNotPossible:
            raise UserError('This transport is not listable.')

    @command()
    def cmd_lsl(self, relpath='.'):
        """List files and directories at current location. (detailed)

        Prints permissions, human-readable size, filename.
        """
        try:
            for name in sorted(self.transport.list_dir(relpath)):
                path = osutils.pathjoin(relpath, name)
                self.print_listing(path)
        except bzr_errors.NoSuchFile, e:
            raise UserError(e)
        except bzr_errors.TransportNotPossible:
            raise UserError('This transport is not listable.')

    def print_listing(self, name):
        tstat = self.transport.stat(name)
        size = float(tstat.st_size)
        if stat.S_ISDIR(tstat.st_mode):
            mode_str = 'd'
        else:
            mode_str = ''
        for divisor in [64, 8, 1]:
            for letter, val in [('r', 4), ('w', 2), ('x', 1)]:
                if val & (tstat.st_mode / divisor):
                    mode_str += letter
                else:
                    mode_str += '-'
        for unit in ['', 'K', 'M', 'G']:
            if size < 1024:
                if unit == '' or size > 10:
                    size_str = '%.0f%s' % (size, unit)
                else:
                    size_str = '%.1f%s' % (size, unit)
                break
            size /= 1024
        print '%10s%7s %s' % (mode_str, size_str, name)

    @command()
    def cmd_cd(self, relpath):
        """Move around at a location.

        This updates the current location the same way cd does.
        """
        self.transport = self.transport.clone(relpath)

    def do_EOF(self, *args):
        print
        readline.write_history_file(self.history_file)
        sys.exit()

    def _retrieve_file(self, relpath, outfile):
        try:
            return osutils.pumpfile(self.transport.get(relpath), outfile)
        except bzr_errors.ReadError, e:
            raise UserError(str(e))
        except bzr_errors.NoSuchFile, e:
            raise UserError(str(e))

    @command(needs_transport=False)
    def cmd_exit(self):
        """Exit the program."""
        readline.write_history_file(self.history_file)
        sys.exit()

    @command()
    def cmd_cat(self, relpath):
        """Print the contents of a file to screen."""
        self._retrieve_file(relpath, sys.stdout)

    def get_basename(self, relpath):
        full_url = self.transport.clone(relpath).base
        return urlutils.basename(full_url)

    @command()
    def cmd_get(self, relpath):
        """Retrieve a file."""
        outfile = open(self.get_basename(relpath), 'wb')
        try:
            len_copied = self._retrieve_file(relpath, outfile)
        finally:
            outfile.close()
        note("%d bytes copied" % len_copied)

    @command()
    def cmd_mirror(self, relpath='.', target=None):
        from_transport = self.transport.clone(relpath)
        if target is None:
            target = self.get_basename(relpath)
        to_transport = get_transport(target,
            possible_transports=self.possible_transports)
        try:
            to_transport.mkdir('.')
        except bzr_errors.FileExists, e:
            note(str(e))
        else:
            from_transport.copy_tree_to_transport(to_transport)

    def put_path(self, path, target):
        infile = open(path, 'rb')
        try:
            return self.transport.put_file(target, infile)
        finally:
            infile.close()

    @command()
    def cmd_put(self, path, target=None):
        """Copy a local file to a remote location."""
        if target is None:
            target = os.path.basename(path)
        len_copied = self.put_path(path, target)
        note("%d bytes copied" % len_copied)

    @command()
    def cmd_edit(self, relpath):
        """Download, edit, and re-upload a file."""
        tmpdir = tempfile.mkdtemp()
        try:
            filename = osutils.pathjoin(tmpdir, self.get_basename(relpath))
            outfile = open(filename, 'wb')
            try:
                note('Retrieving file...')
                len_copied = self._retrieve_file(relpath, outfile)
            finally:
                outfile.close()
            msgeditor._run_editor(filename)
            note('Uploading file...')
            self.put_path(filename, relpath)
        finally:
            shutil.rmtree(tmpdir)

    @command()
    def cmd_rename(self, source, target):
        """Rename a file from source to target."""
        self.transport.rename(source, target)

    @command()
    def cmd_mkdir(self, relpath):
        try:
            self.transport.mkdir(relpath)
        except bzr_errors.PermissionDenied, e:
            note(str(e))

    @command()
    def cmd_rm(self, relpath):
        """Delete a file."""
        try:
            self.transport.delete(relpath)
        except (bzr_errors.TransportError, bzr_errors.NoSuchFile), e:
            note(str(e))

    @command()
    def cmd_rmtree(self, relpath):
        """Delete a directory and its contents."""
        try:
            self.transport.delete_tree(relpath)
        except (bzr_errors.NoSuchFile), e:
            note(str(e))

    @command()
    def cmd_info(self):
        """Print information about the current location.

        URL, implementation class and smart protocol version are printed,
        if applicable.
        """
        print 'Location: %s' % self.transport.base
        print ('Transport implementation: %s'
               % self.transport.__class__.__name__)
        print 'Credentials: %r' % (self.transport._get_credentials(),)
        try:
            smart_medium = self.transport.get_smart_medium()
        except bzr_errors.NoSmartMedium:
            print 'Protocol does not support smart server.'
        else:
            try:
                print ("Smart protocol version: %s" %
                       smart_medium.protocol_version())
            except bzr_errors.SmartProtocolError:
                print "Unable to use smart protocol."
            else:
                remote_path = getattr(smart_medium, '_bzr_remote_path', None)
                if remote_path is not None:
                    print ("Remote executible path: %s" % remote_path)

    @command()
    def cmd_connect_and_wait(self):
        """Establish a connection and wait forever.

        Useful for SSH multiplexing.
        """
        self.transport.list_dir('.')
        note('Connection established.  ^C to end.')
        time.sleep(60 * 60 * 24 * 10000)


def main():
    # we want things like lp directories to work
    load_plugins()
    trace.enable_default_logging()
    if len(sys.argv) > 2:
        first_command = sys.argv[2:]
    else:
        first_command = None
    if len(sys.argv) > 1:
        location = sys.argv[1]
    else:
        location = None
    try:
        prompt = PromptCmd(location, first_command)
    except UserError, e:
        print e.args[0]
    else:
        if first_command is None:
            try:
                prompt.cmdloop()
            except KeyboardInterrupt:
                sys.stderr.write("\nInterrupted.\n")
                sys.exit(2)


if __name__ == "__main__":
    main()
