1 package com.atlassian.sal.core.net;
2
3 import com.atlassian.sal.api.net.Request;
4 import com.atlassian.sal.api.net.ResponseException;
5 import com.atlassian.sal.api.user.UserManager;
6 import com.atlassian.sal.core.trusted.CertificateFactory;
7 import com.atlassian.security.auth.trustedapps.EncryptedCertificate;
8 import com.atlassian.security.auth.trustedapps.TrustedApplicationUtils;
9 import org.apache.http.Header;
10 import org.apache.http.HttpClientConnection;
11 import org.apache.http.HttpException;
12 import org.apache.http.HttpRequest;
13 import org.apache.http.HttpResponse;
14 import org.apache.http.HttpStatus;
15 import org.apache.http.ProtocolVersion;
16 import org.apache.http.conn.ConnectionRequest;
17 import org.apache.http.conn.HttpClientConnectionManager;
18 import org.apache.http.conn.routing.HttpRoute;
19 import org.apache.http.entity.StringEntity;
20 import org.apache.http.message.BasicHttpResponse;
21 import org.apache.http.protocol.HttpContext;
22 import org.apache.http.protocol.HttpRequestExecutor;
23 import org.hamcrest.FeatureMatcher;
24 import org.hamcrest.Matcher;
25 import org.junit.Before;
26 import org.junit.Rule;
27 import org.junit.Test;
28 import org.junit.rules.ExpectedException;
29 import org.junit.runner.RunWith;
30 import org.mockito.ArgumentCaptor;
31 import org.mockito.Captor;
32 import org.mockito.Mock;
33 import org.mockito.junit.MockitoJUnitRunner;
34
35 import java.io.IOException;
36 import java.nio.charset.StandardCharsets;
37 import java.text.MessageFormat;
38 import java.util.concurrent.ExecutionException;
39 import java.util.concurrent.TimeUnit;
40
41 import static org.hamcrest.MatcherAssert.assertThat;
42 import static org.hamcrest.Matchers.arrayContaining;
43 import static org.hamcrest.Matchers.equalTo;
44 import static org.mockito.ArgumentMatchers.any;
45 import static org.mockito.ArgumentMatchers.anyLong;
46 import static org.mockito.ArgumentMatchers.eq;
47 import static org.mockito.Mockito.mock;
48 import static org.mockito.Mockito.verify;
49 import static org.mockito.Mockito.when;
50
51 @RunWith(MockitoJUnitRunner.class)
52 public class TestHttpClientTrustedRequest {
53
54 private static final String DUMMY_HOST = "dummy.atlassian.test";
55 private static final String DUMMY_HTTP_URL = MessageFormat.format("http://{0}/", DUMMY_HOST);
56
57 private static final String DUMMY_USERNAME = "dummy";
58 private static final String DUMMY_TRUSTED_TOKEN_ID = "dummy-id";
59
60 @Rule
61 public ExpectedException thrown = ExpectedException.none();
62
63 @Mock
64 private HttpRequestExecutor mockRequestExecutor;
65
66 @Mock
67 private HttpClientConnectionManager mockConnectionManager;
68
69 @Mock
70 private CertificateFactory certificateFactory;
71
72 @Mock
73 private EncryptedCertificate encryptedCertificate;
74
75 @Mock
76 private UserManager userManager;
77
78 private HttpClientTrustedRequestFactory requestFactory;
79
80 @Captor
81 private ArgumentCaptor<HttpRequest> requestCaptor;
82
83 @Before
84 public void setup() throws InterruptedException, ExecutionException, IOException, HttpException {
85 requestFactory = new HttpClientWithMockConnectionTrustedRequestFactory(userManager, certificateFactory, mockConnectionManager, mockRequestExecutor);
86
87
88 when(mockRequestExecutor.execute(any(HttpRequest.class), any(HttpClientConnection.class),
89 any(HttpContext.class))).thenReturn(createOkResponse());
90
91
92 final HttpClientConnection conn = mock(HttpClientConnection.class);
93 final ConnectionRequest connRequest = mock(ConnectionRequest.class);
94 when(connRequest.get(anyLong(), any(TimeUnit.class))).thenReturn(conn);
95 when(mockConnectionManager.requestConnection(any(HttpRoute.class), any())).thenReturn(connRequest);
96
97
98 when(userManager.getRemoteUsername()).thenReturn(DUMMY_USERNAME);
99 when(encryptedCertificate.getID()).thenReturn(DUMMY_TRUSTED_TOKEN_ID);
100 when(certificateFactory.createCertificate(eq(DUMMY_USERNAME), eq(DUMMY_HTTP_URL))).thenReturn(encryptedCertificate);
101 }
102
103 private static HttpResponse createOkResponse() {
104 final BasicHttpResponse response = new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1), HttpStatus.SC_OK, "OK");
105 response.setEntity(new StringEntity("test body", StandardCharsets.UTF_8));
106 return response;
107 }
108
109
110 @Test
111 public void assertThatTrustedTokenHeadersAdded() throws ResponseException, IOException, HttpException {
112 final HttpClientTrustedRequest request = requestFactory.createTrustedRequest(Request.MethodType.GET, DUMMY_HTTP_URL);
113
114 request.addTrustedTokenAuthentication(DUMMY_HOST);
115 request.execute();
116
117 verify(mockRequestExecutor).execute(requestCaptor.capture(), any(HttpClientConnection.class), any(HttpContext.class));
118 final HttpRequest lastRequest = requestCaptor.getValue();
119 final Header[] headers = lastRequest.getHeaders(TrustedApplicationUtils.Header.Request.ID);
120
121
122 assertThat(headers, arrayContaining(headerWithValue(equalTo(DUMMY_TRUSTED_TOKEN_ID))));
123
124 }
125
126 private static Matcher<Header> headerWithValue(final Matcher<String> valueMatcher) {
127 return new FeatureMatcher<Header, String>(valueMatcher, "header with value", "header value") {
128 @Override
129 protected String featureValueOf(final Header header) {
130 return header.getValue();
131 }
132 };
133 }
134 }