1 package com.atlassian.activeobjects.confluence;
2
3 import com.atlassian.activeobjects.confluence.hibernate.DialectExtractor;
4 import com.atlassian.activeobjects.spi.AbstractTenantAwareDataSourceProvider;
5 import com.atlassian.activeobjects.spi.ConnectionHandler;
6 import com.atlassian.activeobjects.spi.DatabaseType;
7 import com.atlassian.hibernate.PluginHibernateSessionFactory;
8 import com.atlassian.tenancy.api.Tenant;
9 import com.google.common.collect.ImmutableMap;
10 import net.sf.hibernate.HibernateException;
11 import net.sf.hibernate.Session;
12 import net.sf.hibernate.dialect.DB2Dialect;
13 import net.sf.hibernate.dialect.Dialect;
14 import net.sf.hibernate.dialect.HSQLDialect;
15 import net.sf.hibernate.dialect.MySQLDialect;
16 import net.sf.hibernate.dialect.Oracle9Dialect;
17 import net.sf.hibernate.dialect.PostgreSQLDialect;
18 import net.sf.hibernate.dialect.SQLServerDialect;
19
20 import javax.annotation.Nonnull;
21 import javax.sql.DataSource;
22 import java.io.PrintWriter;
23 import java.sql.Connection;
24 import java.sql.SQLException;
25 import java.sql.SQLFeatureNotSupportedException;
26 import java.util.Map;
27 import java.util.logging.Logger;
28
29 import static com.google.common.base.Preconditions.checkNotNull;
30
31 public final class ConfluenceTenantAwareDataSourceProvider extends AbstractTenantAwareDataSourceProvider {
32 private static final Map<Class<? extends Dialect>, DatabaseType> DIALECT_TO_DATABASE_MAPPING = ImmutableMap.<Class<? extends Dialect>, DatabaseType>builder()
33 .put(HSQLDialect.class, DatabaseType.HSQL)
34 .put(MySQLDialect.class, DatabaseType.MYSQL)
35 .put(PostgreSQLDialect.class, DatabaseType.POSTGRESQL)
36 .put(Oracle9Dialect.class, DatabaseType.ORACLE)
37 .put(SQLServerDialect.class, DatabaseType.MS_SQL)
38 .put(DB2Dialect.class, DatabaseType.DB2)
39 .build();
40
41 private final SessionFactoryDataSource dataSource;
42 private final DialectExtractor dialectExtractor;
43
44 public ConfluenceTenantAwareDataSourceProvider(PluginHibernateSessionFactory sessionFactory, DialectExtractor dialectExtractor) {
45 this.dataSource = new SessionFactoryDataSource(checkNotNull(sessionFactory));
46 this.dialectExtractor = checkNotNull(dialectExtractor);
47 }
48
49 @Nonnull
50 @Override
51 public DataSource getDataSource(@Nonnull final Tenant tenant) {
52 return dataSource;
53 }
54
55 @Nonnull
56 @Override
57 public DatabaseType getDatabaseType(@Nonnull final Tenant tenant) {
58 final Class<? extends Dialect> dialect = dialectExtractor.getDialect();
59 if (dialect == null) {
60 return DatabaseType.UNKNOWN;
61 }
62 for (Map.Entry<Class<? extends Dialect>, DatabaseType> entry : DIALECT_TO_DATABASE_MAPPING.entrySet()) {
63 if (entry.getKey().isAssignableFrom(dialect)) {
64 return entry.getValue();
65 }
66 }
67 return super.getDatabaseType(tenant);
68 }
69
70 private static class SessionFactoryDataSource extends AbstractDataSource {
71 private final PluginHibernateSessionFactory sessionFactory;
72
73 public SessionFactoryDataSource(PluginHibernateSessionFactory sessionFactory) {
74 this.sessionFactory = sessionFactory;
75 }
76
77 @Override
78 public Connection getConnection() throws SQLException {
79 final Session session = sessionFactory.getSession();
80 try {
81 return ConnectionHandler.newInstance(session.connection());
82 } catch (HibernateException e) {
83 throw new SQLException(e.getMessage());
84 }
85 }
86
87 @Override
88 public Connection getConnection(String username, String password) throws SQLException {
89 throw new IllegalStateException("Not allowed to get a connection for non default username/password");
90 }
91
92 @Override
93 public <T> T unwrap(Class<T> tClass) throws SQLException {
94 return null;
95 }
96
97 @Override
98 public boolean isWrapperFor(Class<?> aClass) throws SQLException {
99 return false;
100 }
101 }
102
103 private static abstract class AbstractDataSource implements DataSource {
104
105
106
107 @Override
108 public int getLoginTimeout() throws SQLException {
109 return 0;
110 }
111
112
113
114
115 @Override
116 public void setLoginTimeout(int timeout) throws SQLException {
117 throw new UnsupportedOperationException("setLoginTimeout");
118 }
119
120
121
122
123 @Override
124 public PrintWriter getLogWriter() {
125 throw new UnsupportedOperationException("getLogWriter");
126 }
127
128
129
130
131 @Override
132 public void setLogWriter(PrintWriter pw) throws SQLException {
133 throw new UnsupportedOperationException("setLogWriter");
134 }
135
136
137 public Logger getParentLogger() throws SQLFeatureNotSupportedException {
138 throw new SQLFeatureNotSupportedException();
139 }
140 }
141 }