1 package com.atlassian.dbexporter.importer;
2
3 import com.atlassian.dbexporter.Column;
4 import com.atlassian.dbexporter.Context;
5 import com.atlassian.dbexporter.DatabaseInformations;
6 import com.atlassian.dbexporter.ImportExportErrorService;
7 import com.atlassian.dbexporter.Table;
8 import com.google.common.base.Predicate;
9 import com.google.common.collect.Iterables;
10
11 import java.sql.Connection;
12 import java.sql.SQLException;
13 import java.sql.Statement;
14
15 import static com.atlassian.dbexporter.DatabaseInformations.database;
16 import static com.atlassian.dbexporter.jdbc.JdbcUtils.closeQuietly;
17 import static com.atlassian.dbexporter.jdbc.JdbcUtils.quote;
18 import static com.google.common.base.Preconditions.checkNotNull;
19
20 public final class SqlServerAroundTableImporter implements DataImporter.AroundTableImporter {
21 private final ImportExportErrorService errorService;
22 private final String schema;
23
24 public SqlServerAroundTableImporter(ImportExportErrorService errorService, String schema) {
25 this.errorService = checkNotNull(errorService);
26 this.schema = schema;
27 }
28
29 @Override
30 public void before(ImportConfiguration configuration, Context context, String table, Connection connection) {
31 setIdentityInsert(configuration, context, connection, table, "ON");
32 }
33
34 @Override
35 public void after(ImportConfiguration configuration, Context context, String table, Connection connection) {
36 setIdentityInsert(configuration, context, connection, table, "OFF");
37 }
38
39 private void setIdentityInsert(ImportConfiguration configuration, Context context, Connection connection, String table, String onOff) {
40 if (isSqlServer(configuration) && isAutoIncrementTable(context, table)) {
41 setIdentityInsert(connection, table, onOff);
42 }
43 }
44
45 private boolean isAutoIncrementTable(Context context, final String tableName) {
46 return hasAnyAutoIncrementColumn(findTable(context, tableName));
47 }
48
49 private boolean hasAnyAutoIncrementColumn(Table table) {
50 return Iterables.any(table.getColumns(), new Predicate<Column>() {
51 @Override
52 public boolean apply(Column c) {
53 return c.isAutoIncrement();
54 }
55 });
56 }
57
58 private Table findTable(Context context, final String tableName) {
59 return Iterables.find(context.getAll(Table.class), new Predicate<Table>() {
60 @Override
61 public boolean apply(Table t) {
62 return t.getName().equals(tableName);
63 }
64 });
65 }
66
67 private void setIdentityInsert(Connection connection, String table, String onOff) {
68 Statement s = null;
69 try {
70 s = connection.createStatement();
71 s.execute(setIdentityInsertSql(quote(errorService, table, connection, table), onOff));
72 } catch (SQLException e) {
73 throw errorService.newImportExportSqlException(table, "", e);
74 } finally {
75 closeQuietly(s);
76 }
77 }
78
79 private String setIdentityInsertSql(String table, String onOff) {
80 return String.format("SET IDENTITY_INSERT %s %s", schema != null ? schema + "." + table : table, onOff);
81 }
82
83 private boolean isSqlServer(ImportConfiguration configuration) {
84 return DatabaseInformations.Database.Type.MSSQL.equals(database(configuration.getDatabaseInformation()).getType());
85 }
86 }