Skip to content

Commit

Permalink
Merge pull request #647 from rben01/compare-nan
Browse files Browse the repository at this point in the history
Fixed NaN comparison
  • Loading branch information
sharkdp authored Nov 8, 2024
2 parents 382128c + 9d9f379 commit 1eb15a6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 21 deletions.
4 changes: 2 additions & 2 deletions numbat/src/interpreter/assert_eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use compact_str::{format_compact, CompactString};
use std::fmt::Display;
use thiserror::Error;

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
pub struct AssertEq2Error {
pub span_lhs: Span,
pub lhs: Value,
Expand All @@ -28,7 +28,7 @@ impl Display for AssertEq2Error {
}
}

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
pub struct AssertEq3Error {
pub span_lhs: Span,
pub lhs_original: Quantity,
Expand Down
4 changes: 2 additions & 2 deletions numbat/src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use thiserror::Error;

pub use crate::value::Value;

#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[derive(Debug, Clone, Error, PartialEq)]
#[allow(clippy::large_enum_variant)]
pub enum RuntimeError {
#[error("Division by zero")]
Expand Down Expand Up @@ -67,7 +67,7 @@ pub enum RuntimeError {
FileWrite(std::path::PathBuf),
}

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq)]
#[must_use]
pub enum InterpreterResult {
Value(Value),
Expand Down
27 changes: 25 additions & 2 deletions numbat/src/quantity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ impl PartialEq for Quantity {
}
}

impl Eq for Quantity {}

impl PartialOrd for Quantity {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let other_converted = other.convert_to(self.unit()).ok()?;
Expand All @@ -347,7 +345,32 @@ impl PrettyPrint for Quantity {
}
}

pub(crate) enum QuantityOrdering {
IncompatibleUnits,
NanOperand,
Ok(std::cmp::Ordering),
}

impl Quantity {
/// partial_cmp that encodes whether comparison fails because its arguments have
/// incompatible units, or because one of them is NaN
pub(crate) fn partial_cmp_preserve_nan(&self, other: &Self) -> QuantityOrdering {
if self.value.to_f64().is_nan() || other.value.to_f64().is_nan() {
return QuantityOrdering::NanOperand;
}

let Ok(other_converted) = other.convert_to(self.unit()) else {
return QuantityOrdering::IncompatibleUnits;
};

let cmp = self
.value
.partial_cmp(&other_converted.value)
.expect("unexpectedly got a None partial_cmp from non-NaN arguments");

QuantityOrdering::Ok(cmp)
}

/// Pretty prints with the given options.
/// If options is None, default options will be used.
fn pretty_print_with_options(&self, options: Option<FmtFloatConfig>) -> crate::markup::Markup {
Expand Down
2 changes: 1 addition & 1 deletion numbat/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl std::fmt::Display for FunctionReference {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Quantity(Quantity),
Boolean(bool),
Expand Down
37 changes: 23 additions & 14 deletions numbat/src/vm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{HashMap, VecDeque};
use std::fmt::Display;
use std::sync::Arc;
use std::{cmp::Ordering, fmt::Display};

use compact_str::{CompactString, ToCompactString};
use indexmap::IndexMap;
Expand Down Expand Up @@ -753,22 +753,31 @@ impl Vm {
self.push(ret);
}
op @ (Op::LessThan | Op::GreaterThan | Op::LessOrEqual | Op::GreatorOrEqual) => {
use crate::quantity::QuantityOrdering;
use std::cmp::Ordering;

let rhs = self.pop_quantity();
let lhs = self.pop_quantity();

let result = lhs.partial_cmp(&rhs).ok_or_else(|| {
RuntimeError::QuantityError(QuantityError::IncompatibleUnits(
lhs.unit().clone(),
rhs.unit().clone(),
))
})?;

let result = match op {
Op::LessThan => result == Ordering::Less,
Op::GreaterThan => result == Ordering::Greater,
Op::LessOrEqual => result != Ordering::Greater,
Op::GreatorOrEqual => result != Ordering::Less,
_ => unreachable!(),
let result = match lhs.partial_cmp_preserve_nan(&rhs) {
QuantityOrdering::IncompatibleUnits => {
return Err(Box::new(RuntimeError::QuantityError(
QuantityError::IncompatibleUnits(
lhs.unit().clone(),
rhs.unit().clone(),
),
)))
}
QuantityOrdering::NanOperand => false,
QuantityOrdering::Ok(Ordering::Less) => {
matches!(op, Op::LessThan | Op::LessOrEqual)
}
QuantityOrdering::Ok(Ordering::Equal) => {
matches!(op, Op::LessOrEqual | Op::GreatorOrEqual)
}
QuantityOrdering::Ok(Ordering::Greater) => {
matches!(op, Op::GreaterThan | Op::GreatorOrEqual)
}
};

self.push(Value::Boolean(result));
Expand Down
28 changes: 28 additions & 0 deletions numbat/tests/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,34 @@ fn test_comparisons() {
expect_output("2 >= 2", "true");
expect_output("2 >= 2.1", "false");

// NaN comparison; all false

expect_output("NaN < NaN", "false");
expect_output("NaN < 0", "false");
expect_output("NaN < 0m", "false");
expect_output("0 < NaN", "false");
expect_output("0m < NaN", "false");

expect_output("NaN <= NaN", "false");
expect_output("NaN <= 0", "false");
expect_output("NaN <= 0m", "false");
expect_output("0 <= NaN", "false");
expect_output("0m <= NaN", "false");

expect_output("NaN > NaN", "false");
expect_output("NaN > 0", "false");
expect_output("NaN > 0m", "false");
expect_output("0 > NaN", "false");
expect_output("0m > NaN", "false");

expect_output("NaN >= NaN", "false");
expect_output("NaN >= 0", "false");
expect_output("NaN >= 0m", "false");
expect_output("0 >= NaN", "false");
expect_output("0m >= NaN", "false");

// equality

expect_output("200 cm == 2 m", "true");
expect_output("201 cm == 2 m", "false");

Expand Down

0 comments on commit 1eb15a6

Please sign in to comment.