View Javadoc
1   /*
2    *    Copyright 2010-2023 the original author or authors.
3    *
4    *    Licensed under the Apache License, Version 2.0 (the "License");
5    *    you may not use this file except in compliance with the License.
6    *    You may obtain a copy of the License at
7    *
8    *       https://www.apache.org/licenses/LICENSE-2.0
9    *
10   *    Unless required by applicable law or agreed to in writing, software
11   *    distributed under the License is distributed on an "AS IS" BASIS,
12   *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   *    See the License for the specific language governing permissions and
14   *    limitations under the License.
15   */
16  package org.apache.ibatis.migration.operations;
17  
18  import java.io.BufferedReader;
19  import java.io.PrintWriter;
20  import java.io.Reader;
21  import java.sql.Connection;
22  import java.sql.ResultSet;
23  import java.sql.ResultSetMetaData;
24  import java.sql.SQLException;
25  import java.sql.SQLWarning;
26  import java.sql.Statement;
27  import java.util.regex.Matcher;
28  import java.util.regex.Pattern;
29  
30  /**
31   * @author Clinton Begin
32   */
33  public class ScriptRunner {
34  
35    private static final String LINE_SEPARATOR = System.lineSeparator();
36  
37    private static final String DEFAULT_DELIMITER = ";";
38  
39    private static final Pattern DELIMITER_PATTERN = Pattern
40        .compile("^\\s*((--)|(//))?\\s*(//)?\\s*@DELIMITER\\s+([^\\s]+)", Pattern.CASE_INSENSITIVE);
41  
42    private final Connection connection;
43  
44    private boolean stopOnError;
45    private boolean throwWarning;
46    private boolean autoCommit;
47    private boolean sendFullScript;
48    private boolean removeCRs;
49    private boolean escapeProcessing = true;
50  
51    private PrintWriter logWriter = new PrintWriter(System.out);
52    private PrintWriter errorLogWriter = new PrintWriter(System.err);
53  
54    private String delimiter = DEFAULT_DELIMITER;
55    private boolean fullLineDelimiter;
56  
57    public ScriptRunner(Connection connection) {
58      this.connection = connection;
59    }
60  
61    public void setStopOnError(boolean stopOnError) {
62      this.stopOnError = stopOnError;
63    }
64  
65    public void setThrowWarning(boolean throwWarning) {
66      this.throwWarning = throwWarning;
67    }
68  
69    public void setAutoCommit(boolean autoCommit) {
70      this.autoCommit = autoCommit;
71    }
72  
73    public void setSendFullScript(boolean sendFullScript) {
74      this.sendFullScript = sendFullScript;
75    }
76  
77    public void setRemoveCRs(boolean removeCRs) {
78      this.removeCRs = removeCRs;
79    }
80  
81    /**
82     * Sets the escape processing.
83     *
84     * @param escapeProcessing
85     *          the new escape processing
86     *
87     * @since 3.1.1
88     */
89    public void setEscapeProcessing(boolean escapeProcessing) {
90      this.escapeProcessing = escapeProcessing;
91    }
92  
93    public void setLogWriter(PrintWriter logWriter) {
94      this.logWriter = logWriter;
95    }
96  
97    public void setErrorLogWriter(PrintWriter errorLogWriter) {
98      this.errorLogWriter = errorLogWriter;
99    }
100 
101   public void setDelimiter(String delimiter) {
102     this.delimiter = delimiter;
103   }
104 
105   public void setFullLineDelimiter(boolean fullLineDelimiter) {
106     this.fullLineDelimiter = fullLineDelimiter;
107   }
108 
109   public void runScript(Reader reader) {
110     setAutoCommit();
111 
112     try {
113       if (sendFullScript) {
114         executeFullScript(reader);
115       } else {
116         executeLineByLine(reader);
117       }
118     } finally {
119       rollbackConnection();
120     }
121   }
122 
123   private void executeFullScript(Reader reader) {
124     StringBuilder script = new StringBuilder();
125     try (BufferedReader lineReader = new BufferedReader(reader)) {
126       String line;
127       while ((line = lineReader.readLine()) != null) {
128         script.append(line);
129         script.append(LINE_SEPARATOR);
130       }
131       String command = script.toString();
132       println(command);
133       executeStatement(command);
134       commitConnection();
135     } catch (Exception e) {
136       String message = "Error executing: " + script + ".  Cause: " + e;
137       printlnError(message);
138       throw new RuntimeException(message, e);
139     }
140   }
141 
142   private void executeLineByLine(Reader reader) {
143     StringBuilder command = new StringBuilder();
144     try {
145       BufferedReader lineReader = new BufferedReader(reader);
146       String line;
147       while ((line = lineReader.readLine()) != null) {
148         handleLine(command, line);
149       }
150       commitConnection();
151       checkForMissingLineTerminator(command);
152     } catch (Exception e) {
153       String message = "Error executing: " + command + ".  Cause: " + e;
154       printlnError(message);
155       throw new RuntimeException(message, e);
156     }
157   }
158 
159   /**
160    * @deprecated Since 3.5.4, this method is deprecated. Please close the {@link Connection} outside of this class.
161    */
162   @Deprecated
163   public void closeConnection() {
164     try {
165       connection.close();
166     } catch (Exception e) {
167       // ignore
168     }
169   }
170 
171   private void setAutoCommit() {
172     try {
173       if (autoCommit != connection.getAutoCommit()) {
174         connection.setAutoCommit(autoCommit);
175       }
176     } catch (Throwable t) {
177       throw new RuntimeException("Could not set AutoCommit to " + autoCommit + ". Cause: " + t, t);
178     }
179   }
180 
181   private void commitConnection() {
182     try {
183       if (!connection.getAutoCommit()) {
184         connection.commit();
185       }
186     } catch (Throwable t) {
187       throw new RuntimeException("Could not commit transaction. Cause: " + t, t);
188     }
189   }
190 
191   private void rollbackConnection() {
192     try {
193       if (!connection.getAutoCommit()) {
194         connection.rollback();
195       }
196     } catch (Throwable t) {
197       // ignore
198     }
199   }
200 
201   private void checkForMissingLineTerminator(StringBuilder command) {
202     if (command != null && command.toString().trim().length() > 0) {
203       throw new RuntimeException("Line missing end-of-line terminator (" + delimiter + ") => " + command);
204     }
205   }
206 
207   private void handleLine(StringBuilder command, String line) throws SQLException {
208     String trimmedLine = line.trim();
209     if (lineIsComment(trimmedLine)) {
210       Matcher matcher = DELIMITER_PATTERN.matcher(trimmedLine);
211       if (matcher.find()) {
212         delimiter = matcher.group(5);
213       }
214       println(trimmedLine);
215     } else if (commandReadyToExecute(trimmedLine)) {
216       command.append(line, 0, line.lastIndexOf(delimiter));
217       command.append(LINE_SEPARATOR);
218       println(command);
219       executeStatement(command.toString());
220       command.setLength(0);
221     } else if (trimmedLine.length() > 0) {
222       command.append(line);
223       command.append(LINE_SEPARATOR);
224     }
225   }
226 
227   private boolean lineIsComment(String trimmedLine) {
228     return trimmedLine.startsWith("//") || trimmedLine.startsWith("--");
229   }
230 
231   private boolean commandReadyToExecute(String trimmedLine) {
232     // issue #561 remove anything after the delimiter
233     return !fullLineDelimiter && trimmedLine.contains(delimiter) || fullLineDelimiter && trimmedLine.equals(delimiter);
234   }
235 
236   private void executeStatement(String command) throws SQLException {
237     Statement statement = connection.createStatement();
238     try {
239       statement.setEscapeProcessing(escapeProcessing);
240       String sql = command;
241       if (removeCRs) {
242         sql = sql.replace("\r\n", "\n");
243       }
244       try {
245         boolean hasResults = statement.execute(sql);
246         // DO NOT try to 'improve' the condition even if IDE tells you to!
247         // It's important that getUpdateCount() is called here.
248         while (!(!hasResults && statement.getUpdateCount() == -1)) {
249           checkWarnings(statement);
250           printResults(statement, hasResults);
251           hasResults = statement.getMoreResults();
252         }
253       } catch (SQLWarning e) {
254         throw e;
255       } catch (SQLException e) {
256         if (stopOnError) {
257           throw e;
258         }
259         String message = "Error executing: " + command + ".  Cause: " + e;
260         printlnError(message);
261       }
262     } finally {
263       try {
264         statement.close();
265       } catch (Exception ignored) {
266         // Ignore to workaround a bug in some connection pools
267         // (Does anyone know the details of the bug?)
268       }
269     }
270   }
271 
272   private void checkWarnings(Statement statement) throws SQLException {
273     if (!throwWarning) {
274       return;
275     }
276     // In Oracle, CREATE PROCEDURE, FUNCTION, etc. returns warning
277     // instead of throwing exception if there is compilation error.
278     SQLWarning warning = statement.getWarnings();
279     if (warning != null) {
280       throw warning;
281     }
282   }
283 
284   private void printResults(Statement statement, boolean hasResults) {
285     if (!hasResults) {
286       return;
287     }
288     try (ResultSet rs = statement.getResultSet()) {
289       ResultSetMetaData md = rs.getMetaData();
290       int cols = md.getColumnCount();
291       for (int i = 0; i < cols; i++) {
292         String name = md.getColumnLabel(i + 1);
293         print(name + "\t");
294       }
295       println("");
296       while (rs.next()) {
297         for (int i = 0; i < cols; i++) {
298           String value = rs.getString(i + 1);
299           print(value + "\t");
300         }
301         println("");
302       }
303     } catch (SQLException e) {
304       printlnError("Error printing results: " + e.getMessage());
305     }
306   }
307 
308   private void print(Object o) {
309     if (logWriter != null) {
310       logWriter.print(o);
311       logWriter.flush();
312     }
313   }
314 
315   private void println(Object o) {
316     if (logWriter != null) {
317       logWriter.println(o);
318       logWriter.flush();
319     }
320   }
321 
322   private void printlnError(Object o) {
323     if (errorLogWriter != null) {
324       errorLogWriter.println(o);
325       errorLogWriter.flush();
326     }
327   }
328 
329 }