KEMBAR78
feat: Support bloom models by xingchensong · Pull Request #3553 · ggml-org/llama.cpp · GitHub
Skip to content

Conversation

@xingchensong
Copy link
Contributor

@xingchensong xingchensong commented Oct 9, 2023

This is a follow-up PR, plz see ggml-org/ggml#543

Test Script

./build/bin/main -m models/bloom-1b7.fp16.gguf \
  -p "Building a website can be done in 10 simple steps:\nStep 1:" \
  -n 100 -e --temp 1.0 --top-k 1 --top-p 1.0 \
  --repeat-last-n 0 -s 2023

Tested Models

TODO

  • PPL test

@xingchensong
Copy link
Contributor Author

PPL Test

Test script for torch fp32

import argparse

import torch
from transformers import AutoConfig, BloomForCausalLM, BloomTokenizerFast


def calculate_ppl(
    device, model, tokenizer, sentence: str, max_length: int = 100, stride: int = 50
) -> float:
    sentence_ids = tokenizer.encode(sentence)  # do not add bos_token_id
    print(sentence_ids)
    seq_len = len(sentence_ids)

    nlls = []
    for begin_loc in range(0, seq_len, stride):
        end_loc = min(begin_loc + max_length + stride // 2, seq_len)
        if (end_loc - begin_loc) != (max_length + stride // 2):
            break
        input_ids = sentence_ids[begin_loc:end_loc]
        input_ids = torch.tensor([input_ids])
        input_ids = input_ids.to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-stride] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        if end_loc == seq_len:
            break

    ggml_nlls = torch.cumsum(torch.stack(nlls) * stride, dim=0)
    count = torch.arange(stride, len(nlls) * stride + stride, stride)
    chunk_ppls = torch.exp(ggml_nlls / count).cpu().tolist()
    for i, ppl in enumerate(chunk_ppls):
        print("[{}] {}".format(i + 1, ppl))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="")

    args = parser.parse_args()

    device = torch.device("cpu")

    tokenizer = BloomTokenizerFast.from_pretrained(args.model_name_or_path)

    model_config = AutoConfig.from_pretrained(
        args.model_name_or_path, trust_remote_code=True
    )

    model = BloomForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.float16,
        config=model_config,
        device_map="auto",
    )

    model.to(device).float()  # type: ignore
    model.eval()  # type: ignore

    sens = ["About six million children are reported to child protection agencies in America each year. About 400,000 of those children are placed in protective custody because of severe neglect or abuse. About 500,000 children are placed into foster care and adoptive placements. Abused and neglected children are all around us. These children are invisible in our community, yet each one of us is directly responsible for their plight. They live under our laws; they go to our schools; they are convicted by our courts; many of them spend lifetimes in our prisons. They have no say in the laws and policies that rule their lives. Just like they had no say in the neglect and abuse that was their childhood. Neglected and abused children make up a great majority of the crime, drugs, and violence we experience in our communities. Over fifty percent of the children in the juvenile justice system have diagnosable mental illness, about thirty percent of children in child protection services are proscribed psychotropic medications, & almost eighty percent of youth aging out of foster care lead dysfunctional lives. Ninety percent of the juveniles in the Juvenile Justice System have come out of the Child Protection System (Minnesota’s Chief Justice, Kathleen Blatz). Over 90 percent of the adults in the Criminal Justice System come out of the Juvenile Justice System. Justice Blatz (and others) call it a prison “feeder” system. The United States is the only nation in the world to build prisons based on failed third grade reading scores or the number of children in Child Protection. Children are not aware of the rightness or wrongness of their own abuse. They do not know that abuse is abnormal, or even that it is wrong. To a five-year-old, no matter how painful and frightening her life is, her life is normal. A sad and lasting fact of child abuse is that children blame themselves for the abuse they receive. How can sex, drugs, and violence be unlearned by a ten year old child whose entire life has been just that? It takes years of therapy to change a child’s perception of an abusive past. It takes a great deal longer for an abused child to develop a healthy view of the world and a positive self-image. There is no book a child can go to, or code they are born with, that explains the abnormality of what is happening to them. Children can’t call their senators, or complain to the authorities (they can’t even tell their parents). Behaviors learned by abused children to stay alive in toxic homes are terribly counter-productive once the child is out of the abusive circumstances and trying to live a normal life. The behaviors developed for staying alive and avoiding pain dominate and thus can become significant detriments to getting along in society. As a matter of fact, for many troubled youth, their explosive responses and pain avoidance behaviors define them as uneducated social misfits with criminal histories."]
    for _, sen in enumerate(sens):
        calculate_ppl(
            device, model, tokenizer, sen, max_length=100, stride=50
        )

Test script for ggml fp16/q4_1

./build/bin/perplexity -m models/bloom-1b7.fp16.gguf \
  -p "About six million children are reported to child protection agencies in America each year. About 400,000 of those children are placed in protective custody because of severe neglect or abuse. About 500,000 children are placed into foster care and adoptive placements. Abused and neglected children are all around us. These children are invisible in our community, yet each one of us is directly responsible for their plight. They live under our laws; they go to our schools; they are convicted by our courts; many of them spend lifetimes in our prisons. They have no say in the laws and policies that rule their lives. Just like they had no say in the neglect and abuse that was their childhood. Neglected and abused children make up a great majority of the crime, drugs, and violence we experience in our communities. Over fifty percent of the children in the juvenile justice system have diagnosable mental illness, about thirty percent of children in child protection services are proscribed psychotropic medications, & almost eighty percent of youth aging out of foster care lead dysfunctional lives. Ninety percent of the juveniles in the Juvenile Justice System have come out of the Child Protection System (Minnesota’s Chief Justice, Kathleen Blatz). Over 90 percent of the adults in the Criminal Justice System come out of the Juvenile Justice System. Justice Blatz (and others) call it a prison “feeder” system. The United States is the only nation in the world to build prisons based on failed third grade reading scores or the number of children in Child Protection. Children are not aware of the rightness or wrongness of their own abuse. They do not know that abuse is abnormal, or even that it is wrong. To a five-year-old, no matter how painful and frightening her life is, her life is normal. A sad and lasting fact of child abuse is that children blame themselves for the abuse they receive. How can sex, drugs, and violence be unlearned by a ten year old child whose entire life has been just that? It takes years of therapy to change a child’s perception of an abusive past. It takes a great deal longer for an abused child to develop a healthy view of the world and a positive self-image. There is no book a child can go to, or code they are born with, that explains the abnormality of what is happening to them. Children can’t call their senators, or complain to the authorities (they can’t even tell their parents). Behaviors learned by abused children to stay alive in toxic homes are terribly counter-productive once the child is out of the abusive circumstances and trying to live a normal life. The behaviors developed for staying alive and avoiding pain dominate and thus can become significant detriments to getting along in society. As a matter of fact, for many troubled youth, their explosive responses and pain avoidance behaviors define them as uneducated social misfits with criminal histories." \
  --ppl-stride 50 -c 100 -b 512 -s 2023

Results

Model chunk-0 chunk-1 chunk-2 chunk-3 chunk-4 chunk-5 chunk-6 chunk-7 chunk-8 chunk-9
torch fp32 11.9603 11.6657 14.0454 14.4815 15.9778 16.6827 16.0810 15.7785 16.6674 17.1874
ggml fp16 (4.2GB) 11.9615 11.6673 14.0466 14.4828 15.9786 16.9979 16.1121 15.9386 17.3866 17.9251
ggml q4_1 (1.5GB) 12.4996 12.3940 14.6400 15.3892 16.9743 18.2565 17.1820 17.0910 18.4840 18.8846
Model chunk-0 chunk-1 chunk-2 chunk-3 chunk-4 chunk-5 chunk-6 chunk-7 chunk-8 chunk-9
torch fp32 24.8323 26.7774 30.0789 39.0886 41.8020 40.5517 37.7743 38.0244 38.2512 39.4991
ggml fp16 (2.7GB) 24.8268 26.7805 30.0809 39.0914 41.8034 40.5512 37.7735 38.0241 38.2522 39.5002
ggml q4_1 (855MB) 26.0862 28.5914 31.6534 40.4253 43.3038 42.3877 39.7752 40.3762 40.6543 41.8351

@xingchensong
Copy link
Contributor Author

PPL results look good to me, I think this PR is ready for a final review :)), @ggerganov

@ggerganov
Copy link
Member

Nice job. This still lacks tensor offloading for GPU support, but we can fix this later.
I'll review and merge this PR after the #3417 is merged

@ggerganov ggerganov added model Model specific need feedback Testing and feedback with results are needed labels Oct 9, 2023
@ggerganov
Copy link
Member

ggerganov commented Oct 10, 2023

Tested on M2 Ultra using Metal - seems to work as expected:

./main -m ./models/bloom-1b/ggml-model-f16.gguf -p "I believe the meaning of life is" --ignore-eos -n 64 -t 4 -ngl 1 -s 1

llama_new_context_with_model: compute buffer total size = 500.13 MB
llama_new_context_with_model: max tensor size =   980.00 MB
ggml_metal_add_buffer: allocated 'data            ' buffer, size =  4279.47 MB, ( 4280.09 / 147456.00)
ggml_metal_add_buffer: allocated 'kv              ' buffer, size =    98.00 MB, ( 4378.09 / 147456.00)
ggml_metal_add_buffer: allocated 'alloc           ' buffer, size =   494.02 MB, ( 4872.11 / 147456.00)

system_info: n_threads = 4 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0

I believe the meaning of life is determined not by an individual's physical, spiritual or mental well-being but rather their place in a more meaningful context.
The term holistic wellbeing was first coined to describe the concept that people should be healthy and happy as individuals without being forced into health care programs (Barnett & Jones, 2006) . To achieve this

llama_print_timings:        load time =     216.92 ms
llama_print_timings:      sample time =     330.09 ms /    64 runs   (    5.16 ms per token,   193.89 tokens per second)
llama_print_timings: prompt eval time =      20.13 ms /     7 tokens (    2.88 ms per token,   347.81 tokens per second)
llama_print_timings:        eval time =     609.28 ms /    63 runs   (    9.67 ms per token,   103.40 tokens per second)
llama_print_timings:       total time =    1002.27 ms

//////////////////

./main -m ./models/bloom-1b/ggml-model-q4_0.gguf -p "I believe the meaning of life is" --ignore-eos -n 64 -t 4 -ngl 1 -s 1

llama_new_context_with_model: compute buffer total size = 500.13 MB
llama_new_context_with_model: max tensor size =   401.95 MB
ggml_metal_add_buffer: allocated 'data            ' buffer, size =  1341.05 MB, ( 1341.67 / 147456.00)
ggml_metal_add_buffer: allocated 'kv              ' buffer, size =    98.00 MB, ( 1439.67 / 147456.00)
ggml_metal_add_buffer: allocated 'alloc           ' buffer, size =   494.02 MB, ( 1933.69 / 147456.00)

system_info: n_threads = 4 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0


I believe the meaning of life is finding the right partner. And that you just don't know how to find it when you're young or mature," she said.
"You really need a mentor who will give you guidance and direction in your relationships - whether it's with friends, family, partners, children.
"My advice would be to trust yourself enough so not to let

llama_print_timings:        load time =     167.38 ms
llama_print_timings:      sample time =     319.66 ms /    64 runs   (    4.99 ms per token,   200.21 tokens per second)
llama_print_timings: prompt eval time =      21.48 ms /     7 tokens (    3.07 ms per token,   325.84 tokens per second)
llama_print_timings:        eval time =     402.63 ms /    63 runs   (    6.39 ms per token,   156.47 tokens per second)
llama_print_timings:       total time =     786.13 ms

@ggerganov ggerganov merged commit 02d2875 into ggml-org:master Oct 10, 2023
@xingchensong xingchensong deleted the xcsong-bloom branch October 10, 2023 14:50
joelkuiper added a commit to vortext/llama.cpp that referenced this pull request Oct 12, 2023
…example

* 'master' of github.com:ggerganov/llama.cpp: (34 commits)
  examples: support LLaVA v1.5 (multimodal model) (ggml-org#3436)
  docs : fix typo GOMP_CPU_AFFINITY (ggml-org#3597)
  cmake : fix add_compile_options on macOS
  typo : it is `--n-gpu-layers` not `--gpu-layers` (ggml-org#3592)
  ci : check if there is enough VRAM (ggml-org#3596)
  server : add completion mode (no chat) (ggml-org#3582)
  prompts : add mnemonics.txt
  server : fix kv cache management (ggml-org#3588)
  main : fix session loading bug (ggml-org#3400)
  server : add parameter -tb N, --threads-batch N (ggml-org#3584)
  common : fix mirostat state when using multiple sequences (ggml-org#3543)
  batched : add bench tool (ggml-org#3545)
  examples : add batched.swift + improve CI for swift (ggml-org#3562)
  Add MPT model to supported models in README.md (ggml-org#3574)
  Minor improvements in GPT2 tokenizer (ggml-org#3567)
  readme : add bloom (ggml-org#3570)
  llm : add bloom models (ggml-org#3553)
  swift : improvements and fixes (ggml-org#3564)
  llm : add MPT support (ggml-org#3417)
  infill. : fix tokenization (ggml-org#3508)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific need feedback Testing and feedback with results are needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants