Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

DLRM TensorFlow2 model conversion to ONNX fails with >2GB tensor #717

Open
piotrm-nvidia opened this issue May 27, 2021 · 0 comments
Open

Comments

@piotrm-nvidia
Copy link

The function keras2onnx.convert_keras fails for models with >2GB tensors with error:

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

Motivation

I would like to convert to ONNX the DLRM model for TensorFlow2 from NVIDIA DeepLearningExamples:
https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Recommendation/DLRM

This model uses very large embedding tables (>10 GB).

ONNX supports large models with external data format:
https://github.com/onnx/onnx/blob/master/docs/ExternalData.md

It is possible to export 2GB tensors from PyTorch but Keras2ONNX fails to convert such models.

Reproduction steps

Build docker with TensorFlow2 from NVIDIA NGC using this Dockerfile:

FROM nvcr.io/nvidia/tensorflow:21.04-tf2-py3                                                                                                                                                      
RUN pip install -U keras2onnx 

Execute in python script:

import tensorflow as tf                                                                                                                                                                           
import keras2onnx                                                                
import onnx                                                                      
                                                                                 
# Example                                                                        
FeatureSize=11316796 # this is too big to be converted                          
# FeatureSize=1572176 # this works                                                 
OutputSize=128                                                                   
                                                                                 
                                                                                 
class EmbeddingTable(tf.keras.layers.Layer):                                     
                                                                                 
    def __init__(self, input_dim, output_dim):                                   
       super(EmbeddingTable, self).__init__(dtype=tf.float32)                    
       self.input_dim = input_dim                                                
       self.output_dim = output_dim                                              
       self.embedding_table = None                                               
                                                                                 
    def build(self, input_shape):                                                
        self.embedding_table = self.add_weight("embedding_table",                
                            shape=[self.input_dim, self.output_dim],             
                            dtype=tf.float32,                                    
                            initializer=tf.keras.initializers.Identity(),        
                            trainable=False                                      
                            )                                                    
                                                                                 
    def call(self, indices):                                                     
        return tf.gather(params=self.embedding_table, indices=indices)           
                                                                                 
class SimpleModel(tf.keras.Model):                                               
                                                                                 
    def __init__(self, input_dim, output_dim):                                   
        super(SimpleModel, self).__init__()                                      
        self.embedding_layer = EmbeddingTable(input_dim, output_dim)             
                                                                                 
    def call(self, input_idxs):                                                  
        return self.embedding_layer(input_idxs)                                  
                                                                                 
           
model = SimpleModel(FeatureSize, OutputSize)                                     
output = model(tf.zeros(shape=[1,], dtype="int64"))                              
onnx_model2 = keras2onnx.convert_keras(model, model.name)                        #line 42
onnx.save_model(onnx_model2, "model.onnx")                                       
                                           

Expected results

The file with mode.onnx is saved.

Result:
Script crashes during keras2onnx.convert_keras

Traceback (most recent call last):
  File "main.py", line 42, in <module>
    onnx_model2 = keras2onnx.convert_keras(model, model.name)
  File "/usr/local/lib/python3.8/dist-packages/keras2onnx/main.py", line 62, in convert_keras
    tf_graph = build_layer_output_from_model(model, output_dict, input_names, output_names)
  File "/usr/local/lib/python3.8/dist-packages/keras2onnx/_parser_tf.py", line 302, in build_layer_output_from_model
    return extract_outputs_from_subclassing_model(model, output_dict, input_names, output_names)
  File "/usr/local/lib/python3.8/dist-packages/keras2onnx/_parser_tf.py", line 263, in extract_outputs_from_subclassing_model
    graph_def, converted_input_indices = _convert_to_constants(
  File "/usr/local/lib/python3.8/dist-packages/keras2onnx/_graph_cvt.py", line 525, in convert_variables_to_constants_v2
    _populate_const_op(output_node, input_node.name, dtype, data, data.shape)
  File "/usr/local/lib/python3.8/dist-packages/keras2onnx/_graph_cvt.py", line 310, in _populate_const_op
    tensor = tensor_util.make_tensor_proto(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/tensor_util.py", line 527, in make_tensor_proto
    raise ValueError(
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant