Skip to content

Commit

Permalink
[Spark/Rust] Support filter on nesting column name (#422)
Browse files Browse the repository at this point in the history
Signed-off-by: zenghua <huazeng@dmetasoul.com>
Co-authored-by: zenghua <huazeng@dmetasoul.com>
  • Loading branch information
Ceng23333 and zenghua authored Jan 17, 2024
1 parent 1027011 commit cc05f46
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.parquet.{NumRowGroupsAcc, ParquetFilters, ParquetTest, SparkToParquetSchemaConverter}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.parquet.{NativeParquetScan, ParquetScan}
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, ParquetOutputTimestampType}
import org.apache.spark.sql.lakesoul.catalog.LakeSoulCatalog
Expand All @@ -39,66 +40,6 @@ import java.time.{LocalDate, LocalDateTime, ZoneId}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

@RunWith(classOf[JUnitRunner])
class ParquetV2FilterSuite
extends ParquetFilterSuite {
override protected def sparkConf: SparkConf =
super
.sparkConf
.set(SQLConf.USE_V1_SOURCE_LIST, "")

override def checkFilterPredicate(
df: DataFrame,
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
checker: (DataFrame, Seq[Row]) => Unit,
expected: Seq[Row]): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct

withSQLConf(
SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true",
SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true",
SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true",
SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true",
// Disable adding filters from constraints because it adds, for instance,
// is-not-null to pushed filters, which makes it hard to test if the pushed
// filter is expected or not (this had to be fixed with SPARK-13495).
SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName,
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
val query = df
.select(output.map(e => Column(e)): _*)
.where(Column(predicate))

query.queryExecution.optimizedPlan.collectFirst {
case PhysicalOperation(_, filters,
DataSourceV2ScanRelation(_, scan: ParquetScan, _, _)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray
val pushedFilters = scan.pushedFilters
assert(pushedFilters.nonEmpty, "No filter is pushed down")
val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
val parquetFilters = createParquetFilters(schema)
// In this test suite, all the simple predicates are convertible here.
assert(parquetFilters.convertibleFilters(sourceFilters) === pushedFilters)
val pushedParquetFilters = pushedFilters.map { pred =>
val maybeFilter = parquetFilters.createFilter(pred)
assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred")
maybeFilter.get
}
// Doesn't bother checking type parameters here (e.g. `Eq[Integer]`)
assert(pushedParquetFilters.exists(_.getClass === filterClass),
s"${pushedParquetFilters.map(_.getClass).toList} did not contain ${filterClass}.")

checker(stripSparkFilter(query), expected)

case _ =>
throw new AnalysisException("Can not match ParquetTable in the query.")
}
}
}
}

@RunWith(classOf[JUnitRunner])
class ParquetNativeFilterSuite
extends ParquetFilterSuite
Expand Down Expand Up @@ -310,25 +251,27 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
"a", // zero nesting
(x: Any) => x
),
// (
// df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
// "a.b", // one level nesting
// (x: Any) => Row(x)),
// (
// df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
// "a.b.c", // two level nesting
// (x: Any) => Row(Row(x))
// ),
// (
// df.withColumnRenamed("temp", "a.b"),
// "`a.b`", // zero nesting with column name containing `dots`
// (x: Any) => x
// ),
// (
// df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
// "`a.b`.`c.d`", // one level nesting with column names containing `dots`
// (x: Any) => Row(x)
// )
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)
),

(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d")).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
)
}

Expand Down
109 changes: 54 additions & 55 deletions rust/lakesoul-io/src/filter/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

use arrow_schema::{DataType, Field, SchemaRef};
use arrow_schema::{DataType, Field, SchemaRef, Fields};
use datafusion::logical_expr::Expr;
use datafusion::prelude::col;
use datafusion::scalar::ScalarValue;
use std::ops::Not;
use std::sync::Arc;

pub struct Parser {}

Expand All @@ -25,50 +26,28 @@ impl Parser {
let inner = Parser::parse(right, schema);
Expr::not(inner)
} else {
let column = qualified_col_name(left.as_str(), schema.clone());
let column = datafusion::common::Column::new_unqualified(column);
match schema.column_with_name(&column.name.clone()) {
None => Expr::Literal(ScalarValue::Boolean(Some(false))),
Some((_, field)) => {
if matches!(field.data_type(), DataType::Struct(_)) {
col(column).is_not_null()
} else if right == "null" {
match op.as_str() {
"eq" => col(column).is_null(),
"noteq" => col(column).is_not_null(),
_ => Expr::Literal(ScalarValue::Boolean(Some(true))),
}
} else {
match op.as_str() {
"eq" => {
let value = Parser::parse_literal(field, right);
col(column).eq(value)
}
"noteq" => {
let value = Parser::parse_literal(field, right);
col(column).not_eq(value)
}
"gt" => {
let value = Parser::parse_literal(field, right);
col(column).gt(value)
}
"gteq" => {
let value = Parser::parse_literal(field, right);
col(column).gt_eq(value)
}
"lt" => {
let value = Parser::parse_literal(field, right);
col(column).lt(value)
}
"lteq" => {
let value = Parser::parse_literal(field, right);
col(column).lt_eq(value)
}

_ => Expr::Literal(ScalarValue::Boolean(Some(true))),
}
let expr_filed = qualified_expr(left.as_str(), schema.clone());
if let Some((expr, field)) = expr_filed {
if right == "null" {
match op.as_str() {
"eq" => expr.is_null(),
"noteq" => expr.is_not_null(),
_ => Expr::Literal(ScalarValue::Boolean(Some(true))),
}
}
} else {
let value = Parser::parse_literal(field, right);
match op.as_str() {
"eq" => expr.eq(value),
"noteq" => expr.not_eq(value),
"gt" => expr.gt(value),
"gteq" => expr.gt_eq(value),
"lt" => expr.lt(value),
"lteq" => expr.lt_eq(value),
_ => Expr::Literal(ScalarValue::Boolean(Some(true))),
}
}
} else {
Expr::Literal(ScalarValue::Boolean(Some(false)))
}
}
}
Expand Down Expand Up @@ -107,7 +86,7 @@ impl Parser {
}
}

fn parse_literal(field: &Field, value: String) -> Expr {
fn parse_literal(field: Arc<Field>, value: String) -> Expr {
let data_type = field.data_type().clone();
match data_type {
DataType::Decimal128(precision, scale) => {
Expand Down Expand Up @@ -174,17 +153,37 @@ impl Parser {
}
}

fn qualified_col_name(column: &str, schema: SchemaRef) -> &str {
if let Ok(_field) = schema.field_with_name(column) {
return column;
} else if let Some(dot) = column.find('.') {
if let Ok(field) = schema.field_with_name(&column[..dot]) {
if matches!(field.data_type(), DataType::Struct(_)) {
return &column[..dot];
}
fn qualified_expr(expr_str: &str, schema: SchemaRef) -> Option<(Expr, Arc<Field>)> {
if let Ok(field) = schema.field_with_name(expr_str) {
Some((col(datafusion::common::Column::new_unqualified(expr_str)), Arc::new(field.clone())))

} else {
let mut expr: Option<(Expr, Arc<Field>)> = None;
let mut root = "".to_owned();
let mut sub_fields: &Fields = schema.fields();
for expr_substr in expr_str.split('.').into_iter() {
root = if root.is_empty() {
expr_substr.to_owned()
} else {
format!("{}.{}", root, expr_substr)
};
if let Some((_, field)) = sub_fields.find(&root) {

expr = if let Some((folding_exp, _)) = expr {
Some((folding_exp.field(field.name()), field.clone()))
} else {
Some((col(datafusion::common::Column::new_unqualified(field.name())), field.clone()))
};
root = "".to_owned();

sub_fields = match field.data_type() {
DataType::Struct(struct_sub_fields) => &struct_sub_fields,
_ => sub_fields
};
}
}
}
column
expr
}
}

#[cfg(test)]
Expand Down
32 changes: 20 additions & 12 deletions rust/lakesoul-io/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,37 @@ use arrow::record_batch::RecordBatch;
use arrow_array::{
new_null_array, types::*, ArrayRef, BooleanArray, PrimitiveArray, RecordBatchOptions, StringArray, StructArray,
};
use arrow_schema::{DataType, Field, Schema, SchemaBuilder, SchemaRef, TimeUnit};
use arrow_schema::{DataType, Field, Schema, SchemaBuilder, SchemaRef, TimeUnit, FieldRef, Fields};
use datafusion::error::Result;
use datafusion_common::DataFusionError::{ArrowError, External};

use crate::constant::{ARROW_CAST_OPTIONS, LAKESOUL_EMPTY_STRING, LAKESOUL_NULL_STRING};

/// adjust time zone to UTC
pub fn uniform_field(orig_field: &FieldRef) -> FieldRef {
let data_type = orig_field.data_type();
match data_type {
DataType::Timestamp(unit, Some(_)) => Arc::new(Field::new(
orig_field.name(),
DataType::Timestamp(unit.clone(), Some(Arc::from(crate::constant::LAKESOUL_TIMEZONE))),
orig_field.is_nullable(),
)),
DataType::Struct(fields) => Arc::new(Field::new(
orig_field.name(),
DataType::Struct(Fields::from(fields.iter().map(uniform_field).collect::<Vec<_>>())),
orig_field.is_nullable()
)),
_ => orig_field.clone(),
}
}

/// adjust time zone to UTC
pub fn uniform_schema(orig_schema: SchemaRef) -> SchemaRef {
Arc::new(Schema::new(
orig_schema
.fields()
.iter()
.map(|field| {
let data_type = field.data_type();
match data_type {
DataType::Timestamp(unit, Some(_)) => Arc::new(Field::new(
field.name(),
DataType::Timestamp(unit.clone(), Some(Arc::from(crate::constant::LAKESOUL_TIMEZONE))),
field.is_nullable(),
)),
_ => field.clone(),
}
})
.map(uniform_field)
.collect::<Vec<_>>(),
))
}
Expand Down

0 comments on commit cc05f46

Please sign in to comment.