Skip to content

Add Phi-4 Backbone #2272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,16 @@
from keras_hub.src.models.phi3.phi3_tokenizer import (
Phi3Tokenizer as Phi3Tokenizer,
)
from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone as Phi4Backbone
from keras_hub.src.models.phi4.phi4_causal_lm import (
Phi4CausalLM as Phi4CausalLM,
)
from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
Phi4CausalLMPreprocessor as Phi4CausalLMPreprocessor,
)
from keras_hub.src.models.phi4.phi4_tokenizer import (
Phi4Tokenizer as Phi4Tokenizer,
)
from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor
from keras_hub.src.models.qwen.qwen_backbone import (
QwenBackbone as Qwen2Backbone,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from keras_hub.src.models.phi3.phi3_tokenizer import (
Phi3Tokenizer as Phi3Tokenizer,
)
from keras_hub.src.models.phi4.phi4_tokenizer import (
Phi4Tokenizer as Phi4Tokenizer,
)
from keras_hub.src.models.qwen.qwen_tokenizer import (
QwenTokenizer as Qwen2Tokenizer,
)
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/phi4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO: Add a register_presets call once phi4_presets.py is implemented.
63 changes: 63 additions & 0 deletions keras_hub/src/models/phi4/phi4_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone


@keras_hub_export("keras_hub.models.Phi4Backbone")
class Phi4Backbone(Phi3Backbone):
"""Phi-4 core network with hyperparameters.

This network implements a Transformer-based decoder network,
Phi-4, as described in ["Phi-4 Technical Report"](https://arxiv.org/pdf/2412.08905).
It includes the embedding lookups and transformer layers.

The default constructor gives a fully customizable, randomly initialized
phi-4 model with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.

Note that the defaults here are the Phi-3 defaults, because the Phi-4 model
follows the Phi-3-medium architecture but with different hyper-parameters.
Use `keras_hub.models.Backbone.from_preset` to get the Phi-4 defaults.

Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
hidden_dim: int. The size of the embeddings and the hidden states of
the transformer layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a three-layer feedforward network for each transformer.
num_query_heads: int. The number of query attention heads for each
transformer layer.
num_key_value_heads: int. The number of key and value attention heads
for each transformer layer.
layer_norm_epsilon: float, optional. Epsilon for the RMS layernorm
layers in the transformer decoder. Defaults to `1e-6`.
dropout:: float, optional. Dropout probability for the Transformer
decoder.
max_sequence_length: int, optional. The maximum sequence length
that this model might ever be used with. Defaults to `4096`.
pretraining_sequence_length: int, optional. The maximum sequence length
that the model was pretrained with. Defaults to `4096`.
rope_max_wavelength: int, optional. The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_type: str, optional. The type of the rope scaling. Can be
either `None` or `"su"`. `None` is for no rope scaling, `"su"` is
for SuScaled rope, `"su"` is used when `max_sequence_length` is
larger than `original_max_sequence_length`. Defaults to `None`.
rope_scaling_short_factor: list[float]. List of factors used to adjust
rope frequencies when the `rope_scaling_type` is `"su"`. List must
be of length `hidden_dim//num_query_heads//2`. It is used when
`sequence_length` is smaller than `pretraining_sequence_length`.
Defaults to `None`.
rope_scaling_long_factor: list[float]. List of factors used to adjust
rope frequencies when the `rope_scaling_type` is `"su"`. List must
be of length `hidden_dim//num_query_heads//2`. It is used when
`sequence_length` is larger than `pretraining_sequence_length`.
Defaults to `None`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
"""

pass
92 changes: 92 additions & 0 deletions keras_hub/src/models/phi4/phi4_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from keras import ops

from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
from keras_hub.src.tests.test_case import TestCase


class Phi4Test(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_query_heads": 4,
"num_key_value_heads": 2,
"hidden_dim": 8,
"intermediate_dim": 8,
}
self.su_rotary_init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_query_heads": 2,
"num_key_value_heads": 1,
"hidden_dim": 8,
"intermediate_dim": 12,
"max_sequence_length": 10,
"pretraining_sequence_length": 5,
"rope_scaling_type": "su",
"rope_scaling_short_factor": [1.2, 1.4],
"rope_scaling_long_factor": [0.8, 0.6],
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=Phi4Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=Phi4Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

def test_backbone_basics_with_su_rotary(self):
self.run_backbone_test(
cls=Phi4Backbone,
init_kwargs=self.su_rotary_init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
)

@pytest.mark.large
def test_saved_model_with_su_rotary(self):
self.run_model_saving_test(
cls=Phi4Backbone,
init_kwargs=self.su_rotary_init_kwargs,
input_data=self.input_data,
)

@pytest.mark.extra_large
def test_smallest_preset(self):
self.run_preset_test(
cls=Phi4Backbone,
preset="phi4_mini_4k_instruct_en",
input_data={
"token_ids": ops.array([[1, 450, 4996, 1701, 29916, 29889]]),
"padding_mask": ops.ones((1, 6), dtype="int32"),
},
expected_output_shape=(1, 6, 3072),
# The forward pass from a preset should be stable!
# Reference values computed using PyTorch HF model.
expected_partial_output=ops.array(
[-0.21222, 0.04004, -0.02759, 0.02200]
),
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Phi4Backbone.presets:
self.run_preset_test(
cls=Phi4Backbone,
preset=preset,
input_data=self.input_data,
)
Comment on lines +85 to +92
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually how big these models will be and how many presets are we testing here?

33 changes: 33 additions & 0 deletions keras_hub/src/models/phi4/phi4_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
Phi4CausalLMPreprocessor,
)


@keras_hub_export("keras_hub.models.Phi4CausalLM")
class Phi4CausalLM(Phi3CausalLM):
"""An end-to-end Phi4 model for causal language modeling.

A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a Phi-4 model, simply by calling `fit()`.

This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_hub.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.

Args:
backbone: A `keras_hub.models.Phi4Backbone` instance.
preprocessor: A `keras_hub.models.Phi4CausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
"""

backbone_cls = Phi4Backbone
preprocessor_cls = Phi4CausalLMPreprocessor
76 changes: 76 additions & 0 deletions keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer


@keras_hub_export("keras_hub.models.Phi4CausalLMPreprocessor")
class Phi4CausalLMPreprocessor(CausalLMPreprocessor):
"""Phi4 Causal LM preprocessor.

This preprocessing layer is meant for use with
`keras_hub.models.Phi4CausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence.

For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_hub.models.Phi4CausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).

Args:
tokenizer: A `keras_hub.models.Phi4Tokenizer` instance.
sequence_length: The length of the packed inputs.
add_start_token: If `True`, the preprocessor will prepend the tokenizer
start token to each input sequence. Default is `True`.
add_end_token: If `True`, the preprocessor will append the tokenizer
end token to each input sequence. Default is `False`.

Call arguments:
x: A string, `tf.Tensor` or list of python strings.
y: Label data. Should always be `None` as the layer generates labels.
sample_weight: Label weights. Should always be `None` as the layer
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_hub.models.Phi4CausalLMPreprocessor.from_preset(
"phi4_mini_4k_instruct_en"
)

# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])

# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

backbone_cls = Phi4Backbone
tokenizer_cls = Phi4Tokenizer
92 changes: 92 additions & 0 deletions keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest

from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
Phi4CausalLMPreprocessor,
)
from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer
from keras_hub.src.tests.test_case import TestCase


class Phi4CausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += [
"<s>",
"</s>",
"<pad>",
"<im_start>",
"<im_sep>",
"<im_end>",
]
self.vocab += ["<fim_prefix>", "<fim_middle>", "<fim_suffix>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.tokenizer = Phi4Tokenizer(
vocabulary=self.vocab, merges=self.merges
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 10,
}
# [1, 3, 4, 2, 5]
self.input_data = (["airplane at airport"],)

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=Phi4CausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[6, 1, 3, 4, 2, 5, 0, 0, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
},
[[1, 3, 4, 2, 5, 0, 0, 0, 0, 7]],
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
),
)

def test_no_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = Phi4CausalLMPreprocessor(
**self.init_kwargs,
add_start_token=False,
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(
x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0, 0, 0]] * 4
)
self.assertAllEqual(
x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4
)
self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4)

def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

def test_generate_postprocess(self):
input_data = {
"token_ids": [1, 3, 4, 2, 5, 3, 9, 7, 11, 0],
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
}
preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "airplane at airport")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Phi4CausalLMPreprocessor.presets:
self.run_preset_test(
cls=Phi4CausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)
Loading
Loading