KEMBAR78
Support for Segment Anything Model 2 (SAM 2) by haithamkhedr · Pull Request #32394 · huggingface/transformers · GitHub
Skip to content

Conversation

@haithamkhedr
Copy link
Contributor

@haithamkhedr haithamkhedr commented Aug 2, 2024

This PR integrates SAM 2 models into hugging face transformers (Closes #32308). Sample usage:

Video Predictor

from transformers import Sam2VideoPredictor
import torch
import numpy as np

video_dir = # path to your video
predictor = Sam2VideoPredictor.from_pretrained("facebook/sam2-hiera-large-hf").cuda()
inference_state = predictor.init_state(video_path=video_dir)
# prompt input
points = np.array([[210, 350]], dtype=np.float32) 
labels = np.array([1], np.int32) 
with torch.autocast("cuda", dtype=torch.bfloat16):
   # Select object of interest to generate a mask
    _, out_obj_ids, out_mask_logits = predictor.add_new_points(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=0,
        points=points,
        labels=labels,
    ).to_tuple()
    # Propagate mask
    for output in predictor.propagate_in_video(inference_state):
        out_frame_idx, out_obj_ids, out_mask_logits = output.to_tuple()
        # store/plot output

Image Predictor

from transformers import Sam2ImagePredictor
import torch
import numpy as np
predictor = Sam2ImagePredictor.from_pretrained("facebook/sam2-hiera-large-hf")

img = np.random.rand(1080, 920, 3).astype(np.float32)
# prompt input
points = np.array([[100, 100]], dtype=np.float32)  # Dummy point
labels = np.array([1], np.int32) 
with torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(img)
    output = predictor.predict(
        point_coords = points,
        point_labels = labels,
    )
    masks, scores, low_res_masks = output.to_tuple()

TODO:

  • Add Image/Video predictors
  • Push Hiera-S/T/L models
  • Documentation
  • Add tests
  • Notebooks

@SangbumChoi SangbumChoi mentioned this pull request Aug 3, 2024
5 tasks
@SangbumChoi
Copy link
Contributor

@haithamkhedr Hi haitham! I saw that you were directly working on SAM2. We have closed #32317 for to clarify this is the main PR with direct authors.
If there is anything that requires some help I can help you with this.

@RUFFY-369
Copy link
Contributor

@haithamkhedr +1

@ArthurZucker
Copy link
Collaborator

cc @NielsRogge , @qubvel and @amyeroberts !

@amyeroberts
Copy link
Contributor

@haithamkhedr Awesome - super excited to see SAM 2 available in transformers! 🥳 Let us know when the PR's ready for review or if you have any qs in the meantime

@bhack
Copy link
Contributor

bhack commented Aug 5, 2024

For the HF team here but this was also requested upstream by other users:
How do you think we are fine-tuning this type of model?

@haithamkhedr haithamkhedr marked this pull request as ready for review August 7, 2024 04:26
@haithamkhedr
Copy link
Contributor Author

@haithamkhedr Awesome - super excited to see SAM 2 available in transformers! 🥳 Let us know when the PR's ready for review or if you have any qs in the meantime

Hi @amyeroberts, it would be great to get an initial review on this PR now, it's currently functional and would appreciate your feedback. Thanks!

@qubvel
Copy link
Contributor

qubvel commented Aug 8, 2024

Hi @haithamkhedr thanks for working on adding the model!

Thats great that the model is already functional and we are looking forward for adding it, however, transformers are following design patterns for modeling code. Here is the documentation:

You can also get inspiration by looking at the sam (v1) model:

Let me know if you have any specific questions regarding implementation!

@haithamkhedr
Copy link
Contributor Author

haithamkhedr commented Aug 8, 2024

Hi @haithamkhedr thanks for working on adding the model!

Thats great that the model is already functional and we are looking forward for adding it, however, transformers are following design patterns for modeling code. Here is the documentation:

You can also get inspiration by looking at the sam (v1) model:

Let me know if you have any specific questions regarding implementation!

Hi @qubvel, thanks for the feedback. So overall, the interaction with the models has to be through a forward function only? Our Video predictor is stateful and supports user interaction as demonstrated in the usage above, what would be the best way to implement this?

@bhack
Copy link
Contributor

bhack commented Aug 8, 2024

There was a "stateful issue" also for serving inference on torchserve but there was a sort of workaround pytorch/serve#2743 (comment)

@bhack
Copy link
Contributor

bhack commented Aug 8, 2024

There will be something similar with onnx exports when they will be available for the video mode (currently only users image exports) microsoft/onnxruntime#20943

@qubvel
Copy link
Contributor

qubvel commented Aug 8, 2024

So overall, the interaction with the models has to be through a forward function only?

I suppose you can make something similar to text models and the model.generate(...) method. The forward method should be able to accept the current frame pixel_values and some state object as inputs and produce masks, scores, ..., updated state as output. Then, another method will be a loop over it (or it might be a separate model for video). Does it fit for the SAM2 model?

@amyeroberts what do you think regarding this design?

@haithamkhedr
Copy link
Contributor Author

So overall, the interaction with the models has to be through a forward function only?

I suppose you can make something similar to text models and the model.generate(...) method. The forward method should be able to accept the current frame pixel_values and some state object as inputs and produce masks, scores, ..., updated state as output. Then, another method will be a loop over it (or it might be a separate model for video). Does it fit for the SAM2 model?

@amyeroberts what do you think regarding this design?

This should work, the state is updated in place so looping over forward should work. However, I was mainly wondering if the transformers library guidelines allow supporting other custom APIs like init_state or add_new_points which allow user interaction?

@bhack
Copy link
Contributor

bhack commented Aug 8, 2024

And also reset_state when you need to start a new sequence. I am guessing how you are going to orchestrate all these API with fine-tuning.

@qubvel
Copy link
Contributor

qubvel commented Aug 8, 2024

However, I was mainly wondering if the transformers library guidelines allow supporting other custom APIs like init_state or add_new_points which allow user interaction?

I haven't yet seen something similar in the codebase, but it's better to ask @ArthurZucker or @amyeroberts.

Here is a raw design I have in mind, the model is "stateless" (without memory) and the state or memory is passed at each step. It also allows user interaction, such as adding points at any frame. Moreover, it allows to use model at the same time with different state objects without resetting or creating several different instances of the model. Let me know what you think about it

import torch
from typing import Optional


class Sam2VideoState:
    """Store frame-wise information for video segmentation:
    points, labels, and past model hidden states used in the model.
    """
    ...

class Sam2ForVideoSegmentation(Sam2PretrainedModel):

    ...

    def forward(
        self,
        pixel_values: torch.Tensor,
        points: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        state: Optional[Sam2VideoState] = None,
    ) -> Sam2ForVideoSegmentationOutput:
        
        if state is None:
            state = Sam2VideoState()

        if points is not None:
            state.add_points(points)  # add points at the current frame
        
        if labels is not None:
            state.add_labels(labels)  # add labels at the current frame

        # Forward pass with `state`. The state is used to access model's hidden states from previous frames.
        # State is not modified inside the model, hidden states are returned in `output` and updated
        # at the next step.
        output = self.video_model(pixel_values, state)

        # Add the current model's hidden states to the `state`
        state.add_hidden_states(output.hidden_states)
        state.step()

        return Sam2ForVideoSegmentationOutput(
            mask_logits=output.mask_logits,
            object_logits=output.object_logits,
            state=state,
            ... # other outputs
        )
    
model = Sam2ForVideoSegmentation.from_pretrained("model_name")
image_processor = Sam2ImageProcessor.from_pretrained("model_name")

state = None
for i, frame in enumerate(video):

    # points and labels can be passed for specific frames
    inputs = image_processor(images=frame, points=points, labels=labels)

    outputs = model(**inputs, state=state)
    state = outputs.state
    
    frame_annotations = image_processor.post_process_image_segmentation(
        **outputs, target_size=frame.size
    )
    # save/plot annotations

@guillochon
Copy link

My understanding is that SAM2 requires torch >=2.3.1. Will that be the case for this transformers implementation as well?

@amyeroberts
Copy link
Contributor

I think @qubvel's suggestion is a good one and aligns well with transformer patterns. For other models which handle state e.g. RWKV we pass state as part of the model outputs. Similarly, for things like cache when generating, this is returned from and passed to the model at each generation step.

@bhack
Copy link
Contributor

bhack commented Aug 9, 2024

Have you evaluated what is the overhead in this case for exchanging cache/state on every frame?

@haithamkhedr
Copy link
Contributor Author

haithamkhedr commented Aug 9, 2024

Here is a raw design I have in mind, the model is "stateless" (without memory) and the state or memory is passed at each step. It also allows user interaction, such as adding points at any frame.

Thanks for drafting this design. The main concern I have is that it requires the user to use 2 models to be able to do prediction on videos and do some bookkeeping on video frames, whereas they really only need one model. Will this require splitting the checkpoints across these 2 models (Sam2ImagePorcessor and Sam2ForVideoSegmentation)?

cc @qubvel @amyeroberts

@qubvel
Copy link
Contributor

qubvel commented Aug 9, 2024

@haithamkhedr ImageProcessor is not a model, it's an object responsible for image/frame preprocessing (resizing, normalizing..) and postprocessing (e.g. threshold filtering, applying final activation to logits, ..). This pattern is used across all our vision models.
Let me know if you have more questions 🤗

@devinli123
Copy link

Here is a raw design I have in mind, the model is "stateless" (without memory) and the state or memory is passed at each step. It also allows user interaction, such as adding points at any frame.

Thanks for drafting this design. The main concern I have is that it requires the user to use 2 models to be able to do prediction on videos and do some bookkeeping on video frames, whereas they really only need one model. Will this require splitting the checkpoints across these 2 models (Sam2ImagePorcessor and Sam2ForVideoSegmentation)?

cc @qubvel @amyeroberts

@haithamkhedr Thanks for your work! Will this be realized? Really looking forward to use it in the live streaming mode as @qubvel suggested

@dvolgyes
Copy link

dvolgyes commented Sep 5, 2024

@haithamkhedr Will you continue? It would be great to get this contribution merged into HF.

@HuFY-dev
Copy link

Hi, are there any updates on integrating SAM 2 into HF Transformers? We are trying to fine-tune it with the HF trainer, and it would be great if SAM 2 were in native HF format!

@amyeroberts
Copy link
Contributor

@haithamkhedr Thanks for all the work in this PR adding this model to the library!

This is a model that the community are really excited about having in transformers, and so there's also a lot of interest in adding the model themselves. As there hasn't been any recent activity here, two contributors @RUFFY-369 and @SangbumChoi, have (re)started an effort on #32317, which will likely be the PR to be merged in. If you're still interested in collaborating, what I would suggest is helping with that effort and you can be added as a co-author on that PR.

@haithamkhedr
Copy link
Contributor Author

This is currently not in active development. Closing for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for Segment Anything 2