[片段] @CreatedBy / @ModifiedBy 拦截器实现

拦截器实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package app.pooi.common.entity;

import app.pooi.common.entity.anno.CreatedBy;
import app.pooi.common.entity.anno.ModifiedBy;
import lombok.Data;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;

import java.util.Arrays;
import java.util.Properties;
import java.util.function.Supplier;

@Data
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class EntityInterceptor implements org.apache.ibatis.plugin.Interceptor {

private Supplier<Long> auditorAware;

@Override
public Object intercept(Invocation invocation) throws Throwable {

Executor executor = (Executor) invocation.getTarget();

MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
Object o = invocation.getArgs()[1];

Arrays.stream(o.getClass().getDeclaredFields())
.forEach(field -> {
final CreatedBy createdBy = field.getAnnotation(CreatedBy.class);
final ModifiedBy modifiedBy = field.getAnnotation(ModifiedBy.class);

if (createdBy != null || modifiedBy != null) {
field.setAccessible(true);
try {
field.set(o, auditorAware.get());
} catch (IllegalAccessException ignore) {
}
}
});

return invocation.proceed();
}

@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}

@Override
public void setProperties(Properties properties) {

}
}

配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@Configuration
static class MybatisInterceptorConfig {

@Bean
public Interceptor[] configurationCustomizer(CipherSpi cipherSpi) {
final EntityInterceptor entityInterceptor = new EntityInterceptor();

entityInterceptor.setAuditorAware(() -> {
final String header = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest().getHeader(XHeaders.LOGIN_USER_ID);
return Long.valueOf(header);
});
return new Interceptor[]{new DecryptInterceptor(cipherSpi), entityInterceptor};
}
}

[片段] Java收集方法参数+Spring DataBinder

收集参数

目前是使用了spring aop 来拦截方法调用,把方法参数包装成Map形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface CollectArguments {
}

@Aspect
public class ArgumentsCollector {

private static final ThreadLocal<Map<String, Object>> ARGUMENTS = ThreadLocal.withInitial(ImmutableMap::of);

static Map<String, Object> getArgs() {
return ARGUMENTS.get();
}

private Object[] args(Object[] args, int exceptLength) {
if (exceptLength == args.length) {
return args;
}

return Arrays.copyOf(args, exceptLength);
}

@Pointcut("@annotation(CollectArguments)")
void collectArgumentsAnnotationPointCut() {
}

@Before("collectArgumentsAnnotationPointCut()")
public void doAccessCheck(JoinPoint joinPoint) {
final String[] parameterNames = ((MethodSignature) joinPoint.getSignature()).getParameterNames();
final Object[] args = args(joinPoint.getArgs(), parameterNames.length);

ARGUMENTS.set(Collections.unmodifiableMap((IntStream.range(0, parameterNames.length - 1)
.mapToObj(idx -> Tuple2.of(parameterNames[idx], args[idx]))
.collect(HashMap::new, (m, t) -> m.put(t.getT1(), t.getT2()), HashMap::putAll))));
}

@After("collectArgumentsAnnotationPointCut()")
public void remove() {
ARGUMENTS.remove();
}

@Data
private static class Tuple2<T1, T2> {

private T1 t1;
private T2 t2;

Tuple2(T1 t1, T2 t2) {
this.t1 = t1;
this.t2 = t2;
}

public static <T1, T2> Tuple2<T1, T2> of(T1 t1, T2 t2) {
return new Tuple2<>(t1, t2);
}
}
}

通过Map构造对象

1
2
3
4
5
6
7
8
9
10
11
12
public class BinderUtil {

BinderUtil() {
}

@SuppressWarnings("unchecked")
public static <T> T getTarget(Class<T> beanClazz) {
final DataBinder binder = new DataBinder(BeanUtils.instantiate(beanClazz));
binder.bind(new MutablePropertyValues(ArgumentsCollector.getArgs()));
return (T) binder.getTarget();
}
}

[片段] Mybatis ResultSetHandler实践

这次拦截的方法是handleResultSets(Statement stmt),用来批量解密用@Encrypted注解的String字段,可能还有一些坑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@Override
public List<Object> handleResultSets(Statement stmt) throws SQLException {
ErrorContext.instance().activity("handling results").object(mappedStatement.getId());

final List<Object> multipleResults = new ArrayList<Object>();

int resultSetCount = 0;
ResultSetWrapper rsw = getFirstResultSet(stmt);

List<ResultMap> resultMaps = mappedStatement.getResultMaps();
int resultMapCount = resultMaps.size();
validateResultMapsCount(rsw, resultMapCount);
while (rsw != null && resultMapCount > resultSetCount) {
ResultMap resultMap = resultMaps.get(resultSetCount);
handleResultSet(rsw, resultMap, multipleResults, null);
rsw = getNextResultSet(stmt);
cleanUpAfterHandlingResultSet();
resultSetCount++;
}

String[] resultSets = mappedStatement.getResultSets();
if (resultSets != null) {
while (rsw != null && resultSetCount < resultSets.length) {
ResultMapping parentMapping = nextResultMaps.get(resultSets[resultSetCount]);
if (parentMapping != null) {
String nestedResultMapId = parentMapping.getNestedResultMapId();
ResultMap resultMap = configuration.getResultMap(nestedResultMapId);
handleResultSet(rsw, resultMap, null, parentMapping);
}
rsw = getNextResultSet(stmt);
cleanUpAfterHandlingResultSet();
resultSetCount++;
}
}

return collapseSingleResultList(multipleResults);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package app.pooi.common.encrypt;


import app.pooi.common.encrypt.anno.CipherSpi;
import app.pooi.common.encrypt.anno.Encrypted;
import lombok.Getter;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.lang.reflect.Field;
import java.sql.Statement;
import java.util.*;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Intercepts({
@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class}),
})
public class EncryptInterceptor implements Interceptor {

private static final Logger logger = Logger.getLogger(EncryptInterceptor.class.getName());

private CipherSpi cipherSpi;

public EncryptInterceptor(CipherSpi cipherSpi) {
this.cipherSpi = cipherSpi;
}

@Override
public Object intercept(Invocation invocation) throws Throwable {

final Object proceed = invocation.proceed();

if (proceed == null) {
return proceed;
}

List<?> results = (List<?>) proceed;

if (results.isEmpty()) {
return proceed;
}

final Object first = results.iterator().next();

final Class<?> modelClazz = first.getClass();

final List<String> fieldsNeedDecrypt = Arrays.stream(modelClazz.getDeclaredFields())
.filter(f -> f.getAnnotation(Encrypted.class) != null)
.filter(f -> {
boolean isString = f.getType() == String.class;
if (!isString) {
logger.warning(f.getName() + "is not String, actual type is " + f.getType().getSimpleName() + " ignored");
}
return isString;
})
.map(Field::getName)
.collect(Collectors.toList());

final List<List<String>> partition = partition(fieldsNeedDecrypt, 20);

for (Object r : results) {
final MetaObject metaObject = SystemMetaObject.forObject(r);

for (List<String> fields : partition) {
final Map<String, String> fieldValueMap = fields.stream().collect(Collectors.toMap(Function.identity(), f -> (String) metaObject.getValue(f)));
final ArrayList<String> values = new ArrayList<>(fieldValueMap.values());
Map<String, String> decryptValues = cipherSpi.decrypt(values);

fieldValueMap.entrySet()
.stream()
.map(e -> Tuple2.of(e.getKey(), decryptValues.getOrDefault(e.getValue(), "")))
.forEach(e -> metaObject.setValue(e.getT1(), e.getT2()));
}
}

return results;
}

private <T> List<List<T>> partition(List<T> list, int batchCount) {
if (!(batchCount > 0)) {
throw new IllegalArgumentException("batch count must greater than zero");
}

List<List<T>> partitionList = new ArrayList<>(list.size() / (batchCount + 1));

for (int i = 0; i < list.size(); i += batchCount) {
partitionList.add(list.stream().skip(i).limit(batchCount).collect(Collectors.toList()));

}
return partitionList;
}

@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}

@Override
public void setProperties(Properties properties) {

}
}

@Getter
class Tuple2<T1, T2> {

private final T1 t1;

private final T2 t2;

Tuple2(T1 t1, T2 t2) {
this.t1 = t1;
this.t2 = t2;
}

static <T1, T2> Tuple2<T1, T2> of(T1 t1, T2 t2) {
return new Tuple2<>(t1, t2);
}
}

AbstractQueuedSynchronizer解析

AbstractQueuedSynchronizer 数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 	/** 
     * Head of the wait queue , lazily initialized . Except for 
     * initialization , it is modified only via method setHead .   Note : 
     * If head exists , its waitStatus is guaranteed not to be 
     * CANCELLED . 
     */ 
    private transient volatile Node head; 

    /** 
     * Tail of the wait queue , lazily initialized . Modified only via 
     * method enq to add new wait node . 
     */ 
    private transient volatile Node tail; 

    /** 
     * The synchronization state . 
     */ 
    private volatile int state;

稍微注意下在线程争用锁是才会初始化链表

AbstractQueuedSynchronizer.Node 数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/** 
* Status field , taking on only the values : 
*    SIGNAL :      The successor of this node is ( or will soon be ) 
*                blocked ( via park ), so the current node must 
*                unpark its successor when it releases or 
*                cancels . To avoid races , acquire methods must 
*                first indicate they need a signal , 
*                then retry the atomic acquire , and then , 
*                on failure , block . 
*    CANCELLED :   This node is cancelled due to timeout or interrupt . 
*                Nodes never leave this state . In particular , 
*                a thread with cancelled node never again blocks . 
*    CONDITION :   This node is currently on a condition queue . 
*                It will not be used as a sync queue node 
*                until transferred , at which time the status 
*                will be set to 0. ( Use of this value here has 
*                nothing to do with the other uses of the 
*                field , but simplifies mechanics .) 
*    PROPAGATE :   A releaseShared should be propagated to other 
*                nodes . This is set ( for head node only ) in 
*                doReleaseShared to ensure propagation 
*                continues , even if other operations have 
*                since intervened . 
*    0:           None of the above 

* The values are arranged numerically to simplify use . 
* Non - negative values mean that a node doesn ' t need to 
* signal . So , most code doesn ' t need to check for particular 
* values , just for sign . 

* The field is initialized to 0 for normal sync nodes , and 
* CONDITION for condition nodes .   It is modified using CAS 
* ( or when possible , unconditional volatile writes ). 
*/ 
volatile int waitStatus; 

/** 
* Link to predecessor node that current node / thread relies on 
* for checking waitStatus . Assigned during enqueuing , and nulled 
* out ( for sake of GC ) only upon dequeuing .   Also , upon 
* cancellation of a predecessor , we short - circuit while 
* finding a non - cancelled one , which will always exist 
* because the head node is never cancelled : A node becomes 
* head only as a result of successful acquire . A 
* cancelled thread never succeeds in acquiring , and a thread only 
* cancels itself , not any other node . 
*/ 
volatile Node prev; 

/** 
* Link to the successor node that the current node / thread 
* unparks upon release . Assigned during enqueuing , adjusted 
* when bypassing cancelled predecessors , and nulled out ( for 
* sake of GC ) when dequeued .   The enq operation does not 
* assign next field of a predecessor until after attachment , 
* so seeing a null next field does not necessarily mean that 
* node is at end of queue . However , if a next field appears 
* to be null , we can scan prev ' s from the tail to 
* double - check .   The next field of cancelled nodes is set to 
* point to the node itself instead of null , to make life 
* easier for isOnSyncQueue . 
*/ 
volatile Node next; 

/** 
* The thread that enqueued this node . Initialized on 
* construction and nulled out after use . 
*/ 
volatile Thread thread; 

/** 
* Link to next node waiting on condition , or the special 
* value SHARED . Because condition queues are accessed only 
* when holding in exclusive mode , we just need a simple 
* linked queue to hold nodes while they are waiting on 
* conditions . They are then transferred to the queue to 
* re - acquire . And because conditions can only be exclusive , 
* we save a field by using special value to indicate shared 
* mode . 
*/ 
Node nextWaiter;

AbstractQueuedSynchronizer** 的数据结构(盗用的图)

AbstractQueuedSynchronizer 做了什么 ?

内部维护state和CLH队列,负责在资源争用时线程入队,资源释放时唤醒队列中线程。

而实现类只需要实现 什么条件获取资源成功什么条件释放资源 成功就可以了

所以,最简单的CountDownLatch使用AbstractQueuedSynchronizer实现非常简单:

  •          申明AbstractQueuedSynchronizer的state数量(比如十个)
    
  •          await方法尝试获取资源,如果state>0表示获取失败( **什么条件获取资源成功** ,CountDownLatch实现),获取失败线程休眠(AbstractQueuedSynchronizer负责)
    
  •         countDown方法state-1,如果state==0表示资源释放成功( **什么条件释放资源成功** ,CountDownLatch实现),唤醒队列中所有线程(AbstractQueuedSynchronizer负责)
    

AbstractQueuedSynchronizer 怎么做的?

顺着ReentrantLock lock、unlock看一遍我们就大致总结出AbstractQueuedSynchronizer工作原理了

先简单介绍下ReentrantLock特性:可重入,中断,有超时机制。

ReentrantLock lock() 流程 ( 再盗图 )

黄色表示ReentrantLock实现,绿色表示AbstractQueuedSynchronizer内部实现

  1. lock方法入口 直接调用 AbstractQueuedSynchronizer.acquire方法
  2. tryAcquire
  3. addWaiter
  4. acquireQueued
AbstractQueuedSynchronizer.acquire
1
2
3
4
5
**public** final void acquire ( int arg) { 
**if** (! tryAcquire (arg) &&
acquireQueued ( addWaiter ( Node . EXCLUSIVE ), arg))
selfInterrupt ();
}

获取的锁的逻辑:直接获取成功则返回,如果没有获取成功入队休眠(对就是这么简单)

下面我们仔细一个一个方法看

ReentrantLock.tryAcquire

我这里贴的时非公平的所获取,公平和不公平的区别在于公平锁老老实实的会进入队列排队,非公平锁会先检查资源是否可用,如果可用不管队列中的情况直接尝试获取锁。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
final boolean nonfairTryAcquire ( int acquires) { 
final Thread current = Thread . currentThread ();
int c = getState ();
if (c == 0 ) {
if ( compareAndSetState ( 0 , acquires)) {
setExclusiveOwnerThread (current);
return true ;
}
}
else if (current == getExclusiveOwnerThread ()) {
int nextc = c + acquires;
if (nextc < 0 ) // overflow
throw new Error ( "Maximum lock count exceeded" );
setState (nextc);
return true ;
}
return false ;
}

ReentrantLock.tryAcquire读取到state==0时尝试占用锁,并保证同一线程可以重复占用。其他情况下获取资源失败。如果获取成功就没啥事了,不过关键不就是锁争用的时候是如何处理的吗?

AbstractQueuedSynchronizer.addWaiter(Node.EXCLUSIVE)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private Node addWaiter ( Node mode) { 
Node node = new Node (mode);

for (;;) {
Node oldTail = tail;
if (oldTail != null ) {
node. setPrevRelaxed (oldTail);
if ( compareAndSetTail (oldTail, node)) {
oldTail. next = node;
return node;
}
} else {
initializeSyncQueue ();
}
}
}

一旦锁争用,一定会初始化队列(因为排队的线程需要前驱节点唤醒,所以要初始化一个前驱节点),之后自旋成为队列尾节点。

简单来说就是获取不到锁就放进队列里维护起来,等锁释放的时候再用。

这里还有一个 很具有参考性的小细节 :先设置新节点的前驱结点,自旋成为尾节点后设置前驱的后驱

AbstractQueuedSynchronizer.acquireQueued
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
final boolean acquireQueued final Node node, int arg)
    boolean interrupted = false ; 
    try { 
        for (;;) { 
            final Node p = node. predecessor (); 
            if (p == head && tryAcquire (arg)) { 
                setHead (node); 
                p. next = null ; // help GC 
                return interrupted; 
            } 
            if ( shouldParkAfterFailedAcquire (p, node)) 
                interrupted |= parkAndCheckInterrupt (); 
        } 
    } catch ( Throwable t) { 
        cancelAcquire (node); 
        if (interrupted) 
            selfInterrupt (); 
        throw t; 
    } 


private static boolean shouldParkAfterFailedAcquire ( Node pred, Node node)
    int ws = pred. waitStatus ; 
    if (ws == Node . SIGNAL ) 
        /* 
             * This node has already set status asking a release 
             * to signal it, so it can safely park. 
             */ 
        return true ; 
    if (ws > 0 ) { 
        /* 
             * Predecessor was cancelled. Skip over predecessors and 
             * indicate retry. 
             */ 
        do { 
            node. prev = pred = pred. prev ; 
        } while (pred. waitStatus > 0 ); 
        pred. next = node; 
    } else { 
        /* 
             * waitStatus must be 0 or PROPAGATE.  Indicate that we 
             * need a signal, but don't park yet.  Caller will need to 
             * retry to make sure it cannot acquire before parking. 
             */ 
        pred. compareAndSetWaitStatus (ws, Node . SIGNAL ); 
    } 
    return false ; 


private final boolean parkAndCheckInterrupt ()
        LockSupport . park ( this ); 
        return Thread . interrupted (); 
    }

前面只是维护下链表数据结构,这里负责找到合适的唤醒前驱,然后让线程休眠。

这里主要是一个循环过程:

  1. 检查是否能获取到锁,获取到则返回
  2. 失败则寻找前面最近的未放弃争用的前驱,把前驱的waitStatus设置为-1,并把放弃争用的节点抛弃
  3. 检查是否能休眠
  4. 使用Usafe.park休眠(不是wait)

ReentrantLock lock 总结

ReentrantLock unlock()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public final boolean release int arg)
    if ( tryRelease (arg)) { 
        Node h = head; 
        if (h != null && h. waitStatus != 0 ) 
            unparkSuccessor (h); 
        return true ; 
    } 
    return false ; 


protected final boolean tryRelease int releases)
    int c = getState () - releases; 
    if ( Thread . currentThread () != getExclusiveOwnerThread ()) 
        throw new IllegalMonitorStateException (); 
    boolean free = false ; 
    if (c == 0 ) { 
        free = true ; 
        setExclusiveOwnerThread ( null ); 
    } 
    setState (c); 
    return free; 


private void unparkSuccessor ( Node node)
    /* 
         * If status is negative (i.e., possibly needing signal) try 
         * to clear in anticipation of signalling.  It is OK if this 
         * fails or if status is changed by waiting thread. 
         */ 
    int ws = node. waitStatus ; 
    if (ws < 0 ) 
        node. compareAndSetWaitStatus (ws, 0 ); 

    /* 
         * Thread to unpark is held in successor, which is normally 
         * just the next node.  But if cancelled or apparently null, 
         * traverse backwards from tail to find the actual 
         * non-cancelled successor. 
         */ 
    Node s = node. next ; 
    if (s == null || s. waitStatus > 0 ) { 
        s = null ; 
        for ( Node p = tail; p != node && p != null ; p = p. prev ) 
            if (p. waitStatus <= 0 ) 
                s = p; 
    } 
    if (s != null ) 
        LockSupport . unpark (s. thread ); 
}

unlock的代码特别简单:

  1. 每unlock一次state-1
  2. state == 0 时资源成功释放
  3. 如果释放成功,唤醒第二个节点
  4. 如果第二个节点没引用或者放弃争用,从队尾开始寻找可以唤醒的线程