1 module hunt.jwt.JwtOpenSSL;
2 
3 import deimos.openssl.ssl;
4 import deimos.openssl.pem;
5 import deimos.openssl.rsa;
6 import deimos.openssl.hmac;
7 import deimos.openssl.err;
8 
9 import hunt.jwt.Exceptions;
10 import hunt.jwt.JwtAlgorithm;
11 
12 import hunt.logging;
13 
14 import std.conv;
15 import std.range;
16 
17 import core.stdc.stdlib : alloca;
18 import core.stdc.config : c_long;
19 
20 string sign(string msg, string key, JwtAlgorithm algo = JwtAlgorithm.HS256) {
21     ubyte[] sign;
22 
23     void sign_hs(const(EVP_MD)* evp, uint signLen) {
24         sign = new ubyte[signLen];
25 
26         HMAC_CTX ctx;
27         scope(exit) HMAC_CTX_reset(&ctx);
28         HMAC_CTX_reset(&ctx);
29        
30         if(0 == HMAC_Init_ex(&ctx, key.ptr, cast(int)key.length, evp, null)) {
31             throw new Exception("Can't initialize HMAC context.");
32         }
33         if(0 == HMAC_Update(&ctx, cast(const(ubyte)*)msg.ptr, cast(ulong)msg.length)) {
34             throw new Exception("Can't update HMAC.");
35         }
36         if(0 == HMAC_Final(&ctx, cast(ubyte*)sign.ptr, &signLen)) {
37             throw new Exception("Can't finalize HMAC.");
38         }
39     }
40 
41 
42 version(HUNT_JWT_DEBUG) {
43     trace("msg: ", msg);
44     trace("key: ", key);
45     trace("algo: ", algo);
46 }
47 
48     switch(algo) {
49         case JwtAlgorithm.NONE: {
50             break;
51         }
52         case JwtAlgorithm.HS256: {
53             sign_hs(EVP_sha256(), SHA256_DIGEST_LENGTH);
54             break;
55         }
56         case JwtAlgorithm.HS384: {
57             sign_hs(EVP_sha384(), SHA384_DIGEST_LENGTH);
58             break;
59         }
60         case JwtAlgorithm.HS512: {
61             sign_hs(EVP_sha512(), SHA512_DIGEST_LENGTH);
62             break;
63         }
64 
65         /* RSA */
66         case JwtAlgorithm.RS256: {
67             const(EVP_MD) *alg = EVP_sha256();
68             sign = signShaPem(alg, EVP_PKEY_RSA, key, msg);
69             break;
70         }
71         case JwtAlgorithm.RS384: {
72             const(EVP_MD) *alg = EVP_sha384();
73             sign = signShaPem(alg, EVP_PKEY_RSA, key, msg);
74             break;
75         }
76         case JwtAlgorithm.RS512: {
77             const(EVP_MD) *alg = EVP_sha512();
78             sign = signShaPem(alg, EVP_PKEY_RSA, key, msg);
79             break;
80         }
81 
82         /* ECC */
83         case JwtAlgorithm.ES256: {
84             const(EVP_MD) *alg = EVP_sha256();
85             sign = signShaPem(alg, EVP_PKEY_EC, key, msg);
86             break;
87         }
88         case JwtAlgorithm.ES384: {
89             const(EVP_MD) *alg = EVP_sha384();
90             sign = signShaPem(alg, EVP_PKEY_EC, key, msg);
91             break;
92         }
93         case JwtAlgorithm.ES512: {
94             const(EVP_MD) *alg = EVP_sha512();
95             sign = signShaPem(alg, EVP_PKEY_EC, key, msg);
96             break;
97         }
98 
99         default:
100             throw new SignException("Wrong algorithm: " ~ to!string(algo));
101     }
102 
103     return cast(string)sign;
104 }
105 
106 // Ported from https://github.com/benmcollins/libjwt/blob/master/libjwt/jwt-openssl.c
107 private ubyte[] signShaPem(const(EVP_MD) *alg, int type, string key, string msg) {
108     BIO * bufkey = BIO_new_mem_buf(cast(void*)key.ptr, cast(int)key.length);
109     if(bufkey is null) {
110         // throw new Exception("Can't load the private key.");
111         warning("Can't load the private key.");
112         return null;
113     }
114     scope(exit) BIO_free(bufkey);
115 
116 	/* This uses OpenSSL's default passphrase callback if needed. The
117 	 * library caller can override this in many ways, all of which are
118 	 * outside of the scope of LibJWT and this is documented in jwt.h. */
119 	EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bufkey, null, null, null);
120 	if (pkey is null) {
121         warning("Invalid argument");
122         return null;
123     }
124     scope(exit) EVP_PKEY_free(pkey);
125 
126 	int pkey_type = EVP_PKEY_id(pkey);
127 	if (pkey_type != type){
128         warning("Invalid argument");
129         return null;
130     }
131 
132 	EVP_MD_CTX *mdctx = EVP_MD_CTX_create();
133 	if (mdctx is null){
134         warning("Out of memory");
135         return null;
136     }
137     scope(exit) EVP_MD_CTX_destroy(mdctx);
138 
139 	/* Initialize the DigestSign operation using alg */
140 	if (EVP_DigestSignInit(mdctx, null, alg, null, pkey) != 1){
141         warning("Invalid argument");
142         return null;
143     }
144 
145 	/* Call update with the message */
146 	if (EVP_DigestSignUpdate(mdctx, cast(void*)msg.ptr, msg.length) != 1){
147         warning("Invalid argument");
148         return null;
149     }
150 
151 	/* First, call EVP_DigestSignFinal with a null sig parameter to get length
152 	 * of sig. Length is returned in slen */
153     size_t slen;
154 	if (EVP_DigestSignFinal(mdctx, null, &slen) != 1){
155         warning("Invalid argument");
156         return null;
157     }
158 
159 	/* Allocate memory for signature based on returned size */
160     // FIXME: Needing refactor or cleanup -@zhangxueping at 2021-03-03T19:38:11+08:00
161     // Crashed
162 	// ubyte[] sig = new ubyte[slen];
163     ubyte* sig = cast(ubyte*)alloca(slen);
164 
165 	/* Get the signature */
166 	if (EVP_DigestSignFinal(mdctx, sig, &slen) != 1) {
167         warning("Invalid argument");
168         return null;
169     }
170 
171     ubyte[] resultSig;
172 
173 	if (pkey_type != EVP_PKEY_EC) {
174         resultSig = sig[0..slen].dup;
175 	} else {
176 		uint degree, bn_len, r_len, s_len, buf_len;
177 
178 		/* For EC we need to convert to a raw format of R/S. */
179 
180 		/* Get the actual ec_key */
181 		EC_KEY *ec_key = EVP_PKEY_get1_EC_KEY(pkey);
182 		if (ec_key is null) {
183             warning("Out of memory");
184             return null;
185         }
186 
187 		degree = EC_GROUP_get_degree(EC_KEY_get0_group(ec_key));
188 
189 		EC_KEY_free(ec_key);
190 
191 		/* Get the sig from the DER encoded version. */
192         version(HUNT_JWT_DEBUG) {
193             infof("slen: %d, sig: %(%02X %)", slen, sig[0..slen]);
194         }
195 
196         // FIXME: Needing refactor or cleanup -@zhangxueping at 2021-03-03T19:39:16+08:00
197         // Crashed here
198         // ECDSA_SIG *ec_sig = d2i_ECDSA_SIG(null, cast(const(ubyte) **)sig.ptr, cast(long)slen);
199 		ECDSA_SIG *ec_sig = d2i_ECDSA_SIG(null, cast(const(ubyte) **)&sig, cast(c_long)slen);
200 		if (ec_sig is null) {
201             warning("Can't decode ECDSA signature.");
202             return null;
203         }
204         scope(exit) ECDSA_SIG_free(ec_sig);
205             
206         // version(HUNT_JWT_DEBUG) {
207         //     tracef("slen: %d, sig: %(%02X %)", slen, sig[0..slen]);
208         // }
209 
210         BIGNUM *ec_sig_r;
211         BIGNUM *ec_sig_s;
212 		ECDSA_SIG_get0(ec_sig, &ec_sig_r, &ec_sig_s);
213 		r_len = BN_num_bytes(ec_sig_r);
214 		s_len = BN_num_bytes(ec_sig_s);
215 		bn_len = (degree + 7) / 8;
216 		if ((r_len > bn_len) || (s_len > bn_len)){
217             warning("Invalid argument");
218             return null;
219         }
220 
221 		buf_len = 2 * bn_len;
222         ubyte[] raw_buf = new ubyte[buf_len];
223 
224 		/* Pad the bignums with leading zeroes. */
225 		// memset(raw_buf, 0, buf_len);
226 		BN_bn2bin(ec_sig_r, raw_buf.ptr + bn_len - r_len);
227 		BN_bn2bin(ec_sig_s, raw_buf.ptr + buf_len - s_len);
228 
229         resultSig = raw_buf;
230 	}
231 
232     version(HUNT_JWT_DEBUG) {
233         tracef("%d, buffer: %(%02X %)", resultSig.length, resultSig);
234     }
235 
236     return resultSig;
237 }
238 
239 
240 bool verifySignature(string head, string signature, string key, JwtAlgorithm algo = JwtAlgorithm.HS256) {
241     import hunt.jwt.Base64Codec;
242 
243     version(HUNT_JWT_DEBUG) {
244         infof("signature: %s", signature);
245     }
246 
247     ubyte[] decodedSign = cast(ubyte[])urlsafeB64Decode(signature);
248 
249     bool verify_rs(ubyte* hash, int type, uint len, uint signLen) {
250         RSA* rsa_public = RSA_new();
251         scope(exit) RSA_free(rsa_public);
252 
253         BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
254         if(bpo is null)
255             throw new Exception("Can't load key to the BIO.");
256         scope(exit) BIO_free(bpo);
257 
258         RSA* rsa = PEM_read_bio_RSA_PUBKEY(bpo, &rsa_public, null, null);
259         if(rsa is null) {
260             throw new Exception("Can't create RSA key.");
261         }
262 
263         // ubyte[] sign = cast(ubyte[])signature;
264         int ret = RSA_verify(type, hash, signLen, decodedSign.ptr, len, rsa_public);
265         return ret == 1;
266     }
267 
268 
269     switch(algo) {
270         case JwtAlgorithm.NONE: {
271             return key.length == 0;
272         }
273         case JwtAlgorithm.HS256:
274         case JwtAlgorithm.HS384:
275         case JwtAlgorithm.HS512: {
276             return decodedSign == cast(ubyte[])sign(head, key, algo);
277         }
278 
279         /* RSA */
280         case JwtAlgorithm.RS256: {
281             const(EVP_MD) *alg = EVP_sha256();
282             return verifyShaPem(alg, EVP_PKEY_RSA, head, decodedSign, key);
283         }
284         case JwtAlgorithm.RS384: {
285             const(EVP_MD) *alg = EVP_sha384();
286             return verifyShaPem(alg, EVP_PKEY_RSA, head, decodedSign, key);
287         }
288         case JwtAlgorithm.RS512: {
289             const(EVP_MD) *alg = EVP_sha512();
290             return verifyShaPem(alg, EVP_PKEY_RSA, head, decodedSign, key);
291         }
292 
293         /* ECC */
294         case JwtAlgorithm.ES256: {
295             const(EVP_MD) *alg = EVP_sha256();
296             return verifyShaPem(alg, EVP_PKEY_EC, head, decodedSign, key);
297 
298             // ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
299             // SHA256(cast(const(ubyte)*)head.ptr, head.length, hash.ptr);
300             // return verify_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH );
301         }
302         case JwtAlgorithm.ES384: {
303             const(EVP_MD) *alg = EVP_sha384();
304             return verifyShaPem(alg, EVP_PKEY_EC, head, decodedSign, key);
305         }
306         case JwtAlgorithm.ES512: {
307             const(EVP_MD) *alg = EVP_sha512();
308             return verifyShaPem(alg, EVP_PKEY_EC, head, decodedSign, key);
309         }
310 
311         default:
312             throw new VerifyException("Wrong algorithm.");
313     }
314 }
315 
316 private bool verifyShaPem(const(EVP_MD) *alg, int type, string head, const(ubyte)[] sig, string key) {
317     version(HUNT_JWT_DEBUG) {
318         tracef("head: %s", head);
319         tracef("sig: %(%02X %)", sig);
320     }
321 
322     int slen = cast(int)sig.length;
323 
324 	// sig = jwt_b64_decode(sig_b64, &slen);
325 	if (sig.empty()) {
326         version(HUNT_JWT_DEBUG) warning("Invalid argument");
327         return false;
328     }
329 
330 	BIO *bufkey = BIO_new_mem_buf(cast(void*)key.ptr, cast(int)key.length);
331 	if (bufkey is null) {
332         version(HUNT_JWT_DEBUG) warning("Out of memory");
333         return false;
334     }
335 
336     scope(exit) BIO_free(bufkey);
337 
338 	/* This uses OpenSSL's default passphrase callback if needed. The
339 	 * library caller can override this in many ways, all of which are
340 	 * outside of the scope of LibJWT and this is documented in jwt.h. */
341 	EVP_PKEY *pkey = PEM_read_bio_PUBKEY(bufkey, null, null, null);
342 	if (pkey is null) {
343         version(HUNT_JWT_DEBUG) warning("Invalid argument");
344         return false;
345     }
346     scope(exit) EVP_PKEY_free(pkey);
347 
348 	int pkey_type = EVP_PKEY_id(pkey);
349 	if (pkey_type != type) {
350         version(HUNT_JWT_DEBUG) warning("Invalid argument");
351         return false;
352     }
353 
354 	/* Convert EC sigs back to ASN1. */
355 	if (pkey_type == EVP_PKEY_EC) {
356 		uint degree, bn_len;
357 		EC_KEY *ec_key;
358 
359 		ECDSA_SIG *ec_sig = ECDSA_SIG_new();
360 		if (ec_sig is null) {
361             version(HUNT_JWT_DEBUG) warning("Out of memory");
362             return false;
363         }
364         scope(exit) ECDSA_SIG_free(ec_sig);
365 
366 		/* Get the actual ec_key */
367 		ec_key = EVP_PKEY_get1_EC_KEY(pkey);
368 		if (ec_key is null) {
369             version(HUNT_JWT_DEBUG) warning("Out of memory");
370             return false;
371         }
372 
373 		degree = EC_GROUP_get_degree(EC_KEY_get0_group(ec_key));
374 
375 		EC_KEY_free(ec_key);
376 
377 		bn_len = (degree + 7) / 8;
378 		if ((bn_len * 2) != slen) {
379             version(HUNT_JWT_DEBUG) warning("Invalid argument");
380             return false;
381         }
382 
383 		BIGNUM *ec_sig_r = BN_bin2bn(cast(const(ubyte)*)sig.ptr, bn_len, null);
384 		BIGNUM *ec_sig_s = BN_bin2bn(cast(const(ubyte)*)sig.ptr + bn_len, bn_len, null);
385 		if (ec_sig_r  is null || ec_sig_s is null) {
386             version(HUNT_JWT_DEBUG) warning("Invalid argument");
387             return false;
388         }
389 
390 		ECDSA_SIG_set0(ec_sig, ec_sig_r, ec_sig_s);
391 
392 		slen = i2d_ECDSA_SIG(ec_sig, null);
393 		// sig = jwt_malloc(slen);
394         ubyte[] tempBuffer = new ubyte[slen];
395         ubyte*p = tempBuffer.ptr;
396         // ubyte* tempBuffer = cast(ubyte*)alloca(slen);
397         // ubyte*p = tempBuffer;
398 		slen = i2d_ECDSA_SIG(ec_sig, &p);
399 		if (slen == 0) {
400             version(HUNT_JWT_DEBUG) warning("Invalid argument");
401             return false;
402         }
403         sig = tempBuffer;
404 	}
405 
406 	EVP_MD_CTX *mdctx = EVP_MD_CTX_create();
407 	if (mdctx is null) {
408         version(HUNT_JWT_DEBUG) warning("Out of memory");
409         return false;
410     }
411     scope(exit) EVP_MD_CTX_destroy(mdctx);
412 
413 	/* Initialize the DigestVerify operation using alg */
414 	if (EVP_DigestVerifyInit(mdctx, null, alg, null, pkey) != 1){
415         version(HUNT_JWT_DEBUG) warning("Invalid argument");
416         return false;
417     }
418 
419 	/* Call update with the message */
420 	if (EVP_DigestVerifyUpdate(mdctx, head.ptr, cast(int)head.length) != 1){
421         version(HUNT_JWT_DEBUG) warning("Invalid argument");
422         return false;
423     }
424 
425     version(HUNT_JWT_DEBUG) {
426         tracef("slen: %d, sig: %(%02X %)", slen, sig);
427     }
428 
429 	/* Now check the sig for validity. */
430 	if (EVP_DigestVerifyFinal(mdctx, cast(ubyte*)sig.ptr, slen) != 1) {
431         version(HUNT_JWT_DEBUG) warning("Invalid argument");
432         return false;
433     }
434 
435 	return true;    
436 
437 }