View Javadoc

1   package com.atlassian.sal.core.net;
2   
3   import java.io.IOException;
4   import java.net.InetAddress;
5   import java.net.Socket;
6   import java.net.UnknownHostException;
7   import java.security.NoSuchAlgorithmException;
8   import java.util.Arrays;
9   
10  import javax.net.ssl.SSLContext;
11  import javax.net.ssl.SSLSocket;
12  
13  import org.apache.commons.httpclient.ConnectTimeoutException;
14  import org.apache.commons.httpclient.params.HttpConnectionParams;
15  import org.apache.commons.httpclient.protocol.SSLProtocolSocketFactory;
16  
17  /**
18   * A custom SecureProtocolSocketFactory that uses JSSE to create sockets.
19   * <p/>
20   * This factory handles the JSSE system property "https.protocols" used to specify which protocol suites to enable.
21   * http://docs.oracle.com/javase/7/docs/technotes/guides/security/jsse/JSSERefGuide.html#Customization
22   */
23  class CustomSSLProtocolSocketFactory extends SSLProtocolSocketFactory
24  {
25      private final String[] protocols;
26  
27      /**
28       * Creates CustomSSLProtocolSocketFactory.
29       * Uses the "https.protocols" system property if defined otherwise uses protocols parameter.
30       *
31       * Available protocols:
32       * http://docs.oracle.com/javase/1.5.0/docs/guide/security/jsse/JSSERefGuide.html#AppA
33       *
34       * @param protocols Comma-separated list of protocol versions enabled for use on socket created by this factory.
35       */
36      public CustomSSLProtocolSocketFactory(String protocols)
37      {
38          String protocolsProperty = System.getProperty("https.protocols", protocols);
39          if (protocolsProperty == null)
40          {
41              protocolsProperty = getSupportedTLSProtocols();
42          }
43          this.protocols = protocolsProperty.split(",");
44      }
45  
46      /**
47       * Returns a String of supported TLS protocols delimited by ',' - TLSv1.0 is the earliest supported version.
48       * @return a String of supported TLS protocols delimited by ',' - TLSv1.0 is the earliest supported version.
49       */
50      public static String getSupportedTLSProtocols()
51      {
52          StringBuilder sb = new StringBuilder();
53          try
54          {
55              for (String protocol: SSLContext.getDefault()
56                  .getSupportedSSLParameters().getProtocols())
57              {
58                  if (protocol.startsWith("TLS"))
59                  {
60                      sb.append(protocol).append(",");
61                  }
62              }
63          }
64          catch (NoSuchAlgorithmException e)
65          {
66              throw new RuntimeException(e);
67          }
68          String protocols = sb.toString();
69          if (protocols.endsWith(","))
70          {
71              protocols = protocols.substring(0, protocols.length() - 1);
72          }
73          return protocols;
74      }
75  
76      @Override
77      protected void setSSLProtocols(SSLSocket socket)
78      {
79          setSocketProtocols(socket);
80      }
81  
82      private void setSocketProtocols(Socket socket)
83      {
84          SSLSocket sslSocket = (SSLSocket) socket;
85          sslSocket.setEnabledProtocols(protocols);
86      }
87  
88      @Override
89      public Socket createSocket(final String host, final int port, final InetAddress clientHost, final int clientPort)
90              throws IOException, UnknownHostException
91      {
92          Socket socket = super.createSocket(host, port, clientHost, clientPort);
93          setSocketProtocols(socket);
94          return socket;
95      }
96  
97      @Override
98      public Socket createSocket(final String host, final int port, final InetAddress localAddress, final int localPort, final HttpConnectionParams params)
99              throws IOException, UnknownHostException, ConnectTimeoutException
100     {
101         Socket socket = super.createSocket(host, port, localAddress, localPort, params);
102         setSocketProtocols(socket);
103         return socket;
104     }
105 
106     @Override
107     public Socket createSocket(final String host, final int port) throws IOException, UnknownHostException
108     {
109         Socket socket = super.createSocket(host, port);
110         setSocketProtocols(socket);
111         return socket;
112     }
113 
114     @Override
115     public Socket createSocket(final Socket socket, final String host, final int port, final boolean autoClose)
116             throws IOException, UnknownHostException
117     {
118         Socket newSocket = super.createSocket(socket, host, port, autoClose);
119         setSocketProtocols(newSocket);
120         return newSocket;
121     }
122 
123     public String[] getProtocols()
124     {
125         return protocols;
126     }
127 
128     @Override
129     public boolean equals(Object o)
130     {
131         if (this == o)
132         {
133             return true;
134         }
135         if (o == null || getClass() != o.getClass())
136         {
137             return false;
138         }
139         if (!super.equals(o))
140         {
141             return false;
142         }
143 
144         CustomSSLProtocolSocketFactory that = (CustomSSLProtocolSocketFactory) o;
145         if (!Arrays.equals(protocols, that.protocols))
146         {
147             return false;
148         }
149 
150         return true;
151     }
152 
153     @Override
154     public int hashCode()
155     {
156         int result = super.hashCode();
157         result = 31 * result + (protocols != null ? Arrays.hashCode(protocols) : 0);
158         return result;
159     }
160 }