KEMBAR78
Allow saved_model export of TFCLIPModel in save_pretrained by seanmor5 · Pull Request #16886 · huggingface/transformers · GitHub
Skip to content

Conversation

@seanmor5
Copy link
Contributor

@seanmor5 seanmor5 commented Apr 22, 2022

What does this PR do?

I apologize if this is out of scope. There were a few bugs in TFCLIPModel which prevented the model from being exported using the TensorFlow SavedModel format:

  1. _build_causal_attention_mask makes use of tf.constant with a runtime dynamic value. It seems shape_list makes use of tf.shape which returns a symbolic tensor (inside autograph), which prevents the graph from being fully traced. tf.constant does not allow runtime dynamic values, but tf.fill does, so I replaced tf.constant with a tf.cast and tf.fill combo. I don't even think TFCLIPModel would run inside a tf.function without this change because the autograph trace fails.

  2. TFCLIPTextModel needs to override the serving default implementation. The default implementation expects token_type_ids which is not a valid input here.

  3. serving_output for TFCLIPModel has some issue with tracing through nested dataclasses, which I can't seem to get right just quite yet. Ideally it should be as easy as calling serving_output on text_model_output and vision_model_output (since there is some convert_to_tensor stuff going on in each output). I was having problems with TensorFlow saying TFBaseModelOutputWithPooling not being a tensor, so I figured the tuple conversion would work, but it doesn't seem to be the fix.

I added tests for exporting each of TFCLIPModel, TFTextModel and TFVisionModel as saved models to verify individual components work and that it's the integration of both that's failing

cc @LysandreJik for TensorFlow changes

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 22, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @seanmor5 for these valuable changes 🙏 We know that serialization (and compiling into graphs) is dodgy at the moment, and we are planning to improve it soon -- perhaps borrowing some changes from this PR.

Added a few minor notes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting to a tuple would break the assumption in the type hint. Can we remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, but unfortunately the test fails because TFCLIPOutput is a nested dataclass---I'm not 100% if TF can trace through nested dataclasses or not, or if the reason it's failing lies somewhere else

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhmmm that is inconvinient. Changing the output type would mean a different TF-PT API, which we would like to avoid, but failing at serving time is also very undesirable 🤔

Can you share the stack trace of the error?

@gante
Copy link
Member

gante commented Apr 26, 2022

Tagging @ydshieh -- this PR adds several tests

@gante gante requested a review from ydshieh April 26, 2022 16:54
@seanmor5 seanmor5 force-pushed the tf-clip-saved-model branch from bcbb239 to 177684a Compare April 26, 2022 22:40
@seanmor5
Copy link
Contributor Author

@gante No problem, glad to help now and if you have plans to improve graph/saved_model serialization in the future I will be glad to help then as well :D

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

@seanmor5 Thank you for this PR 🚀 !

@gante Let me take a look before merge.

I haven't checked yet, but from the author @seanmor5 's description (regarding tf.constant and dynamic shapes), it looks (almost) all the model won't be able to use saved_model and in graph model. However, I don't think this is the case as @gante is able to work with XLA.

Therefore I would like to check a bit more on my side 😄

@gante
Copy link
Member

gante commented Apr 27, 2022

I haven't checked yet, but from the author @seanmor5 's description (regarding tf.constant and dynamic shapes), it looks (almost) all the model won't be able to use saved_model and in graph model. However, I don't think this is the case as @gante is able to work with XLA.

I think we will have to touch a significant number of models for XLA/Saved Model, to be honest 😅

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

Hi, I could confirm the issue from tf.constant() with symbolic tensor as shape.

If I change the signature of serving to fixed shape, it works.

    @tf.function(
        input_signature=[
            {
                "input_ids": tf.TensorSpec((3, 5), tf.int32, name="input_ids"),
                "attention_mask": tf.TensorSpec((3, 5), tf.int32, name="attention_mask"),
            }
        ]
    )
    def serving(self, inputs):

I will check running the model in tf.function too.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

Regarding tf.function, I am able to make the following code working. It works even with jit_compile=True.
@seanmor5 Could you elaborate a bit more your concern regarding tf.function?

Code snippet

from transformers import TFCLIPTextModel, TFCLIPVisionModel, TFCLIPModel, CLIPConfig
import os
import tensorflow as tf

ckpt = "openai/clip-vit-base-patch32"
config = CLIPConfig.from_pretrained(ckpt)
text_config = config.text_config

model = TFCLIPTextModel.from_config(text_config)


def get_inputs(batch_size, seq_len):
    input_ids = tf.constant(1, shape=[batch_size, seq_len], dtype=tf.int32)
    attention_mask = tf.constant(1, shape= [batch_size, seq_len], dtype=tf.int32)
    inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
    return inputs


inputs_1 = get_inputs(3, 5)
inputs_2 = get_inputs(4, 7)

outputs = model(**inputs_1)
# print(outputs)


@tf.function
def foo(inputs):
    outputs = model(**inputs)
    return outputs

outputs = foo(inputs_1)
outputs = foo(inputs_2)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

Sorry for being a bit picky, but I would prefer to get better context in order to decide such changes. In particular, if this issue only occurs during saving to saved_model, I think we can do some more research first to see if there is better solution.

@seanmor5
Copy link
Contributor Author

seanmor5 commented Apr 27, 2022

@ydshieh No problem! You are right, I was assuming there might be issues with tf.function, but because the input shapes are static and known at trace time then it makes sense that it works. I think the issue is exclusive to saved_model because the input shapes might not be known and so the shape could be symoblic.

EDIT: This is a failing case for tf.function assuming a non-static sequence length. This is probably not really desirable behavior though, because of the limitations of dynamic shapes in XLA. So it's probably okay to ignore, but I'm just pointing out for due diligence :)

@tf.function(
    input_signature=[tf.TensorSpec((3, None), dtype=tf.int32), tf.TensorSpec((3, None), dtype=tf.int32)],
    jit_compile=True,
    experimental_relax_shapes=True
)
def foo(input_ids, attn_mask):
    outputs = model(input_ids=input_ids, attention_mask=attn_mask)
    return outputs

inputs = [(tf.constant(1, shape=[x, y], dtype=tf.int32),
           tf.constant(1, shape=[x, y], dtype=tf.int32))
          for x, y in zip([3, 3, 3, 3, 3], [1, 2, 3, 4, 5])]

for inp in inputs:
    print(foo(*inp))

I am open to exploring whatever other options you think might be better

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

OK. Maybe doing some more research is good. I will find time to get some more ideas.

I always feel that these limitations not easy to handle, but so far (my own) use cases could use a fixed shape (other than batch dim).

@gante
Copy link
Member

gante commented Apr 27, 2022

For context, we have been suggesting to users to pad and set the second dimension of the shape to the model's maximum sequence length, in this sort of situation. However, it's unoptimized, a manual process, and it doesn't work well in all situations (e.g. in auto-regressive text generation with models like GPT-2, the defined padded input length has to be smaller than the max sequence length, to allow new tokens to come in, but big enough to handle all prompts).

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

I understand, and that's what I will do although not the best ideally.

One question is: if we use tf.fill as suggested by @seanmor5 , are we able to run the failing case provided above.
We know from the PR that it will work for saved_model, but I would like to verify it also works for the above example.
(I have feeling that it won't work even with tf.fill, but need to verify)

@seanmor5
Copy link
Contributor Author

@ydshieh So I applied the patch with tf.fill and the function does run with an input_signature=[tf.TensorSpec(3, None, dtype=tf.int32), tf.TensorSpec((3, None), dtype=tf.int32)

One thing to note is that without the input signature to relax the sequence length constraint, the function is retraced which can be a performance hit. With tf.fill, I can verify the following is not retraced with the relaxed input signature:

is_not_retraced = True

@tf.function(
    input_signature=[tf.TensorSpec((3, None), dtype=tf.int32), tf.TensorSpec((3, None), dtype=tf.int32)],
    jit_compile=True,
    experimental_relax_shapes=True
)
def foo(input_ids, attn_mask):
    global compiles
    compiles += 1
    outputs = model(input_ids=input_ids, attention_mask=attn_mask)
    return outputs

inputs = [(tf.constant(1, shape=[x, y], dtype=tf.int32),
           tf.constant(1, shape=[x, y], dtype=tf.int32))
          for x, y in zip([3, 3, 3, 3, 3], [1, 2, 3, 4, 5])]

prev_concrete_f = foo.get_concrete_function(*inp)

for inp in inputs:
    concrete_f = foo.get_concrete_function(*inp)
    is_not_retraced = is_not_retraced and concrete_f is prev_concrete_f

assert is_not_retraced

But this version is retraced without an input signature:

is_not_retraced = True

@tf.function
def foo(input_ids, attn_mask):
    global compiles
    compiles += 1
    outputs = model(input_ids=input_ids, attention_mask=attn_mask)
    return outputs

inputs = [(tf.constant(1, shape=[x, y], dtype=tf.int32),
           tf.constant(1, shape=[x, y], dtype=tf.int32))
          for x, y in zip([3, 3, 3, 3, 3], [1, 2, 3, 4, 5])]

prev_concrete_f = foo.get_concrete_function(*inp)

for inp in inputs:
    concrete_f = foo.get_concrete_function(*inp)
    is_not_retraced = is_not_retraced and concrete_f is prev_concrete_f

assert is_not_retraced
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/var/folders/57/yg31bn915kg_s_tzht3by3r80000gp/T/ipykernel_24110/149582172.py in <module>
     18     is_not_retraced = is_not_retraced and concrete_f is prev_concrete_f
     19 
---> 20 assert is_not_retraced

AssertionError: 

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

@seanmor5 Thank you for all the effort providing the information. I will take a look (you know a lot about the TF graph thing 🚀 !

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 27, 2022

I could confirm all @seanmor5 states, and also find the TF doc here

Screenshot 2022-04-27 231948

I start to think that's the design. Not sure why TF decides to do so, maybe there are reasons to separate the static constants and dynamic constants (performance consideration?).

I am in favor to approve. I will take some time to check the added tests, and think about it a bit more in the meantime.

@seanmor5
Copy link
Contributor Author

@ydshieh Thank you! I've had to debug way too much TF code in my life so I've gotten use to it :)

So unfortunately the last thing that needs to be addressed is the failing serving_output test for the joint TFCLIPModel, and I'm not quite sure what the fix might be. Here is the stack trace of the failing test (cc @gante):

E         ValueError: in user code:
E
E
E             ValueError: Got a non-Tensor value TFBaseModelOutputWithPooling(last_hidden_state=<tf.Tensor 'StatefulPartitionedCall:6' shape=(None, None, 32) dtype=float32>, pooler_output=<tf.Tensor 'StatefulPartitionedCall:7' shape=(None, 32) dtype=float32>, hidden_states=<tf.Tensor 'StatefulPartitionedCall:5' shape=(6, None, None, 32) dtype=float32>, attentions=<tf.Tensor 'StatefulPartitionedCall:4' shape=(5, None, 4, None, None) dtype=float32>) for key 'text_model_output' in the output of the function __inference_serving_217811 used to generate the SavedModel signature 'serving_default'. Outputs for functions used as signatures must be a single Tensor, a sequence of Tensors, or a dictionary from string to Tensor.

@gante
Copy link
Member

gante commented Apr 28, 2022

@seanmor5 that exception is raised here -- i.e. on the TF serving side, so outside our control.

It means we have to change something about our API if we want to support serving for all outputs. I'm calling in for second opinions: @sgugger, as he's more experienced in these situations, and @Rocketknight1, my fellow TF MLE.


@sgugger @Rocketknight1 some context:

  1. This PR attempts to fix serving for TF CLIP which, as you know, has an image and a text component;
  2. If we want to output everything (i.e. attention and hidden layers for the vision and the text models), we have the existing TFCLIPOutput (here), which contains tf.Tensor and TFBaseModelOutputWithPooling members. Note: contrarily to other ModelOutput child classes, TFCLIPOutput can contains non-tf.Tensor members;
  3. Our serving_output functions return classes inherited from ModelOutput, like TFBaseModelOutputWithPooling (e.g. gpt2, bert). We would like to do the same here;
  4. TF raises an exception whenever we attempt to serve structures that do not contain tensors, sequence of tensors, or dictionary of tensors (see first link in this comment)... which is the case here, TFCLIPOutput does not fit that criteria (see why below);
  5. @seanmor5 originally proposed to return .to_tuple() (this one) instead, and it works. However, in that case, the API would be different for this model (and across frameworks), and we would lose the ability to access fields by name.

A few additional notes:

  1. Happy to move this to a separate issue, as it may warrant further discussion;
  2. Any decision here would set a precedent for multimodal TF models;
  3. More specifically, the exception will not be raised if we are serving a structure containing CompositeTensor (code) members. Acording to its docstring, it can expand whatever tf.nest can expand, and if the leaves are tf.Tensor, we are all good. Looking at the docs of tf.nest (here), we can see that it treats @dataclass-decorated structures as an atom. So while we can serve most @dataclass-decorated ModelOutput (its members are tf.Tensor), we cannot serve @dataclass-decorated ModelOutput containing other @dataclass-decorated ModelOutput, which would likely be our desired case for multimodal models.

Potential solutions:

  1. Consider using namedtuples? It has a few drawbacks
  2. Use dictionaries and, if needed, add syntactic sugar to access fields by name and by index?
  3. Expand nested fields -- instead of TFCLIPOutput holding two TFBaseModelOutputWithPooling, it holds a tuple for each TFBaseModelOutputWithPooling or expands their attributes directly into TFCLIPOutput (e.g. text_model_output -> text_model_output_attentions, text_model_output_hidden_states`, ...)
  4. ???

@sgugger
Copy link
Collaborator

sgugger commented Apr 28, 2022

For more information, is it possible to convert the nested fields (here text_model_output and vision_model_output that are TFBaseModelOutputWithPooling and not tensors) to tuple inside the serve part only? Or would that change need to be done all the time?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 28, 2022

I haven't tried TF serving yet (probably just once). For HF models, while serving_output returns things like TFBaseModelOutputWithPooling, what happens when we use the converted TF serving models? For example, if we call those TF serving models, is the output still be TFBaseModelOutputWithPooling? Or it is just a dictionary ..?

@gante
Copy link
Member

gante commented Apr 28, 2022

For more information, is it possible to convert the nested fields (here text_model_output and vision_model_output that are TFBaseModelOutputWithPooling and not tensors) to tuple inside the serve part only? Or would that change need to be done all the time?

@sgugger Yes, it is possible, and it should solve the problem (tuple is supported by tf.nest). The only difference (to PT) is that we wouldn't be able to access the field by name, but that seems like a small price to pay for a simple solution.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 28, 2022

For more information, is it possible to convert the nested fields (here text_model_output and vision_model_output that are TFBaseModelOutputWithPooling and not tensors) to tuple inside the serve part only? Or would that change need to be done all the time?

@sgugger Yes, it is possible, and it should solve the problem (tuple is supported by tf.nest). The only difference (to PT) is that we wouldn't be able to access the field by name, but that seems like a small price to pay for a simple solution.

I guess @sgugger doesn't necessarily mean we should use tuple instead of dict, but just a question about where we should do the conversion (?).

I would much prefer using dictionary though. Maybe let us check what really gives as outputs when we use HF's TF serving model to make the decision?

@gante
Copy link
Member

gante commented Apr 28, 2022

I haven't tried TF serving yet (probably just once). For HF models, while serving_output returns things like TFBaseModelOutputWithPooling, what happens when we use the converted TF serving models? For example, if we call those TF serving models, is the output still be TFBaseModelOutputWithPooling? Or it is just a dictionary ..?

@ydshieh Can confirm that the output of a loaded model using tf.keras.models.load_model is a dict, not a subclass of ModelOutput

(output types of a reloaded TFCLIPTextModel, after storing with tf.keras.models.load_model)
Screenshot 2022-04-28 at 17 15 15


I experimented with converting the problematic variables to multiple formats:

  • to dict with dict() and with dataclass.asdict()
  • to tuple with tuple(), .to_tuple()
  • using tf.nest.flatten() as the CompositeTensor docstring suggests

All of them result in the same exception. I also tried to look into documentation on how to cast into CompositeTensor, but there is none we can use (there is an experimental function in tensorflow-probability, which is not our dependency atm, but it throws an exception related to the expected input object).

The only thing that seems to work is a flat serving structure, without nested components.

@seanmor5, I got the exact same exception when attempting your original solution, with return output.to_tuple(). The command I ran was RUN_SLOW=1 py.test -vv tests/clip/test_modeling_tf_clip.py::TFCLIPModelTest::test_saved_model_creation_extended -- were you also running this command?

@sgugger
Copy link
Collaborator

sgugger commented Apr 28, 2022

Are those composite outputs really useful for the serving? Can't we just remove them entirely?

@gante
Copy link
Member

gante commented Apr 28, 2022

Probably. We can leave them as a TODO and wait for an issue :) @seanmor5 would you be okay with that? (I'm afraid we are hitting a hard problem for a feature we probably don't need)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 28, 2022

Thank you for the work, @gante ! Surprised that dict is not working 😢 . Adding to TODO is good for me. Meanwhile, I think the users might customize the format if they want to sever the model. We can even add a comment in the code around serving_output.

Could you share the code you use 🙏 ?

@seanmor5
Copy link
Contributor Author

seanmor5 commented May 2, 2022

@ydshieh Thank you for the review! I believe I addressed all comments

@seanmor5 seanmor5 changed the title [WIP] Allow saved_model export of TFCLIPModel in save_pretrained Allow saved_model export of TFCLIPModel in save_pretrained May 3, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented May 3, 2022

@seanmor5 Thank you. I will take a look.

Regarding the tests:

  • You can ignore Model templates runner / run_tests_templates (pull_request) .
  • About ci/circleci: check_code_quality , you can run make style and make quality to fix. (After installing pip install hf-doc-builder -U)

@seanmor5
Copy link
Contributor Author

seanmor5 commented May 3, 2022

@ydshieh Thank you! I was running black locally but it seemed for some reason it was not catching the formatting issues. I have fixed the issues, though it seems the code quality is still failing but from the docs/source path which I have not touched.

@ydshieh
Copy link
Collaborator

ydshieh commented May 3, 2022

@seanmor5

Regarding the style, could you follow this comment, please?
#17008 (comment)

And I just merged a (big) PR, so you will need to move a bit the test file, please see this comment:
#17008 (comment)

Thank you so much!

@seanmor5 seanmor5 force-pushed the tf-clip-saved-model branch from b41515f to eb2a2a7 Compare May 3, 2022 13:00
@seanmor5
Copy link
Contributor Author

seanmor5 commented May 3, 2022

@ydshieh Beautiful! That worked great!

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's wait @sgugger approval too, and see if @Rocketknight1 has any further comment :-)

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick review: Replacing tf.constant with tf.fill for potentially variable sizes is completely correct, good job! The added test is also very valuable, so I'm happy to approve this.

@ydshieh ydshieh merged commit 279bc58 into huggingface:main May 4, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented May 4, 2022

Merged. Thank you again, @seanmor5 .

Let's see if we could find a way to make TFCLIPModel work.

@seanmor5 seanmor5 deleted the tf-clip-saved-model branch May 4, 2022 14:40
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…ce#16886)

* CLIP Serving

* Add type hints per code review

* Use black, flake8, and isort

* Update src/transformers/models/clip/modeling_tf_clip.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Rollback serving_output and add TODO

* Remove irrelevant portions of failing tests

* Revert "Rollback serving_output and add TODO"

This reverts commit a4abfa6ba3b7875a13538dbc2ddc4eb17dfcca8d.

* Rollback to original test/serving_output

* Fix unused var

* Apply suggestions from code review

* Update formatting with black

* Fix style again from rebase

* Update tests/models/clip/test_modeling_tf_clip.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sean Moriarity <sean.l.moriarity.mil@army.mil>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants