View Javadoc

1   package com.atlassian.asap.api;
2   
3   import com.atlassian.asap.core.serializer.Json;
4   import com.google.common.collect.ImmutableList;
5   import com.google.common.collect.ImmutableSet;
6   import org.apache.commons.lang3.builder.EqualsBuilder;
7   import org.apache.commons.lang3.builder.HashCodeBuilder;
8   import org.apache.commons.lang3.builder.ToStringBuilder;
9   import org.apache.commons.lang3.builder.ToStringStyle;
10  
11  import javax.json.JsonArray;
12  import javax.json.JsonArrayBuilder;
13  import javax.json.JsonNumber;
14  import javax.json.JsonObject;
15  import javax.json.JsonObjectBuilder;
16  import javax.json.JsonString;
17  import javax.json.JsonValue;
18  import java.io.IOException;
19  import java.io.ObjectInputStream;
20  import java.io.ObjectOutputStream;
21  import java.io.Serializable;
22  import java.time.Duration;
23  import java.time.Instant;
24  import java.util.List;
25  import java.util.Map;
26  import java.util.Objects;
27  import java.util.Optional;
28  import java.util.Set;
29  import java.util.UUID;
30  import java.util.stream.Collectors;
31  
32  /**
33   * A fluent builder for constructing a {@link com.atlassian.asap.api.Jwt} object.
34   */
35  public final class JwtBuilder {
36      /**
37       * Token lifetime, unless a specific lifetime is explicitly set. This is the time span between the
38       * iat and the exp claims.
39       */
40      public static final Duration DEFAULT_LIFETIME = Duration.ofSeconds(60);
41  
42      private SigningAlgorithm alg;
43      private String keyId;
44      private String iss;
45      private Optional<String> sub;
46      private List<String> aud;
47      private Instant iat;
48      private Instant exp;
49      private Optional<Instant> nbf;
50      private String jti;
51      private JsonObject customClaims;
52  
53      private JwtBuilder() {
54          Instant now = Instant.now();
55          notBefore(Optional.of(now));
56          issuedAt(now);
57          expirationTime(now.plus(DEFAULT_LIFETIME));
58          jwtId(UUID.randomUUID().toString());
59          algorithm(SigningAlgorithm.RS256);
60          sub = Optional.empty();
61          customClaims = Json.provider().createObjectBuilder().build();
62      }
63  
64      /**
65       * Construct a simple jwt builder initialised with default claim values as follows:
66       * <ul>
67       * <li> nbf, iat claim set to current system time </li>
68       * <li> exp claim set current time plus default expiry as defined in {@link JwtBuilder#DEFAULT_LIFETIME} </li>
69       * <li> jti claim set to a random UUID. </li>
70       * <li> alg header set to {@link SigningAlgorithm#RS256}. </li>
71       * </ul>
72       *
73       * @return a new fluent builder
74       */
75      public static JwtBuilder newJwt() {
76          return new JwtBuilder();
77      }
78  
79      /**
80       * Construct a JWT builder initialised from a prototype JWT. Time claims are rebased to the current time,
81       * and the jti is set to a random UUID.
82       *
83       * @param jwtPrototype Jwt to use as prototype
84       * @return a new fluent builder
85       */
86      public static JwtBuilder newFromPrototype(Jwt jwtPrototype) {
87          Instant now = Instant.now();
88          return copyJwt(jwtPrototype)
89                  .notBefore(Optional.of(now))
90                  .issuedAt(now)
91                  .expirationTime(now.plus(DEFAULT_LIFETIME))
92                  .jwtId(UUID.randomUUID().toString());
93      }
94  
95      /**
96       * Returns a builder initialised to make an identical copy of the given original Jwt.
97       *
98       * @param prototype Jwt to use as original
99       * @return a Jwt builder initialised to be an identical copy of the original
100      */
101     public static JwtBuilder copyJwt(Jwt prototype) {
102         return new JwtBuilder()
103                 .algorithm(prototype.getHeader().getAlgorithm())
104                 .keyId(prototype.getHeader().getKeyId())
105                 .issuer(prototype.getClaims().getIssuer())
106                 .subject(prototype.getClaims().getSubject())
107                 .audience(prototype.getClaims().getAudience())
108                 .issuedAt(prototype.getClaims().getIssuedAt())
109                 .expirationTime(prototype.getClaims().getExpiry())
110                 .notBefore(prototype.getClaims().getNotBefore())
111                 .jwtId(prototype.getClaims().getJwtId())
112                 .customClaims(prototype.getClaims().getJson());
113     }
114 
115     /**
116      * Sets the key id jws header.
117      *
118      * @param keyId the key id for the jws header of this jwt
119      * @return the fluent builder
120      */
121     public JwtBuilder keyId(String keyId) {
122         this.keyId = keyId;
123         return this;
124     }
125 
126     /**
127      * Sets the algorithm (alg) jws header.
128      *
129      * @param alg the alg for the jws header of this jwt
130      * @return the fluent builder
131      */
132     public JwtBuilder algorithm(SigningAlgorithm alg) {
133         this.alg = alg;
134         return this;
135     }
136 
137     /**
138      * Sets the audience (aud) claim.
139      *
140      * @param aud an iterable containing one or more audiences for the jwt
141      * @return the fluent builder
142      */
143     public JwtBuilder audience(Iterable<String> aud) {
144         this.aud = ImmutableList.copyOf(aud);
145         return this;
146     }
147 
148     /**
149      * Sets the audience (aud) claim.
150      *
151      * @param aud one or more audiences for the jwt
152      * @return the fluent builder
153      */
154     public JwtBuilder audience(String... aud) {
155         this.aud = ImmutableList.copyOf(aud);
156         return this;
157     }
158 
159     /**
160      * Sets the expiration time (exp) claim.
161      *
162      * @param expiry the expiration time
163      * @return the fluent builder
164      */
165     public JwtBuilder expirationTime(Instant expiry) {
166         this.exp = expiry;
167         return this;
168     }
169 
170     /**
171      * Set the issued at (iat) claim.
172      *
173      * @param iat the issued at time
174      * @return the fluent builder
175      */
176     public JwtBuilder issuedAt(Instant iat) {
177         this.iat = iat;
178         return this;
179     }
180 
181     /**
182      * Set the issuer (iss) claim.
183      *
184      * @param iss the issuer
185      * @return the fluent builder
186      */
187     public JwtBuilder issuer(String iss) {
188         this.iss = iss;
189         return this;
190     }
191 
192     /**
193      * Set the jwt id (jti) claim.
194      *
195      * @param jti a unique id for the jwt
196      * @return the fluent builder
197      */
198     public JwtBuilder jwtId(String jti) {
199         this.jti = jti;
200         return this;
201     }
202 
203     /**
204      * Set the not before (nbf) claim.
205      *
206      * @param nbf the not before date
207      * @return the fluent builder
208      */
209     public JwtBuilder notBefore(Optional<Instant> nbf) {
210         this.nbf = nbf;
211         return this;
212     }
213 
214     /**
215      * Sets the subject (sub) claim for this jwt.
216      *
217      * @param sub the subject
218      * @return the fluent builder
219      */
220     public JwtBuilder subject(Optional<String> sub) {
221         this.sub = sub;
222         return this;
223     }
224 
225     /**
226      * Sets the custom claims for this jwt. Any previously set custom claim is discarded. If any of the custom
227      * claims conflicts with a registered claim, the registered claim takes precedence.
228      *
229      * @param customClaims the custom claims
230      * @return the fluent builder
231      */
232     public JwtBuilder customClaims(JsonObject customClaims) {
233         this.customClaims = customClaims;
234         return this;
235     }
236 
237     /**
238      * @return a serializable JWT object representing all the values specified to this builder
239      * @throws java.lang.NullPointerException if some required parameter has not been specified
240      */
241     public Jwt build() {
242         JwsHeader header = new ImmutableJwsHeader(alg, keyId);
243 
244         JsonObjectBuilder claimsJsonObjectBuilder = Json.provider().createObjectBuilder();
245         // custom claims are added first, so in case in conflict with a registered claim the latter wins
246         Set<String> optionalRegisteredClaimKeys = ImmutableSet.of(JwtClaims.RegisteredClaim.SUBJECT.key(), JwtClaims.RegisteredClaim.NOT_BEFORE.key());
247         for (Map.Entry<String, JsonValue> entry : customClaims.entrySet()) {
248             if (!optionalRegisteredClaimKeys.contains(entry.getKey())) {
249                 claimsJsonObjectBuilder.add(entry.getKey(), entry.getValue());
250             }
251         }
252         for (Map.Entry<String, JsonValue> entry : ImmutableJwtClaims.getRegisteredClaims(iss, sub, aud, exp, nbf, iat, jti).entrySet()) {
253             claimsJsonObjectBuilder.add(entry.getKey(), entry.getValue());
254         }
255         JwtClaims claims = new ImmutableJwtClaims(claimsJsonObjectBuilder.build());
256 
257         return new ImmutableJwt(header, claims);
258     }
259 
260     /**
261      * An immutable value object that represents a JWT.
262      */
263     private static class ImmutableJwt implements Jwt, Serializable {
264         private static final long serialVersionUID = 4437693510625284065L;
265 
266         private final JwsHeader header;
267         private final JwtClaims claimsSet;
268 
269         ImmutableJwt(JwsHeader header, JwtClaims claims) {
270             this.header = Objects.requireNonNull(header);
271             this.claimsSet = Objects.requireNonNull(claims);
272         }
273 
274         @Override
275         public JwsHeader getHeader() {
276             return header;
277         }
278 
279         @Override
280         public JwtClaims getClaims() {
281             return claimsSet;
282         }
283 
284         @Override
285         public String toString() {
286             return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
287                     .append("header", header)
288                     .append("claims", claimsSet)
289                     .toString();
290         }
291 
292         @Override
293         public boolean equals(Object o) {
294             if (o == null) {
295                 return false;
296             }
297             if (o == this) {
298                 return true;
299             }
300             if (o.getClass() != getClass()) {
301                 return false;
302             }
303             ImmutableJwt rhs = (ImmutableJwt) o;
304             return new EqualsBuilder()
305                     .append(header, rhs.header)
306                     .append(claimsSet, rhs.claimsSet)
307                     .isEquals();
308         }
309 
310         @Override
311         public int hashCode() {
312             return new HashCodeBuilder()
313                     .append(header)
314                     .append(claimsSet)
315                     .hashCode();
316         }
317     }
318 
319     /**
320      * An immutable value object that represents the information contained in the JWS header.
321      */
322     private static final class ImmutableJwsHeader implements JwsHeader, Serializable {
323         private static final long serialVersionUID = -4791575710345791956L;
324 
325         private final SigningAlgorithm algorithm;
326         private final String keyId;
327 
328         ImmutableJwsHeader(SigningAlgorithm algorithm, String keyId) {
329             this.algorithm = Objects.requireNonNull(algorithm, "JWT header 'alg' cannot be null");
330             this.keyId = Objects.requireNonNull(keyId, "JWT header 'kid' cannot be null");
331         }
332 
333         @Override
334         public String getKeyId() {
335             return keyId;
336         }
337 
338         @Override
339         public SigningAlgorithm getAlgorithm() {
340             return algorithm;
341         }
342 
343         @Override
344         public String toString() {
345             return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
346                     .append(Header.ALGORITHM.key(), algorithm)
347                     .append(Header.KEY_ID.key(), keyId)
348                     .toString();
349         }
350 
351         @Override
352         public boolean equals(Object o) {
353             if (o == null) {
354                 return false;
355             }
356             if (o == this) {
357                 return true;
358             }
359             if (o.getClass() != getClass()) {
360                 return false;
361             }
362             ImmutableJwsHeader rhs = (ImmutableJwsHeader) o;
363             return new EqualsBuilder()
364                     .append(algorithm, rhs.algorithm)
365                     .append(keyId, rhs.keyId)
366                     .isEquals();
367         }
368 
369         @Override
370         public int hashCode() {
371             return new HashCodeBuilder()
372                     .append(algorithm)
373                     .append(keyId)
374                     .hashCode();
375         }
376     }
377 
378     /**
379      * An immutable value object that represents the information contained in the claims (payload) of a JWT.
380      */
381     private static class ImmutableJwtClaims implements JwtClaims, Serializable {
382         private static final long serialVersionUID = 5227306085811054804L;
383 
384         private transient JsonObject claimsJsonObject;  // cannot be final due to custom serialization
385 
386         ImmutableJwtClaims(JsonObject claimsJsonObject) {
387             this.claimsJsonObject = Objects.requireNonNull(claimsJsonObject);
388         }
389 
390         static JsonObject getRegisteredClaims(String iss,
391                                               Optional<String> sub,
392                                               List<String> aud,
393                                               Instant exp,
394                                               Optional<Instant> nbf,
395                                               Instant iat,
396                                               String jti) {
397             Objects.requireNonNull(iss, "JWT claim 'iss' cannot be null");
398             Objects.requireNonNull(sub, "JWT claim 'sub' cannot be null (but it can be None)");
399             Objects.requireNonNull(aud, "JWT claim 'aud' cannot be null");
400             Objects.requireNonNull(iat, "JWT claim 'iat' cannot be null");
401             Objects.requireNonNull(exp, "JWT claim 'exp' cannot be null");
402             Objects.requireNonNull(nbf, "JWT claim 'nbf' cannot be null (but it can be None)");
403             Objects.requireNonNull(jti, "JWT claim 'jit' cannot be null");
404 
405             JsonObjectBuilder jsonObjectBuilder = Json.provider().createObjectBuilder();
406             jsonObjectBuilder.add(RegisteredClaim.ISSUER.key(), iss);
407             jsonObjectBuilder.add(RegisteredClaim.ISSUED_AT.key(), iat.getEpochSecond());
408             jsonObjectBuilder.add(RegisteredClaim.EXPIRY.key(), exp.getEpochSecond());
409             jsonObjectBuilder.add(RegisteredClaim.JWT_ID.key(), jti);
410             sub.ifPresent(s -> jsonObjectBuilder.add(RegisteredClaim.SUBJECT.key(), s));
411             nbf.ifPresent(n -> jsonObjectBuilder.add(RegisteredClaim.NOT_BEFORE.key(), n.getEpochSecond()));
412             if (aud.size() == 1) { // optimise JSON in case of a single audience
413                 jsonObjectBuilder.add(RegisteredClaim.AUDIENCE.key(), aud.iterator().next());
414             } else {
415                 JsonArrayBuilder audienceArrayBuilder = Json.provider().createArrayBuilder();
416                 aud.forEach(audienceArrayBuilder::add);
417                 jsonObjectBuilder.add(RegisteredClaim.AUDIENCE.key(), audienceArrayBuilder);
418             }
419             return jsonObjectBuilder.build();
420         }
421 
422         @Override
423         public String getIssuer() {
424             return claimsJsonObject.getString(RegisteredClaim.ISSUER.key());
425         }
426 
427         @Override
428         public Optional<String> getSubject() {
429             return Optional.ofNullable(claimsJsonObject.getString(RegisteredClaim.SUBJECT.key(), null));
430         }
431 
432         @Override
433         public Set<String> getAudience() {
434             JsonValue jsonValue = claimsJsonObject.getOrDefault(RegisteredClaim.AUDIENCE.key(), Json.provider().createArrayBuilder().build());
435             switch (jsonValue.getValueType()) {
436                 case ARRAY:
437                     return ((JsonArray) jsonValue).getValuesAs(JsonString.class).stream().map(JsonString::getString).collect(Collectors.toSet());
438                 case STRING:
439                     return ImmutableSet.of(((JsonString) jsonValue).getString());
440                 default:
441                     throw new RuntimeException("Unexpected 'aud' claim type");
442             }
443         }
444 
445         @Override
446         public Instant getExpiry() {
447             return Instant.ofEpochSecond(claimsJsonObject.getJsonNumber(RegisteredClaim.EXPIRY.key()).longValueExact());
448         }
449 
450         @Override
451         public Optional<Instant> getNotBefore() {
452             return Optional.ofNullable(claimsJsonObject.getJsonNumber(RegisteredClaim.NOT_BEFORE.key())).map(JsonNumber::longValueExact).map(Instant::ofEpochSecond);
453         }
454 
455         @Override
456         public Instant getIssuedAt() {
457             return Instant.ofEpochSecond(claimsJsonObject.getJsonNumber(RegisteredClaim.ISSUED_AT.key()).longValueExact());
458         }
459 
460         @Override
461         public String getJwtId() {
462             return claimsJsonObject.getString(RegisteredClaim.JWT_ID.key());
463         }
464 
465         @Override
466         public JsonObject getJson() {
467             return claimsJsonObject;
468         }
469 
470         @Override
471         public String toString() {
472             return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
473                     .append(RegisteredClaim.ISSUER.key(), getIssuer())
474                     .append(RegisteredClaim.SUBJECT.key(), getSubject())
475                     .append(RegisteredClaim.AUDIENCE.key(), getAudience())
476                     .append(RegisteredClaim.ISSUED_AT.key(), getIssuedAt())
477                     .append(RegisteredClaim.EXPIRY.key(), getExpiry())
478                     .append(RegisteredClaim.NOT_BEFORE.key(), getNotBefore())
479                     .append(RegisteredClaim.JWT_ID.key(), getJwtId())
480                     .toString();
481         }
482 
483         @Override
484         public boolean equals(Object o) {
485             if (o == null) {
486                 return false;
487             }
488             if (o == this) {
489                 return true;
490             }
491             if (o.getClass() != getClass()) {
492                 return false;
493             }
494             ImmutableJwtClaims rhs = (ImmutableJwtClaims) o;
495             return claimsJsonObject.equals(rhs.claimsJsonObject);
496         }
497 
498         @Override
499         public int hashCode() {
500             return claimsJsonObject.hashCode();
501         }
502 
503         private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
504             inputStream.defaultReadObject();
505             this.claimsJsonObject = Json.provider().createReader(inputStream).readObject();
506         }
507 
508         private void writeObject(ObjectOutputStream outputStream) throws IOException {
509             outputStream.defaultWriteObject();
510             Json.provider().createWriter(outputStream).writeObject(this.claimsJsonObject);
511         }
512     }
513 }