diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java index 2c5db6f4334a..40a51496d6f5 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java @@ -35,6 +35,7 @@ import org.hibernate.exception.spi.SQLExceptionConversionDelegate; import org.hibernate.exception.spi.TemplatedViolatedConstraintNameExtractor; import org.hibernate.exception.spi.ViolatedConstraintNameExtractor; +import org.hibernate.persister.entity.mutation.EntityMutationTarget; import org.hibernate.query.sqm.CastType; import org.hibernate.service.ServiceRegistry; import org.hibernate.sql.ast.SqlAstTranslator; @@ -42,6 +43,8 @@ import org.hibernate.sql.ast.spi.StandardSqlAstTranslatorFactory; import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.exec.spi.JdbcOperation; +import org.hibernate.sql.model.MutationOperation; +import org.hibernate.sql.model.internal.OptionalTableUpdate; import org.hibernate.tool.schema.extract.internal.SequenceInformationExtractorMariaDBDatabaseImpl; import org.hibernate.tool.schema.extract.spi.SequenceInformationExtractor; import org.hibernate.type.SqlTypes; @@ -421,4 +424,10 @@ public boolean supportsWithClauseInSubquery() { return false; } + @Override + public MutationOperation createOptionalTableUpdateOperation(EntityMutationTarget mutationTarget, OptionalTableUpdate optionalTableUpdate, SessionFactoryImplementor factory) { + final MariaDBSqlAstTranslator translator = new MariaDBSqlAstTranslator<>( factory, optionalTableUpdate, MariaDBDialect.this ); + return translator.createMergeOperation( optionalTableUpdate ); + } + } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java index 7e5c119076f6..5f15a8ef7941 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MySQLDialect.java @@ -45,6 +45,7 @@ import org.hibernate.mapping.CheckConstraint; import org.hibernate.metamodel.mapping.EntityMappingType; import org.hibernate.metamodel.spi.RuntimeModelCreationContext; +import org.hibernate.persister.entity.mutation.EntityMutationTarget; import org.hibernate.query.common.TemporalUnit; import org.hibernate.query.sqm.CastType; import org.hibernate.query.sqm.IntervalType; @@ -63,6 +64,8 @@ import org.hibernate.sql.ast.spi.StandardSqlAstTranslatorFactory; import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.exec.spi.JdbcOperation; +import org.hibernate.sql.model.MutationOperation; +import org.hibernate.sql.model.internal.OptionalTableUpdate; import org.hibernate.type.BasicTypeRegistry; import org.hibernate.type.NullType; import org.hibernate.type.SqlTypes; @@ -1668,4 +1671,10 @@ public boolean supportsRowValueConstructorSyntaxInQuantifiedPredicates() { return false; } + @Override + public MutationOperation createOptionalTableUpdateOperation(EntityMutationTarget mutationTarget, OptionalTableUpdate optionalTableUpdate, SessionFactoryImplementor factory) { + final MySQLSqlAstTranslator translator = new MySQLSqlAstTranslator<>( factory, optionalTableUpdate, MySQLDialect.this ); + return translator.createMergeOperation( optionalTableUpdate ); + } + } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MariaDBSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MariaDBSqlAstTranslator.java index 481c9445ea81..9316b822312b 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MariaDBSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MariaDBSqlAstTranslator.java @@ -14,7 +14,6 @@ import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.query.sqm.ComparisonOperator; import org.hibernate.sql.ast.Clause; -import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.ast.tree.delete.DeleteStatement; import org.hibernate.sql.ast.tree.expression.BinaryArithmeticExpression; @@ -44,7 +43,7 @@ * * @author Christian Beikov */ -public class MariaDBSqlAstTranslator extends AbstractSqlAstTranslator { +public class MariaDBSqlAstTranslator extends SqlAstTranslatorWithOnDuplicateKeyUpdate { private final MariaDBDialect dialect; diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MySQLSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MySQLSqlAstTranslator.java index f4d714987de6..c53283ca195d 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MySQLSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/MySQLSqlAstTranslator.java @@ -12,7 +12,6 @@ import org.hibernate.internal.util.collections.Stack; import org.hibernate.query.sqm.ComparisonOperator; import org.hibernate.sql.ast.Clause; -import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; import org.hibernate.sql.ast.tree.Statement; import org.hibernate.sql.ast.tree.delete.DeleteStatement; import org.hibernate.sql.ast.tree.expression.BinaryArithmeticExpression; @@ -46,7 +45,7 @@ * * @author Christian Beikov */ -public class MySQLSqlAstTranslator extends AbstractSqlAstTranslator { +public class MySQLSqlAstTranslator extends SqlAstTranslatorWithOnDuplicateKeyUpdate { /** * On MySQL, 1GB or {@code 2^30 - 1} is the maximum size that a char value can be casted. diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/SqlAstTranslatorWithOnDuplicateKeyUpdate.java b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/SqlAstTranslatorWithOnDuplicateKeyUpdate.java new file mode 100644 index 000000000000..b4291e1a667e --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/dialect/sql/ast/SqlAstTranslatorWithOnDuplicateKeyUpdate.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.dialect.sql.ast; + + +import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.sql.ast.spi.SqlAstTranslatorWithUpsert; +import org.hibernate.sql.ast.tree.Statement; +import org.hibernate.sql.exec.spi.JdbcOperation; +import org.hibernate.sql.model.ast.ColumnValueBinding; +import org.hibernate.sql.model.internal.OptionalTableUpdate; + +import java.util.List; + +/** + * @author Jan Schatteman + */ +public class SqlAstTranslatorWithOnDuplicateKeyUpdate extends SqlAstTranslatorWithUpsert { + + public SqlAstTranslatorWithOnDuplicateKeyUpdate(SessionFactoryImplementor sessionFactory, Statement statement) { + super( sessionFactory, statement ); + } + + @Override + protected void renderUpsertStatement(OptionalTableUpdate optionalTableUpdate) { + // INSERT INTO employees (id, name, salary) + // VALUES (1, 'Alice', 50000) + // ON DUPLICATE KEY UPDATE + // name = VALUES(name), + // salary = VALUES(salary) + renderInsertInto( optionalTableUpdate ); + appendSql( " " ); + renderOnDuplicateKeyUpdate( optionalTableUpdate ); + } + + protected void renderInsertInto(OptionalTableUpdate optionalTableUpdate) { + appendSql( "insert into " ); + appendSql( optionalTableUpdate.getMutatingTable().getTableName() ); + appendSql( " (" ); + + final List keyBindings = optionalTableUpdate.getKeyBindings(); + for ( ColumnValueBinding keyBinding : keyBindings ) { + appendSql( keyBinding.getColumnReference().getColumnExpression() ); + appendSql( ',' ); + } + + optionalTableUpdate.forEachValueBinding( (columnPosition, columnValueBinding) -> { + appendSql( columnValueBinding.getColumnReference().getColumnExpression() ); + if ( columnPosition != optionalTableUpdate.getValueBindings().size() - 1 ) { + appendSql( ',' ); + } + } ); + + appendSql( ") values (" ); + + for ( ColumnValueBinding keyBinding : keyBindings ) { + keyBinding.getValueExpression().accept( this ); + appendSql( ',' ); + } + + optionalTableUpdate.forEachValueBinding( (columnPosition, columnValueBinding) -> { + if ( columnPosition > 0 ) { + appendSql( ',' ); + } + columnValueBinding.getValueExpression().accept( this ); + } ); + + appendSql( ")" ); + } + + protected void renderOnDuplicateKeyUpdate(OptionalTableUpdate optionalTableUpdate) { + appendSql( "on duplicate key update " ); + optionalTableUpdate.forEachValueBinding( (columnPosition, columnValueBinding) -> { + if ( columnPosition > 0 ) { + appendSql( ',' ); + } + appendSql( columnValueBinding.getColumnReference().getColumnExpression() ); + append( " = " ); + appendSql( "values (" ); + appendSql( columnValueBinding.getColumnReference().getColumnExpression() ); + appendSql( ")" ); + } ); + } + +} diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/stateless/UpsertTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/stateless/UpsertTest.java index b86e880e70b7..db2113de4099 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/stateless/UpsertTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/stateless/UpsertTest.java @@ -6,14 +6,19 @@ import jakarta.persistence.Entity; import jakarta.persistence.Id; +import org.hibernate.dialect.MariaDBDialect; +import org.hibernate.dialect.MySQLDialect; +import org.hibernate.testing.jdbc.SQLStatementInspector; import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.RequiresDialects; import org.hibernate.testing.orm.junit.SessionFactory; import org.hibernate.testing.orm.junit.SessionFactoryScope; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; -@SessionFactory +@SessionFactory(useCollectingStatementInspector = true) @DomainModel(annotatedClasses = UpsertTest.Record.class) public class UpsertTest { @Test void test(SessionFactoryScope scope) { @@ -25,14 +30,45 @@ public class UpsertTest { assertEquals("hello earth", s.get( Record.class,123L).message); assertEquals("hello mars", s.get( Record.class,456L).message); }); - scope.inStatelessTransaction(s-> { - s.upsert(new Record(123L,"goodbye earth")); - }); + scope.inStatelessTransaction(s-> s.upsert(new Record(123L,"goodbye earth")) ); scope.inStatelessTransaction(s-> { assertEquals("goodbye earth", s.get( Record.class,123L).message); assertEquals("hello mars", s.get( Record.class,456L).message); }); } + + @RequiresDialects( + value = { + @RequiresDialect( MySQLDialect.class ), + @RequiresDialect( MariaDBDialect.class ) + } + ) + @Test void testMySQL(SessionFactoryScope scope) { + SQLStatementInspector statementInspector = scope.getCollectingStatementInspector(); + statementInspector.clear(); + + scope.inStatelessTransaction(s-> { + s.upsert(new Record(123L,"hello earth")); + s.upsert(new Record(456L,"hello mars")); + }); + // Verify that only a single query is executed for each upsert, in contrast to the former update+insert + statementInspector.assertExecutedCount( 2 ); + + scope.inStatelessTransaction(s-> { + assertEquals("hello earth",s.get(Record.class,123L).message); + assertEquals("hello mars",s.get(Record.class,456L).message); + }); + statementInspector.clear(); + + scope.inStatelessTransaction(s-> s.upsert(new Record(123L,"goodbye earth")) ); + statementInspector.assertExecutedCount( 1 ); + + scope.inStatelessTransaction(s-> { + assertEquals("goodbye earth",s.get(Record.class,123L).message); + assertEquals("hello mars",s.get(Record.class,456L).message); + }); + } + @Entity static class Record { @Id Long id;