Skip to content

Commit

Permalink
Merge branch 'main' into integrate-swanlab
Browse files Browse the repository at this point in the history
  • Loading branch information
ShaohonChen authored Feb 28, 2025
2 parents 2d9b7e2 + 222505c commit 012a976
Show file tree
Hide file tree
Showing 11 changed files with 587 additions and 32 deletions.
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ jobs:
parallelism: 1
steps:
- checkout
- run: if [[ "$CIRCLE_PULL_REQUEST" == "" && "$CIRCLE_BRANCH" != "main" && "$CIRCLE_BRANCH" != *-release ]]; then echo "Not a PR, not the main branch and not a release branch, skip test!"; circleci-agent step halt; fi
- run: 'curl -L -H "Accept: application/vnd.github+json" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/${CIRCLE_PULL_REQUEST##*/} >> github.txt'
- run: cat github.txt
- run: (python3 -c 'import json; from datetime import datetime; fp = open("github.txt"); data = json.load(fp); fp.close(); f = "%Y-%m-%dT%H:%M:%SZ"; created = datetime.strptime(data["created_at"], f); updated = datetime.strptime(data["updated_at"], f); s = (updated - created).total_seconds(); print(int(s))' || true) > elapsed.txt
- run: if [ "$(cat elapsed.txt)" == "" ]; then echo 60 > elapsed.txt; fi
- run: cat elapsed.txt
- run: if [ "$(cat elapsed.txt)" -lt "30" ]; then echo "PR is just opened, wait some actions from GitHub"; sleep 30; fi
- run: 'if grep -q "\"draft\": true," github.txt; then echo "draft mode, skip test!"; circleci-agent step halt; fi'
- run: uv pip install -U -e .
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/change_pr_to_draft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Change PR to draft

on:
pull_request_target:
types: [opened, reopened]

jobs:
convert_pr_to_draft:
runs-on: ubuntu-22.04
name: Convert PR to draft
permissions:
pull-requests: write
contents: write
if: github.event.pull_request.draft == false
steps:
- name: Convert PR to draft
shell: bash
env:
PR_NUMBER: ${{ github.event.number }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPO: ${{ github.repository }}
run: |
echo $PR_NUMBER
gh pr ready $PR_NUMBER --repo $REPO --undo
gh pr comment $PR_NUMBER --repo $REPO --body "Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the \`Ready for review\` button (at the bottom of the PR page)."
15 changes: 13 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -358,12 +358,23 @@ class DynamicCache(Cache):
```
"""

def __init__(self) -> None:
def __init__(self, _distributed_cache_data: Iterable = None) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

# `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
# (shape=[global batch size, num_heads, seq_len, head_dim]).
# WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break
# compatibility. The name of the argument doesn't matter.
if _distributed_cache_data is not None:
for key_states, value_states in _distributed_cache_data:
self.key_cache.append(key_states)
self.value_cache.append(value_states)

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
Expand Down
27 changes: 16 additions & 11 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,6 @@
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
}

if is_decord_available():
from decord import VideoReader, cpu

if is_av_available():
import av

if is_cv2_available():
import cv2

if is_yt_dlp_available():
from yt_dlp import YoutubeDL

if TYPE_CHECKING:
if is_torch_available():
Expand Down Expand Up @@ -608,6 +597,10 @@ def sample_indices_fn(metadata, **kwargs):
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import cv2
requires_backends(read_video_opencv, ["cv2"])
import cv2

video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = video.get(cv2.CAP_PROP_FPS)
Expand Down Expand Up @@ -661,6 +654,10 @@ def sample_indices_fn(metadata, **kwargs):
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
requires_backends(read_video_decord, ["decord"])
from decord import VideoReader, cpu

vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
Expand Down Expand Up @@ -700,6 +697,10 @@ def sample_indices_fn(metadata, **kwargs):
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import av
requires_backends(read_video_pyav, ["av"])
import av

container = av.open(video_path)
total_num_frames = container.streams.video[0].frames
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
Expand Down Expand Up @@ -834,6 +835,10 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
requires_backends(load_video, ["yt_dlp"])
from yt_dlp import YoutubeDL

buffer = BytesIO()
with redirect_stdout(buffer), YoutubeDL() as f:
f.download([video])
Expand Down
Loading

0 comments on commit 012a976

Please sign in to comment.