llama : fix llama3.1 rope_freqs not respecting custom head_dim (#9141)

* fix: llama3.1 rope_freqs not respecting custom head_dim

* fix: use potential head_dim for Exaone
This commit is contained in:
Carsten Kragelund Jørgensen 2024-08-27 08:53:40 +02:00 committed by GitHub
parent ad76569f8e
commit 75e1dbbaab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View file

@ -1572,7 +1572,7 @@ class LlamaModel(Model):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = rope_scaling.get("factor", 8.0)
@ -3820,7 +3820,7 @@ class ExaoneModel(Model):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = rope_scaling.get("factor", 8.0)