事务中的全部操作,要求要么都成功,要么都不成功。通常在同一个jvm中是比较容易做到的,例如数据库JDBC操作,Spring能够帮我们做这件事。但是在分布式环境下,A服务调用B服务,在这个过程出现了异常,又该怎么保证A、B服务的事务都回滚呢?
分布式事务一般有三种解决方案:
1、2PC
2、最终消息一致性
3、TCC
这里我们介绍2PC这种解决方案。
在A调用B的过程中,A出现了一个异常。
A服务:
@Service
public class DemoService {
@Autowired
private DemoDao demoDao;
@SxmTransactional(start = true)
public void test() {
demoDao.insert("server1");
HttpUtil.post("http://localhost:8082/server2/test");
int i = 1/0;
}
}
B服务:
@Service
public class DemoService {
@Autowired
private DemoDao demoDao;
@SxmTransactional(end = true)
public void test() {
System.out.println("执行server2业务");
demoDao.insert("server2");
}
}
自定义分布式事务注解:
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Transactional
public @interface SxmTransactional {
boolean start() default false;
boolean end() default false;
}
我们需要在A服务执行方法之前,提前做些准备,这里借助spring的AOP来完成。
//保证当前切面先ConnectionAspect执行,值越小有越高的优先级
@Order(10000)
@Aspect
@Component
public class TransactionAspect {
@Around("@annotation(com.su.annotation.SxmTransactional)")
public void invoke(ProceedingJoinPoint joinPoint){
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
SxmTransactional annotation = method.getAnnotation(SxmTransactional.class);
String gid = "";
if (annotation.start()){
//创建事务组
gid = TransactionManager.createGroup();
}else {
//从上一个服务调用传过来
gid = TransactionManager.getCurrentGroup();
}
//创建本地事务
Transaction transaction = TransactionManager.createTransaction(gid);
try {
//执行目标方法(在这里我们要拿到真正数据库操作连接的控制权)
//注意:这里要保证目标方法快速执行完,不能形成等待死锁,所以在目标方法中新建一个线程来执行commit或rollback
joinPoint.proceed();
//向netty服务端事务管理器提交本地事务 commit
TransactionManager.commitTransaction(transaction,annotation.end(),TransactionType.COMMIT);
} catch (Throwable throwable) {
//向netty服务端事务管理器提交本地事务 rollback
TransactionManager.commitTransaction(transaction,annotation.end(),TransactionType.ROLLBACK);
throwable.printStackTrace();
}
}
}
在执行目标方法时,涉及到了jdbc操作,所以我们要拿到数据库操作连接的控制权,这里同样使用AOP来完成。
@Aspect
@Component
public class ConnectionAspect {
@Around("execution(* javax.sql.DataSource.getConnection(..))")
public Connection getConnection(ProceedingJoinPoint joinPoint){
try {
//真正的jdbc连接对象
Connection connection = (Connection) joinPoint.proceed();
//从ThreadLocal中获取本地当前Transaction对象,所以要让TransactionAspect先执行
Transaction currentTransaction = TransactionManager.getCurrentTransaction();
return new SxmConnection(connection,currentTransaction);
} catch (Throwable throwable) {
throwable.printStackTrace();
}
return null;
}
}
返回包装后的连接对象:
public class SxmConnection implements Connection {
//真正的数据库连接对象
private Connection connection;
//当前连接的本地自定义事务对象
private Transaction transaction;
public SxmConnection(Connection connection,Transaction transaction) {
this.connection = connection;
this.transaction = transaction;
}
@Override
public void commit() throws SQLException {
new Thread(()->{
//需要等待netty服务端事务管理器通知,然后才提交
transaction.await();
//netty服务端可能会更改transaction的TransactionType
try {
if (transaction.getType().equals(TransactionType.COMMIT)){
connection.commit();
}else {
connection.rollback();
}
connection.close();
}catch (Exception e){
e.printStackTrace();
}
}).start();
}
@Override
public void rollback() throws SQLException {
new Thread(()->{
//需要等待netty服务事务管理器通知,然后才回滚
transaction.await();
try {
connection.rollback();
connection.close();
} catch (SQLException e) {
e.printStackTrace();
}
}).start();
}
//注意这里不能调用connection.close(),如果真正的连接对象都关闭了,后面都不用玩了。
@Override
public void close() throws SQLException {
}
@Override
public Statement createStatement() throws SQLException {
return connection.createStatement();
}
@Override
public PreparedStatement prepareStatement(String sql) throws SQLException {
return connection.prepareStatement(sql);
}
//以下还有许多Override方法
......
}
本地事务管理器:
@Component("localTransactionManager")
public class TransactionManager {
//用于保存本地事务管理器管理的所有事务组,key为groupId
public static Map<String,Map<String,Transaction>> groupMap = new HashMap();
//用于保存本地事务组groupId
public static ThreadLocal<String> groupThreadLocal = new ThreadLocal<>();
//用于保存本地事务
public static ThreadLocal<Transaction> transactionThreadLocal =new ThreadLocal<>();
//用于记录当前事务标号
public static ThreadLocal<Integer> currentTransactionNum = new ThreadLocal<>();
public static NettyClient client;
@Autowired
public void setNettyClient(NettyClient nettyClient){
client = nettyClient;
}
/**
* 创建事务组
*/
public static String createGroup(){
String gid = UUID.randomUUID().toString();
groupMap.put(gid,new HashMap<>());
groupThreadLocal.set(gid);
//发送创建事务组消息给netty服务器
JSONObject jsonObject = new JSONObject();
jsonObject.put("groupId",gid);
jsonObject.put("command","create");
client.send(jsonObject);
return gid;
}
/**
* 创建本地事务
* @param gid
*/
public static Transaction createTransaction(String gid){
Transaction transaction = new Transaction();
String tid = UUID.randomUUID().toString();
transaction.setId(tid);
transaction.setGid(gid);
if (groupMap.get(gid)==null){
groupMap.put(gid,new HashMap<>());
}
//保存到本地事务管理器
transactionThreadLocal.set(transaction);
groupMap.get(gid).put(tid,transaction);
//保存本地事务标号
saveCurrentTransactionNum();
return transaction;
}
/**
* 保存本地事务标号
* 上一个服务的基础上加1
*/
private static void saveCurrentTransactionNum() {
Integer num = currentTransactionNum.get()==null?0:currentTransactionNum.get()+1;
currentTransactionNum.set(num);
}
/**
* 得到本地事务标号
*/
public static Integer getCurrentTransactionNum(){
return currentTransactionNum.get();
}
/**
* 设置本地事务标号
*/
public static void setCurrentTransactionNum(Integer num){
currentTransactionNum.set(num);
}
/**
* 设置当前线程维护的事务组
* @param gid
*/
public static void setCurrentGroup(String gid){
groupThreadLocal.set(gid);
}
/**
* 得到当前线程维护的事务组
* @return
*/
public static String getCurrentGroup(){
return groupThreadLocal.get();
}
/**
* 得到当前线程维护的事务
*/
public static Transaction getCurrentTransaction(){
return transactionThreadLocal.get();
}
/**
* 向netty服务器提交事务
*/
public static void commitTransaction(Transaction transaction,boolean end,TransactionType type){
JSONObject jsonObject = new JSONObject();
jsonObject.put("groupId",transaction.getGid());
jsonObject.put("transactionId",transaction.getId());
jsonObject.put("command","add");
jsonObject.put("end",end);
jsonObject.put("transactionType",type);
jsonObject.put("transactionNum",getCurrentTransactionNum());
client.send(jsonObject);
}
}
本地事务:
public class Transaction {
//事务id
private String id;
//事务组id
private String gid;
//事务类型(commit or rollback)
private TransactionType type;
//当前事务对应的锁对象
private Lock lock = new ReentrantLock();
private Condition condition = lock.newCondition();
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getGid() {
return gid;
}
public void setGid(String gid) {
this.gid = gid;
}
public TransactionType getType() {
return type;
}
public void setType(TransactionType type) {
this.type = type;
}
public Condition getCondition() {
return condition;
}
//等待(当前线程会等待在这个condition对象的等待队列中)
public void await(){
try {
lock.lock();
condition.await();
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
//唤醒
public void signal(){
try {
lock.lock();
condition.signal();
}finally {
lock.unlock();
}
}
}
事务类型:
public enum TransactionType {
COMMIT,ROLLBACK
}
本地netty客户端:
@Component
public class NettyClient implements InitializingBean {
public NettyClientHandler client = null;
private static ExecutorService executorService = Executors.newCachedThreadPool();
@Override
public void afterPropertiesSet() throws Exception {
start("localhost", 8080);
}
public void start(String hostName, Integer port) {
client = new NettyClientHandler();
Bootstrap b = new Bootstrap();
EventLoopGroup group = new NioEventLoopGroup();
b.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
protected void initChannel(SocketChannel socketChannel) throws Exception {
ChannelPipeline pipeline = socketChannel.pipeline();
pipeline.addLast("decoder", new StringDecoder());
pipeline.addLast("encoder", new StringEncoder());
pipeline.addLast("handler", client);
}
});
try {
b.connect(hostName, port).sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
public void send(JSONObject jsonObject) {
try {
client.call(jsonObject);
} catch (Exception e) {
e.printStackTrace();
}
}
}
public class NettyClientHandler extends ChannelInboundHandlerAdapter {
private ChannelHandlerContext context;
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
context = ctx;
}
/**
* 接收服务端通知
*/
@Override
public synchronized void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
JSONObject jo = JSON.parseObject((String) msg);
String groupId = jo.getString("groupId");
String noticeCommand = jo.getString("noticeCommand");
String transactionId = jo.getString("transactionId");
System.out.println("client receive command:"+noticeCommand);
Transaction transaction = TransactionManager.groupMap.get(groupId).get(transactionId);
if ("commit".equals(noticeCommand)) {
transaction.setType(TransactionType.COMMIT);
}else {
transaction.setType(TransactionType.ROLLBACK);
}
transaction.signal();
}
public synchronized Object call(JSONObject data) throws Exception {
context.writeAndFlush(data.toJSONString()).channel().newPromise();
return null;
}
}
HttpUtil:
@Component
public class HttpUtil {
private static RestTemplate restTemplate = new RestTemplate();
public static Object post(String url){
HttpHeaders header = new HttpHeaders();
header.set("groupId", TransactionManager.getCurrentGroup());
header.set("transactionNum",String.valueOf(TransactionManager.getCurrentTransactionNum()));
HttpEntity<MultiValueMap<String, String>> httpEntity = new HttpEntity<>(null, header);
return restTemplate.postForObject(url,httpEntity,Object.class);
}
}
请求拦截器:
@Configuration
public class WebAppConfig extends WebMvcConfigurerAdapter {
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(new RequestInterceptor());
}
}
public class RequestInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
//接收从上一个服务调用传过来的
String groupId = request.getHeader("groupId");
String transactionNum = request.getHeader("transactionNum");
TransactionManager.setCurrentGroup(groupId);
TransactionManager.setCurrentTransactionNum(Integer.valueOf(transactionNum==null? "0":transactionNum));
return true;
}
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
}
}
服务端事务管理器:
/**
* 作为分布式事务管理器,它需要:
* 1. 创建并保存事务组
* 2. 保存各个子事务在对应的事务组内
* 3. 统计并判断事务组内的各个子事务状态,以算出当前事务组的状态(提交or回滚)
* 4. 通知各个子事务提交或回滚
*/
public class NettyServerHandler extends ChannelInboundHandlerAdapter {
//private static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
//保存事务组对应的所有channel
private static Map<String,Map<String,Channel>> channelGroupMap = new ConcurrentHashMap<>();
//保存事务组内所有的事务
private static Map<String,List<JSONObject>> groupTransactions = new ConcurrentHashMap<>();
//保存事务组内所有事务的状态
private static Map<String,List<String>> groupStatus = new ConcurrentHashMap<>();
//保存每个事务组结束状态
private static Map<String,Boolean> endGroupMap = new ConcurrentHashMap<>();
//保存每个事务组应该有的事务数量
private static Map<String,Integer> countTransGroupMap = new ConcurrentHashMap<>();
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
System.out.println("接受数据:" + msg.toString());
JSONObject jsonObject = JSON.parseObject((String) msg);
String groupId = jsonObject.getString("groupId");
String command = jsonObject.getString("command");
System.out.println("groupId:"+groupId);
System.out.println("command:"+command);
if ("create".equals(command)){
createGroup(groupId);
}else if ("add".equals(command)){
boolean end = jsonObject.getBoolean("end");
String transactionId = jsonObject.getString("transactionId");
String transactionType = jsonObject.getString("transactionType");
Integer transactionNum = jsonObject.getInteger("transactionNum");
//保存事务组对应的channel
addChannelMap(groupId,transactionId,ctx.channel());
//保存事务组中的事务
addGroupTransactions(groupId,jsonObject);
addGroupStatus(groupId,transactionType);
if (end){
System.out.println("------已经执行-----");
endGroupMap.put(groupId,Boolean.TRUE);
countTransGroupMap.put(groupId,transactionNum);
}
//如果当前事务组收到end,并且事务组应有事务数量等于实际接收到事务数量,触发计算事务组状态
if (endGroupMap.get(groupId) && countTransGroupMap.get(groupId)==groupTransactions.get(groupId).size()){
//算出当前事务组的状态(提交or回滚)
String noticeCommand = "";
List<JSONObject> result = new LinkedList<>();
if (groupStatus.get(groupId).contains("ROLLBACK")){
noticeCommand = "rollback";
}else {
noticeCommand = "commit";
}
sendResult(groupId,noticeCommand);
}
}
}
private void createGroup(String groupId) {
groupTransactions.put(groupId,new LinkedList<>());
groupStatus.put(groupId,new LinkedList<>());
}
private void addGroupTransactions(String groupId, JSONObject jsonObject) {
if (groupTransactions.get(groupId)==null){
groupTransactions.put(groupId,new LinkedList<>());
}
groupTransactions.get(groupId).add(jsonObject);
}
private void addGroupStatus(String groupId, String transactionType) {
if (groupStatus.get(groupId)==null){
groupStatus.put(groupId,new LinkedList<>());
}
groupStatus.get(groupId).add(transactionType);
}
private void addChannelMap(String groupId,String transactionId, Channel channel) {
if (channelGroupMap.get(groupId)==null){
channelGroupMap.put(groupId,new HashMap<>());
}
channelGroupMap.get(groupId).put(transactionId,channel);
}
/**
* 通知本地事务
* 这里就不考虑发送失败等情况了
*/
private void sendResult(String groupId, String noticeCommand) {
Map<String, Channel> channels = channelGroupMap.get(groupId);
for (Map.Entry<String, Channel> entry : channels.entrySet()) {
JSONObject jo = new JSONObject();
jo.put("groupId",groupId);
jo.put("noticeCommand",noticeCommand);
jo.put("transactionId",entry.getKey());
ChannelFuture channelFuture = entry.getValue().writeAndFlush(jo.toJSONString());
System.out.println(channelFuture);
}
//释放资源
channelGroupMap.remove(groupId);
groupTransactions.remove(groupId);
groupStatus.remove(groupId);
endGroupMap.remove(groupId);
countTransGroupMap.remove(groupId);
}
}
至此,分布式事务框架已完成,A、B服务出现异常,数据库都会回滚。但是不能用于生产环境,还有许多需要优化的地方,这里是为了描述工作原理。