/** 
 * XMLSec library
 *
 * RSA algorithm support
 * 
 * See Copyright for the status of this software.
 * 
 * Author: Aleksey Sanin <aleksey@aleksey.com>
 */
#include <stdlib.h>
#include <string.h>

#include <openssl/objects.h>
#include <openssl/sha.h>

#include <xmlsec/xmlsec.h>
#include <xmlsec/base64.h>
#include <xmlsec/bn.h>

#include <xmlsec/rsa.h>

/**
 * RSA Context
 */
typedef struct _xmlSecRsaContext {
    RSA			*rsa;
    SHA_CTX		sha1;	
    size_t		digestLen;
    size_t		digestSize;	
    unsigned char	digest[1]; /* must be the last one! */
} xmlSecRsaSha1Context, *xmlSecRsaSha1ContextPtr;

static void
xmlSecRsaKeyDestroy(xmlSecRsaKeyPtr ptr) {
    if (ptr == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyDestroy: ptr is null\n");
#endif 	    
	return;
    }
    if(ptr->name != NULL) {
	xmlFree(ptr->name);
    }
    if(ptr->rsa != NULL) {
	RSA_free(ptr->rsa);
    }
    xmlFree(ptr);
} 

xmlSecRsaKeyPtr
xmlSecRsaKeyCreate(const xmlChar* name) {
    xmlSecRsaKeyPtr ptr;

    /*
     * Allocate a new xmlSecRsaKey and fill the fields.
     */
    ptr = (xmlSecRsaKeyPtr) xmlMalloc(sizeof(xmlSecRsaKey));
    if (ptr == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyCreate: malloc failed\n");
#endif 	    
	return(NULL);
    }
    memset(ptr, 0, sizeof(xmlSecRsaKey));
    ptr->algorithm = xmlSecSignRsaSha1;
    ptr->name = (name != NULL) ? xmlStrdup(name) : NULL;
    ptr->size = sizeof(xmlSecRsaKey);
    ptr->destroyCallback = (xmlSecKeyDestroyCallback)xmlSecRsaKeyDestroy;

    /* create new RSA object */
    ptr->rsa = RSA_new();
    if(ptr->rsa == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlDSigRsaKeyCreate: failed to create RSA object\n");
#endif 	    
	xmlSecKeyDestroy((xmlSecKeyPtr)ptr);
	return(NULL);	
    }    
    return(ptr);    
}

/**
 * http://www.w3.org/TR/xmldsig-core/#sec-RSAKeyValue
 * The RSAKeyValue Element
 *
 * RSA key values have two fields: Modulus and Exponent.
 *
 * <RSAKeyValue>
 *   <Modulus>xA7SEU+e0yQH5rm9kbCDN9o3aPIo7HbP7tX6WOocLZAtNfyxSZDU16ksL6W
 *     jubafOqNEpcwR3RdFsT7bCqnXPBe5ELh5u4VEy19MzxkXRgrMvavzyBpVRgBUwUlV
 *   	  5foK5hhmbktQhyNdy/6LpQRhDUDsTvK+g9Ucj47es9AQJ3U=
 *   </Modulus>
 *   <Exponent>AQAB</Exponent>
 * </RSAKeyValue>
 *
 * Arbitrary-length integers (e.g. "bignums" such as RSA moduli) are 
 * represented in XML as octet strings as defined by the ds:CryptoBinary type.
 *
 * Schema Definition:
 * 
 * <element name="RSAKeyValue" type="ds:RSAKeyValueType"/>
 * <complexType name="RSAKeyValueType">
 *   <sequence>
 *     <element name="Modulus" type="ds:CryptoBinary"/> 
 *     <element name="Exponent" type="ds:CryptoBinary"/>
 *   </sequence>
 * </complexType>
 *
 * DTD Definition:
 * 
 * <!ELEMENT RSAKeyValue (Modulus, Exponent) > 
 * <!ELEMENT Modulus (#PCDATA) >
 * <!ELEMENT Exponent (#PCDATA) >
 *
 * ============================================================================
 * 
 * To support reading/writing private keys an PrivateExponent element is added
 * to the end
 **/ 
xmlSecRsaKeyPtr
xmlSecRsaKeyRead(xmlDocPtr doc, xmlNodePtr rsaKeyValueNode, 
		 const xmlChar* name) {
    xmlSecRsaKeyPtr key = NULL;
    xmlNodePtr cur;
    
    if(rsaKeyValueNode == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: cur is null\n");
#endif 	    
	return(NULL);
    }    
    cur = xmlSecGetNextElementNode(rsaKeyValueNode->children); 

    /* create new object */
    key = xmlSecRsaKeyCreate(name);
    if(key == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: failed to create xmlSecRsaKey \n");
#endif 	    
	return(NULL);	
    }    
    
    /* first is Modulus node */
    if((cur == NULL) || (!xmlSecCheckNodeName(doc, cur, BAD_CAST "Modulus", xmlDSigNs))) {
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: required element \"Modulus\" missed\n");
	xmlSecKeyDestroy((xmlSecKeyPtr)key);
	return(NULL);
    }
    if(xmlSecNodeGetBNValue(cur, &(key->rsa->n)) == NULL) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: failed to convert element \"Modulus\" value\n");
#endif	    
	xmlSecKeyDestroy((xmlSecKeyPtr)key);
	return(NULL);
    }
    cur = xmlSecGetNextElementNode(cur->next);

    /* next is Exponent node. */
    if((cur == NULL) || (!xmlSecCheckNodeName(doc, cur, BAD_CAST "Exponent", xmlDSigNs))) {
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: required element \"Exponent\" missed\n");
	xmlSecKeyDestroy((xmlSecKeyPtr)key);
	return(NULL);
    }
    if(xmlSecNodeGetBNValue(cur, &(key->rsa->e)) == NULL) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: failed to convert element \"Exponent\" value\n");
#endif	    
	xmlSecKeyDestroy((xmlSecKeyPtr)key);
	return(NULL);
    }
    cur = xmlSecGetNextElementNode(cur->next);

    if((cur != NULL) && (xmlSecCheckNodeName(doc, cur, BAD_CAST "PrivateExponent", xmlDSigNs))) {
        /* next is PrivateExponent node. It is REQUIRED for private key but
	 * we are not sure exactly what do we read */
	if(xmlSecNodeGetBNValue(cur, &(key->rsa->d)) == NULL) {
#ifdef DEBUG_XMLSEC    
	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaKeyRead: failed to convert element \"PrivateExponent\" value\n");
#endif	    
	    xmlSecKeyDestroy((xmlSecKeyPtr)key);
	    return(NULL);
	}
	key->privateKey = 1;
	cur = xmlSecGetNextElementNode(cur->next);
    }
    
    if(cur != NULL) {
	 xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaKeyRead: unexpected node found\n");
	xmlSecKeyDestroy((xmlSecKeyPtr)key);
	return(NULL);
    }
    return(key);
}

int
xmlSecRsaKeyWrite(const xmlSecRsaKeyPtr key, xmlDocPtr doc, xmlNodePtr parent, 
		 int privateKey) {
    xmlNodePtr cur;
    xmlNodePtr prev;
    int ret;
    
    if((key == NULL) || (parent == NULL)) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyWrite: key or parent is null \n");
#endif 	    
	return(-1);	
    }        
    cur = xmlSecGetNextElementNode(parent->children);
    
    /* first is Modulus node */
    if((cur == NULL) || (!xmlSecCheckNodeName(doc, cur, BAD_CAST "Modulus", xmlDSigNs))) {
	cur = xmlSecNewSibling(parent, cur, NULL, BAD_CAST "Modulus");
	if(cur == NULL) {
#ifdef DEBUG_XMLSEC
    	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaKeyWrite: failed create new node\n");
#endif 	    
	    return(-1);	
	}
    }
    ret = xmlSecNodeSetBNValue(cur, key->rsa->n, 1);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyRead: failed to convert element \"Modulus\" value\n");
#endif	    
	return(-1);
    }    
    prev = cur;
    cur = xmlSecGetNextElementNode(cur->next);

    /* netx is Exponent node. */
    if((cur == NULL) || (!xmlSecCheckNodeName(doc, cur, BAD_CAST "Exponent", xmlDSigNs))) {
	cur = xmlSecNewSibling(parent, prev, NULL, BAD_CAST "Exponent");
	if(cur == NULL) {
#ifdef DEBUG_XMLSEC
    	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaKeyWrite: failed create new node\n");
#endif 	    
	    return(-1);	
	}
    }
    ret = xmlSecNodeSetBNValue(cur, key->rsa->e, 1);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaKeyWrite: failed to convert element \"Exponent\" value\n");
#endif	    
	return(-1);
    }
    prev = cur;
    cur = xmlSecGetNextElementNode(cur->next);

    /* next is PrivateExponent node, for private key we will write it,
     otherwise delete */
    if(privateKey) {
	if((cur == NULL) || (!xmlSecCheckNodeName(doc, cur, BAD_CAST "PrivateExponent", xmlDSigNs))) {
	    cur = xmlSecNewSibling(parent, prev, NULL, BAD_CAST "PrivateExponent");    
	    if(cur == NULL) {
#ifdef DEBUG_XMLSEC
    		xmlGenericError(xmlGenericErrorContext,
	    	    "xmlSecRsaKeyWrite: failed create new node\n");
#endif 	    
		return(-1);	
	    }
	}
	ret = xmlSecNodeSetBNValue(cur, key->rsa->d, 1);
	if(ret < 0) {
#ifdef DEBUG_XMLSEC    
	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaKeyWrite: failed to convert element \"PrivateExponent\" value\n");
#endif	    
	    return(-1);
	}
	prev = cur;
	cur = xmlSecGetNextElementNode(cur->next);
    } else if((cur != NULL) && (xmlSecCheckNodeName(doc, cur, BAD_CAST "PrivateExponent", xmlDSigNs))) {
	xmlNodePtr ptr;	
	/* if it is public key remove "PrivateExponent" */
	ptr = cur;
	cur = xmlSecGetNextElementNode(cur->next);
	xmlUnlinkNode(ptr);
	xmlFreeNode(ptr);
    }

    /* remove the rest */
    while(cur != NULL) {
	prev = cur;
	cur = xmlSecGetNextElementNode(cur->next);
	xmlUnlinkNode(prev);
	xmlFreeNode(prev);
    }    
    return(0);   
}

/**
 * RSA-SHA1 binary transform
 */							 
static int
xmlSecRsaSha1ContextCalculate(xmlSecRsaSha1ContextPtr ctx, int sign) {
    unsigned char buf[SHA_DIGEST_LENGTH];
    int ret;

    if((ctx == NULL) || (ctx->rsa == NULL)) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1ContextCalculate: ctx is null\n");
#endif	    
	return(-1);
    }
    
    /* first of all finalizae SHA1 */
    SHA1_Final(buf, &(ctx->sha1));

    if(sign) {
	ret = RSA_sign(NID_sha1, buf, SHA_DIGEST_LENGTH, 
		       ctx->digest, &(ctx->digestLen), ctx->rsa);
	if(ret != 1) {
#ifdef DEBUG_XMLSEC
	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaSha1ContextCalculate: RSA sign failed\n");
#endif 	    
	    return(-1);	    
	}
    } else {
	ret = RSA_verify(NID_sha1, buf, SHA_DIGEST_LENGTH, 
		         ctx->digest, ctx->digestLen, ctx->rsa);
	if(ret != 1) { 
#ifdef DEBUG_XMLSEC
	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaSha1ContextCalculate: RSA verify failed\n");
#endif 	    
	    return(-1);
	} 
    }
    return(0);
}    

static int
xmlSecRsaSha1TransformRead(xmlSecBinTransformPtr ptr, unsigned char *buf, size_t len) {
    xmlSecRsaSha1ContextPtr ctx;
    int ret;

    if((ptr == NULL) || (ptr->data == NULL) || (buf == NULL)) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformDestroy: ptr, ptr->data or bufis null\n");
#endif	    
	return(-1);
    }

    /* if we already called Final then nothing to read more! */
    if(ptr->finalized) {
	return(0);
    }
    
    ctx = (xmlSecRsaSha1ContextPtr)(ptr->data);
    /* have to read everything first */
    do {
	ret = xmlSecBinTransformRead(ptr->prev, buf, len);
        if(ret < 0) {
#ifdef DEBUG_XMLSEC
	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecSha1TransformRead: prev read failed\n");
#endif 	    
	    return(-1);
	} else if(ret > 0) {
	    SHA1_Update(&(ctx->sha1), buf, ret);
	}
    } while(ret > 0);
    
    /* ret == 0 i.e. there is no more data */
    ptr->finalized = 1;
    ret = xmlSecRsaSha1ContextCalculate(ctx, ptr->encode);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformWrite: signature is NOT verified or NOT written\n");
#endif	    
	return(-1);
    }

    if(len < ctx->digestLen) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformDestroy: buf is too small\n");
#endif	    
	return(-1);
    }
    memcpy(buf, ctx->digest, ctx->digestLen);
    return(0);
}


static int
xmlSecRsaSha1TransformWrite(xmlSecBinTransformPtr ptr, const unsigned char *buf, size_t len) {
    xmlSecRsaSha1ContextPtr ctx;

    if((ptr == NULL) || (ptr->data == NULL)) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformWrite: ptr, ptr->data is null\n");
#endif	    
	return(-1);
    }

    if((buf == NULL) || (len == 0)) {
	return(0);
    }

    /* if we already called Final then nothing to write more! */
    if(ptr->finalized) {
	return(0);
    }
    
    ctx = (xmlSecRsaSha1ContextPtr)(ptr->data);
    SHA1_Update(&(ctx->sha1), buf, len);
    return(len);
}


static int
xmlSecRsaSha1TransformFlush(xmlSecBinTransformPtr ptr) {
    xmlSecRsaSha1ContextPtr ctx;
    int ret;
    
    if((ptr == NULL) || (ptr->data == NULL)) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecSha1TransformFlush: ptr, ptr->data is null\n");
#endif	    
	return(-1);
    }

    /* if we already called Final then nothing to flush more! */
    if(ptr->finalized) {
	return(0);
    }

    ctx = (xmlSecRsaSha1ContextPtr)(ptr->data);

    ret = xmlSecRsaSha1ContextCalculate(ctx, ptr->encode);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformFlush: signature is NOT verified or NOT written\n");
#endif	    
	return(-1);
    }
    
    ret = xmlSecBinTransformWrite(ptr->next, ctx->digest, ctx->digestLen);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecSha1TransformFlush: prev write failed\n");
#endif 	    
	return(-1);
    }

    ret = xmlSecBinTransformFlush(ptr->next);
    if(ret < 0) {
#ifdef DEBUG_XMLSEC    
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecSha1TransformFlush: prev flush failed\n");
#endif 	    
	return(-1);
    } 

    return(0);
}

static void
xmlSecRsaSha1TransformDestroy(xmlSecBinTransformPtr ptr) {
    xmlSecRsaSha1ContextPtr ctx; 

    if(ptr == NULL) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformDestroy: ptr is null\n");
#endif	    
	return;
    }

    if(ptr->data == NULL) {
	xmlFree(ptr);
	return;
    }
    
    ctx = (xmlSecRsaSha1ContextPtr)(ptr->data);
    if(ctx->rsa != NULL) {
	RSA_free(ctx->rsa);
    }
    memset(ptr, 0, sizeof(xmlSecBinTransform) + sizeof(xmlSecRsaSha1Context) + ctx->digestSize);
    xmlFree(ptr);
}


xmlSecBinTransformPtr
xmlSecRsaSha1TransformCreate(int encode, xmlSecRsaKeyPtr rsaKey, const xmlChar* digest) {
    xmlSecBinTransformPtr ptr;
    xmlSecRsaSha1ContextPtr ctx;
    size_t size;
    size_t digestSize = 0;
    int ret;

    if((rsaKey == NULL) || (rsaKey->rsa == NULL)) {
#ifdef DEBUG_XMLSEC
	xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: rsaKey is null\n");
#endif	    
	return(NULL);
    }
    /*
     * Allocate a new xmlSecBinTransform  + xmlSecRsaSha1Context + 
     * space for result md and fill the fields.
     */
    if(rsaKey != NULL) {
	digestSize = RSA_size(rsaKey->rsa);
    }
    size = sizeof(xmlSecBinTransform) + sizeof(xmlSecRsaSha1Context) + digestSize;
    ptr = (xmlSecBinTransformPtr) xmlMalloc(size);
    if (ptr == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: malloc failed\n");
#endif 	    
	return(NULL);
    }
    memset(ptr, 0, size);

    ctx = (xmlSecRsaSha1ContextPtr)(((unsigned char*)ptr) + sizeof(xmlSecBinTransform));

    ptr->algorithm = xmlSecSignRsaSha1;
    ptr->encode = encode;
    ptr->data = ctx;
    ptr->destroyCallback = (xmlSecBinTransformDestroyCallback)xmlSecRsaSha1TransformDestroy;
    ptr->readCallback = (xmlSecBinTransformReadCallback)xmlSecRsaSha1TransformRead;
    ptr->writeCallback = (xmlSecBinTransformWriteCallback)xmlSecRsaSha1TransformWrite;
    ptr->flushCallback = (xmlSecBinTransformFlushCallback)xmlSecRsaSha1TransformFlush;
    
    /* 
     * decode digest
     */
    ctx->digestSize = digestSize;
    if(!encode && digest != NULL) {
	ret = xmlSecBase64Decode(digest, ctx->digest, ctx->digestSize);
	if(ret < 0) {	
#ifdef DEBUG_XMLSEC
    	    xmlGenericError(xmlGenericErrorContext,
		"xmlSecRsaSha1TransformCreate: failed to base64 decode digest\n");
#endif 	    
	    xmlSecRsaSha1TransformDestroy(ptr);	
	    return(NULL);	    
	}	
	ctx->digestLen = ret;
    }

    /* 
     * create SHA1
     */
    SHA1_Init(&(ctx->sha1));
     
    /* 
     * now create RSA context and copy all keys. 
     * todo: do we need this or we can 
     * reuse existing structure?
     */
    ctx->rsa = RSA_new();
    if(ctx->rsa == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: failed to create RSA context\n");
#endif 	    
	xmlSecRsaSha1TransformDestroy(ptr);
	return(NULL);
    }  
     
    if((rsaKey->rsa != NULL) && (rsaKey->rsa->n != NULL)) {
	ctx->rsa->n = BN_dup(rsaKey->rsa->n);
    } else {
	ctx->rsa->n = BN_new();
    }
    if(ctx->rsa->n == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: failed to create n \n");
#endif 	    
	xmlSecRsaSha1TransformDestroy(ptr);
	return(NULL);
    }   

    if((rsaKey->rsa != NULL) && (rsaKey->rsa->e != NULL)) {
	ctx->rsa->e = BN_dup(rsaKey->rsa->e);
    } else {
	ctx->rsa->e = BN_new();
    }
    if(ctx->rsa->e == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: failed to create e \n");
#endif 	    
	xmlSecRsaSha1TransformDestroy(ptr);
	return(NULL);
    }   

    if((rsaKey->rsa != NULL) && (rsaKey->rsa->d != NULL)) {
	ctx->rsa->d = BN_dup(rsaKey->rsa->d);
    } else {
	ctx->rsa->d = BN_new();
    }
    if(ctx->rsa->d == NULL) {
#ifdef DEBUG_XMLSEC
        xmlGenericError(xmlGenericErrorContext,
	    "xmlSecRsaSha1TransformCreate: failed to create d \n");
#endif 	    
	xmlSecRsaSha1TransformDestroy(ptr);
	return(NULL);
    }   

    return(ptr);
}



