Skip to content

Commit

Permalink
[Spark/NativeIO] fix lsh hard coding & apply lsh compute with schema …
Browse files Browse the repository at this point in the history
…metadata (#571)

Signed-off-by: zenghua <huazeng@dmetasoul.com>
Co-authored-by: zenghua <huazeng@dmetasoul.com>
  • Loading branch information
Ceng23333 and zenghua authored Jan 14, 2025
1 parent 035f5cf commit ebb8f45
Show file tree
Hide file tree
Showing 7 changed files with 489 additions and 415 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ object LakeSoulOptions {
val TIME_ZONE = "timezone"
val DISCOVERY_INTERVAL = "discoveryinterval"

object SchemaFieldMetadata {
val LSH_EMBEDDING_DIMENSION = "lsh_embedding_dimension"
val LSH_BIT_WIDTH = "lsh_bit_width"
val LSH_RNG_SEED = "lsh_rng_seed"
}

object ReadType extends Enumeration {
val FULL_READ = "fullread"
val SNAPSHOT_READ = "snapshot"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package org.apache.spark.sql.arrow

import com.dmetasoul.lakesoul.meta.LakeSoulOptions
import com.dmetasoul.lakesoul.meta.LakeSoulOptions.SchemaFieldMetadata.{LSH_BIT_WIDTH, LSH_EMBEDDING_DIMENSION, LSH_RNG_SEED}
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
Expand Down Expand Up @@ -167,11 +169,20 @@ object ArrowUtils {
def toArrowSchema(schema: StructType, timeZoneId: String = "UTC"): Schema = {
new Schema(schema.map { field =>
val comment = field.getComment
val metadata = if (comment.isDefined) {
val map = new util.HashMap[String, String]
map.put("spark_comment", comment.get)
map
} else null
val metadata = new util.HashMap[String, String]
if (field.metadata.contains(LSH_EMBEDDING_DIMENSION)) {
metadata.put(LSH_EMBEDDING_DIMENSION, field.metadata.getString(LSH_EMBEDDING_DIMENSION))
}
if (field.metadata.contains(LSH_BIT_WIDTH)) {
metadata.put(LSH_BIT_WIDTH, field.metadata.getString(LSH_BIT_WIDTH))
}
if (field.metadata.contains(LSH_RNG_SEED)) {
metadata.put(LSH_RNG_SEED, field.metadata.getString(LSH_RNG_SEED))
}

if (comment.isDefined) {
metadata.put("spark_comment", comment.get)
}
toArrowField(field.name, field.dataType, field.nullable, timeZoneId, field.metadata, metadata)
}.asJava)
}
Expand Down
16 changes: 4 additions & 12 deletions rust/lakesoul-io/src/lakesoul_io_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ pub static OPTION_KEY_HASH_BUCKET_ID: &str = "hash_bucket_id";
pub static OPTION_KEY_CDC_COLUMN: &str = "cdc_column";
pub static OPTION_KEY_IS_COMPACTED: &str = "is_compacted";
pub static OPTION_KEY_MAX_FILE_SIZE: &str = "max_file_size";
pub static OPTION_KEY_IS_LSH: &str = "is_lsh";
pub static OPTION_KEY_NBITS: &str = "nbits";
pub static OPTION_KEY_D: &str= "d";
pub static OPTION_KEY_COMPUTE_LSH: &str = "compute_lsh";


#[derive(Debug, Derivative)]
#[derivative(Default, Clone)]
Expand Down Expand Up @@ -196,17 +195,10 @@ impl LakeSoulIOConfig {
self.option(OPTION_KEY_IS_COMPACTED).map_or(false, |x| x.eq("true"))
}

pub fn is_lsh(&self) -> bool {
self.option(OPTION_KEY_IS_LSH).map_or(false,|x| x.eq("true"))
}

pub fn nbits(&self) -> Option<u64>{
self.option(OPTION_KEY_NBITS).map(|x| x.parse().unwrap())
pub fn compute_lsh(&self) -> bool {
self.option(OPTION_KEY_COMPUTE_LSH).map_or(true,|x| x.eq("true"))
}

pub fn d(&self) -> Option<u64>{
self.option(OPTION_KEY_D).map(|x| x.parse().unwrap())
}
}

#[derive(Derivative, Debug)]
Expand Down
Loading

0 comments on commit ebb8f45

Please sign in to comment.