From 00b86e7016bbb5c1c5caf10ede452ccec3174f29 Mon Sep 17 00:00:00 2001 From: Michael Zolotukhin Date: Sat, 5 Oct 2024 17:31:09 -0700 Subject: [PATCH] Allow customizing branch names for branches stack-pr creates. Closes #18. stack-info: PR: https://github.com/modularml/stack-pr/pull/33, branch: ZolotukhinM/stack/3 --- README.md | 1 + src/stack_pr/cli.py | 61 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index d30bfba..86a9321 100644 --- a/README.md +++ b/README.md @@ -258,4 +258,5 @@ keep_body=False remote=origin target=main reviewer=GithubHandle1,GithubHandle2 +branch_name_template=$USERNAME/stack ``` diff --git a/src/stack_pr/cli.py b/src/stack_pr/cli.py index 2262a09..4eb82f8 100755 --- a/src/stack_pr/cli.py +++ b/src/stack_pr/cli.py @@ -52,6 +52,7 @@ import json import os import re +from functools import cache from subprocess import SubprocessError from stack_pr.git import ( @@ -382,12 +383,14 @@ def split_header(s: str) -> List[CommitHeader]: return [CommitHeader(h) for h in s.split("\0")[:-1]] -def is_valid_ref(ref: str) -> bool: +def is_valid_ref(ref: str, branch_name_template: str) -> bool: ref = ref.strip("'") - splits = ref.rsplit("/", 2) - if len(splits) < 3: + + branch_name_base = get_branch_name_base(branch_name_template) + splits = ref.rsplit("/", 1) + if len(splits) < 2: return False - return splits[-2] == "stack" and splits[-1].isnumeric() + return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric() def last(ref: str, sep: str = "/") -> str: @@ -555,23 +558,33 @@ def add_or_update_metadata( return True -def get_available_branch_name(remote: str) -> str: +@cache +def get_branch_name_base(branch_name_template: str): username = get_gh_username() + branch_name_base = branch_name_template.replace("$USERNAME", username) + return branch_name_base + + +def get_available_branch_name(remote: str, branch_name_template: str) -> str: + branch_name_base = get_branch_name_base(branch_name_template) refs = get_command_output( [ "git", "for-each-ref", - f"refs/remotes/{remote}/{username}/stack", + f"refs/remotes/{remote}/{branch_name_base}", "--format='%(refname)'", ] ).split() - refs = list(filter(is_valid_ref, refs)) + def check_ref(ref): + return is_valid_ref(ref, branch_name_base) + + refs = list(filter(check_ref, refs)) max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0 new_branch_id = max_ref_num + 1 - return f"{username}/stack/{new_branch_id}" + return f"{branch_name_base}/{new_branch_id}" def get_next_available_branch_name(name: str) -> str: @@ -579,19 +592,23 @@ def get_next_available_branch_name(name: str) -> str: return f"{base}/{int(id) + 1}" -def set_head_branches(st: List[StackEntry], remote: str, verbose: bool): +def set_head_branches( + st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str +): """Set the head ref for each stack entry if it doesn't already have one.""" run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) - available_name = get_available_branch_name(remote) + available_name = get_available_branch_name(remote, branch_name_template) for e in filter(lambda e: not e.has_head(), st): e.head = available_name available_name = get_next_available_branch_name(available_name) -def init_local_branches(st: List[StackEntry], remote: str, verbose: bool): +def init_local_branches( + st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str +): log(h("Initializing local branches"), level=1) - set_head_branches(st, remote, verbose) + set_head_branches(st, remote, verbose, branch_name_template) for e in st: run_shell_command( ["git", "checkout", e.commit.commit_id(), "-B", e.head], @@ -785,6 +802,7 @@ class CommonArgs(NamedTuple): target: str hyperlinks: bool verbose: bool + branch_name_template: str @classmethod def from_args(cls, args: argparse.Namespace) -> "CommonArgs": @@ -795,6 +813,7 @@ def from_args(cls, args: argparse.Namespace) -> "CommonArgs": args.target, args.hyperlinks, args.verbose, + args.branch_name_template, ) @@ -822,6 +841,7 @@ def deduce_base(args: CommonArgs) -> CommonArgs: args.target, args.hyperlinks, args.verbose, + args.branch_name_template, ) @@ -876,7 +896,9 @@ def command_submit( # Create local branches and initialize base and head fields in the stack # elements - init_local_branches(st, args.remote, args.verbose) + init_local_branches( + st, args.remote, args.verbose, args.branch_name_template + ) set_base_branches(st, args.target) print_stack(st, args.hyperlinks) @@ -1137,7 +1159,9 @@ def command_abandon(args: CommonArgs): return current_branch = get_current_branch_name() - init_local_branches(st, args.remote, args.verbose) + init_local_branches( + st, args.remote, args.verbose, args.branch_name_template + ) set_base_branches(st, args.target) print_stack(st, args.hyperlinks) @@ -1219,7 +1243,7 @@ def command_view(args: CommonArgs): st = get_stack(args.base, args.head, args.verbose) - set_head_branches(st, args.remote, args.verbose) + set_head_branches(st, args.remote, args.verbose, args.branch_name_template) set_base_branches(st, args.target) print_stack(st, args.hyperlinks) print_tips_after_view(st, args) @@ -1268,6 +1292,13 @@ def create_argparser( default=config.getboolean("common", "verbose", fallback=False), help="Enable verbose output from Git subcommands.", ) + common_parser.add_argument( + "--branch-name-template", + default=config.get( + "repo", "branch_name_template", fallback="$USERNAME/stack" + ), + help="A template for names of the branches stack-pr would use.", + ) parser_submit = subparsers.add_parser( "submit",