View Javadoc

1   package com.atlassian.plugins.rest.common.security.jersey;
2   
3   import com.atlassian.plugin.tracker.PluginModuleTracker;
4   import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5   import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6   import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7   import com.google.common.base.Predicate;
8   import com.google.common.collect.Iterables;
9   import com.sun.jersey.spi.container.ContainerRequest;
10  import com.sun.jersey.spi.container.ContainerRequestFilter;
11  import com.sun.jersey.spi.container.ContainerResponse;
12  import com.sun.jersey.spi.container.ContainerResponseFilter;
13  import com.sun.jersey.spi.container.ResourceFilter;
14  import org.slf4j.Logger;
15  import org.slf4j.LoggerFactory;
16  
17  import javax.ws.rs.core.Response;
18  import java.net.URI;
19  import java.util.Collections;
20  import java.util.List;
21  import java.util.Set;
22  
23  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS;
24  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
25  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_METHODS;
26  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
27  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_EXPOSE_HEADERS;
28  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_MAX_AGE;
29  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
30  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_METHOD;
31  import static com.atlassian.plugins.rest.common.security.CorsHeaders.ORIGIN;
32  import static com.google.common.collect.Lists.newArrayList;
33  import static com.google.common.collect.Sets.newHashSet;
34  
35  /**
36   * A filter that handles Cross-Origin Resource Sharing preflight checks and response headers.  Handles simple and preflight
37   * requests.
38   *
39   * See spec at http://www.w3.org/TR/cors
40   *
41   * @since 2.6
42   */
43  public class CorsResourceFilter implements ResourceFilter, ContainerRequestFilter, ContainerResponseFilter
44  {
45      private static final String CORS_PREFLIGHT_FAILED = "Cors-Preflight-Failed";
46      private static final String CORS_PREFLIGHT_SUCCEEDED = "Cors-Preflight-Succeeded";
47      public static final String CORS_PREFLIGHT_REQUESTED = "Cors-Preflight-Requested";
48      private static final Logger log = LoggerFactory.getLogger(CorsResourceFilter.class);
49  
50      private final PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker;
51      private final String allowMethod;
52  
53      public CorsResourceFilter(PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker, String allowMethod)
54      {
55          this.allowMethod = allowMethod;
56          this.pluginModuleTracker = pluginModuleTracker;
57      }
58  
59      public ContainerRequest filter(final ContainerRequest request)
60      {
61          // if origin is present, match exactly or terminate
62          // if any tokens are not a case-sensitive match for the whitelist, terminate
63          // method = Access-Control-Request-Method, if null terminate
64          // headers = Access-Control-Request-Headers values, or empty list
65          // if method is not a case-sensitive match in list, terminate
66          // If any of 'headers' is not a case-insensitive match for values, terminate
67          // add single Access-Control-Allow-Origin with the Origin value and add Access-Control-Allow-Credentials to true
68          // add Access-Control-Max-Age header in seconds
69          // add Access-Control-Allow-Methods (optional)
70          // add one or more Access-Control-Allow-Headers for each header to expose (optional)
71  
72          if (request.getProperties().containsKey(CORS_PREFLIGHT_REQUESTED))
73          {
74              Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
75              try
76              {
77                  String origin = validateSingleOriginInWhitelist(defaults, request);
78                  Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
79  
80                  Response.ResponseBuilder response = Response.ok();
81                  validateAccessControlRequestMethod(allowMethod, request);
82                  Set<String> allowedRequestHeaders = getAllowedRequestHeaders(defaultsWithAllowedOrigin, origin);
83                  validateAccessControlRequestHeaders(allowedRequestHeaders, request);
84  
85                  addAccessControlAllowOrigin(response, origin);
86                  conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
87                  addAccessControlMaxAge(response);
88                  addAccessControlAllowMethods(response, allowMethod);
89                  addAccessControlAllowHeaders(response, allowedRequestHeaders);
90  
91                  request.getProperties().put(CORS_PREFLIGHT_SUCCEEDED, "true");
92                  // exceptions are the only way to return a response here in Jersey
93                  throw new CorsPreflightCheckCompleteException(response.build());
94              }
95              catch (PreflightFailedException ex)
96              {
97                  Response.ResponseBuilder response = Response.ok();
98                  request.getProperties().put(CORS_PREFLIGHT_FAILED, "true");
99                  log.info("CORS preflight failed: " + ex.getMessage());
100                 throw new CorsPreflightCheckCompleteException(response.build());
101             }
102         }
103         else
104         {
105             return request;
106         }
107     }
108 
109     public ContainerResponse filter(ContainerRequest request, ContainerResponse containerResponse)
110     {
111         // if origin is present, split otherwise terminate
112         // if any tokens are not a case-sensitive match for the whitelist, terminate
113         // add single Access-Control-Allow-Origin with the Origin value and add Access-Control-Allow-Credentials to true
114         // add one or more Access-Control-Expose-Headers for each header to expose
115 
116         if (!request.getProperties().containsKey(CORS_PREFLIGHT_FAILED) &&
117                 !request.getProperties().containsKey(CORS_PREFLIGHT_SUCCEEDED) &&
118                 extractOrigin(request) != null)
119         {
120             Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
121             try
122             {
123                 String origin = validateAnyOriginInListInWhitelist(defaults, request);
124                 Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
125 
126                 Response.ResponseBuilder response = Response.fromResponse(containerResponse.getResponse());
127                 addAccessControlAllowOrigin(response, origin);
128                 conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
129                 addAccessControlExposeHeaders(response, getAllowedResponseHeaders(defaultsWithAllowedOrigin, origin));
130                 containerResponse.setResponse(response.build());
131                 return containerResponse;
132             }
133             catch (PreflightFailedException ex)
134             {
135                 log.info("Unable to add CORS headers to response: " + ex.getMessage());
136             }
137         }
138         return containerResponse;
139     }
140 
141     private void addAccessControlExposeHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
142     {
143         for (String header : allowedHeaders)
144         {
145             response.header(ACCESS_CONTROL_EXPOSE_HEADERS.value(), header);
146         }
147     }
148 
149     private void addAccessControlAllowHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
150     {
151         for (String header : allowedHeaders)
152         {
153             response.header(ACCESS_CONTROL_ALLOW_HEADERS.value(), header);
154         }
155     }
156 
157     private void addAccessControlAllowMethods(Response.ResponseBuilder response, String allowMethod)
158     {
159         response.header(ACCESS_CONTROL_ALLOW_METHODS.value(), allowMethod);
160     }
161 
162     private void addAccessControlMaxAge(Response.ResponseBuilder response)
163     {
164         response.header(ACCESS_CONTROL_MAX_AGE.value(), 60 * 60);
165     }
166 
167     private void addAccessControlAllowOrigin(Response.ResponseBuilder response, String origin)
168     {
169         response.header(ACCESS_CONTROL_ALLOW_ORIGIN.value(), origin);
170     }
171 
172     private void conditionallyAddAccessControlAllowCredentials(Response.ResponseBuilder response, String origin, Iterable<CorsDefaults> defaultsWithAllowedOrigin)
173     {
174         if (anyAllowsCredentials(defaultsWithAllowedOrigin, origin))
175         {
176             response.header(ACCESS_CONTROL_ALLOW_CREDENTIALS.value(), "true");
177         }
178     }
179 
180     private void validateAccessControlRequestHeaders(Set<String> allowedHeaders, ContainerRequest request) throws PreflightFailedException
181     {
182         //Note: According to the spec, this should be a case-insensitive comparison
183         List<String> requestedHeaders = request.getRequestHeader(ACCESS_CONTROL_REQUEST_HEADERS.value());
184         requestedHeaders = requestedHeaders != null ? requestedHeaders : Collections.<String>emptyList();
185         if (!allowedHeaders.containsAll(requestedHeaders))
186         {
187             List<String> unexpectedHeaders = newArrayList(requestedHeaders);
188             unexpectedHeaders.removeAll(allowedHeaders);
189 
190             throw new PreflightFailedException("Unexpected headers in CORS request: " + unexpectedHeaders);
191         }
192     }
193 
194     private void validateAccessControlRequestMethod(String allowMethod, ContainerRequest request) throws PreflightFailedException
195     {
196         String requestedMethod = request.getHeaderValue(ACCESS_CONTROL_REQUEST_METHOD.value());
197         if (!allowMethod.equals(requestedMethod))
198         {
199             throw new PreflightFailedException("Invalid method: " + requestedMethod);
200         }
201     }
202 
203     private String validateAnyOriginInListInWhitelist(Iterable<CorsDefaults> defaults, ContainerRequest request) throws PreflightFailedException
204     {
205         String originRaw = extractOrigin(request);
206         String[] originList = originRaw.split(" ");
207         for (String origin : originList)
208         {
209             validateOriginAsUri(origin);
210             if (! Iterables.isEmpty(allowsOrigin(defaults, origin)))
211             {
212                 return origin;
213             }
214         }
215         throw new PreflightFailedException("Origins '" + originRaw + "' not in whitelist");
216     }
217 
218     private String validateSingleOriginInWhitelist(Iterable<CorsDefaults> defaults, ContainerRequest request) throws PreflightFailedException
219     {
220         String origin = extractOrigin(request);
221         validateOriginAsUri(origin);
222 
223         if (Iterables.isEmpty(allowsOrigin(defaults, origin)))
224         {
225             throw new PreflightFailedException("Origin '" + origin + "' not in whitelist");
226         }
227         return origin;
228     }
229 
230     private void validateOriginAsUri(String origin) throws PreflightFailedException
231     {
232         try
233         {
234             URI.create(origin);
235         }
236         catch (IllegalArgumentException ex)
237         {
238             throw new PreflightFailedException("Origin '" + origin + "' is not a valid URI");
239         }
240     }
241 
242     public static String extractOrigin(ContainerRequest request)
243     {
244         return request.getHeaderValue(ORIGIN.value());
245     }
246 
247     public ContainerRequestFilter getRequestFilter()
248     {
249         return this;
250     }
251 
252     public ContainerResponseFilter getResponseFilter()
253     {
254         return this;
255     }
256 
257     /**
258      * Thrown if the preflight or simple cross-origin check process fails
259      */
260     private static class PreflightFailedException extends Exception
261     {
262         private PreflightFailedException(String message)
263         {
264             super(message);
265         }
266     }
267 
268     private static Iterable<CorsDefaults> allowsOrigin(Iterable<CorsDefaults> delegates, final String uri)
269     {
270         return Iterables.filter(delegates, new Predicate<CorsDefaults>()
271         {
272             public boolean apply(CorsDefaults delegate)
273             {
274                 return delegate.allowsOrigin(uri);
275             }
276         });
277     }
278 
279     private static boolean anyAllowsCredentials(Iterable<CorsDefaults> delegatesWhichAllowOrigin, final String uri)
280     {
281         for (CorsDefaults defs : delegatesWhichAllowOrigin)
282         {
283             if (defs.allowsCredentials(uri))
284             {
285                 return true;
286             }
287         }
288         return false;
289     }
290 
291 
292     private static Set<String> getAllowedRequestHeaders(Iterable<CorsDefaults> delegatesWhichAllowOrigin, String uri)
293     {
294         Set<String> result = newHashSet();
295         for (CorsDefaults defs : delegatesWhichAllowOrigin)
296         {
297             result.addAll(defs.getAllowedRequestHeaders(uri));
298         }
299         return result;
300     }
301 
302     private static Set<String> getAllowedResponseHeaders(Iterable<CorsDefaults> delegatesWithAllowedOrigin, String uri)
303     {
304         Set<String> result = newHashSet();
305         for (CorsDefaults defs : delegatesWithAllowedOrigin)
306         {
307             result.addAll(defs.getAllowedResponseHeaders(uri));
308         }
309         return result;
310     }
311 
312 }