-
Notifications
You must be signed in to change notification settings - Fork 1
/
samplers.py
110 lines (99 loc) · 4.54 KB
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from torch.distributions.categorical import Categorical
# from fastshap codebase
import numpy as np
from scipy.special import binom
# Shapley Sampler is from FastSHAP Codebase: https://github.com/iancovert/fastshap
# we thank the authors for releasing their codebase
class ShapleySampler:
'''
For sampling player subsets from the Shapley distribution.
Args:
num_players: number of players.
'''
def __init__(self, num_players):
arange = torch.arange(1, num_players)
w = 1 / (arange * (num_players - arange))
w = w / torch.sum(w)
self.categorical = Categorical(probs=w)
self.num_players = num_players
self.tril = torch.tril(
torch.ones(num_players - 1, num_players, dtype=torch.float32),
diagonal=0)
def sample(self, batch_size, paired_sampling):
'''
Generate sample.
Args:
batch_size: number of samples.
paired_sampling: whether to use paired sampling.
'''
num_included = 1 + self.categorical.sample([batch_size])
S = self.tril[num_included - 1]
# TODO ideally avoid for loops
for i in range(batch_size):
if paired_sampling and i % 2 == 1:
S[i] = 1 - S[i - 1]
else:
S[i] = S[i, torch.randperm(self.num_players)]
return S
def svs_sample(self, batch_size):
buf = []
for i in range(batch_size):
buf.append(torch.randperm(self.num_players) == torch.arange(self.num_players))
return torch.stack(buf, dim=0)
def get_categorical(self, num_players):
arange = torch.arange(1, num_players)
w = 1 / (arange * (num_players - arange))
w = w / torch.sum(w)
return Categorical(probs=w)
def dummy_sample_with_weight(self, batch_size, paired_sampling, guide_weight):
#num_included = 1 + self.categorical.sample([batch_size])
# we can only do local normalization when doing importance sampling as global normalizatino is intractable
assert guide_weight is not None
assert torch.is_tensor(guide_weight)
assert batch_size > 2
# require: guide_weight [seq_len]
assert len(guide_weight.shape) == 1
categorical = self.get_categorical(len(guide_weight))
num_included = 1 + categorical.sample([batch_size])
seq_len = guide_weight.shape[0]
#S = torch.zeros(batch_size, seq_len, dtype=torch.float32).cpu()
tril = torch.tril(torch.ones(seq_len - 1, seq_len, dtype=torch.float32), diagonal=0)
S = tril[num_included - 1]
w = torch.ones(batch_size, dtype=torch.float32).cpu()
w[0]=100000
w[1]=100000
S[0]=0
S[1]=1
#guide_weight_cpu = guide_weight.cpu()
for i in range(2, batch_size):
if paired_sampling and i % 2 == 1:
S[i] = 1 - S[i - 1]
else:
S[i] = S[i, torch.randperm(seq_len)]
return S, w
def guided_sample(self, batch_size, paired_sampling, guide_weight):
# we can only do local normalization when doing importance sampling as global normalizatino is intractable
assert guide_weight is not None
assert torch.is_tensor(guide_weight)
# require: guide_weight [seq_len]
assert len(guide_weight.shape) == 1
categorical = self.get_categorical(len(guide_weight))
num_included = 1 + categorical.sample([batch_size])
seq_len = guide_weight.shape[0]
S = torch.zeros(batch_size, seq_len, dtype=torch.float32).cpu()
w = torch.zeros(batch_size, dtype=torch.float32).cpu()
guide_weight_cpu = guide_weight.cpu()
for batch_i in range(batch_size):
current_guide_weight_cpu = guide_weight_cpu.clone().tolist()
for feat_i in range(num_included[batch_i]):
current_guide_weight_cpu_tensor = torch.Tensor(current_guide_weight_cpu)
cat_dist = torch.softmax(current_guide_weight_cpu_tensor, dim=-1)
cat_dist_class = Categorical(cat_dist)
sample_feat_id = cat_dist_class.sample()
S[batch_i][sample_feat_id] = 1
current_guide_weight_cpu = torch.cat([torch.Tensor(current_guide_weight_cpu[:sample_feat_id]), torch.Tensor(current_guide_weight_cpu[sample_feat_id+1:])])
w[batch_i] += torch.log(cat_dist[sample_feat_id])
w[batch_i] = binom(self.num_players, num_included[batch_i]) * torch.exp(w[batch_i]) + 1e-6
w[batch_i] = 1.0 / w[batch_i]
return S, w