1 package com.atlassian.sal.core.net;
2
3 import java.io.IOException;
4 import java.net.InetAddress;
5 import java.net.Socket;
6 import java.net.UnknownHostException;
7 import java.security.NoSuchAlgorithmException;
8 import java.util.Arrays;
9
10 import javax.net.ssl.SSLContext;
11 import javax.net.ssl.SSLSocket;
12
13 import org.apache.commons.httpclient.ConnectTimeoutException;
14 import org.apache.commons.httpclient.params.HttpConnectionParams;
15 import org.apache.commons.httpclient.protocol.SSLProtocolSocketFactory;
16
17
18
19
20
21
22
23 class CustomSSLProtocolSocketFactory extends SSLProtocolSocketFactory
24 {
25 private final String[] protocols;
26
27
28
29
30
31
32
33
34
35
36 public CustomSSLProtocolSocketFactory(String protocols)
37 {
38 String protocolsProperty = System.getProperty("https.protocols", protocols);
39 if (protocolsProperty == null)
40 {
41 protocolsProperty = getSupportedTLSProtocols();
42 }
43 this.protocols = protocolsProperty.split(",");
44 }
45
46
47
48
49
50 public static String getSupportedTLSProtocols()
51 {
52 StringBuilder sb = new StringBuilder();
53 try
54 {
55 for (String protocol: SSLContext.getDefault()
56 .getSupportedSSLParameters().getProtocols())
57 {
58 if (protocol.startsWith("TLS"))
59 {
60 sb.append(protocol).append(",");
61 }
62 }
63 }
64 catch (NoSuchAlgorithmException e)
65 {
66 throw new RuntimeException(e);
67 }
68 String protocols = sb.toString();
69 if (protocols.endsWith(","))
70 {
71 protocols = protocols.substring(0, protocols.length() - 1);
72 }
73 return protocols;
74 }
75
76 @Override
77 protected void setSSLProtocols(SSLSocket socket)
78 {
79 setSocketProtocols(socket);
80 }
81
82 private void setSocketProtocols(Socket socket)
83 {
84 SSLSocket sslSocket = (SSLSocket) socket;
85 sslSocket.setEnabledProtocols(protocols);
86 }
87
88 @Override
89 public Socket createSocket(final String host, final int port, final InetAddress clientHost, final int clientPort)
90 throws IOException, UnknownHostException
91 {
92 Socket socket = super.createSocket(host, port, clientHost, clientPort);
93 setSocketProtocols(socket);
94 return socket;
95 }
96
97 @Override
98 public Socket createSocket(final String host, final int port, final InetAddress localAddress, final int localPort, final HttpConnectionParams params)
99 throws IOException, UnknownHostException, ConnectTimeoutException
100 {
101 Socket socket = super.createSocket(host, port, localAddress, localPort, params);
102 setSocketProtocols(socket);
103 return socket;
104 }
105
106 @Override
107 public Socket createSocket(final String host, final int port) throws IOException, UnknownHostException
108 {
109 Socket socket = super.createSocket(host, port);
110 setSocketProtocols(socket);
111 return socket;
112 }
113
114 @Override
115 public Socket createSocket(final Socket socket, final String host, final int port, final boolean autoClose)
116 throws IOException, UnknownHostException
117 {
118 Socket newSocket = super.createSocket(socket, host, port, autoClose);
119 setSocketProtocols(newSocket);
120 return newSocket;
121 }
122
123 public String[] getProtocols()
124 {
125 return protocols;
126 }
127
128 @Override
129 public boolean equals(Object o)
130 {
131 if (this == o)
132 {
133 return true;
134 }
135 if (o == null || getClass() != o.getClass())
136 {
137 return false;
138 }
139 if (!super.equals(o))
140 {
141 return false;
142 }
143
144 CustomSSLProtocolSocketFactory that = (CustomSSLProtocolSocketFactory) o;
145 if (!Arrays.equals(protocols, that.protocols))
146 {
147 return false;
148 }
149
150 return true;
151 }
152
153 @Override
154 public int hashCode()
155 {
156 int result = super.hashCode();
157 result = 31 * result + (protocols != null ? Arrays.hashCode(protocols) : 0);
158 return result;
159 }
160 }