一、简介

一般一些后台管理系统都会有租户或者组织架构的概念,由于租户或者组织都是属于某一群体的,因此他们的数据应该是隔离的。
就拿组织架构来说,一个公司相当于一个组织,因此由这个组织的用户创建的任何东西都是属于这个组织的,其他组织是看不到的,所以后台会有一个组织表 dept,当然,组织自然也是有层级的,即会存在父部门和子部门,子部门不能查看父部门的数据,但是父部门可以查看子部门的数据,因此 dept 表中还需有一个字段 parent_id,用于表示该部门的父部门,基本结构如下图所示:

部门表.png
下面我将要通过上面的表来实现一套数据权限注解,其功能为:

  • 各部门之间的数据是隔离的
  • 父部门可以查看子部门数据
  • 子部门无法查看父部门数据

使用框架:druid 数据库连接池、mybatis plus。也可以不用 druid,使用 druid 主要是用了它的 SQL 解析器,可以用 JSqlParser 代替,效果都是一样的。

二、实现

1、获取部门树

本质上实现数据权限是通过获取到指定的部门 id,然后在 SQL 中将这个部门 id 通过 in 拼接进去实现数据过滤,所以我们需要根据父部门获取到所有子部门。我们可以通过 MySQL 里的递归方法,可以很快获取到父部门下的所有子部门。

WITH RECURSIVE dept_hierarchy AS (
  SELECT 
		id, 
		name, 
		parent_id parentId
  FROM 
		dept
  WHERE 
		id = 1  -- 这里的 id 是你传入的父部门的 ID
		AND deleted = 0
  UNION ALL
  SELECT 
		d.id, 
		d.name, 
		d.parent_id parentId
  FROM 
		dept d
		JOIN dept_hierarchy dh ON d.parent_id = dh.id
	WHERE
		deleted = 0
)
SELECT * FROM dept_hierarchy;

只需要传入父部门 id,即可获取到该部门下所有部门
我们再处理一下,因为只需要部门 id 即可:

@Override
public List<Integer> getDeptIdsByParentId(Integer parentId) {
    List<Dept> deptList = deptMapper.getDeptByParentId(parentId);
    List<Integer> collect = deptList.stream().map(x -> x.getId()).collect(Collectors.toList());
    return collect;
}

2、创建用户工具类

该类的作用是获取到当前请求的用户的信息,拿到所在部门,获取部门 id 集合

@Component
public class DataScope {

    @Autowired
    private DeptService deptService;

    public List<Integer> getDeptIdByUser() {
        // 先拿到用户信息,因为各个项目的都不一样,我这里就不写了
        User user = getUser();
        // 拿到用户信息后,根据用户所在的部门获取子部门
        List<Integer> deptIds = deptService.getDeptIdsByParentId(user.getDeptId());
        return deptIds;
    }

    // 获取用户信息
    public User getUser() {
        User user = new User();
        user.setDeptId(1);
        user.setIsAdmin(false);
        return user;
    }
}

3、创建注解

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface DataFilter {
    
    /**
     * 需要过滤的表的别名
     * @return
     */
    String tableAlias();

    /**
     * 部门id字段名
     * @return
     */
    String field() default "dept_id";
    
}

4、编写 Mybatis SQL 拦截器

@Intercepts({
        // 这里是拦截方法,如果使用了PageHelper分页插件则要替换成注释的这个
//        @Signature(
//                type = Executor.class,
//                method = "query",
//                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
//        )
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
        )
})
public class DataFilterInterceptor implements Interceptor {

    @Autowired
    private ApplicationContext applicationContext;

    @Override
    public Object intercept(Invocation invocation) throws Throwable{
        // 拿到运行SQL的mappedStatement
        MappedStatement mappedStatement = (MappedStatement)invocation.getArgs()[0];
        // 判断是否是select
        if (!"SELECT".equals(mappedStatement.getSqlCommandType().name())) {
            return invocation.proceed();
        }
        // 拿到类名,方法名以及class
        String nameSpace = mappedStatement.getId();
        String classPath = nameSpace.substring(0, nameSpace.lastIndexOf("."));
        String methodName = nameSpace.substring(nameSpace.lastIndexOf(".") + 1);
        Class<?> clazz = Class.forName(classPath);
        // 拿到该类的全部方法,并找到上面的方法,这个方法也就是本次即将要运行的SQL对应的方法
        Method[] methods = clazz.getDeclaredMethods();
        for (Method method : methods) {
            // 注意,这里不能分辨出重载方法,所以尽量不要有方法名相同的方法
            if (methodName.equals(method.getName())) {
                // 判断是否有注解
                DataFilter annotation = method.getAnnotation(DataFilter.class);
                if (annotation == null) {
                    return invocation.proceed();
                }
                // 获取到注解后,可以拿到注解的信息
                String field = annotation.field();
                String tableAlias = annotation.tableAlias();
                // 获取原始SQL
                BoundSql boundSql = mappedStatement.getBoundSql(invocation.getArgs()[1]);
                String sql = boundSql.getSql();
                // 修改SQL
                String newSql = getFilterSql(sql, tableAlias, field);
                // 处理SQL
                sqlHandle(mappedStatement, new MySqlSource(boundSql), invocation, newSql);
                break;
            }
        }
        return invocation.proceed();
    }

    /**
     * 处理SQL
     * @param ms
     * @param sqlSource
     * @param invocation
     * @param newSql
     */
    private void sqlHandle(MappedStatement ms, SqlSource sqlSource, Invocation invocation, String newSql) {
        // 组装 MappedStatement
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), sqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        MappedStatement newMappedStatement = builder.build();
        MetaObject metaObject =  MetaObject.forObject(newMappedStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
        metaObject.setValue("sqlSource.boundSql.sql", newSql);
        invocation.getArgs()[0] = newMappedStatement;

        // 组装BoundSql
        // 如果使用了PageHelper分页插件则要取消下面的注释
//        invocation.getArgs()[5] = newMappedStatement.getSqlSource().getBoundSql(invocation.getArgs()[1]);
    }

    /**
     * 拼接权限过滤SQL
     * @param sql 原始SQL
     * @param tableAlias 要拼接的表的别名
     * @param field 要拼接的表的字段
     * @return
     */
    private String getFilterSql(String sql, String tableAlias, String field) {
        // 在这里需要获取当前的用户,如果是管理员则不需要权限过滤,直接返回原始SQL
        // 这里不能直接注入,会导致循环依赖
        DataScope dataScope = (DataScope) applicationContext.getBean("dataScope");
        User user = dataScope.getUser();
        if (user.getIsAdmin()) {
            // 如果是管理员则不需要权限过滤
            return sql;
        }
        // 否则获取到当前用户所在的部门,并根据这个部门id拿到子部门
        List<Integer> deptIds = dataScope.getDeptIdByUser();
        String deptIdsStr = deptIds.stream().map(String::valueOf).collect(Collectors.joining(","));
        String condition = tableAlias + "." + field + " IN (" + deptIdsStr + ")";
        String newSql = SQLUtils.addCondition(sql, condition, JdbcConstants.MYSQL);

        return newSql;
    }

    // 定义一个内部类,作用是包装sql
    class MySqlSource implements SqlSource {

        private BoundSql boundSql;

        public MySqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }
}

5、注册拦截器

@Configuration
public class MybatisPlusConfig {

    /**
     * 权限过滤器拦截插件
     * @return
     */
    @Bean
    public DataFilterInterceptor dataFilterInterceptor() {
        DataFilterInterceptor dataFilterInterceptor = new DataFilterInterceptor();
        return dataFilterInterceptor;
    }
}

6、测试效果

我们建一个文件表 file,并加上一些数据

文件表.png
接下来我们写一条 SQL 用来测试注解好不好使

<select id="getFileList" resultType="club.gggd.datafilterdemo.domain.File">
    select
        id,
        dept_id deptId,
        file_name name
    from file f -- 这里一定要定义一个别名
</select>

在 mapper 中加入注解

@Mapper
public interface FileMapper {

    // 这里加上注解,这里的参数指定哪个表就表示以哪个表来进行数据过滤
    @DataFilter(tableAlias = "f")
    List<File> getFileList();
}

为了方便查看效果我们打开打印日志功能,在配置文件中加入配置:

mybatis-plus:
  configuration:
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl

编写测试接口:

@RestController
@RequestMapping("/test")
public class TestController {

    @Autowired
    private FileMapper fileMapper;

    @GetMapping
    public String test() {
        List<File> fileList = fileMapper.getFileList();
        return fileList.toString();
    }
}

请求一下

请求结果.png
可以正常请求到数据,看看打印出来的 SQL:

请求结果SQL.png

可以看到,在原来 SQL 的基础上,拦截器自动帮我们拼接了 in 查询进去,大家也可以试试在 DataScope 中的getUser 方法里改动 deptid 和 isAdmin 看看效果

需要注意的地方
如果使用了PageHelper分页插件,则需要替换拦截器的拦截方法,也就是替换成注释的那个,同时要把sqlHandle 方法中最后面那行的注释去掉。
具体原因是因为 Mybatis 拦截器是采用的责任链模式,一般拦截器中intercept方法中最后执行 invocation.proceed() 方法,将拦截器责任链向后传递,但是查看pageHelper源码可以发现,他的拦截器方法中并没有向后传递责任链,而是直接执行了另一个query方法,就导致没有向后传递从而拦截不到,所以就需要更换拦截方法。
我个人建议分页插件直接用 mybatis 的就可以了。