AG ๋ด์ค ๋ฐ์ดํฐ์
AG ๋ด์ค ๋ฐ์ดํฐ์
์ ๋ฐ์ดํฐ ๋ง์ด๋๊ณผ ์ ๋ณด ์ถ์ถ ๋ฐฉ๋ฒ ์ฐ๊ตฌ๋ฅผ ๋ชฉ์ ์ผ๋ก 2005๋
์ ์์งํ ๋ด์ค ๊ธฐ์ฌ ๋ชจ์์
๋๋ค. ์ด๋ฒ ์ฅ์ ๋ชฉํ๋ ํ
์คํธ ๋ถ๋ฅ์์ ์ฌ์ ํ๋ จ๋ ๋จ์ด ์๋ฒ ๋ฉ์ ํจ๊ณผ๋ฅผ ์์๋ณด๋ ๊ฒ์
๋๋ค. ๊ธฐ์ฌ ์ ๋ชฉ์ ์ด์ ์ ๋ง์ถฐ ์ฃผ์ด์ง ์ ๋ชฉ์ผ๋ก ์นดํ
๊ณ ๋ฆฌ๋ฅผ ์์ธกํ๋ ๋ค์ค ๋ถ๋ฅ ์์
์ ๋ง๋ค ๊ฒ์
๋๋ค.
ย
ํ
์คํธ ์ ์ฒ๋ฆฌ๋ ํ
์คํธ๋ฅผ ์๋ฌธ์๋ก ๋ณํ ํ ์ผํ, ๋ง์นจํ, ๋๋ํ ๋ฑ์ ์ฃผ์์ ๊ณต๋ฐฑ์ ์ถ๊ฐํ๊ณ ๊ทธ ์ธ ๊ตฌ๋์ ๊ธฐํธ๋ ๋ชจ๋ ์ ๊ฑฐํ๋ ์์ผ๋ก ์งํํฉ๋๋ค. ๋ฐ์ดํฐ์
์ ํ์ต, ๊ฒ์ฆ, ํ
์คํธ ์ธํธ๋ก ๋ถํ ํฉ๋๋ค. ๋ค์ ์ฝ๋๋ ๋ฐ์ดํฐ์
์ ๊ฐ ํ์์ ๋ชจ๋ธ ์
๋ ฅ์ ๋ํ๋ด๋ ๋ฌธ์์ด์ ์ถ์ถํ๊ณ Vectorizer๋ฅผ ์ฌ์ฉํด ๋ฒกํฐ๋ก ๋ณํํ๋ ๊ณผ์ ์ ๋ณด์ฌ์ค๋๋ค. ๊ทธ๋ค์์ผ๋ก ๋ด์ค ์นดํ
๊ณ ๋ฆฌ๋ฅผ ๋ํ๋ด๋ ์ ์์ ์์ ๊ตฌ์ฑํฉ๋๋ค.
ย
๋ค์์ NewsDataset.__getitem__() ๋ฉ์๋ ๊ตฌํ์
๋๋ค.
class NewsDataset(Dataset): @classmethod def load_dataset_and_make_vectorizer(cls, news_csv): # ๋ฐ์ดํฐ์ ์ ๋ก๋ํ๊ณ ์ฒ์๋ถํฐ ์๋ก์ด Vectorizer ๋ง๋ค๊ธฐ # news_csv(str): ๋ฐ์ดํฐ์ ์ ์์น news_df = pd.read_csv(news_csv) train_news_df = news_df[news_df.split=='train'] # NewsDataset์ ์ธ์คํด์ค return cls(news_df, NewsVectorizer.form_dataframe(train_news_df)) def __getitem__(self, index): # ํ์ดํ ์น ๋ฐ์ดํฐ์ ์ ์ฃผ์ ์ง์ ๋ฉ์๋ # index(int): ๋ฐ์ดํฐ ํฌ์ธํธ์ ์ธ๋ฑ์ค row = self._target_df.iloc[index] title_vector = self._vectorizer.vectorize(row.title, self._max_seq_length) category_index = self._vectorizer.category_vocab.lookup_token(row.category) return {'x_data': title_vector, # ๋ฐ์ดํฐ ํฌ์ธํธ์ ํน์ฑ 'y_target': category_index} # ๋ ์ด๋ธ
ย
ย
Vocabulary, Vectorizer, DataLoader
์ด๋ฒ ์ฅ์์๋ Vocabulary ํด๋์ค๋ฅผ ์์ํ SequenceVocabulary๋ฅผ ๋ง๋ญ๋๋ค. ์ด ํด๋์ค์์๋ ์ํ์ค ๋ฐ์ดํฐ์ ์ฌ์ฉํ๋ ํน์ ํ ํฐ 4๊ฐ(UNK ํ ํฐ, MASK ํ ํฐ, BEGIN-OF-SEQUENCE ํ ํฐ, END-OF-SEQUENCE ํ ํฐ)๊ฐ ์๋๋ฐ, ์ถํ์ ์์ธํ ์ค๋ช
ํ๊ฒ ์ง๋ง ํฌ๊ฒ 3๊ฐ์ง ์ฉ๋๋ก ์ฌ์ฉ๋ฉ๋๋ค. UNK ํ ํฐ์ ๋ชจ๋ธ์ด ๋๋ฌผ๊ฒ ๋ฑ์ฅํ๋ ๋จ์ด์ ๋ํ ํํ์ ํ์ตํ๋๋ก ํฉ๋๋ค. ํ
์คํธ ์์ ๋ณธ ์ ์๋ ๋จ์ด๋ฅผ ์ฒ๋ฆฌํ ์ ์๊ฒ ๋ฉ๋๋ค. MASK ํ ํฐ์ Embedding ์ธต์ ๋ง์คํน ์ญํ ์ ์ํํ๊ณ ๊ฐ๋ณ ๊ธธ์ด์ ์ํ์ค๊ฐ ์์ ์์ ์์ค ๊ณ์ฐ์ ๋์ต๋๋ค. ๋ง์ง๋ง 2๊ฐ์ง ํ ํฐ์ ์ํ์ค ๊ฒฝ๊ณ์ ๋ํ ํํธ๋ฅผ ์ ๊ฒฝ๋ง์ ์ ๊ณตํฉ๋๋ค.
ย
ํ
์คํธ๋ฅผ ๋ฒกํฐ์ ๋ฏธ๋๋ฐฐ์น๋ก ๋ณํํ๋ ํ์ดํ๋ผ์ธ์ ๋ ๋ฒ์งธ ๋ถ๋ถ์ Vectorizer์
๋๋ค. ์ด ํด๋์ค๋ SequenceVocabulary ๊ฐ์ฒด๋ฅผ ์์ฑํ๊ณ ์บก์ํํฉ๋๋ค. ์ด ์์ ์ Vectorizer๋ ๋จ์ด ๋น๋๋ฅผ ๊ณ์ฐํ๊ณ ํน์ ์๊ณ๊ฐ์ ์ง์ ํ์ฌ Vocabulary์์ ์ฌ์ฉํ ์ ์๋ ์ ์ฒด ๋จ์ด ์งํฉ์ ์ ํํฉ๋๋ค. ํต์ฌ ๋ชฉ์ ์ ๋น๋๊ฐ ๋ฎ์ ์ก์ ๋จ์ด๋ฅผ ์ ๊ฑฐํ์ฌ ์ ํธ ํ์ง์ ๊ฐ์ ํ๊ณ ๋ชจ๋ธ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ์ฝ์ ์ํํ๋ ๊ฒ์
๋๋ค.
ย
์ธ์คํด์ค ์์ฑ ํ Vectorizer์ vectorizer() ๋ฉ์๋๋ ๋ด์ค ์ ๋ชฉ ํ๋๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ์ ๋ฐ์ดํฐ์
์์ ๊ฐ์ฅ ๊ธด ์ ๋ชฉ๊ณผ ๊ธธ์ด๊ฐ ๊ฐ์ ๋ฒกํฐ๋ฅผ ๋ฐํํฉ๋๋ค. ์ด ๋ฉ์๋๋ 2๊ฐ์ง ์ฃผ์ ์์
์ ์ํํฉ๋๋ค. ์ฒซ์งธ, ์ต๋ ์ํ์ค ๊ธธ์ด๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๋ณดํต ๋ฐ์ดํฐ์
์ด ์ต๋ ์ํ์ค ๊ธธ์ด๋ฅผ ๊ด๋ฆฌํ๊ณ ์ถ๋ก ์ ํ
์คํธ ๋ฐ์ดํฐ์ ์ํ์ค ๊ธธ์ด๋ฅผ ๋ฒกํฐ ๊ธธ์ด๋ก ์ฌ์ฉํ์ง๋ง, CNN ๋ชจ๋ธ์ ์ฌ์ฉํ๋ฏ๋ก ์ถ๋ก ์์๋ ๋ฒกํฐ์ ํฌ๊ธฐ๊ฐ ๊ฐ์์ผ ํฉ๋๋ค. ๋์งธ, ๋จ์ด ์ํ์ค๋ฅผ ๋ํ๋ด๋ 0์ผ๋ก ํจ๋ฉ๋ ์ ์ ๋ฒกํฐ๋ฅผ ์ถ๋ ฅํฉ๋๋ค. ์ด ์ ์ ๋ฒกํฐ๋ ์์ ๋ถ๋ถ์ BEGIN-OF-SEQUENCE ํ ํฐ์ ์ถ๊ฐํ๊ณ ๋์๋ END-OF-SEQUENCE ํ ํฐ์ ์ถ๊ฐํจ์ผ๋ก์จ ๋ถ๋ฅ๊ธฐ๋ ์ํ์ค ๊ฒฝ๊ณ๋ฅผ ๊ตฌ๋ถํ๊ณ ๊ฒฝ๊ณ ๊ทผ์ฒ์ ๋จ์ด์ ์ค์์ ๊ฐ๊น์ด ๋จ์ด์๋ ๋ค๋ฅด๊ฒ ๋ฐ์ํ ์ ์์ต๋๋ค.
ย
๋ค์์ AG ๋ด์ค ๋ฐ์ดํฐ ์
์ ์ํ Vectorizer ๊ตฌํ์
๋๋ค.
class NewsVectorizer(object): def vectorize(self, title, vector_length = -1): # title(str): ๊ณต๋ฐฑ์ผ๋ก ๋๋์ด์ง ๋จ์ด ๋ฌธ์์ด # vector_length(int): ์ธ๋ฑ์ค ๋ฒกํฐ์ ๊ธธ์ด ๋งค๊ฐ๋ณ์ indices = [self.title_vocab.begin_seq_index] indices.extend(self.title_vocab.lookup_token(token) for token in title.split(" ")) indices.append(self.title_vocab.end_seq_index) if vector_length < 0: vector_length = len(indices) out_vector = np.zeros(vector_length, dtype=np.int64) out_vector[:len(indices)] = indices out_vector[len(indices):] = self.title_vocab.mask_index # ๋ฒกํฐ๋ก ๋ณํ๋ ์ ๋ชฉ (๋ํ์ด ์ด๋ ์ด) return out_vector @classmethod def from_dataframe(cls, news_df, cutoff=25): # ๋ฐ์ดํฐ์ ๋ฐ์ดํฐํ๋ ์์์ Vectorizer ๊ฐ์ฒด ๋ง๋ค๊ธฐ # news_df(pandas.DataFrame): ํ๊น ๋ฐ์ดํฐ์ # cutoff(int): Vocabulary์ ํฌํจํ ๋น๋ ์๊ณ๊ฐ category_vocab = Vocabulary() for category in sorted(set(news_df.category)): category_vocab.add_token(category) word_counts = Counter() for title in news_df.title: for token in title.split(" "): if token not in string.punctuation: word_counts[token] += 1 title_vocab = SequenceVocabulary() for word, word_count in word_counts.items(): if word_count >= cutoff: title_vocab.add_token(word) # NewsVectorizer ๊ฐ์ฒด return cls(title_vocab, category_vocab)
ย
NewsClassifier ๋ชจ๋ธ
๋จ์ด ์๋ฒ ๋ฉ์ ์ด๊ธฐ ์๋ฒ ๋ฉ ํ๋ ฌ๋ก ์ฌ์ฉํ๋ ค๋ฉด ๋จผ์ ๋์คํฌ์์ ์๋ฒ ๋ฉ์ ๋ก๋ํ ๋ค์ ์ค์ ๋ฐ์ดํฐ์ ์๋ ๋จ์ด์ ํด๋นํ๋ ์๋ฒ ๋ฉ์ ์ผ๋ถ๋ฅผ ์ ํํฉ๋๋ค. ๋ง์ง๋ง์ผ๋ก Embedding ์ธต์ ๊ฐ์ค์น ํ๋ ฌ์ ์ ํํ ์๋ฒ ๋ฉ์ผ๋ก ์ง์ ํฉ๋๋ค. ์ฒซ ๋ฒ์งธ์ ๋ ๋ฒ์งธ ๋จ๊ณ๋ฅผ ๋ค์ ์ฝ๋์์ ์ค๋ช
ํ๊ฒ ์ต๋๋ค. ์ดํ ์ฌ์ ์ ๊ธฐ๋ฐํ์ฌ ๋จ์ด ์๋ฒ ๋ฉ์ ๋ถ๋ถ ์งํฉ์ ์ ํํฉ๋๋ค.
def load_glove_from_file(glove_filepath): # glove_filepath (str): ์๋ฒ ๋ฉ ํ์ผ ๊ฒฝ๋ก word_to_index = {} embeddings = [] with open(glove_filepath, "r") as fp: for index, line in enumerate(fp): line = line.split(" ") # each line: word num1 num2 ... word_to_index[line[0]] = index # word = line[0] embedding_i = np.array([float(val) for val in line[1:]]) embeddings.append(embedding_i) # word_to_index (dict), embeddings (numpy.ndarary) return word_to_index, np.stack(embeddings) def make_embedding_matrix(glove_filepath, words): # ํน์ ๋จ์ด ์งํฉ์ ๋ํ ์๋ฒ ๋ฉ ํ๋ ฌ ๋ง๋ค๊ธฐ # glove_filepath (str): ์๋ฒ ๋ฉ ํ์ผ ๊ฒฝ๋ก # words (list): ๋จ์ด ๋ฆฌ์คํธ word_to_idx, glove_embeddings = load_glove_from_file(glove_filepath) embedding_size = glove_embeddings.shape[1] final_embeddings = np.zeros((len(words), embedding_size)) for i, word in enumerate(words): if word in word_to_idx: final_embeddings[i, :] = glove_embeddings[word_to_idx[word]] else: embedding_i = torch.ones(1, embedding_size) torch.nn.init.xavier_uniform_(embedding_i) final_embeddings[i, :] = embedding_i # final_embeddings (numpu.ndarray): ์๋ฒ ๋ฉ ํ๋ ฌ return final_embeddings
ย
์
๋ ฅ ํ ํฐ ์ธ๋ฑ์ค๋ฅผ ๋ฒกํฐ ํํ์ผ๋ก ๋งคํํ๋ Embedding์ธต์ ์ฌ์ฉํฉ๋๋ค. ๋ค์ ์ฝ๋์์๋ Embedding ์ธต์ ๊ฐ์ค์น๋ฅผ ์ฌ์ ํ๋ จ๋ ์๋ฒ ๋ฉ์ผ๋ก ๋ฐ๊พธ๊ฒ ๋ฉ๋๋ค. forward() ๋ฉ์๋์์ ์ด ์๋ฒ ๋ฉ์ ์ฌ์ฉํด ์ธ๋ฑ์ค๋ฅผ ๋ฒกํฐ๋ก ๋งคํํฉ๋๋ค.
class NewsClassifier(nn.Module): def __init__(self, embedding_size, num_embeddings, num_channels, hidden_dim, num_classes, dropout_p, pretrained_embeddings=None, padding_idx=0): """ ๋งค๊ฐ๋ณ์: embedding_size (int): ์๋ฒ ๋ฉ ๋ฒกํฐ์ ํฌ๊ธฐ num_embeddings (int): ์๋ฒ ๋ฉ ๋ฒกํฐ์ ๊ฐ์ num_channels (int): ํฉ์ฑ๊ณฑ ์ปค๋ ๊ฐ์ hidden_dim (int): ์๋ ์ฐจ์ ํฌ๊ธฐ num_classes (int): ํด๋์ค ๊ฐ์ dropout_p (float): ๋๋กญ์์ ํ๋ฅ pretrained_embeddings (numpy.array): ์ฌ์ ์ ํ๋ จ๋ ๋จ์ด ์๋ฒ ๋ฉ ๊ธฐ๋ณธ๊ฐ์ None padding_idx (int): ํจ๋ฉ ์ธ๋ฑ์ค """ super(NewsClassifier, self).__init__() if pretrained_embeddings is None: self.emb = nn.Embedding(embedding_dim=embedding_size, num_embeddings=num_embeddings, padding_idx=padding_idx) else: pretrained_embeddings = torch.from_numpy(pretrained_embeddings).float() self.emb = nn.Embedding(embedding_dim=embedding_size, num_embeddings=num_embeddings, padding_idx=padding_idx, _weight=pretrained_embeddings) self.convnet = nn.Sequential( nn.Conv1d(in_channels=embedding_size, out_channels=num_channels, kernel_size=3), nn.ELU(), nn.Conv1d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, stride=2), nn.ELU(), nn.Conv1d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, stride=2), nn.ELU(), nn.Conv1d(in_channels=num_channels, out_channels=num_channels, kernel_size=3), nn.ELU() ) self._dropout_p = dropout_p self.fc1 = nn.Linear(num_channels, hidden_dim) self.fc2 = nn.Linear(hidden_dim, num_classes) def forward(self, x_in, apply_softmax=False): """๋ถ๋ฅ๊ธฐ์ ์ ๋ฐฉํฅ ๊ณ์ฐ ๋งค๊ฐ๋ณ์: x_in (torch.Tensor): ์ ๋ ฅ ๋ฐ์ดํฐ ํ ์ x_in.shape๋ (batch, dataset._max_seq_length)์ ๋๋ค. apply_softmax (bool): ์ํํธ๋งฅ์ค ํ์ฑํ ํจ์๋ฅผ ์ํ ํ๋๊ทธ ํฌ๋ก์ค-์ํธ๋กํผ ์์ค์ ์ฌ์ฉํ๋ ค๋ฉด False๋ก ์ง์ ํฉ๋๋ค ๋ฐํ๊ฐ: ๊ฒฐ๊ณผ ํ ์. tensor.shape์ (batch, num_classes)์ ๋๋ค. """ # ์๋ฒ ๋ฉ์ ์ ์ฉํ๊ณ ํน์ฑ๊ณผ ์ฑ๋ ์ฐจ์์ ๋ฐ๊ฟ๋๋ค x_embedded = self.emb(x_in).permute(0, 2, 1) features = self.convnet(x_embedded) # ํ๊ท ๊ฐ์ ๊ณ์ฐํ์ฌ ๋ถ๊ฐ์ ์ธ ์ฐจ์์ ์ ๊ฑฐํฉ๋๋ค remaining_size = features.size(dim=2) features = F.avg_pool1d(features, remaining_size).squeeze(dim=2) features = F.dropout(features, p=self._dropout_p) # MLP ๋ถ๋ฅ๊ธฐ intermediate_vector = F.relu(F.dropout(self.fc1(features), p=self._dropout_p)) prediction_vector = self.fc2(intermediate_vector) if apply_softmax: prediction_vector = F.softmax(prediction_vector, dim=1) return prediction_vector
ย
๋ชจ๋ธ ํ๋ จ
ํ๋ จ ๊ณผ์ ์ ๋ฐ์ดํฐ์
์ด๊ธฐํ, ๋ชจ๋ธ ์ด๊ธฐํ, ์์ค ํจ์ ์ด๊ธฐํ, ์ตํฐ๋ง์ด์ ์ด๊ธฐํ, ํ๋ จ ์ธํธ์ ๋ํ ๋ฐ๋ณต, ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ, ๊ฒ์ฆ ์ธํธ์ ๋ํ ๋ฐ๋ณต๊ณผ ์ฑ๋ฅ ์ธก์ ์ ํ ๋ค์ ํน์ ํ์ ๋์ ์ด ๋ฐ์ดํฐ์
์ ๋ฐ๋ณตํฉ๋๋ค. ๋ค์ ์ฝ๋๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํ ์์ ์ ํ๋ จ ๋งค๊ฐ๋ณ์์
๋๋ค.
args = Namespace( # ๋ ์ง์ ๊ฒฝ๋ก ์ ๋ณด news_csv="data/ag_news/news_with_splits.csv", vectorizer_file="vectorizer.json", model_state_file="model.pth", save_dir="model_storage/ch5/document_classification", # ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ glove_filepath='data/glove/glove.6B.100d.txt', use_glove=False, embedding_size=100, hidden_dim=100, num_channels=100, # ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ seed=1337, learning_rate=0.001, dropout_p=0.1, batch_size=128, num_epochs=100, early_stopping_criteria=5, # ์คํ ์ต์ cuda=True, catch_keyboard_interrupt=True, reload_from_files=False, expand_filepaths_to_save_dir=True )
ย
๋ชจ๋ธ ํ๊ฐ์ ์์ธก
๋ชจ๋ธ์ด ์์
์ ์ ์ํํ๋์ง ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ ๋ ๊ฐ์ง๋ก, ํ
์คํธ ์ธํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋์ ์ผ๋ก ํ๊ฐํ๊ฑฐ๋ ๋ถ๋ฅ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์กฐ์ฌํ์ฌ ์ง์ ์ผ๋ก ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค.
ํ ์คํธ ๋ฐ์ดํฐ๋ก ํ๊ฐํ๊ธฐ
classifier.eval() ๋ฉ์๋๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ ํ์ฌ ๋๋กญ์์๊ณผ ์ญ์ ํ๋ฅผ ๋ ํ ํ๋ จ ์ธํธ ๋ฐ ๊ฒ์ฆ ์ธํธ์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ํ
์คํธ ์ธํธ๋ฅผ ๋ฐ๋ณตํฉ๋๋ค. ์ ์ฒด ํ๋ จ ๊ณผ์ ์์ ํ
์คํธ ์ธํธ๋ ๋ฑ ํ ๋ฒ๋ง ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์๋ก์ด ๋ด์ค ์ ๋ชฉ์ ์นดํ ๊ณ ๋ฆฌ ์์ธกํ๊ธฐ
ํ๋ จ์ ๋ชฉ์ ์ ์ค์ ์ ๋ฐฐ์นํ์ฌ ์ฒ์ ์ ํ๋ ๋ฐ์ดํฐ์ ๋ํด ์ถ๋ก ํน์ ์์ธก์ ์ํํ๊ธฐ ์ํจ์
๋๋ค. ์๋ก์ด ๋ด์ค ์ ๋ชฉ์ ์นดํ
๊ณ ๋ฆฌ๋ฅผ ์์ธกํ๊ธฐ ์ํด์๋ ๋จผ์ ํ๋ จํ ๋ ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํ ๋ฐฉ์์ผ๋ก ํ
์คํธ๋ฅผ ์ ์ฒ๋ฆฌํด์ผ ํฉ๋๋ค. ์ ์ฒ๋ฆฌ๋ ๋ฌธ์์ด์ ํ๋ จ์ ์ฌ์ฉํ Vectorizer๋ฅผ ์ฌ์ฉํด ๋ฒกํฐ๋ก ๋ฐ๊พธ๊ณ ํ์ดํ ์น ํ
์๋ก ๋ณํํฉ๋๋ค. ๊ทธ๋ค์์ผ๋ก ์ด ํ
์์ ๋ถ๋ฅ๊ธฐ๋ฅผ ์ ์ฉํฉ๋๋ค. ์์ธก ๋ฒกํฐ์์ ์ต๋๊ฐ์ ์ฐพ์ ์นดํ
๊ณ ๋ฆฌ ์ด๋ฆ์ ์กฐํํ๋๋ฐ, ์ด ๊ณผ์ ์ ์ฝ๋๋ก ์ดํด๋ณด๊ฒ ์ต๋๋ค.
def predict_category(title, classifier, vectorizer, max_length): """๋ด์ค ์ ๋ชฉ์ ๊ธฐ๋ฐ์ผ๋ก ์นดํ ๊ณ ๋ฆฌ๋ฅผ ์์ธกํฉ๋๋ค ๋งค๊ฐ๋ณ์: title (str): ์์ ์ ๋ชฉ ๋ฌธ์์ด classifier (NewsClassifier): ํ๋ จ๋ ๋ถ๋ฅ๊ธฐ ๊ฐ์ฒด vectorizer (NewsVectorizer): ํด๋น Vectorizer max_length (int): ์ต๋ ์ํ์ค ๊ธธ์ด """ title = preprocess_text(title) vectorized_title = \ torch.tensor(vectorizer.vectorize(title, vector_length=max_length)) result = classifier(vectorized_title.unsqueeze(0), apply_softmax=True) probability_values, indices = result.max(dim=1) predicted_category = vectorizer.category_vocab.lookup_index(indices.item()) return {'category': predicted_category, 'probability': probability_values.item()}
ย
ย
์ด์ ๊ธ ์ฝ๊ธฐ
ย
ย