1 package com.atlassian.plugins.rest.common.security.jersey; 2 3 import com.atlassian.plugin.tracker.PluginModuleTracker; 4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException; 5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults; 6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor; 7 import com.google.common.collect.ImmutableSet; 8 import com.sun.jersey.spi.container.ContainerRequest; 9 import com.sun.jersey.spi.container.ContainerResponse; 10 import org.junit.Before; 11 import org.junit.Test; 12 import org.junit.runner.RunWith; 13 import org.mockito.ArgumentCaptor; 14 import org.mockito.Matchers; 15 import org.mockito.Mock; 16 import org.mockito.runners.MockitoJUnitRunner; 17 18 import javax.ws.rs.core.MultivaluedMap; 19 import javax.ws.rs.core.Response; 20 import java.util.Arrays; 21 import java.util.Map; 22 23 import static com.google.common.collect.Maps.newHashMap; 24 import static com.google.common.collect.Sets.newHashSet; 25 import static org.junit.Assert.assertEquals; 26 import static org.junit.Assert.fail; 27 import static org.mockito.Mockito.*; 28 29 @RunWith (MockitoJUnitRunner.class) 30 public class TestCorsResourceFilter 31 { 32 @Mock 33 private CorsDefaults corsDefaults; 34 35 @Mock 36 private CorsDefaults corsDefaults2; 37 38 private CorsResourceFilter corsResourceFilter; 39 @Mock 40 private ContainerRequest request; 41 private Map<String,Object> requestProps; 42 @Mock 43 private ContainerResponse response; 44 @Mock 45 private PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> tracker; 46 47 @Before 48 public void setUp() 49 { 50 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults)); 51 when(response.getResponse()).thenReturn(Response.ok().build()); 52 requestProps = newHashMap(); 53 when(request.getProperties()).thenReturn(requestProps); 54 corsResourceFilter = new CorsResourceFilter(tracker, "GET"); 55 } 56 57 @Test 58 public void testSimplePreflightForGet() 59 { 60 String origin = "http://localhost"; 61 requestProps.put("Cors-Preflight-Requested", "true"); 62 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 63 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET"); 64 when(request.getHeaderValue("Origin")).thenReturn(origin); 65 MultivaluedMap<String, Object> headers = execPreflight(); 66 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin); 67 } 68 69 @Test 70 public void testPreflightSucceedsWhenOneCorsDefaultsAllowsOrigin() 71 { 72 String origin = "http://localhost"; 73 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin); 74 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin); 75 } 76 77 @Test 78 public void testSecondCorsDefaultsIsNotHitIfDoesntAllowOrigin() 79 { 80 String origin = "http://localhost"; 81 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin); 82 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any()); 83 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any()); 84 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any()); 85 } 86 87 @Test 88 public void testSimplePreflightForGetWrongDomain() 89 { 90 String origin = "http://localhost"; 91 requestProps.put("Cors-Preflight-Requested", "true"); 92 when(corsDefaults.allowsOrigin(origin)).thenReturn(false); 93 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET"); 94 when(request.getHeaderValue("Origin")).thenReturn(origin); 95 execBadPreflight(); 96 } 97 98 @Test 99 public void testSimplePreflightForGetWrongMethod() 100 { 101 String origin = "http://localhost"; 102 requestProps.put("Cors-Preflight-Requested", "true"); 103 when(corsDefaults.allowsOrigin(origin)).thenReturn(false); 104 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("POST"); 105 when(request.getHeaderValue("Origin")).thenReturn(origin); 106 execBadPreflight(); 107 } 108 109 @Test 110 public void testSimplePreflightForGetWrongHeaders() 111 { 112 String origin = "http://localhost"; 113 requestProps.put("Cors-Preflight-Requested", "true"); 114 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 115 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(ImmutableSet.<String>of("Foo-Header")); 116 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET"); 117 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(Arrays.asList("Bar-Header")); 118 when(request.getHeaderValue("Origin")).thenReturn(origin); 119 execBadPreflight(); 120 } 121 122 @Test 123 public void testSimpleGet() 124 { 125 String origin = "http://localhost"; 126 when(request.getMethod()).thenReturn("GET"); 127 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 128 when(request.getHeaderValue("Origin")).thenReturn(origin); 129 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders(); 130 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin); 131 } 132 133 @Test 134 public void testSimpleGetWhenOneCorsDefaultsAllowsOrigin() 135 { 136 String origin = "http://localhost"; 137 MultivaluedMap<String, Object> headers = execNoPreflightWithHeadersForTwoCorsDefaults(origin); 138 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin); 139 } 140 141 @Test 142 public void testSecondCorsDefaultIsNotCalledWhenItDoesntAllowOrigin() 143 { 144 String origin = "http://localhost"; 145 execNoPreflightWithHeadersForTwoCorsDefaults(origin); 146 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any()); 147 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any()); 148 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any()); 149 } 150 151 @Test 152 public void testSimpleGetWrongOrigin() 153 { 154 String origin = "http://localhost"; 155 when(request.getMethod()).thenReturn("GET"); 156 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 157 when(request.getHeaderValue("Origin")).thenReturn("http://foo.com"); 158 execNoPreflightNoHeaders(); 159 } 160 161 @Test 162 public void testSimpleGetNoOrigin() 163 { 164 when(request.getMethod()).thenReturn("GET"); 165 execNoPreflightNoHeaders(); 166 } 167 168 private MultivaluedMap<String, Object> execPreflight() 169 { 170 try 171 { 172 corsResourceFilter.filter(request); 173 fail("Should have thrown preflight exception"); 174 return null; 175 } 176 catch (CorsPreflightCheckCompleteException ex) 177 { 178 return ex.getResponse().getMetadata(); 179 } 180 } 181 182 private MultivaluedMap<String, Object> execNoPreflightWithHeaders() 183 { 184 try 185 { 186 corsResourceFilter.filter(request); 187 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class); 188 corsResourceFilter.filter(request, response); 189 verify(response).setResponse(argument.capture()); 190 return argument.getValue().getMetadata(); 191 } 192 catch (CorsPreflightCheckCompleteException ex) 193 { 194 fail("Shouldn't have thrown preflight exception"); 195 return null; 196 } 197 } 198 199 private void execNoPreflightNoHeaders() 200 { 201 try 202 { 203 corsResourceFilter.filter(request); 204 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class); 205 corsResourceFilter.filter(request, response); 206 verify(response, never()).setResponse(argument.capture()); 207 } 208 catch (CorsPreflightCheckCompleteException ex) 209 { 210 fail("Shouldn't have thrown preflight exception"); 211 } 212 } 213 214 private CorsPreflightCheckCompleteException execBadPreflight() 215 { 216 try 217 { 218 corsResourceFilter.filter(request); 219 corsResourceFilter.filter(request, response); 220 fail("Should have thrown preflight exception"); 221 return null; 222 } 223 catch (CorsPreflightCheckCompleteException ex) 224 { 225 return ex; 226 } 227 } 228 229 private MultivaluedMap<String, Object> execPreflightWithTwoCorsDefaults(String origin) 230 { 231 requestProps.put("Cors-Preflight-Requested", "true"); 232 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2)); 233 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 234 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false); 235 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET"); 236 when(request.getHeaderValue("Origin")).thenReturn(origin); 237 MultivaluedMap<String, Object> headers = execPreflight(); 238 return headers; 239 } 240 241 private MultivaluedMap<String, Object> execNoPreflightWithHeadersForTwoCorsDefaults(String origin) 242 { 243 when(request.getMethod()).thenReturn("GET"); 244 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2)); 245 when(corsDefaults.allowsOrigin(origin)).thenReturn(true); 246 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false); 247 when(request.getHeaderValue("Origin")).thenReturn(origin); 248 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders(); 249 return headers; 250 } 251 }