Skip to content

Commit

Permalink
Add cluster_counts and balance_factor params to cuvs kmeans properly
Browse files Browse the repository at this point in the history
  • Loading branch information
jacketsj committed Oct 11, 2024
1 parent 408e4a1 commit 9493865
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/python/lance/cuvs/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
device: Optional[str] = None,
itopk_size: int = 10,
balance_factor: Optional[float] = None,
cluster_counts: Optional[torch.Tensor] = None,
):
if metric == "dot":
raise ValueError(
Expand All @@ -71,6 +72,8 @@ def __init__(
centroids=centroids,
seed=seed,
device=device,
balance_factor=balance_factor,
cluster_counts=cluster_counts,
)

if self.device.type != "cuda" or not torch.cuda.is_available():
Expand Down

0 comments on commit 9493865

Please sign in to comment.