From 78c00c38804a274a4454b486ee3cecb5684f7890 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Wed, 18 Dec 2024 12:20:02 -0300 Subject: [PATCH 1/4] micro-optimize utf8 helper functions Signed-off-by: martinvuyk --- stdlib/src/utils/string_slice.mojo | 45 ++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index c465811f1b..af33847ce0 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -68,8 +68,8 @@ fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int: debug_assert( 0 <= c <= 0x10FFFF, "Value: ", c, " is not a valid Unicode code point" ) - alias sizes = SIMD[DType.int32, 4](0, 0b0111_1111, 0b0111_1111_1111, 0xFFFF) - return int((sizes < c).cast[DType.uint8]().reduce_add()) + alias sizes = SIMD[DType.uint32, 4](0, 0x80, 0x8_00, 0x1_00_00) + return int((sizes <= c).cast[DType.uint8]().reduce_add()) @always_inline @@ -81,10 +81,12 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int: (b & 0b1100_0000) != 0b1000_0000, "Function does not work correctly if given a continuation byte.", ) - return int(count_leading_zeros(~b)) + int(b < 0b1000_0000) + return int(count_leading_zeros(~b) | (b < 0b1000_0000).cast[DType.uint8]()) -fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): +fn _shift_unicode_to_utf8[ + optimize_ascii: Bool = True +](ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): """Shift unicode to utf8 representation. ### Unicode (represented as UInt32 BE) to UTF-8 conversion: @@ -99,19 +101,32 @@ fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): - (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000, d | 0b10000000 """ - if num_bytes == 1: - ptr[0] = UInt8(c) - return - var shift = 6 * (num_bytes - 1) - var mask = UInt8(0xFF) >> (num_bytes + 1) - var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) - ptr[0] = ((c >> shift) & mask) | num_bytes_marker - for i in range(1, num_bytes): - shift -= 6 - ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + @parameter + if optimize_ascii: + if likely(num_bytes == 1): + ptr[0] = UInt8(c) + return + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + 1) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | num_bytes_marker + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + else: + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + int(num_bytes > 1)) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | ( + num_bytes_marker & (int(num_bytes == 1) - 1) + ) + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 +@always_inline fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): """UTF-8 byte type. @@ -126,7 +141,7 @@ fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): - 3 -> start of 3 byte long sequence. - 4 -> start of 4 byte long sequence. """ - return count_leading_zeros(~(b & UInt8(0b1111_0000))) + return count_leading_zeros(~b) @always_inline From 0a439e776a5a4e494df46ac11e483416e8595d75 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Wed, 18 Dec 2024 13:59:57 -0300 Subject: [PATCH 2/4] fix chr implementation and add a testcase that would previously fail Signed-off-by: martinvuyk --- stdlib/src/collections/string.mojo | 6 +----- stdlib/test/collections/test_string.mojo | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/stdlib/src/collections/string.mojo b/stdlib/src/collections/string.mojo index 06cd2afeb5..7a1816c82f 100644 --- a/stdlib/src/collections/string.mojo +++ b/stdlib/src/collections/string.mojo @@ -132,15 +132,11 @@ fn chr(c: Int) -> String: Examples: ```mojo - print(chr(97)) # "a" - print(chr(8364)) # "€" + print(chr(97), chr(8364)) # "a €" ``` . """ - if c < 0b1000_0000: # 1 byte ASCII char - return String(String._buffer_type(c, 0)) - var num_bytes = _unicode_codepoint_utf8_byte_length(c) var p = UnsafePointer[UInt8].alloc(num_bytes + 1) _shift_unicode_to_utf8(p, c, num_bytes) diff --git a/stdlib/test/collections/test_string.mojo b/stdlib/test/collections/test_string.mojo index 4d9151b279..59fdfa1e9c 100644 --- a/stdlib/test/collections/test_string.mojo +++ b/stdlib/test/collections/test_string.mojo @@ -326,6 +326,7 @@ def test_ord(): def test_chr(): + assert_equal("\0", chr(0)) assert_equal("A", chr(65)) assert_equal("a", chr(97)) assert_equal("!", chr(33)) From dce95e45d01ca3f95482668aa175d32c81f5d3d0 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 19:57:09 -0300 Subject: [PATCH 3/4] remove use of subtraction Signed-off-by: martinvuyk --- stdlib/src/utils/string_slice.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index af33847ce0..2513e53f8d 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -119,7 +119,7 @@ fn _shift_unicode_to_utf8[ var mask = UInt8(0xFF) >> (num_bytes + int(num_bytes > 1)) var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) ptr[0] = ((c >> shift) & mask) | ( - num_bytes_marker & (int(num_bytes == 1) - 1) + num_bytes_marker & -int(num_bytes != 1) ) for i in range(1, num_bytes): shift -= 6 From 8e254a267724bf660fc347115bc429d81c34f8d0 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 26 Dec 2024 09:38:33 -0300 Subject: [PATCH 4/4] add parameter docstring Signed-off-by: martinvuyk --- stdlib/src/utils/string_slice.mojo | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stdlib/src/utils/string_slice.mojo b/stdlib/src/utils/string_slice.mojo index 2513e53f8d..728d001d77 100644 --- a/stdlib/src/utils/string_slice.mojo +++ b/stdlib/src/utils/string_slice.mojo @@ -89,6 +89,9 @@ fn _shift_unicode_to_utf8[ ](ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int): """Shift unicode to utf8 representation. + Parameters: + optimize_ascii: Optimize for languages with mostly ASCII characters. + ### Unicode (represented as UInt32 BE) to UTF-8 conversion: - 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa - a