-
Notifications
You must be signed in to change notification settings - Fork 0
/
recommender.py
136 lines (100 loc) · 4.29 KB
/
recommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import csv
import json
import os
import random
import requests
import streamlit as st
from annoy import AnnoyIndex
st.set_page_config(
layout="wide",
page_title="Yu-Gi-Oh! Card Recommender",
page_icon="🃏",
)
STORAGE = os.getenv("STORAGE")
ANNOY_INDEX_FILE = "card-embeddings.ann"
EMBEDDING_SIZE = 50
@st.cache_resource
def load_vector_database():
import requests
url = f"{STORAGE}/card-embeddings.ann"
r = requests.get(url)
with open(ANNOY_INDEX_FILE, "wb") as f:
f.write(r.content)
ann = AnnoyIndex(EMBEDDING_SIZE, "angular")
ann.load(ANNOY_INDEX_FILE)
return ann
@st.cache_resource
def load_cards():
import requests
url = f"{STORAGE}/cards.csv"
r = requests.get(url)
with open("cards.csv", "wb") as f:
f.write(r.content)
cards = {}
with open("cards.csv", "r") as f:
reader = csv.DictReader(f)
for row in reader:
cards[row["id"]] = row
return cards
@st.cache_resource
def load_supporting_dictionsaries():
url_passcode_to_id = f"{STORAGE}/passcode_to_id.json"
r = requests.get(url_passcode_to_id)
with open("passcode_to_id.json", "wb") as f:
f.write(r.content)
passcode_to_id = {}
with open("passcode_to_id.json", "r") as f:
passcode_to_id = json.load(f)
url_id_to_passcode = f"{STORAGE}/id_to_passcode.json"
r = requests.get(url_id_to_passcode)
with open("id_to_passcode.json", "wb") as f:
f.write(r.content)
id_to_passcode = {}
with open("id_to_passcode.json", "r") as f:
id_to_passcode = json.load(f)
return passcode_to_id, id_to_passcode
if STORAGE is None:
st.error("Please set the STORAGE environment variable to the URL of the storage bucket.")
else:
ann = load_vector_database()
cards = load_cards()
passcode_to_id, id_to_passcode = load_supporting_dictionsaries()
def format_card(card_id):
return cards[card_id]["name"]
st.title("Yu-Gi-Oh! card recommender")
st.write("Welcome to the recommender! Please select a card to get started.")
cards['0'] = {'name': "Random card"}
id_list = list(cards.keys())
id_list.append('0')
selected_passcode = st.query_params["card"] if "card" in st.query_params else None
if selected_passcode is not None:
query_card_passcode = st.selectbox("Select a card", id_list, format_func=format_card, index=
id_list.index(selected_passcode))
else:
query_card_passcode = st.selectbox("Select a card", id_list, format_func=format_card, index=len(id_list)-1)
if query_card_passcode == '0':
query_card_passcode = random.choice(list(cards.keys()))
card_id = passcode_to_id[query_card_passcode]
query_card_embedding = ann.get_item_vector(card_id)
similar_card_ids, scores = ann.get_nns_by_vector(query_card_embedding, 6, include_distances=True)
# Check if the query card is in the list of similar cards
if similar_card_ids[0] == card_id:
similar_card_ids.pop(0)
scores.pop(0)
# Make sure we limit to 5 similar cards
similar_card_ids = similar_card_ids[:5]
scores = scores[:5]
st.subheader("Here are some similar cards:")
columns = st.columns(len(similar_card_ids) + 1)
with columns[0]:
# passcode = id_to_passcode[str(similar_card_ids[0])]
st.markdown("#### Query Card:")
st.image(cards[query_card_passcode]["image_url"])
for similar_card_id, score, column in zip(similar_card_ids, scores, columns[1:]):
with column:
passcode = id_to_passcode[str(similar_card_id)]
st.markdown(f"#### Distance {score:.3f}")
st.image(cards[passcode]["image_url"])
st.markdown("Want to learn more about this repository? **[check out the repository](https://github.com/fferegrino/card-embeddings)**. Follow me on **[Twitter](https://twitter.com/feregri_no)** or [Threads](https://threads.net/feregri_no).")
st.divider()
st.caption("The literal and graphical information presented on this site about Yu-Gi-Oh!, including card images, the attribute, level/rank and type symbols, and card text, is copyright 4K Media Inc, a subsidiary of Konami Digital Entertainment, Inc. This website is not produced by, endorsed by, supported by, or affiliated with 4k Media or Konami Digital Entertainment.")