Skip to content

Commit

Permalink
[TOSA] : Fix float to integer cast for torch.ops.aten.to lowering. (#…
Browse files Browse the repository at this point in the history
…3946)

The behavior of float -> integer cast in PyTorch (though I haven't found
the actual code implementing the cast) appears to be (based on the
results produced in PyTorch):

1. round the float nearest to zero (similar to `arith.fptosi/ui`) 
2. then perform the conversion

Currently we only emit `tosa.cast` for this operation but as per the
spec https://www.mlplatform.org/tosa/tosa_spec.html#_cast the rounding
performed for float -> integer is round to nearest integer (not zero).
Hence, the current TOSA lowering for `torch.ops.aten.to` produces
incorrect answer.
  • Loading branch information
sahas3 authored Jan 22, 2025
1 parent 9dd94fb commit 481da8d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 11 deletions.
20 changes: 19 additions & 1 deletion lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) {

Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
TensorType srcType = dyn_cast<TensorType>(src.getType());
Type srcElemTy = srcType.getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();

// Temporarily disable checkValidityOfCast as it's currently strictly
Expand Down Expand Up @@ -381,6 +382,23 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
equalToZero);
} else {
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger()) {
// for float->int conversion, tosa.cast performs round-to-nearest
// torch performs round-to-zero instead
// generate round-to-zero conversion prior to tosa.cast to match with
// expected torch behavior
auto floor = rewriter.create<tosa::FloorOp>(op->getLoc(), srcType, src);
auto ceil = rewriter.create<tosa::CeilOp>(op->getLoc(), srcType, src);

auto zeroValue =
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();

auto boolType = srcType.clone(rewriter.getIntegerType(1));
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, src);
src = tosa::CreateOpAndInfer<tosa::SelectOp>(
rewriter, op->getLoc(), srcType, isNegative, ceil, floor);
}
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
}
return success();
Expand Down
16 changes: 6 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,8 @@
"TriuIndicesNegativeOffsetModule_basic",
"BmmFloat16Module_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"LinspaceDtypeModule_basic",
"Aten_CastLongModule_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
Expand Down Expand Up @@ -2627,6 +2629,7 @@
}

ONNX_XFAIL_SET = {
"ToDtypeIntFromFloatModule_basic",
# This test is expected to time out
"TimeOutModule_basic",
# Failure - cast error
Expand Down Expand Up @@ -3333,6 +3336,7 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"ScatterAddDynamicModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"AtenFftRfft2DLastDim_basic",
Expand Down Expand Up @@ -3444,7 +3448,6 @@
"AtenSubFloatModule_basic",
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"Aten_CastLongModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
Expand Down Expand Up @@ -3501,7 +3504,6 @@
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"CopyWithDifferentDTypesModule_basic",
"CumsumInputDtypeInt32Module_basic",
"CumsumModule_basic",
"CumsumStaticModule_basic",
Expand Down Expand Up @@ -3544,7 +3546,6 @@
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -3577,16 +3578,13 @@
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_scales_recompute_bilinear",
"IntFloatModule_basic",
"IntImplicitModule_basic",
"IsFloatingPointFloat_True",
"IsFloatingPointInt_False",
"LenStrModule_basic",
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
Expand Down Expand Up @@ -3649,7 +3647,6 @@
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
Expand Down Expand Up @@ -3734,8 +3731,6 @@
"TensorToInt_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"ThresholdBackward2dMixedModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_empty",
Expand Down Expand Up @@ -4002,7 +3997,6 @@
"AtenTriuModule_basic",
"AtenTriuWithNegDiagonalModule_basic",
"AtenTriuWithPosDiagonalModule_basic",
"Aten_CastLongModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
Expand Down Expand Up @@ -4717,6 +4711,8 @@
"ToDtypeLayoutCPUModule_basic",
"ToDtypeLayoutNoneModule_basic",
"ToDtypeLayoutStridedModule_basic",
"ToDtypeIntFromFloatModule_basic",
"ToDtypeFloatFromIntModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,45 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5))


class ToDtypeFloatFromIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.to(
x,
dtype=torch.float32,
)


@register_test_case(module_factory=lambda: ToDtypeFloatFromIntModule())
def ToDtypeFloatFromIntModule_basic(module, tu: TestUtils):
input = torch.randint(low=-5, high=5, size=(2, 2)).to(torch.int64)
module.forward(input)


class ToDtypeIntFromFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.float64, True)])
def forward(self, x):
return torch.ops.aten.to(
x,
dtype=torch.int64,
)


@register_test_case(module_factory=lambda: ToDtypeIntFromFloatModule())
def ToDtypeIntFromFloatModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, low=-5, high=5)
input[1][1] = tu.randint(1, 1) + 0.7
module.forward(input)


class TypeAsSameModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,29 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten
return %0 : !torch.vtensor<[1,128],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor<f32>, tensor<3x5xf32>) -> tensor<3x5xi1>
// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64>
func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64>
return %0 : !torch.vtensor<[3,5],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.gather(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
Expand Down

0 comments on commit 481da8d

Please sign in to comment.