Cerebras / modelzoo

Apache License 2.0
909 stars 128 forks source link

`to_hash` function seems to stall frequently when writing results to disk #20

Open tbarton16 opened 1 year ago

tbarton16 commented 1 year ago

to_hash function seems to stall frequently when writing results to disk. I have been running the pipeline on a 256 cpu machine with 1 tb ram. On a 63 GB slice of data, the writing to disk gets stalled for many hours unnecessarily. I fixed it by writing to disk before the worker pool gets joined.

I change get documents to return a list of documents.

def to_minhash(chunks):
    buckets = []
    documents, output_dir, width, dataset_name, n_docs, chunk_id = chunks
    for doc in documents:
        text, file_path, doc_id = doc[0], doc[1], doc[2]
        file_name = file_path.split("/")[-1]
        if dataset_name == "common_crawl":
            dir_2 = file_path.split("/")[-2]
            output_name = f"{dataset_name}/{dir_2}/{file_name}"
        else:
            output_name = f"{dataset_name}/{file_name}"

        m = MinHash(num_perm=128)
        [m.update(x.encode('utf8')) for x in get_features(text, width)]
        buckets.append(
            {"file_name": output_name, "doc_id": doc_id, "hash": m,}
        )
    with open(
        f"{output_dir}/minhash_nfc/{chunk_id}.pickle", "wb"
    ) as fout:
        pickle.dump(buckets, fout)
def generate_hashes(args):
    if not os.path.exists(f"{args.output_dir}/minhash_nfc"):
        os.mkdir(f"{args.output_dir}/minhash_nfc")
    documents = get_documents(
        args.input_dir,
        args.output_dir,
        args.dataset_name,
    )
    print(cpu_count())
    cpucount = int(cpu_count() - 12)
    n_chunks = args.n_docs // cpucount
    files = [documents[i : i + n_chunks] for i in range(0, len(documents), n_chunks)]
    print(files[0][0])
    with Pool(processes=cpucount) as pool:
        for i, chunks in enumerate(
            tqdm(
                pool.imap_unordered(
                    to_minhash,
                    zip(
                        files,
                        repeat(args.output_dir),
                        repeat(args.w),
                        repeat(args.dataset_name),
                        repeat(n_chunks),
                        range(len(files)),
                    ),
                ),
                total=cpucount,
            )
        ):
            print("pass")

I guess my question is why does your code stall out

frankang commented 11 months ago

The original function is badly written. it should write to the disk periodically in the to_minhash() function rather than holding all process' hash result in memory till the one process returns.