diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py index c5c5e903..22e00e91 100644 --- a/ch05/07_gpt_to_llama/tests/tests.py +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -20,8 +20,10 @@ transformers_version = transformers.__version__ -# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py +# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py # LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE + + def litgpt_build_rope_cache( seq_len: int, n_elem: int,