Open skdbsxir opened 11 months ago
user sequence 길이가 20일때는 문제 X.
user sequence 길이가 50일때 Dataset의 __init__() 에서 OOM 발생 (dataset.py)
__init__()
dataset.py
현재 해당 부분 코드는 다음과 같음. (중간 print는 어디가 문제인지 파악하기 위해서 임시로 넣음.)
def __init__(self, dataset:str, split:str, seed:int, user_seq_len:int=20, item_seq_len:int=250): """ Args: dataset: raw dataset name (ciao // epinions) split: dataset split type (train // valid // test) seed: random seed used in dataset split item_seq_len: length of item list (processed in `data_utils.py`) """ self.data_path = os.getcwd() + '/dataset/' + dataset # 전처리 된 .pkl 파일 load with open(self.data_path + '/' f'sequence_data_seed_{seed}_walk_{user_seq_len}_itemlen_{item_seq_len}_{split}.pkl', 'rb') as file: dataframe = pickle.load(file) user_sequences = dataframe['user_sequences'].values user_sequences = np.array([np.array(x) for x in user_sequences]) print("1") user_degree = dataframe['user_degree'].values user_degree = np.array([np.array(x) for x in user_degree]) print("2") item_sequences = dataframe['item_sequences'].values item_sequences = np.array([np.array(x) for x in item_sequences]) print("3") item_degree = dataframe['item_degree'].values item_degree = np.array([np.array(x) for x in item_degree]) print("4") # shape: [total_samples(num_row), seq_len_user, seq_len_item] rating_matrix = dataframe['item_rating'].values rating_matrix = np.array([np.array(x) for x in rating_matrix]) print("5") # shape: [total_samples(num_row), seq_len_user, seq_len_user] spd_matrix = dataframe['spd_matrix'].values spd_matrix = np.array([np.array(x) for x in spd_matrix]) print("6") del dataframe print("Deleted") self.user_sequences = torch.LongTensor(user_sequences) print("user seq OK") self.user_degree = torch.LongTensor(user_degree) print("user degree OK") self.item_sequences = torch.LongTensor(item_sequences) print("item seq OK") self.item_degree = torch.LongTensor(item_degree) print("item degree OK") self.rating_matrix = torch.LongTensor(rating_matrix) print("rating matrix OK") self.spd_matrix = torch.LongTensor(spd_matrix) print("spd matrix OK") print("1111111")
출력 결과는 다음과 같음.
1 2 3 4 5 6 Deleted user seq OK user degree OK item seq OK item degree OK Killed
sparse matrix를 그대로 memory에 올리고 있어서 memory consumption이 상당 수 발생하는 것으로 보임.
How to handle this?
user sequence 길이가 20일때는 문제 X.
user sequence 길이가 50일때 Dataset의
__init__()
에서 OOM 발생 (dataset.py
)현재 해당 부분 코드는 다음과 같음. (중간 print는 어디가 문제인지 파악하기 위해서 임시로 넣음.)
출력 결과는 다음과 같음.
sparse matrix를 그대로 memory에 올리고 있어서 memory consumption이 상당 수 발생하는 것으로 보임.
How to handle this?