前几天学习了AQS源码为了加深印象今天来基于AQS自己实现一个锁
1.基于AQS实现不可重入的锁
之前我们学习了AQS的源码,了解到了自定义AQS需要实现重写一系列函数,还需要定义原子变量state的含义。
下文我们自己实现一个锁,定义state为0表示锁没有被线程持有,state为1表示锁已经被某一个线程持有,由于是不可重入锁,所以不需要记录持有锁的线程获取锁的次数,另外,我们自定义的锁支持条件变量,因为我们要实现生产者——消费者模型
class NonReentrantLock implements Lock, Serializable {
//实现AQS
private static class Sync extends AbstractQueuedSynchronizer {
//是否锁被占有
@Override
protected boolean isHeldExclusively() {
return getState() == 1;
}
//如果state为0,则尝试获取锁
@Override
protected boolean tryAcquire(int acquires) {
assert acquires == 1;
if (compareAndSetState(0, 1)) {
//CAS成功则将当前线程设置获取到锁
setExclusiveOwnerThread(Thread.currentThread());
return true;
}
return false;
}
//尝试释放锁,将state改为1
@Override
protected boolean tryRelease(int releases) {
assert releases == 1;
if (getState() == 0)
throw new IllegalMonitorStateException();
setExclusiveOwnerThread(null);
setState(0);
return true;
}
//提供条件变量接口
Condition newCondition() {
return new ConditionObject();
}
}
//创建一个Sync来做具体的工作
private final Sync sync = new Sync();
@Override
public void lock() {
sync.acquire(1);
}
@Override
public boolean tryLock() {
return sync.tryAcquire(1);
}
@Override
public void unlock() {
sync.release(1);
}
@Override
public Condition newCondition() {
return sync.newCondition();
}
public boolean isLocked() {
return sync.isHeldExclusively();
}
@Override
public void lockInterruptibly() throws InterruptedException {
sync.acquireInterruptibly(1);
}
@Override
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
return sync.tryAcquireNanos(1, unit.toNanos(time));
}
}
在如上代码中,NonReentrantLock定义了一个内部类Sync用来实现具体的锁的操作,Sync则继承了AQS,由于我们实现的是独占模式的锁,所以Sync重写了tryAcquire、tryRelease和isHeldExclusively三个方法,另外Sync提供了newCondition这个方法用来支持条件变量
2.使用自定义锁实现生产者—消费者模型
public class AQSDemo {
final static NonReentrantLock lock = new NonReentrantLock();
final static Condition notFull = lock.newCondition();
final static Condition notEmpty = lock.newCondition();
final static Queue<String> queue = new LinkedBlockingQueue<>();
final static int queueSize = 10;
public static void main(String[] args) {
Thread producer = new Thread(() -> {
lock.lock();
try {
//如果队列满了则等待
while (queue.size() == queueSize) {
notEmpty.await();
}
//添加元素到队列
queue.add("ele");
//唤醒消费者线程
notFull.signalAll();
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
});
Thread consumer = new Thread(() -> {
lock.lock();
try {
//如果队列满了则等待
while (queue.size() == 0) {
notFull.await();
}
//消费队列
queue.poll();
//唤醒生产线程
notEmpty.signalAll();
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
});
producer.start();
consumer.start();
}
}
如上代码首先创建了一个NonReentrantLock的一个对象Lock,然后调用lock.newCondition创建了两个条件变量,用来进行生产者和消费者线程之间的同步。
在main函数里面,首先创建了生产者线程,在线程内部先调用lock.lock()获取独占锁,然后判断当前队列是否已经满了,如果满了掉用notEmpty.await()阻塞挂起当前线程。需要注意的是,这里使用while而不是if是为了避免虚假唤醒,如果队列不满则直接向队列里面添加元素,然后调用notFull.signalAll()唤醒所有因为消费元素而被阻塞的消费线程,最后释放获取的锁。