Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update faiss_index_bq_dataset.py #1837

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ann/src/main/python/dataflow/faiss_index_bq_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import logging
import os
import pkgutil
import sys
from urllib.parse import urlsplit


import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from .apache_beam.options.pipeline_options import PipelineOptions
import faiss


Expand Down Expand Up @@ -94,8 +93,8 @@ def parse_metric(config):
raise Exception(f"Unknown metric: {metric_str}")


def run_pipeline(argv=[]):
config = parse_d6w_config(argv)
def run_pipeline(argv=[], log_level = logging.INFO):
config = parse_d6w_config(argv=None)
argv_with_extras = argv
if config["gpu"]:
argv_with_extras.extend(["--experiments", "use_runner_v2"])
Expand All @@ -108,7 +107,7 @@ def run_pipeline(argv=[]):
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
]
)

logging.getLogger().setLevel(log_level)
options = PipelineOptions(argv_with_extras)
output_bucket_name = urlsplit(config["output_location"]).netloc

Expand Down Expand Up @@ -228,5 +227,10 @@ def extract_output(self, rows):


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run_pipeline(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level")
args, pipeline_args = parser.parse_known_args()

logging.getLogger().setLevel(args.log_level)
run_pipeline(pipeline_args, log_level=args.log_level)