MongoTemplate很好用,但是基于xml注册为Bean时只能绑定在一个database上。
遇到需要支撑多个database的项目或动态切换database的项目就非常难受了。
解决的思路是把MongoTemplate放在Map中缓存起来,由于MongoTemplate内部实现了连接池,所以不用再关心池的概念。
把管理容器的类声明为Spring的组件,这样一来就可以通过@Value引入properties文件中的属性
使用LocalThread来确保本地线程的安全,避免多线程并发调用时导致的结果不一致。
import com.mongodb.*; import org.springframework.beans.factory.annotation.Value; import org.springframework.data.mongodb.MongoDbFactory; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.SimpleMongoDbFactory; import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver; import org.springframework.data.mongodb.core.convert.DefaultMongoTypeMapper; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.stereotype.Repository; import java.net.UnknownHostException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; /** * @author ParanoidCAT * @since JDK 1.8 */ @Repository(value = "mongoRepository") public class MongoRepository { @Value("${mongo.host}") private String host; @Value("${mongo.port}") private Integer port; @Value("${mongo.username}") private String username; @Value("${mongo.password}") private String password; @Value("${mongo.database}") private String database; @Value("${mongo.connectionsPerHost}") private Integer connectionsPerHost; @Value("${mongo.threadsAllowedToBlockForConnectionMultiplier}") private Integer threadsAllowedToBlockForConnectionMultiplier; @Value("${mongo.connectTimeout}") private Integer connectTimeout; @Value("${mongo.maxWaitTime}") private Integer maxWaitTime; @Value("${mongo.socketTimeout}") private Integer socketTimeout; @Value("${mongo.socketKeepAlive}") private Boolean socketKeepAlive; private ThreadLocal<MongoTemplate> threadLocal = new ThreadLocal<>(); private static final Map<String, MongoTemplate> MONGO_TEMPLATE_CACHE = new ConcurrentHashMap<>(16); private void changeDatabase(String databaseName) { if (Optional.ofNullable(threadLocal.get()).map(MongoTemplate::getDb).map(DB::getName).orElse(database).equals(databaseName)) { return; } if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) { threadLocal.remove(); threadLocal.set(MONGO_TEMPLATE_CACHE.get(databaseName)); return; } synchronized (MONGO_TEMPLATE_CACHE) { if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) { changeDatabase(databaseName); } else { threadLocal.remove(); try { threadLocal.set(createMongoTemplate(databaseName)); MONGO_TEMPLATE_CACHE.putIfAbsent(databaseName, threadLocal.get()); } catch (Exception e) { // TODO 输出日志 System.out.println(e.toString()); } } } } private MongoTemplate createMongoTemplate(String databaseName) throws UnknownHostException { MongoClient mongoClient = new MongoClient( Collections.singletonList(new ServerAddress(host, port)), Collections.singletonList(MongoCredential.createCredential(username, database, password.toCharArray())), new MongoClientOptions .Builder() .connectionsPerHost(connectionsPerHost) .threadsAllowedToBlockForConnectionMultiplier(threadsAllowedToBlockForConnectionMultiplier) .connectTimeout(connectTimeout) .maxWaitTime(maxWaitTime) .socketTimeout(socketTimeout) .socketKeepAlive(socketKeepAlive) .cursorFinalizerEnabled(true) .build() ); MongoDbFactory mongoDbFactory = new SimpleMongoDbFactory(mongoClient, databaseName); MappingMongoConverter mappingMongoConverter = new MappingMongoConverter(new DefaultDbRefResolver(mongoDbFactory), new MongoMappingContext()); mappingMongoConverter.setTypeMapper(new DefaultMongoTypeMapper(null)); return new MongoTemplate(mongoDbFactory, mappingMongoConverter); } /** * 插入一条记录 * * @param databaseName 数据库名 * @param t 实例 * @param <T> 实例所属的类 */ public <T> void insert(String databaseName, T t) { changeDatabase(databaseName); threadLocal.get().insert(t); } /** * 插入一条记录 * * @param databaseName 数据库名 * @param collectionName 集合名 * @param t 实例 * @param <T> 实例所属的类 */ public <T> void insert(String databaseName, String collectionName, T t) { changeDatabase(databaseName); threadLocal.get().insert(t, collectionName); } /** * 插入多条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param tList 实例 * @param <T> 实例所属的类 */ public <T> void insertAll(String databaseName, Class<T> tClass, List<T> tList) { changeDatabase(databaseName); threadLocal.get().insert(tList, tClass); } /** * 插入多条记录 * * @param databaseName 数据库名 * @param collectionName 集合名 * @param tList 实例 * @param <T> 实例所属的类 */ public <T> void insertAll(String databaseName, String collectionName, List<T> tList) { changeDatabase(databaseName); threadLocal.get().insert(tList, collectionName); } /** * 移除一条或多条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param query 查询条件 * @param <T> 实例所属的类 * @return */ public <T> long remove(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().remove(query, tClass).getN(); } /** * 移除一条或多条记录 * * @param databaseName 数据库名 * @param collectionName 集合名 * @param query 查询条件 * @return 受影响的记录条数 */ public long remove(String databaseName, String collectionName, Query query) { changeDatabase(databaseName); return threadLocal.get().remove(query, collectionName).getN(); } /** * 更新多条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param query 查询条件 * @param update 更新内容 * @param <T> 实例所属的类 * @return 受影响的记录条数 */ public <T> long updateMulti(String databaseName, Class<T> tClass, Query query, Update update) { changeDatabase(databaseName); return threadLocal.get().updateMulti(query, update, tClass).getN(); } /** * 更新多条记录 * * @param databaseName 数据库名 * @param collectionName 集合名 * @param query 查询条件 * @param update 更新内容 * @return 受影响的记录条数 */ public long updateMulti(String databaseName, String collectionName, Query query, Update update) { changeDatabase(databaseName); return threadLocal.get().updateMulti(query, update, collectionName).getN(); } /** * 查询多条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param query 查询条件 * @param <T> 实例所属的类 * @return 实例 */ public <T> List<T> find(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().find(query, tClass); } /** * 查询多条记录 * * @param databaseName 数据库名 * @param collectionName 集合名 * @param query 查询条件 * @param tClass 实例的class * @param <T> 实例所属的类 * @return 实例 */ public <T> List<T> find(String databaseName, String collectionName, Query query, Class<T> tClass) { changeDatabase(databaseName); return threadLocal.get().find(query, tClass, collectionName); } /** * 查询第一条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param query 查询条件 * @param <T> 实例所属的类 * @return 实例 */ public <T> T findOne(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().findOne(query, tClass); } /** * 查询第一条记录 * * @param databaseName 数据库名 * @param tClass 实例的class * @param collectionName 集合名 * @param query 查询条件 * @param <T> 实例所属的类 * @return 实例 */ public <T> T findOne(String databaseName, Class<T> tClass, String collectionName, Query query) { changeDatabase(databaseName); return threadLocal.get().findOne(query, tClass, collectionName); } }
测试类:
40个线程同时并发,测试多线程调用是否安全
这里我在五个database中放入了五个名称为"test"的collection,每个collection里放了一个{"name":"数据库名"}的Document用于测试
import com.mdruby.repository.MongoRepository; import com.mongodb.BasicDBObject; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.query.Query; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.ResponseBody; import java.util.Optional; import java.util.concurrent.CountDownLatch; /** * @author ParanoidCAT * @since JDK 1.8 */ @Controller(value = "testController") @RequestMapping(value = {"test"}) public class TestController { @Autowired private MongoRepository mongoRepository; private static final CountDownLatch COUNT_DOWN_LATCH = new CountDownLatch(40); @RequestMapping(value = {"/run/{databaseName}"}, method = {RequestMethod.GET}, produces = {"application/json;charset=utf-8"}) @ResponseBody public void run(@PathVariable String databaseName) throws InterruptedException { // 线程计数+1 COUNT_DOWN_LATCH.countDown(); // 线程没到40个就等等 COUNT_DOWN_LATCH.await(); // 线程如果到了40个就一起放行,每个线程执行150次query for (int i = 0; i < 150; i++) { BasicDBObject basicDBObject = mongoRepository.findOne(databaseName, BasicDBObject.class, "test", new Query()); if (!databaseName.equals(Optional.ofNullable(basicDBObject).map(basicDBObject1 -> basicDBObject1.getString("name")).orElse("testString"))) { System.out.println(Thread.currentThread().getName() + ": " + databaseName + " - " + basicDBObject); } } } }