#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2022 fjord-technologies
# SPDX-License-Identifier: GPL-3.0-or-later
"""fd-thallium-admin"""

import json
import logging
import os
import sys

import argparse
import getpass
import hvac

LOG                 = logging.getLogger('fd-thallium-admin')
ORGANIZATION_DOMAIN = 'commandersact.com'
_SUBCMDS            = {}


class FdThalliumAdminExit(SystemExit):
    pass


def _load_conf_file():
    conf_path = os.path.expanduser('~/.config/fd-thallium-admin')
    if not os.path.isfile(conf_path):
        return
    with open(conf_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            if line.startswith('export '):
                line = line[7:]
            if '=' not in line:
                continue
            key, value = line.split('=', 1)
            key   = key.strip()
            value = value.strip()
            if (value.startswith('"') and value.endswith('"')) or \
               (value.startswith("'") and value.endswith("'")):
                value = value[1:-1]
            if key and key not in os.environ:
                os.environ[key] = value


def argv_parse_check():
    """
    Parse (and check a little) command line parameters
    """
    _load_conf_file()

    parser          = argparse.ArgumentParser()

    parser.add_argument("-l",
                        dest    = 'loglevel',
                        default = 'info',   # warning: see affectation under
                        choices = ('critical', 'error', 'warning', 'info', 'debug'),
                        help    = ("Emit traces with LOGLEVEL details, must be one of:\t"
                                   "critical, error, warning, info, debug"))
    parser.add_argument("--uri",
                        "--thallium-uri",
                        dest    = 'uri',
                        default = os.environ.get('THALLIUM_URI'),
                        help    = "Thallium URI address")
    parser.add_argument("--role-id",
                        "--thallium-role-id",
                        dest    = 'thallium_role_id',
                        default = os.environ.get('THALLIUM_ROLE_ID'),
                        help    = "Thallium role ID")
    parser.add_argument("--secret-id",
                        "--thallium-secret-id",
                        dest    = 'secret_id',
                        default = os.environ.get('THALLIUM_SECRET_ID'),
                        help    = "Thallium secret ID")
    parser.add_argument("--token",
                        "--thallium-token",
                        dest    = 'token',
                        default = os.environ.get('THALLIUM_TOKEN') or os.environ.get('VAULT_TOKEN'),
                        help    = "Thallium token")

    subparsers = parser.add_subparsers(dest = 'subcommand',
                                       help = "Choice sub-command")

    for subcmd in _SUBCMDS.values():
        subcmd.load_subcmd_parser(subparsers)

    args          = parser.parse_args()
    args.loglevel = getattr(logging, args.loglevel.upper(), logging.INFO)

    if not args.uri:
        parser.error("missing variable THALLIUM_URI")

    if not args.token:
        if not args.thallium_role_id:
            parser.error("missing variable THALLIUM_ROLE_ID")

        if not args.secret_id:
            parser.error("missing variable THALLIUM_SECRET_ID")

    return args


class FdThalliumCleanKeys(object):
    def __init__(self, options, client, xpath = None, force = False):
        self.options = options
        self.client  = client
        self.path    = xpath
        self.force   = force

    @staticmethod
    def load_subcmd_parser(subparsers):
        parser = subparsers.add_parser('clean-keys',
                                       help = 'Remove useless keys')
        parser.add_argument("--path",
                            dest    = 'path',
                            default = '',
                            help    = "Begin clean keys from path instead of %(default)s")
        parser.add_argument("-a",
                            dest    = 'action',
                            choices = ('discover',),
                            default = 'discover',
                            help    = "Choose action discover, instead of %(default)s")

    def _walk(self, root):
        try:
            rs = self.client.secrets.kv.list_secrets(root)
        except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden):
            return
        if not rs or 'data' not in rs or 'keys' not in rs['data']:
            return

        for xdir in rs['data']['keys']:
            xpath = ("%s/%s" % (root.strip('/'), xdir)).lstrip('/')
            if xpath.endswith('/'):
                self._walk(xpath)
                continue

            if self.force:
                self.client.secrets.kv.delete_metadata_and_all_versions(xpath)
                LOG.warning("deleted file: %r", xpath)
                continue

            try:
                self.client.secrets.kv.read_secret_version(xpath, raise_on_deleted_version=True)
            except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden):
                continue

    def do_discover(self):
        root = self.path if self.path is not None else self.options.path
        LOG.info("scan KV depuis '%s'", root or '(racine)')
        if self.path is None:
            self._walk(self.options.path)
        else:
            self._walk(self.path)
        LOG.info("scan terminé")

    def __call__(self, action):
        return getattr(self, "do_%s" % action.replace('-', '_'))()


_SUBCMDS['clean-keys'] = FdThalliumCleanKeys


class FdThalliumUser(object):
    def __init__(self, options, client):
        self.options   = options
        self.client    = client
        self._ssh_keys = {'public': {}}

    @staticmethod
    def load_subcmd_parser(subparsers):
        parser = subparsers.add_parser('user',
                                       help = 'Create and clean users area')
        parser.add_argument("-a",
                            dest    = 'action',
                            choices = ('discover','export-ssh-public-keys', 'chg-pass'),
                            default = 'discover',
                            help    = "Choose action discover or export-ssh-public-keys, instead of %(default)s")
        parser.add_argument("-u",
                            dest    = 'username',
                            default = '',
                            help    = "Selected username, only for action chg-pass")

    def _fetch_user_entity(self, xid):
        return self.client.read("identity/entity/id/%s" % xid)

    def _fetch_user_entities(self):
        r  = {}

        rs = self.client.list("identity/entity/id")
        if not rs or 'data' not in rs or 'key_info' not in rs['data']:
            LOG.error("unable to fetch users entities")
            raise FdThalliumAdminExit(1)

        for k, v in rs['data']['key_info'].items():
            if v['name'].find('-at-') > -1:
                entity_data = self._fetch_user_entity(k)
                if not entity_data or 'data' not in entity_data:
                    continue

                r[v['name']] = {
                    'id':       k,
                    'metadata': entity_data['data'].get('metadata') or {},
                    'aliases':  {},
                }

                ref_aliases = r[v['name']]['aliases']
                for alias in (entity_data['data'].get('aliases') or []):
                    ref_aliases[alias['name']] = alias

        return r

    def _fetch_users(self):
        rs = self.client.list("auth/userpass/users")
        if not rs or 'data' not in rs or 'keys' not in rs['data']:
            LOG.error("unable to fetch users list")
            raise FdThalliumAdminExit(1)

        return rs['data']['keys']

    def _fetch_groups(self):
        r  = {}

        rs = self.client.list("identity/group/id")
        if not rs or 'data' not in rs or 'key_info' not in rs['data']:
            LOG.error("unable to fetch group")
            raise FdThalliumAdminExit(1)

        for k, v in rs['data']['key_info'].items():
            grp = self.client.read("identity/group/id/%s" % k)
            if not grp or not 'data' in grp:
                LOG.error("unable to fetch group id: %r", k)
                raise FdThalliumAdminExit(1)

            r[v['name']] = grp['data']

        return r

    def _fetch_auth_info(self, name):
        rs = self.client.sys.list_auth_methods()
        name = name.rstrip('/') + '/'
        if name in rs:
            return rs[name]

        return None

    def _update_groups(self, grps):
        for name, data in grps.items():
            if data.get('member_entity_ids'):
                data['member_entity_ids'] = list(set(data['member_entity_ids']))
            self.client.write("identity/group/id/%s" % data['id'], **data)
            LOG.info("groupe %r mis à jour (%d membres)",
                     name, len(data.get('member_entity_ids') or []))

    def _upsert_entity(self, name, metadata = None, policies = None, xid = None):
        if not metadata:
            metadata = {}

        if not policies:
            policies = []

        xdict = {'name': name,
                 'metadata': metadata,
                 'policies': policies}

        if xid:
            xdict['id'] = xid

        rs = self.client.write("identity/entity",
                               **xdict)

        if rs and 'data' in rs:
            return rs['data']

        return None

    def _upsert_entity_alias(self, name, canonical_id, mount_accessor, xid = None):
        xdict = {'name': name,
                 'canonical_id': canonical_id,
                 'mount_accessor': mount_accessor}

        if xid:
            xdict['id'] = xid

        try:
            rs = self.client.write("identity/entity-alias",
                                   **xdict)
        except hvac.exceptions.InvalidRequest:
            return None

        if rs and 'data' in rs:
            return rs['data']

        return None

    @staticmethod
    def _get_teams(grpslist):
        r = {}

        for grpname in grpslist.keys():
            if grpname.startswith("corp::team::"):
                r[grpname[12:]] = grpname

        return r

    def _create_user_area(self, user):
        self.client.secrets.kv.delete_metadata_and_all_versions("corp/users/%s/private/.motd" % user)
        self.client.secrets.kv.delete_metadata_and_all_versions("corp/users/%s/share/.motd" % user)
        self.client.secrets.kv.create_or_update_secret("corp/users/%s/private/.motd" % user,
                                                       {'msg': "your private area"})
        self.client.secrets.kv.create_or_update_secret("corp/users/%s/share/.motd" % user,
                                                       {'msg': "share what you want"})

        try:
            rs = self.client.secrets.kv.read_secret_version(
                "corp/users/%s/share/.ssh/public_keys" % user,
                raise_on_deleted_version=True
            )
            if 'data' in rs and 'data' in rs['data'] and rs['data']['data'].get('default'):
                self._ssh_keys['public'][user] = rs['data']['data']['default']
        except hvac.exceptions.InvalidPath as e:
            LOG.debug("Secret not found for %s: %s", user, str(e))
            self.client.secrets.kv.create_or_update_secret("corp/users/%s/share/.ssh/public_keys" % user,
                                                           {'default': ""})

    def _sync_ssh_keys(self, users):
        for user, key in self._ssh_keys['public'].items():
            self.client.secrets.kv.create_or_update_secret("si/ssh-keys/users/%s" % user,
                                                           {"public": key})
            LOG.debug("clé SSH synchronisée : %r", user)

        try:
            rs = self.client.secrets.kv.list_secrets("si/ssh-keys/users")
        except hvac.exceptions.InvalidPath as e:
            LOG.debug("No ssh keys to list in si/ssh-keys/users: %s", str(e))
            return

        if not rs or 'data' not in rs or 'keys' not in rs['data']:
            return

        for xdir in rs['data']['keys']:
            if xdir.endswith('/'):
                continue

            if xdir not in users:
                self.client.secrets.kv.delete_metadata_and_all_versions("si/ssh-keys/users/%s" % xdir)
                LOG.info("clé SSH supprimée (user inexistant) : %r", xdir)

    def _clean_user_area(self, users, xpath = "corp/users"):
        try:
            rs = self.client.secrets.kv.list_secrets(xpath)
        except hvac.exceptions.InvalidPath:
            return
        if not rs or 'data' not in rs or 'keys' not in rs['data']:
            return

        for xdir in rs['data']['keys']:
            if not xdir.endswith('/'):
                continue

            xdir  = xdir.rstrip('/')

            if xdir not in users:
                LOG.info("nettoyage area KV (user supprimé) : %r", xdir)
                FdThalliumCleanKeys(self.options,
                                    self.client,
                                    xpath = "%s/%s/" % (xpath, xdir),
                                    force = True)('discover')

    def _create_user_policies(self, users, teams):
        paths = {}
        paths["cubbyhole/*"] = {'capabilities': ['deny']}

        paths["secret/metadata"] = {'capabilities': ['list']}
        paths["secret/metadata/*"] = {'capabilities': []}
        paths["secret/metadata/corp/*"] = {'capabilities': ['list']}

        paths["secret/data/corp/share/*"] = {'capabilities':
                                             ['read',
                                              'create',
                                              'update',
                                              'delete']}
        paths["secret/metadata/corp/share/*"] = {'capabilities':
                                                 ['list',
                                                  'delete']}

        paths["secret/metadata/corp/share/.motd"] = {'capabilities': ['read']}
        paths["secret/data/corp/share/.motd"] = {'capabilities': ['read']}

        for xdir in ('private', 'share'):
            paths["secret/data/corp/users/{{identity.entity.name}}/%s/*" % xdir] = {'capabilities':
                                                                                    ['read',
                                                                                     'create',
                                                                                     'update',
                                                                                     'delete']}
            paths["secret/metadata/corp/users/{{identity.entity.name}}/%s/*" % xdir] = {'capabilities': ['read',
                                                                                                         'list',
                                                                                                         'delete']}
            paths["secret/data/corp/users/{{identity.entity.name}}/%s/.motd" % xdir] = {'capabilities': ['read']}
            paths["secret/metadata/corp/users/{{identity.entity.name}}/%s/.motd" % xdir] = {'capabilities': ['read']}

        for user in users:
            paths["secret/data/corp/users/%s/private/*" % user] = {'capabilities': []}
            paths["secret/metadata/corp/users/%s/private/*" % user] = {'capabilities': []}
            paths["secret/data/corp/users/%s/share/*" % user] = {'capabilities': ['read']}
            paths["secret/metadata/corp/users/%s/share/*" % user] = {'capabilities': ['read',
                                                                                      'list']}

        for team in teams.keys():
            paths["secret/data/corp/teams/%s/private/*" % team] = {'capabilities': []}
            paths["secret/metadata/corp/teams/%s/private/*" % team] = {'capabilities': []}
            paths["secret/data/corp/teams/%s/share/*" % team] = {'capabilities': ['read']}
            paths["secret/metadata/corp/teams/%s/share/*" % team] = {'capabilities': ['read',
                                                                                      'list']}

        self.client.sys.create_or_update_policy("corp::users", {'path': paths})

    def _create_teams_area(self, teams):
        for team, grpname in teams.items():
            paths = {}

            for xdir in ('private', 'share'):
                paths["secret/data/corp/teams/%s/%s/*" % (team, xdir)] = {'capabilities':
                                                                          ['read',
                                                                           'create',
                                                                           'update',
                                                                           'delete']}
                paths["secret/metadata/corp/teams/%s/%s/*" % (team, xdir)] = {'capabilities':
                                                                              ['read',
                                                                               'list',
                                                                               'delete']}
                paths["secret/data/corp/teams/%s/%s/.motd" % (team, xdir)] = {'capabilities': ['read']}
                paths["secret/metadata/corp/teams/%s/%s/.motd" % (team, xdir)] = {'capabilities': ['read']}

            self.client.sys.create_or_update_policy("grp::%s" % grpname, {'path': paths})

            self.client.secrets.kv.delete_metadata_and_all_versions("corp/teams/%s/private/.motd" % team)
            self.client.secrets.kv.create_or_update_secret("corp/teams/%s/private/.motd" % team,
                                                           {'msg': "private area for team: %s" % team})

            self.client.secrets.kv.delete_metadata_and_all_versions("corp/teams/%s/share/.motd" % team)
            self.client.secrets.kv.create_or_update_secret("corp/teams/%s/share/.motd" % team,
                                                           {'msg': "share what you want"})

    def do_discover(self):
        auth_info = self._fetch_auth_info('userpass')
        if not auth_info:
            LOG.error("unable to fetch userpass info")
            raise FdThalliumAdminExit(1)

        grpslist  = self._fetch_groups()
        entities  = self._fetch_user_entities()
        userlist  = self._fetch_users()

        LOG.info("discover — %d users, %d groupes, %d entités",
                 len(userlist), len(grpslist), len(entities))

        for grpvalues in grpslist.values():
            grpvalues['member_entity_ids'] = []

        users    = []
        skipped  = 0

        for username in userlist:
            userinfo       = self.client.auth.userpass.read_user(username)
            userpolicies   = []
            policies       = []
            userteams      = set()

            if userinfo and userinfo['data']:
                userpolicies = (userinfo['data'].get('token_policies')
                                or userinfo['data'].get('policies')
                                or [])

            (localpart,
             organization) = username.split("-at-", 1)
            entity_id      = None
            metadata       = {}

            if username in entities:
                entity_id = entities[username]['id']
                metadata  = (entities[username]['metadata'] or {}).copy()

            metadata.update({'organization': organization})

            if not userpolicies:
                LOG.error("missing policy for username: %r", username)
                skipped += 1
                continue

            if 'admin' not in userpolicies:
                if organization == ORGANIZATION_DOMAIN:
                    policies.append('corp::users')

            rs = self._upsert_entity(username,
                                     metadata = metadata,
                                     policies = policies,
                                     xid      = entity_id)

            if not entity_id:
                if rs:
                    entities[username] = rs
                    LOG.info("entity créée : %r", username)
                else:
                    LOG.error("unable to create entity: %r", username)
                    skipped += 1
                    continue
            else:
                LOG.debug("entity mise à jour : %r", username)

            entity_id   = entities[username]['id']
            ref_aliases = entities[username].get('aliases') or {}

            for userpolicy in userpolicies:
                if not userpolicy.startswith("grp::"):
                    continue

                grppolicy = userpolicy[5:]
                if grppolicy not in grpslist:
                    continue

                if grpslist[grppolicy].get('member_entity_ids') is None:
                    grpslist[grppolicy]['member_entity_ids'] = []

                grpslist[grppolicy]['member_entity_ids'].append(entity_id)
                if grppolicy.startswith("corp::team::"):
                    userteams.add(grppolicy[12:])

            metadata.update({'corp-teams': ','.join(userteams)})

            if len(userteams) == 1:
                metadata.update({'main-corp-team': list(userteams)[0]})

            self._upsert_entity(username,
                                metadata = metadata,
                                policies = policies,
                                xid      = entity_id)

            alias_id = None
            if ref_aliases and username in ref_aliases:
                alias_id = ref_aliases[username]['id']

            rs = self._upsert_entity_alias(username,
                                           canonical_id   = entities[username]['id'],
                                           mount_accessor = auth_info['accessor'],
                                           xid            = alias_id)

            if not rs and not alias_id:
                LOG.debug("alias already exists or skipped for %r", username)

            if organization == ORGANIZATION_DOMAIN:
                users.append(username)
                self._create_user_area(username)
                LOG.debug("area KV provisionnée : %r", username)

        LOG.info("users traités : %d, ignorés : %d", len(users), skipped)

        self._update_groups(grpslist)

        teams = self._get_teams(grpslist)

        self._create_user_policies(users, teams)
        LOG.info("policy corp::users régénérée (%d users, %d équipes)", len(users), len(teams))

        self._create_teams_area(teams)
        LOG.info("areas équipes mises à jour : %s", ', '.join(teams.keys()) or 'aucune')

        self.client.secrets.kv.delete_metadata_and_all_versions("corp/share/.motd")
        self.client.secrets.kv.create_or_update_secret("corp/share/.motd",
                                                       {'msg': "share what you want"})

        self._sync_ssh_keys(users)
        LOG.info("clés SSH synchronisées (%d)", len(self._ssh_keys['public']))

        self._clean_user_area(users)
        LOG.info("discover terminé")

    def do_export_ssh_public_keys(self):
        try:
            rs = self.client.secrets.kv.list_secrets("si/ssh-keys/users")
        except hvac.exceptions.InvalidPath as e:
            LOG.info("Aucune clé SSH à exporter")
            LOG.debug("No ssh keys found when listing: %s", str(e))
            return

        if not rs or 'data' not in rs or 'keys' not in rs['data']:
            LOG.error("unable to fetch users ssh public keys")
            raise FdThalliumAdminExit(1)

        entities  = self._fetch_user_entities()
        exported  = 0

        for user in rs['data']['keys']:
            if user not in entities \
               or not entities[user].get('metadata') \
               or 'main-corp-team' not in entities[user]['metadata']:
                continue

            key = "%s@corp::team::%s" % (user, entities[user]['metadata']['main-corp-team'])

            try:
                ssh_keys = self.client.secrets.kv.read_secret_version(
                    "si/ssh-keys/users/%s" % user,
                    raise_on_deleted_version=True
                )
            except hvac.exceptions.InvalidPath as e:
                LOG.debug("Missing ssh key for user %s: %s", user, str(e))
                continue

            if not ssh_keys \
               or 'data' not in ssh_keys \
               or 'data' not in ssh_keys['data'] \
               or 'public' not in ssh_keys['data']['data']:
                continue

            sys.stdout.write(json.dumps({key: ssh_keys['data']['data']['public']}) + "\n")
            exported += 1

        LOG.info("%d clé(s) SSH exportée(s)", exported)

    def do_chg_pass(self):
        if not self.options.username:
            LOG.error("invalid username: %r", self.options.username)
            raise FdThalliumAdminExit(1)

        username = self.options.username.replace('@', '-at-')

        try:
            self.client.auth.userpass.read_user(username)
        except hvac.exceptions.InvalidPath:
            LOG.error("user inexistant : %r", username)
            raise FdThalliumAdminExit(1)

        passwd = getpass.getpass('Password:')
        if not passwd or len(passwd) < 13:
            LOG.error("mot de passe trop court (minimum 13 caractères)")
            raise FdThalliumAdminExit(1)

        try:
            self.client.write("auth/userpass/users/%s/password" % username, password=passwd)
        except Exception as e:
            LOG.error("impossible de changer le mot de passe pour %r : %s", username, e)
            raise FdThalliumAdminExit(1)

        LOG.info("mot de passe mis à jour pour %r", username)

    def __call__(self, action):
        return getattr(self, "do_%s" % action.replace('-', '_'))()

_SUBCMDS['user'] = FdThalliumUser


def main(options):
    """
    Main function
    """
    xformat = "%(levelname)s:%(asctime)-15s: %(message)s"
    datefmt = '%Y-%m-%d %H:%M:%S'
    logging.basicConfig(level   = options.loglevel,
                        format  = xformat,
                        datefmt = datefmt)

    client = hvac.Client(url = options.uri)

    if options.token and options.token != '-':
        client.token = options.token
    else:
        client.auth.approle.login(role_id    = options.thallium_role_id,
                                  secret_id  = options.secret_id)

    rc = 0

    try:
        _SUBCMDS[options.subcommand](options, client)(options.action)
    except FdThalliumAdminExit as e:
        rc = e.code
    except (SystemExit, KeyboardInterrupt):
        rc = 255
    except IOError as e:
        rc = 5
        LOG.error(e)
    except Exception as e:
        rc = 6
        LOG.exception(e)

    return rc


if __name__ == '__main__':
    sys.exit(main(argv_parse_check()))
