JPA 表租户 SQL解析实现

1. 功能介绍

  • 针对表租户ID字段标识的多租户系统

  • 参考了Mybatis-Plus插件的TenantSqlParser进行的JPA实现,使用jsqlparser解析并修改SQL,我们不生产代码,我们只做代码的搬运工

  • 实现获取当前用户租户ID,SQL增删改查时处理租户字段,实现租户数据的隔离
    参考项目:

  • https://github.com/baomidou/mybatis-plus

  • https://github.com/JSQLParser/JSqlParser

  • 2. 在JPA项目中引入jsqlparser依赖,本例中使用的版本号为 3.1
                <dependency>
                    <groupId>com.github.jsqlparser</groupId>
                    <artifactId>jsqlparser</artifactId>
                    <version>${jsqlparser.version}</version>
                </dependency>
    

    3. 编写租户拦截器TenantInterceptor

    重写hibernate提供的StatementInspector的inspect接口,参数为hibernate处理后的原始SQL,返回值为我们修改后的SQL

    import lombok.Data;
    import lombok.experimental.Accessors;
    import lombok.extern.slf4j.Slf4j;
    import net.sf.jsqlparser.expression.BinaryExpression;
    import net.sf.jsqlparser.expression.Expression;
    import net.sf.jsqlparser.expression.Parenthesis;
    import net.sf.jsqlparser.expression.StringValue;
    import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
    import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
    import net.sf.jsqlparser.expression.operators.relational.*;
    import net.sf.jsqlparser.parser.CCJSqlParserUtil;
    import net.sf.jsqlparser.schema.Column;
    import net.sf.jsqlparser.schema.Table;
    import net.sf.jsqlparser.statement.Statement;
    import net.sf.jsqlparser.statement.Statements;
    import net.sf.jsqlparser.statement.delete.Delete;
    import net.sf.jsqlparser.statement.insert.Insert;
    import net.sf.jsqlparser.statement.select.*;
    import net.sf.jsqlparser.statement.update.Update;
    import org.hibernate.resource.jdbc.spi.StatementInspector;
    import java.util.List;
     * 参考Mybatis-Plus插件中的TenantSqlParser进行租户解析处理,其实现为使用jsqlparser对sql进行解析,拼装SQL语句
     * @author wangqichang
     * @since 2019/12/5
    @Slf4j
    @Data
    @Accessors(chain = true)
    public class TenantInterceptor implements StatementInspector {
         * 当前租户ID,从UserContext获取
        private String tenantId;
         * 需进行租户解析的表名,需要注入
        private List<String> tenantTables;
         * 需进行租户解析的租户字段名,本项目中为固定名称
        private String tenantIdColumn = "tenant_id";
         * 重写StatementInspector的inspect接口,参数为hibernate处理后的原始SQL,返回值为我们修改后的SQL
         * @param sql
         * @return
        @Override
        public String inspect(String sql) {
            try {
                 * 非租户用户不进行解析
                if (UserContext.current() == null || UserContext.current().getAdministrator()) {
                    return null;
                 * 初始化需要进行租户解析的租户表
                if (tenantTables == null) {
                    TenantProperties bean = SpringContextUtil.getBean(TenantProperties.class);
                    if (bean != null) {
                        tenantTables = bean.getTables();
                    } else {
                        throw new RuntimeException("未能获取TenantProperties参数配置");
                 * 从当前线程获取登录用户的所属租户ID
                CurrentUser user = UserContext.current();
                tenantId = user.getTenantId();
                log.info("租户解析开始,原始SQL:{}", sql);
                Statements statements = CCJSqlParserUtil.parseStatements(sql);
                StringBuilder sqlStringBuilder = new StringBuilder();
                int i = 0;
                for (Statement statement : statements.getStatements()) {
                    if (null != statement) {
                        if (i++ > 0) {
                            sqlStringBuilder.append(';');
                        sqlStringBuilder.append(this.processParser(statement));
                String newSql = sqlStringBuilder.toString();
                log.info("租户解析结束,解析后SQL:{}", newSql);
                return newSql;
            } catch (Exception e) {
                log.error("租户解析失败,解析SQL异常{}", e.getMessage());
                e.printStackTrace();
            } finally {
                tenantId = null;
            return null;
        private String processParser(Statement statement) {
            if (statement instanceof Insert) {
                this.processInsert((Insert) statement);
            } else if (statement instanceof Select) {
                this.processSelectBody(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                this.processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                this.processDelete((Delete) statement);
             * 返回处理后的SQL
            return statement.toString();
         * select 语句处理
        public void processSelectBody(SelectBody selectBody) {
            if (selectBody instanceof PlainSelect) {
                processPlainSelect((PlainSelect) selectBody);
            } else if (selectBody instanceof WithItem) {
                WithItem withItem = (WithItem) selectBody;
                if (withItem.getSelectBody() != null) {
                    processSelectBody(withItem.getSelectBody());
            } else {
                SetOperationList operationList = (SetOperationList) selectBody;
                if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
                    operationList.getSelects().forEach(this::processSelectBody);
         * insert 语句处理
        public void processInsert(Insert insert) {
            if (tenantTables.contains(insert.getTable().getFullyQualifiedName())) {
                insert.getColumns().add(new Column(tenantIdColumn));
                if (insert.getSelect() != null) {
                    processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
                } else if (insert.getItemsList() != null) {
                    // fixed github pull/295
                    ItemsList itemsList = insert.getItemsList();
                    if (itemsList instanceof MultiExpressionList) {
                        ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(new StringValue(tenantId)));
                    } else {
                        ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(tenantId));
                } else {
                    throw new RuntimeException("Failed to process multiple-table update, please exclude the tableName or statementId");
         * update 语句处理
        public void processUpdate(Update update) {
            final Table table = update.getTable();
            if (tenantTables.contains(table.getFullyQualifiedName())) {
                update.setWhere(this.andExpression(table, update.getWhere()));
         * delete 语句处理
        public void processDelete(Delete delete) {
            if (tenantTables.contains(delete.getTable().getFullyQualifiedName())) {
                delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
         * delete update 语句 where 处理
        protected BinaryExpression andExpression(Table table, Expression where) {
            //获得where条件表达式
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(this.getAliasColumn(table));
            equalsTo.setRightExpression(new StringValue(tenantId));
            if (null != where) {
                if (where instanceof OrExpression) {
                    return new AndExpression(equalsTo, new Parenthesis(where));
                } else {
                    return new AndExpression(equalsTo, where);
            return equalsTo;
         * 处理 PlainSelect
        protected void processPlainSelect(PlainSelect plainSelect) {
            processPlainSelect(plainSelect, false);
         * 处理 PlainSelect
         * @param plainSelect ignore
         * @param addColumn   是否添加租户列,insert into select语句中需要
        protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
            FromItem fromItem = plainSelect.getFromItem();
            if (fromItem instanceof Table) {
                Table fromTable = (Table) fromItem;
                if (tenantTables.contains(fromTable.getFullyQualifiedName())) {
                    //#1186 github
                    plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
                    if (addColumn) {
                        plainSelect.getSelectItems().add(new SelectExpressionItem(
                                new Column(tenantIdColumn)));
            } else {
                processFromItem(fromItem);
            List<Join> joins = plainSelect.getJoins();
            if (joins != null && joins.size() > 0) {
                joins.forEach(j -> {
                    processJoin(j);
                    processFromItem(j.getRightItem());
         * 处理子查询等
        protected void processFromItem(FromItem fromItem) {
            if (fromItem instanceof SubJoin) {
                SubJoin subJoin = (SubJoin) fromItem;
                if (subJoin.getJoinList() != null) {
                    subJoin.getJoinList().forEach(this::processJoin);
                if (subJoin.getLeft() != null) {
                    processFromItem(subJoin.getLeft());
            } else if (fromItem instanceof SubSelect) {
                SubSelect subSelect = (SubSelect) fromItem;
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
            } else if (fromItem instanceof ValuesList) {
                log.debug("Perform a subquery, if you do not give us feedback");
            } else if (fromItem instanceof LateralSubSelect) {
                LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
                if (lateralSubSelect.getSubSelect() != null) {
                    SubSelect subSelect = lateralSubSelect.getSubSelect();
                    if (subSelect.getSelectBody() != null) {
                        processSelectBody(subSelect.getSelectBody());
         * 处理联接语句
        protected void processJoin(Join join) {
            if (join.getRightItem() instanceof Table) {
                Table fromTable = (Table) join.getRightItem();
                if (tenantTables.contains(fromTable.getFullyQualifiedName())) {
                    join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
         * 处理条件:
         * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
         * 默认tenantId的表达式: LongValue(1)这种依旧支持
        protected Expression builderExpression(Expression currentExpression, Table table) {
            final Expression tenantExpression = new StringValue(tenantId);
            Expression appendExpression;
            if (!(tenantExpression instanceof SupportsOldOracleJoinSyntax)) {
                appendExpression = new EqualsTo();
                ((EqualsTo) appendExpression).setLeftExpression(this.getAliasColumn(table));
                ((EqualsTo) appendExpression).setRightExpression(tenantExpression);
            } else {
                appendExpression = processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
            if (currentExpression == null) {
                return appendExpression;
            if (currentExpression instanceof BinaryExpression) {
                BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
                doExpression(binaryExpression.getLeftExpression());
                doExpression(binaryExpression.getRightExpression());
            } else if (currentExpression instanceof InExpression) {
                InExpression inExp = (InExpression) currentExpression;
                ItemsList rightItems = inExp.getRightItemsList();
                if (rightItems instanceof SubSelect) {
                    processSelectBody(((SubSelect) rightItems).getSelectBody());
            if (currentExpression instanceof OrExpression) {
                return new AndExpression(new Parenthesis(currentExpression), appendExpression);
            } else {
                return new AndExpression(currentExpression, appendExpression);
        protected void doExpression(Expression expression) {
            if (expression instanceof FromItem) {
                processFromItem((FromItem) expression);
            } else if (expression instanceof InExpression) {
                InExpression inExp = (InExpression) expression;
                ItemsList rightItems = inExp.getRightItemsList();
                if (rightItems instanceof SubSelect) {
                    processSelectBody(((SubSelect) rightItems).getSelectBody());
         * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
         * select a.id, b.name
         * from a
         * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
         * @param expression
         * @param table
         * @return 加上别名的多租户字段表达式
        protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
            //cannot add table alias for customized tenantId expression,
            // when tables including tenantId at the join table poistion
            return expression;
         * 租户字段别名设置
         * <p>tableName.tenantId 或 tableAlias.tenantId</p>
         * @param table 表对象
         * @return 字段
        protected Column getAliasColumn(Table table) {
            StringBuilder column = new StringBuilder();
            if (null == table.getAlias()) {
                column.append(table.getName());
            } else {
                column.append(table.getAlias().getName());
            column.append(".");
            column.append(tenantIdColumn);
            return new Column(column.toString());
    

    4.JPA拦截yml配置

    spring:
        database: mysql
        show-sql: true
        hibernate:
          ddl-auto: update
        properties:
          hibernate:
            session_factory:
              statement_inspector: com.tba.sc.common.intercepters.TenantInterceptor
    

    5. 租户表yml配置

    # 需进行租户解析的租户表
    tenant:
      tables:
        - sys_user
    

    6. 租户表配置类

    @Data
    @Component
    @ConfigurationProperties(prefix = "tenant")
    public class TenantProperties {
         * 需要进行租户解析的租户表
        private List<String> tables;
    

    7. 测试类

    * @author wangqichang * @since 2019/12/5 @Slf4j @SpringBootTest(classes = SystemApplication.class) @RunWith(SpringRunner.class) public class TanentTest { @Autowired UserService userService; @Test public void tenantTest() { CurrentUser user = new CurrentUser(); user.setId("40285b816252ff61016253008f9f0000"); user.setTenantId("40285b816252ff61016253008f9f0001"); user.setAdministrator(false); UserContext.setCurrentUser(user); while (true) { List<UserDTO> all = userService.findAll(); all.forEach(x -> log.info(x.toString()));

    8. 测试效果如下

    可以看到查询的SQL语句自动拼接了WHERE user0_.tenant_id = '40285b816252ff61016253008f9f0001'条件

    2019-12-06 10:02:22.345  INFO 174116 --- [           main] c.t.s.c.intercepters.TenantInterceptor   : 租户解析开始,原始SQL:select user0_.id as id1_9_, user0_.create_date as create_d2_9_, user0_.update_date as update_d3_9_, user0_.administrator as administ4_9_, user0_.org_id as org_id10_9_, user0_.password as password5_9_, user0_.real_name as real_nam6_9_, user0_.salt as salt7_9_, user0_.tenant_id as tenant_i8_9_, user0_.user_name as user_nam9_9_ from sys_user user0_
    2019-12-06 10:02:22.348  INFO 174116 --- [           main] c.t.s.c.intercepters.TenantInterceptor   : 租户解析结束,解析后SQL:SELECT user0_.id AS id1_9_, user0_.create_date AS create_d2_9_, user0_.update_date AS update_d3_9_, user0_.administrator AS administ4_9_, user0_.org_id AS org_id10_9_, user0_.password AS password5_9_, user0_.real_name AS real_nam6_9_, user0_.salt AS salt7_9_, user0_.tenant_id AS tenant_i8_9_, user0_.user_name AS user_nam9_9_ FROM sys_user user0_ WHERE user0_.tenant_id = '40285b816252ff61016253008f9f0001'