diff --git a/keybert/_mmr.py b/keybert/_mmr.py index ae6b22eb..c11db0f9 100644 --- a/keybert/_mmr.py +++ b/keybert/_mmr.py @@ -41,6 +41,7 @@ def mmr( # Initialize candidates and already choose best keyword/keyphras keywords_idx = [np.argmax(word_doc_similarity)] candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]] + words_mmr = [(1 - diversity) * np.max(word_doc_similarity)] for _ in range(min(top_n - 1, len(words) - 1)): # Extract similarities within candidates and @@ -55,15 +56,14 @@ def mmr( 1 - diversity ) * candidate_similarities - diversity * target_similarities.reshape(-1, 1) mmr_idx = candidates_idx[np.argmax(mmr)] + words_mmr.append(np.max(mmr)) # Update keywords & candidates keywords_idx.append(mmr_idx) candidates_idx.remove(mmr_idx) # Extract and sort keywords in descending similarity - keywords = [ - (words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4)) - for idx in keywords_idx - ] + keywords = [(words[keywords_idx[i]], round(float(words_mmr[i]), 4)) + for i in range(len(keywords_idx))] keywords = sorted(keywords, key=itemgetter(1), reverse=True) return keywords