forked from kulinseth/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OpaqueTensorImpl.h
186 lines (164 loc) · 5.95 KB
/
OpaqueTensorImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#pragma once
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
namespace at {
// An "Opaque" TensorImpl -- there are no strides and (for now)
// even data() is not supported (thus no pointer arithmetic).
// NOTE: We could allow data() in the future, but would have to ensure pointer
// arithmetic code is properly guarded.
//
// NOTE: This does not support resize_ (and other metadata-changing ops) because
// of `shallow_copy_and_detach`. We would need to define an interface to
// "shallow copy" in order to add support.
template <typename OpaqueHandle>
struct TORCH_API OpaqueTensorImpl : public TensorImpl {
// public constructor for now...
OpaqueTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::Device device,
OpaqueHandle opaque_handle,
c10::IntArrayRef sizes,
bool is_non_overlapping_and_dense = true)
: TensorImpl(key_set, data_type, device),
opaque_handle_(std::move(opaque_handle)) {
set_storage_access_should_throw();
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
sizes_and_strides_.set_sizes(sizes);
refresh_numel();
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
}
// Destructor doesn't call release_resources because it's
// unnecessary; don't forget to change that if needed!
void release_resources() override {
TensorImpl::release_resources();
opaque_handle_ = {};
}
void set_size(int64_t dim, int64_t new_size) override {
AT_ERROR("opaque tensors do not have set_size");
}
void set_stride(int64_t dim, int64_t new_stride) override {
AT_ERROR("opaque tensors do not have set_stride");
}
void set_storage_offset(int64_t storage_offset) override {
AT_ERROR("opaque tensors do not have set_storage_offset");
}
#ifdef DEBUG
bool has_storage() const override {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!storage_, "OpaqueTensorImpl assumes that storage_ is never set");
return false;
}
#endif
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
key_set(),
dtype(),
device(),
opaque_handle_,
sizes_and_strides_.sizes_arrayref());
copy_tensor_metadata(
/*src_opaque_impl=*/this,
/*dest_opaque_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
key_set(),
dtype(),
device(),
opaque_handle_,
sizes_and_strides_.sizes_arrayref());
copy_tensor_metadata(
/*src_opaque_impl=*/this,
/*dest_opaque_impl=*/impl.get(),
/*version_counter=*/std::move(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
/**
* Shallow-copies data from another TensorImpl into this TensorImpl.
*
* For why this function doesn't check this TensorImpl's
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
*/
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
auto opaque_impl =
static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
copy_tensor_metadata(
/*src_impl=*/opaque_impl,
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel();
}
const OpaqueHandle& opaque_handle() const {
return opaque_handle_;
}
OpaqueHandle& unsafe_opaque_handle() {
return opaque_handle_;
}
protected:
/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
* storage_offset) from one TensorImpl to another TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
* [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_metadata(
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_opaque_impl,
dest_opaque_impl,
version_counter,
allow_tensor_metadata_change);
// OpaqueTensorImpl-specific fields.
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
}
static void copy_tensor_metadata(
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_opaque_impl,
dest_opaque_impl,
std::move(version_counter),
allow_tensor_metadata_change);
// OpaqueTensorImpl-specific fields.
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
}
private:
const char* tensorimpl_type_name() const override {
return "OpaqueTensorImpl";
}
OpaqueHandle opaque_handle_;
};
} // namespace at