Skip to content

Commit

Permalink
Add a new validate parameter to the b64decode() function
Browse files Browse the repository at this point in the history
Signed-off-by: Manuel Saelices <[email protected]>
  • Loading branch information
msaelices committed Jan 6, 2025
1 parent 9137d83 commit 565284c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
19 changes: 13 additions & 6 deletions stdlib/src/base64/base64.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,24 @@ fn b64encode(input_bytes: List[UInt8, _]) -> String:


@always_inline
fn b64decode(str: String) -> String:
fn b64decode[validate: Bool = False](str: String) raises -> String:
"""Performs base64 decoding on the input string.
Parameters:
validate: If true, the function will validate the input string.
Args:
str: A base64 encoded string.
Returns:
The decoded string.
"""
var n = str.byte_length()
debug_assert(n % 4 == 0, "Input length must be divisible by 4")

@parameter
if validate:
if n % 4 != 0:
raise Error("ValueError: Input length must be divisible by 4")

var p = String._buffer_type(capacity=n + 1)

Expand All @@ -133,10 +140,10 @@ fn b64decode(str: String) -> String:
var c = _ascii_to_value(str[i + 2])
var d = _ascii_to_value(str[i + 3])

debug_assert(
a >= 0 and b >= 0 and c >= 0 and d >= 0,
"Unexpected character encountered",
)
@parameter
if validate:
if a < 0 or b < 0 or c < 0 or d < 0:
raise Error("ValueError: Unexpected character encountered")

p.append((a << 2) | (b >> 4))
if str[i + 2] == "=":
Expand Down
10 changes: 9 additions & 1 deletion stdlib/test/base64/test_base64.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from base64 import b16decode, b16encode, b64decode, b64encode

from testing import assert_equal
from testing import assert_equal, assert_raises


def test_b64encode():
Expand Down Expand Up @@ -60,6 +60,14 @@ def test_b64decode():

assert_equal(b64decode("QUJDREVGYWJjZGVm"), "ABCDEFabcdef")

with assert_raises(
contains="ValueError: Input length must be divisible by 4"
):
_ = b64decode[validate=True]("invalid base64 string")

with assert_raises(contains="ValueError: Unexpected character encountered"):
_ = b64decode[validate=True]("invalid base64 string!!!")


def test_b16encode():
assert_equal(b16encode("a"), "61")
Expand Down

0 comments on commit 565284c

Please sign in to comment.