Skip to content

Conversation

@Cui-yshoho
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Adds # (feature)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@Cui-yshoho Cui-yshoho requested a review from vigo999 as a code owner November 7, 2025 10:06
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the Arcee model family into the mindone/transformers library. It provides a complete implementation of the Arcee architecture, including its configuration and various model heads for tasks such as causal language modeling, sequence classification, question answering, and token classification. The integration also includes updates to the auto-modeling system and new unit tests to validate the model's functionality and numerical precision.

Highlights

  • New Model Integration: The Arcee model architecture, including its core components and various task-specific heads (CausalLM, Sequence Classification, Question Answering, Token Classification), has been added to the mindone/transformers library.
  • Modular Implementation: The Arcee model's configuration and core components are defined in modular_arcee.py, which then generates modeling_arcee.py, ensuring a structured and maintainable codebase.
  • Auto-Modeling Support: The new Arcee models are integrated into the auto-modeling framework, allowing for easy instantiation and usage across different tasks.
  • Comprehensive Testing: Dedicated unit tests have been added for the Arcee model, including precision comparison tests against PyTorch implementations, to ensure correctness and reliability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Arcee model, a new transformer architecture. The changes include the model implementation, integration with auto-classes, and corresponding tests. My review has identified a few issues that need attention. Most critically, there's a structural problem with the model's configuration file being missing, while an incorrect and unused modular_arcee.py file is present. Additionally, there's a potential bug in parameter passing within the model, a misleading example in a docstring, and a style issue with a wildcard import. Addressing these points will improve the correctness, clarity, and maintainability of the new model.

Comment on lines +1 to +228
# coding=utf-8
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MindSpore Arcee model."""

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.utils import auto_docstring, logging

from ..llama.modeling_llama import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
)
from ..nemotron.modeling_nemotron import NemotronMLP

logger = logging.get_logger(__name__)


class ArceeConfig(LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the AFM-4.5B-Base.

Pre-trained weights are available at
[arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
and were used to build the examples below.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`ArceeModel`]
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 18432):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 128000):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 128001):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'yarn'. The original max position embeddings used during pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads

```python
>>> from mindone.transformers import ArceeModel
>>> from transformers import ArceeConfig

>>> # Initializing an Arcee AFM-4.5B-Base style configuration
>>> configuration = ArceeConfig()

>>> # Initializing a model from the AFM-4.5B-Base style configuration
>>> model = ArceeModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "arcee"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
vocab_size=32000,
hidden_size=2560,
intermediate_size=18432,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="relu2",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=128000,
eos_token_id=128001,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
mlp_bias=mlp_bias,
head_dim=head_dim,
**kwargs,
)

del self.pretraining_tp


class ArceeMLP(NemotronMLP):
pass


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForCausalLM(LlamaForCausalLM):
pass


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForSequenceClassification(LlamaForSequenceClassification):
pass


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
pass


@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForTokenClassification(LlamaForTokenClassification):
pass


__all__ = [
"ArceeConfig",
"ArceeForCausalLM",
"ArceeForQuestionAnswering",
"ArceeForSequenceClassification",
"ArceeForTokenClassification",
"ArceeModel", # noqa: F822
"ArceePreTrainedModel", # noqa: F822
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file seems to be incorrect and is causing several issues:

  1. Missing Configuration: The modeling_arcee.py file imports ArceeConfig from transformers.models.arcee.configuration_arcee, but this configuration file is missing from the pull request. The ArceeConfig class defined in this modular_arcee.py file seems to be what's needed.
  2. Contradictory Definitions: This file defines model classes like ArceeForCausalLM that inherit from LlamaForCausalLM. This contradicts the full implementation provided in modeling_arcee.py.
  3. Unused File: This file is not imported or used anywhere in the codebase, while modeling_arcee.py is. The warning at the top of modeling_arcee.py about it being auto-generated seems to be a red herring, as this modular file is not its source.

To fix this, I recommend the following:

  1. Rename this file to mindone/transformers/models/arcee/configuration_arcee.py.
  2. Remove all model definitions (e.g., ArceeMLP, ArceeForCausalLM) from it, keeping only the ArceeConfig class.
  3. Update the import in mindone/transformers/models/arcee/modeling_arcee.py to from .configuration_arcee import ArceeConfig.
  4. Export ArceeConfig from mindone/transformers/models/arcee/__init__.py and mindone/transformers/__init__.py to make it accessible.

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_arcee import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code clarity and to avoid polluting the namespace, it's recommended to use explicit imports instead of a wildcard import. This makes it clear which names are being imported from the module.

Suggested change
from .modeling_arcee import *
from .modeling_arcee import (
ArceeForCausalLM,
ArceeForQuestionAnswering,
ArceeForSequenceClassification,
ArceeForTokenClassification,
ArceeModel,
ArceePreTrainedModel,
)

hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The position_ids parameter is passed to self.self_attn, but the ArceeAttention.construct method does not accept it in its signature. The rotary position embeddings are already computed in the ArceeModel and passed down as position_embeddings. Passing an unused position_ids parameter can be confusing and may lead to bugs if the signature of the called method changes. It's better to remove it from this call.

Comment on lines +446 to +448
>>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model name in the example docstring appears to be a copy-paste error from a different model. It should be consistent with the checkpoint name arcee-ai/AFM-4.5B specified in the @auto_docstring decorator to avoid confusion for users.

Suggested change
>>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
>>> model = ArceeForCausalLM.from_pretrained("arcee-ai/AFM-4.5B")
>>> tokenizer = AutoTokenizer.from_pretrained("arcee-ai/AFM-4.5B")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant