1 package com.atlassian.asap.nimbus.parser;
2
3 import com.atlassian.asap.api.JwsHeader;
4 import com.atlassian.asap.api.Jwt;
5 import com.atlassian.asap.api.JwtBuilder;
6 import com.atlassian.asap.api.SigningAlgorithm;
7 import com.atlassian.asap.core.SecurityProvider;
8 import com.atlassian.asap.core.exception.SignatureMismatchException;
9 import com.atlassian.asap.core.exception.UnsupportedAlgorithmException;
10 import com.atlassian.asap.core.parser.VerifiableJwt;
11 import com.atlassian.asap.nimbus.serializer.NimbusJwtSerializer;
12 import com.nimbusds.jose.JWSObject;
13 import com.nimbusds.jose.crypto.RSASSAVerifier;
14 import com.nimbusds.jwt.JWTClaimsSet;
15 import org.junit.Before;
16 import org.junit.Test;
17 import org.junit.runner.RunWith;
18 import org.mockito.Mock;
19 import org.mockito.runners.MockitoJUnitRunner;
20
21 import java.security.KeyPair;
22 import java.security.KeyPairGenerator;
23 import java.security.Provider;
24 import java.security.interfaces.RSAPublicKey;
25 import java.time.Instant;
26 import java.time.temporal.ChronoUnit;
27 import java.util.Collections;
28 import java.util.Optional;
29
30 import static org.junit.Assert.assertEquals;
31 import static org.junit.Assert.fail;
32 import static org.mockito.Matchers.any;
33 import static org.mockito.Mockito.verify;
34 import static org.mockito.Mockito.when;
35
36 @RunWith(MockitoJUnitRunner.class)
37 public class NimbusVerifiableJwtTest {
38 private static final Provider PROVIDER = SecurityProvider.getProvider();
39
40 @Mock
41 private Jwt jwt;
42 @Mock
43 private JWSObject jwsObject;
44 @Mock
45 private RSAPublicKey rsaPublicKey;
46
47 @Mock
48 private JwsHeader jwsHeader;
49
50 @Before
51 public void setUpMocks() {
52 when(jwt.getHeader()).thenReturn(jwsHeader);
53 }
54
55 @Test
56 public void shouldVerifyValidRS256Signature() throws Exception {
57 when(jwsHeader.getAlgorithm()).thenReturn(SigningAlgorithm.RS256);
58 when(jwsObject.verify(any(RSASSAVerifier.class))).thenReturn(true);
59
60 new NimbusVerifiableJwt(jwt, jwsObject, PROVIDER).verifySignature(rsaPublicKey);
61
62 verify(jwsObject).verify(any(RSASSAVerifier.class));
63 }
64
65 @Test
66 public void shouldThrowIfRS256SignatureIsInvalid() throws Exception {
67 when(jwsHeader.getAlgorithm()).thenReturn(SigningAlgorithm.RS256);
68 when(jwsObject.verify(any(RSASSAVerifier.class))).thenReturn(false);
69
70 try {
71 new NimbusVerifiableJwt(jwt, jwsObject, PROVIDER).verifySignature(rsaPublicKey);
72 fail("Should have thrown");
73 } catch (SignatureMismatchException ex) {
74 verify(jwsObject).verify(any(RSASSAVerifier.class));
75 }
76 }
77
78 @Test(expected = UnsupportedAlgorithmException.class)
79 public void shouldThrowIfAlgorithmIsUnsupported() throws Exception {
80 when(jwsHeader.getAlgorithm()).thenReturn(SigningAlgorithm.ES256);
81
82 new NimbusVerifiableJwt(jwt, jwsObject, PROVIDER).verifySignature(rsaPublicKey);
83 }
84
85 @Test
86 public void shouldReturnValidVerifiableJwt() throws Exception {
87
88 KeyPair keyPair = KeyPairGenerator.getInstance("RSA").generateKeyPair();
89
90 Instant now = Instant.now();
91
92 Jwt unverifiedJwt = JwtBuilder.newJwt()
93 .algorithm(SigningAlgorithm.RS256)
94 .keyId("my_key")
95 .issuer("my_issuer")
96 .jwtId("my_jwtId")
97 .subject(Optional.of("my_subject"))
98 .audience("my_audience")
99 .expirationTime(now.plusSeconds(60))
100 .issuedAt(now)
101 .notBefore(Optional.of(now))
102 .build();
103
104 String serializedJwt = new NimbusJwtSerializer().serialize(unverifiedJwt, keyPair.getPrivate());
105 JWSObject jwsObject = JWSObject.parse(serializedJwt);
106 JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(jwsObject.getPayload().toJSONObject());
107
108
109 VerifiableJwt result = NimbusVerifiableJwt.buildVerifiableJwt(jwsObject, jwtClaimsSet, PROVIDER);
110
111 assertEquals(SigningAlgorithm.RS256, result.getHeader().getAlgorithm());
112 assertEquals("my_key", result.getHeader().getKeyId());
113 assertEquals("my_issuer", result.getClaims().getIssuer());
114 assertEquals("my_jwtId", result.getClaims().getJwtId());
115 assertEquals(Optional.of("my_subject"), result.getClaims().getSubject());
116 assertEquals(Collections.singleton("my_audience"), result.getClaims().getAudience());
117 assertEquals(now.plusSeconds(60).truncatedTo(ChronoUnit.SECONDS), result.getClaims().getExpiry());
118 assertEquals(now.truncatedTo(ChronoUnit.SECONDS), result.getClaims().getIssuedAt());
119 assertEquals(Optional.of(now.truncatedTo(ChronoUnit.SECONDS)), result.getClaims().getNotBefore());
120 }
121
122
123 }