KEMBAR78
Llava 1.6 support by cmp-nct · Pull Request #5267 · ggml-org/llama.cpp · GitHub
Skip to content

Conversation

@cmp-nct
Copy link
Contributor

@cmp-nct cmp-nct commented Feb 1, 2024

First steps - I got impressive results with llava-1.6-13B on the license_demo example already, despite many open issues.

Todo:

The biggest and most important difference missing is the "spatial_unpad" logic.
The conversion script I added can convert the nested array into a flat 2D array with valid image shapes, but it ignores them at this point.
The new tensor for the image separation is part of the projector - for compatibility with pytorch I removed it from the llava.clip extract.

llava-surgery-v2.py should be compatible with cogvlm, llava-1.6 and llava-1.5

For Mistral and using llava-cli binary:
Add this: -p "<image>\nUSER:\nProvide a full description.\nASSISTANT:\n"
The mistral template for llava-1.6 seems to be no system print and a USER/ASSISTANT role

For Vicunas the default settings work.

For the 34B this should work:
Add this: -e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat can be said about this image?<|im_end|><|im_start|>assistant\n

Do not expect great results before the proper image preprocessing was added

Downloads:
I've extracted the embedded vit and quantized it for all 4 variants (though, not all quantizations)
They are being uploaded here: https://huggingface.co/cmp-nct/llava-1.6-gguf/upload/main

Please note: Until preprocessing is done, expect poor results

@cmp-nct cmp-nct mentioned this pull request Feb 2, 2024
5 tasks
@cjpais
Copy link
Contributor

cjpais commented Feb 2, 2024

one note is that for liuhaotian/llava-v1.6-mistral-7b the model.mm_projector is in model-00003-of-00004

this impacts the current code as it only looks for mmproj in the last of the .safetensors files

will now search for projector
@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 2, 2024

one note is that for liuhaotian/llava-v1.6-mistral-7b the model.mm_projector is in model-00003-of-00004

this impacts the current code as it only looks for mmproj in the last of the .safetensors files

I've updated it with a quick solution to search for those two checkpoint paths.
As long as projectors are not split through multiple files it should work now

Cleaning Q:\models\llava\llava-v1.6-vicuna-13b\model-00005-of-00006.safetensors
Searching for vision tower tensors in Q:\models\llava\llava-v1.6-vicuna-13b\model-00005-of-00006.safetensors
No vision tower found in Q:\models\llava\llava-v1.6-vicuna-13b\model-00005-of-00006.safetensors
Done! All vision tower tensors are removed from the model files and stored in llava.clip file.
Taking projector from Q:\models\llava\llava-v1.6-vicuna-13b\model-00006-of-00006.safetensors
Taking newline from Q:\models\llava\llava-v1.6-vicuna-13b\model-00001-of-00006.safetensors
Found 4 tensors to extract.
Found additional 1 tensors to extract.
Done!
Now you can convert Q:\models\llava\llava-v1.6-vicuna-13b\ to a a regular LLaMA GGUF file.
Also, use Q:\models\llava\llava-v1.6-vicuna-13b\/llava.projector to prepare a llava-encoder.gguf file.

@chigkim
Copy link

chigkim commented Feb 3, 2024

It looks like Ollama that uses Llama.cpp as their backend already supports llava:34b-v1.6.
https://ollama.ai/library/llava/tags
Not sure if I can tag people from other repos, but @jmorganca, did you guys modify your own Llama.cpp already to support v1.6 architecture?

@aisensiy
Copy link

aisensiy commented Feb 3, 2024

It looks like Ollama that uses Llama.cpp as their backend already supports llava:34b-v1.6. https://ollama.ai/library/llava/tags Not sure if I can tag people from other repos, but @jmorganca, did you guys modify your own Llama.cpp already to support v1.6 architecture?

I have the same question...

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 3, 2024

With these tools you can convert llava-1.6 into a llama.cpp GGUF file and it will work for inferencing.
But as long as the image preprocessing is not integrated, it will not provide the same quality in results.

Right now llama.cpp will create the usual 14 patches of a rectangular padded 336 pixel image.
But the big change in llava-1.6 was the preprocessing in how patches are split up into image regions of much higher resolutions, it does not need the padding/cropping anymore.

@aisensiy
Copy link

aisensiy commented Feb 3, 2024

With these tools you can convert llava-1.6 into a llama.cpp GGUF file and it will work for inferencing. But as long as the image preprocessing is not integrated, it will not provide the same quality in results.

Right now llama.cpp will create the usual 14 patches of a rectangular padded 336 pixel image. But the big change in llava-1.6 was the preprocessing in how patches are split up into image regions of much higher resolutions, it does not need the padding/cropping anymore.

Thanks for the reply. So right now, ollama is using the old way to use llava-1.6 and it may loss quality for large images?

@chigkim
Copy link

chigkim commented Feb 3, 2024

With these tools you can convert llava-1.6 into a llama.cpp GGUF file and it will work for inferencing. But as long as the image preprocessing is not integrated, it will not provide the same quality in results.

Right now llama.cpp will create the usual 14 patches of a rectangular padded 336 pixel image. But the big change in llava-1.6 was the preprocessing in how patches are split up into image regions of much higher resolutions, it does not need the padding/cropping anymore.

Yeah, I get that, but I was wondering if Ollama forked Llama.cpp they're using and already completed their own implementation to match Llava 1.6 architecture and image preprocessing.

@ggerganov
Copy link
Member

If it is just the pre-processing missing, we can merge this and make a separate issue with a specific goal to implement that pre-processing. Might be a good idea to first confirm that using a correctly pre-processed image (from the reference implementation) yields good results using this code

@ggerganov ggerganov added the help wanted Needs help from the community label Feb 5, 2024
@chigkim
Copy link

chigkim commented Feb 5, 2024

Basically does the resize logic from clip_image_preprocess function in clip.cpp need to be modified to match the resize logic from process_image function in conversation.py?
I need to look more carefully, but it looks like clip_image_preprocess makes the image into square with padding first and resizes to vision_model.hparams.image_size with linear interpolation.

# LLaVA/llava/conversation.py
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
	if image_process_mode == "Pad":
		def expand2square(pil_img, background_color=(122, 116, 104)):
			width, height = pil_img.size
			if width == height:
				return pil_img
			elif width > height:
				result = Image.new(pil_img.mode, (width, width), background_color)
				result.paste(pil_img, (0, (width - height) // 2))
				return result
			else:
				result = Image.new(pil_img.mode, (height, height), background_color)
				result.paste(pil_img, ((height - width) // 2, 0))
				return result
		image = expand2square(image)
	elif image_process_mode in ["Default", "Crop"]:
		pass
	elif image_process_mode == "Resize":
		image = image.resize((336, 336))
	else:
		raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
	if max(image.size) > max_len:
		max_hw, min_hw = max(image.size), min(image.size)
		aspect_ratio = max_hw / min_hw
		shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
		longest_edge = int(shortest_edge * aspect_ratio)
		W, H = image.size
		if H > W:
			H, W = longest_edge, shortest_edge
		else:
			H, W = shortest_edge, longest_edge
		image = image.resize((W, H))
	if return_pil:
		return image
	else:
		buffered = BytesIO()
		image.save(buffered, format=image_format)
		img_b64_str = base64.b64encode(buffered.getvalue()).decode()
		return img_b64_str
// llama.cpp/examples/llava/clip.cpp
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
	if (!ctx->has_vision_encoder) {
		printf("This gguf file seems to have no vision encoder\n");
		return false;
	}

	// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
	// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156

	clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
	if (pad2square && img->nx != img->ny) {
		int longer_side = std::max(img->nx, img->ny);
		temp->nx = longer_side;
		temp->ny = longer_side;
		temp->buf.resize(3 * longer_side * longer_side);
		const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA

		// fill with background color
		for (size_t i = 0; i < temp->buf.size(); i++) {
			temp->buf[i] = bc[i % 3];
		}

		// copy from the input image
		for (int y = 0; y < img->ny; y++) {
			for (int x = 0; x < img->nx; x++) {
				const int i = 3 * (y * img->nx + x);
				const int j = 3 * (y * temp->nx + x);
				temp->buf[j]   = img->buf[i];
				temp->buf[j+1] = img->buf[i+1];
				temp->buf[j+2] = img->buf[i+2];
			}
		}
	} else {
		temp->nx = img->nx;
		temp->ny = img->ny;
		temp->buf.resize(img->buf.size());
		memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
	}

	const int nx = temp->nx;
	const int ny = temp->ny;

	const int nx2 = ctx->vision_model.hparams.image_size;
	const int ny2 = ctx->vision_model.hparams.image_size;

	res->nx = nx2;
	res->ny = ny2;
	res->buf.resize(3 * nx2 * ny2);

	const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;

	const int nx3 = int(nx / scale + 0.5f);
	const int ny3 = int(ny / scale + 0.5f);

	const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
	const auto & s3 = ctx->image_std;  // {0.26862954f, 0.26130258f, 0.27577711f};

	for (int y = 0; y < ny3; y++) {
		for (int x = 0; x < nx3; x++) {
			for (int c = 0; c < 3; c++) {
				// linear interpolation
				const float sx = (x + 0.5f) * scale - 0.5f;
				const float sy = (y + 0.5f) * scale - 0.5f;

				const int x0 = std::max(0, (int)std::floor(sx));
				const int y0 = std::max(0, (int)std::floor(sy));

				const int x1 = std::min(x0 + 1, nx - 1);
				const int y1 = std::min(y0 + 1, ny - 1);

				const float dx = sx - x0;
				const float dy = sy - y0;

				const int j00 = 3 * (y0 * nx + x0) + c;
				const int j01 = 3 * (y0 * nx + x1) + c;
				const int j10 = 3 * (y1 * nx + x0) + c;
				const int j11 = 3 * (y1 * nx + x1) + c;

				const float v00 = temp->buf[j00];
				const float v01 = temp->buf[j01];
				const float v10 = temp->buf[j10];
				const float v11 = temp->buf[j11];

				const float v0 = v00 * (1.0f - dx) + v01 * dx;
				const float v1 = v10 * (1.0f - dx) + v11 * dx;

				const float v = v0 * (1.0f - dy) + v1 * dy;

				const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);

				const int i = 3 * (y * nx3 + x) + c;

				res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
			}
		}
	}
	clip_image_u8_free(temp);

	return true;
}

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 5, 2024

If it is just the pre-processing missing, we can merge this and make a separate issue with a specific goal to implement that pre-processing. Might be a good idea to first confirm that using a correctly pre-processed image (from the reference implementation) yields good results using this code

PR:
I'm investing more time into it already and I expect a more complete PR soon, I also found issues in the 1.5 implementation yesterday which will get covered alongside and should increase llava-1.5.
I expect to have something within this week.

Testing:
The preprocessing involves more than image modifications, you can not preprocess an image outside and feed it in (you'd need to feed it the embeddings for such a test).
With llava-1.6 we actually create up to 5 images out of one source image, process those spatially separated in ViT and then feed those projected embeddings into the LLM.

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 6, 2024

I'm almost done but got stuck for today on the 5 dimensional permutations that arrange the final embeddings.

I tried to create an own slim tensor manipulation class, it's probably buggy.
ggml sadly doesn't support 5 dimensions, maybe someone knows a slim approach for operations like these ?

        Tensor tensor({combined_embeddings.size()});
        tensor.set_data(combined_embeddings);
        std::vector<int> new_shape = {grid_shape.second, grid_shape.first, num_patches_per_side, num_patches_per_side, clip_n_mmproj_embd(ctx_clip)};
        tensor.view(new_shape);
        tensor.permute({0, 2, 1, 3, 4});
        tensor.contiguous();
        tensor.flatten(0, 3);
        tensor.concatenate(base_tensor, 0);
        std::vector<size_t> new_shape_1d = { tensor.shape[0] * tensor.shape[1] * tensor.shape[2] * tensor.shape[3] * tensor.shape[4]};
        tensor.view(new_shape_1d);

@ggerganov
Copy link
Member

Since dims 3 and 4 or not permuted, you can reshape to 4d tensor, apply permute + cont and then reshape back to 5d tensor

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 7, 2024

I'm cleaning up the code and hope a first PR update by tomorrow.
Solved: image grid embeddings with newline embedding, bicubic resize, normalization bugfix (1.5 too)
I'm currently not permuting the grid embeddings into patches, while I was able to solve the 5D limitation for the (non unpad) variant thanks to gg's hint, I am still not getting correct results.
The permutation in python results in different code than in ggml, super frustrating that cost me a full day by now.

Even without the grid-embeddings permutation I am getting results better than any other llava variant I tested.

Below are 34B 3 bit (llm) with 6bit (ViT) quantized results, I'm not satisfied until the permutation works but it's anyway quite good. It's about GPT4V/Cog-VLM level but with the permutation bug solved it will exceed it.

image

The image shows a person standing on the back of a vehicle, which is an unusual and dangerous situation. This is not a typical way to ride in a car or truck, as it poses significant risks for injury due to lack of safety restraints and exposure to potential accidents. The individual appears to be ironing clothes while standing on the vehicle, which adds to the absurdity of the scene, as this activity would normally take place indoors and not while being transported down a street. Additionally, there are other vehicles in motion nearby, which further highlights the hazardous nature of this situation. It's important to note that such behavior is illegal and can lead to severe consequences for both the person involved and others on the road.

image

This image shows three cats sharing a meal from two bowls placed on the ground. The cat in the foreground is black and appears to be eating from one of the bowls, while another cat with a mix of colors (possibly calico) is also eating from the same bowl. A third cat, which seems to have a similar color pattern as the second but is partially obscured by the other two, is peeking out from behind them. The setting appears to be an outdoor area with a concrete surface and what looks like a wall or structure in the background. It's a candid moment that captures the cats' behavior during feeding time. The image has a shallow depth of field, which blurs the background and keeps the focus on the cats and their immediate surroundings.

image

{
  "text": [
    "CALIFORNIA USA DRIVER LICENSE",
    "DL 12345678",
    "EXP 08/31/2014",
    "END NONE",
    "CLASS C",
    "LN CARDHOLDER FN IMA",
    "2570 24TH STREET ANYTOWN, CA 95618",
    "DOB 08/31/1977",
    "VETERAN",
    "SEX F",
    "HAIR BRN",
    "EYES BRN",
    "HT 5'6",
    "WT 125 LB",
    "ISS DD 00000000/ANYNN/ANFDYY",
    "08/31/2009"
  ]
}

Below is a comparison of the same llava-1.6 using the previous inference, same settings but fp16 ViT:

{
  "text": [
    "CALIFORNIA USA DRIVER LICENSE",
    "1234567890",
    "LN: CARDHOLDER NAME",
    "EXP: 09/30/2014",
    "ANYTOWN, CA 91767",
    "ISSUE DATE: 08/30/2014",
    "END DATE: 08/.30.2017",
    "SEX: F",
    "HAIR: BRN",
    "EYES: BLK",
    "HEIGHT: 5'6",
    "WEIGHT: 140",
    "BIRTH DATE: 08/30/1976",
    "ISSUE STATE: CA",
    "ISSUE COUNTRY: USA",
    "ISSUE CITY: ANYTOWN",
    "ISSUE ZIP: 91767",
    "ISSUE STREET: 12345 MAIN ST",
    "ISSUE STATE CODE: CA",
    "ISSUE COUNTRY CODE: USA",
    "ISSUE CITY CODE: ANYTOWN",
    "ISSUE ZIP CODE: 91767",
    "ISSUE STREET CODE: 12345 M

The image shows three cats eating from a bowl. The cats appear to be domestic, and they are focused on the food in front of them. They seem to be comfortable with each other's presence, suggesting that they might be familiar or used to sharing resources like this. The setting looks like an outdoor area, possibly a courtyard or a street, given the concrete surface and the absence of any indoor furniture or decorations. The cats are likely stray or feral, as they are eating from a common bowl rather than individual dishes, which is more typical for domestic pets in a home environment. The image captures a moment of their daily life, where they have found food to sustain themselves.

…lues)

Clip: bicubic resize function
Clip: added save-to-bmp/pil for debugging and conversion from/to 32/8 images
Clip: added normalization with FP16 precision simulation (image tensors match HF implementation, can be switched off, only used for llava-1.6)
Clip: added newline tensor, mergetype kv, image-grid kv, new resize-pad function with resolution from gridpoints
Clip: clip_image_preprocess now returns a float * vector instead of float, this way llava 1.5 and 1.6 is supported
llava: added ggml cpu graph for embedding patching, added spatial_unpad preliminary support, added a lot of comments that need to be cleaned when all is final
convert-image-encoder: fixed image-grid flattening
@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 8, 2024

I just pushed, I was not able to finalize it completely and will be mostly busy for the weekend.
This PR now includes:

  • Clip: Bugfix for normalization (it previously did not load the 3 std and mean values, so all llava-1.5 image tensors were with incorrect normalization)
  • Clip: bicubic resize function
  • Clip: added save-to-bmp/pil for debugging and conversion from/to 32/8 images
  • Clip: added normalization with FP16 precision simulation (image tensors match HF implementation, can be switched off, only used for llava-1.6)
  • Clip: added newline tensor, mergetype kv, image-grid kv, new resize-pad function with resolution from gridpoints
  • Clip: clip_image_preprocess now returns a float * vector instead of float, this way llava 1.5 and 1.6 is supported
  • llava: added ggml cpu graph for embedding patching, added spatial_unpad preliminary support, added a lot of comments that need to be cleaned when all is final
  • convert-image-encoder: fixed image-grid flattening

I hope everything is compiling, it's 9AM so I could not test the PR.
It's a large jump in generation quality (as seen in the image demos above) but there still is one problem, I think it's with the tensor strides and would super appreciate if @ggerganov or @slaren maybe could take a look as you've much more tensor knowhow than me: as soon as the permutation of the grid-image embeddings (see at bottom) is applied it fails to inference properly. Without the permutation we get the above (quite good) results.

I have simplified the original python implementation due to the lack of 5 dimensional tensors in GGML, I first tested it on the python side and it did not result in a noticeable output quality drop.

I am uploading the quantized projectors again on HF, they need an update due to the grid array bugfix. the gguf of the LLMs are still compatible.
https://huggingface.co/cmp-nct/llava-1.6-gguf/tree/main
To create a clip+projector yourself, you need to add this into the vision config part:

  "image_aspect_ratio": "anyres",
  "image_crop_resolution": 224,
  "image_grid_pinpoints": [
    [
      336,
      672
    ],
    [
      672,
      336
    ],
    [
      672,
      672
    ],
    [
      1008,
      336
    ],
    [
      336,
      1008
    ]
  ],
  "mm_patch_merge_type": "spatial_unpad",
  "mm_projector_lr": null,
  "mm_projector_type": "mlp2x_gelu",

Running:
The llava-1.6 34B model creates very good results, in my tests GPT-4V level (with this implementation despite the hotfix)
You need a lot of context for the additional embeddings, except for that it's normal llava:
llava-cli.exe -m Q:\models\llava\llava-v1.6-vicuna-7b\ggml-model-f16-q_5_k.gguf --mmproj Q:\models\llava\llava-v1.6-vicuna-7b\mmproj-model-f16.gguf --image C:\temp\LICENSE_DEMO.jpg -p "Provide a full description in JSON format" -ngl 50 --temp 0 -n 500 -ngl 50 -c 4000
Mistral:
-e -p "<image>\nUSER:\nProvide a full description in JSON format\nASSISTANT:\n"
Yi-34B:
-e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat can be said about this image?<|im_end|><|im_start|>assistant\n

General problems:

  1. The llava1.6 model embedded ViT is different than the vanilla ViT they claim to use, I get different embeddings with both but similar output of the llm.
  2. I was not able to replicate the HF-CLIP embeddings (using a fully black image), this is a general issue with the ggml clip implementation, input tensors are exactly the same, output differs (might be not a problem though, maybe due to kernel differences ).
  3. Current ggml CLIP can not be run in batched mode when using a mmprojector, I think we could just extract the llm-projector into a secondary sequential graph for llava-1.6. Then batching could be used for the image-grid. - A 5-image grid (max) inferences in 150ms on my GPU, so it's quite fast without batching.

Implementation problems:

  1. The final embedding permutation fails, I have added the below line of code into the llava.cpp handle_patches() graph which just reverses the permutation again. So the grid-embeddings are added without the 24x24 patching. I think it's a strides error.
  2. GGML 5-dim tensor support would make it possible to run the unpad() routine which makes it more efficient (it removes some embedding tokens)

Here is the current Hotfix in llava.cpp to reverse the permutation again:

permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, permuted_cont, 0, 2, 1, 3)); // permute back to before - todo: fix bug

Details on the modification of the original implementation:
For tests I filled in some of the variables, so the below code is correct for any image with an about quadratic aspect ratio, llava-1.6 will split those into a 2x2 grid, so we've 4 grid embedding tensors + 1 base embedding tensor .
The 576 features from llava-1.5 are the 24x24 features in 1.6. 4096 is the number of embeddings per feature

    // Python reference for full unpad is basically:
        // base_image_feature = image_feature[0]
        // image_feature = image_feature[1:]
        // image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        // image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        // image_feature = unpad_image(image_feature, image_sizes[image_idx])
        // image_feature = torch.cat((
        //     image_feature,
        //     self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)
        // ), dim=-1)
        // image_feature = image_feature.flatten(1, 2).transpose(0, 1)
        // image_feature = torch.cat((base_image_feature, image_feature), dim=0)
               
        Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling))
        # original without unpadding:
        # image_feature = image_feature.view(2, 2, 24, 24, 4096)
        # image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
        # image_feature = image_feature.view(2, 24, 2, 24, 4096)
        # image_feature = image_feature.flatten(0, 3)

        # Final version in 4D without unpadding:
        Reshape to 4D tensor by merging the last two dimensions
        image_feature = image_feature.view(2, 2, 24, 24*4096)
        image_feature = image_feature.permute(0, 2, 1, 3).contiguous()
        image_feature = image_feature.view(-1, 4096)

@haotian-liu
Copy link

WOW thank you so much for implementing LLaVA-1.6 in llama.cpp!!!

One quick note about the prompt for our 34b model: Ideally the correct format should be

<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat can be said about this image?<|im_end|><|im_start|>assistant\n

And please let me know if there is anything I could do to help (but not about tensor in cpp😭).

@ggerganov
Copy link
Member

Yup, the state of the clip/llava implementation is not great - hopefully nobody uses this in production

I've added --skip-unknown to convert.py

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 13, 2024

Yup, the state of the clip/llava implementation is not great - hopefully nobody uses this in production

I've added --skip-unknown to convert.py

Sounds great!
I am using llava using ggml on a production-poc project, it works very well and stable.
Though I wrote my own tool that binds with llava.cpp and clip for that.
It's just the conversion process that's a struggle, also because every single llava release does it different
Most people will just download the GGUF and be happy with it

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be ok to merge - let's give it some time until tomorrow and if no issues are reported we can merge it

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 13, 2024

Thanks, looking forward to see a closure on that one. It was a big pain to get working :-)

@ggerganov
Copy link
Member

It would be very useful to add detailed instructions for LLaVA v1.6 in the README. I tried writing them, but I realized I don't know which is the correct CLIP model to use - I think I'm using the small one with 576 tokens embedding. Not sure

Anyway, if anyone figures out all the steps, please add open a PR

@ggerganov ggerganov merged commit aa23412 into ggml-org:master Feb 14, 2024
@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 14, 2024

I've updated the readme with detailed instructions and hints on llava-1.6

In general: I have used llava-surgery-v2 only a little bit outside llava-1.6, it is meant as a full replacement to llava-surgery.py.
It should now work with all models, safetensor as well as pytorch and with projectors and visual tensors in any locations but I've not tested it in great detail on non llava-1.6 yet.

Maybe after a while the original llava-surgery can be removed, once we know the new one works for everything

@Elbios
Copy link
Contributor

Elbios commented Feb 14, 2024

Hi all, I tried running this on latest master:

./bin/server -m ../models/mistral-7b-q_5_k.gguf --mmproj ../models/mmproj-mistral7b-f16-q6_k.gguf -ngl 50 -c 6000 --host 0.0.0.0 --port 8007 --no-mmap

and got:

llama_new_context_with_model: graph splits (measure): 3
Available slots:
-> Slot 0 - max context: 6000
{"timestamp":1707926446,"level":"INFO","function":"main","line":2623,"message":"model loaded"}
all slots are idle and system prompt is empty, clear the KV cache
slot 0 - loaded image
slot 0 is processing [task id: 0]
slot 0 : kv cache rm - [0, end)
slot 0 - encoding image [id: 1]
munmap_chunk(): invalid pointer
Aborted

and traced it down to a memory management error. Fixed it here:

#5491

Great work on this btw, llava 1.6 is fantastic!

@jpohhhh
Copy link

jpohhhh commented Feb 14, 2024

Similarly, possible fix for bad access crash on wide images: #5493

And got the same impression in testing: definitely a huge step up :)

@chigkim
Copy link

chigkim commented Feb 15, 2024

@ggerganov @cmp-nct This is amazing! Thanks!
Does this work on main cli as well as server, or someone has to modify them?
Main cli has --mmproj and --image options, and server has api to send image as well as web UI to upload an image in the chat mode.
What's the difference between using it with main and llava-cli other than llava-cli gives much fewer options?

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 15, 2024

@ggerganov @cmp-nct This is amazing! Thanks! Does this work on main cli as well as server, or someone has to modify them? Main cli has --mmproj and --image options, and server has api to send image as well as web UI to upload an image in the chat mode. What's the difference between using it with main and llava-cli other than llava-cli gives much fewer options?
Update:
It should work on llava-cli.
I personally only tested it on llava-cli in it's latest iteration
CLIP does not come with many options, so both are equal

Make sure to update to the latest commits, several bugs have been corrected

For server I just found this bug report:
#5514

For some reason server is processing the image instead of using processing functions already available in llava.cpp, maybe that's historic weight.
server.cpp needs an update as described in the issue (I am stuck with another project atm)

@chigkim
Copy link

chigkim commented Feb 15, 2024

Just curious...
Why does llava-cli exist separately when main cli has --mmproj and --image flags?
Wouldn't it be better to fold llava-cli into main cli?

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 15, 2024

Just curious... Why does llava-cli exist separately when main cli has --mmproj and --image flags? Wouldn't it be better to fold llava-cli into main cli?

I've not actually tested but last time I looked main did not support that.
The command line arguments are in "common" and they are not modularized, so every command line argument is shared with every example. You also have like 40 command line options when using llava-cli -h, barely any of those will have an effect.

@chigkim
Copy link

chigkim commented Feb 16, 2024

Thanks so much for the explanation . You're totally right. I just tried main, and it doesn't work.

I got llava-cli with llava 1.6 13B to work on free Colab t4 , and hopefully someone can fix server to work properly with llava 1.6. That would be amazing!

https://colab.research.google.com/gist/chigkim/c44dcc37af26f1cb3af03a2209d7c50a/llava16.ipynb

More importantly, I believe Ollama also uses llama.cpp server.

@jxy
Copy link
Contributor

jxy commented Feb 16, 2024

I had two issues with the convert-image-encoder-to-gguf.py.

  1. The correct argument seems to be --clip-model-is-vision instead of the one with underscore _ in README.
  2. I used https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/config.json but I needed to change the python file
diff --git a/examples/llava/convert-image-encoder-to-gguf.py b/examples/llava/convert-image-encoder-to-gguf.py
index c69f89ac..94754a47 100644
--- a/examples/llava/convert-image-encoder-to-gguf.py
+++ b/examples/llava/convert-image-encoder-to-gguf.py
@@ -117,7 +117,7 @@ else:
 with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
     config = json.load(f)
     if args.clip_model_is_vision:
-        v_hparams = config
+        v_hparams = config["vision_config"]
         t_hparams = None
     else:
         v_hparams = config["vision_config"]

@cmp-nct
Copy link
Contributor Author

cmp-nct commented Feb 16, 2024

I had two issues with the convert-image-encoder-to-gguf.py.

  1. The correct argument seems to be --clip-model-is-vision instead of the one with underscore _ in README.
  2. I used https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/config.json but I needed to change the python file
diff --git a/examples/llava/convert-image-encoder-to-gguf.py b/examples/llava/convert-image-encoder-to-gguf.py
index c69f89ac..94754a47 100644
--- a/examples/llava/convert-image-encoder-to-gguf.py
+++ b/examples/llava/convert-image-encoder-to-gguf.py
@@ -117,7 +117,7 @@ else:
 with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
     config = json.load(f)
     if args.clip_model_is_vision:
-        v_hparams = config
+        v_hparams = config["vision_config"]
         t_hparams = None
     else:
         v_hparams = config["vision_config"]

The readme was updated but I think just merged today.
The config I had in the HF was for a full model but not for the extracted ViT. The current readme points to the right file, so your modification won't be necessary.

The entire llava-1.6 change including the last minute refactors were a bit much for me to push at once, I've had dozens of variants local, that's how that issue sneaked in.

At this point llava-cli appears to work flawless, server does not. Server needs an update to use llava.cpp preprocessing functions, it also needs an update to allow flexible system prompt and finetune syntax. It's minor high level work anyone should be able to do.
If no one does it I'll try find time for it next week.

@cjpais
Copy link
Contributor

cjpais commented Feb 17, 2024

At this point llava-cli appears to work flawless, server does not. Server needs an update to use llava.cpp preprocessing functions, it also needs an update to allow flexible system prompt and finetune syntax. It's minor high level work anyone should be able to do. If no one does it I'll try find time for it next week.

Preprocessing step should be implemented in #5553, thanks for the insight into the problem on the other threads.

it also needs an update to allow flexible system prompt and finetune syntax

Could you explain more on this, happy to try to include it as well

@arkohut

This comment was marked as duplicate.

@svenstaro
Copy link

svenstaro commented Mar 4, 2024

@cjpais Would it be reasonable to update the description on your HF repo? I think these are probably the most proper quants for LLaVA v1.6 around and fairly complete. I think the disclaimer should be toned down or removed to reflect that.

@cjpais
Copy link
Contributor

cjpais commented Mar 5, 2024

@svenstaro, updated

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* Create llava-survery-v2.py

* Update convert-image-encoder-to-gguf.py

* Update convert-image-encoder-to-gguf.py

* Rename llava-survery-v2.py to llava-surgery-v2.py

* Update convert-image-encoder-to-gguf.py

will now search for projector

* Update convert-image-encoder-to-gguf.py

whoops

* Update llava-surgery-v2.py

* Clip: Bugfix for normalization (it did not loat the 3 std and mean values)
Clip: bicubic resize function
Clip: added save-to-bmp/pil for debugging and conversion from/to 32/8 images
Clip: added normalization with FP16 precision simulation (image tensors match HF implementation, can be switched off, only used for llava-1.6)
Clip: added newline tensor, mergetype kv, image-grid kv, new resize-pad function with resolution from gridpoints
Clip: clip_image_preprocess now returns a float * vector instead of float, this way llava 1.5 and 1.6 is supported
llava: added ggml cpu graph for embedding patching, added spatial_unpad preliminary support, added a lot of comments that need to be cleaned when all is final
convert-image-encoder: fixed image-grid flattening

* whitespace corrections

* ws

* Tensors are now properly permuted.
Before the embeddings were inserted 1:1, now they are split into the 24x24 patches as in reference.

* ws

* added verbose_prompt support into cli
added stopwords for llava-1.6 into cli

* moved llava functions to llava.cpp, made clip.h C compatible API, replaced vector style functions with pointers, added a debug define to remove functions from compilation while not needed

* ws

* convert : skip unknown tensors (need for LLaVA)

* llava : update readme

* llava : fix compile warnings

* llava : style

* convert : add --skip-unknown CLI arg

* server : remove clip structs

* bugfix for non llava-1.6

It should now work with llava-1.5 as well

* clip : minor code rearrange

* llava : update readme a bit

---------

Co-authored-by: John <cmt-nct@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* Create llava-survery-v2.py

* Update convert-image-encoder-to-gguf.py

* Update convert-image-encoder-to-gguf.py

* Rename llava-survery-v2.py to llava-surgery-v2.py

* Update convert-image-encoder-to-gguf.py

will now search for projector

* Update convert-image-encoder-to-gguf.py

whoops

* Update llava-surgery-v2.py

* Clip: Bugfix for normalization (it did not loat the 3 std and mean values)
Clip: bicubic resize function
Clip: added save-to-bmp/pil for debugging and conversion from/to 32/8 images
Clip: added normalization with FP16 precision simulation (image tensors match HF implementation, can be switched off, only used for llava-1.6)
Clip: added newline tensor, mergetype kv, image-grid kv, new resize-pad function with resolution from gridpoints
Clip: clip_image_preprocess now returns a float * vector instead of float, this way llava 1.5 and 1.6 is supported
llava: added ggml cpu graph for embedding patching, added spatial_unpad preliminary support, added a lot of comments that need to be cleaned when all is final
convert-image-encoder: fixed image-grid flattening

* whitespace corrections

* ws

* Tensors are now properly permuted.
Before the embeddings were inserted 1:1, now they are split into the 24x24 patches as in reference.

* ws

* added verbose_prompt support into cli
added stopwords for llava-1.6 into cli

* moved llava functions to llava.cpp, made clip.h C compatible API, replaced vector style functions with pointers, added a debug define to remove functions from compilation while not needed

* ws

* convert : skip unknown tensors (need for LLaVA)

* llava : update readme

* llava : fix compile warnings

* llava : style

* convert : add --skip-unknown CLI arg

* server : remove clip structs

* bugfix for non llava-1.6

It should now work with llava-1.5 as well

* clip : minor code rearrange

* llava : update readme a bit

---------

Co-authored-by: John <cmt-nct@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

help wanted Needs help from the community

Projects

None yet

Development

Successfully merging this pull request may close these issues.