#!/usr/bin/env python

import ldap
import ldap.modlist
import copy
import re
import types
from base64 import b16encode
from base64 import b16decode
from base64 import b64encode
from base64 import b64decode

from dm.common.utility.loggingManager import LoggingManager
from dm.common.utility.configurationManager import ConfigurationManager
from dm.common.exceptions.configurationError import ConfigurationError
from dm.common.exceptions.internalError import InternalError
from dm.common.exceptions.objectNotFound import ObjectNotFound
from dm.common.exceptions.authenticationError import AuthenticationError
from dm.common.exceptions.communicationError import CommunicationError
from dm.common.exceptions.invalidArgument import InvalidArgument
from dm.common.exceptions.dmException import DmException
from ldapClient import LdapClient

class LdapUserManager(LdapClient):

    def __init__(self, serverUrl, adminDn, adminPasswordFile, userDnFormat, groupDnFormat, minGidNumber=None):
        LdapClient.__init__(self, serverUrl, adminDn, adminPasswordFile)
        self.userDnFormat = userDnFormat
        self.groupDnFormat = groupDnFormat
        self.minGidNumber = minGidNumber
        self.getLogger().debug('Min GID number: %s' % minGidNumber)
        # Remove first entry from the dn format to get tree base
        self.groupBaseDn = ','.join(groupDnFormat.split(',')[1:])
        self.getLogger().debug('Group base DN: %s' % self.groupBaseDn)

    @classmethod
    def decodePasswordHash(cls, b64EncodedString):
        decodedString = b64EncodedString.replace('{SHA}','')
        decodedString = b16encode(b64decode(decodedString)).upper()
        return decodedString
        
    @classmethod
    def encodePasswordHash(cls, passwordHash):
        encodedString = '{SHA}'+b64encode(b16decode(passwordHash))
        return encodedString
        
    @LdapClient.executeLdapCall
    def getUserInfo(self, username):
        """ Get user info. """
        userDn = self.userDnFormat % str(username)
        ldapClient = self.getLdapClient()
        resultList = ldapClient.search_s(userDn, ldap.SCOPE_BASE)
        userTuple = resultList[0]
        return userTuple

    def modifyUserInfo(self, username, attrDict):
        """ Modify user. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        userDn,userAttrs = self.getUserInfo(username)
        logger.debug('Modifying user %s attrs %s' % (username, attrDict))
        userAttrs2 = copy.copy(userAttrs)
        for name,value in attrDict.items():
            if not userAttrs2.has_key(name):
                raise InvalidArgument('No such attribute: %s' % name)
            if type(value) == types.ListType:
                userAttrs2[name] = value
            else:
                if name == 'userPassword':
                    value = self.encodePasswordHash(value)
                userAttrs2[name] = [str(value)]

        userLdif = ldap.modlist.modifyModlist(userAttrs, userAttrs2)
        ldapClient.modify_s(userDn, userLdif)

    def createGroup(self, name):
        """ Create group if it does not exist. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        name = str(name)
        try:
            groupDn = self.groupDnFormat % name
            logger.debug('Looking for group DN: %s' % groupDn)
            # this method will throw exception if group is not found
            resultList = ldapClient.search_s(groupDn, ldap.SCOPE_BASE)
            groupTuple = resultList[0]
            logger.debug('Group %s already exists' % groupTuple[0])
            return
        except ldap.NO_SUCH_OBJECT, ex:
            logger.debug('Group DN %s must be created' % groupDn)
        except Exception, ex:
            raise InternalError(exception=ex)

        # determine gidNumber: look through all entries to get max value,
        # then increment it
        # ldap should really be configured to handle gid's automatically,
        # and should prevent invalid entries
        try:
            logger.debug('Looking for max group id')
            resultList = ldapClient.search_s(self.groupBaseDn, ldap.SCOPE_ONELEVEL, attrlist=['gidNumber'])
            maxGid = 0
            if self.minGidNumber:
                maxGid = self.minGidNumber 
            for result in resultList:
                gidList = result[1].get('gidNumber', [])
                gid = 0
                if gidList:
                    gid = int(gidList[0])

                if gid > maxGid:
                    maxGid = gid
            gidNumber = str(maxGid + 1)
            logger.debug('Max GID is %s, new group id will be %s' % (maxGid, gidNumber)) 
        except Exception, ex:
            raise InternalError(exception=ex)
        
        attrs = {}
        attrs['objectclass'] = ['posixGroup','top']
        attrs['cn'] = name
        attrs['gidNumber'] = [gidNumber]
        attrs['memberUid'] = []
        try:
            groupLdif = ldap.modlist.addModlist(attrs)
            ldapClient.add_s(groupDn, groupLdif)
        except Exception, ex:
            logger.error('Could not add group %s: %s' % (groupDn, ex))
            raise InternalError(exception=ex)

    def addUserToGroup(self, username, groupName):
        """ Add user to group. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        username = str(username)
        groupName = str(groupName)
        try:
            groupDn = self.groupDnFormat % groupName
            resultList = ldapClient.search_s(groupDn, ldap.SCOPE_BASE)
            groupTuple = resultList[0]
            groupAttrs = groupTuple[1]
            memberUidList = groupAttrs.get('memberUid', [])
            if username in memberUidList:
                logger.debug('Group %s already contains user %s' % (groupName, username))
                return
        except Exception, ex:
            raise InternalError(exception=ex)
        logger.debug('Adding user %s to group %s' % (username, groupName))
        memberUidList2 = copy.copy(memberUidList)
        memberUidList2.append(username)
        groupAttrs2 = copy.copy(groupAttrs)
        groupAttrs2['memberUid'] = memberUidList2
        try:
            groupLdif = ldap.modlist.modifyModlist(groupAttrs, groupAttrs2)
            ldapClient.modify_s(groupDn, groupLdif)
        except Exception, ex:
            logger.error('Could not add user %s to group %s: %s' % (username, groupName, ex))
            raise InternalError(exception=ex)

    def deleteUserFromGroup(self, username, groupName):
        """ Remove user from group. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        username = str(username)
        groupName = str(groupName)
        try:
            groupDn = self.groupDnFormat % groupName
            resultList = ldapClient.search_s(groupDn, ldap.SCOPE_BASE)
            groupTuple = resultList[0]
            groupAttrs = groupTuple[1]
            memberUidList = groupAttrs.get('memberUid', [])
            if username not in memberUidList:
                logger.debug('Group %s does not contain user %s' % (groupName, username))
                return
        except Exception, ex:
            raise InternalError(exception=ex)
        logger.debug('Removing user %s from group %s' % (username, groupName))
        memberUidList2 = copy.copy(memberUidList)
        memberUidList2.remove(username)
        groupAttrs2 = copy.copy(groupAttrs)
        groupAttrs2['memberUid'] = memberUidList2
        try:
            groupLdif = ldap.modlist.modifyModlist(groupAttrs, groupAttrs2)
            ldapClient.modify_s(groupDn, groupLdif)
        except Exception, ex:
            logger.error('Could not remove user %s from group %s: %s' % (username, groupName, ex))
            raise InternalError(exception=ex)

    def getGroupInfo(self, groupName):
        """ Get given group info. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        groupName = str(groupName)
        try:
            groupDn = self.groupDnFormat % groupName
            resultList = ldapClient.search_s(groupDn, ldap.SCOPE_BASE)
            groupTuple = resultList[0]
            groupAttrs = groupTuple[1]
            return groupTuple
        except Exception, ex:
            raise InternalError(exception=ex)

    def setGroupUsers(self, groupName, usernameList):
        """ Set list of users for a given group. """
        logger = self.getLogger()
        ldapClient = self.getLdapClient()
        groupName = str(groupName)
        try:
            groupDn = self.groupDnFormat % groupName
            resultList = ldapClient.search_s(groupDn, ldap.SCOPE_BASE)
            groupTuple = resultList[0]
            groupAttrs = groupTuple[1]
        except Exception, ex:
            raise InternalError(exception=ex)
        logger.debug('Setting users %s for group %s' % (usernameList, groupName))
        memberUidList = []
        for username in usernameList:
            memberUidList.append(str(username))
        groupAttrs2 = copy.copy(groupAttrs)
        groupAttrs2['memberUid'] = memberUidList
        try:
            groupLdif = ldap.modlist.modifyModlist(groupAttrs, groupAttrs2)
            ldapClient.modify_s(groupDn, groupLdif)
        except Exception, ex:
            logger.error('Could not set users %s for group %s: %s' % (usernameList, groupName, ex))
            raise InternalError(exception=ex)

#######################################################################
# Testing.

if __name__ == '__main__':
    utility = LdapUserManager('ldaps://dmid-vm.xray.aps.anl.gov:636', 'uid=dmadmin,ou=People,o=aps.anl.gov,dc=aps,dc=anl,dc=gov', '/tmp/ldapPassword', userDnFormat='uid=%s,ou=DM,ou=People,o=aps.anl.gov,dc=aps,dc=anl,dc=gov', groupDnFormat='cn=%s,ou=DM,ou=Group,o=aps.anl.gov,dc=aps,dc=anl,dc=gov', minGidNumber=66000)
    print utility.getGroupInfo(u's1id-test03')
    user = utility.getUserInfo(u'd225159')
    print user
    utility.modifyUserInfo(u'd225159', {'homeDirectory' : '/data'})
    user = utility.getUserInfo(u'd225159')
    print user
    user = utility.getUserInfo(u'd65114')
    print user

    passwordHash = LdapUserManager.decodePasswordHash(user[1]['userPassword'][0])
    print passwordHash
    #print LdapUserManager.encodePasswordHash(passwordHash)

    #utility.addLocalUserToGroup(u'sveseli', u'id8i-test02')
    #print utility.getGroupInfo(u'id8i-test02')
    #utility.deleteLocalUserFromGroup(u'sveseli', u'id8i-test02')
    #print utility.getGroupInfo(u'id8i-test02')