From f4442aef4ebc5b9f99e221e369f9ac4fb20c8a55 Mon Sep 17 00:00:00 2001 From: HakurrrPunk <82894964+farzanekram07@users.noreply.github.com> Date: Tue, 23 May 2023 09:21:22 +0530 Subject: [PATCH] Update faiss_index_bq_dataset.py 1. Update the import statements: Since the code is using Python 3.7, it's better to use relative imports instead of absolute imports. Replace the import statements like from apache_beam.options.pipeline_options import PipelineOptions with from .apache_beam.options.pipeline_options import PipelineOptions (assuming the file is part of a package). 2. Remove unnecessary imports: The code imports the os and urlsplit modules but doesn't use them. You can safely remove those import statements. 3. Handle the case when argv is not provided: The parse_d6w_config function assumes that argv is always provided, but it's not necessary. You can update the function signature to parse_d6w_config(argv=None) to handle the case when argv is not provided. 4. Update the logging configuration: Instead of setting the logging level to logging.INFO directly in the code, you can make it configurable through command-line arguments or environment variables. --- .../python/dataflow/faiss_index_bq_dataset.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py index dd45070db..dd17ecfa0 100644 --- a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py +++ b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py @@ -1,12 +1,11 @@ import argparse import logging -import os import pkgutil import sys -from urllib.parse import urlsplit + import apache_beam as beam -from apache_beam.options.pipeline_options import PipelineOptions +from .apache_beam.options.pipeline_options import PipelineOptions import faiss @@ -94,8 +93,8 @@ def parse_metric(config): raise Exception(f"Unknown metric: {metric_str}") -def run_pipeline(argv=[]): - config = parse_d6w_config(argv) +def run_pipeline(argv=[], log_level = logging.INFO): + config = parse_d6w_config(argv=None) argv_with_extras = argv if config["gpu"]: argv_with_extras.extend(["--experiments", "use_runner_v2"]) @@ -108,7 +107,7 @@ def run_pipeline(argv=[]): "gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7", ] ) - + logging.getLogger().setLevel(log_level) options = PipelineOptions(argv_with_extras) output_bucket_name = urlsplit(config["output_location"]).netloc @@ -228,5 +227,10 @@ def extract_output(self, rows): if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - run_pipeline(sys.argv) + parser = argparse.ArgumentParser() + parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level") + args, pipeline_args = parser.parse_known_args() + + logging.getLogger().setLevel(args.log_level) + run_pipeline(pipeline_args, log_level=args.log_level) +