Skip to content

Commit

Permalink
fix: decimal conversion looses value on lower precision (#6836)
Browse files Browse the repository at this point in the history
* decimal conversion looses value on lower precision, throws error now on overflow.

* fix review comments and fix formatting.

* for simple case of equal scale and bigger precision, no conversion needed.

revert whitespace changes

formatting check

---------

Co-authored-by: himadripal <[email protected]>
  • Loading branch information
himadripal and himadripal authored Dec 12, 2024
1 parent 84dba34 commit eb7ab83
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 32 deletions.
49 changes: 28 additions & 21 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ where
O::Native::from_decimal(adjusted)
};

Ok(match cast_options.safe {
true => array.unary_opt(f),
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
Ok(if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
})?
})
}

Expand All @@ -137,15 +141,20 @@ where

let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());

Ok(match cast_options.safe {
true => array.unary_opt(f),
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
Ok(if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
})?
})
}

// Only support one type of decimal cast operations
pub(crate) fn cast_decimal_to_decimal_same_type<T>(
array: &PrimitiveArray<T>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
Expand All @@ -155,29 +164,27 @@ where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
Ordering::Equal => {
// the scale doesn't change, the native value don't need to be changed
let array: PrimitiveArray<T> =
if input_scale == output_scale && input_precision <= output_precision {
array.clone()
}
Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?,
Ordering::Less => {
// input_scale < output_scale
} else if input_scale < output_scale {
// the scale doesn't change, but precision may change and cause overflow
convert_to_bigger_or_equal_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?
}
};
} else {
convert_to_smaller_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?
};

Ok(Arc::new(array.with_precision_and_scale(
output_precision,
Expand Down
99 changes: 88 additions & 11 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,18 +830,20 @@ pub fn cast_with_options(
(Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => {
cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned())
}
(Decimal128(_, s1), Decimal128(p2, s2)) => {
(Decimal128(p1, s1), Decimal128(p2, s2)) => {
cast_decimal_to_decimal_same_type::<Decimal128Type>(
array.as_primitive(),
*p1,
*s1,
*p2,
*s2,
cast_options,
)
}
(Decimal256(_, s1), Decimal256(p2, s2)) => {
(Decimal256(p1, s1), Decimal256(p2, s2)) => {
cast_decimal_to_decimal_same_type::<Decimal256Type>(
array.as_primitive(),
*p1,
*s1,
*p2,
*s2,
Expand Down Expand Up @@ -2694,13 +2696,16 @@ mod tests {
// negative test
let array = vec![Some(123456), None];
let array = create_decimal_array(array, 10, 0).unwrap();
let result = cast(&array, &DataType::Decimal128(2, 2));
assert!(result.is_ok());
let array = result.unwrap();
let array: &Decimal128Array = array.as_primitive();
let err = array.validate_decimal_precision(2);
let result_safe = cast(&array, &DataType::Decimal128(2, 2));
assert!(result_safe.is_ok());
let options = CastOptions {
safe: false,
..Default::default()
};

let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options);
assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99",
err.unwrap_err().to_string());
result_unsafe.unwrap_err().to_string());
}

#[test]
Expand Down Expand Up @@ -8460,7 +8465,7 @@ mod tests {
let input_type = DataType::Decimal128(10, 3);
let output_type = DataType::Decimal256(10, 5);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(i128::MAX), Some(i128::MIN)];
let array = vec![Some(123456), Some(-123456)];
let input_decimal_array = create_decimal_array(array, 10, 3).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;

Expand All @@ -8470,8 +8475,8 @@ mod tests {
Decimal256Array,
&output_type,
vec![
Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)),
Some(i256::from_i128(i128::MIN).mul_wrapping(hundred))
Some(i256::from_i128(123456).mul_wrapping(hundred)),
Some(i256::from_i128(-123456).mul_wrapping(hundred))
]
);
}
Expand Down Expand Up @@ -9935,4 +9940,76 @@ mod tests {
"Cast non-nullable to non-nullable struct field returning null should fail",
);
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal128(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 4);
let output_type = DataType::Decimal128(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal128(6, 3);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal256(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999");
}
}

0 comments on commit eb7ab83

Please sign in to comment.