1 package com.atlassian.asap.api;
2
3 import com.atlassian.asap.core.serializer.Json;
4 import com.google.common.collect.ImmutableList;
5 import com.google.common.collect.ImmutableSet;
6 import org.apache.commons.lang3.builder.EqualsBuilder;
7 import org.apache.commons.lang3.builder.HashCodeBuilder;
8 import org.apache.commons.lang3.builder.ToStringBuilder;
9 import org.apache.commons.lang3.builder.ToStringStyle;
10
11 import javax.json.JsonArray;
12 import javax.json.JsonArrayBuilder;
13 import javax.json.JsonNumber;
14 import javax.json.JsonObject;
15 import javax.json.JsonObjectBuilder;
16 import javax.json.JsonString;
17 import javax.json.JsonValue;
18 import java.io.IOException;
19 import java.io.ObjectInputStream;
20 import java.io.ObjectOutputStream;
21 import java.io.Serializable;
22 import java.time.Duration;
23 import java.time.Instant;
24 import java.util.List;
25 import java.util.Map;
26 import java.util.Objects;
27 import java.util.Optional;
28 import java.util.Set;
29 import java.util.UUID;
30 import java.util.stream.Collectors;
31
32
33
34
35 public final class JwtBuilder {
36
37
38
39
40 public static final Duration DEFAULT_LIFETIME = Duration.ofSeconds(60);
41
42 private SigningAlgorithm alg;
43 private String keyId;
44 private String iss;
45 private Optional<String> sub;
46 private List<String> aud;
47 private Instant iat;
48 private Instant exp;
49 private Optional<Instant> nbf;
50 private String jti;
51 private JsonObject customClaims;
52
53 private JwtBuilder() {
54 Instant now = Instant.now();
55 notBefore(Optional.of(now));
56 issuedAt(now);
57 expirationTime(now.plus(DEFAULT_LIFETIME));
58 jwtId(UUID.randomUUID().toString());
59 algorithm(SigningAlgorithm.RS256);
60 sub = Optional.empty();
61 customClaims = Json.provider().createObjectBuilder().build();
62 }
63
64
65
66
67
68
69
70
71
72
73
74
75 public static JwtBuilder newJwt() {
76 return new JwtBuilder();
77 }
78
79
80
81
82
83
84
85
86 public static JwtBuilder newFromPrototype(Jwt jwtPrototype) {
87 Instant now = Instant.now();
88 return copyJwt(jwtPrototype)
89 .notBefore(Optional.of(now))
90 .issuedAt(now)
91 .expirationTime(now.plus(DEFAULT_LIFETIME))
92 .jwtId(UUID.randomUUID().toString());
93 }
94
95
96
97
98
99
100
101 public static JwtBuilder copyJwt(Jwt prototype) {
102 return new JwtBuilder()
103 .algorithm(prototype.getHeader().getAlgorithm())
104 .keyId(prototype.getHeader().getKeyId())
105 .issuer(prototype.getClaims().getIssuer())
106 .subject(prototype.getClaims().getSubject())
107 .audience(prototype.getClaims().getAudience())
108 .issuedAt(prototype.getClaims().getIssuedAt())
109 .expirationTime(prototype.getClaims().getExpiry())
110 .notBefore(prototype.getClaims().getNotBefore())
111 .jwtId(prototype.getClaims().getJwtId())
112 .customClaims(prototype.getClaims().getJson());
113 }
114
115
116
117
118
119
120
121 public JwtBuilder keyId(String keyId) {
122 this.keyId = keyId;
123 return this;
124 }
125
126
127
128
129
130
131
132 public JwtBuilder algorithm(SigningAlgorithm alg) {
133 this.alg = alg;
134 return this;
135 }
136
137
138
139
140
141
142
143 public JwtBuilder audience(Iterable<String> aud) {
144 this.aud = ImmutableList.copyOf(aud);
145 return this;
146 }
147
148
149
150
151
152
153
154 public JwtBuilder audience(String... aud) {
155 this.aud = ImmutableList.copyOf(aud);
156 return this;
157 }
158
159
160
161
162
163
164
165 public JwtBuilder expirationTime(Instant expiry) {
166 this.exp = expiry;
167 return this;
168 }
169
170
171
172
173
174
175
176 public JwtBuilder issuedAt(Instant iat) {
177 this.iat = iat;
178 return this;
179 }
180
181
182
183
184
185
186
187 public JwtBuilder issuer(String iss) {
188 this.iss = iss;
189 return this;
190 }
191
192
193
194
195
196
197
198 public JwtBuilder jwtId(String jti) {
199 this.jti = jti;
200 return this;
201 }
202
203
204
205
206
207
208
209 public JwtBuilder notBefore(Optional<Instant> nbf) {
210 this.nbf = nbf;
211 return this;
212 }
213
214
215
216
217
218
219
220 public JwtBuilder subject(Optional<String> sub) {
221 this.sub = sub;
222 return this;
223 }
224
225
226
227
228
229
230
231
232 public JwtBuilder customClaims(JsonObject customClaims) {
233 this.customClaims = customClaims;
234 return this;
235 }
236
237
238
239
240
241 public Jwt build() {
242 JwsHeader header = new ImmutableJwsHeader(alg, keyId);
243
244 JsonObjectBuilder claimsJsonObjectBuilder = Json.provider().createObjectBuilder();
245
246 Set<String> optionalRegisteredClaimKeys = ImmutableSet.of(JwtClaims.RegisteredClaim.SUBJECT.key(), JwtClaims.RegisteredClaim.NOT_BEFORE.key());
247 for (Map.Entry<String, JsonValue> entry : customClaims.entrySet()) {
248 if (!optionalRegisteredClaimKeys.contains(entry.getKey())) {
249 claimsJsonObjectBuilder.add(entry.getKey(), entry.getValue());
250 }
251 }
252 for (Map.Entry<String, JsonValue> entry : ImmutableJwtClaims.getRegisteredClaims(iss, sub, aud, exp, nbf, iat, jti).entrySet()) {
253 claimsJsonObjectBuilder.add(entry.getKey(), entry.getValue());
254 }
255 JwtClaims claims = new ImmutableJwtClaims(claimsJsonObjectBuilder.build());
256
257 return new ImmutableJwt(header, claims);
258 }
259
260
261
262
263 private static class ImmutableJwt implements Jwt, Serializable {
264 private static final long serialVersionUID = 4437693510625284065L;
265
266 private final JwsHeader header;
267 private final JwtClaims claimsSet;
268
269 ImmutableJwt(JwsHeader header, JwtClaims claims) {
270 this.header = Objects.requireNonNull(header);
271 this.claimsSet = Objects.requireNonNull(claims);
272 }
273
274 @Override
275 public JwsHeader getHeader() {
276 return header;
277 }
278
279 @Override
280 public JwtClaims getClaims() {
281 return claimsSet;
282 }
283
284 @Override
285 public String toString() {
286 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
287 .append("header", header)
288 .append("claims", claimsSet)
289 .toString();
290 }
291
292 @Override
293 public boolean equals(Object o) {
294 if (o == null) {
295 return false;
296 }
297 if (o == this) {
298 return true;
299 }
300 if (o.getClass() != getClass()) {
301 return false;
302 }
303 ImmutableJwt rhs = (ImmutableJwt) o;
304 return new EqualsBuilder()
305 .append(header, rhs.header)
306 .append(claimsSet, rhs.claimsSet)
307 .isEquals();
308 }
309
310 @Override
311 public int hashCode() {
312 return new HashCodeBuilder()
313 .append(header)
314 .append(claimsSet)
315 .hashCode();
316 }
317 }
318
319
320
321
322 private static final class ImmutableJwsHeader implements JwsHeader, Serializable {
323 private static final long serialVersionUID = -4791575710345791956L;
324
325 private final SigningAlgorithm algorithm;
326 private final String keyId;
327
328 ImmutableJwsHeader(SigningAlgorithm algorithm, String keyId) {
329 this.algorithm = Objects.requireNonNull(algorithm, "JWT header 'alg' cannot be null");
330 this.keyId = Objects.requireNonNull(keyId, "JWT header 'kid' cannot be null");
331 }
332
333 @Override
334 public String getKeyId() {
335 return keyId;
336 }
337
338 @Override
339 public SigningAlgorithm getAlgorithm() {
340 return algorithm;
341 }
342
343 @Override
344 public String toString() {
345 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
346 .append(Header.ALGORITHM.key(), algorithm)
347 .append(Header.KEY_ID.key(), keyId)
348 .toString();
349 }
350
351 @Override
352 public boolean equals(Object o) {
353 if (o == null) {
354 return false;
355 }
356 if (o == this) {
357 return true;
358 }
359 if (o.getClass() != getClass()) {
360 return false;
361 }
362 ImmutableJwsHeader rhs = (ImmutableJwsHeader) o;
363 return new EqualsBuilder()
364 .append(algorithm, rhs.algorithm)
365 .append(keyId, rhs.keyId)
366 .isEquals();
367 }
368
369 @Override
370 public int hashCode() {
371 return new HashCodeBuilder()
372 .append(algorithm)
373 .append(keyId)
374 .hashCode();
375 }
376 }
377
378
379
380
381 private static class ImmutableJwtClaims implements JwtClaims, Serializable {
382 private static final long serialVersionUID = 5227306085811054804L;
383
384 private transient JsonObject claimsJsonObject;
385
386 ImmutableJwtClaims(JsonObject claimsJsonObject) {
387 this.claimsJsonObject = Objects.requireNonNull(claimsJsonObject);
388 }
389
390 static JsonObject getRegisteredClaims(String iss,
391 Optional<String> sub,
392 List<String> aud,
393 Instant exp,
394 Optional<Instant> nbf,
395 Instant iat,
396 String jti) {
397 Objects.requireNonNull(iss, "JWT claim 'iss' cannot be null");
398 Objects.requireNonNull(sub, "JWT claim 'sub' cannot be null (but it can be None)");
399 Objects.requireNonNull(aud, "JWT claim 'aud' cannot be null");
400 Objects.requireNonNull(iat, "JWT claim 'iat' cannot be null");
401 Objects.requireNonNull(exp, "JWT claim 'exp' cannot be null");
402 Objects.requireNonNull(nbf, "JWT claim 'nbf' cannot be null (but it can be None)");
403 Objects.requireNonNull(jti, "JWT claim 'jit' cannot be null");
404
405 JsonObjectBuilder jsonObjectBuilder = Json.provider().createObjectBuilder();
406 jsonObjectBuilder.add(RegisteredClaim.ISSUER.key(), iss);
407 jsonObjectBuilder.add(RegisteredClaim.ISSUED_AT.key(), iat.getEpochSecond());
408 jsonObjectBuilder.add(RegisteredClaim.EXPIRY.key(), exp.getEpochSecond());
409 jsonObjectBuilder.add(RegisteredClaim.JWT_ID.key(), jti);
410 sub.ifPresent(s -> jsonObjectBuilder.add(RegisteredClaim.SUBJECT.key(), s));
411 nbf.ifPresent(n -> jsonObjectBuilder.add(RegisteredClaim.NOT_BEFORE.key(), n.getEpochSecond()));
412 if (aud.size() == 1) {
413 jsonObjectBuilder.add(RegisteredClaim.AUDIENCE.key(), aud.iterator().next());
414 } else {
415 JsonArrayBuilder audienceArrayBuilder = Json.provider().createArrayBuilder();
416 aud.forEach(audienceArrayBuilder::add);
417 jsonObjectBuilder.add(RegisteredClaim.AUDIENCE.key(), audienceArrayBuilder);
418 }
419 return jsonObjectBuilder.build();
420 }
421
422 @Override
423 public String getIssuer() {
424 return claimsJsonObject.getString(RegisteredClaim.ISSUER.key());
425 }
426
427 @Override
428 public Optional<String> getSubject() {
429 return Optional.ofNullable(claimsJsonObject.getString(RegisteredClaim.SUBJECT.key(), null));
430 }
431
432 @Override
433 public Set<String> getAudience() {
434 JsonValue jsonValue = claimsJsonObject.getOrDefault(RegisteredClaim.AUDIENCE.key(), Json.provider().createArrayBuilder().build());
435 switch (jsonValue.getValueType()) {
436 case ARRAY:
437 return ((JsonArray) jsonValue).getValuesAs(JsonString.class).stream().map(JsonString::getString).collect(Collectors.toSet());
438 case STRING:
439 return ImmutableSet.of(((JsonString) jsonValue).getString());
440 default:
441 throw new RuntimeException("Unexpected 'aud' claim type");
442 }
443 }
444
445 @Override
446 public Instant getExpiry() {
447 return Instant.ofEpochSecond(claimsJsonObject.getJsonNumber(RegisteredClaim.EXPIRY.key()).longValueExact());
448 }
449
450 @Override
451 public Optional<Instant> getNotBefore() {
452 return Optional.ofNullable(claimsJsonObject.getJsonNumber(RegisteredClaim.NOT_BEFORE.key())).map(JsonNumber::longValueExact).map(Instant::ofEpochSecond);
453 }
454
455 @Override
456 public Instant getIssuedAt() {
457 return Instant.ofEpochSecond(claimsJsonObject.getJsonNumber(RegisteredClaim.ISSUED_AT.key()).longValueExact());
458 }
459
460 @Override
461 public String getJwtId() {
462 return claimsJsonObject.getString(RegisteredClaim.JWT_ID.key());
463 }
464
465 @Override
466 public JsonObject getJson() {
467 return claimsJsonObject;
468 }
469
470 @Override
471 public String toString() {
472 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
473 .append(RegisteredClaim.ISSUER.key(), getIssuer())
474 .append(RegisteredClaim.SUBJECT.key(), getSubject())
475 .append(RegisteredClaim.AUDIENCE.key(), getAudience())
476 .append(RegisteredClaim.ISSUED_AT.key(), getIssuedAt())
477 .append(RegisteredClaim.EXPIRY.key(), getExpiry())
478 .append(RegisteredClaim.NOT_BEFORE.key(), getNotBefore())
479 .append(RegisteredClaim.JWT_ID.key(), getJwtId())
480 .toString();
481 }
482
483 @Override
484 public boolean equals(Object o) {
485 if (o == null) {
486 return false;
487 }
488 if (o == this) {
489 return true;
490 }
491 if (o.getClass() != getClass()) {
492 return false;
493 }
494 ImmutableJwtClaims rhs = (ImmutableJwtClaims) o;
495 return claimsJsonObject.equals(rhs.claimsJsonObject);
496 }
497
498 @Override
499 public int hashCode() {
500 return claimsJsonObject.hashCode();
501 }
502
503 private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
504 inputStream.defaultReadObject();
505 this.claimsJsonObject = Json.provider().createReader(inputStream).readObject();
506 }
507
508 private void writeObject(ObjectOutputStream outputStream) throws IOException {
509 outputStream.defaultWriteObject();
510 Json.provider().createWriter(outputStream).writeObject(this.claimsJsonObject);
511 }
512 }
513 }