View Javadoc
1   /*
2    *    Copyright 2009-2023 the original author or authors.
3    *
4    *    Licensed under the Apache License, Version 2.0 (the "License");
5    *    you may not use this file except in compliance with the License.
6    *    You may obtain a copy of the License at
7    *
8    *       https://www.apache.org/licenses/LICENSE-2.0
9    *
10   *    Unless required by applicable law or agreed to in writing, software
11   *    distributed under the License is distributed on an "AS IS" BASIS,
12   *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   *    See the License for the specific language governing permissions and
14   *    limitations under the License.
15   */
16  package org.mybatis.guice.transactional;
17  
18  import java.lang.reflect.Field;
19  import java.util.Arrays;
20  import java.util.IdentityHashMap;
21  import java.util.concurrent.ConcurrentHashMap;
22  
23  import javax.transaction.xa.XAException;
24  import javax.transaction.xa.XAResource;
25  import javax.transaction.xa.Xid;
26  
27  import org.apache.ibatis.logging.Log;
28  import org.apache.ibatis.logging.LogFactory;
29  import org.apache.ibatis.session.SqlSession;
30  import org.apache.ibatis.session.SqlSessionManager;
31  
32  public class XASqlSessionManager implements XAResource {
33    private static final Log log = LogFactory.getLog(XASqlSessionManager.class);
34  
35    public static final int NO_TX = 0;
36    public static final int STARTED = 1;
37    public static final int ENDED = 2;
38    public static final int PREPARED = 3;
39  
40    private SqlSessionManager sqlSessionManager;
41    private int transactionTimeout;
42    private String id;
43    private Xid xid;
44    private int state = NO_TX;
45  
46    private static ConcurrentHashMap<GlobalKey, GlobalToken> globalTokens = new ConcurrentHashMap<XASqlSessionManager.GlobalKey, XASqlSessionManager.GlobalToken>();
47  
48    public XASqlSessionManager(SqlSessionManager sqlSessionManager) {
49      this.sqlSessionManager = sqlSessionManager;
50      id = sqlSessionManager.getConfiguration().getEnvironment().getId();
51    }
52  
53    public String getId() {
54      return id;
55    }
56  
57    public int getState() {
58      return state;
59    }
60  
61    private String xlatedState() {
62      switch (state) {
63        case NO_TX:
64          return "NO_TX";
65        case STARTED:
66          return "STARTED";
67        case ENDED:
68          return "ENDED";
69        case PREPARED:
70          return "PREPARED";
71        default:
72          return "!invalid state (" + state + ")!";
73      }
74    }
75  
76    private String decodeXAResourceFlag(int flag) {
77      switch (flag) {
78        case XAResource.TMENDRSCAN:
79          return "TMENDRSCAN";
80        case XAResource.TMFAIL:
81          return "TMFAIL";
82        case XAResource.TMJOIN:
83          return "TMJOIN";
84        case XAResource.TMNOFLAGS:
85          return "TMNOFLAGS";
86        case XAResource.TMONEPHASE:
87          return "TMONEPHASE";
88        case XAResource.TMRESUME:
89          return "TMRESUME";
90        case XAResource.TMSTARTRSCAN:
91          return "TMSTARTRSCAN";
92        case XAResource.TMSUCCESS:
93          return "TMSUCCESS";
94        case XAResource.TMSUSPEND:
95          return "TMSUSPEND";
96        default:
97          return "" + flag;
98      }
99    }
100 
101   @Override
102   public int getTransactionTimeout() throws XAException {
103     return transactionTimeout;
104   }
105 
106   @Override
107   public boolean setTransactionTimeout(int second) throws XAException {
108     transactionTimeout = second;
109     return true;
110   }
111 
112   @Override
113   public void forget(Xid xid) throws XAException {
114   }
115 
116   @Override
117   public Xid[] recover(int flags) throws XAException {
118     return new Xid[0];
119   }
120 
121   @Override
122   public boolean isSameRM(XAResource xares) throws XAException {
123     return this == xares;
124   }
125 
126   @Override
127   public void start(Xid xid, int flag) throws XAException {
128     if (log.isDebugEnabled()) {
129       log.debug(
130           id + ": call start old state=" + xlatedState() + ", XID=" + xid + ", flag=" + decodeXAResourceFlag(flag));
131     }
132 
133     if (flag != XAResource.TMNOFLAGS && flag != XAResource.TMJOIN) {
134       throw new MyBatisXAException(id + ": unsupported start flag " + decodeXAResourceFlag(flag),
135           XAException.XAER_RMERR);
136     }
137 
138     if (xid == null) {
139       throw new MyBatisXAException(id + ": XID cannot be null", XAException.XAER_INVAL);
140     }
141 
142     if (state == NO_TX) {
143       if (this.xid != null) {
144         throw new MyBatisXAException(id + ": resource already started on XID " + this.xid, XAException.XAER_PROTO);
145       } else {
146         if (flag == XAResource.TMJOIN) {
147           throw new MyBatisXAException(id + ": resource not yet started", XAException.XAER_PROTO);
148         } else {
149           if (log.isDebugEnabled()) {
150             log.debug(id + ": OK to start, old state=" + xlatedState() + ", XID=" + xid + ", flag="
151                 + decodeXAResourceFlag(flag));
152           }
153           this.xid = xid;
154         }
155       }
156     } else if (state == STARTED) {
157       throw new MyBatisXAException(id + ": resource already started on XID " + this.xid, XAException.XAER_PROTO);
158     } else if (state == ENDED) {
159       if (flag == XAResource.TMNOFLAGS) {
160         throw new MyBatisXAException(id + ": resource already registered XID " + this.xid, XAException.XAER_DUPID);
161       } else {
162         if (xid.equals(this.xid)) {
163           if (log.isDebugEnabled()) {
164             log.debug(id + ": OK to join, old state=" + xlatedState() + ", XID=" + xid + ", flag="
165                 + decodeXAResourceFlag(flag));
166           }
167         } else {
168           throw new MyBatisXAException(id + ": resource already started on XID " + this.xid
169               + " - cannot start it on more than one XID at a time", XAException.XAER_RMERR);
170         }
171       }
172     } else if (state == PREPARED) {
173       throw new MyBatisXAException(id + ": resource already prepared on XID " + this.xid, XAException.XAER_PROTO);
174     }
175 
176     state = STARTED;
177     parentSuspend(xid);
178   }
179 
180   @Override
181   public void end(Xid xid, int flag) throws XAException {
182     if (log.isDebugEnabled()) {
183       log.debug(
184           id + ": call end old state=" + xlatedState() + ", XID=" + xid + " and flag " + decodeXAResourceFlag(flag));
185     }
186 
187     if (flag != XAResource.TMSUCCESS && flag != XAResource.TMFAIL) {
188       throw new MyBatisXAException(id + ": unsupported end flag " + decodeXAResourceFlag(flag), XAException.XAER_RMERR);
189     }
190 
191     if (xid == null) {
192       throw new MyBatisXAException(id + ": XID cannot be null", XAException.XAER_INVAL);
193     }
194 
195     if (state == NO_TX) {
196       throw new MyBatisXAException(id + ": resource never started on XID " + xid, XAException.XAER_PROTO);
197     } else if (state == STARTED) {
198       if (this.xid.equals(xid)) {
199         if (log.isDebugEnabled()) {
200           log.debug(
201               id + ": OK to end, old state=" + xlatedState() + ", XID=" + xid + ", flag=" + decodeXAResourceFlag(flag));
202         }
203       } else {
204         throw new MyBatisXAException(
205             id + ": resource already started on XID " + this.xid + " - cannot end it on another XID " + xid,
206             XAException.XAER_PROTO);
207       }
208     } else if (state == ENDED) {
209       throw new MyBatisXAException(id + ": resource already ended on XID " + xid, XAException.XAER_PROTO);
210     } else if (state == PREPARED) {
211       throw new MyBatisXAException(id + ": cannot end, resource already prepared on XID " + xid,
212           XAException.XAER_PROTO);
213     }
214 
215     if (flag == XAResource.TMFAIL) {
216       // Rollback transaction. After call method end() call method rollback()
217       if (log.isDebugEnabled()) {
218         log.debug(id + ": after end TMFAIL reset state to ENDED and roolback");
219       }
220     }
221 
222     this.state = ENDED;
223   }
224 
225   @Override
226   public int prepare(Xid xid) throws XAException {
227     if (log.isDebugEnabled()) {
228       log.debug(id + ": call prepare old state=" + xlatedState() + ", XID=" + xid);
229     }
230 
231     if (xid == null) {
232       throw new MyBatisXAException(id + ": XID cannot be null", XAException.XAER_INVAL);
233     }
234 
235     if (state == NO_TX) {
236       throw new MyBatisXAException(id + ": resource never started on XID " + xid, XAException.XAER_PROTO);
237     } else if (state == STARTED) {
238       throw new MyBatisXAException(id + ": resource never ended on XID " + xid, XAException.XAER_PROTO);
239     } else if (state == ENDED) {
240       if (this.xid.equals(xid)) {
241         if (log.isDebugEnabled()) {
242           log.debug(id + ": OK to prepare, old state=" + xlatedState() + ", XID=" + xid);
243         }
244       } else {
245         throw new MyBatisXAException(
246             id + ": resource already started on XID " + this.xid + " - cannot prepare it on another XID " + xid,
247             XAException.XAER_PROTO);
248       }
249     } else if (state == PREPARED) {
250       throw new MyBatisXAException(id + ": resource already prepared on XID " + this.xid, XAException.XAER_PROTO);
251     }
252 
253     this.state = PREPARED;
254     return XAResource.XA_OK;
255   }
256 
257   @Override
258   public void commit(Xid xid, boolean onePhase) throws XAException {
259     if (log.isDebugEnabled()) {
260       log.debug(id + ": call commit old state=" + xlatedState() + ", XID=" + xid + " onePhase is " + onePhase);
261     }
262 
263     if (xid == null) {
264       throw new MyBatisXAException(id + ": XID cannot be null", XAException.XAER_INVAL);
265     }
266 
267     if (state == NO_TX) {
268       throw new MyBatisXAException(id + ": resource never started on XID " + xid, XAException.XAER_PROTO);
269     } else if (state == STARTED) {
270       throw new MyBatisXAException(id + ": resource never ended on XID " + xid, XAException.XAER_PROTO);
271     } else if (state == ENDED) {
272       if (onePhase) {
273         if (log.isDebugEnabled()) {
274           log.debug(id + ": OK to commit with 1PC, old state=" + xlatedState() + ", XID=" + xid);
275         }
276       } else {
277         throw new MyBatisXAException(id + ": resource never prepared on XID " + xid, XAException.XAER_PROTO);
278       }
279     } else if (state == PREPARED) {
280       if (!onePhase) {
281         if (this.xid.equals(xid)) {
282           if (log.isDebugEnabled()) {
283             log.debug(id + ": OK to commit, old state=" + xlatedState() + ", XID=" + xid);
284           }
285         } else {
286           throw new MyBatisXAException(
287               id + ": resource already started on XID " + this.xid + " - cannot commit it on another XID " + xid,
288               XAException.XAER_PROTO);
289         }
290       } else {
291         throw new MyBatisXAException(id + ": cannot commit in one phase as resource has been prepared on XID " + xid,
292             XAException.XAER_PROTO);
293       }
294     }
295 
296     try {
297       parentResume(xid);
298     } finally {
299       if (log.isDebugEnabled()) {
300         log.debug(id + ": after commit reset state to NO_TX");
301       }
302       this.state = NO_TX;
303       this.xid = null;
304     }
305   }
306 
307   @Override
308   public void rollback(Xid xid) throws XAException {
309     if (log.isDebugEnabled()) {
310       log.debug(id + ": call roolback old state=" + xlatedState() + ", XID=" + xid);
311     }
312 
313     if (xid == null) {
314       throw new MyBatisXAException(id + ": XID cannot be null", XAException.XAER_INVAL);
315     }
316 
317     if (state == NO_TX) {
318       throw new MyBatisXAException(id + ": resource never started on XID " + xid, XAException.XAER_PROTO);
319     } else if (state == STARTED) {
320       throw new MyBatisXAException(id + ": resource never ended on XID " + xid, XAException.XAER_PROTO);
321     } else if (state == ENDED) {
322       if (this.xid.equals(xid)) {
323         if (log.isDebugEnabled()) {
324           log.debug(id + ": OK to rollback, old state=" + xlatedState() + ", XID=" + xid);
325         }
326       } else {
327         throw new MyBatisXAException(
328             id + ": resource already started on XID " + this.xid + " - cannot roll it back on another XID " + xid,
329             XAException.XAER_PROTO);
330       }
331     } else if (state == PREPARED) {
332       if (log.isDebugEnabled()) {
333         log.debug(id + ": rollback reset state from PREPARED to NO_TX");
334       }
335       this.state = NO_TX;
336       throw new MyBatisXAException(id + ": resource committed during prepare on XID " + this.xid,
337           XAException.XA_HEURCOM);
338     }
339 
340     try {
341       parentResume(xid);
342     } finally {
343       if (log.isDebugEnabled()) {
344         log.debug(id + ": after rollback reset state to NO_TX");
345       }
346       this.state = NO_TX;
347       this.xid = null;
348     }
349   }
350 
351   private void parentSuspend(Xid xid) {
352     if (log.isDebugEnabled()) {
353       log.debug(id + ": suspend parent session " + xid);
354     }
355 
356     byte[] trId = xid.getGlobalTransactionId();
357     GlobalKey key = new GlobalKey(trId);
358     GlobalToken globalToken = globalTokens.get(key);
359 
360     if (globalToken == null) {
361       if (log.isDebugEnabled()) {
362         log.debug(id + ": add GlobalToken " + key);
363       }
364 
365       globalTokens.put(key, globalToken = new GlobalToken());
366     } else {
367       if (log.isDebugEnabled()) {
368         log.debug(id + ": present GlobalToken " + key);
369       }
370     }
371     globalToken.parentSuspend(id, sqlSessionManager);
372   }
373 
374   private void parentResume(Xid xid) {
375     if (log.isDebugEnabled()) {
376       log.debug(id + ": resume parent session " + xid);
377     }
378 
379     byte[] trId = xid.getGlobalTransactionId();
380     GlobalKey key = new GlobalKey(trId);
381     GlobalToken globalToken = globalTokens.get(key);
382 
383     if (globalToken != null) {
384       globalToken.parentResume(id, sqlSessionManager);
385 
386       if (globalToken.isEmpty()) {
387         if (log.isDebugEnabled()) {
388           log.debug(id + ": remove GlobalToken " + key);
389         }
390 
391         globalTokens.remove(key);
392       } else {
393         if (log.isDebugEnabled()) {
394           log.debug(id + ": not remove GlobalToken " + key);
395         }
396       }
397     } else {
398       if (log.isDebugEnabled()) {
399         log.debug(id + ": not find GlobalToken " + key);
400       }
401     }
402   }
403 
404   static class GlobalKey {
405     final byte[] globalId;
406     final int arrayHash;
407 
408     public GlobalKey(byte[] globalId) {
409       this.globalId = globalId;
410       this.arrayHash = Arrays.hashCode(globalId);
411     }
412 
413     @Override
414     public int hashCode() {
415       return arrayHash;
416     }
417 
418     @Override
419     public boolean equals(Object obj) {
420       if (this == obj) {
421         return true;
422       }
423 
424       if (obj == null) {
425         return false;
426       }
427 
428       if (getClass() != obj.getClass()) {
429         return false;
430       }
431 
432       GlobalKey other = (GlobalKey) obj;
433       if (!Arrays.equals(globalId, other.globalId)) {
434         return false;
435       }
436       return true;
437     }
438 
439     @Override
440     public String toString() {
441       StringBuilder s = new StringBuilder();
442       s.append("[Xid:globalId=");
443       for (int i = 0; i < globalId.length; i++) {
444         s.append(Integer.toHexString(globalId[i]));
445       }
446       s.append(",length=").append(globalId.length);
447       return s.toString();
448     }
449   }
450 
451   static class GlobalToken {
452     private final Log log = LogFactory.getLog(getClass());
453     IdentityHashMap<SqlSessionManager, Token> tokens = new IdentityHashMap<SqlSessionManager, XASqlSessionManager.Token>();
454 
455     public GlobalToken() {
456     }
457 
458     void parentSuspend(String id, SqlSessionManager sqlSessionManager) {
459       Token token = tokens.get(sqlSessionManager);
460 
461       if (token == null) {
462         if (log.isDebugEnabled()) {
463           log.debug(id + ": add Token " + sqlSessionManager);
464         }
465 
466         token = new Token(sqlSessionManager);
467         tokens.put(sqlSessionManager, token);
468       } else {
469         if (log.isDebugEnabled()) {
470           log.debug(id + ": present Token " + sqlSessionManager);
471         }
472       }
473       token.parentSuspend(id);
474     }
475 
476     void parentResume(String id, SqlSessionManager sqlSessionManager) {
477       Token token = tokens.get(sqlSessionManager);
478 
479       if (token != null) {
480         token.parentResume(id);
481 
482         // remove last
483         if (token.isFirst()) {
484           if (log.isDebugEnabled()) {
485             log.debug(id + ": remove parent session " + sqlSessionManager);
486           }
487 
488           tokens.remove(sqlSessionManager);
489         }
490       } else {
491         if (log.isDebugEnabled()) {
492           log.debug(id + ": not find parent session " + sqlSessionManager);
493         }
494       }
495     }
496 
497     boolean isEmpty() {
498       return tokens.isEmpty();
499     }
500   }
501 
502   static class Token {
503     private final Log log = LogFactory.getLog(getClass());
504     final SqlSessionManager sqlSessionManager;
505     ThreadLocal<SqlSession> localSqlSession;
506     SqlSession suspendedSqlSession;
507     int count;
508 
509     @SuppressWarnings("unchecked")
510     public Token(SqlSessionManager sqlSessionManager) {
511       this.sqlSessionManager = sqlSessionManager;
512       this.count = 0;
513       try {
514         Field field = SqlSessionManager.class.getDeclaredField("localSqlSession");
515         field.setAccessible(true);
516         localSqlSession = (ThreadLocal<SqlSession>) field.get(sqlSessionManager);
517       } catch (Exception e) {
518       }
519     }
520 
521     boolean isFirst() {
522       return count == 0;
523     }
524 
525     void parentSuspend(String id) {
526       if (isFirst()) {
527         if (log.isDebugEnabled()) {
528           log.debug(id + " suspend parent session");
529         }
530 
531         if (localSqlSession != null) {
532           suspendedSqlSession = localSqlSession.get();
533           localSqlSession.remove();
534         }
535       } else {
536         if (log.isDebugEnabled()) {
537           log.debug(id + " skip suspend parent session");
538         }
539       }
540       count++;
541     }
542 
543     void parentResume(String id) {
544       if (count > 0) {
545         count--;
546       }
547 
548       if (isFirst()) {
549         if (log.isDebugEnabled()) {
550           log.debug(id + " resume parent session");
551         }
552 
553         if (localSqlSession != null) {
554           if (suspendedSqlSession == null) {
555             localSqlSession.remove();
556           } else {
557             localSqlSession.set(suspendedSqlSession);
558             suspendedSqlSession = null;
559           }
560         }
561       } else {
562         if (log.isDebugEnabled()) {
563           log.debug(id + " skip resume parent session");
564         }
565       }
566     }
567   }
568 }