diff --git a/presto-iceberg/pom.xml b/presto-iceberg/pom.xml index 036a44bef74e9..5ebe42b1eba96 100644 --- a/presto-iceberg/pom.xml +++ b/presto-iceberg/pom.xml @@ -31,6 +31,17 @@ + + + io.github.jbellis + jvector + 4.0.0-rc.4 + + + org.apache.commons + commons-math3 + + com.facebook.airlift concurrent @@ -752,11 +763,6 @@ - - org.apache.commons - commons-math3 - test - @@ -771,6 +777,7 @@ org.apache.httpcomponents.core5:httpcore5 com.amazonaws:aws-java-sdk-s3 com.amazonaws:aws-java-sdk-core + org.apache.commons:commons-math3 diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index eb6ea59912b76..558637bdf80df 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -32,6 +32,7 @@ import com.facebook.presto.iceberg.changelog.ChangelogOperation; import com.facebook.presto.iceberg.changelog.ChangelogUtil; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; +import com.facebook.presto.iceberg.tvf.ApproxNearestNeighborsFunction; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; @@ -60,7 +61,9 @@ import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionOperator; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.RowExpressionService; @@ -259,6 +262,7 @@ public abstract class IcebergAbstractMetadata protected Transaction transaction; protected final StatisticsFileCache statisticsFileCache; protected final IcebergTableProperties tableProperties; + protected final IcebergConfig icebergConfig; private final StandardFunctionResolution functionResolution; private final ConcurrentMap icebergTables = new ConcurrentHashMap<>(); @@ -272,7 +276,8 @@ public IcebergAbstractMetadata( NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, - IcebergTableProperties tableProperties) + IcebergTableProperties tableProperties, + IcebergConfig icebergConfig) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); @@ -283,6 +288,7 @@ public IcebergAbstractMetadata( this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.statisticsFileCache = requireNonNull(statisticsFileCache, "statisticsFileCache is null"); this.tableProperties = requireNonNull(tableProperties, "tableProperties is null"); + this.icebergConfig = requireNonNull(icebergConfig, "icebergConfig is null"); } protected final Table getIcebergTable(ConnectorSession session, SchemaTableName schemaTableName) @@ -1424,6 +1430,11 @@ protected Optional getDataLocationBasedOnWarehouseDataDir(SchemaTableNam return Optional.empty(); } + public TypeManager getTypeManager() + { + return typeManager; + } + @Override public Optional getInfo(ConnectorTableLayoutHandle tableHandle) { @@ -1765,4 +1776,28 @@ private boolean viewExists(ConnectorSession session, ConnectorTableMetadata view return false; } } + + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + if (!icebergConfig.isSimilaritySearchEnabled()) { + return Optional.empty(); + } + if (handle instanceof ApproxNearestNeighborsFunction.IcebergAnnTableFunctionHandle) { + ApproxNearestNeighborsFunction.IcebergAnnTableFunctionHandle annTableFunctionHandle = (ApproxNearestNeighborsFunction.IcebergAnnTableFunctionHandle) handle; + ApproxNearestNeighborsFunction.IcebergAnnTableHandle originalHandle = (ApproxNearestNeighborsFunction.IcebergAnnTableHandle) annTableFunctionHandle.getTableHandle(); + SchemaTableName schemaTableName = new SchemaTableName(originalHandle.getSchemaName(), originalHandle.getTableName()); + Table icebergTable = getIcebergTable(session, schemaTableName); + Optional tableLocation = tryGetLocation(icebergTable); + ApproxNearestNeighborsFunction.IcebergAnnTableHandle updatedHandle = new ApproxNearestNeighborsFunction.IcebergAnnTableHandle( + originalHandle.getInputVector(), + originalHandle.getLimit(), + originalHandle.getSchemaName(), + originalHandle.getTableName(), + tableLocation); + return Optional.of(new TableFunctionApplicationResult<>(updatedHandle, annTableFunctionHandle.getColumnHandles())); + } + + throw new IllegalArgumentException("Unsupported function handle: " + handle); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java index 7220800e0959e..b082a8a9a05ee 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java @@ -43,6 +43,7 @@ import com.facebook.presto.hive.gcs.HiveGcsConfigurationInitializer; import com.facebook.presto.iceberg.nessie.IcebergNessieConfig; import com.facebook.presto.iceberg.optimizer.IcebergPlanOptimizerProvider; +import com.facebook.presto.iceberg.procedure.BuildVectorIndexProcedure; import com.facebook.presto.iceberg.procedure.ExpireSnapshotsProcedure; import com.facebook.presto.iceberg.procedure.FastForwardBranchProcedure; import com.facebook.presto.iceberg.procedure.ManifestFileCacheInvalidationProcedure; @@ -56,6 +57,7 @@ import com.facebook.presto.iceberg.procedure.UnregisterTableProcedure; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; import com.facebook.presto.iceberg.statistics.StatisticsFileCacheKey; +import com.facebook.presto.iceberg.tvf.ApproxNearestNeighborsFunction; import com.facebook.presto.orc.CachingStripeMetadataSource; import com.facebook.presto.orc.DwrfAwareStripeMetadataSourceFactory; import com.facebook.presto.orc.EncryptionLibrary; @@ -83,6 +85,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.google.common.cache.Cache; @@ -188,6 +191,9 @@ protected void setup(Binder binder) procedures.addBinding().toProvider(SetTablePropertyProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(StatisticsFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(ManifestFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); + if (icebergConfig.isSimilaritySearchEnabled()) { + procedures.addBinding().toProvider(BuildVectorIndexProcedure.class).in(Scopes.SINGLETON); + } // for orc binder.bind(EncryptionLibrary.class).annotatedWith(HiveDwrfEncryptionProvider.ForCryptoService.class).to(UnsupportedEncryptionLibrary.class).in(Scopes.SINGLETON); @@ -198,7 +204,9 @@ protected void setup(Binder binder) configBinder(binder).bindConfig(OrcFileWriterConfig.class); configBinder(binder).bindConfig(ParquetCacheConfig.class, connectorId); - + if (icebergConfig.isSimilaritySearchEnabled()) { + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(ApproxNearestNeighborsFunction.class).in(Scopes.SINGLETON); + } binder.bind(ConnectorPlanOptimizerProvider.class).to(IcebergPlanOptimizerProvider.class).in(Scopes.SINGLETON); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java index 432612bb3366b..00f55ef99278b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java @@ -59,6 +59,7 @@ public class IcebergConfig private boolean mergeOnReadModeEnabled = true; private double statisticSnapshotRecordDifferenceWeight; private boolean pushdownFilterEnabled; + private boolean similaritySearchEnabled; private boolean deleteAsJoinRewriteEnabled = true; private int deleteAsJoinRewriteMaxDeleteColumns = 400; private int rowsForMetadataOptimizationThreshold = 1000; @@ -496,4 +497,17 @@ public IcebergConfig setMaterializedViewStoragePrefix(String materializedViewSto this.materializedViewStoragePrefix = materializedViewStoragePrefix; return this; } + + public boolean isSimilaritySearchEnabled() + { + return similaritySearchEnabled; + } + + @Config("iceberg.similarity-search-enabled") + @ConfigDescription("Enable filter for similarity search") + public IcebergConfig setSimilaritySearchEnabled(boolean similaritySearchEnabled) + { + this.similaritySearchEnabled = similaritySearchEnabled; + return this; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java index cff54195d1c81..8558b81474ba3 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java @@ -31,6 +31,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorMetadata; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; @@ -68,6 +69,7 @@ public class IcebergConnector private final ConnectorAccessControl accessControl; private final Set> procedures; private final ConnectorPlanOptimizerProvider planOptimizerProvider; + private final Set connectorTableFunctions; public IcebergConnector( LifeCycleManager lifeCycleManager, @@ -84,7 +86,8 @@ public IcebergConnector( List> columnProperties, ConnectorAccessControl accessControl, Set> procedures, - ConnectorPlanOptimizerProvider planOptimizerProvider) + ConnectorPlanOptimizerProvider planOptimizerProvider, + Set connectorTableFunctions) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); @@ -101,6 +104,7 @@ public IcebergConnector( this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.procedures = requireNonNull(procedures, "procedures is null"); this.planOptimizerProvider = requireNonNull(planOptimizerProvider, "planOptimizerProvider is null"); + this.connectorTableFunctions = connectorTableFunctions; } @Override @@ -246,4 +250,9 @@ private > Set getProcedures(Class clazz) .map(clazz::cast) .collect(Collectors.toSet()); } + + public Set getTableFunctions() + { + return connectorTableFunctions; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java index 5fb9569b5222e..8d3a0b01aa7e1 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java @@ -187,9 +187,10 @@ public IcebergHiveMetadata( StatisticsFileCache statisticsFileCache, ManifestFileCache manifestFileCache, IcebergTableProperties tableProperties, - ConnectorSystemConfig connectorSystemConfig) + ConnectorSystemConfig connectorSystemConfig, + IcebergConfig icebergConfig) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties, icebergConfig); this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java index d01ae7a33dab7..3a4b96ea91758 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java @@ -49,6 +49,7 @@ public class IcebergHiveMetadataFactory final ManifestFileCache manifestFileCache; final IcebergTableProperties tableProperties; final ConnectorSystemConfig connectorSystemConfig; + final IcebergConfig icebergConfig; @Inject public IcebergHiveMetadataFactory( @@ -66,7 +67,8 @@ public IcebergHiveMetadataFactory( StatisticsFileCache statisticsFileCache, ManifestFileCache manifestFileCache, IcebergTableProperties tableProperties, - ConnectorSystemConfig connectorSystemConfig) + ConnectorSystemConfig connectorSystemConfig, + IcebergConfig icebergConfig) { this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.metastore = requireNonNull(metastore, "metastore is null"); @@ -83,6 +85,7 @@ public IcebergHiveMetadataFactory( this.manifestFileCache = requireNonNull(manifestFileCache, "manifestFileCache is null"); this.tableProperties = requireNonNull(tableProperties, "icebergTableProperties is null"); this.connectorSystemConfig = requireNonNull(connectorSystemConfig, "connectorSystemConfig is null"); + this.icebergConfig = requireNonNull(icebergConfig, "icebergConfig is null"); } public ConnectorMetadata create() @@ -102,6 +105,7 @@ public ConnectorMetadata create() statisticsFileCache, manifestFileCache, tableProperties, - connectorSystemConfig); + connectorSystemConfig, + icebergConfig); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java index 52d59d7b367f7..08f28637613c7 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java @@ -117,9 +117,10 @@ public IcebergNativeMetadata( NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, - IcebergTableProperties tableProperties) + IcebergTableProperties tableProperties, + IcebergConfig icebergConfig) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties, icebergConfig); this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.catalogType = requireNonNull(catalogType, "catalogType is null"); this.warehouseDataDir = Optional.ofNullable(catalogFactory.getCatalogWarehouseDataDir()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java index b421a823efe2c..6e2cf82061f2d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java @@ -42,6 +42,7 @@ public class IcebergNativeMetadataFactory final FilterStatsCalculatorService filterStatsCalculatorService; final StatisticsFileCache statisticsFileCache; final IcebergTableProperties tableProperties; + final IcebergConfig icebergConfig; @Inject public IcebergNativeMetadataFactory( @@ -55,7 +56,8 @@ public IcebergNativeMetadataFactory( NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, - IcebergTableProperties tableProperties) + IcebergTableProperties tableProperties, + IcebergConfig icebergConfig) { this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); @@ -68,10 +70,11 @@ public IcebergNativeMetadataFactory( this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.statisticsFileCache = requireNonNull(statisticsFileCache, "statisticsFileCache is null"); this.tableProperties = requireNonNull(tableProperties, "tableProperties is null"); + this.icebergConfig = requireNonNull(icebergConfig, "icebergConfig is null"); } public ConnectorMetadata create() { - return new IcebergNativeMetadata(catalogFactory, typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, catalogType, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + return new IcebergNativeMetadata(catalogFactory, typeManager, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, catalogType, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties, icebergConfig); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java index c99e5b8034c3a..901ee328afc58 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java @@ -47,6 +47,7 @@ import com.facebook.presto.iceberg.delete.IcebergDeletePageSink; import com.facebook.presto.iceberg.delete.PositionDeleteFilter; import com.facebook.presto.iceberg.delete.RowPredicate; +import com.facebook.presto.iceberg.tvf.ANNPageSource; import com.facebook.presto.memory.context.AggregatedMemoryContext; import com.facebook.presto.orc.DwrfEncryptionProvider; import com.facebook.presto.orc.DwrfKeyProvider; @@ -74,6 +75,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.FixedPageSource; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SplitContext; @@ -208,6 +210,8 @@ public class IcebergPageSourceProvider private final PageIndexerFactory pageIndexerFactory; private final int maxOpenPartitions; private final SortParameters sortParameters; + private final ManifestFileCache manifestFileCache; + private final boolean similaritySearchEnabled; @Inject public IcebergPageSourceProvider( @@ -223,7 +227,8 @@ public IcebergPageSourceProvider( JsonCodec jsonCodec, PageIndexerFactory pageIndexerFactory, IcebergConfig icebergConfig, - SortParameters sortParameters) + SortParameters sortParameters, + ManifestFileCache manifestFileCache) { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.fileFormatDataSourceStats = requireNonNull(fileFormatDataSourceStats, "fileFormatDataSourceStats is null"); @@ -239,6 +244,8 @@ public IcebergPageSourceProvider( requireNonNull(icebergConfig, "icebergConfig is null"); this.maxOpenPartitions = icebergConfig.getMaxPartitionsPerWriter(); this.sortParameters = requireNonNull(sortParameters, "sortParameters is null"); + this.manifestFileCache = requireNonNull(manifestFileCache, "manifestFileCache is null"); + this.similaritySearchEnabled = icebergConfig.isSimilaritySearchEnabled(); } private static ConnectorPageSourceWithRowPositions createParquetPageSource( @@ -734,6 +741,20 @@ public ConnectorPageSource createPageSource( IcebergSplit split = (IcebergSplit) connectorSplit; IcebergTableHandle table = icebergLayout.getTable(); + if (similaritySearchEnabled && split.isAnn()) { + String tableLocation = table.getOutputPath() + .orElseThrow(() -> new IllegalStateException("Table location is required for ANN queries")); + + HdfsContext hdfsContext = new HdfsContext(session, table.getSchemaName(), table.getTableName(), split.getPath(), false); + HdfsFileIO hdfsFileIO = new HdfsFileIO(manifestFileCache, hdfsEnvironment, hdfsContext); + + return new ANNPageSource( + new FixedPageSource(ImmutableList.of()), + split.getQueryVector(), + split.getTopN(), + tableLocation, + hdfsFileIO); + } List columns = desiredColumns; if (split.getChangelogSplitInfo().isPresent()) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplit.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplit.java index 2ebcdc5a8ed54..e3a0e4179b869 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplit.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplit.java @@ -53,6 +53,9 @@ public class IcebergSplit private final long dataSequenceNumber; private final long affinitySchedulingFileSectionSize; private final long affinitySchedulingFileSectionIndex; + private final boolean ann; + private final List queryVector; + private final int topN; @JsonCreator public IcebergSplit( @@ -69,7 +72,10 @@ public IcebergSplit( @JsonProperty("deletes") List deletes, @JsonProperty("changelogSplitInfo") Optional changelogSplitInfo, @JsonProperty("dataSequenceNumber") long dataSequenceNumber, - @JsonProperty("affinitySchedulingSectionSize") long affinitySchedulingFileSectionSize) + @JsonProperty("affinitySchedulingSectionSize") long affinitySchedulingFileSectionSize, + @JsonProperty("ann") boolean ann, + @JsonProperty("queryVector") List queryVector, + @JsonProperty("topN") int topN) { requireNonNull(nodeSelectionStrategy, "nodeSelectionStrategy is null"); this.path = requireNonNull(path, "path is null"); @@ -87,6 +93,27 @@ public IcebergSplit( this.dataSequenceNumber = dataSequenceNumber; this.affinitySchedulingFileSectionSize = affinitySchedulingFileSectionSize; this.affinitySchedulingFileSectionIndex = start / affinitySchedulingFileSectionSize; + this.ann = ann; + this.queryVector = queryVector; + this.topN = topN; + } + + @JsonProperty + public boolean isAnn() + { + return ann; + } + + @JsonProperty + public List getQueryVector() + { + return queryVector; + } + + @JsonProperty + public int getTopN() + { + return topN; } @JsonProperty diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java index 0ad3345b7ae9d..0b7968d79b296 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java @@ -18,13 +18,17 @@ import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.iceberg.changelog.ChangelogSplitSource; import com.facebook.presto.iceberg.equalitydeletes.EqualityDeletesSplitSource; +import com.facebook.presto.iceberg.tvf.ApproxNearestNeighborsFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.SplitWeight; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import jakarta.inject.Inject; import org.apache.iceberg.DeleteFile; import org.apache.iceberg.IncrementalChangelogScan; @@ -35,6 +39,7 @@ import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -53,17 +58,20 @@ public class IcebergSplitManager private final TypeManager typeManager; private final ExecutorService executor; private final ThreadPoolExecutorMBean executorServiceMBean; + private final IcebergConfig icebergConfig; @Inject public IcebergSplitManager( IcebergTransactionManager transactionManager, TypeManager typeManager, - @ForIcebergSplitManager ExecutorService executor) + @ForIcebergSplitManager ExecutorService executor, + IcebergConfig icebergConfig) { this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.executor = requireNonNull(executor, "executor is null"); this.executorServiceMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor); + this.icebergConfig = requireNonNull(icebergConfig, "icebergConfig is null"); } @Override @@ -76,6 +84,29 @@ public ConnectorSplitSource getSplits( IcebergTableLayoutHandle layoutHandle = (IcebergTableLayoutHandle) layout; IcebergTableHandle table = layoutHandle.getTable(); + if (icebergConfig.isSimilaritySearchEnabled() && table instanceof ApproxNearestNeighborsFunction.IcebergAnnTableHandle) { + ApproxNearestNeighborsFunction.IcebergAnnTableHandle annHandle = (ApproxNearestNeighborsFunction.IcebergAnnTableHandle) table; + IcebergSplit annSplit = new IcebergSplit( + /* path */ "", // non-null dummy string + /* start */ 1L, + /* length */ 1L, + /* fileFormat */ FileFormat.PARQUET, // pick any valid enum + /* addresses */ ImmutableList.of(), // empty list is fine + /* partitionKeys */ ImmutableMap.of(), // empty map + /* partitionSpecAsJson */ "{}", // minimal valid JSON + /* partitionDataJson */ Optional.of("{}"), // optional non-null JSON + /* nodeSelectionStrategy */ NodeSelectionStrategy.SOFT_AFFINITY, + /* splitWeight */ SplitWeight.fromRawValue(1), // minimal positive weight + /* deletes */ ImmutableList.of(), + /* changelogSplitInfo */ Optional.empty(), + /* dataSequenceNumber */ 1L, + /* affinitySchedulingSectionSize */ 1L, + /* ann */ true, + /* queryVector */ annHandle.getInputVector(), + /* topN */ annHandle.getLimit()); + return new FixedSplitSource(ImmutableList.of(annSplit)); + } + if (!table.getIcebergTableName().getSnapshotId().isPresent()) { return new FixedSplitSource(ImmutableList.of()); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitSource.java index 98ee9f2693450..8fb2e5506aea8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitSource.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitSource.java @@ -143,6 +143,9 @@ private ConnectorSplit toIcebergSplit(FileScanTask task) task.deletes().stream().map(DeleteFile::fromIceberg).collect(toImmutableList()), Optional.empty(), getDataSequenceNumber(task.file()), - affinitySchedulingFileSectionSize); + affinitySchedulingFileSectionSize, + false, + ImmutableList.of(), + 0); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java index 3b8994df46855..a596fff322848 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java @@ -28,6 +28,7 @@ import com.facebook.presto.hive.gcs.HiveGcsModule; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.s3.HiveS3Module; +import com.facebook.presto.iceberg.tvf.ApproxNearestNeighborsFunction; import com.facebook.presto.plugin.base.security.AllowAllAccessControl; import com.facebook.presto.spi.ConnectorSystemConfig; import com.facebook.presto.spi.NodeManager; @@ -47,6 +48,7 @@ import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeNodePartitioningProvider; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.relation.RowExpressionService; @@ -60,6 +62,8 @@ import javax.management.MBeanServer; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -118,13 +122,16 @@ public static Connector createConnector( IcebergSessionProperties icebergSessionProperties = injector.getInstance(IcebergSessionProperties.class); HiveCommonSessionProperties hiveCommonSessionProperties = injector.getInstance(HiveCommonSessionProperties.class); IcebergTableProperties icebergTableProperties = injector.getInstance(IcebergTableProperties.class); + IcebergConfig icebergConfig = injector.getInstance(IcebergConfig.class); Set> procedures = injector.getInstance(Key.get(new TypeLiteral>>() {})); ConnectorPlanOptimizerProvider planOptimizerProvider = injector.getInstance(ConnectorPlanOptimizerProvider.class); List> allSessionProperties = new ArrayList<>(icebergSessionProperties.getSessionProperties()); allSessionProperties.addAll(hiveCommonSessionProperties.getSessionProperties()); - + Set connectorTableFunctions = icebergConfig.isSimilaritySearchEnabled() + ? new HashSet<>(Collections.singleton(new ApproxNearestNeighborsFunction().get())) + : new HashSet<>(); return new IcebergConnector( lifeCycleManager, transactionManager, @@ -140,7 +147,8 @@ public static Connector createConnector( icebergTableProperties.getColumnProperties(), new AllowAllAccessControl(), procedures, - planOptimizerProvider); + planOptimizerProvider, + connectorTableFunctions); } } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/changelog/ChangelogSplitSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/changelog/ChangelogSplitSource.java index 44f6a2b30d494..0a136fc3ffe83 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/changelog/ChangelogSplitSource.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/changelog/ChangelogSplitSource.java @@ -157,6 +157,9 @@ private IcebergSplit splitFromContentScanTask(ContentScanTask task, Ch changeTask.commitSnapshotId(), columnHandles)), getDataSequenceNumber(task.file()), - affinitySchedulingSectionSize); + affinitySchedulingSectionSize, + false, + ImmutableList.of(), + 0); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/equalitydeletes/EqualityDeletesSplitSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/equalitydeletes/EqualityDeletesSplitSource.java index a3df4ca65e049..b9097ade05240 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/equalitydeletes/EqualityDeletesSplitSource.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/equalitydeletes/EqualityDeletesSplitSource.java @@ -125,6 +125,9 @@ private IcebergSplit splitFromDeleteFile(DeleteFile deleteFile) ImmutableList.of(), Optional.empty(), IcebergUtil.getDataSequenceNumber(deleteFile), - affinitySchedulingSectionSize); + affinitySchedulingSectionSize, + false, + ImmutableList.of(), + 0); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/BuildVectorIndexProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/BuildVectorIndexProcedure.java new file mode 100644 index 0000000000000..88a98817dc761 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/BuildVectorIndexProcedure.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.procedure; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.hive.HiveTransactionHandle; +import com.facebook.presto.iceberg.IcebergMetadataFactory; +import com.facebook.presto.iceberg.IcebergTransactionManager; +import com.facebook.presto.iceberg.vectors.IcebergVectorIndexBuilder; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.procedure.Procedure; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; +import javax.inject.Provider; + +import java.lang.invoke.MethodHandle; +import java.nio.file.Path; + +import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static java.util.Objects.requireNonNull; + +/** + * Procedure for building vector indexes from Iceberg table data. + */ +public class BuildVectorIndexProcedure + implements Provider +{ + private static final Logger log = Logger.get(BuildVectorIndexProcedure.class); + + private final IcebergTransactionManager transactionManager; + private final IcebergMetadataFactory metadataFactory; + private final ConnectorPageSourceProvider pageSourceProvider; + + @Inject + public BuildVectorIndexProcedure( + IcebergTransactionManager transactionManager, + IcebergMetadataFactory metadataFactory, + ConnectorPageSourceProvider pageSourceProvider) + { + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + this.metadataFactory = requireNonNull(metadataFactory, "metadataFactory is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); + } + + @Override + public Procedure get() + { + // Create the MethodHandle at runtime and bind to this instance so the arity matches + MethodHandle handle = methodHandle( + BuildVectorIndexProcedure.class, + "buildVectorIndex", + ConnectorSession.class, + String.class // columnPath + ).bindTo(this); + + return new Procedure( + "system", // Use "system" as the schema name instead of an empty string + "CREATE_VEC_INDEX", + ImmutableList.of( + new Procedure.Argument("column_path", VARCHAR)), + handle); + } + + public void buildVectorIndex( + ConnectorSession session, + String columnPath) + { + // Parse the column path to extract catalog, schema, table, and column names + // Format: catalog.schema.table.column or schema.table.column + String[] parts = columnPath.split("\\."); + String catalogName = null; + String schemaName; + String tableName; + String columnName; + if (parts.length == 4) { + // Format: catalog.schema.table.column + catalogName = parts[0]; + schemaName = parts[1]; + tableName = parts[2]; + columnName = parts[3]; + } + else if (parts.length == 3) { + // Format: schema.table.column + schemaName = parts[0]; + tableName = parts[1]; + columnName = parts[2]; + } + else { + throw new IllegalArgumentException("Invalid column path format. Expected: [catalog.]schema.table.column, got: " + columnPath); + } + // Use a generic fixed name for the index + String indexName = "vector_index"; + log.info("Building vector index for %s.%s.%s", schemaName, tableName, columnName); + SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); + // Use default values for similarity function, m, and ef_construction + int mValue = 16; + int efValue = 100; + String simFunction = "COSINE"; + + // Create a new transaction handle for this procedure + ConnectorTransactionHandle transactionHandle = new HiveTransactionHandle(); + boolean transactionRegistered = false; + + try { + // Create a new metadata instance and register it with the transaction manager + ConnectorMetadata metadata = metadataFactory.create(); + transactionManager.put(transactionHandle, metadata); + transactionRegistered = true; + + Path resultPath = IcebergVectorIndexBuilder.buildAndSaveVectorIndex( + metadata, + pageSourceProvider, + transactionHandle, + session, + schemaTableName, + columnName, + indexName, + catalogName, // Pass the catalog name to the builder + simFunction, + mValue, + efValue); + log.info("Vector index built and saved to %s", resultPath); + } + catch (Exception e) { + log.error(e, "Error building vector index"); + throw new RuntimeException("Error building vector index: " + e.getMessage(), e); + } + finally { + // Only clean up the transaction if it was successfully registered + if (transactionRegistered) { + try { + transactionManager.remove(transactionHandle); + } + catch (Exception e) { + log.warn(e, "Failed to remove transaction handle"); + } + } + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ANNPageSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ANNPageSource.java new file mode 100644 index 0000000000000..1a884b25466cf --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ANNPageSource.java @@ -0,0 +1,191 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.tvf; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.iceberg.HdfsFileIO; +import com.facebook.presto.iceberg.HdfsInputFile; +import com.facebook.presto.iceberg.vectors.NodeRowIdMapping; +import com.facebook.presto.spi.ConnectorPageSource; +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.List; + +import static com.facebook.presto.common.type.BigintType.BIGINT; + +public class ANNPageSource + implements ConnectorPageSource +{ + private final ConnectorPageSource delegate; + private final List queryVector; + private final int topN; + private boolean finished; + private final String tableLocation; + private final HdfsFileIO hdfsFileIO; + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final String VECTOR_INDEX_DIR = ".vector_index"; + private static final Logger log = Logger.get(ANNPageSource.class); + + public ANNPageSource(ConnectorPageSource delegate, List queryVector, int topN, String tableLocation, HdfsFileIO hdfsFileIO) + { + this.delegate = delegate; + this.queryVector = queryVector; + this.topN = topN; + this.tableLocation = tableLocation; + this.hdfsFileIO = hdfsFileIO; + } + + @Override + public long getCompletedBytes() + { + return delegate.getCompletedBytes(); + } + + @Override + public long getCompletedPositions() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return delegate.getReadTimeNanos(); + } + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public Page getNextPage() + { + if (finished) { + return null; + } + log.info("tableLocation: %s", tableLocation); + + String indexDirPath = tableLocation + "/" + VECTOR_INDEX_DIR; + String indexPath = indexDirPath + "/vector_index.hnsw"; + String mappingPath = indexDirPath + "/vector_index_mapping.bin"; + log.info("Loading vector index from: %s", indexPath); + log.info("Loading node-to-row mapping from: %s", mappingPath); + java.nio.file.Path tempIndexFile = null; + try { + // Load mapping directly from stream + NodeRowIdMapping mapping; + try { + HdfsInputFile mappingInputFile = (HdfsInputFile) hdfsFileIO.newInputFile(mappingPath); + try (InputStream mappingInputStream = mappingInputFile.newStream()) { + mapping = NodeRowIdMapping.load(mappingInputStream); + } + } + catch (Exception e) { + throw new RuntimeException( + String.format("Failed to load node-to-row ID mapping from %s. " + + "The index may need to be rebuilt with the updated version that includes mapping support.", + mappingPath), e); + } + // Create temp file only for index + tempIndexFile = Files.createTempFile("vector-index-", ".hnsw"); + HdfsInputFile indexInputFile = (HdfsInputFile) hdfsFileIO.newInputFile(indexPath); + try (InputStream indexInputStream = indexInputFile.newStream()) { + Files.copy(indexInputStream, tempIndexFile, StandardCopyOption.REPLACE_EXISTING); + } + log.info("Copied index file to temp location"); + try (ReaderSupplier rs = ReaderSupplierFactory.open(tempIndexFile)) { + OnDiskGraphIndex loadedIndex = OnDiskGraphIndex.load(rs); + VectorFloat query = vts.createFloatVector(extractVector(queryVector)); + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact( + query, VectorSimilarityFunction.EUCLIDEAN, loadedIndex.getView()); + SearchResult result; + try (GraphSearcher searcher = new GraphSearcher(loadedIndex)) { + result = searcher.search(ssp, topN, Bits.ALL); + } + // Translate node IDs to row IDs using the mapping + BlockBuilder rowIdBuilder = BIGINT.createBlockBuilder(null, topN); + for (SearchResult.NodeScore ns : result.getNodes()) { + try { + long rowId = mapping.getRowId(ns.node); + BIGINT.writeLong(rowIdBuilder, rowId); + log.debug("Translated node ID %d to row ID %d (score: %.4f)", + ns.node, rowId, ns.score); + } + catch (IndexOutOfBoundsException e) { + log.error("Invalid node ID %d returned from search. Mapping size: %d", + ns.node, mapping.size()); + throw new RuntimeException( + String.format("Invalid node ID %d. This may indicate index corruption.", ns.node), e); + } + } + finished = true; + return new Page(rowIdBuilder.build()); + } + } + catch (IOException e) { + throw new RuntimeException("Error reading vector index or mapping from S3", e); + } + finally { + if (tempIndexFile != null) { + try { + Files.deleteIfExists(tempIndexFile); + log.debug("Cleaned up temporary index file: %s", tempIndexFile); + } + catch (IOException e) { + log.warn(e, "Failed to delete temporary index file: %s", tempIndexFile); + } + } + } + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + private static float[] extractVector(List list) + { + float[] vec = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + vec[i] = list.get(i); + } + return vec; + } + + @Override + public void close() throws IOException + { + delegate.close(); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ApproxNearestNeighborsFunction.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ApproxNearestNeighborsFunction.java new file mode 100644 index 0000000000000..70fc059879796 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/tvf/ApproxNearestNeighborsFunction.java @@ -0,0 +1,301 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.tvf; + +import com.facebook.presto.common.block.IntArrayBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.iceberg.ColumnIdentity; +import com.facebook.presto.iceberg.IcebergColumnHandle; +import com.facebook.presto.iceberg.IcebergTableHandle; +import com.facebook.presto.iceberg.IcebergTableName; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorTableFunction; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import javax.inject.Provider; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; + +@Description("Approximate Nearest Neighbors for input vector against a vector column") +public class ApproxNearestNeighborsFunction + implements Provider +{ + @Override + public ConnectorTableFunction get() + { + return new ClassLoaderSafeConnectorTableFunction(new QueryFunction(), getClass().getClassLoader()); + } + + public static final String NAME = "approx_nearest_neighbors"; + public static final String SCHEMA_NAME = "system"; + + public static class QueryFunction + extends AbstractConnectorTableFunction + { + private static final String QUERY_VECTOR = "query_vector"; + private static final String COLUMN_NAME = "column_name"; + private static final String LIMIT = "limit"; + + public QueryFunction() + { + super( + SCHEMA_NAME, + NAME, + List.of( + ScalarArgumentSpecification.builder() + .name(QUERY_VECTOR) + .type(new ArrayType(REAL)) + .build(), + ScalarArgumentSpecification.builder() + .name(COLUMN_NAME) + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name(LIMIT) + .type(BIGINT) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments) + { + Descriptor returnedType = new Descriptor(ImmutableList.of( + new Descriptor.Field("row_id", Optional.of(BIGINT)))); + + ScalarArgument queryVector = (ScalarArgument) arguments.get(QUERY_VECTOR); + ScalarArgument limit = (ScalarArgument) arguments.get(LIMIT); + ScalarArgument qualifiedColumnName = (ScalarArgument) arguments.get(COLUMN_NAME); + + String formattedColumnName = stripSurroundingSingleQuotes(((Slice) qualifiedColumnName.getValue()).toStringUtf8()); + String columnName = parseQualifiedName(formattedColumnName).getColumn(); + String schemaName = parseQualifiedName(formattedColumnName).getSchema(); + String tableName = parseQualifiedName(formattedColumnName).getTable(); + + // --- build column handles --- + List columnHandles = ImmutableList.of( + new IcebergColumnHandle( + new ColumnIdentity(0, "row_id", ColumnIdentity.TypeCategory.PRIMITIVE, ImmutableList.of()), + BIGINT, + Optional.empty(), + REGULAR, + ImmutableList.of())); + + IcebergAnnTableFunctionHandle handle = + new IcebergAnnTableFunctionHandle(schemaName, tableName, queryVector, limit, columnHandles); + + return TableFunctionAnalysis.builder() + .returnedType(returnedType) + .handle(handle) + .build(); + } + + private static String stripSurroundingSingleQuotes(String input) + { + if (input != null && input.length() >= 2 && + input.startsWith("'") && input.endsWith("'")) { + return input.substring(1, input.length() - 1); + } + return input; + } + private static List extractColumnParameters(String columnName) + { + List columnHandles = new ArrayList<>(); + + Type prestoType = createUnboundedVarcharType(); + columnHandles.add(new IcebergColumnHandle(new ColumnIdentity(ThreadLocalRandom.current().nextInt(), columnName, ColumnIdentity.TypeCategory.PRIMITIVE, ImmutableList.of()), + prestoType, + Optional.empty(), + REGULAR, + ImmutableList.of())); + return columnHandles; + } + } + public static class QualifiedNameParts + { + public final String schema; + + public String getSchema() + { + return schema; + } + + public String getTable() + { + return table; + } + + public String getColumn() + { + return column; + } + + public final String table; + public final String column; + + public QualifiedNameParts(String schema, String table, String column) + { + this.schema = schema; + this.table = table; + this.column = column; + } + } + + public static QualifiedNameParts parseQualifiedName(String qualifiedName) + { + String[] parts = qualifiedName.split("\\."); + if (parts.length != 3) { + throw new IllegalArgumentException("Expected format: schema.table.column, but got: " + qualifiedName); + } + return new QualifiedNameParts(parts[0], parts[1], parts[2]); + } + + public static class IcebergAnnTableHandle + extends IcebergTableHandle + { + private final List queryVector; + private final int limit; + public IcebergAnnTableHandle(List queryVector, int limit, String schema, String table) + { + this(queryVector, limit, schema, table, Optional.empty()); + } + public IcebergAnnTableHandle(List queryVector, int limit, String schema, String table, Optional outputPath) + { + super(schema, IcebergTableName.from(table), false, outputPath, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), ImmutableList.of(), Optional.empty()); + this.queryVector = queryVector; + this.limit = limit; + } + public List getInputVector() + { + return queryVector; + } + + public int getLimit() + { + return limit; + } + } + public static class IcebergAnnTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final List queryVector; + private final int limit; + private final ConnectorTableHandle tableHandle; + private final List columnHandles; + + public IcebergAnnTableFunctionHandle(String schema, String table, ScalarArgument inputVector, ScalarArgument limit, List columnHandles) + { + this.queryVector = convertBlockToFloatList((IntArrayBlock) inputVector.getValue()); + this.limit = ((Long) limit.getValue()).intValue(); + this.tableHandle = new IcebergAnnTableHandle(queryVector, this.limit, schema, table); + this.columnHandles = columnHandles; + } + + public List getInputVector() + { + return queryVector; + } + + public int getLimit() + { + return limit; + } + + public ConnectorTableHandle getTableHandle() + { + return this.tableHandle; + } + public List getColumnHandles() + { + return this.columnHandles; + } + private static List parseQueryVector(String valueExpression) + { + if (valueExpression == null || valueExpression.isEmpty()) { + return List.of(); + } + + // Remove leading "ARRAY[" and trailing "]" + String inner = valueExpression.trim(); + if (inner.startsWith("ARRAY[")) { + inner = inner.substring("ARRAY[".length(), inner.length() - 1); + } + + // Split by comma + String[] parts = inner.split(","); + + List result = new ArrayList<>(); + for (String part : parts) { + part = part.trim(); + + // Strip DECIMAL '...' + if (part.startsWith("DECIMAL")) { + int start = part.indexOf('\''); + int end = part.lastIndexOf('\''); + if (start > 0 && end > start) { + part = part.substring(start + 1, end); + } + } + // Convert to float + result.add(Float.parseFloat(part)); + } + return result; + } + private static List convertBlockToFloatList(IntArrayBlock intBlock) + { + int positionCount = intBlock.getPositionCount(); + List result = new ArrayList<>(positionCount); + + for (int position = 0; position < positionCount; position++) { + if (intBlock.isNull(position)) { + result.add(null); + } + else { + // REAL type stores floats as int bits - must convert back + result.add(Float.intBitsToFloat(intBlock.getInt(position))); + } + } + + return result; + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/CustomVectorFloat.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/CustomVectorFloat.java new file mode 100644 index 0000000000000..0acceb0da0729 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/CustomVectorFloat.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.vectors; + +import com.facebook.airlift.log.Logger; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.lang.reflect.Constructor; + +/** + * Custom implementation of VectorFloat that wraps a float[] array. + * This is used because ArrayVectorFloat constructors are not public in JVector 4.0.0-rc.4. + * This class also provides methods to convert to ArrayVectorFloat when needed. + */ +public class CustomVectorFloat + implements VectorFloat +{ + private static final Logger log = Logger.get(CustomVectorFloat.class); + private final float[] values; + private static Constructor arrayVectorFloatConstructor; + // Static initializer to find the ArrayVectorFloat constructor + static { + try { + Class arrayVectorFloatClass = Class.forName("io.github.jbellis.jvector.vector.ArrayVectorFloat"); + arrayVectorFloatConstructor = arrayVectorFloatClass.getDeclaredConstructor(float[].class); + arrayVectorFloatConstructor.setAccessible(true); + } + catch (Exception e) { + try { + Package vectorPackage = VectorFloat.class.getPackage(); + String packageName = vectorPackage.getName(); + Class arrayVectorFloatClass = Class.forName(packageName + ".ArrayVectorFloat"); + arrayVectorFloatConstructor = arrayVectorFloatClass.getDeclaredConstructor(float[].class); + arrayVectorFloatConstructor.setAccessible(true); + } + catch (Exception e2) { + log.warn("Could not get ArrayVectorFloat constructor using package approach: %s", e2.getMessage()); + arrayVectorFloatConstructor = null; + } + } + } + + public CustomVectorFloat(float[] values) + { + this.values = values; + } + /** + * Attempts to convert this CustomVectorFloat to an ArrayVectorFloat instance. + * + * @return An ArrayVectorFloat instance if conversion is possible, or this instance if not + */ + public VectorFloat toArrayVectorFloat() + { + if (arrayVectorFloatConstructor != null) { + try { + float[] copy = new float[values.length]; + System.arraycopy(values, 0, copy, 0, values.length); + return (VectorFloat) arrayVectorFloatConstructor.newInstance((Object) copy); + } + catch (Exception e) { + log.warn("Failed to convert to ArrayVectorFloat: %s", e.getMessage()); + } + } + return this; + } + + /** + * Returns the vector instance. + * Required by the VectorFloat interface. + */ + @Override + public CustomVectorFloat get() + { + return this; + } + + @Override + public float get(int index) + { + return values[index]; + } + + @Override + public void set(int index, float value) + { + values[index] = value; + } + /** + * Returns the length (dimension) of this vector. + */ + @Override + public int length() + { + return values.length; + } + + @Override + public void copyFrom(VectorFloat src, int srcOffset, int destOffset, int length) + { + if (src instanceof CustomVectorFloat) { + float[] srcValues = ((CustomVectorFloat) src).values; + System.arraycopy(srcValues, srcOffset, this.values, destOffset, length); + } + else { + for (int i = 0; i < length; i++) { + this.values[destOffset + i] = src.get(srcOffset + i); + } + } + } + @Override + public void zero() + { + for (int i = 0; i < values.length; i++) { + values[i] = 0.0f; + } + } + + /** + * Returns a hash code value for this vector. + */ + @Override + public int getHashCode() + { + int result = 1; + for (int i = 0; i < this.length(); ++i) { + if (this.get(i) != 0.0F) { + result = 31 * result + Float.hashCode(this.get(i)); + } + } + return result; + } + + /** + * Returns the memory usage of this vector in bytes. + * Required by the Accountable interface. + */ + @Override + public long ramBytesUsed() + { + // Base object overhead (16 bytes) + array reference (8 bytes) + float array size + return 16 + 8 + (long) values.length * Float.BYTES; + } + + /** + * Returns a copy of the internal float array. + */ + public float[] vectorValue() + { + float[] copy = new float[values.length]; + System.arraycopy(values, 0, copy, 0, values.length); + return copy; + } + + /** + * Creates a copy of this vector with its own independent float array. + * @return A new CustomVectorFloat with a copy of the data + */ + @Override + public CustomVectorFloat copy() + { + float[] copy = new float[values.length]; + System.arraycopy(values, 0, copy, 0, values.length); + return new CustomVectorFloat(copy); + } + /** + * Returns the raw float array backing this vector. + * @return The internal float array (not a copy) + */ + public float[] getFloatArray() + { + return values; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/IcebergVectorIndexBuilder.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/IcebergVectorIndexBuilder.java new file mode 100644 index 0000000000000..38d5bad2c0efb --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/IcebergVectorIndexBuilder.java @@ -0,0 +1,472 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.vectors; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.iceberg.HdfsFileIO; +import com.facebook.presto.iceberg.HdfsOutputFile; +import com.facebook.presto.iceberg.IcebergAbstractMetadata; +import com.facebook.presto.iceberg.IcebergColumnHandle; +import com.facebook.presto.iceberg.IcebergSplit; +import com.facebook.presto.iceberg.IcebergTableHandle; +import com.facebook.presto.iceberg.IcebergTableLayoutHandle; +import com.facebook.presto.iceberg.IcebergUtil; +import com.facebook.presto.iceberg.PartitionData; +import com.facebook.presto.iceberg.RuntimeStatsMetricsReporter; +import com.facebook.presto.iceberg.delete.DeleteFile; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.SplitWeight; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.io.CloseableIterable; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** + * Utility class for building vector indexes from Iceberg table data. + */ +public class IcebergVectorIndexBuilder +{ + private static final Logger log = Logger.get(IcebergVectorIndexBuilder.class); + private static final String VECTOR_INDEX_DIR = ".vector_index"; + private static final String ROW_ID_COLUMN = "row_id"; + + private IcebergVectorIndexBuilder() {} + + /** + * Internal class to hold vector data along with corresponding row IDs. + */ + private static class VectorData + { + final List vectors; + final List rowIds; + + VectorData(List vectors, List rowIds) + { + if (vectors.size() != rowIds.size()) { + throw new IllegalStateException( + String.format("Vectors and row IDs must have the same size. Vectors: %d, Row IDs: %d", + vectors.size(), rowIds.size())); + } + this.vectors = vectors; + this.rowIds = rowIds; + } + } + + /** + * Builds a vector index from an Iceberg table column and saves it to S3. + * The index is saved to the table's data location using Iceberg's FileIO, + * which automatically handles S3 configuration. + * Path format: [table_data_location]/.vector_index/vector_index.hnsw + * + * @param metadata The connector metadata + * @param pageSourceProvider The page source provider + * @param transactionHandle The transaction handle + * @param session The connector session + * @param schemaTableName The schema and table name + * @param columnName The name of the column containing vector data + * @param indexName The name of the index + * @param catalogName The catalog name + * @param similarityFunction The similarity function to use for the index + * @param m The maximum number of connections per node in the graph + * @param efConstruction The size of the dynamic candidate list during construction + * @return The path to the saved index file + */ + public static Path buildAndSaveVectorIndex( + ConnectorMetadata metadata, + ConnectorPageSourceProvider pageSourceProvider, + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + SchemaTableName schemaTableName, + String columnName, + String indexName, + String catalogName, + String similarityFunction, + int m, + int efConstruction) throws Exception + { + // 1. Get the Iceberg table + Table icebergTable = IcebergUtil.getIcebergTable(metadata, session, schemaTableName); + // Get the table's HdfsFileIO for S3 operations + HdfsFileIO hdfsFileIO = (HdfsFileIO) icebergTable.io(); + // Get the table's location + String tableLocation; + if (icebergTable instanceof BaseTable) { + // Get location from metadata without filesystem access + tableLocation = ((BaseTable) icebergTable).operations().current().location(); + } + else { + // Fallback to direct call + tableLocation = icebergTable.location(); + } + // Compute the S3 path for the index + String indexDirPath = tableLocation + "/" + VECTOR_INDEX_DIR; + String indexFileName = indexName + ".hnsw"; + String indexPath = indexDirPath + "/" + indexFileName; + log.info("Vector index will be saved to S3 path: %s", indexPath); + log.info("Table location: %s", tableLocation); + // 2. Read vectors and row IDs from the table + VectorData vectorData = readVectorsFromTable( + metadata, + pageSourceProvider, + transactionHandle, + session, + schemaTableName, + columnName); + if (vectorData.vectors.isEmpty()) { + throw new IllegalStateException("No vectors found in column: " + columnName); + } + log.info("Read %d vectors with corresponding row IDs", vectorData.vectors.size()); + // Normalize all vectors using L2 normalization + log.info("Normalizing vectors using L2 normalization"); + for (float[] vector : vectorData.vectors) { + CustomVectorFloat customVector = new CustomVectorFloat(vector); + VectorUtil.l2normalize(customVector.toArrayVectorFloat()); + } + // 3. Create vector values wrapper + int dimension = vectorData.vectors.get(0).length; + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectorData.vectors, dimension); + // 4. Create similarity function + VectorSimilarityFunction simFunction = getVectorSimilarityFunction(similarityFunction); + // 5. Build the index + log.info("Building vector index with %d vectors of dimension %d", vectorData.vectors.size(), dimension); + // Create a BuildScoreProvider from the RandomAccessVectorValues and similarity function + BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, simFunction); + GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), m, efConstruction, 4 * m, 1.2f, false, true); + ImmutableGraphIndex index = builder.build(ravv); + log.info("Vector index built successfully with %d nodes", index.size()); + // 6. Create node-to-row ID mapping + NodeRowIdMapping mapping = new NodeRowIdMapping(vectorData.rowIds); + log.info("Created node-to-row ID mapping with %d entries", mapping.size()); + // 7. Prepare mapping file path + String mappingFileName = indexName + "_mapping.bin"; + String mappingPath = indexDirPath + "/" + mappingFileName; + try { + hdfsFileIO.newInputFile(indexPath).newStream().close(); + log.info("Existing index file found at %s - will be replaced", indexPath); + } + catch (Exception e) { + log.info("No existing index file found at %s - creating new", indexPath); + } + try { + hdfsFileIO.newInputFile(mappingPath).newStream().close(); + log.info("Existing mapping file found at %s - will be replaced", mappingPath); + } + catch (Exception e) { + log.info("No existing mapping file found at %s - creating new", mappingPath); + } + // 8. Save index and mapping to S3 using Iceberg's FileIO with retry logic + int maxRetries = 3; + int retryDelayMs = 1000; + Path localTempIndexPath = null; + Path localTempMappingPath = null; + for (int attempt = 1; attempt <= maxRetries; attempt++) { + try { + log.info("Saving index and mapping to S3 (attempt %d of %d)", attempt, maxRetries); + // Create local temporary files + localTempIndexPath = Files.createTempFile("vector-index-", ".tmp"); + localTempMappingPath = Files.createTempFile("vector-mapping-", ".tmp"); + try { + // Write the index to local temporary file + log.info("Writing index to local temporary file: %s", localTempIndexPath); + OnDiskGraphIndex.write(index, ravv, localTempIndexPath); + + // Write the mapping to local temporary file + log.info("Writing mapping to local temporary file: %s", localTempMappingPath); + try (OutputStream mappingOut = Files.newOutputStream(localTempMappingPath)) { + mapping.save(mappingOut); + } + + // Upload index to S3 + log.info("Uploading index to S3: %s", indexPath); + HdfsOutputFile indexOutputFile = (HdfsOutputFile) hdfsFileIO.newOutputFile(indexPath); + try (OutputStream out = indexOutputFile.createOrOverwrite()) { + Files.copy(localTempIndexPath, out); + } + log.info("Vector index saved successfully to S3: %s", indexPath); + // Upload mapping to S3 + log.info("Uploading mapping to S3: %s", mappingPath); + HdfsOutputFile mappingOutputFile = (HdfsOutputFile) hdfsFileIO.newOutputFile(mappingPath); + try (OutputStream out = mappingOutputFile.createOrOverwrite()) { + Files.copy(localTempMappingPath, out); + } + log.info("Node-to-row ID mapping saved successfully to S3: %s", mappingPath); + return Paths.get(indexPath); + } + finally { + // Clean up local temporary files + if (localTempIndexPath != null) { + try { + Files.delete(localTempIndexPath); + } + catch (IOException e) { + log.warn(e, "Failed to delete local temporary index file: %s", localTempIndexPath); + } + } + if (localTempMappingPath != null) { + try { + Files.delete(localTempMappingPath); + } + catch (IOException e) { + log.warn(e, "Failed to delete local temporary mapping file: %s", localTempMappingPath); + } + } + } + } + catch (Exception e) { + log.error(e, "Error saving index and mapping to S3 (attempt %d of %d): %s", + attempt, maxRetries, e.getMessage()); + if (attempt < maxRetries) { + try { + Thread.sleep(retryDelayMs * attempt); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } + else { + throw new RuntimeException("Failed to save vector index and mapping to S3 after " + maxRetries + " attempts", e); + } + } + } + throw new RuntimeException("Failed to save vector index and mapping to S3 after " + maxRetries + " attempts"); + } + private static VectorSimilarityFunction getVectorSimilarityFunction(String similarityFunction) + { + switch (similarityFunction.toUpperCase()) { + case "COSINE": + return VectorSimilarityFunction.COSINE; + case "DOT_PRODUCT": + return VectorSimilarityFunction.DOT_PRODUCT; + case "EUCLIDEAN": + return VectorSimilarityFunction.EUCLIDEAN; + default: + throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction); + } + } + /** + * Reads vector data and row IDs from an Iceberg table. + * + * @return VectorData containing both vectors and their corresponding row IDs + */ + private static VectorData readVectorsFromTable( + ConnectorMetadata metadata, + ConnectorPageSourceProvider pageSourceProvider, + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + SchemaTableName schemaTableName, + String columnName) throws IOException + { + // Get the table handle + IcebergTableHandle tableHandle = (IcebergTableHandle) metadata.getTableHandle(session, schemaTableName); + if (tableHandle == null) { + throw new IllegalArgumentException("Table not found: " + schemaTableName); + } + // Get TypeManager from metadata + TypeManager typeManager = null; + if (metadata instanceof IcebergAbstractMetadata) { + typeManager = ((IcebergAbstractMetadata) metadata).getTypeManager(); + } + if (typeManager == null) { + throw new IllegalStateException("Could not get TypeManager from metadata"); + } + // Get the Iceberg table + Table icebergTable = IcebergUtil.getIcebergTable(metadata, session, schemaTableName); + // Get all columns + List columns = IcebergUtil.getColumns( + icebergTable.schema(), + icebergTable.spec(), + typeManager); + // Find the target vector column + IcebergColumnHandle targetColumn = null; + for (IcebergColumnHandle column : columns) { + if (column.getName().equals(columnName)) { + targetColumn = column; + break; + } + } + if (targetColumn == null) { + throw new IllegalArgumentException("Vector column not found: " + columnName); + } + // Verify vector column type + Type columnType = targetColumn.getType(); + if (!(columnType instanceof ArrayType)) { + throw new IllegalArgumentException("Vector column must be an array type: " + columnName); + } + // Find the row_id column + IcebergColumnHandle rowIdColumn = null; + for (IcebergColumnHandle column : columns) { + if (column.getName().equals(ROW_ID_COLUMN)) { + rowIdColumn = column; + break; + } + } + if (rowIdColumn == null) { + throw new IllegalArgumentException( + String.format("Required column '%s' not found in table %s. " + + "All tables using vector indexing must have a '%s' column of type BIGINT.", + ROW_ID_COLUMN, schemaTableName, ROW_ID_COLUMN)); + } + // Verify row_id column type + if (!rowIdColumn.getType().equals(com.facebook.presto.common.type.BigintType.BIGINT)) { + throw new IllegalArgumentException( + String.format("Column '%s' must be of type BIGINT, but found: %s", + ROW_ID_COLUMN, rowIdColumn.getType())); + } + log.info("Reading vectors from column '%s' and row IDs from column '%s'", columnName, ROW_ID_COLUMN); + // Create a table layout handle + IcebergTableLayoutHandle layoutHandle = createTableLayoutHandle(tableHandle, columns); + // Use TableScan API to get actual data files + List vectors = new ArrayList<>(); + List rowIds = new ArrayList<>(); + // Create a table scan to get the data files + TableScan tableScan = icebergTable.newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())); + // If the table has a current snapshot, use it + if (icebergTable.currentSnapshot() != null) { + tableScan = tableScan.useSnapshot(icebergTable.currentSnapshot().snapshotId()); + } + // Get the data files + try (CloseableIterable fileScanTasks = tableScan.planFiles()) { + for (FileScanTask fileScanTask : fileScanTasks) { + // Create a split for each data file + IcebergSplit split = new IcebergSplit( + fileScanTask.file().path().toString(), // Use the actual file path + fileScanTask.start(), + fileScanTask.length(), + IcebergUtil.getFileFormat(icebergTable), + ImmutableList.of(), + IcebergUtil.getPartitionKeys(fileScanTask), + PartitionSpecParser.toJson(fileScanTask.spec()), + IcebergUtil.partitionDataFromStructLike(fileScanTask.spec(), fileScanTask.file().partition()).map(PartitionData::toJson), + NodeSelectionStrategy.NO_PREFERENCE, + SplitWeight.standard(), + fileScanTask.deletes().stream().map(DeleteFile::fromIceberg).collect(toImmutableList()), + Optional.empty(), + IcebergUtil.getDataSequenceNumber(fileScanTask.file()), + 1, + false, + ImmutableList.of(), + 0); // Use 1 to avoid division by zero + // Read the data from this split - read both vector and row_id columns + try (ConnectorPageSource pageSource = pageSourceProvider.createPageSource( + transactionHandle, + session, + split, + layoutHandle, + ImmutableList.of(targetColumn, rowIdColumn), + new SplitContext(false), + new RuntimeStats())) { + while (!pageSource.isFinished()) { + Page page = pageSource.getNextPage(); + if (page == null) { + continue; + } + // Block 0 is the vector column, Block 1 is the row_id column + Block vectorBlock = page.getBlock(0); + Block rowIdBlock = page.getBlock(1); + for (int position = 0; position < vectorBlock.getPositionCount(); position++) { + // Skip if vector is null + if (vectorBlock.isNull(position)) { + continue; + } + // Skip if row_id is null + if (rowIdBlock.isNull(position)) { + log.warn("Skipping row at position %d: row_id is null", position); + continue; + } + Block arrayBlock = vectorBlock.getBlock(position); + if (arrayBlock.getPositionCount() == 0) { + continue; + } + // Read the vector + float[] vector = new float[arrayBlock.getPositionCount()]; + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + if (arrayBlock.isNull(i)) { + vector[i] = 0.0f; + } + else { + vector[i] = ((Number) ((ArrayType) columnType).getElementType().getObjectValue( + session.getSqlFunctionProperties(), arrayBlock, i)).floatValue(); + } + } + // Read the row_id + long rowId = com.facebook.presto.common.type.BigintType.BIGINT.getLong(rowIdBlock, position); + // Add both to their respective lists (maintaining order) + vectors.add(vector); + rowIds.add(rowId); + } + } + } + } + } + log.info("Read %d vectors with %d row IDs from table", vectors.size(), rowIds.size()); + return new VectorData(vectors, rowIds); + } + /** + * Creates a minimal IcebergTableLayoutHandle for reading data. + */ + private static IcebergTableLayoutHandle createTableLayoutHandle(IcebergTableHandle tableHandle, List columns) + { + return new IcebergTableLayoutHandle.Builder() + .setPartitionColumns(ImmutableList.of()) + .setDataColumns(ImmutableList.of()) + .setDomainPredicate(TupleDomain.all()) + .setRemainingPredicate(new ConstantExpression(true, BooleanType.BOOLEAN)) + .setPredicateColumns(ImmutableMap.of()) + .setRequestedColumns(Optional.empty()) + .setPushdownFilterEnabled(false) + .setPartitionColumnPredicate(TupleDomain.all()) + .setPartitions(Optional.empty()) + .setTable(tableHandle) + .build(); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/ListRandomAccessVectorValues.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/ListRandomAccessVectorValues.java new file mode 100644 index 0000000000000..3a86cf4bb9e78 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/ListRandomAccessVectorValues.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.vectors; + +import com.facebook.airlift.log.Logger; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.vector.ArrayVectorFloat; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.lang.reflect.Constructor; +import java.util.List; + +/** + * Implementation of RandomAccessVectorValues that wraps a list of vectors. + * This is used to provide random access to vectors read from Iceberg tables. + */ +public class ListRandomAccessVectorValues + implements RandomAccessVectorValues +{ + private static final Logger log = Logger.get(ListRandomAccessVectorValues.class); + private final List vectors; + private final int dimension; + private final Constructor arrayVectorFloatConstructor; + + public ListRandomAccessVectorValues(List vectors, int dimension) + { + this.vectors = vectors; + this.dimension = dimension; + // Trying to get the ArrayVectorFloat constructor using reflection with multiple approaches + Constructor constructor = null; + // First approach: direct class name + try { + Class arrayVectorFloatClass = ArrayVectorFloat.class; + constructor = arrayVectorFloatClass.getDeclaredConstructor(float[].class); + constructor.setAccessible(true); + log.info("Found ArrayVectorFloat constructor using direct class name"); + } + catch (Exception e) { + log.warn("Could not get ArrayVectorFloat constructor using direct class name: %s", e.getMessage()); + // Second approach: try to find the class from VectorFloat package + try { + // Get the package from VectorFloat + Package vectorPackage = VectorFloat.class.getPackage(); + String packageName = vectorPackage.getName(); + // Try to load the class from the same package + Class arrayVectorFloatClass = Class.forName(packageName + ".ArrayVectorFloat"); + constructor = arrayVectorFloatClass.getDeclaredConstructor(float[].class); + constructor.setAccessible(true); + log.info("Found ArrayVectorFloat constructor using package name: %s", packageName); + } + catch (Exception e2) { + log.warn("Could not get ArrayVectorFloat constructor using package approach: %s", e2.getMessage()); + } + } + this.arrayVectorFloatConstructor = constructor; + // Log whether we found the constructor + if (constructor != null) { + log.info("Successfully found ArrayVectorFloat constructor"); + } + else { + log.warn("Could not find ArrayVectorFloat constructor, will use CustomVectorFloat as fallback"); + } + } + + @Override + public int size() + { + return vectors.size(); + } + + @Override + public int dimension() + { + return dimension; + } + + @Override + public VectorFloat getVector(int ord) + { + // Get the vector data + float[] data = vectors.get(ord); + // Try to create an ArrayVectorFloat using reflection + if (arrayVectorFloatConstructor != null) { + try { + Object vectorInstance = arrayVectorFloatConstructor.newInstance((Object) data); + log.debug("Successfully created ArrayVectorFloat instance for ord=%d", ord); + return (VectorFloat) vectorInstance; + } + catch (Exception e) { + // Log the full stack trace for better debugging + log.warn("Could not create ArrayVectorFloat: %s", e.getMessage()); + log.debug("Exception creating ArrayVectorFloat", e); + // Try again with a copy of the data + try { + float[] dataCopy = new float[data.length]; + System.arraycopy(data, 0, dataCopy, 0, data.length); + Object vectorInstance = arrayVectorFloatConstructor.newInstance((Object) dataCopy); + log.debug("Successfully created ArrayVectorFloat with data copy for ord=%d", ord); + return (VectorFloat) vectorInstance; + } + catch (Exception e2) { + log.warn("Could not create ArrayVectorFloat with data copy: %s", e2.getMessage()); + } + } + } + // Fall back to custom implementation if reflection fails + log.debug("Using CustomVectorFloat fallback for ord=%d", ord); + return new CustomVectorFloat(data); + } + + @Override + public boolean isValueShared() + { + return false; + } + + @Override + public RandomAccessVectorValues copy() + { + return this; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/NodeRowIdMapping.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/NodeRowIdMapping.java new file mode 100644 index 0000000000000..36faf74029a69 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/vectors/NodeRowIdMapping.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.vectors; + +import com.facebook.airlift.log.Logger; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.LongBuffer; +import java.util.List; + +/** + * Maintains the mapping between jvector node IDs and database row IDs. + */ +public class NodeRowIdMapping +{ + private static final Logger log = Logger.get(NodeRowIdMapping.class); + private static final int VERSION = 1; + private static final int BUFFER_SIZE = 65536; // 64KB buffer for I/O + + private final long[] nodeToRowId; + + /** + * Creates a new mapping from a list of row IDs. + * The list order must match the order vectors were added to the index. + * @param rowIds List of row IDs in the same order as vectors were added to the index + */ + public NodeRowIdMapping(List rowIds) + { + if (rowIds == null || rowIds.isEmpty()) { + throw new IllegalArgumentException("Row IDs list cannot be null or empty"); + } + + this.nodeToRowId = new long[rowIds.size()]; + for (int i = 0; i < rowIds.size(); i++) { + this.nodeToRowId[i] = rowIds.get(i); + } + + log.info("Created NodeRowIdMapping with %d entries", nodeToRowId.length); + } + + private NodeRowIdMapping(long[] nodeToRowId) + { + this.nodeToRowId = nodeToRowId; + } + + /** + * Gets the row ID for a given node ID. + * @param nodeId The node ID from the vector index (0-based) + * @return The corresponding row ID from the database table + */ + public long getRowId(int nodeId) + { + if (nodeId < 0 || nodeId >= nodeToRowId.length) { + throw new IndexOutOfBoundsException( + String.format("Invalid node ID %d. Valid range: [0, %d)", + nodeId, nodeToRowId.length)); + } + return nodeToRowId[nodeId]; + } + + /** + * Returns the number of mappings. + */ + public int size() + { + return nodeToRowId.length; + } + + /** + * Saves the mapping to an output stream in binary format. + * @param out The output stream to write to + */ + public void save(OutputStream out) throws IOException + { + BufferedOutputStream bufferedOut = new BufferedOutputStream(out, BUFFER_SIZE); + DataOutputStream dos = new DataOutputStream(bufferedOut); + + try { + // Write header (version and size) + dos.writeInt(VERSION); + dos.writeInt(nodeToRowId.length); + + // Write all row IDs in using ByteBuffer + int dataSize = nodeToRowId.length * 8; + ByteBuffer buffer = ByteBuffer.allocate(Math.min(dataSize, BUFFER_SIZE)); + buffer.order(ByteOrder.BIG_ENDIAN); + + int offset = 0; + while (offset < nodeToRowId.length) { + buffer.clear(); + LongBuffer longBuffer = buffer.asLongBuffer(); + + int remaining = nodeToRowId.length - offset; + int count = Math.min(remaining, BUFFER_SIZE / 8); + longBuffer.put(nodeToRowId, offset, count); + + // Write buffer to stream + int bytesToWrite = count * 8; + dos.write(buffer.array(), 0, bytesToWrite); + + offset += count; + } + + dos.flush(); + } + finally { + dos.flush(); + } + } + + /** + Loads mapping using bulk I/O operations. + */ + public static NodeRowIdMapping load(InputStream in) throws IOException + { + BufferedInputStream bufferedIn = new BufferedInputStream(in, BUFFER_SIZE); + DataInputStream dis = new DataInputStream(bufferedIn); + + int version = dis.readInt(); + if (version != VERSION) { + throw new IOException( + String.format("Unsupported mapping file version: %d (expected %d)", + version, VERSION)); + } + + int size = dis.readInt(); + if (size <= 0) { + throw new IOException("Invalid mapping size: " + size); + } + + //Read all row IDs in using ByteBuffer + long[] nodeToRowId = new long[size]; + int dataSize = size * 8; + + byte[] buffer = new byte[Math.min(dataSize, BUFFER_SIZE)]; + + int offset = 0; + while (offset < size) { + int remaining = size - offset; + int count = Math.min(remaining, BUFFER_SIZE / 8); + int bytesToRead = count * 8; + + dis.readFully(buffer, 0, bytesToRead); + + // Convert bytes to longs + ByteBuffer byteBuffer = ByteBuffer.wrap(buffer, 0, bytesToRead); + byteBuffer.order(ByteOrder.BIG_ENDIAN); + LongBuffer longBuffer = byteBuffer.asLongBuffer(); + longBuffer.get(nodeToRowId, offset, count); + + offset += count; + } + + return new NodeRowIdMapping(nodeToRowId); + } + + @Override + public String toString() + { + return String.format("NodeRowIdMapping[size=%d]", nodeToRowId.length); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java index bf25891c38d57..fd9dccf8c07b1 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java @@ -74,7 +74,8 @@ public void testDefaults() .setManifestCacheMaxChunkSize(succinctDataSize(2, MEGABYTE)) .setMaxStatisticsFileCacheSize(succinctDataSize(256, MEGABYTE)) .setStatisticsKllSketchKParameter(1024) - .setMaterializedViewStoragePrefix("__mv_storage__")); + .setMaterializedViewStoragePrefix("__mv_storage__") + .setSimilaritySearchEnabled(false)); } @Test @@ -111,6 +112,7 @@ public void testExplicitPropertyMappings() .put("iceberg.max-statistics-file-cache-size", "512MB") .put("iceberg.statistics-kll-sketch-k-parameter", "4096") .put("iceberg.materialized-view-storage-prefix", "custom_mv_prefix") + .put("iceberg.similarity-search-enabled", "true") .build(); IcebergConfig expected = new IcebergConfig() @@ -143,7 +145,8 @@ public void testExplicitPropertyMappings() .setMetricsMaxInferredColumn(16) .setMaxStatisticsFileCacheSize(succinctDataSize(512, MEGABYTE)) .setStatisticsKllSketchKParameter(4096) - .setMaterializedViewStoragePrefix("custom_mv_prefix"); + .setMaterializedViewStoragePrefix("custom_mv_prefix") + .setSimilaritySearchEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java index 46d4b7dc7478b..7da65190ab92b 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java @@ -422,7 +422,8 @@ private ConnectorMetadata getIcebergHiveMetadata(ExtendedHiveMetastore metastore new StatisticsFileCache(CacheBuilder.newBuilder().build()), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), new IcebergTableProperties(new IcebergConfig()), - () -> false); + () -> false, + new IcebergConfig()); return icebergHiveMetadataFactory.create(); } diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 228d31f569937..d58cb3c0683a6 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -1172,6 +1172,15 @@ void to_json(json& j, const IcebergSplit& p) { "IcebergSplit", "int64_t", "affinitySchedulingSectionSize"); + to_json_key(j, "ann", p.ann, "IcebergSplit", "bool", "ann"); + to_json_key( + j, + "queryVector", + p.queryVector, + "IcebergSplit", + "List", + "queryVector"); + to_json_key(j, "topN", p.topN, "IcebergSplit", "int", "topN"); } void from_json(const json& j, IcebergSplit& p) { @@ -1251,6 +1260,15 @@ void from_json(const json& j, IcebergSplit& p) { "IcebergSplit", "int64_t", "affinitySchedulingSectionSize"); + from_json_key(j, "ann", p.ann, "IcebergSplit", "bool", "ann"); + from_json_key( + j, + "queryVector", + p.queryVector, + "IcebergSplit", + "List", + "queryVector"); + from_json_key(j, "topN", p.topN, "IcebergSplit", "int", "topN"); } } // namespace facebook::presto::protocol::iceberg namespace facebook::presto::protocol::iceberg { diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h index a659dc24d103b..11d9cc13b238d 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h @@ -253,6 +253,9 @@ struct IcebergSplit : public ConnectorSplit { std::shared_ptr changelogSplitInfo = {}; int64_t dataSequenceNumber = {}; int64_t affinitySchedulingSectionSize = {}; + bool ann = {}; + List queryVector = {}; + int32_t topN = {}; IcebergSplit() noexcept; }; diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergSplit.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergSplit.hpp.inc index 471beab8d2805..2ab6b0462f604 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergSplit.hpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/special/IcebergSplit.hpp.inc @@ -31,6 +31,9 @@ struct IcebergSplit : public ConnectorSplit { std::shared_ptr changelogSplitInfo = {}; int64_t dataSequenceNumber = {}; int64_t affinitySchedulingSectionSize = {}; + bool ann = {}; + List queryVector = {}; + int32_t topN = {}; IcebergSplit() noexcept; }; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorTableFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorTableFunction.java new file mode 100644 index 0000000000000..56a9e818b3ed9 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorTableFunction.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.connector.classloader; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class ClassLoaderSafeConnectorTableFunction + implements ConnectorTableFunction +{ + private final ConnectorTableFunction delegate; + private final ClassLoader classLoader; + + public ClassLoaderSafeConnectorTableFunction(ConnectorTableFunction delegate, ClassLoader classLoader) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public String getSchema() + { + try (ThreadContextClassLoader a = new ThreadContextClassLoader(classLoader)) { + return delegate.getSchema(); + } + } + + @Override + public String getName() + { + try (ThreadContextClassLoader a = new ThreadContextClassLoader(classLoader)) { + return delegate.getName(); + } + } + + @Override + public List getArguments() + { + try (ThreadContextClassLoader a = new ThreadContextClassLoader(classLoader)) { + return delegate.getArguments(); + } + } + + @Override + public ReturnTypeSpecification getReturnTypeSpecification() + { + try (ThreadContextClassLoader a = new ThreadContextClassLoader(classLoader)) { + return delegate.getReturnTypeSpecification(); + } + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments) + { + try (ThreadContextClassLoader a = new ThreadContextClassLoader(classLoader)) { + return delegate.analyze(session, transaction, arguments); + } + } +}