diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/__init__.py b/picotron/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index b063fd6..a612574 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -7,10 +7,10 @@ from safetensors import safe_open import contextlib -from picotron.utils import assert_no_meta_tensors, print -import picotron.process_group_manager as pgm +from .utils import assert_no_meta_tensors, print +from . import process_group_manager as pgm -from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel +from .pipeline_parallel.pipeline_parallel import PipelineParallel @contextlib.contextmanager def init_model_with_dematerialized_weights(include_buffers: bool = False): diff --git a/picotron/context_parallel/__init__.py b/picotron/context_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/context_parallel/context_parallel.py b/picotron/context_parallel/context_parallel.py index 3c16831..9722542 100644 --- a/picotron/context_parallel/context_parallel.py +++ b/picotron/context_parallel/context_parallel.py @@ -4,8 +4,8 @@ import torch.nn.functional as F from typing import Any, Optional, Tuple -import picotron.process_group_manager as pgm -from picotron.context_parallel.cp_communications import ContextCommunicate +from .. import process_group_manager as pgm +from .cp_communications import ContextCommunicate def apply_context_parallel(model): os.environ["CONTEXT_PARALLEL"] = "1" if pgm.process_group_manager.cp_world_size > 1 else "0" diff --git a/picotron/context_parallel/cp_communications.py b/picotron/context_parallel/cp_communications.py index 5b3c0ff..905027a 100644 --- a/picotron/context_parallel/cp_communications.py +++ b/picotron/context_parallel/cp_communications.py @@ -3,7 +3,7 @@ from torch import distributed as dist from typing import List -import picotron.process_group_manager as pgm +from .. import process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" diff --git a/picotron/data.py b/picotron/data.py index e92bb36..c1a89f9 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -5,9 +5,9 @@ from functools import partial from datasets import Features, Sequence, Value, load_dataset from transformers import AutoTokenizer -from picotron.utils import print +from .utils import print -import picotron.process_group_manager as pgm +from . import process_group_manager as pgm class MicroBatchDataLoader(DataLoader): def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True): diff --git a/picotron/data_parallel/__init__.py b/picotron/data_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/data_parallel/data_parallel.py b/picotron/data_parallel/data_parallel.py index a74ae1b..5a63600 100644 --- a/picotron/data_parallel/data_parallel.py +++ b/picotron/data_parallel/data_parallel.py @@ -4,8 +4,8 @@ from torch import nn from torch.autograd import Variable -from picotron.data_parallel.bucket import BucketManager -import picotron.process_group_manager as pgm +from .bucket import BucketManager +from .. import process_group_manager as pgm class DataParallelNaive(nn.Module): """ diff --git a/picotron/model.py b/picotron/model.py index 2538785..3ee3d72 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -3,11 +3,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from picotron.context_parallel import context_parallel +from .context_parallel import context_parallel from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.ops.triton.layer_norm import layer_norm_fn -import picotron.process_group_manager as pgm +from . import process_group_manager as pgm def apply_rotary_pos_emb(x, cos, sin): #TODO: Maybe do class RotaryEmbedding(nn.Module) later diff --git a/picotron/pipeline_parallel/__init__.py b/picotron/pipeline_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/pipeline_parallel/pipeline_parallel.py b/picotron/pipeline_parallel/pipeline_parallel.py index 5bd16bb..78c772b 100644 --- a/picotron/pipeline_parallel/pipeline_parallel.py +++ b/picotron/pipeline_parallel/pipeline_parallel.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -import picotron.process_group_manager as pgm -from picotron.pipeline_parallel.pp_communications import pipeline_communicate, bidirectional_pipeline_communicate +from .. import process_group_manager as pgm +from .pp_communications import pipeline_communicate, bidirectional_pipeline_communicate class PipelineParallel(nn.Module): def __init__(self, model, config): diff --git a/picotron/pipeline_parallel/pp_communications.py b/picotron/pipeline_parallel/pp_communications.py index fcbd31b..41b596b 100644 --- a/picotron/pipeline_parallel/pp_communications.py +++ b/picotron/pipeline_parallel/pp_communications.py @@ -1,7 +1,7 @@ import os import torch import torch.distributed as dist -import picotron.process_group_manager as pgm +from .. import process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" diff --git a/picotron/tensor_parallel/__init__.py b/picotron/tensor_parallel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index 79a4bdb..9490aed 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -import picotron.process_group_manager as pgm -from picotron.tensor_parallel.tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce +from .. import process_group_manager as pgm +from .tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce def apply_tensor_parallel(model): diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index d5345d6..4e05b61 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -1,6 +1,6 @@ import torch.distributed as dist import torch -import picotron.process_group_manager as pgm +from .. import process_group_manager as pgm import torch.nn.functional as F from typing import Tuple diff --git a/picotron/utils.py b/picotron/utils.py index c61e07d..2a4c4be 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -3,10 +3,8 @@ import numpy as np import builtins import fcntl - +from . import process_group_manager as pgm import huggingface_hub - -import picotron.process_group_manager as pgm import torch, torch.distributed as dist def print(*args, is_print_rank=True, **kwargs): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train.py b/train.py index 19cb2ac..b6548a5 100644 --- a/train.py +++ b/train.py @@ -12,18 +12,17 @@ import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig -from picotron.context_parallel.context_parallel import apply_context_parallel -from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel -import picotron.process_group_manager as pgm -from picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params -from picotron.checkpoint import CheckpointManager -from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights -from picotron.data import MicroBatchDataLoader -from picotron.process_group_manager import setup_process_group_manager -from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel -from picotron.data_parallel.data_parallel import DataParallelBucket -from picotron.model import Llama -from picotron.utils import download_model +from .picotron.context_parallel.context_parallel import apply_context_parallel +from .picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel +from .picotron import process_group_manager as pgm +from .picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params +from .picotron.checkpoint import CheckpointManager +from .picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights +from .picotron.data import MicroBatchDataLoader +from .picotron.process_group_manager import setup_process_group_manager +from .picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel +from .picotron.data_parallel.data_parallel import DataParallelBucket +from .picotron.model import Llama import wandb def train_step(model, data_loader, device):