/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.saml.processing.core.saml.v2.util;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.stream.XMLEventReader;
import org.apache.xml.security.encryption.EncryptedData;
import org.keycloak.dom.saml.v1.assertion.SAML11AssertionType;
import org.keycloak.dom.saml.v1.assertion.SAML11AttributeStatementType;
import org.keycloak.dom.saml.v1.assertion.SAML11AttributeType;
import org.keycloak.dom.saml.v1.assertion.SAML11ConditionsType;
import org.keycloak.dom.saml.v1.assertion.SAML11StatementAbstractType;
import org.keycloak.dom.saml.v2.assertion.AssertionType;
import org.keycloak.dom.saml.v2.assertion.AttributeStatementType;
import org.keycloak.dom.saml.v2.assertion.AttributeType;
import org.keycloak.dom.saml.v2.assertion.BaseIDAbstractType;
import org.keycloak.dom.saml.v2.assertion.ConditionsType;
import org.keycloak.dom.saml.v2.assertion.EncryptedAssertionType;
import org.keycloak.dom.saml.v2.assertion.EncryptedElementType;
import org.keycloak.dom.saml.v2.assertion.NameIDType;
import org.keycloak.dom.saml.v2.assertion.StatementAbstractType;
import org.keycloak.dom.saml.v2.assertion.SubjectType;
import org.keycloak.dom.saml.v2.protocol.ResponseType;
import org.keycloak.rotation.HardcodedKeyLocator;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.common.PicketLinkLogger;
import org.keycloak.saml.common.PicketLinkLoggerFactory;
import org.keycloak.saml.common.constants.GeneralConstants;
import org.keycloak.saml.common.constants.JBossSAMLConstants;
import org.keycloak.saml.common.exceptions.ConfigurationException;
import org.keycloak.saml.common.exceptions.ParsingException;
import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.common.exceptions.fed.IssueInstantMissingException;
import org.keycloak.saml.common.util.DocumentUtil;
import org.keycloak.saml.common.util.StaxParserUtil;
import org.keycloak.saml.common.util.StaxUtil;
import org.keycloak.saml.processing.api.saml.v2.sig.SAML2Signature;
import org.keycloak.saml.processing.core.parsers.saml.SAMLParser;
import org.keycloak.saml.processing.core.parsers.util.SAMLParserUtil;
import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
import org.keycloak.saml.processing.core.saml.v2.util.XMLTimeUtil;
import org.keycloak.saml.processing.core.saml.v2.writers.SAMLAssertionWriter;
import org.keycloak.saml.processing.core.util.JAXPValidationUtil;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;
import org.keycloak.saml.processing.core.util.XMLSignatureUtil;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;

public class AssertionUtil {
    private static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger();

    public static String asString(AssertionType assertion) throws ProcessingException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        SAMLAssertionWriter writer = new SAMLAssertionWriter(StaxUtil.getXMLStreamWriter(baos));
        writer.write(assertion);
        return new String(baos.toByteArray(), GeneralConstants.SAML_CHARSET);
    }

    public static Document asDocument(AssertionType assertion) throws ProcessingException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        SAMLAssertionWriter writer = new SAMLAssertionWriter(StaxUtil.getXMLStreamWriter(baos));
        writer.write(assertion);
        try {
            return DocumentUtil.getDocument(new ByteArrayInputStream(baos.toByteArray()));
        }
        catch (Exception e) {
            throw logger.processingError(e);
        }
    }

    public static SAML11AssertionType createSAML11Assertion(String id, XMLGregorianCalendar issueInstant, String issuer) {
        SAML11AssertionType assertion = new SAML11AssertionType(id, issueInstant);
        assertion.setIssuer(issuer);
        return assertion;
    }

    public static AssertionType createAssertion(String id, NameIDType issuer) {
        XMLGregorianCalendar issueInstant = XMLTimeUtil.getIssueInstant();
        AssertionType assertion = new AssertionType(id, issueInstant);
        assertion.setIssuer(issuer);
        return assertion;
    }

    public static SubjectType createAssertionSubject(String userName) {
        SubjectType assertionSubject = new SubjectType();
        SubjectType.STSubType subType = new SubjectType.STSubType();
        NameIDType anil = new NameIDType();
        anil.setValue(userName);
        subType.addBaseID((BaseIDAbstractType)anil);
        assertionSubject.setSubType(subType);
        return assertionSubject;
    }

    public static AttributeType createAttribute(String name, String nameFormat, Object ... attributeValues) {
        AttributeType att = new AttributeType(name);
        att.setNameFormat(nameFormat);
        if (attributeValues != null && attributeValues.length > 0) {
            for (Object attributeValue : attributeValues) {
                att.addAttributeValue(attributeValue);
            }
        }
        return att;
    }

    public static void createTimedConditions(AssertionType assertion, long durationInMilis) throws ConfigurationException, IssueInstantMissingException {
        XMLGregorianCalendar issueInstant = assertion.getIssueInstant();
        if (issueInstant == null) {
            throw new IssueInstantMissingException("PL00088: Null IssueInstant");
        }
        XMLGregorianCalendar assertionValidityLength = XMLTimeUtil.add(issueInstant, durationInMilis);
        ConditionsType conditionsType = new ConditionsType();
        conditionsType.setNotBefore(issueInstant);
        conditionsType.setNotOnOrAfter(assertionValidityLength);
        assertion.setConditions(conditionsType);
    }

    public static void createTimedConditions(AssertionType assertion, long durationInMilis, long clockSkew) throws ConfigurationException, IssueInstantMissingException {
        XMLGregorianCalendar issueInstant = assertion.getIssueInstant();
        if (issueInstant == null) {
            throw logger.samlIssueInstantMissingError();
        }
        XMLGregorianCalendar assertionValidityLength = XMLTimeUtil.add(issueInstant, durationInMilis + clockSkew);
        ConditionsType conditionsType = new ConditionsType();
        XMLGregorianCalendar beforeInstant = XMLTimeUtil.subtract(issueInstant, clockSkew);
        conditionsType.setNotBefore(beforeInstant);
        conditionsType.setNotOnOrAfter(assertionValidityLength);
        assertion.setConditions(conditionsType);
    }

    public static void createSAML11TimedConditions(SAML11AssertionType assertion, long durationInMilis, long clockSkew) throws ConfigurationException, IssueInstantMissingException {
        XMLGregorianCalendar issueInstant = assertion.getIssueInstant();
        if (issueInstant == null) {
            throw new IssueInstantMissingException("PL00088: Null IssueInstant");
        }
        XMLGregorianCalendar assertionValidityLength = XMLTimeUtil.add(issueInstant, durationInMilis + clockSkew);
        SAML11ConditionsType conditionsType = new SAML11ConditionsType();
        XMLGregorianCalendar beforeInstant = XMLTimeUtil.subtract(issueInstant, clockSkew);
        conditionsType.setNotBefore(beforeInstant);
        conditionsType.setNotOnOrAfter(assertionValidityLength);
        assertion.setConditions(conditionsType);
    }

    public static boolean isSignatureValid(Element element, PublicKey publicKey) {
        return AssertionUtil.isSignatureValid(element, new HardcodedKeyLocator(publicKey));
    }

    public static boolean isSignatureValid(Element element, KeyLocator keyLocator) {
        try {
            SAML2Signature.configureIdAttribute(element);
            Element signature = AssertionUtil.getSignature(element);
            if (signature != null) {
                return XMLSignatureUtil.validateSingleNode(signature, keyLocator);
            }
        }
        catch (Exception e) {
            logger.signatureAssertionValidationError(e);
        }
        return false;
    }

    public static boolean isSignedElement(Element element) {
        return AssertionUtil.getSignature(element) != null;
    }

    protected static Element getSignature(Element element) {
        return XMLSignatureUtil.getSignature(element);
    }

    public static boolean hasExpired(AssertionType assertion) throws ConfigurationException {
        boolean expiry = false;
        ConditionsType conditionsType = assertion.getConditions();
        if (conditionsType != null) {
            XMLGregorianCalendar now = XMLTimeUtil.getIssueInstant();
            XMLGregorianCalendar notBefore = conditionsType.getNotBefore();
            XMLGregorianCalendar notOnOrAfter = conditionsType.getNotOnOrAfter();
            if (notBefore != null) {
                logger.trace("Assertion: " + assertion.getID() + " ::Now=" + now.toXMLFormat() + " ::notBefore=" + notBefore.toXMLFormat());
            }
            if (notOnOrAfter != null) {
                logger.trace("Assertion: " + assertion.getID() + " ::Now=" + now.toXMLFormat() + " ::notOnOrAfter=" + String.valueOf(notOnOrAfter));
            }
            boolean bl = expiry = !XMLTimeUtil.isValid(now, notBefore, notOnOrAfter);
            if (expiry) {
                logger.samlAssertionExpired(assertion.getID());
            }
        }
        return expiry;
    }

    public static boolean hasExpired(AssertionType assertion, long clockSkewInMilis) throws ConfigurationException {
        boolean expiry = false;
        ConditionsType conditionsType = assertion.getConditions();
        if (conditionsType != null) {
            XMLGregorianCalendar now = XMLTimeUtil.getIssueInstant();
            XMLGregorianCalendar notBefore = conditionsType.getNotBefore();
            XMLGregorianCalendar updatedNotBefore = XMLTimeUtil.subtract(notBefore, clockSkewInMilis);
            XMLGregorianCalendar notOnOrAfter = conditionsType.getNotOnOrAfter();
            XMLGregorianCalendar updatedOnOrAfter = XMLTimeUtil.add(notOnOrAfter, clockSkewInMilis);
            logger.trace("Now=" + now.toXMLFormat() + " ::notBefore=" + notBefore.toXMLFormat() + " ::notOnOrAfter=" + String.valueOf(notOnOrAfter));
            boolean bl = expiry = !XMLTimeUtil.isValid(now, updatedNotBefore, updatedOnOrAfter);
            if (expiry) {
                logger.samlAssertionExpired(assertion.getID());
            }
        }
        return expiry;
    }

    public static boolean hasExpired(SAML11AssertionType assertion) throws ConfigurationException {
        boolean expiry = false;
        SAML11ConditionsType conditionsType = assertion.getConditions();
        if (conditionsType != null) {
            XMLGregorianCalendar now = XMLTimeUtil.getIssueInstant();
            XMLGregorianCalendar notBefore = conditionsType.getNotBefore();
            XMLGregorianCalendar notOnOrAfter = conditionsType.getNotOnOrAfter();
            logger.trace("Now=" + now.toXMLFormat() + " ::notBefore=" + notBefore.toXMLFormat() + " ::notOnOrAfter=" + String.valueOf(notOnOrAfter));
            boolean bl = expiry = !XMLTimeUtil.isValid(now, notBefore, notOnOrAfter);
            if (expiry) {
                logger.samlAssertionExpired(assertion.getID());
            }
        }
        return expiry;
    }

    public static boolean hasExpired(SAML11AssertionType assertion, long clockSkewInMilis) throws ConfigurationException {
        boolean expiry = false;
        SAML11ConditionsType conditionsType = assertion.getConditions();
        if (conditionsType != null) {
            XMLGregorianCalendar now = XMLTimeUtil.getIssueInstant();
            XMLGregorianCalendar notBefore = conditionsType.getNotBefore();
            XMLGregorianCalendar updatedNotBefore = XMLTimeUtil.subtract(notBefore, clockSkewInMilis);
            XMLGregorianCalendar notOnOrAfter = conditionsType.getNotOnOrAfter();
            XMLGregorianCalendar updatedOnOrAfter = XMLTimeUtil.add(notOnOrAfter, clockSkewInMilis);
            logger.trace("Now=" + now.toXMLFormat() + " ::notBefore=" + notBefore.toXMLFormat() + " ::notOnOrAfter=" + String.valueOf(notOnOrAfter));
            boolean bl = expiry = !XMLTimeUtil.isValid(now, updatedNotBefore, updatedOnOrAfter);
            if (expiry) {
                logger.samlAssertionExpired(assertion.getID());
            }
        }
        return expiry;
    }

    public static XMLGregorianCalendar getExpiration(AssertionType assertion) {
        XMLGregorianCalendar expiry = null;
        ConditionsType conditionsType = assertion.getConditions();
        if (conditionsType != null) {
            expiry = conditionsType.getNotOnOrAfter();
        }
        return expiry;
    }

    public static List<String> getRoles(AssertionType assertion, List<String> roleKeys) {
        ArrayList<String> roles = new ArrayList<String>();
        Set statements = assertion.getStatements();
        for (StatementAbstractType statement : statements) {
            if (!(statement instanceof AttributeStatementType)) continue;
            AttributeStatementType attributeStatement = (AttributeStatementType)statement;
            List attList = attributeStatement.getAttributes();
            for (AttributeStatementType.ASTChoiceType obj : attList) {
                List attributeValues;
                AttributeType attr = obj.getAttribute();
                if (roleKeys != null && roleKeys.size() > 0 && !roleKeys.contains(attr.getName()) || (attributeValues = attr.getAttributeValue()) == null) continue;
                for (Object attrValue : attributeValues) {
                    if (attrValue instanceof String) {
                        roles.add((String)attrValue);
                        continue;
                    }
                    if (attrValue instanceof Node) {
                        Node roleNode = (Node)attrValue;
                        roles.add(roleNode.getFirstChild().getNodeValue());
                        continue;
                    }
                    throw logger.unknownObjectType(attrValue);
                }
            }
        }
        return roles;
    }

    public static List<String> getRoles(SAML11AssertionType assertion, List<String> roleKeys) {
        ArrayList<String> roles = new ArrayList<String>();
        List statements = assertion.getStatements();
        for (SAML11StatementAbstractType statement : statements) {
            if (!(statement instanceof SAML11AttributeStatementType)) continue;
            SAML11AttributeStatementType attributeStatement = (SAML11AttributeStatementType)statement;
            List attributes = attributeStatement.get();
            for (SAML11AttributeType attr : attributes) {
                List attributeValues;
                if (roleKeys != null && roleKeys.size() > 0 && !roleKeys.contains(attr.getAttributeName()) || (attributeValues = attr.get()) == null) continue;
                for (Object attrValue : attributeValues) {
                    if (attrValue instanceof String) {
                        roles.add((String)attrValue);
                        continue;
                    }
                    if (attrValue instanceof Node) {
                        Node roleNode = (Node)attrValue;
                        roles.add(roleNode.getFirstChild().getNodeValue());
                        continue;
                    }
                    throw logger.unknownObjectType(attrValue);
                }
            }
        }
        return roles;
    }

    public static AssertionType getAssertion(SAMLDocumentHolder holder, ResponseType responseType, PrivateKey privateKey) throws ParsingException, ProcessingException, ConfigurationException {
        List assertions = responseType.getAssertions();
        if (assertions.isEmpty()) {
            throw new ProcessingException("No assertion from response.");
        }
        ResponseType.RTChoiceType rtChoiceType = (ResponseType.RTChoiceType)assertions.get(0);
        EncryptedAssertionType encryptedAssertion = rtChoiceType.getEncryptedAssertion();
        if (encryptedAssertion != null) {
            if (privateKey == null) {
                throw new ProcessingException("Encryptd assertion and decrypt private key is null");
            }
            AssertionUtil.decryptAssertion(responseType, privateKey);
        }
        return ((ResponseType.RTChoiceType)responseType.getAssertions().get(0)).getAssertion();
    }

    public static boolean isAssertionEncrypted(ResponseType responseType) throws ProcessingException {
        List assertions = responseType.getAssertions();
        if (assertions.isEmpty()) {
            throw new ProcessingException("No assertion from response.");
        }
        ResponseType.RTChoiceType rtChoiceType = (ResponseType.RTChoiceType)assertions.get(0);
        return rtChoiceType.getEncryptedAssertion() != null;
    }

    public static Element decryptAssertion(ResponseType responseType, PrivateKey privateKey) throws ParsingException, ProcessingException, ConfigurationException {
        return AssertionUtil.decryptAssertion(responseType, (EncryptedData encryptedData) -> Collections.singletonList(privateKey));
    }

    public static Element decryptAssertion(ResponseType responseType, XMLEncryptionUtil.DecryptionKeyLocator decryptionKeyLocator) throws ParsingException, ProcessingException, ConfigurationException {
        Element enc = responseType.getAssertions().stream().map(ResponseType.RTChoiceType::getEncryptedAssertion).filter(Objects::nonNull).findFirst().map(EncryptedElementType::getEncryptedElement).orElseThrow(() -> new ProcessingException("No encrypted assertion found."));
        String oldID = enc.getAttribute(JBossSAMLConstants.ID.get());
        Document newDoc = DocumentUtil.createDocument();
        Node importedNode = newDoc.importNode(enc, true);
        newDoc.appendChild(importedNode);
        Element decryptedDocumentElement = XMLEncryptionUtil.decryptElementInDocument(newDoc, decryptionKeyLocator);
        SAMLParser parser = SAMLParser.getInstance();
        JAXPValidationUtil.checkSchemaValidation(decryptedDocumentElement);
        AssertionType assertion = (AssertionType)parser.parse(SAMLParser.createEventReader(DocumentUtil.getNodeAsStream(decryptedDocumentElement)));
        responseType.replaceAssertion(oldID, new ResponseType.RTChoiceType(assertion));
        return decryptedDocumentElement;
    }

    public static boolean isIdEncrypted(ResponseType responseType) {
        SubjectType.STSubType subTypeElement = AssertionUtil.getSubTypeElement(responseType);
        return subTypeElement != null && subTypeElement.getEncryptedID() != null;
    }

    public static void decryptId(ResponseType responseType, XMLEncryptionUtil.DecryptionKeyLocator decryptionKeyLocator) throws ConfigurationException, ProcessingException, ParsingException {
        SubjectType.STSubType subTypeElement = AssertionUtil.getSubTypeElement(responseType);
        if (subTypeElement == null) {
            return;
        }
        EncryptedElementType encryptedID = subTypeElement.getEncryptedID();
        if (encryptedID == null) {
            return;
        }
        Element encryptedElement = encryptedID.getEncryptedElement();
        Document newDoc = DocumentUtil.createDocument();
        Node importedNode = newDoc.importNode(encryptedElement, true);
        newDoc.appendChild(importedNode);
        Element decryptedNameIdElement = XMLEncryptionUtil.decryptElementInDocument(newDoc, decryptionKeyLocator);
        XMLEventReader xmlEventReader = StaxParserUtil.getXMLEventReader(DocumentUtil.getNodeAsStream(decryptedNameIdElement));
        NameIDType nameIDType = SAMLParserUtil.parseNameIDType(xmlEventReader);
        subTypeElement.addBaseID((BaseIDAbstractType)nameIDType);
        subTypeElement.setEncryptedID(null);
    }

    private static SubjectType.STSubType getSubTypeElement(ResponseType responseType) {
        List assertions = responseType.getAssertions();
        if (assertions.isEmpty()) {
            return null;
        }
        AssertionType assertion = ((ResponseType.RTChoiceType)assertions.get(0)).getAssertion();
        if (assertion.getSubject() == null) {
            return null;
        }
        return assertion.getSubject().getSubType();
    }
}

