Skip to content

Commit

Permalink
feat: add kwargs to SentenceTransformerEmbeddingModel init
Browse files Browse the repository at this point in the history
  • Loading branch information
Pouyanpi committed Jan 9, 2025
1 parent a6116c1 commit db6245a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemoguardrails/embeddings/providers/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):

engine_name = "SentenceTransformers"

def __init__(self, embedding_model: str):
def __init__(self, embedding_model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
except ImportError:
Expand All @@ -58,7 +58,7 @@ def __init__(self, embedding_model: str):
)

device = "cuda" if cuda.is_available() else "cpu"
self.model = SentenceTransformer(embedding_model, device=device)
self.model = SentenceTransformer(embedding_model, device=device, **kwargs)
# Get the embedding dimension of the model
self.embedding_size = self.model.get_sentence_embedding_dimension()

Expand Down

0 comments on commit db6245a

Please sign in to comment.