From ebb8f4533ab4fce2aae14fb6e86d958a12832fde Mon Sep 17 00:00:00 2001 From: Ceng <441651826@qq.com> Date: Tue, 14 Jan 2025 18:06:51 +0800 Subject: [PATCH] [Spark/NativeIO] fix lsh hard coding & apply lsh compute with schema metadata (#571) Signed-off-by: zenghua Co-authored-by: zenghua --- .../lakesoul/meta/LakeSoulOptions.scala | 6 + .../lakesoul/commands/MergeIntoSQLSuite.scala | 412 +++++++++--------- .../apache/spark/sql/arrow/ArrowUtils.scala | 21 +- rust/lakesoul-io/src/lakesoul_io_config.rs | 16 +- rust/lakesoul-io/src/lakesoul_writer.rs | 266 +++-------- rust/lakesoul-io/src/lib.rs | 1 + .../src/local_sensitive_hash/mod.rs | 182 ++++++++ 7 files changed, 489 insertions(+), 415 deletions(-) create mode 100644 rust/lakesoul-io/src/local_sensitive_hash/mod.rs diff --git a/lakesoul-common/src/main/scala/com/dmetasoul/lakesoul/meta/LakeSoulOptions.scala b/lakesoul-common/src/main/scala/com/dmetasoul/lakesoul/meta/LakeSoulOptions.scala index 774d78322..a5dc1dafb 100644 --- a/lakesoul-common/src/main/scala/com/dmetasoul/lakesoul/meta/LakeSoulOptions.scala +++ b/lakesoul-common/src/main/scala/com/dmetasoul/lakesoul/meta/LakeSoulOptions.scala @@ -31,6 +31,12 @@ object LakeSoulOptions { val TIME_ZONE = "timezone" val DISCOVERY_INTERVAL = "discoveryinterval" + object SchemaFieldMetadata { + val LSH_EMBEDDING_DIMENSION = "lsh_embedding_dimension" + val LSH_BIT_WIDTH = "lsh_bit_width" + val LSH_RNG_SEED = "lsh_rng_seed" + } + object ReadType extends Enumeration { val FULL_READ = "fullread" val SNAPSHOT_READ = "snapshot" diff --git a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala index 50002ad5a..bdffb9516 100644 --- a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala +++ b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala @@ -20,7 +20,7 @@ import org.scalatestplus.junit.JUnitRunner import io.jhdf.HdfFile import io.jhdf.api.Dataset import org.apache.commons.lang3.ArrayUtils -import org.apache.spark.sql.types.{ArrayType, ByteType, FloatType, IntegerType, LongType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, ByteType, FloatType, IntegerType, LongType, MetadataBuilder, StructField, StructType} import org.apache.commons.lang3.ArrayUtils import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.rdd.RDD @@ -78,207 +78,215 @@ class MergeIntoSQLSuite extends QueryTest } } -// test("test lsh"){ -// val filepath = "/Users/beidu/Documents/dataset/glove-200-angular.hdf5" -// val trainPath = "/Users/beidu/Documents/LakeSoul/train" -// val testPath = "/Users/beidu/Documents/LakeSoul/test" -// -// println(filepath) -// val spark = SparkSession.builder -// .appName("Array to RDD") -// .master("local[*]") -// .getOrCreate() -// try{ -// val hdfFile = new HdfFile(Paths.get(filepath)) -// -// val trainDataset = hdfFile.getDatasetByPath("train") -// val testDataset = hdfFile.getDatasetByPath("test") -// val neighborDataset = hdfFile.getDatasetByPath("neighbors") -// val trainData = trainDataset.getData() -// val testData = testDataset.getData() -// val neighborData = neighborDataset.getData() -// println(trainData) -// var float2DDataNeighbor: Array[Array[Int]] = null -// neighborData match { -// case data:Array[Array[Int]] => -// float2DDataNeighbor = data -// case _ => -// println("not") -// } -// // the smaller the Hamming distance,the greater the similarity -// val calculateHammingDistanceUDF = udf((trainLSH: Seq[Long], testLSH: Seq[Long]) => { -// require(trainLSH.length == testLSH.length, "The input sequences must have the same length") -// trainLSH.zip(testLSH).map { case (train, test) => -// java.lang.Long.bitCount(train ^ test) -// }.sum -// }) -// // the smaller the Euclidean distance,the greater the similarity -// val calculateEuclideanDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { -// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") -// sqrt(trainEmbedding.zip(testEmbedding).map{case (train,test) => -// pow(train - test,2) }.sum) -// }) -// //the greater the Cosine distance,the greater the similarity -// val calculateCosineDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { -// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") -// trainEmbedding.zip(testEmbedding).map{case (train,test) => -// train * test}.sum / (sqrt(trainEmbedding.map{train => train * train}.sum) * sqrt(testEmbedding.map{test => test * test}.sum)) -// }) -// //the smaller the Jaccard distance,the greater the similarity -// val calculateJaccardDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { -// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") -// val anb = testEmbedding.intersect(trainEmbedding).distinct -// val aub = testEmbedding.union(trainEmbedding).distinct -// val jaccardCoefficient = anb.length.toDouble / aub.length -// 1 - jaccardCoefficient -// }) -// spark.udf.register("calculateHammingDistance",calculateHammingDistanceUDF) -// spark.udf.register("calculateEuclideanDistance",calculateEuclideanDistanceUDF) -// spark.udf.register("calculateCosineDistance",calculateCosineDistanceUDF) -// spark.udf.register("calculateJaccardDistance",calculateJaccardDistanceUDF) -//// println(float2DDataNeighbor.length) -// trainData match { -// case float2DData:Array[Array[Float]] => -// val classIds = (1 to float2DData.length).toArray -// val schema = StructType(Array( -// StructField("IndexId",IntegerType,true), -// StructField("Embedding",ArrayType(FloatType),true), -// StructField("LSH",ArrayType(LongType),true) -// )) -// val rows = float2DData.zip(classIds).map { -// case (embedding,indexId)=> -// Row(indexId,embedding,null) -// } -// val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),schema) -// df.write.format("lakesoul") -// .option("hashPartitions", "IndexId") -// .option("hashBucketNum", 4) -// .option(LakeSoulOptions.SHORT_TABLE_NAME,"trainData") -// .mode("Overwrite").save(trainPath) -//// val startTime1 = System.nanoTime() -// val lakeSoulTable = LakeSoulTable.forPath(trainPath) -// lakeSoulTable.compaction() -// -// testData match { -// case float2DTestData:Array[Array[Float]] => -// val classIdsTest = (1 to float2DTestData.length).toArray -// val schemaTest = StructType(Array( -// StructField("IndexId",IntegerType,true), -// StructField("Embedding",ArrayType(FloatType),true), -// StructField("LSH",ArrayType(LongType),true) -// )) -// val rowsTest = float2DTestData.zip(classIdsTest).map{ -// case (embedding,indexId) => -// Row(indexId,embedding,null) -// } -// -// val num = 50 -// val dfTest = spark.createDataFrame(spark.sparkContext.parallelize(rowsTest),schemaTest).limit(num) -// dfTest.write.format("lakesoul") -// .option("hashPartitions","IndexId") -// .option("hashBucketNum",4) -// .option(LakeSoulOptions.SHORT_TABLE_NAME,"testData") -// .mode("Overwrite").save(testPath) -// val lakeSoulTableTest = LakeSoulTable.forPath(testPath) -// -// lakeSoulTableTest.compaction() -//// val endTime1 = System.nanoTime() -//// val duration1 = (endTime1 - startTime1).nanos -//// println(s"time:${duration1.toMillis}") -//// val lshTrain = sql("select LSH from trainData") -//// val lshTest = sql("select LSH from testData") -// -//// val arr = Array(1,5,10,20,40,60,80,100,150,200,250,300) -//// for(n <- arr) { -// val n = 300 -// val topk = 100 -// val topkFirst = n * topk -// -// // val result = sql("select testData.IndexId as indexId,trainData.LSH as trainLSH,testData.LSH as testLSH," + -// // "calculateHammingDistance(testData.LSH,trainData.LSH) AS hamming_distance " + -// // "from testData " + -// // "cross join trainData " + -// // "order by indexId,hamming_distance") -// -// // val result = spark.sql(s""" -// // SELECT * -// // FROM ( -// // SELECT -// // testData.IndexId AS indexIdTest, -// // trainData.IndexId AS indexIdTrain, -// // trainData.LSH AS trainLSH, -// // testData.LSH AS testLSH, -// // calculateHammingDistance(testData.LSH, trainData.LSH) AS hamming_distance, -// // ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank -// // FROM testData -// // CROSS JOIN trainData -// // ) ranked -// // WHERE rank <= $topk -// // """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) -// val startTime = System.nanoTime() -// val result = spark.sql( -// s""" -// SELECT * -// FROM ( -// SELECT -// testData.IndexId AS indexIdTest, -// trainData.IndexId AS indexIdTrain, -// testData.Embedding as EmbeddingTest, -// trainData.Embedding as EmbeddingTrain, -// ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank -// FROM testData -// CROSS JOIN trainData -// ) ranked -// WHERE rank <= $topkFirst -// """) -// result.createOrReplaceTempView("rank") -// val reResult = spark.sql( -// s""" -// SELECT * -// FROM ( -// SELECT -// rank.indexIdTest, -// rank.indexIDTrain, -// ROW_NUMBER() OVER(PARTITION BY rank.indexIdTest ORDER BY calculateEuclideanDistance(rank.EmbeddingTest,rank.EmbeddingTrain) asc) AS reRank -// FROM rank -// ) reRanked -// WHERE reRank <= $topk -// """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) -// -// -// val endTime = System.nanoTime() -// val duration = (endTime - startTime).nanos -// println(s"time for query n4topk ${n} :${duration.toMillis} milliseconds") -// -// val startTime2 = System.nanoTime() -// -//// val (totalRecall, count) = reResult.map(row => { -//// val indexIdTest = row.getAs[Int]("indexIdTest") -//// val indexIdTrainList: Array[Int] = row.getAs[Seq[Int]]("indexIdTrainList").toArray -//// val updatedList = indexIdTrainList.map(_ - 1) -//// val count = float2DDataNeighbor(indexIdTest - 1).take(topk).count(updatedList.contains) -//// val recall = (count * 1.0 / topk) -//// (recall, 1) -//// }).reduce((acc1, acc2) => { -//// (acc1._1 + acc2._1, acc1._2 + acc2._2) -//// }) -//// println(totalRecall / count) -// val endTime2 = System.nanoTime() -// val duration2 = (endTime2 - startTime2).nanos -// println(s"time for sort:${duration2.toMillis} milliseconds") -//// } -// } -// case _ => -// println("unexpected data type") -// case _ => -// println("unexpected data type") -// } -// } -// finally { -// -// } -// -// } + // test("test lsh") { + // withTempDir { trainDir => + // withTempDir { testDir => + // val filepath = "/Users/ceng/Downloads/fashion-mnist-784-euclidean.hdf5" + // val trainPath = trainDir.getCanonicalPath + // val testPath = testDir.getCanonicalPath + // + // println(filepath) + // val spark = SparkSession.builder + // .appName("Array to RDD") + // .master("local[*]") + // .getOrCreate() + // try { + // val hdfFile = new HdfFile(Paths.get(filepath)) + // + // val trainDataset = hdfFile.getDatasetByPath("train") + // val testDataset = hdfFile.getDatasetByPath("test") + // val neighborDataset = hdfFile.getDatasetByPath("neighbors") + // val trainData = trainDataset.getData() + // val testData = testDataset.getData() + // val neighborData = neighborDataset.getData() + // + // var float2DDataNeighbor: Array[Array[Int]] = null + // neighborData match { + // case data: Array[Array[Int]] => + // float2DDataNeighbor = data + // case _ => + // println("not") + // } + // // the smaller the Hamming distance,the greater the similarity + // val calculateHammingDistanceUDF = udf((trainLSH: Seq[Long], testLSH: Seq[Long]) => { + // require(trainLSH.length == testLSH.length, "The input sequences must have the same length") + // trainLSH.zip(testLSH).map { case (train, test) => + // java.lang.Long.bitCount(train ^ test) + // }.sum + // }) + // // the smaller the Euclidean distance,the greater the similarity + // val calculateEuclideanDistanceUDF = udf((trainEmbedding: Seq[Float], testEmbedding: Seq[Float]) => { + // require(testEmbedding.length == trainEmbedding.length, "The input sequences must have the same length") + // sqrt(trainEmbedding.zip(testEmbedding).map { case (train, test) => + // pow(train - test, 2) + // }.sum) + // }) + // //the greater the Cosine distance,the greater the similarity + // val calculateCosineDistanceUDF = udf((trainEmbedding: Seq[Float], testEmbedding: Seq[Float]) => { + // require(testEmbedding.length == trainEmbedding.length, "The input sequences must have the same length") + // trainEmbedding.zip(testEmbedding).map { case (train, test) => + // train * test + // }.sum / (sqrt(trainEmbedding.map { train => train * train }.sum) * sqrt(testEmbedding.map { test => test * test }.sum)) + // }) + // //the smaller the Jaccard distance,the greater the similarity + // val calculateJaccardDistanceUDF = udf((trainEmbedding: Seq[Float], testEmbedding: Seq[Float]) => { + // require(testEmbedding.length == trainEmbedding.length, "The input sequences must have the same length") + // val anb = testEmbedding.intersect(trainEmbedding).distinct + // val aub = testEmbedding.union(trainEmbedding).distinct + // val jaccardCoefficient = anb.length.toDouble / aub.length + // 1 - jaccardCoefficient + // }) + // spark.udf.register("calculateHammingDistance", calculateHammingDistanceUDF) + // spark.udf.register("calculateEuclideanDistance", calculateEuclideanDistanceUDF) + // spark.udf.register("calculateCosineDistance", calculateCosineDistanceUDF) + // spark.udf.register("calculateJaccardDistance", calculateJaccardDistanceUDF) + // // println(float2DDataNeighbor.length) + // val schema = StructType(Array( + // StructField("IndexId", IntegerType, true), + // StructField("Embedding", ArrayType(FloatType), true, new MetadataBuilder() + // .putString(LakeSoulOptions.SchemaFieldMetadata.LSH_EMBEDDING_DIMENSION, "784") + // .putString(LakeSoulOptions.SchemaFieldMetadata.LSH_BIT_WIDTH, "512") + // .putString(LakeSoulOptions.SchemaFieldMetadata.LSH_RNG_SEED, "1234567890").build()), + // StructField("Embedding_LSH", ArrayType(LongType), true) + // )) + // trainData match { + // case float2DData: Array[Array[Float]] => + // val classIds = (1 to float2DData.length).toArray + // + // val rows = float2DData.zip(classIds).map { + // case (embedding, indexId) => + // // Row(indexId, embedding) + // Row(indexId, embedding, null) + // } + // val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + // df.write.format("lakesoul") + // .option("hashPartitions", "IndexId") + // .option("hashBucketNum", 4) + // .option(LakeSoulOptions.SHORT_TABLE_NAME, "trainData") + // .mode("Overwrite").save(trainPath) + // + // // val startTime1 = System.nanoTime() + // val lakeSoulTable = LakeSoulTable.forPath(trainPath) + // lakeSoulTable.compaction() + // + // testData match { + // case float2DTestData: Array[Array[Float]] => + // val classIdsTest = (1 to float2DTestData.length).toArray + // + // val rowsTest = float2DTestData.zip(classIdsTest).map { + // case (embedding, indexId) => + // // Row(indexId, embedding) + // Row(indexId, embedding, null) + // } + // + // val num = 50 + // val dfTest = spark.createDataFrame(spark.sparkContext.parallelize(rowsTest), schema).limit(num) + // dfTest.write.format("lakesoul") + // .option("hashPartitions", "IndexId") + // .option("hashBucketNum", 4) + // .option(LakeSoulOptions.SHORT_TABLE_NAME, "testData") + // .mode("Overwrite").save(testPath) + // val lakeSoulTableTest = LakeSoulTable.forPath(testPath) + // + // lakeSoulTableTest.compaction() + // // val endTime1 = System.nanoTime() + // // val duration1 = (endTime1 - startTime1).nanos + // // println(s"time:${duration1.toMillis}") + // // val lshTrain = sql("select LSH from trainData") + // // val lshTest = sql("select LSH from testData") + // + // // val arr = Array(1,5,10,20,40,60,80,100,150,200,250,300) + // // for(n <- arr) { + // val n = 3 + // val topk = 100 + // val topkFirst = n * topk + // + // // val result = sql("select testData.IndexId as indexId,trainData.LSH as trainLSH,testData.LSH as testLSH," + + // // "calculateHammingDistance(testData.LSH,trainData.LSH) AS hamming_distance " + + // // "from testData " + + // // "cross join trainData " + + // // "order by indexId,hamming_distance") + // + // // val result = spark.sql(s""" + // // SELECT * + // // FROM ( + // // SELECT + // // testData.IndexId AS indexIdTest, + // // trainData.IndexId AS indexIdTrain, + // // trainData.LSH AS trainLSH, + // // testData.LSH AS testLSH, + // // calculateHammingDistance(testData.LSH, trainData.LSH) AS hamming_distance, + // // ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank + // // FROM testData + // // CROSS JOIN trainData + // // ) ranked + // // WHERE rank <= $topk + // // """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) + // val startTime = System.nanoTime() + // val result = spark.sql( + // s""" + // SELECT * + // FROM ( + // SELECT + // testData.IndexId AS indexIdTest, + // trainData.IndexId AS indexIdTrain, + // testData.Embedding as EmbeddingTest, + // trainData.Embedding as EmbeddingTrain, + // ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.Embedding_LSH, trainData.Embedding_LSH) asc) AS rank + // FROM testData + // CROSS JOIN trainData + // ) ranked + // WHERE rank <= $topkFirst + // """) + // result.createOrReplaceTempView("rank") + // val reResult = spark.sql( + // s""" + // SELECT * + // FROM ( + // SELECT + // rank.indexIdTest, + // rank.indexIDTrain, + // ROW_NUMBER() OVER(PARTITION BY rank.indexIdTest ORDER BY calculateEuclideanDistance(rank.EmbeddingTest,rank.EmbeddingTrain) asc) AS reRank + // FROM rank + // ) reRanked + // WHERE reRank <= $topk + // """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) + // + // + // val endTime = System.nanoTime() + // val duration = (endTime - startTime).nanos + // println(s"time for query n4topk ${n} :${duration.toMillis} milliseconds") + // + // val startTime2 = System.nanoTime() + // + // val (totalRecall, count) = reResult.map(row => { + // val indexIdTest = row.getAs[Int]("indexIdTest") + // val indexIdTrainList: Array[Int] = row.getAs[Seq[Int]]("indexIdTrainList").toArray + // val updatedList = indexIdTrainList.map(_ - 1) + // val count = float2DDataNeighbor(indexIdTest - 1).take(topk).count(updatedList.contains) + // val recall = (count * 1.0 / topk) + // (recall, 1) + // }).reduce((acc1, acc2) => { + // (acc1._1 + acc2._1, acc1._2 + acc2._2) + // }) + // println(s"recall rate = ${totalRecall / count}") + // val endTime2 = System.nanoTime() + // val duration2 = (endTime2 - startTime2).nanos + // println(s"time for sort:${duration2.toMillis} milliseconds") + // // } + // } + // case _ => + // println("unexpected data type") + // case _ => + // println("unexpected data type") + // } + // } + // finally { + // + // } + // } + // } + // } test("merge into table with hash partition -- supported case") { initHashTable() diff --git a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala index b75e375ec..c0c7a27a5 100644 --- a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala +++ b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala @@ -4,6 +4,8 @@ package org.apache.spark.sql.arrow +import com.dmetasoul.lakesoul.meta.LakeSoulOptions +import com.dmetasoul.lakesoul.meta.LakeSoulOptions.SchemaFieldMetadata.{LSH_BIT_WIDTH, LSH_EMBEDDING_DIMENSION, LSH_RNG_SEED} import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} @@ -167,11 +169,20 @@ object ArrowUtils { def toArrowSchema(schema: StructType, timeZoneId: String = "UTC"): Schema = { new Schema(schema.map { field => val comment = field.getComment - val metadata = if (comment.isDefined) { - val map = new util.HashMap[String, String] - map.put("spark_comment", comment.get) - map - } else null + val metadata = new util.HashMap[String, String] + if (field.metadata.contains(LSH_EMBEDDING_DIMENSION)) { + metadata.put(LSH_EMBEDDING_DIMENSION, field.metadata.getString(LSH_EMBEDDING_DIMENSION)) + } + if (field.metadata.contains(LSH_BIT_WIDTH)) { + metadata.put(LSH_BIT_WIDTH, field.metadata.getString(LSH_BIT_WIDTH)) + } + if (field.metadata.contains(LSH_RNG_SEED)) { + metadata.put(LSH_RNG_SEED, field.metadata.getString(LSH_RNG_SEED)) + } + + if (comment.isDefined) { + metadata.put("spark_comment", comment.get) + } toArrowField(field.name, field.dataType, field.nullable, timeZoneId, field.metadata, metadata) }.asJava) } diff --git a/rust/lakesoul-io/src/lakesoul_io_config.rs b/rust/lakesoul-io/src/lakesoul_io_config.rs index eeeb3ed92..87c5b52b3 100644 --- a/rust/lakesoul-io/src/lakesoul_io_config.rs +++ b/rust/lakesoul-io/src/lakesoul_io_config.rs @@ -51,9 +51,8 @@ pub static OPTION_KEY_HASH_BUCKET_ID: &str = "hash_bucket_id"; pub static OPTION_KEY_CDC_COLUMN: &str = "cdc_column"; pub static OPTION_KEY_IS_COMPACTED: &str = "is_compacted"; pub static OPTION_KEY_MAX_FILE_SIZE: &str = "max_file_size"; -pub static OPTION_KEY_IS_LSH: &str = "is_lsh"; -pub static OPTION_KEY_NBITS: &str = "nbits"; -pub static OPTION_KEY_D: &str= "d"; +pub static OPTION_KEY_COMPUTE_LSH: &str = "compute_lsh"; + #[derive(Debug, Derivative)] #[derivative(Default, Clone)] @@ -196,17 +195,10 @@ impl LakeSoulIOConfig { self.option(OPTION_KEY_IS_COMPACTED).map_or(false, |x| x.eq("true")) } - pub fn is_lsh(&self) -> bool { - self.option(OPTION_KEY_IS_LSH).map_or(false,|x| x.eq("true")) - } - - pub fn nbits(&self) -> Option{ - self.option(OPTION_KEY_NBITS).map(|x| x.parse().unwrap()) + pub fn compute_lsh(&self) -> bool { + self.option(OPTION_KEY_COMPUTE_LSH).map_or(true,|x| x.eq("true")) } - pub fn d(&self) -> Option{ - self.option(OPTION_KEY_D).map(|x| x.parse().unwrap()) - } } #[derive(Derivative, Debug)] diff --git a/rust/lakesoul-io/src/lakesoul_writer.rs b/rust/lakesoul-io/src/lakesoul_writer.rs index 5d95ac711..5deeec54d 100644 --- a/rust/lakesoul-io/src/lakesoul_writer.rs +++ b/rust/lakesoul-io/src/lakesoul_writer.rs @@ -4,20 +4,16 @@ use std::borrow::Borrow; use std::collections::HashMap; -use std:: ptr; use arrow::datatypes::{DataType, Field}; use arrow_array::RecordBatch; -use arrow_schema::SchemaRef; +use arrow_schema::{SchemaBuilder, SchemaRef}; use datafusion_common::{DataFusionError, Result}; use rand::distributions::DistString; use tokio::runtime::Runtime; use tokio::sync::Mutex; use tracing::debug; -use rand::{Rng,SeedableRng,rngs::StdRng}; -use ndarray::{concatenate, s, Array2, Axis,ArrayView2}; -use arrow::array::{Array as OtherArray, Float64Array, ListArray,Float32Array,Int64Array,GenericListArray}; -use arrow::buffer::OffsetBuffer; +use arrow::array::{Array as OtherArray, ListArray}; use std::sync::Arc; @@ -25,6 +21,7 @@ use crate::async_writer::{AsyncBatchWriter, MultiPartAsyncWriter, PartitioningAs use crate::helpers::{get_batch_memory_size, get_file_exist_col}; use crate::lakesoul_io_config::{IOSchema, LakeSoulIOConfig}; use crate::transform::uniform_schema; +use crate::local_sensitive_hash::LSH; pub type SendableWriter = Box; @@ -37,6 +34,7 @@ pub struct SyncSendableMutableLakeSoulWriter { /// The in-progress file writer if any in_progress: Option>>, flush_results: WriterFlushResult, + lsh_computers: HashMap } impl SyncSendableMutableLakeSoulWriter { @@ -45,6 +43,37 @@ impl SyncSendableMutableLakeSoulWriter { runtime.clone().block_on(async move { let mut config = config.clone(); let writer_config = config.clone(); + + // Initialize HashMap instead of Vec + let mut lsh_computers = HashMap::new(); + + if config.compute_lsh() { + let mut target_schema_builder = SchemaBuilder::from(&config.target_schema().fields); + + for field in config.target_schema().fields.iter() { + if let Some(lsh_bit_width) = field.metadata().get("lsh_bit_width") { + let lsh_column_name = format!("{}_LSH", field.name()); + let lsh_column = Arc::new(Field::new( + lsh_column_name.clone(), + DataType::List(Arc::new(Field::new("element", DataType::Int64, true))), + true + )); + target_schema_builder.try_merge(&lsh_column)?; + + // Store LSH computer with column name as key + let lsh = LSH::new( + lsh_bit_width.parse().unwrap(), + field.metadata().get("lsh_embedding_dimension") + .map(|d| d.parse().unwrap()) + .unwrap_or(0), + config.seed + ); + lsh_computers.insert(field.name().to_string(), lsh); + } + } + config.target_schema = IOSchema(Arc::new(target_schema_builder.finish())); + } + let writer = Self::create_writer(writer_config).await?; let schema = writer.schema(); if let Some(mem_limit) = config.mem_limit() { @@ -58,9 +87,10 @@ impl SyncSendableMutableLakeSoulWriter { Ok(SyncSendableMutableLakeSoulWriter { in_progress: Some(Arc::new(Mutex::new(writer))), runtime, - schema, // this should be the final written schema + schema, config, flush_results: vec![], + lsh_computers, }) }) } @@ -129,29 +159,41 @@ impl SyncSendableMutableLakeSoulWriter { // for ffi callers pub fn write_batch(&mut self, record_batch: RecordBatch) -> Result<()> { let runtime = self.runtime.clone(); - if record_batch.num_rows() == 0{ + if record_batch.num_rows() == 0 { runtime.block_on(async move { self.write_batch_async(record_batch, false).await }) - } - else{ - if self.config.is_lsh() { - let projection: ListArray= if let Some(array) = record_batch.column_by_name("Embedding") { - let embedding = array.as_any().downcast_ref::().unwrap(); - let projection_result:Result = self.lsh(&Some(embedding.clone())); - projection_result.unwrap().into() - - } else { - eprintln!("there is no column named Embedding"); - return Ok(()) ; - }; - + } else { + if self.config.compute_lsh() { let mut new_columns = record_batch.columns().to_vec(); - new_columns[record_batch.schema().index_of("LSH").unwrap()] = Arc::new(projection.clone()); - let new_record_batch = RecordBatch::try_new(self.config.target_schema(),new_columns).unwrap(); - - runtime.block_on(async move { self.write_batch_async(new_record_batch, false).await }) - } - else{ - runtime.block_on(async move { self.write_batch_async(record_batch, false).await }) + + for (field_name, lsh_computer) in self.lsh_computers.iter() { + let lsh_column_name = format!("{}_LSH", field_name); + + if let Some(array) = record_batch.column_by_name(field_name) { + let embedding = array.as_any().downcast_ref::() + .ok_or_else(|| DataFusionError::Internal( + format!("Column {} is not a ListArray", field_name) + ))?; + + let lsh_array = lsh_computer.compute_lsh(&Some(embedding.clone()))?; + + if let Ok(index) = record_batch.schema().index_of(lsh_column_name.as_str()) { + new_columns[index] = Arc::new(lsh_array); + } + } + } + + let new_record_batch = RecordBatch::try_new( + self.config.target_schema(), + new_columns + )?; + + runtime.block_on(async move { + self.write_batch_async(new_record_batch, false).await + }) + } else { + runtime.block_on(async move { + self.write_batch_async(record_batch, false).await + }) } } } @@ -294,174 +336,6 @@ impl SyncSendableMutableLakeSoulWriter { self.schema.clone() } - // generate random digit with fixed seed - fn create_rng_with_seed(&self) -> StdRng { - StdRng::seed_from_u64(self.config.seed) - } - - // generate random planes - fn generate_random_array(&self) -> Result,String>{ - match self.config.nbits() { - Some(nbits) if nbits > 0 => { - match self.config.d() { - Some(d) if d > 0 => { - let mut rng = self.create_rng_with_seed(); -// assert!(d >= nbits,"the dimension of the embedding must be greater than nbits"); - let random_array = Array2::from_shape_fn((nbits as usize, d as usize), |_| rng.gen_range(-1.0..1.0)); - Ok(random_array) - } - Some(_) => Err("the dimension you input in the config must be greater than 0".to_string()), - None => Err("the dimension you input in the config is None".to_string()), - } - } - Some(_) => Err("the number of bits used for binary encoding must be greater than 0".to_string()), - None => Err("the number of bits used for binary encoding must be greater than 0".to_string()), - } - } - - // project the input data - fn project(&self,input_data:&ListArray,random_plans:&Result,String>) -> Result,String>{ - let list_len = input_data.len(); - assert!(list_len > 0,"the length of input data must be large than 0"); - let dimension_len = input_data.value(0).len(); - - let input_values = if let Some(values) = input_data.values().as_any().downcast_ref::(){ - let float64_values: Vec = values.iter().map(|x| x.unwrap() as f64).collect(); - Float64Array::from(float64_values) - } else if let Some(values) = input_data.values().as_any().downcast_ref::(){ - values.clone() - } - else { - return Err("Unsupported data type in ListArray.".to_string()); - }; - - let mut re_array2 = Array2::::zeros((list_len,dimension_len)); - - unsafe { - let data_ptr = input_values.values().as_ptr(); - let data_size = list_len * dimension_len; - ptr::copy_nonoverlapping(data_ptr,re_array2.as_mut_ptr(),data_size); - } - match random_plans { - Ok(random_array) => { - assert!(re_array2.shape()[1] == random_array.shape()[1],"the dimension corresponding to the matrix must be the same"); -// let final_result = re_array2.dot(&random_array.t()); - let batch_size = 1000; - let num_batches = re_array2.shape()[0] / batch_size; - let remaining_rows = re_array2.shape()[0] % batch_size; - let mut result = vec![]; - - for batch_idx in 0..num_batches{ - let batch_start = batch_idx * batch_size; - let batch_end = batch_start + batch_size; - - let current_batch = re_array2.slice(s![batch_start..batch_end,..]); - let random_projection = current_batch.dot(&random_array.t()); - - result.push(random_projection); - } - - if remaining_rows > 0{ - let batch_start = num_batches * batch_size; - let batch_end = batch_start + remaining_rows; - - let remaining_batch = re_array2.slice(s![batch_start..batch_end,..]); - - let random_projection = remaining_batch.dot(&random_array.t()); - - result.push(random_projection); - } - - let result_views: Vec> = result.iter().map(|arr| ArrayView2::from(arr)).collect(); - - - let final_result = concatenate(Axis(0),&result_views).expect("Failed to concatenate results"); - - // println!("{:}",end); - - Ok(final_result) - } - Err(e) => { - eprintln!("Error:{}",e); - Err(e.to_string()) - } - } - } - // add the input data with their projection - pub fn lsh(&self,input_embedding:&Option) -> Result - where - { - match input_embedding { - Some(data) => { - let random_plans = self.generate_random_array(); - let data_projection = self.project(data,&random_plans).unwrap(); - match Ok(data_projection) { - Ok(mut projection) => { - projection.mapv_inplace(|x| if x >= 0.0 {1.0} else {0.0}); - let convert:Vec> = Self::convert_array_to_u64_vec(&projection); - Ok(Self::convert_vec_to_byte_u64(convert)) - } - Err(e) => { - eprintln!("Error:{}",e); - Err(e) - } - } - } - None => { - Err("the input data is None".to_string()) - } - } - } - - fn convert_vec_to_byte_u64(array:Vec>) -> ListArray { - let field = Arc::new(Field::new("element", DataType::Int64,true)); - let values = Int64Array::from(array.iter().flatten().map(|&x| x as i64).collect::>()); - let mut offsets = vec![]; - for subarray in array{ - let current_offset = subarray.len() as usize; - offsets.push(current_offset); - } - let offsets_buffer = OffsetBuffer::from_lengths(offsets); - let list_array = GenericListArray::try_new(field,offsets_buffer,Arc::new(values),None).expect("can not list_array"); - list_array - - } - - fn convert_array_to_u64_vec(array:&Array2) -> Vec> - where - T: TryFrom + Copy, - >::Error: std::fmt::Debug, - { - let bianry_encode:Vec> = array - .axis_iter(ndarray::Axis(0)) - .map(|row|{ - let mut results = Vec::new(); - let mut acc = 0u64; - - for(i,&bit) in row.iter().enumerate(){ - acc = (acc << 1) | bit as u64; - if(i + 1) % 64 == 0{ - results.push(acc); - acc = 0; - } - } - if row.len() % 64 != 0{ - results.push(acc); - } - results - }) - .collect(); - - bianry_encode - .into_iter() - .map(|inner_vec|{ - inner_vec - .into_iter() - .map(|x| T::try_from(x).unwrap()) - .collect() - }).collect() - } - } #[cfg(test)] diff --git a/rust/lakesoul-io/src/lib.rs b/rust/lakesoul-io/src/lib.rs index 78a2cde41..c32704698 100644 --- a/rust/lakesoul-io/src/lib.rs +++ b/rust/lakesoul-io/src/lib.rs @@ -10,6 +10,7 @@ pub mod helpers; pub mod lakesoul_io_config; pub mod lakesoul_reader; pub mod lakesoul_writer; +pub mod local_sensitive_hash; mod projection; pub mod repartition; pub mod sorted_merge; diff --git a/rust/lakesoul-io/src/local_sensitive_hash/mod.rs b/rust/lakesoul-io/src/local_sensitive_hash/mod.rs new file mode 100644 index 000000000..8ea1d0aaa --- /dev/null +++ b/rust/lakesoul-io/src/local_sensitive_hash/mod.rs @@ -0,0 +1,182 @@ +use std::ptr; +use arrow::array::{Int64Array,Array, Float32Array, Float64Array, GenericListArray, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{DataFusionError, Result}; +use ndarray::{Array2, s, Axis, concatenate, ArrayView2}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct LSH { + nbits: u32, + d: u32, + seed: u64, +} + +impl LSH { + pub fn new(nbits: u32, d: u32, seed: u64) -> Self { + Self { nbits, d, seed } + } + + // generate random digit with fixed seed + fn create_rng_with_seed(&self) -> StdRng { + StdRng::seed_from_u64(self.seed) + } + + // generate random planes + fn generate_random_array(&self) -> Result> { + if self.nbits > 0 { + if self.d > 0 { + let mut rng = self.create_rng_with_seed(); + let random_array = Array2::from_shape_fn( + (self.nbits as usize, self.d as usize), + |_| rng.gen_range(-1.0..1.0) + ); + Ok(random_array) + } else { + Err(DataFusionError::Internal( + "the dimension you input in the config must be greater than 0".to_string() + )) + } + } else { + Err(DataFusionError::Internal( + "the number of bits used for binary encoding must be greater than 0".to_string() + )) + } + } + + // project the input data + fn project(&self, input_data: &ListArray, random_plans: &Result>) -> Result> { + let list_len = input_data.len(); + assert!(list_len > 0, "the length of input data must be large than 0"); + let dimension_len = input_data.value(0).len(); + + let input_values = if let Some(values) = input_data.values().as_any().downcast_ref::() { + let float64_values: Vec = values.iter().map(|x| x.unwrap() as f64).collect(); + Float64Array::from(float64_values) + } else if let Some(values) = input_data.values().as_any().downcast_ref::() { + values.clone() + } else { + return Err(DataFusionError::Internal("Unsupported data type in ListArray.".to_string())); + }; + + let mut re_array2 = Array2::::zeros((list_len, dimension_len)); + + unsafe { + let data_ptr = input_values.values().as_ptr(); + let data_size = list_len * dimension_len; + ptr::copy_nonoverlapping(data_ptr, re_array2.as_mut_ptr(), data_size); + } + + match random_plans { + Ok(random_array) => { + assert!(re_array2.shape()[1] == random_array.shape()[1], + "the dimension corresponding to the matrix must be the same"); + + let batch_size = 1000; + let num_batches = re_array2.shape()[0] / batch_size; + let remaining_rows = re_array2.shape()[0] % batch_size; + let mut result = vec![]; + + for batch_idx in 0..num_batches { + let batch_start = batch_idx * batch_size; + let batch_end = batch_start + batch_size; + + let current_batch = re_array2.slice(s![batch_start..batch_end,..]); + let random_projection = current_batch.dot(&random_array.t()); + + result.push(random_projection); + } + + if remaining_rows > 0 { + let batch_start = num_batches * batch_size; + let batch_end = batch_start + remaining_rows; + + let remaining_batch = re_array2.slice(s![batch_start..batch_end,..]); + let random_projection = remaining_batch.dot(&random_array.t()); + + result.push(random_projection); + } + + let result_views: Vec> = result.iter() + .map(|arr| ArrayView2::from(arr)) + .collect(); + + let final_result = concatenate(Axis(0), &result_views) + .expect("Failed to concatenate results"); + + Ok(final_result) + } + Err(e) => { + eprintln!("Error:{}", e); + Err(DataFusionError::Internal(e.to_string())) + } + } + } + + fn convert_vec_to_byte_u64(array: Vec>) -> ListArray { + let field = Arc::new(Field::new("element", DataType::Int64, true)); + let values = Int64Array::from(array.iter().flatten().map(|&x| x as i64).collect::>()); + let mut offsets = vec![]; + for subarray in array { + let current_offset = subarray.len() as usize; + offsets.push(current_offset); + } + let offsets_buffer = OffsetBuffer::from_lengths(offsets); + let list_array = GenericListArray::try_new(field, offsets_buffer, Arc::new(values), None) + .expect("can not create list_array"); + list_array + } + + fn convert_array_to_u64_vec(array: &Array2) -> Vec> + where + T: TryFrom + Copy, + >::Error: std::fmt::Debug, + { + let binary_encode: Vec> = array + .axis_iter(ndarray::Axis(0)) + .map(|row| { + let mut results = Vec::new(); + let mut acc = 0u64; + + for (i, &bit) in row.iter().enumerate() { + acc = (acc << 1) | bit as u64; + if (i + 1) % 64 == 0 { + results.push(acc); + acc = 0; + } + } + if row.len() % 64 != 0 { + results.push(acc); + } + results + }) + .collect(); + + binary_encode + .into_iter() + .map(|inner_vec| { + inner_vec + .into_iter() + .map(|x| T::try_from(x).unwrap()) + .collect() + }) + .collect() + } + + // add the input data with their projection + pub fn compute_lsh(&self, input_embedding: &Option) -> Result { + match input_embedding { + Some(data) => { + let random_plans = self.generate_random_array(); + let data_projection = self.project(data, &random_plans)?; + let mut projection = data_projection; + projection.mapv_inplace(|x| if x >= 0.0 { 1.0 } else { 0.0 }); + let convert: Vec> = Self::convert_array_to_u64_vec(&projection); + Ok(Self::convert_vec_to_byte_u64(convert)) + } + None => Err(DataFusionError::Internal("the input data is None".to_string())), + } + } +}