View Javadoc

1   package com.atlassian.plugins.rest.common.security.jersey;
2   
3   import com.atlassian.plugin.tracker.PluginModuleTracker;
4   import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5   import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6   import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7   import com.google.common.base.Function;
8   import com.google.common.base.Joiner;
9   import com.google.common.base.Predicate;
10  import com.google.common.collect.ImmutableSet;
11  import com.google.common.collect.Iterables;
12  import com.google.common.collect.Sets;
13  import com.sun.jersey.spi.container.ContainerRequest;
14  import com.sun.jersey.spi.container.ContainerRequestFilter;
15  import com.sun.jersey.spi.container.ContainerResponse;
16  import com.sun.jersey.spi.container.ContainerResponseFilter;
17  import com.sun.jersey.spi.container.ResourceFilter;
18  import org.slf4j.Logger;
19  import org.slf4j.LoggerFactory;
20  
21  import javax.ws.rs.core.Response;
22  import java.net.URI;
23  import java.util.Arrays;
24  import java.util.Collections;
25  import java.util.List;
26  import java.util.Set;
27  import java.util.HashSet;
28  import java.util.Locale;
29  
30  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS;
31  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
32  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_METHODS;
33  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
34  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_EXPOSE_HEADERS;
35  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_MAX_AGE;
36  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
37  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_METHOD;
38  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ORIGIN;
39  import static com.google.common.collect.Lists.newArrayList;
40  import static com.google.common.collect.Sets.newHashSet;
41  
42  /**
43   * A filter that handles Cross-Origin Resource Sharing preflight checks and response headers.  Handles simple and preflight
44   * requests.
45   *
46   * See spec at http://www.w3.org/TR/cors
47   *
48   * @since 2.6
49   */
50  public class CorsResourceFilter implements ResourceFilter, ContainerRequestFilter, ContainerResponseFilter {
51      private static final String CORS_PREFLIGHT_FAILED = "Cors-Preflight-Failed";
52      private static final String CORS_PREFLIGHT_SUCCEEDED = "Cors-Preflight-Succeeded";
53      public static final String CORS_PREFLIGHT_REQUESTED = "Cors-Preflight-Requested";
54      private static final Logger log = LoggerFactory.getLogger(CorsResourceFilter.class);
55  
56      private final PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker;
57      private final String allowMethod;
58  
59      public CorsResourceFilter(PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker, String allowMethod) {
60          this.allowMethod = allowMethod;
61          this.pluginModuleTracker = pluginModuleTracker;
62      }
63  
64      /**
65       * Adds the appropriate <a href="http://www.w3.org/TR/cors/#resource-preflight-requests">
66       * cors preflight </a>
67       * response headers for cors preflight requests from a whitelisted origin.
68       *
69       * @param request the request.
70       * @return request if the request is not a cors preflight request.
71       */
72      public ContainerRequest filter(final ContainerRequest request) {
73          if (!request.getProperties().containsKey(CORS_PREFLIGHT_REQUESTED)) {
74              return request;
75          }
76          Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
77          try {
78              String origin = validateSingleOriginInWhitelist(defaults, request);
79              Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
80  
81              Response.ResponseBuilder response = Response.ok();
82              validateAccessControlRequestMethod(allowMethod, request);
83              Set<String> allowedRequestHeaders = getAllowedRequestHeaders(defaultsWithAllowedOrigin, origin);
84              validateAccessControlRequestHeaders(allowedRequestHeaders, request);
85  
86              addAccessControlAllowOrigin(response, origin);
87              conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
88              addAccessControlMaxAge(response);
89              addAccessControlAllowMethods(response, allowMethod);
90              addAccessControlAllowHeaders(response, allowedRequestHeaders);
91  
92              request.getProperties().put(CORS_PREFLIGHT_SUCCEEDED, "true");
93              // exceptions are the only way to return a response here in Jersey
94              throw new CorsPreflightCheckCompleteException(response.build());
95          } catch (PreflightFailedException ex) {
96              Response.ResponseBuilder response = Response.ok();
97              request.getProperties().put(CORS_PREFLIGHT_FAILED, "true");
98              log.info("CORS preflight failed: " + ex.getMessage());
99              throw new CorsPreflightCheckCompleteException(response.build());
100         }
101 
102     }
103 
104     /**
105      * Adds the appropriate cors response headers to the response of
106      * <a href="http://www.w3.org/TR/cors/#resource-requests">cors requests</a>
107      * from a whitelisted origin.
108      *
109      * @param request           the request.
110      * @param containerResponse the response.
111      * @return containerResponse
112      */
113     public ContainerResponse filter(ContainerRequest request, ContainerResponse containerResponse) {
114         if (request.getProperties().containsKey(CORS_PREFLIGHT_FAILED) ||
115                 request.getProperties().containsKey(CORS_PREFLIGHT_SUCCEEDED) ||
116                 extractOrigin(request) == null) {
117             return containerResponse;
118         }
119 
120         Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
121         try {
122             String origin = validateSingleOriginInWhitelist(defaults, request);
123             Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
124 
125             Response.ResponseBuilder response = Response.fromResponse(containerResponse.getResponse());
126             addAccessControlAllowOrigin(response, origin);
127             conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
128             addAccessControlExposeHeaders(response, getAllowedResponseHeaders(defaultsWithAllowedOrigin, origin));
129             containerResponse.setResponse(response.build());
130             return containerResponse;
131         } catch (PreflightFailedException ex) {
132             log.info("Unable to add CORS headers to response: " + ex.getMessage());
133         }
134 
135         return containerResponse;
136     }
137 
138     private void addAccessControlExposeHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders) {
139         response.header(ACCESS_CONTROL_EXPOSE_HEADERS.value(), Joiner.on(", ").join(allowedHeaders));
140     }
141 
142     private void addAccessControlAllowHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders) {
143         response.header(ACCESS_CONTROL_ALLOW_HEADERS.value(), Joiner.on(", ").join(allowedHeaders));
144     }
145 
146     private void addAccessControlAllowMethods(Response.ResponseBuilder response, String allowMethod) {
147         response.header(ACCESS_CONTROL_ALLOW_METHODS.value(), allowMethod);
148     }
149 
150     private void addAccessControlMaxAge(Response.ResponseBuilder response) {
151         response.header(ACCESS_CONTROL_MAX_AGE.value(), 60 * 60);
152     }
153 
154     private void addAccessControlAllowOrigin(Response.ResponseBuilder response, String origin) {
155         response.header(ACCESS_CONTROL_ALLOW_ORIGIN.value(), origin);
156     }
157 
158     private void conditionallyAddAccessControlAllowCredentials(Response.ResponseBuilder response, String origin, Iterable<CorsDefaults> defaultsWithAllowedOrigin) {
159         if (anyAllowsCredentials(defaultsWithAllowedOrigin, origin)) {
160             response.header(ACCESS_CONTROL_ALLOW_CREDENTIALS.value(), "true");
161         }
162     }
163 
164     private void validateAccessControlRequestHeaders(Set<String> allowedHeaders, ContainerRequest request) throws PreflightFailedException {
165         List<String> requestedHeaders = request.getRequestHeader(ACCESS_CONTROL_REQUEST_HEADERS.value());
166         requestedHeaders = requestedHeaders != null ? requestedHeaders : Collections.<String>emptyList();
167         Set<String> flatRequestedHeaders = new HashSet<String>();
168         for (String requestedHeader : requestedHeaders) {
169             flatRequestedHeaders.addAll(Arrays.asList(
170                     requestedHeader.toLowerCase(Locale.US).trim().split("\\s*,\\s*")));
171         }
172         ImmutableSet<String> allowedHeadersLowerCase = ImmutableSet.copyOf(
173                 Iterables.transform(allowedHeaders,
174                         new Function<String, String>() {
175                             public String apply(String from) {
176                                 return from.toLowerCase(Locale.US);
177                             }
178                         })
179         );
180         final Set<String> difference = Sets.difference(flatRequestedHeaders,
181                 allowedHeadersLowerCase);
182         if (!difference.isEmpty()) {
183             throw new PreflightFailedException(
184                     "Unexpected headers in CORS request: " + newArrayList(difference));
185         }
186     }
187 
188     private void validateAccessControlRequestMethod(String allowMethod, ContainerRequest request) throws PreflightFailedException {
189         String requestedMethod = request.getHeaderValue(ACCESS_CONTROL_REQUEST_METHOD.value());
190         if (!allowMethod.equals(requestedMethod)) {
191             throw new PreflightFailedException("Invalid method: " + requestedMethod);
192         }
193     }
194 
195     private String validateSingleOriginInWhitelist(Iterable<CorsDefaults> defaults, ContainerRequest request) throws PreflightFailedException {
196         String origin = extractOrigin(request);
197         validateOriginAsUri(origin);
198 
199         if (Iterables.isEmpty(allowsOrigin(defaults, origin))) {
200             throw new PreflightFailedException("Origin '" + origin + "' not in whitelist");
201         }
202         return origin;
203     }
204 
205     private void validateOriginAsUri(String origin) throws PreflightFailedException {
206         try {
207             final URI originUri = URI.create(origin);
208             if (originUri.isOpaque() || !originUri.isAbsolute()) {
209                 throw new IllegalArgumentException(
210                         "The origin URI must be absolute and not opaque.");
211             }
212         } catch (IllegalArgumentException ex) {
213             throw new PreflightFailedException("Origin '" + origin + "' is not a valid URI");
214         }
215     }
216 
217     public static String extractOrigin(ContainerRequest request) {
218         return request.getHeaderValue(ORIGIN.value());
219     }
220 
221     public ContainerRequestFilter getRequestFilter() {
222         return this;
223     }
224 
225     public ContainerResponseFilter getResponseFilter() {
226         return this;
227     }
228 
229     /**
230      * Thrown if the preflight or simple cross-origin check process fails
231      */
232     private static class PreflightFailedException extends Exception {
233         private PreflightFailedException(String message) {
234             super(message);
235         }
236     }
237 
238     private static Iterable<CorsDefaults> allowsOrigin(Iterable<CorsDefaults> delegates, final String uri) {
239         return Iterables.filter(delegates, new Predicate<CorsDefaults>() {
240             public boolean apply(CorsDefaults delegate) {
241                 return delegate.allowsOrigin(uri);
242             }
243         });
244     }
245 
246     private static boolean anyAllowsCredentials(Iterable<CorsDefaults> delegatesWhichAllowOrigin, final String uri) {
247         for (CorsDefaults defs : delegatesWhichAllowOrigin) {
248             if (defs.allowsCredentials(uri)) {
249                 return true;
250             }
251         }
252         return false;
253     }
254 
255 
256     private static Set<String> getAllowedRequestHeaders(Iterable<CorsDefaults> delegatesWhichAllowOrigin, String uri) {
257         Set<String> result = newHashSet();
258         for (CorsDefaults defs : delegatesWhichAllowOrigin) {
259             result.addAll(defs.getAllowedRequestHeaders(uri));
260         }
261         return result;
262     }
263 
264     private static Set<String> getAllowedResponseHeaders(Iterable<CorsDefaults> delegatesWithAllowedOrigin, String uri) {
265         Set<String> result = newHashSet();
266         for (CorsDefaults defs : delegatesWithAllowedOrigin) {
267             result.addAll(defs.getAllowedResponseHeaders(uri));
268         }
269         return result;
270     }
271 
272 }